自定义 bert 在 onnxruntime 推理错误:TypeError: run(): incompatible function arguments
2024-01-09 13:38:46
自定义 bert 在 onnxruntime 推理错误:TypeError: run(): incompatible function arguments
自定义 bert 在 onnxruntime 推理错误:TypeError: run(): incompatible function arguments
推理代码
# text embedding
toks = self.tokenizer([text])
if self.debug:
print('toks', toks)
text_embed = self.text_model_session.run(output_names=['output'], input_feed=toks)
错误提示
Traceback (most recent call last):
File "/xx/workspace/model/test_onnx.py", line 90, in <module>
res = inferencer.inference(text, img_path)
File "/xx/workspace/model/test_onnx.py", line 58, in inference
text_embed = self.text_model_session.run(output_names=['output'], input_feed=toks)
File "/xx/miniconda3/envs/py39/lib/python3.9/site-packages/onnxruntime/capi/onnxruntime_inference_collection.py", line 220, in run
return self._sess.run(output_names, input_feed, run_options)
TypeError: run(): incompatible function arguments. The following argument types are supported:
1. (self: onnxruntime.capi.onnxruntime_pybind11_state.InferenceSession, arg0: List[str], arg1: Dict[str, object], arg2: onnxruntime.capi.onnxruntime_pybind11_state.RunOptions) -> List[object]
Invoked with: <onnxruntime.capi.onnxruntime_pybind11_state.InferenceSession object at 0x7f975ded1570>, ['output'], {'input_ids': array([[ 101, 3899, 102]]), 'token_type_ids': array([[0, 0, 0]]), 'attention_mask': array([[1, 1, 1]])}, None
核心错误
TypeError: run(): incompatible function arguments. The following argument types are supported:
1. (self: onnxruntime.capi.onnxruntime_pybind11_state.InferenceSession, arg0: List[str], arg1: Dict[str, object], arg2: onnxruntime.capi.onnxruntime_pybind11_state.RunOptions) -> List[object]
解决方法
核对参数
arg0: List[str]
arg1: Dict[str, object]
对应的参数
output_names=['output'], input_feed=toks
arg0=[‘output’] 参数类型正确
arg1=toks 表面看参数也正常,打印看看toks的每个值的类型
type(toks[‘input_ids’]) 输出为 <class ‘torch.Tensor’>, 实际需要输入类型为 <class ‘numpy.ndarray’>
修改代码
# text embedding
toks = self.tokenizer([text])
if self.debug:
print('toks', toks)
text_input = {}
text_input['input_ids'] = toks['input_ids'].numpy()
text_input['token_type_ids'] = toks['token_type_ids'].numpy()
text_input['attention_mask'] = toks['attention_mask'].numpy()
text_embed = self.text_model_session.run(output_names=['output'], input_feed=text_input)
再次执行代码,正常运行,无报错!!
文章来源:https://blog.csdn.net/zengNLP/article/details/135477431
本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。 如若内容造成侵权/违法违规/事实不符,请联系我的编程经验分享网邮箱:veading@qq.com进行投诉反馈,一经查实,立即删除!
本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。 如若内容造成侵权/违法违规/事实不符,请联系我的编程经验分享网邮箱:veading@qq.com进行投诉反馈,一经查实,立即删除!