mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2026-04-23 17:11:21 +08:00
add qwen-2.5-7B-PRM/ernie-rm (#4319)
This commit is contained in:
@@ -44,6 +44,12 @@ from fastdeploy.model_executor.models.model_base import (
|
||||
ModelForCasualLM,
|
||||
ModelRegistry,
|
||||
)
|
||||
from fastdeploy.model_executor.utils import (
|
||||
WeightsMapper,
|
||||
default_weight_loader,
|
||||
process_weights_after_loading,
|
||||
process_weights_before_loading,
|
||||
)
|
||||
|
||||
|
||||
class Qwen2MLP(nn.Layer):
|
||||
@@ -316,6 +322,14 @@ class Qwen2ForCausalLM(ModelForCasualLM):
|
||||
prefix="lm_head",
|
||||
)
|
||||
|
||||
self.process_weights_before_loading_fn = process_weights_before_loading(
|
||||
mapper=(
|
||||
WeightsMapper(orig_to_new_prefix={"model.": "qwen2."})
|
||||
if self.fd_config.model_config.model_format == "torch"
|
||||
else None
|
||||
),
|
||||
)
|
||||
|
||||
@paddle.no_grad()
|
||||
def load_weights(self, weights_iterator) -> None:
|
||||
"""
|
||||
@@ -325,11 +339,6 @@ class Qwen2ForCausalLM(ModelForCasualLM):
|
||||
weights_iterator (Iterator): An iterator yielding (name, weight) pairs.
|
||||
"""
|
||||
|
||||
from fastdeploy.model_executor.utils import (
|
||||
default_weight_loader,
|
||||
process_weights_after_loading,
|
||||
)
|
||||
|
||||
stacked_params_mapping = [
|
||||
# (param_name, shard_name, shard_id)
|
||||
("qkv_proj", "q_proj", "q"),
|
||||
@@ -344,10 +353,13 @@ class Qwen2ForCausalLM(ModelForCasualLM):
|
||||
params_dict = dict(self.named_parameters())
|
||||
process_weights_after_loading_fn = process_weights_after_loading(dict(self.named_sublayers()))
|
||||
for loaded_weight_name, loaded_weight in weights_iterator:
|
||||
model_format = self.fd_config.model_config.model_format
|
||||
# Because the prefix for Paddle is qwen2, and for Hugging Face it is model.
|
||||
if model_format == "torch":
|
||||
loaded_weight_name = loaded_weight_name.replace("model", "qwen2")
|
||||
loaded_weight_name = (
|
||||
self.process_weights_before_loading_fn(loaded_weight_name)
|
||||
if getattr(self, "process_weights_before_loading_fn", None)
|
||||
else loaded_weight_name
|
||||
)
|
||||
if loaded_weight_name is None:
|
||||
continue
|
||||
for param_name, weight_name, shard_id in stacked_params_mapping:
|
||||
if weight_name not in loaded_weight_name:
|
||||
continue
|
||||
|
||||
Reference in New Issue
Block a user