mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-10 20:19:22 +00:00
fix: Handle tool calling in remote vLLM provider
Signed-off-by: Yuan Tang <terrytangyuan@gmail.com>
This commit is contained in:
parent
b981b49bfa
commit
cc3bb0938a
1 changed files with 39 additions and 0 deletions
|
@ -187,12 +187,51 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
|
||||||
)
|
)
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
def convert_to_vllm_tools(self, tools: List[ToolDefinition]) -> List[dict]:
|
||||||
|
if tools is None:
|
||||||
|
return tools
|
||||||
|
|
||||||
|
compat_tools = []
|
||||||
|
|
||||||
|
for tool in tools:
|
||||||
|
properties = {}
|
||||||
|
compat_required = []
|
||||||
|
if tool.parameters:
|
||||||
|
for tool_key, tool_param in tool.parameters.items():
|
||||||
|
properties[tool_key] = {"type": tool_param.param_type}
|
||||||
|
if tool_param.description:
|
||||||
|
properties[tool_key]["description"] = tool_param.description
|
||||||
|
if tool_param.default:
|
||||||
|
properties[tool_key]["default"] = tool_param.default
|
||||||
|
if tool_param.required:
|
||||||
|
compat_required.append(tool_key)
|
||||||
|
|
||||||
|
compat_tool = {
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": tool.tool_name,
|
||||||
|
"description": tool.description,
|
||||||
|
"parameters": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": properties,
|
||||||
|
"required": compat_required,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
compat_tools.append(compat_tool)
|
||||||
|
|
||||||
|
if len(compat_tools) > 0:
|
||||||
|
return compat_tools
|
||||||
|
return None
|
||||||
|
|
||||||
async def _get_params(self, request: Union[ChatCompletionRequest, CompletionRequest]) -> dict:
|
async def _get_params(self, request: Union[ChatCompletionRequest, CompletionRequest]) -> dict:
|
||||||
options = get_sampling_options(request.sampling_params)
|
options = get_sampling_options(request.sampling_params)
|
||||||
if "max_tokens" not in options:
|
if "max_tokens" not in options:
|
||||||
options["max_tokens"] = self.config.max_tokens
|
options["max_tokens"] = self.config.max_tokens
|
||||||
|
|
||||||
input_dict = {}
|
input_dict = {}
|
||||||
|
input_dict["tools"] = self.convert_to_vllm_tools(request.tools)
|
||||||
|
|
||||||
if isinstance(request, ChatCompletionRequest):
|
if isinstance(request, ChatCompletionRequest):
|
||||||
input_dict["messages"] = [await convert_message_to_openai_dict(m, download=True) for m in request.messages]
|
input_dict["messages"] = [await convert_message_to_openai_dict(m, download=True) for m in request.messages]
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue