mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2026-04-23 00:17:25 +08:00
[Bug fix] Fix the multi-input accuracy issue in the pooling model. (#5374)
* fix multi-inputs * fix threshold * fix threshold * fix
This commit is contained in:
@@ -237,4 +237,97 @@ def test_single_text_embedding(embedding_api_url, headers):
|
||||
save_embedding_baseline(embedding, baseline_file)
|
||||
else:
|
||||
print(f"Comparing with baseline: {baseline_file}")
|
||||
check_embedding_against_baseline(embedding, baseline_file, threshold=0.01)
|
||||
check_embedding_against_baseline(embedding, baseline_file, threshold=0.02)
|
||||
|
||||
|
||||
def test_multi_text_embedding(embedding_api_url, headers):
|
||||
"""Test embedding generation for batch text inputs."""
|
||||
payload = {
|
||||
"model": "default",
|
||||
"input": ["北京天安门在哪里?", "上海东方明珠有多高?", "杭州西湖的面积是多少?"],
|
||||
}
|
||||
|
||||
resp = requests.post(embedding_api_url, headers=headers, json=payload)
|
||||
assert resp.status_code == 200, f"Unexpected status code: {resp.status_code}, response: {resp.text}"
|
||||
|
||||
result = resp.json()
|
||||
assert "data" in result, "Response missing 'data' field"
|
||||
assert len(result["data"]) == 3, f"Expected 3 embedding results, got {len(result['data'])}"
|
||||
|
||||
# Validate each embedding in the batch
|
||||
for idx, item in enumerate(result["data"]):
|
||||
assert "embedding" in item, f"Item {idx} missing 'embedding' field"
|
||||
assert "index" in item, f"Item {idx} missing 'index' field"
|
||||
assert item["index"] == idx, f"Item index mismatch: expected {idx}, got {item['index']}"
|
||||
|
||||
embedding = item["embedding"]
|
||||
assert isinstance(embedding, list), f"Embedding {idx} should be a list"
|
||||
assert len(embedding) > 0, f"Embedding {idx} vector should not be empty"
|
||||
assert all(isinstance(x, (int, float)) for x in embedding), f"Embedding {idx} values should be numeric"
|
||||
|
||||
print(f"Text {idx} embedding dimension: {len(embedding)}")
|
||||
|
||||
# Verify all embeddings have the same dimension
|
||||
dimensions = [len(item["embedding"]) for item in result["data"]]
|
||||
assert len(set(dimensions)) == 1, f"All embeddings should have same dimension, got: {dimensions}"
|
||||
|
||||
# Compare embeddings with baseline
|
||||
base_path = os.getenv("MODEL_PATH", "")
|
||||
baseline_filename = "test-Qwen3-Embedding-0.6B-multi-input-baseline.json"
|
||||
|
||||
if base_path:
|
||||
baseline_file = os.path.join(base_path, "torch", baseline_filename)
|
||||
else:
|
||||
baseline_file = baseline_filename
|
||||
|
||||
# Save all embeddings to baseline
|
||||
batch_embeddings = [item["embedding"] for item in result["data"]]
|
||||
|
||||
if not os.path.exists(baseline_file):
|
||||
print("Batch baseline file not found. Saving current embeddings as baseline...")
|
||||
baseline_data = {
|
||||
"embeddings": batch_embeddings,
|
||||
"dimension": len(batch_embeddings[0]),
|
||||
"count": len(batch_embeddings),
|
||||
"inputs": payload["input"],
|
||||
}
|
||||
with open(baseline_file, "w", encoding="utf-8") as f:
|
||||
json.dump(baseline_data, f, indent=2)
|
||||
print(f"Batch baseline saved to: {baseline_file}")
|
||||
else:
|
||||
print(f"Comparing batch with baseline: {baseline_file}")
|
||||
with open(baseline_file, "r", encoding="utf-8") as f:
|
||||
baseline_data = json.load(f)
|
||||
baseline_embeddings = baseline_data["embeddings"]
|
||||
|
||||
assert len(batch_embeddings) == len(
|
||||
baseline_embeddings
|
||||
), f"Embedding count mismatch: current={len(batch_embeddings)}, baseline={len(baseline_embeddings)}"
|
||||
|
||||
# Compare each embedding
|
||||
for idx, (current_emb, baseline_emb) in enumerate(zip(batch_embeddings, baseline_embeddings)):
|
||||
print(f"\n--- Comparing embedding {idx}: '{payload['input'][idx]}' ---")
|
||||
mean_abs_diff = compare_embeddings(current_emb, baseline_emb, threshold=0.05)
|
||||
|
||||
if mean_abs_diff >= 0.05:
|
||||
# Save current batch for debugging
|
||||
temp_file = f"{baseline_file}.current"
|
||||
print("temp_file", temp_file)
|
||||
with open(temp_file, "w", encoding="utf-8") as f:
|
||||
json.dump(
|
||||
{
|
||||
"embeddings": batch_embeddings,
|
||||
"dimension": len(batch_embeddings[0]),
|
||||
"count": len(batch_embeddings),
|
||||
"inputs": payload["input"],
|
||||
},
|
||||
f,
|
||||
indent=2,
|
||||
)
|
||||
|
||||
raise AssertionError(
|
||||
f"Embedding {idx} differs from baseline by too much "
|
||||
f"(mean_abs_diff={mean_abs_diff:.6f} >= 0.01):\n"
|
||||
f"Current batch saved to: {temp_file}\n"
|
||||
f"Please check the differences."
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user