llama-stack/llama_stack/providers/remote/inference/vllm/vllm.py
Ben Browning c64f0d5888
fix: Get builtin tool calling working in remote-vllm (#1236)
# What does this PR do?

This PR makes a couple of changes required to get the test
`tests/client-sdk/agents/test_agents.py::test_builtin_tool_web_search`
passing on the remote-vllm provider.

First, we adjust agent_instance to also pass in the description and
parameters of builtin tools. We need these parameters so we can pass the
tool's expected parameters into vLLM. The meta-reference implementations
may not have needed these for builtin tools, as they are able to take
advantage of the Llama-model specific support for certain builtin tools.
However, with vLLM, our server-side chat templates for tool calling
treat all tools the same and don't separate out Llama builtin vs custom
tools. So, we need to pass the full set of parameter definitions and
list of required parameters for builtin tools as well.

Next, we adjust the vllm streaming chat completion code to fix up some
edge cases where it was returning an extra ChatCompletionResponseEvent
with an empty ToolCall with empty string call_id, tool_name, and
arguments properties. This is a bug discovered after the above fix,
where after a successful tool invocation we were sending extra chunks
back to the client with these empty ToolCalls.

## Test Plan

With these changes, the following test that previously failed now
passes:

```
VLLM_URL="http://localhost:8000/v1" \
INFERENCE_MODEL="meta-llama/Llama-3.2-3B-Instruct" \
LLAMA_STACK_CONFIG=remote-vllm \
python -m pytest -v \
tests/client-sdk/agents/test_agents.py::test_builtin_tool_web_search \
--inference-model "meta-llama/Llama-3.2-3B-Instruct"
```

Additionally, I ran the remote-vllm client-sdk and provider inference
tests as below to ensure they all still passed with this change:

```
VLLM_URL="http://localhost:8000/v1" \
INFERENCE_MODEL="meta-llama/Llama-3.2-3B-Instruct" \
LLAMA_STACK_CONFIG=remote-vllm \
python -m pytest -v \
tests/client-sdk/inference/test_text_inference.py \
--inference-model "meta-llama/Llama-3.2-3B-Instruct"
```

```
VLLM_URL="http://localhost:8000/v1" \
python -m pytest -s -v \
llama_stack/providers/tests/inference/test_text_inference.py \
--providers "inference=vllm_remote"
```


[//]: # (## Documentation)

Signed-off-by: Ben Browning <bbrownin@redhat.com>
2025-02-26 15:25:47 -05:00

406 lines
15 KiB
Python

# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import json
import logging
from typing import AsyncGenerator, List, Optional, Union
from llama_models.datatypes import StopReason, ToolCall
from openai import OpenAI
from llama_stack.apis.common.content_types import (
InterleavedContent,
InterleavedContentItem,
TextDelta,
ToolCallDelta,
ToolCallParseStatus,
)
from llama_stack.apis.inference import (
ChatCompletionRequest,
ChatCompletionResponse,
ChatCompletionResponseEvent,
ChatCompletionResponseEventType,
ChatCompletionResponseStreamChunk,
CompletionMessage,
CompletionRequest,
CompletionResponse,
CompletionResponseStreamChunk,
EmbeddingsResponse,
EmbeddingTaskType,
Inference,
LogProbConfig,
Message,
ResponseFormat,
ResponseFormatType,
SamplingParams,
TextTruncation,
ToolChoice,
ToolConfig,
ToolDefinition,
ToolPromptFormat,
)
from llama_stack.apis.models import Model, ModelType
from llama_stack.models.llama.datatypes import BuiltinTool
from llama_stack.models.llama.sku_list import all_registered_models
from llama_stack.providers.datatypes import ModelsProtocolPrivate
from llama_stack.providers.utils.inference.model_registry import (
ModelRegistryHelper,
build_hf_repo_model_entry,
)
from llama_stack.providers.utils.inference.openai_compat import (
OpenAICompatCompletionResponse,
UnparseableToolCall,
convert_message_to_openai_dict,
convert_tool_call,
get_sampling_options,
process_chat_completion_stream_response,
process_completion_response,
process_completion_stream_response,
)
from llama_stack.providers.utils.inference.prompt_adapter import (
completion_request_to_prompt,
content_has_media,
interleaved_content_as_str,
request_has_media,
)
from .config import VLLMInferenceAdapterConfig
log = logging.getLogger(__name__)
def build_hf_repo_model_entries():
return [
build_hf_repo_model_entry(
model.huggingface_repo,
model.descriptor(),
)
for model in all_registered_models()
if model.huggingface_repo
]
def _convert_to_vllm_tool_calls_in_response(
tool_calls,
) -> List[ToolCall]:
if not tool_calls:
return []
call_function_arguments = None
for call in tool_calls:
call_function_arguments = json.loads(call.function.arguments)
return [
ToolCall(
call_id=call.id,
tool_name=call.function.name,
arguments=call_function_arguments,
)
for call in tool_calls
]
def _convert_to_vllm_tools_in_request(tools: List[ToolDefinition]) -> List[dict]:
if tools is None:
return tools
compat_tools = []
for tool in tools:
properties = {}
compat_required = []
if tool.parameters:
for tool_key, tool_param in tool.parameters.items():
properties[tool_key] = {"type": tool_param.param_type}
if tool_param.description:
properties[tool_key]["description"] = tool_param.description
if tool_param.default:
properties[tool_key]["default"] = tool_param.default
if tool_param.required:
compat_required.append(tool_key)
# The tool.tool_name can be a str or a BuiltinTool enum. If
# it's the latter, convert to a string.
tool_name = tool.tool_name
if isinstance(tool_name, BuiltinTool):
tool_name = tool_name.value
compat_tool = {
"type": "function",
"function": {
"name": tool_name,
"description": tool.description,
"parameters": {
"type": "object",
"properties": properties,
"required": compat_required,
},
},
}
compat_tools.append(compat_tool)
if len(compat_tools) > 0:
return compat_tools
return None
def _convert_to_vllm_finish_reason(finish_reason: str) -> StopReason:
return {
"stop": StopReason.end_of_turn,
"length": StopReason.out_of_tokens,
"tool_calls": StopReason.end_of_message,
}.get(finish_reason, StopReason.end_of_turn)
async def _process_vllm_chat_completion_stream_response(
stream: AsyncGenerator[OpenAICompatCompletionResponse, None],
) -> AsyncGenerator:
event_type = ChatCompletionResponseEventType.start
tool_call_buf = UnparseableToolCall()
async for chunk in stream:
choice = chunk.choices[0]
if choice.finish_reason:
args_str = tool_call_buf.arguments
args = None
try:
args = {} if not args_str else json.loads(args_str)
except Exception as e:
log.warning(f"Failed to parse tool call buffer arguments: {args_str} \nError: {e}")
if args:
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=event_type,
delta=ToolCallDelta(
tool_call=ToolCall(
call_id=tool_call_buf.call_id,
tool_name=tool_call_buf.tool_name,
arguments=args,
),
parse_status=ToolCallParseStatus.succeeded,
),
)
)
elif args_str:
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.progress,
delta=ToolCallDelta(
tool_call=str(tool_call_buf),
parse_status=ToolCallParseStatus.failed,
),
)
)
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.complete,
delta=TextDelta(text=choice.delta.content or ""),
logprobs=None,
stop_reason=_convert_to_vllm_finish_reason(choice.finish_reason),
)
)
elif choice.delta.tool_calls:
tool_call = convert_tool_call(choice.delta.tool_calls[0])
tool_call_buf.tool_name += tool_call.tool_name
tool_call_buf.call_id += tool_call.call_id
tool_call_buf.arguments += tool_call.arguments
else:
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=event_type,
delta=TextDelta(text=choice.delta.content or ""),
logprobs=None,
)
)
event_type = ChatCompletionResponseEventType.progress
class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
def __init__(self, config: VLLMInferenceAdapterConfig) -> None:
self.register_helper = ModelRegistryHelper(build_hf_repo_model_entries())
self.config = config
self.client = None
async def initialize(self) -> None:
log.info(f"Initializing VLLM client with base_url={self.config.url}")
self.client = OpenAI(base_url=self.config.url, api_key=self.config.api_token)
async def shutdown(self) -> None:
pass
async def unregister_model(self, model_id: str) -> None:
pass
async def completion(
self,
model_id: str,
content: InterleavedContent,
sampling_params: Optional[SamplingParams] = SamplingParams(),
response_format: Optional[ResponseFormat] = None,
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
) -> Union[CompletionResponse, CompletionResponseStreamChunk]:
model = await self.model_store.get_model(model_id)
request = CompletionRequest(
model=model.provider_resource_id,
content=content,
sampling_params=sampling_params,
response_format=response_format,
stream=stream,
logprobs=logprobs,
)
if stream:
return self._stream_completion(request)
else:
return await self._nonstream_completion(request)
async def chat_completion(
self,
model_id: str,
messages: List[Message],
sampling_params: Optional[SamplingParams] = SamplingParams(),
response_format: Optional[ResponseFormat] = None,
tools: Optional[List[ToolDefinition]] = None,
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
tool_prompt_format: Optional[ToolPromptFormat] = None,
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
tool_config: Optional[ToolConfig] = None,
) -> AsyncGenerator:
model = await self.model_store.get_model(model_id)
request = ChatCompletionRequest(
model=model.provider_resource_id,
messages=messages,
sampling_params=sampling_params,
tools=tools or [],
stream=stream,
logprobs=logprobs,
response_format=response_format,
tool_config=tool_config,
)
if stream:
return self._stream_chat_completion(request, self.client)
else:
return await self._nonstream_chat_completion(request, self.client)
async def _nonstream_chat_completion(
self, request: ChatCompletionRequest, client: OpenAI
) -> ChatCompletionResponse:
params = await self._get_params(request)
r = client.chat.completions.create(**params)
choice = r.choices[0]
result = ChatCompletionResponse(
completion_message=CompletionMessage(
content=choice.message.content or "",
stop_reason=_convert_to_vllm_finish_reason(choice.finish_reason),
tool_calls=_convert_to_vllm_tool_calls_in_response(choice.message.tool_calls),
),
logprobs=None,
)
return result
async def _stream_chat_completion(self, request: ChatCompletionRequest, client: OpenAI) -> AsyncGenerator:
params = await self._get_params(request)
# TODO: Can we use client.completions.acreate() or maybe there is another way to directly create an async
# generator so this wrapper is not necessary?
async def _to_async_generator():
s = client.chat.completions.create(**params)
for chunk in s:
yield chunk
stream = _to_async_generator()
if len(request.tools) > 0:
res = _process_vllm_chat_completion_stream_response(stream)
else:
res = process_chat_completion_stream_response(stream, request)
async for chunk in res:
yield chunk
async def _nonstream_completion(self, request: CompletionRequest) -> CompletionResponse:
params = await self._get_params(request)
r = self.client.completions.create(**params)
return process_completion_response(r)
async def _stream_completion(self, request: CompletionRequest) -> AsyncGenerator:
params = await self._get_params(request)
# Wrapper for async generator similar
async def _to_async_generator():
stream = self.client.completions.create(**params)
for chunk in stream:
yield chunk
stream = _to_async_generator()
async for chunk in process_completion_stream_response(stream):
yield chunk
async def register_model(self, model: Model) -> Model:
model = await self.register_helper.register_model(model)
res = self.client.models.list()
available_models = [m.id for m in res]
if model.provider_resource_id not in available_models:
raise ValueError(
f"Model {model.provider_resource_id} is not being served by vLLM. "
f"Available models: {', '.join(available_models)}"
)
return model
async def _get_params(self, request: Union[ChatCompletionRequest, CompletionRequest]) -> dict:
options = get_sampling_options(request.sampling_params)
if "max_tokens" not in options:
options["max_tokens"] = self.config.max_tokens
input_dict = {}
if isinstance(request, ChatCompletionRequest) and request.tools is not None:
input_dict = {"tools": _convert_to_vllm_tools_in_request(request.tools)}
if isinstance(request, ChatCompletionRequest):
input_dict["messages"] = [await convert_message_to_openai_dict(m, download=True) for m in request.messages]
else:
assert not request_has_media(request), "vLLM does not support media for Completion requests"
input_dict["prompt"] = await completion_request_to_prompt(request)
if fmt := request.response_format:
if fmt.type == ResponseFormatType.json_schema.value:
input_dict["extra_body"] = {"guided_json": request.response_format.json_schema}
elif fmt.type == ResponseFormatType.grammar.value:
raise NotImplementedError("Grammar response format not supported yet")
else:
raise ValueError(f"Unknown response format {fmt.type}")
if request.logprobs and request.logprobs.top_k:
input_dict["logprobs"] = request.logprobs.top_k
return {
"model": request.model,
**input_dict,
"stream": request.stream,
**options,
}
async def embeddings(
self,
model_id: str,
contents: List[str] | List[InterleavedContentItem],
text_truncation: Optional[TextTruncation] = TextTruncation.none,
output_dimension: Optional[int] = None,
task_type: Optional[EmbeddingTaskType] = None,
) -> EmbeddingsResponse:
model = await self.model_store.get_model(model_id)
kwargs = {}
assert model.model_type == ModelType.embedding
assert model.metadata.get("embedding_dimension")
kwargs["dimensions"] = model.metadata.get("embedding_dimension")
assert all(not content_has_media(content) for content in contents), "VLLM does not support media for embeddings"
response = self.client.embeddings.create(
model=model.provider_resource_id,
input=[interleaved_content_as_str(content) for content in contents],
**kwargs,
)
embeddings = [data.embedding for data in response.data]
return EmbeddingsResponse(embeddings=embeddings)