mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2026-04-23 17:11:21 +08:00
[Other] Adjust GPUModelRunner to enhance compatibility (#6851)
This commit is contained in:
@@ -18,6 +18,7 @@ from __future__ import annotations
|
||||
|
||||
import re
|
||||
from functools import partial
|
||||
from typing import Dict
|
||||
|
||||
import paddle
|
||||
from paddle import nn
|
||||
@@ -420,7 +421,7 @@ class Qwen3MoeForCausalLM(ModelForCasualLM):
|
||||
self.model.load_state_dict(state_dict)
|
||||
self.lm_head.load_state_dict(state_dict)
|
||||
|
||||
def compute_logits(self, hidden_states: paddle.Tensor):
|
||||
def compute_logits(self, hidden_states: paddle.Tensor, forward_meta: ForwardMeta = None):
|
||||
""" """
|
||||
logits = self.lm_head(hidden_states)
|
||||
logits = logits.astype(paddle.float32)
|
||||
@@ -444,10 +445,10 @@ class Qwen3MoeForCausalLM(ModelForCasualLM):
|
||||
|
||||
def forward(
|
||||
self,
|
||||
ids_remove_padding: paddle.Tensor,
|
||||
inputs: Dict,
|
||||
forward_meta: ForwardMeta,
|
||||
):
|
||||
""" """
|
||||
ids_remove_padding = inputs["ids_remove_padding"]
|
||||
hidden_states = self.model(ids_remove_padding=ids_remove_padding, forward_meta=forward_meta)
|
||||
|
||||
return hidden_states
|
||||
|
||||
Reference in New Issue
Block a user