mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2026-04-23 00:17:25 +08:00
[Speculative Decoding] Support mtp super ultra overlap in pd-split mode with insert_task overlap (#7323)
* support mtp overlap in pd-split mode with insert_task overlap
This commit is contained in:
@@ -116,36 +116,33 @@ from fastdeploy.worker.output import LogprobsTensors, ModelOutputData, SamplerOu
|
||||
|
||||
DISABLE_RECOVER = envs.FD_DISABLED_RECOVER == "1"
|
||||
|
||||
if current_platform.is_cuda():
|
||||
|
||||
def async_set_value(tgt, src):
|
||||
if isinstance(src, (int, float, bool)):
|
||||
src = paddle.full(tgt.shape, fill_value=src, dtype=tgt.dtype)
|
||||
elif isinstance(src, (list, np.array)):
|
||||
dtype_str = str(tgt.dtype).split(".")[1]
|
||||
if isinstance(src, list):
|
||||
src = np.array(src, dtype=dtype_str if dtype_str != "bfloat16" else "float32")
|
||||
def async_set_value(tgt, src):
|
||||
if isinstance(src, (int, float, bool)):
|
||||
src = paddle.full(tgt.shape, fill_value=src, dtype=tgt.dtype)
|
||||
elif isinstance(src, (list, np.ndarray)):
|
||||
dtype_str = str(tgt.dtype).split(".")[1]
|
||||
if isinstance(src, list):
|
||||
src = np.array(src, dtype=dtype_str if dtype_str != "bfloat16" else "float32")
|
||||
if current_platform.is_cuda():
|
||||
if str(src.dtype) != dtype_str:
|
||||
srt_tensor = paddle.empty(tgt.shape, dtype=str(src.dtype))
|
||||
src = custom_numpy_to_tensor(src, srt_tensor)
|
||||
else:
|
||||
return custom_numpy_to_tensor(src, tgt)
|
||||
elif isinstance(src, paddle.Tensor):
|
||||
pass
|
||||
else:
|
||||
raise ValueError("async_set_value unsupported src type: {}".format(type(src)))
|
||||
if src.shape != tgt.shape:
|
||||
src = src.reshape(tgt.shape)
|
||||
if src.dtype != tgt.dtype:
|
||||
src = src.cast(tgt.dtype)
|
||||
if src.place != tgt.place:
|
||||
src = src.to(tgt.place)
|
||||
tgt.copy_(src, blocking=False)
|
||||
|
||||
else:
|
||||
|
||||
def async_set_value(*args, **kwargs):
|
||||
raise RuntimeError("async_set_value is only available on CUDA")
|
||||
src = paddle.to_tensor(src, dtype=tgt.dtype)
|
||||
elif isinstance(src, paddle.Tensor):
|
||||
pass
|
||||
else:
|
||||
raise ValueError("async_set_value unsupported src type: {}".format(type(src)))
|
||||
if src.shape != tgt.shape:
|
||||
src = src.reshape(tgt.shape)
|
||||
if src.dtype != tgt.dtype:
|
||||
src = src.cast(tgt.dtype)
|
||||
if src.place != tgt.place:
|
||||
src = src.to(tgt.place)
|
||||
tgt.copy_(src, blocking=False)
|
||||
|
||||
|
||||
def pre_process(
|
||||
|
||||
Reference in New Issue
Block a user