[RL] Adapt async rollout checkpoint update flow (#7042)

* update checkpoint-transfer flow and control update_weights params

* test: add update_weights route validation
This commit is contained in:
jackyYang6
2026-03-30 19:19:34 +08:00
committed by GitHub
parent 8789329457
commit 05f2d95729
9 changed files with 58 additions and 88 deletions
+8 -8
View File
@@ -604,13 +604,13 @@ async def test_update_weights_route_validation():
api_server.app.state.engine_client.run_control_method = AsyncMock(return_value=mock_control_response)
valid_req = MagicMock()
valid_req.body = AsyncMock(return_value=b'{"version":"v2","rsync_config":{"etcd_server":"127.0.0.1"}}')
valid_req.json = AsyncMock(return_value={"version": "v2", "rsync_config": {"etcd_server": "127.0.0.1"}})
valid_req.body = AsyncMock(return_value=b'{"version":"v2","verify_checksum":true}')
valid_req.json = AsyncMock(return_value={"version": "v2", "verify_checksum": True})
valid_resp = await api_server.update_weights(valid_req)
assert valid_resp.status_code == 200
control_request = api_server.app.state.engine_client.run_control_method.await_args.args[0]
assert control_request.method == "update_weights"
assert control_request.args == {"version": "v2", "rsync_config": {"etcd_server": "127.0.0.1"}}
assert control_request.args == {"version": "v2", "verify_checksum": True}
invalid_version_req = MagicMock()
invalid_version_req.body = AsyncMock(return_value=b'{"version":1}')
@@ -618,11 +618,11 @@ async def test_update_weights_route_validation():
invalid_version_resp = await api_server.update_weights(invalid_version_req)
assert invalid_version_resp.status_code == 400
invalid_rsync_req = MagicMock()
invalid_rsync_req.body = AsyncMock(return_value=b'{"rsync_config":{"user":"u"}}')
invalid_rsync_req.json = AsyncMock(return_value={"rsync_config": {"user": "u"}})
invalid_rsync_resp = await api_server.update_weights(invalid_rsync_req)
assert invalid_rsync_resp.status_code == 400
invalid_checksum_req = MagicMock()
invalid_checksum_req.body = AsyncMock(return_value=b'{"verify_checksum":"true"}')
invalid_checksum_req.json = AsyncMock(return_value={"verify_checksum": "true"})
invalid_checksum_resp = await api_server.update_weights(invalid_checksum_req)
assert invalid_checksum_resp.status_code == 400
@pytest.mark.asyncio