diff --git a/llama_stack/providers/remote/inference/groq/groq_utils.py b/llama_stack/providers/remote/inference/groq/groq_utils.py index 537043d69..d00e5c5a9 100644 --- a/llama_stack/providers/remote/inference/groq/groq_utils.py +++ b/llama_stack/providers/remote/inference/groq/groq_utils.py @@ -6,7 +6,7 @@ import json import warnings -from typing import AsyncGenerator, Literal, Union +from typing import AsyncGenerator, Literal from groq import Stream from groq.types.chat.chat_completion import ChatCompletion @@ -15,9 +15,6 @@ from groq.types.chat.chat_completion_assistant_message_param import ( ) from groq.types.chat.chat_completion_chunk import ChatCompletionChunk from groq.types.chat.chat_completion_message_param import ChatCompletionMessageParam -from groq.types.chat.chat_completion_message_tool_call import ( - ChatCompletionMessageToolCall, -) from groq.types.chat.chat_completion_system_message_param import ( ChatCompletionSystemMessageParam, ) @@ -30,7 +27,6 @@ from groq.types.shared.function_definition import FunctionDefinition from llama_models.llama3.api.datatypes import ToolParamDefinition -from pydantic import BaseModel from llama_stack.apis.common.content_types import ( TextDelta, @@ -52,6 +48,8 @@ from llama_stack.apis.inference import ( ) from llama_stack.providers.utils.inference.openai_compat import ( get_sampling_strategy_options, + convert_tool_call, + UnparseableToolCall, ) @@ -143,7 +141,7 @@ def convert_chat_completion_response( # groq only supports n=1 at time of writing, so there is only one choice choice = response.choices[0] if choice.finish_reason == "tool_calls": - tool_calls = [_convert_groq_tool_call(tool_call) for tool_call in choice.message.tool_calls] + tool_calls = [convert_tool_call(tool_call) for tool_call in choice.message.tool_calls] if any(isinstance(tool_call, UnparseableToolCall) for tool_call in tool_calls): # If we couldn't parse a tool call, jsonify the tool calls and return them return ChatCompletionResponse( @@ -216,7 +214,7 @@ async def convert_chat_completion_response_stream( warnings.warn("Groq returned multiple tool calls in one chunk. Using the first one, ignoring the rest.") # We assume Groq produces fully formed tool calls for each chunk - tool_call = _convert_groq_tool_call(choice.delta.tool_calls[0]) + tool_call = convert_tool_call(choice.delta.tool_calls[0]) if isinstance(tool_call, ToolCall): yield ChatCompletionResponseStreamChunk( event=ChatCompletionResponseEvent( @@ -247,37 +245,3 @@ async def convert_chat_completion_response_stream( ) ) event_type = ChatCompletionResponseEventType.progress - - -class UnparseableToolCall(BaseModel): - """ - A ToolCall with arguments that are not valid JSON. - Mirrors the ToolCall schema, but with arguments as a string. - """ - - call_id: str - tool_name: str - arguments: str - - -def _convert_groq_tool_call( - tool_call: ChatCompletionMessageToolCall, -) -> Union[ToolCall, UnparseableToolCall]: - """ - Convert a Groq tool call to a ToolCall. - Returns an UnparseableToolCall if the tool call is not valid JSON. - """ - try: - arguments = json.loads(tool_call.function.arguments) - except Exception as e: - return UnparseableToolCall( - call_id=tool_call.id, - tool_name=tool_call.function.name, - arguments=tool_call.function.arguments, - ) - - return ToolCall( - call_id=tool_call.id, - tool_name=tool_call.function.name, - arguments=arguments, - ) diff --git a/llama_stack/providers/remote/inference/vllm/vllm.py b/llama_stack/providers/remote/inference/vllm/vllm.py index 2e13a6262..02594891b 100644 --- a/llama_stack/providers/remote/inference/vllm/vllm.py +++ b/llama_stack/providers/remote/inference/vllm/vllm.py @@ -13,7 +13,7 @@ from llama_models.llama3.api.tokenizer import Tokenizer from llama_models.sku_list import all_registered_models from openai import OpenAI -from llama_stack.apis.common.content_types import InterleavedContent +from llama_stack.apis.common.content_types import InterleavedContent, ToolCallDelta, ToolCallParseStatus, TextDelta from llama_stack.apis.inference import ( ChatCompletionRequest, ChatCompletionResponse, @@ -32,6 +32,9 @@ from llama_stack.apis.inference import ( ToolDefinition, ToolPromptFormat, CompletionMessage, + ChatCompletionResponseEventType, + ChatCompletionResponseStreamChunk, + ChatCompletionResponseEvent, ) from llama_stack.apis.models import Model, ModelType from llama_stack.providers.datatypes import ModelsProtocolPrivate @@ -42,9 +45,12 @@ from llama_stack.providers.utils.inference.model_registry import ( from llama_stack.providers.utils.inference.openai_compat import ( convert_message_to_openai_dict, get_sampling_options, - process_chat_completion_stream_response, process_completion_response, process_completion_stream_response, + OpenAICompatCompletionResponse, + UnparseableToolCall, + convert_tool_call, + process_chat_completion_stream_response, ) from llama_stack.providers.utils.inference.prompt_adapter import ( completion_request_to_prompt, @@ -136,6 +142,51 @@ def _convert_to_vllm_finish_reason(finish_reason: str) -> StopReason: }.get(finish_reason, StopReason.end_of_turn) +async def _process_vllm_chat_completion_stream_response( + stream: AsyncGenerator[OpenAICompatCompletionResponse, None], +) -> AsyncGenerator: + event_type = ChatCompletionResponseEventType.start + tool_call_buf = UnparseableToolCall() + async for chunk in stream: + choice = chunk.choices[0] + if choice.finish_reason: + yield ChatCompletionResponseStreamChunk( + event=ChatCompletionResponseEvent( + event_type=event_type, + delta=ToolCallDelta( + tool_call=ToolCall( + call_id=tool_call_buf.call_id, + tool_name=tool_call_buf.tool_name, + arguments=json.loads(tool_call_buf.arguments), + ), + parse_status=ToolCallParseStatus.succeeded, + ), + ) + ) + yield ChatCompletionResponseStreamChunk( + event=ChatCompletionResponseEvent( + event_type=ChatCompletionResponseEventType.complete, + delta=TextDelta(text=choice.delta.content or ""), + logprobs=None, + stop_reason=_convert_to_vllm_finish_reason(choice.finish_reason), + ) + ) + elif choice.delta.tool_calls: + tool_call = convert_tool_call(choice.delta.tool_calls[0]) + tool_call_buf.tool_name += tool_call.tool_name + tool_call_buf.call_id += tool_call.call_id + tool_call_buf.arguments += tool_call.arguments + else: + yield ChatCompletionResponseStreamChunk( + event=ChatCompletionResponseEvent( + event_type=event_type, + delta=TextDelta(text=choice.delta.content or ""), + logprobs=None, + ) + ) + event_type = ChatCompletionResponseEventType.progress + + class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate): def __init__(self, config: VLLMInferenceAdapterConfig) -> None: self.register_helper = ModelRegistryHelper(build_model_aliases()) @@ -232,7 +283,11 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate): yield chunk stream = _to_async_generator() - async for chunk in process_chat_completion_stream_response(stream, self.formatter, request): + if len(request.tools) > 0: + res = _process_vllm_chat_completion_stream_response(stream) + else: + res = process_chat_completion_stream_response(stream, self.formatter, request) + async for chunk in res: yield chunk async def _nonstream_completion(self, request: CompletionRequest) -> CompletionResponse: diff --git a/llama_stack/providers/utils/inference/openai_compat.py b/llama_stack/providers/utils/inference/openai_compat.py index 1047c9a58..7480ff2c7 100644 --- a/llama_stack/providers/utils/inference/openai_compat.py +++ b/llama_stack/providers/utils/inference/openai_compat.py @@ -3,6 +3,7 @@ # # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. +import json import logging from typing import AsyncGenerator, Dict, List, Optional, Union @@ -14,7 +15,8 @@ from llama_models.datatypes import ( ) from llama_models.llama3.api.chat_format import ChatFormat -from llama_models.llama3.api.datatypes import StopReason +from llama_models.llama3.api.datatypes import StopReason, ToolCall +from openai.types.chat import ChatCompletionMessageToolCall from pydantic import BaseModel from llama_stack.apis.common.content_types import ( @@ -408,3 +410,38 @@ async def convert_message_to_openai_dict(message: Message, download: bool = Fals "role": message.role, "content": content, } + + +class UnparseableToolCall(BaseModel): + """ + A ToolCall with arguments that are not valid JSON. + Mirrors the ToolCall schema, but with arguments as a string. + """ + + call_id: str = "" + tool_name: str = "" + arguments: str = "" + + +def convert_tool_call( + tool_call: ChatCompletionMessageToolCall, +) -> Union[ToolCall, UnparseableToolCall]: + """ + Convert a ChatCompletionMessageToolCall tool call to either a + ToolCall or UnparseableToolCall. Returns an UnparseableToolCall + if the tool call is not valid JSON. + """ + try: + arguments = json.loads(tool_call.function.arguments) + except Exception as e: + return UnparseableToolCall( + call_id=tool_call.id or "", + tool_name=tool_call.function.name or "", + arguments=tool_call.function.arguments or "", + ) + + return ToolCall( + call_id=tool_call.id, + tool_name=tool_call.function.name, + arguments=arguments, + )