diff --git a/llama_stack/providers/remote/inference/vllm/vllm.py b/llama_stack/providers/remote/inference/vllm/vllm.py index 482e6fa97..1124afc7f 100644 --- a/llama_stack/providers/remote/inference/vllm/vllm.py +++ b/llama_stack/providers/remote/inference/vllm/vllm.py @@ -187,12 +187,51 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate): ) return model + def convert_to_vllm_tools(self, tools: List[ToolDefinition]) -> List[dict]: + if tools is None: + return tools + + compat_tools = [] + + for tool in tools: + properties = {} + compat_required = [] + if tool.parameters: + for tool_key, tool_param in tool.parameters.items(): + properties[tool_key] = {"type": tool_param.param_type} + if tool_param.description: + properties[tool_key]["description"] = tool_param.description + if tool_param.default: + properties[tool_key]["default"] = tool_param.default + if tool_param.required: + compat_required.append(tool_key) + + compat_tool = { + "type": "function", + "function": { + "name": tool.tool_name, + "description": tool.description, + "parameters": { + "type": "object", + "properties": properties, + "required": compat_required, + }, + }, + } + + compat_tools.append(compat_tool) + + if len(compat_tools) > 0: + return compat_tools + return None + async def _get_params(self, request: Union[ChatCompletionRequest, CompletionRequest]) -> dict: options = get_sampling_options(request.sampling_params) if "max_tokens" not in options: options["max_tokens"] = self.config.max_tokens input_dict = {} + input_dict["tools"] = self.convert_to_vllm_tools(request.tools) if isinstance(request, ChatCompletionRequest): input_dict["messages"] = [await convert_message_to_openai_dict(m, download=True) for m in request.messages]