mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2026-04-23 00:17:25 +08:00
[Feature] Add temp_scaled_logprobs and top_p_normalized_logprobs parameters for logits and logprobs post processing (#3552)
* [feature] Add temp_scaled_logprobs and top_p_normalized_logprobs parameters for logits and logprobs post processing * infer engine support temp_scaled_logprobs and top_p_normalized_logprobs * delete some code * code check * code check and add doc * fix tokenizer.decoder(-1), return 'Invalid Token' * add ci for temp_scaled and top_p logprobs * check test * check seq len time shape * logprob clip inf --------- Co-authored-by: sunlei1024 <sunlei5788@gmail.com>
This commit is contained in:
@@ -154,8 +154,101 @@ def test_stream_without_logprobs():
|
||||
assert result_chunk["choices"][0]["logprobs"] is None
|
||||
|
||||
|
||||
def test_stream_with_temp_scaled_logprobs():
|
||||
"""
|
||||
测试流式响应开启 temp_scaled_logprobs 后,首个 token 的概率信息是否正确。
|
||||
"""
|
||||
data = {
|
||||
"stream": True,
|
||||
"messages": [
|
||||
{"role": "system", "content": "You are a helpful assistant."},
|
||||
{"role": "user", "content": "牛顿的三大运动定律是什么?"},
|
||||
],
|
||||
"max_tokens": 3,
|
||||
"temperature": 0.8,
|
||||
"top_p": 0,
|
||||
"temp_scaled_logprobs": True,
|
||||
}
|
||||
|
||||
payload = build_request_payload(TEMPLATE, data)
|
||||
response = send_request(URL, payload)
|
||||
|
||||
# 解析首个包含 content 的流式 chunk
|
||||
result_chunk = {}
|
||||
for line in response.iter_lines():
|
||||
if not line:
|
||||
continue
|
||||
decoded = line.decode("utf-8").removeprefix("data: ")
|
||||
if decoded == "[DONE]":
|
||||
break
|
||||
|
||||
chunk = json.loads(decoded)
|
||||
content = chunk["choices"][0]["delta"].get("content")
|
||||
if content:
|
||||
result_chunk = chunk
|
||||
print(json.dumps(result_chunk, indent=2, ensure_ascii=False))
|
||||
break
|
||||
|
||||
# 校验概率字段
|
||||
assert result_chunk["choices"][0]["delta"]["content"] == "牛顿"
|
||||
assert result_chunk["choices"][0]["logprobs"]["content"][0]["token"] == "牛顿"
|
||||
assert result_chunk["choices"][0]["logprobs"]["content"][0]["logprob"] == -0.006811376195400953
|
||||
assert result_chunk["choices"][0]["logprobs"]["content"][0]["top_logprobs"][0] == {
|
||||
"token": "牛顿",
|
||||
"logprob": -0.006811376195400953,
|
||||
"bytes": [231, 137, 155, 233, 161, 191],
|
||||
}
|
||||
|
||||
|
||||
def test_stream_with_top_p_normalized_logprobs():
|
||||
"""
|
||||
测试流式响应开启 top_p_normalized_logprobs 后,首个 token 的概率信息是否正确。
|
||||
"""
|
||||
data = {
|
||||
"stream": True,
|
||||
"messages": [
|
||||
{"role": "system", "content": "You are a helpful assistant."},
|
||||
{"role": "user", "content": "牛顿的三大运动定律是什么?"},
|
||||
],
|
||||
"max_tokens": 3,
|
||||
"top_p": 0,
|
||||
"top_p_normalized_logprobs": True,
|
||||
}
|
||||
|
||||
payload = build_request_payload(TEMPLATE, data)
|
||||
response = send_request(URL, payload)
|
||||
|
||||
# 解析首个包含 content 的流式 chunk
|
||||
result_chunk = {}
|
||||
for line in response.iter_lines():
|
||||
if not line:
|
||||
continue
|
||||
decoded = line.decode("utf-8").removeprefix("data: ")
|
||||
if decoded == "[DONE]":
|
||||
break
|
||||
|
||||
chunk = json.loads(decoded)
|
||||
content = chunk["choices"][0]["delta"].get("content")
|
||||
if content:
|
||||
result_chunk = chunk
|
||||
print(json.dumps(result_chunk, indent=2, ensure_ascii=False))
|
||||
break
|
||||
|
||||
# 校验概率字段
|
||||
assert result_chunk["choices"][0]["delta"]["content"] == "牛顿"
|
||||
assert result_chunk["choices"][0]["logprobs"]["content"][0]["token"] == "牛顿"
|
||||
assert result_chunk["choices"][0]["logprobs"]["content"][0]["logprob"] == 0.0
|
||||
assert result_chunk["choices"][0]["logprobs"]["content"][0]["top_logprobs"][0] == {
|
||||
"token": "牛顿",
|
||||
"logprob": 0.0,
|
||||
"bytes": [231, 137, 155, 233, 161, 191],
|
||||
}
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_unstream_with_logprobs()
|
||||
test_unstream_without_logprobs()
|
||||
test_stream_with_logprobs()
|
||||
test_stream_without_logprobs()
|
||||
test_stream_with_temp_scaled_logprobs()
|
||||
test_stream_with_top_p_normalized_logprobs()
|
||||
|
||||
Reference in New Issue
Block a user