Files
FastDeploy/tests/e2e/utils/rollout_routing_replay_test_utils.py
T
2026-01-22 15:46:34 +08:00

209 lines
8.4 KiB
Python

import os
import shutil
import time
import paddle
# ==========================
# Test Rollout Routing Replay
# ==========================
def calculate_routing_ratio(expected_routing: paddle.Tensor, actual_routing: paddle.Tensor) -> float:
"""Caculate routing overlap ratio"""
assert (
expected_routing.shape == actual_routing.shape
), "Routing shapes not equal. Expected shape {expected_routing.shap} actual shape {actual_routing.shape}."
expected_routing_length = get_real_routing_length(expected_routing)
actual_routing_length = get_real_routing_length(actual_routing)
for i in range(max(expected_routing_length, actual_routing_length)):
if not paddle.all(paddle.equal(expected_routing[i], actual_routing[i])).item():
print(f"token index {i}:\n expected_routing:{expected_routing[i]}\n actual_routing: {actual_routing[i]}\n")
assert (
expected_routing_length == actual_routing_length
), f"Routing real lengths do not match. Expected length {expected_routing_length} actual length {actual_routing_length}."
total_rows, elements_per_row = expected_routing.shape
mask1 = paddle.any(expected_routing != -1, axis=1)
mask2 = paddle.any(actual_routing != -1, axis=1)
valid_mask = mask1 & mask2
if paddle.sum(valid_mask.cast("int32")) == 0:
return paddle.to_tensor(0.0)
valid_expected_routing = expected_routing[valid_mask] # [n_valid, top_k]
valid_actual_routing = actual_routing[valid_mask] # [n_valid, top_k]
# valid_expected_routing: [n_valid, top_k, 1], valid_actual_routing: [n_valid, 1, top_k]
# -> equals: [n_valid, top_k, top_k]
equals = valid_expected_routing.unsqueeze(2) == valid_actual_routing.unsqueeze(1)
overlap_mask = paddle.any(equals, axis=2) # [n_valid, 8]
overlap_counts = paddle.sum(overlap_mask.cast("float32"), axis=1) # [n_valid]
overlap_ratios = overlap_counts / elements_per_row # [n_valid]
return paddle.mean(overlap_ratios)
def get_real_routing_length(routing: paddle.Tensor) -> int:
mask = routing == -1
mask_float = mask.astype(paddle.float32)
row_has_true = paddle.any(mask_float, axis=1).astype(paddle.float32)
first_true_index = paddle.argmax(row_has_true, axis=0)
if row_has_true.any().item():
return first_true_index.item()
else:
return -1
# Streaming test
def send_r3_streaming_chat(openai_client, user_id: str = ""):
"""
Test streaming chat functionality with the local service
"""
response = openai_client.chat.completions.create(
model="default",
messages=[
{"role": "system", "content": "You are a helpful AI assistant."},
{"role": "user", "content": "List 3 countries and their capitals."},
{
"role": "assistant",
"content": "China(Beijing), France(Paris), Australia(Canberra).",
},
{"role": "user", "content": "OK, tell more."},
],
temperature=1,
top_p=0,
max_tokens=1024,
seed=13,
stream=True,
user=user_id, # "r3_chat_completion_stream_test",
)
return response
def send_r3_non_streaming_chat(openai_client, user_id: str = ""):
"""
Test non-streaming chat functionality with the local service
"""
# Send test request
response = openai_client.chat.completions.create(
model="default",
messages=[
{"role": "system", "content": "You are a helpful AI assistant."},
{"role": "user", "content": "List 3 countries and their capitals."},
],
temperature=1,
top_p=0,
max_tokens=1024,
seed=13,
stream=False,
user=user_id, # "rollout_routing_replay_chat_completion_nonstream_test"
)
return response
def generated_base_line_routing_index(openai_client, cur_save_routing_path, baseline_path):
# Generate streaming chat routing index
send_r3_streaming_chat(openai_client, user_id="r3_chat_completion_stream")
# Generate non streaming chat routing index
send_r3_non_streaming_chat(openai_client, user_id="r3_chat_completion_nonstream")
# Check the routing is generated correctly
stream_cur_save_routing_path = os.path.join(cur_save_routing_path, "r3_chat_completion_stream")
nonstream_cur_save_routing_path = os.path.join(cur_save_routing_path, "r3_chat_completion_nonstream")
wait_for_file(stream_cur_save_routing_path)
wait_for_file(nonstream_cur_save_routing_path)
# Move the baseline to the routing_replay_output_baseline folder
stream_baseline_path = os.path.join(baseline_path, "r3_chat_completion_stream")
nonstream_baseline_path = os.path.join(baseline_path, "r3_chat_completion_nonstream")
shutil.move(stream_cur_save_routing_path, stream_baseline_path)
shutil.move(nonstream_cur_save_routing_path, nonstream_baseline_path)
def wait_for_file(file_path, timeout=20, check_interval=0.1):
start_time = time.perf_counter()
deadline = start_time + timeout
while True:
# Check timeout or not
current_time = time.perf_counter()
if current_time >= deadline:
return False
# Check file generated
if os.path.exists(file_path):
return True
sleep_time = min(check_interval, deadline - current_time)
time.sleep(sleep_time)
def check_routing_replay_chat_completion(openai_client, moe_layer_num: int, model_name: str):
"""Test rollout routing replay chat completion"""
cur_save_routing_path = f"./R3_tmp/routing_replay_output_{model_name}/"
model_path = os.getenv("MODEL_PATH")
if model_path:
baseline_path = os.path.join(model_path, f"R3_BaseLine_dev/routing_replay_output_baseline_{model_name}")
else:
baseline_path = f"./R3_BaseLine_dev/routing_replay_output_baseline_{model_name}"
stream_baseline_path = os.path.join(baseline_path, "r3_chat_completion_stream")
nonstream_baseline_path = os.path.join(baseline_path, "r3_chat_completion_nonstream")
# Maybe need to generate baseline routing index
if not os.path.exists(stream_baseline_path) or not os.path.exists(nonstream_baseline_path):
generated_base_line_routing_index(openai_client, cur_save_routing_path, baseline_path)
raise FileNotFoundError(f"Not find the R3 baseline file {nonstream_baseline_path} or {stream_baseline_path} .")
routing_layer_num_1 = len(os.listdir(stream_baseline_path))
routing_layer_num_2 = len(os.listdir(nonstream_baseline_path))
assert (
routing_layer_num_1 == moe_layer_num
), f"routing index number {routing_layer_num_1} should equal to moe layer number {moe_layer_num}"
assert (
routing_layer_num_2 == moe_layer_num
), f"routing index number {routing_layer_num_2} should equal to moe layer number {moe_layer_num}"
# Test streaming chat
send_r3_streaming_chat(openai_client, user_id="r3_chat_completion_stream")
for layer_index in range(moe_layer_num):
cur_routing_path = os.path.join(
cur_save_routing_path, f"r3_chat_completion_stream/layer_{layer_index}.pdtensor"
)
baseline_routing_path = os.path.join(stream_baseline_path, f"layer_{layer_index}.pdtensor")
wait_for_file(cur_routing_path)
generated_routing = paddle.load(cur_routing_path)
baseline_routing = paddle.load(baseline_routing_path)
overlap_ratio = calculate_routing_ratio(baseline_routing, generated_routing)
assert (
overlap_ratio >= 0.999
), f"the routing overlap ratio of the layer {layer_index} should be equal to baseline routing index, but got {overlap_ratio}"
# Test non streaming chat
send_r3_non_streaming_chat(openai_client, user_id="r3_chat_completion_nonstream")
for layer_index in range(moe_layer_num):
cur_routing_path = os.path.join(
cur_save_routing_path, f"r3_chat_completion_nonstream/layer_{layer_index}.pdtensor"
)
baseline_routing_path = os.path.join(nonstream_baseline_path, f"layer_{layer_index}.pdtensor")
wait_for_file(cur_routing_path)
generated_routing = paddle.load(cur_routing_path)
baseline_routing = paddle.load(baseline_routing_path)
overlap_ratio = calculate_routing_ratio(baseline_routing, generated_routing)
assert (
overlap_ratio >= 0.999
), f"the routing overlap ratio of the layer {layer_index} should be equal to baseline routing index, but got {overlap_ratio}"
# shutil.rmtree(cur_save_routing_path)