mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2026-04-24 01:29:57 +08:00
[RL][BugFix][Optimization] Support chunked part files loading and fix model path format in IPC snapshot strategy (#6852)
* [RL] Support chunked part files loading in IPC snapshot strategy
## Motivation
When using IPC snapshot for elastic recovery in RL training, loading a single large pdparams file causes a significant memory spike. This PR refactors `_update_ipc_snapshot` to support loading chunked part files to avoid the memory spike.
## Modifications
Refactored `_update_ipc_snapshot` in `fastdeploy/rl/dynamic_weight_manager.py` with a three-level loading priority:
1. **Chunked part files** (`model_state.tpR{id}.part{N}.pdparams`): Load multiple smaller shards sequentially, freeing memory between each chunk via `gc.collect()` to avoid memory spike.
2. **Single full file** (`model_state.tpR{id}.pdparams`): Legacy single-file loading path (preserved for backward compatibility).
3. **Shared fallback directory** (`/shared_ipc_meta/...`): Oldest legacy fallback path (preserved for backward compatibility).
Also fixed the rank ID in the file name pattern from hardcoded `tp0` to dynamic `paddle.distributed.get_rank()`.
## Checklist
- [ ] Add at least a tag in the PR title.
- [ ] Format your code, run `pre-commit` before commit.
- [ ] Add unit tests. Please write the reason in this PR if no unit tests.
- [ ] Provide accuracy results.
- [ ] If the current PR is submitting to the `release` branch, make sure the PR has been submitted to the `develop` branch, then cherry-pick it to the `release` branch with the `[Cherry-Pick]` PR tag.
Co-Authored-By: lishuaihui <lishuaihui@baidu.com>
* [RL] Support chunked part files loading in IPC snapshot strategy
## Motivation
When using IPC snapshot for elastic recovery in RL training, loading a single large pdparams file causes a significant memory spike. This PR refactors `_update_ipc_snapshot` to support loading chunked part files to avoid the memory spike.
## Modifications
Refactored `_update_ipc_snapshot` in `fastdeploy/rl/dynamic_weight_manager.py` with a three-level loading priority:
1. **Chunked part files** (`model_state.tpR{id}.part{N}.pdparams`): Load multiple smaller shards sequentially, freeing memory between each chunk via `gc.collect()` to avoid memory spike.
2. **Single full file** (`model_state.tpR{id}.pdparams`): Legacy single-file loading path (preserved for backward compatibility).
3. **Shared fallback directory** (`/shared_ipc_meta/...`): Oldest legacy fallback path (preserved for backward compatibility).
Also fixed the rank ID in the file name pattern from hardcoded `tp0` to dynamic `paddle.distributed.get_rank()`.
## Checklist
- [ ] Add at least a tag in the PR title.
- [ ] Format your code, run `pre-commit` before commit.
- [ ] Add unit tests. Please write the reason in this PR if no unit tests.
- [ ] Provide accuracy results.
- [ ] If the current PR is submitting to the `release` branch, make sure the PR has been submitted to the `develop` branch, then cherry-pick it to the `release` branch with the `[Cherry-Pick]` PR tag.
Co-Authored-By: lishuaihui <lishuaihui@baidu.com>
* [RL][BugFix] Fix ambiguous model path format and add legacy fallback in IPC snapshot
## Motivation
The previous snapshot file naming `model_state.tp{rank}{id}` concatenated
rank and id without a separator, causing ambiguity (e.g., rank=1, id=234
and rank=12, id=34 both produce `tp1234`). Additionally, after the naming
format is updated, existing checkpoints saved in the old format would fail
to load during elastic recovery, causing unnecessary failures.
## Modifications
- Add dot separator between rank and id in snapshot file name:
`model_state.tp{rank}{id}` → `model_state.tp{rank}.{id}`
- Add Priority 3 legacy fallback to load old-format files
(`model_state.tp0{id}.pdparams`) for backward compatibility during
rolling upgrades
- Update docstring and error message to reflect the new 4-level priority
Co-Authored-By: lishuaihui <lishuaihui@baidu.com>
* [RL][Test] Add unit tests for DynamicWeightManager._update_ipc_snapshot
Cover all 4 loading priority branches (chunked part files, single full
pdparams, legacy format, shared directory fallback) with mock-based
tests to verify correct behavior without filesystem or GPU dependencies.
Co-Authored-By: lishuaihui <lishuaihui@baidu.com>
* [RL][Test] Remove unused import 'call' in test_update_ipc_snapshot.py
Co-Authored-By: lishuaihui <lishuaihui@baidu.com>
* Potential fix for pull request finding
Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com>
* [RL] Fix snapshot part index to match filename numbering
Parse part index from filename (e.g. .part0.) instead of using
enumerate index, so that logs and src_type stay consistent with
the actual file naming convention.
Co-Authored-By: wikilsh <wiki_hui@qq.com>
---------
Co-authored-by: YuBaoku <49938469+EmmonsCurse@users.noreply.github.com>
Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com>
This commit is contained in:
@@ -14,8 +14,11 @@
|
||||
# limitations under the License.
|
||||
"""
|
||||
|
||||
import gc
|
||||
import glob
|
||||
import io
|
||||
import os
|
||||
import re
|
||||
import time
|
||||
from multiprocessing.shared_memory import SharedMemory
|
||||
from typing import Any, Dict, List
|
||||
@@ -199,20 +202,88 @@ class DynamicWeightManager:
|
||||
# step6: update weight status signal
|
||||
|
||||
def _update_ipc_snapshot(self):
|
||||
"""Update using IPC snapshot strategy for elastic recovery."""
|
||||
model_path = os.path.join(
|
||||
self.fd_config.model_config.model,
|
||||
f"model_state.tp0{self.meta_src_id}.pdparams",
|
||||
)
|
||||
"""Update using IPC snapshot strategy for elastic recovery.
|
||||
|
||||
try:
|
||||
Loading priority:
|
||||
1. Chunked part files (model_state.tp{rank}.{id}.part{N}.pdparams)
|
||||
2. Single full file (model_state.tp{rank}.{id}.pdparams)
|
||||
3. Legacy format (model_state.tp0{id}.pdparams)
|
||||
4. Shared fallback dir (/shared_ipc_meta/...)
|
||||
"""
|
||||
model_dir = self.fd_config.model_config.model
|
||||
base_name = f"model_state.tp{paddle.distributed.get_rank()}.{self.meta_src_id}"
|
||||
legacy_base_name = f"model_state.tp0{self.meta_src_id}"
|
||||
|
||||
# --- Priority 1: load from chunked part files to avoid memory spike ---
|
||||
part_pattern = os.path.join(model_dir, f"{base_name}.part*.pdparams")
|
||||
all_part_files = glob.glob(part_pattern)
|
||||
|
||||
valid_part_files = []
|
||||
invalid_part_files = []
|
||||
part_regex = re.compile(r"\.part(\d+)\.")
|
||||
|
||||
for path in all_part_files:
|
||||
match = part_regex.search(path)
|
||||
if not match:
|
||||
invalid_part_files.append(os.path.basename(path))
|
||||
continue
|
||||
try:
|
||||
part_idx = int(match.group(1))
|
||||
except (TypeError, ValueError):
|
||||
invalid_part_files.append(os.path.basename(path))
|
||||
continue
|
||||
valid_part_files.append((part_idx, path))
|
||||
|
||||
if invalid_part_files:
|
||||
logger.warning(
|
||||
"Found snapshot part files with invalid naming pattern under %s: %s. "
|
||||
"These files will be ignored when loading IPC snapshot parts.",
|
||||
model_dir,
|
||||
", ".join(invalid_part_files),
|
||||
)
|
||||
|
||||
part_files = [p for _, p in sorted(valid_part_files, key=lambda item: item[0])]
|
||||
|
||||
if part_files:
|
||||
logger.info(f"Found {len(part_files)} snapshot part files for {base_name}")
|
||||
for load_idx, part_path in enumerate(part_files):
|
||||
match = re.search(r"\.part(\d+)\.", part_path)
|
||||
# Use part index parsed from filename to keep logs and src_type consistent with file naming
|
||||
part_index = int(match.group(1)) if match else load_idx
|
||||
logger.info(f"Loading snapshot part {part_index+1}/{len(part_files)} from {part_path}")
|
||||
ipc_state_dict = paddle.load(part_path, safetensors=True)
|
||||
self._update_model_from_state(ipc_state_dict, f"snapshot-part{part_index}")
|
||||
del ipc_state_dict
|
||||
gc.collect()
|
||||
logger.info(f"IPC snapshot update completed from {len(part_files)} part files under {model_dir}")
|
||||
return
|
||||
|
||||
# --- Priority 2: single full pdparams file ---
|
||||
model_path = os.path.join(model_dir, f"{base_name}.pdparams")
|
||||
if os.path.exists(model_path):
|
||||
ipc_state_dict = paddle.load(model_path, safetensors=True)
|
||||
except FileNotFoundError:
|
||||
fallback_path = f"/shared_ipc_meta/model_state.tp0{self.meta_src_id}.pdparams"
|
||||
ipc_state_dict = paddle.load(fallback_path)
|
||||
self._update_model_from_state(ipc_state_dict, "snapshot")
|
||||
logger.info(f"IPC snapshot update completed from {model_path}")
|
||||
return
|
||||
|
||||
# --- Priority 3: legacy format (model_state.tp0{id}.pdparams) ---
|
||||
legacy_path = os.path.join(model_dir, f"{legacy_base_name}.pdparams")
|
||||
if os.path.exists(legacy_path):
|
||||
ipc_state_dict = paddle.load(legacy_path, safetensors=True)
|
||||
self._update_model_from_state(ipc_state_dict, "snapshot")
|
||||
logger.info(f"IPC snapshot update completed from legacy format {legacy_path}")
|
||||
return
|
||||
|
||||
# --- Priority 4: shared directory fallback ---
|
||||
fallback_path = f"/shared_ipc_meta/{base_name}.pdparams"
|
||||
if not os.path.exists(fallback_path):
|
||||
raise FileNotFoundError(
|
||||
f"No snapshot found for {base_name}: " f"checked {model_dir} (new/legacy) and {fallback_path}"
|
||||
)
|
||||
logger.info(f"No local snapshot in {model_dir}, fallback to {fallback_path}")
|
||||
ipc_state_dict = paddle.load(fallback_path)
|
||||
self._update_model_from_state(ipc_state_dict, "snapshot")
|
||||
logger.info(f"IPC snapshot update parameters completed from {model_path}")
|
||||
logger.info(f"IPC snapshot update completed from {fallback_path}")
|
||||
|
||||
def _update_ipc(self):
|
||||
"""Update using standard IPC strategy (requires Training Worker)."""
|
||||
|
||||
Reference in New Issue
Block a user