mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2026-04-23 00:17:25 +08:00
[Feature] support v1 update/clear api for RL (#6761)
* [Feature] support v1 update/clear api for RL * [fix] fix execute_model and add sleep/wakeup api * [fix] fix mtp and key_prefix * [chore] move _update_key_prefix to resume method * [fix] make the interface safe to call multiple times * [fix] fix some tiny bugs * [chore] make small changes against pr review * [docs] add docs for weight update * [test] add some tests and update docs * [style] fix code style check * [test] fix ci * [fix] fix stale control responses when control method timed out * [chore] remove unused code * [chore] fix code style * [chore] optimize tags and key_prefix * [test] fix ci * [chore] fix code style * [test] fix ci * [fix] fix ep control * [fix] fix ep control for engine cache queue
This commit is contained in:
@@ -58,6 +58,9 @@ class EngineCacheQueue:
|
||||
client_id: Unique identifier for client instances
|
||||
local_data_parallel_size: data parallel size
|
||||
local_data_parallel_id: local data parallel id
|
||||
|
||||
TODO(liyonghua): Remove multi-DP initialization. Each DP will have its own cache queue.
|
||||
|
||||
"""
|
||||
self.address: Tuple[str, int] = address
|
||||
self.authkey: bytes = authkey
|
||||
@@ -87,6 +90,7 @@ class EngineCacheQueue:
|
||||
]
|
||||
|
||||
# Initialize barriers
|
||||
self.barrier = [threading.Barrier(self.num_client) for _ in range(self.local_data_parallel_size)]
|
||||
self.barrier0_init = [threading.Barrier(self.num_client) for _ in range(self.local_data_parallel_size)]
|
||||
self.barrier1_init = [threading.Barrier(self.num_client) for _ in range(self.local_data_parallel_size)]
|
||||
self.barrier2_init = [threading.Barrier(self.num_client) for _ in range(self.local_data_parallel_size)]
|
||||
@@ -142,6 +146,7 @@ class EngineCacheQueue:
|
||||
callable=lambda idx: self.transfer_task_done_lock_init[idx],
|
||||
proxytype=AcquirerProxy,
|
||||
)
|
||||
QueueManager.register("get_barrier", callable=lambda idx: self.barrier[idx])
|
||||
QueueManager.register("get_barrier0", callable=lambda idx: self.barrier0_init[idx])
|
||||
QueueManager.register("get_barrier1", callable=lambda idx: self.barrier1_init[idx])
|
||||
QueueManager.register("get_barrier2", callable=lambda idx: self.barrier2_init[idx])
|
||||
@@ -191,6 +196,7 @@ class EngineCacheQueue:
|
||||
QueueManager.register("get_cache_sync_value")
|
||||
QueueManager.register("get_transfer_task_lock")
|
||||
QueueManager.register("get_transfer_task_done_lock")
|
||||
QueueManager.register("get_barrier")
|
||||
QueueManager.register("get_barrier0")
|
||||
QueueManager.register("get_barrier1")
|
||||
QueueManager.register("get_barrier2")
|
||||
@@ -215,6 +221,7 @@ class EngineCacheQueue:
|
||||
self.task_done_lock = self.manager.get_transfer_task_done_lock(self.local_data_parallel_id)
|
||||
|
||||
# Get barrier proxies
|
||||
self.barrier = self.manager.get_barrier(self.local_data_parallel_id)
|
||||
self.barrier0 = self.manager.get_barrier0(self.local_data_parallel_id)
|
||||
self.barrier1 = self.manager.get_barrier1(self.local_data_parallel_id)
|
||||
self.barrier2 = self.manager.get_barrier2(self.local_data_parallel_id)
|
||||
@@ -264,7 +271,12 @@ class EngineCacheQueue:
|
||||
|
||||
def put_transfer_task(self, item):
|
||||
"""
|
||||
put swap task
|
||||
Enqueue a cache transfer task (cpu/gpu swap task, read/write storage task)
|
||||
or a control task (cache clearing/restoring).
|
||||
|
||||
The queue is shared by multiple clients. A task can be enqueued only after
|
||||
the previous task has been read by all clients.
|
||||
`task_sync_value` is used as a bitmask to track per-client read status.
|
||||
"""
|
||||
self.task_lock.acquire()
|
||||
if 0 < self.task_sync_value.get() < self.total_num:
|
||||
@@ -279,7 +291,11 @@ class EngineCacheQueue:
|
||||
|
||||
def get_transfer_task(self):
|
||||
"""
|
||||
get swap task
|
||||
Get the current cache transfer task (cpu/gpu swap task, read/write storage task)
|
||||
or control signal (cache clearing/restoring) from cache task queue.
|
||||
|
||||
Each client reads the same task once. The task is removed from the queue
|
||||
only after all clients have read it, tracked by `task_sync_value`.
|
||||
"""
|
||||
data = None
|
||||
read_finish = False
|
||||
|
||||
Reference in New Issue
Block a user