mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2026-04-23 00:17:25 +08:00
[RL] Adapt async rollout checkpoint update flow (#7042)
* update checkpoint-transfer flow and control update_weights params * test: add update_weights route validation
This commit is contained in:
@@ -604,13 +604,13 @@ async def test_update_weights_route_validation():
|
||||
api_server.app.state.engine_client.run_control_method = AsyncMock(return_value=mock_control_response)
|
||||
|
||||
valid_req = MagicMock()
|
||||
valid_req.body = AsyncMock(return_value=b'{"version":"v2","rsync_config":{"etcd_server":"127.0.0.1"}}')
|
||||
valid_req.json = AsyncMock(return_value={"version": "v2", "rsync_config": {"etcd_server": "127.0.0.1"}})
|
||||
valid_req.body = AsyncMock(return_value=b'{"version":"v2","verify_checksum":true}')
|
||||
valid_req.json = AsyncMock(return_value={"version": "v2", "verify_checksum": True})
|
||||
valid_resp = await api_server.update_weights(valid_req)
|
||||
assert valid_resp.status_code == 200
|
||||
control_request = api_server.app.state.engine_client.run_control_method.await_args.args[0]
|
||||
assert control_request.method == "update_weights"
|
||||
assert control_request.args == {"version": "v2", "rsync_config": {"etcd_server": "127.0.0.1"}}
|
||||
assert control_request.args == {"version": "v2", "verify_checksum": True}
|
||||
|
||||
invalid_version_req = MagicMock()
|
||||
invalid_version_req.body = AsyncMock(return_value=b'{"version":1}')
|
||||
@@ -618,11 +618,11 @@ async def test_update_weights_route_validation():
|
||||
invalid_version_resp = await api_server.update_weights(invalid_version_req)
|
||||
assert invalid_version_resp.status_code == 400
|
||||
|
||||
invalid_rsync_req = MagicMock()
|
||||
invalid_rsync_req.body = AsyncMock(return_value=b'{"rsync_config":{"user":"u"}}')
|
||||
invalid_rsync_req.json = AsyncMock(return_value={"rsync_config": {"user": "u"}})
|
||||
invalid_rsync_resp = await api_server.update_weights(invalid_rsync_req)
|
||||
assert invalid_rsync_resp.status_code == 400
|
||||
invalid_checksum_req = MagicMock()
|
||||
invalid_checksum_req.body = AsyncMock(return_value=b'{"verify_checksum":"true"}')
|
||||
invalid_checksum_req.json = AsyncMock(return_value={"verify_checksum": "true"})
|
||||
invalid_checksum_resp = await api_server.update_weights(invalid_checksum_req)
|
||||
assert invalid_checksum_resp.status_code == 400
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
||||
Reference in New Issue
Block a user