mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2026-04-23 00:17:25 +08:00
955785e2e0
* fix typo
209 lines
8.4 KiB
Python
209 lines
8.4 KiB
Python
import os
|
|
import shutil
|
|
import time
|
|
|
|
import paddle
|
|
|
|
|
|
# ==========================
|
|
# Test Rollout Routing Replay
|
|
# ==========================
|
|
def calculate_routing_ratio(expected_routing: paddle.Tensor, actual_routing: paddle.Tensor) -> float:
|
|
"""Caculate routing overlap ratio"""
|
|
assert (
|
|
expected_routing.shape == actual_routing.shape
|
|
), "Routing shapes not equal. Expected shape {expected_routing.shap} actual shape {actual_routing.shape}."
|
|
expected_routing_length = get_real_routing_length(expected_routing)
|
|
actual_routing_length = get_real_routing_length(actual_routing)
|
|
|
|
for i in range(max(expected_routing_length, actual_routing_length)):
|
|
if not paddle.all(paddle.equal(expected_routing[i], actual_routing[i])).item():
|
|
print(f"token index {i}:\n expected_routing:{expected_routing[i]}\n actual_routing: {actual_routing[i]}\n")
|
|
|
|
assert (
|
|
expected_routing_length == actual_routing_length
|
|
), f"Routing real lengths do not match. Expected length {expected_routing_length} actual length {actual_routing_length}."
|
|
total_rows, elements_per_row = expected_routing.shape
|
|
|
|
mask1 = paddle.any(expected_routing != -1, axis=1)
|
|
mask2 = paddle.any(actual_routing != -1, axis=1)
|
|
valid_mask = mask1 & mask2
|
|
|
|
if paddle.sum(valid_mask.cast("int32")) == 0:
|
|
return paddle.to_tensor(0.0)
|
|
|
|
valid_expected_routing = expected_routing[valid_mask] # [n_valid, top_k]
|
|
valid_actual_routing = actual_routing[valid_mask] # [n_valid, top_k]
|
|
|
|
# valid_expected_routing: [n_valid, top_k, 1], valid_actual_routing: [n_valid, 1, top_k]
|
|
# -> equals: [n_valid, top_k, top_k]
|
|
equals = valid_expected_routing.unsqueeze(2) == valid_actual_routing.unsqueeze(1)
|
|
|
|
overlap_mask = paddle.any(equals, axis=2) # [n_valid, 8]
|
|
|
|
overlap_counts = paddle.sum(overlap_mask.cast("float32"), axis=1) # [n_valid]
|
|
overlap_ratios = overlap_counts / elements_per_row # [n_valid]
|
|
|
|
return paddle.mean(overlap_ratios)
|
|
|
|
|
|
def get_real_routing_length(routing: paddle.Tensor) -> int:
|
|
mask = routing == -1
|
|
mask_float = mask.astype(paddle.float32)
|
|
row_has_true = paddle.any(mask_float, axis=1).astype(paddle.float32)
|
|
|
|
first_true_index = paddle.argmax(row_has_true, axis=0)
|
|
if row_has_true.any().item():
|
|
return first_true_index.item()
|
|
else:
|
|
return -1
|
|
|
|
|
|
# Streaming test
|
|
def send_r3_streaming_chat(openai_client, user_id: str = ""):
|
|
"""
|
|
Test streaming chat functionality with the local service
|
|
"""
|
|
response = openai_client.chat.completions.create(
|
|
model="default",
|
|
messages=[
|
|
{"role": "system", "content": "You are a helpful AI assistant."},
|
|
{"role": "user", "content": "List 3 countries and their capitals."},
|
|
{
|
|
"role": "assistant",
|
|
"content": "China(Beijing), France(Paris), Australia(Canberra).",
|
|
},
|
|
{"role": "user", "content": "OK, tell more."},
|
|
],
|
|
temperature=1,
|
|
top_p=0,
|
|
max_tokens=1024,
|
|
seed=13,
|
|
stream=True,
|
|
user=user_id, # "r3_chat_completion_stream_test",
|
|
)
|
|
|
|
return response
|
|
|
|
|
|
def send_r3_non_streaming_chat(openai_client, user_id: str = ""):
|
|
"""
|
|
Test non-streaming chat functionality with the local service
|
|
"""
|
|
# Send test request
|
|
response = openai_client.chat.completions.create(
|
|
model="default",
|
|
messages=[
|
|
{"role": "system", "content": "You are a helpful AI assistant."},
|
|
{"role": "user", "content": "List 3 countries and their capitals."},
|
|
],
|
|
temperature=1,
|
|
top_p=0,
|
|
max_tokens=1024,
|
|
seed=13,
|
|
stream=False,
|
|
user=user_id, # "rollout_routing_replay_chat_completion_nonstream_test"
|
|
)
|
|
|
|
return response
|
|
|
|
|
|
def generated_base_line_routing_index(openai_client, cur_save_routing_path, baseline_path):
|
|
# Generate streaming chat routing index
|
|
send_r3_streaming_chat(openai_client, user_id="r3_chat_completion_stream")
|
|
# Generate non streaming chat routing index
|
|
send_r3_non_streaming_chat(openai_client, user_id="r3_chat_completion_nonstream")
|
|
|
|
# Check the routing is generated correctly
|
|
stream_cur_save_routing_path = os.path.join(cur_save_routing_path, "r3_chat_completion_stream")
|
|
nonstream_cur_save_routing_path = os.path.join(cur_save_routing_path, "r3_chat_completion_nonstream")
|
|
|
|
wait_for_file(stream_cur_save_routing_path)
|
|
wait_for_file(nonstream_cur_save_routing_path)
|
|
|
|
# Move the baseline to the routing_replay_output_baseline folder
|
|
stream_baseline_path = os.path.join(baseline_path, "r3_chat_completion_stream")
|
|
nonstream_baseline_path = os.path.join(baseline_path, "r3_chat_completion_nonstream")
|
|
shutil.move(stream_cur_save_routing_path, stream_baseline_path)
|
|
shutil.move(nonstream_cur_save_routing_path, nonstream_baseline_path)
|
|
|
|
|
|
def wait_for_file(file_path, timeout=20, check_interval=0.1):
|
|
start_time = time.perf_counter()
|
|
deadline = start_time + timeout
|
|
|
|
while True:
|
|
# Check timeout or not
|
|
current_time = time.perf_counter()
|
|
if current_time >= deadline:
|
|
return False
|
|
|
|
# Check file generated
|
|
if os.path.exists(file_path):
|
|
return True
|
|
|
|
sleep_time = min(check_interval, deadline - current_time)
|
|
time.sleep(sleep_time)
|
|
|
|
|
|
def check_routing_replay_chat_completion(openai_client, moe_layer_num: int, model_name: str):
|
|
"""Test rollout routing replay chat completion"""
|
|
cur_save_routing_path = f"./R3_tmp/routing_replay_output_{model_name}/"
|
|
model_path = os.getenv("MODEL_PATH")
|
|
if model_path:
|
|
baseline_path = os.path.join(model_path, f"R3_BaseLine_dev/routing_replay_output_baseline_{model_name}")
|
|
else:
|
|
baseline_path = f"./R3_BaseLine_dev/routing_replay_output_baseline_{model_name}"
|
|
stream_baseline_path = os.path.join(baseline_path, "r3_chat_completion_stream")
|
|
|
|
nonstream_baseline_path = os.path.join(baseline_path, "r3_chat_completion_nonstream")
|
|
|
|
# Maybe need to generate baseline routing index
|
|
if not os.path.exists(stream_baseline_path) or not os.path.exists(nonstream_baseline_path):
|
|
generated_base_line_routing_index(openai_client, cur_save_routing_path, baseline_path)
|
|
raise FileNotFoundError(f"Not find the R3 baseline file {nonstream_baseline_path} or {stream_baseline_path} .")
|
|
|
|
routing_layer_num_1 = len(os.listdir(stream_baseline_path))
|
|
routing_layer_num_2 = len(os.listdir(nonstream_baseline_path))
|
|
assert (
|
|
routing_layer_num_1 == moe_layer_num
|
|
), f"routing index number {routing_layer_num_1} should equal to moe layer number {moe_layer_num}"
|
|
assert (
|
|
routing_layer_num_2 == moe_layer_num
|
|
), f"routing index number {routing_layer_num_2} should equal to moe layer number {moe_layer_num}"
|
|
|
|
# Test streaming chat
|
|
send_r3_streaming_chat(openai_client, user_id="r3_chat_completion_stream")
|
|
for layer_index in range(moe_layer_num):
|
|
cur_routing_path = os.path.join(
|
|
cur_save_routing_path, f"r3_chat_completion_stream/layer_{layer_index}.pdtensor"
|
|
)
|
|
baseline_routing_path = os.path.join(stream_baseline_path, f"layer_{layer_index}.pdtensor")
|
|
wait_for_file(cur_routing_path)
|
|
|
|
generated_routing = paddle.load(cur_routing_path)
|
|
baseline_routing = paddle.load(baseline_routing_path)
|
|
overlap_ratio = calculate_routing_ratio(baseline_routing, generated_routing)
|
|
assert (
|
|
overlap_ratio >= 0.999
|
|
), f"the routing overlap ratio of the layer {layer_index} should be equal to baseline routing index, but got {overlap_ratio}"
|
|
|
|
# Test non streaming chat
|
|
send_r3_non_streaming_chat(openai_client, user_id="r3_chat_completion_nonstream")
|
|
for layer_index in range(moe_layer_num):
|
|
cur_routing_path = os.path.join(
|
|
cur_save_routing_path, f"r3_chat_completion_nonstream/layer_{layer_index}.pdtensor"
|
|
)
|
|
baseline_routing_path = os.path.join(nonstream_baseline_path, f"layer_{layer_index}.pdtensor")
|
|
|
|
wait_for_file(cur_routing_path)
|
|
|
|
generated_routing = paddle.load(cur_routing_path)
|
|
baseline_routing = paddle.load(baseline_routing_path)
|
|
overlap_ratio = calculate_routing_ratio(baseline_routing, generated_routing)
|
|
assert (
|
|
overlap_ratio >= 0.999
|
|
), f"the routing overlap ratio of the layer {layer_index} should be equal to baseline routing index, but got {overlap_ratio}"
|
|
|
|
# shutil.rmtree(cur_save_routing_path)
|