more eplb offline load dtypes (#6435)

This commit is contained in:
RichardWooSJTU
2026-03-02 14:34:20 +08:00
committed by GitHub
parent 758770bc43
commit 6d83dcc1c2
+3
View File
@@ -173,6 +173,9 @@ def load_tensor_from_shm_mem(tensor_infos, shm_ptr, logger=None):
elif dtype == paddle.float8_e4m3fn: elif dtype == paddle.float8_e4m3fn:
tmp = np_array.view(np.uint8) tmp = np_array.view(np.uint8)
tensor = paddle.Tensor(tmp, dtype=paddle.float8_e4m3fn, place=paddle.CPUPlace(), zero_copy=True) tensor = paddle.Tensor(tmp, dtype=paddle.float8_e4m3fn, place=paddle.CPUPlace(), zero_copy=True)
elif dtype == paddle.int32:
tmp = np_array.view(np.int32)
tensor = paddle.Tensor(tmp, dtype=paddle.int32, place=paddle.CPUPlace(), zero_copy=True)
else: else:
raise TypeError(f"Unsupported dtype: {dtype}") raise TypeError(f"Unsupported dtype: {dtype}")