mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-06-28 02:53:30 +00:00
# What does this PR do? The `tool_name` attribute of `ToolDefinition` instances can either be a str or a BuiltinTool enum type. This fixes the remote vLLM provider to use the value of those BuiltinTool enums when serializing to JSON instead of attempting to serialize the actual enum to JSON. Reference of how this is handled in some other areas, since I followed that same pattern for the remote vLLM provider here: - [remote nvidia provider](https://github.com/meta-llama/llama-stack/blob/v0.1.3/llama_stack/providers/remote/inference/nvidia/openai_utils.py#L137-L140) - [meta reference provider](https://github.com/meta-llama/llama-stack/blob/v0.1.3/llama_stack/providers/inline/agents/meta_reference/agent_instance.py#L635-L636) There is opportunity to potentially reconcile the remove nvidia and remote vllm bits where they are both translating Llama Stack Inference APIs to OpenAI client requests, but that's a can of worms I didn't want to open for this bug fix. This explicitly fixes this error when using the remote vLLM provider and the agent tests: ``` TypeError: Object of type BuiltinTool is not JSON serializable ``` So, this is related to #1144 and addresses the immediate issue raised there. With this fix, `tests/client-sdk/agents/test_agents.py::test_builtin_tool_web_search` now gets past the JSON serialization error when using the remote vLLM provider and actually attempts to call the web search tool. I don't have any API keys setup for the actual web search providers yet, so I cannot verify everything works after that point. ## Test Plan I ran the `test_builtin_tool_web_search` locally with the remote vLLM provider like: ``` VLLM_URL="http://localhost:8000/v1" INFERENCE_MODEL="meta-llama/Llama-3.2-3B-Instruct" LLAMA_STACK_CONFIG=remote-vllm python -m pytest -v tests/client-sdk/agents/test_agents.py::test_builtin_tool_web_search --inference-model "meta-llama/Llama-3.2-3B-Instruct" ``` Before my change, that reproduced the `TypeError: Object of type BuiltinTool is not JSON serializable` error. After my change, that error is gone and the test actually attempts the web search. That failed for me locally, due to lack of API key, but it gets past the JSON serialization error. Signed-off-by: Ben Browning <bbrownin@redhat.com>
395 lines
14 KiB
Python
395 lines
14 KiB
Python
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
# All rights reserved.
|
|
#
|
|
# This source code is licensed under the terms described in the LICENSE file in
|
|
# the root directory of this source tree.
|
|
import json
|
|
import logging
|
|
from typing import AsyncGenerator, List, Optional, Union
|
|
|
|
from llama_models.datatypes import StopReason, ToolCall
|
|
from openai import OpenAI
|
|
|
|
from llama_stack.apis.common.content_types import InterleavedContent, TextDelta, ToolCallDelta, ToolCallParseStatus
|
|
from llama_stack.apis.inference import (
|
|
ChatCompletionRequest,
|
|
ChatCompletionResponse,
|
|
ChatCompletionResponseEvent,
|
|
ChatCompletionResponseEventType,
|
|
ChatCompletionResponseStreamChunk,
|
|
CompletionMessage,
|
|
CompletionRequest,
|
|
CompletionResponse,
|
|
CompletionResponseStreamChunk,
|
|
EmbeddingsResponse,
|
|
Inference,
|
|
LogProbConfig,
|
|
Message,
|
|
ResponseFormat,
|
|
ResponseFormatType,
|
|
SamplingParams,
|
|
ToolChoice,
|
|
ToolConfig,
|
|
ToolDefinition,
|
|
ToolPromptFormat,
|
|
)
|
|
from llama_stack.apis.models import Model, ModelType
|
|
from llama_stack.models.llama.datatypes import BuiltinTool
|
|
from llama_stack.models.llama.sku_list import all_registered_models
|
|
from llama_stack.providers.datatypes import ModelsProtocolPrivate
|
|
from llama_stack.providers.utils.inference.model_registry import (
|
|
ModelRegistryHelper,
|
|
build_hf_repo_model_entry,
|
|
)
|
|
from llama_stack.providers.utils.inference.openai_compat import (
|
|
OpenAICompatCompletionResponse,
|
|
UnparseableToolCall,
|
|
convert_message_to_openai_dict,
|
|
convert_tool_call,
|
|
get_sampling_options,
|
|
process_chat_completion_stream_response,
|
|
process_completion_response,
|
|
process_completion_stream_response,
|
|
)
|
|
from llama_stack.providers.utils.inference.prompt_adapter import (
|
|
completion_request_to_prompt,
|
|
content_has_media,
|
|
interleaved_content_as_str,
|
|
request_has_media,
|
|
)
|
|
|
|
from .config import VLLMInferenceAdapterConfig
|
|
|
|
log = logging.getLogger(__name__)
|
|
|
|
|
|
def build_hf_repo_model_entries():
|
|
return [
|
|
build_hf_repo_model_entry(
|
|
model.huggingface_repo,
|
|
model.descriptor(),
|
|
)
|
|
for model in all_registered_models()
|
|
if model.huggingface_repo
|
|
]
|
|
|
|
|
|
def _convert_to_vllm_tool_calls_in_response(
|
|
tool_calls,
|
|
) -> List[ToolCall]:
|
|
if not tool_calls:
|
|
return []
|
|
|
|
call_function_arguments = None
|
|
for call in tool_calls:
|
|
call_function_arguments = json.loads(call.function.arguments)
|
|
|
|
return [
|
|
ToolCall(
|
|
call_id=call.id,
|
|
tool_name=call.function.name,
|
|
arguments=call_function_arguments,
|
|
)
|
|
for call in tool_calls
|
|
]
|
|
|
|
|
|
def _convert_to_vllm_tools_in_request(tools: List[ToolDefinition]) -> List[dict]:
|
|
if tools is None:
|
|
return tools
|
|
|
|
compat_tools = []
|
|
|
|
for tool in tools:
|
|
properties = {}
|
|
compat_required = []
|
|
if tool.parameters:
|
|
for tool_key, tool_param in tool.parameters.items():
|
|
properties[tool_key] = {"type": tool_param.param_type}
|
|
if tool_param.description:
|
|
properties[tool_key]["description"] = tool_param.description
|
|
if tool_param.default:
|
|
properties[tool_key]["default"] = tool_param.default
|
|
if tool_param.required:
|
|
compat_required.append(tool_key)
|
|
|
|
# The tool.tool_name can be a str or a BuiltinTool enum. If
|
|
# it's the latter, convert to a string.
|
|
tool_name = tool.tool_name
|
|
if isinstance(tool_name, BuiltinTool):
|
|
tool_name = tool_name.value
|
|
|
|
compat_tool = {
|
|
"type": "function",
|
|
"function": {
|
|
"name": tool_name,
|
|
"description": tool.description,
|
|
"parameters": {
|
|
"type": "object",
|
|
"properties": properties,
|
|
"required": compat_required,
|
|
},
|
|
},
|
|
}
|
|
|
|
compat_tools.append(compat_tool)
|
|
|
|
if len(compat_tools) > 0:
|
|
return compat_tools
|
|
return None
|
|
|
|
|
|
def _convert_to_vllm_finish_reason(finish_reason: str) -> StopReason:
|
|
return {
|
|
"stop": StopReason.end_of_turn,
|
|
"length": StopReason.out_of_tokens,
|
|
"tool_calls": StopReason.end_of_message,
|
|
}.get(finish_reason, StopReason.end_of_turn)
|
|
|
|
|
|
async def _process_vllm_chat_completion_stream_response(
|
|
stream: AsyncGenerator[OpenAICompatCompletionResponse, None],
|
|
) -> AsyncGenerator:
|
|
event_type = ChatCompletionResponseEventType.start
|
|
tool_call_buf = UnparseableToolCall()
|
|
async for chunk in stream:
|
|
choice = chunk.choices[0]
|
|
if choice.finish_reason:
|
|
args_str = tool_call_buf.arguments
|
|
args = None
|
|
try:
|
|
args = {} if not args_str else json.loads(args_str)
|
|
except Exception as e:
|
|
log.warning(f"Failed to parse tool call buffer arguments: {args_str} \nError: {e}")
|
|
if args is not None:
|
|
yield ChatCompletionResponseStreamChunk(
|
|
event=ChatCompletionResponseEvent(
|
|
event_type=event_type,
|
|
delta=ToolCallDelta(
|
|
tool_call=ToolCall(
|
|
call_id=tool_call_buf.call_id,
|
|
tool_name=tool_call_buf.tool_name,
|
|
arguments=args,
|
|
),
|
|
parse_status=ToolCallParseStatus.succeeded,
|
|
),
|
|
)
|
|
)
|
|
else:
|
|
yield ChatCompletionResponseStreamChunk(
|
|
event=ChatCompletionResponseEvent(
|
|
event_type=ChatCompletionResponseEventType.progress,
|
|
delta=ToolCallDelta(
|
|
tool_call=str(tool_call_buf),
|
|
parse_status=ToolCallParseStatus.failed,
|
|
),
|
|
)
|
|
)
|
|
yield ChatCompletionResponseStreamChunk(
|
|
event=ChatCompletionResponseEvent(
|
|
event_type=ChatCompletionResponseEventType.complete,
|
|
delta=TextDelta(text=choice.delta.content or ""),
|
|
logprobs=None,
|
|
stop_reason=_convert_to_vllm_finish_reason(choice.finish_reason),
|
|
)
|
|
)
|
|
elif choice.delta.tool_calls:
|
|
tool_call = convert_tool_call(choice.delta.tool_calls[0])
|
|
tool_call_buf.tool_name += tool_call.tool_name
|
|
tool_call_buf.call_id += tool_call.call_id
|
|
tool_call_buf.arguments += tool_call.arguments
|
|
else:
|
|
yield ChatCompletionResponseStreamChunk(
|
|
event=ChatCompletionResponseEvent(
|
|
event_type=event_type,
|
|
delta=TextDelta(text=choice.delta.content or ""),
|
|
logprobs=None,
|
|
)
|
|
)
|
|
event_type = ChatCompletionResponseEventType.progress
|
|
|
|
|
|
class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
|
|
def __init__(self, config: VLLMInferenceAdapterConfig) -> None:
|
|
self.register_helper = ModelRegistryHelper(build_hf_repo_model_entries())
|
|
self.config = config
|
|
self.client = None
|
|
|
|
async def initialize(self) -> None:
|
|
log.info(f"Initializing VLLM client with base_url={self.config.url}")
|
|
self.client = OpenAI(base_url=self.config.url, api_key=self.config.api_token)
|
|
|
|
async def shutdown(self) -> None:
|
|
pass
|
|
|
|
async def unregister_model(self, model_id: str) -> None:
|
|
pass
|
|
|
|
async def completion(
|
|
self,
|
|
model_id: str,
|
|
content: InterleavedContent,
|
|
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
|
response_format: Optional[ResponseFormat] = None,
|
|
stream: Optional[bool] = False,
|
|
logprobs: Optional[LogProbConfig] = None,
|
|
) -> Union[CompletionResponse, CompletionResponseStreamChunk]:
|
|
model = await self.model_store.get_model(model_id)
|
|
request = CompletionRequest(
|
|
model=model.provider_resource_id,
|
|
content=content,
|
|
sampling_params=sampling_params,
|
|
response_format=response_format,
|
|
stream=stream,
|
|
logprobs=logprobs,
|
|
)
|
|
if stream:
|
|
return self._stream_completion(request)
|
|
else:
|
|
return await self._nonstream_completion(request)
|
|
|
|
async def chat_completion(
|
|
self,
|
|
model_id: str,
|
|
messages: List[Message],
|
|
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
|
response_format: Optional[ResponseFormat] = None,
|
|
tools: Optional[List[ToolDefinition]] = None,
|
|
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
|
|
tool_prompt_format: Optional[ToolPromptFormat] = None,
|
|
stream: Optional[bool] = False,
|
|
logprobs: Optional[LogProbConfig] = None,
|
|
tool_config: Optional[ToolConfig] = None,
|
|
) -> AsyncGenerator:
|
|
model = await self.model_store.get_model(model_id)
|
|
request = ChatCompletionRequest(
|
|
model=model.provider_resource_id,
|
|
messages=messages,
|
|
sampling_params=sampling_params,
|
|
tools=tools or [],
|
|
stream=stream,
|
|
logprobs=logprobs,
|
|
response_format=response_format,
|
|
tool_config=tool_config,
|
|
)
|
|
if stream:
|
|
return self._stream_chat_completion(request, self.client)
|
|
else:
|
|
return await self._nonstream_chat_completion(request, self.client)
|
|
|
|
async def _nonstream_chat_completion(
|
|
self, request: ChatCompletionRequest, client: OpenAI
|
|
) -> ChatCompletionResponse:
|
|
params = await self._get_params(request)
|
|
r = client.chat.completions.create(**params)
|
|
choice = r.choices[0]
|
|
result = ChatCompletionResponse(
|
|
completion_message=CompletionMessage(
|
|
content=choice.message.content or "",
|
|
stop_reason=_convert_to_vllm_finish_reason(choice.finish_reason),
|
|
tool_calls=_convert_to_vllm_tool_calls_in_response(choice.message.tool_calls),
|
|
),
|
|
logprobs=None,
|
|
)
|
|
return result
|
|
|
|
async def _stream_chat_completion(self, request: ChatCompletionRequest, client: OpenAI) -> AsyncGenerator:
|
|
params = await self._get_params(request)
|
|
|
|
# TODO: Can we use client.completions.acreate() or maybe there is another way to directly create an async
|
|
# generator so this wrapper is not necessary?
|
|
async def _to_async_generator():
|
|
s = client.chat.completions.create(**params)
|
|
for chunk in s:
|
|
yield chunk
|
|
|
|
stream = _to_async_generator()
|
|
if len(request.tools) > 0:
|
|
res = _process_vllm_chat_completion_stream_response(stream)
|
|
else:
|
|
res = process_chat_completion_stream_response(stream, request)
|
|
async for chunk in res:
|
|
yield chunk
|
|
|
|
async def _nonstream_completion(self, request: CompletionRequest) -> CompletionResponse:
|
|
params = await self._get_params(request)
|
|
r = self.client.completions.create(**params)
|
|
return process_completion_response(r)
|
|
|
|
async def _stream_completion(self, request: CompletionRequest) -> AsyncGenerator:
|
|
params = await self._get_params(request)
|
|
|
|
# Wrapper for async generator similar
|
|
async def _to_async_generator():
|
|
stream = self.client.completions.create(**params)
|
|
for chunk in stream:
|
|
yield chunk
|
|
|
|
stream = _to_async_generator()
|
|
async for chunk in process_completion_stream_response(stream):
|
|
yield chunk
|
|
|
|
async def register_model(self, model: Model) -> Model:
|
|
model = await self.register_helper.register_model(model)
|
|
res = self.client.models.list()
|
|
available_models = [m.id for m in res]
|
|
if model.provider_resource_id not in available_models:
|
|
raise ValueError(
|
|
f"Model {model.provider_resource_id} is not being served by vLLM. "
|
|
f"Available models: {', '.join(available_models)}"
|
|
)
|
|
return model
|
|
|
|
async def _get_params(self, request: Union[ChatCompletionRequest, CompletionRequest]) -> dict:
|
|
options = get_sampling_options(request.sampling_params)
|
|
if "max_tokens" not in options:
|
|
options["max_tokens"] = self.config.max_tokens
|
|
|
|
input_dict = {}
|
|
if isinstance(request, ChatCompletionRequest) and request.tools is not None:
|
|
input_dict = {"tools": _convert_to_vllm_tools_in_request(request.tools)}
|
|
|
|
if isinstance(request, ChatCompletionRequest):
|
|
input_dict["messages"] = [await convert_message_to_openai_dict(m, download=True) for m in request.messages]
|
|
else:
|
|
assert not request_has_media(request), "vLLM does not support media for Completion requests"
|
|
input_dict["prompt"] = await completion_request_to_prompt(request)
|
|
|
|
if fmt := request.response_format:
|
|
if fmt.type == ResponseFormatType.json_schema.value:
|
|
input_dict["extra_body"] = {"guided_json": request.response_format.json_schema}
|
|
elif fmt.type == ResponseFormatType.grammar.value:
|
|
raise NotImplementedError("Grammar response format not supported yet")
|
|
else:
|
|
raise ValueError(f"Unknown response format {fmt.type}")
|
|
|
|
if request.logprobs and request.logprobs.top_k:
|
|
input_dict["logprobs"] = request.logprobs.top_k
|
|
|
|
return {
|
|
"model": request.model,
|
|
**input_dict,
|
|
"stream": request.stream,
|
|
**options,
|
|
}
|
|
|
|
async def embeddings(
|
|
self,
|
|
model_id: str,
|
|
contents: List[InterleavedContent],
|
|
) -> EmbeddingsResponse:
|
|
model = await self.model_store.get_model(model_id)
|
|
|
|
kwargs = {}
|
|
assert model.model_type == ModelType.embedding
|
|
assert model.metadata.get("embedding_dimension")
|
|
kwargs["dimensions"] = model.metadata.get("embedding_dimension")
|
|
assert all(not content_has_media(content) for content in contents), "VLLM does not support media for embeddings"
|
|
response = self.client.embeddings.create(
|
|
model=model.provider_resource_id,
|
|
input=[interleaved_content_as_str(content) for content in contents],
|
|
**kwargs,
|
|
)
|
|
|
|
embeddings = [data.embedding for data in response.data]
|
|
return EmbeddingsResponse(embeddings=embeddings)
|