diff --git a/llama_stack/providers/inline/agents/meta_reference/agent_instance.py b/llama_stack/providers/inline/agents/meta_reference/agent_instance.py index 51691c546..2f397f438 100644 --- a/llama_stack/providers/inline/agents/meta_reference/agent_instance.py +++ b/llama_stack/providers/inline/agents/meta_reference/agent_instance.py @@ -513,6 +513,9 @@ class ChatAgent(ShieldRunnerMixin): if delta.type == "tool_call": if delta.parse_status == ToolCallParseStatus.succeeded: tool_calls.append(delta.tool_call) + elif delta.parse_status == ToolCallParseStatus.failed: + # If we cannot parse the tools, set the content to the unparsed raw text + content = delta.tool_call if stream: yield AgentTurnResponseStreamChunk( event=AgentTurnResponseEvent( diff --git a/llama_stack/providers/inline/inference/vllm/vllm.py b/llama_stack/providers/inline/inference/vllm/vllm.py index 691737c15..77c95cc7e 100644 --- a/llama_stack/providers/inline/inference/vllm/vllm.py +++ b/llama_stack/providers/inline/inference/vllm/vllm.py @@ -201,7 +201,7 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate): response = OpenAICompatCompletionResponse( choices=[choice], ) - return process_chat_completion_response(response, self.formatter) + return process_chat_completion_response(response, self.formatter, request) async def _stream_chat_completion( self, request: ChatCompletionRequest, results_generator: AsyncGenerator @@ -227,7 +227,7 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate): ) stream = _generate_and_convert_to_openai_compat() - async for chunk in process_chat_completion_stream_response(stream, self.formatter): + async for chunk in process_chat_completion_stream_response(stream, self.formatter, request): yield chunk async def embeddings(self, model_id: str, contents: List[InterleavedContent]) -> EmbeddingsResponse: diff --git a/llama_stack/providers/remote/inference/bedrock/bedrock.py b/llama_stack/providers/remote/inference/bedrock/bedrock.py index 03a0a40c3..54a674d7e 100644 --- a/llama_stack/providers/remote/inference/bedrock/bedrock.py +++ b/llama_stack/providers/remote/inference/bedrock/bedrock.py @@ -134,7 +134,7 @@ class BedrockInferenceAdapter(ModelRegistryHelper, Inference): ) response = OpenAICompatCompletionResponse(choices=[choice]) - return process_chat_completion_response(response, self.formatter) + return process_chat_completion_response(response, self.formatter, request) async def _stream_chat_completion(self, request: ChatCompletionRequest) -> AsyncGenerator: params = await self._get_params_for_chat_completion(request) @@ -152,7 +152,7 @@ class BedrockInferenceAdapter(ModelRegistryHelper, Inference): yield OpenAICompatCompletionResponse(choices=[choice]) stream = _generate_and_convert_to_openai_compat() - async for chunk in process_chat_completion_stream_response(stream, self.formatter): + async for chunk in process_chat_completion_stream_response(stream, self.formatter, request): yield chunk async def _get_params_for_chat_completion(self, request: ChatCompletionRequest) -> Dict: diff --git a/llama_stack/providers/remote/inference/cerebras/cerebras.py b/llama_stack/providers/remote/inference/cerebras/cerebras.py index bd12c56c8..47f208129 100644 --- a/llama_stack/providers/remote/inference/cerebras/cerebras.py +++ b/llama_stack/providers/remote/inference/cerebras/cerebras.py @@ -155,14 +155,14 @@ class CerebrasInferenceAdapter(ModelRegistryHelper, Inference): r = await self.client.completions.create(**params) - return process_chat_completion_response(r, self.formatter) + return process_chat_completion_response(r, self.formatter, request) async def _stream_chat_completion(self, request: CompletionRequest) -> AsyncGenerator: params = await self._get_params(request) stream = await self.client.completions.create(**params) - async for chunk in process_chat_completion_stream_response(stream, self.formatter): + async for chunk in process_chat_completion_stream_response(stream, self.formatter, request): yield chunk async def _get_params(self, request: Union[ChatCompletionRequest, CompletionRequest]) -> dict: diff --git a/llama_stack/providers/remote/inference/databricks/databricks.py b/llama_stack/providers/remote/inference/databricks/databricks.py index 37070b4ce..ee3c6e99b 100644 --- a/llama_stack/providers/remote/inference/databricks/databricks.py +++ b/llama_stack/providers/remote/inference/databricks/databricks.py @@ -112,7 +112,7 @@ class DatabricksInferenceAdapter(ModelRegistryHelper, Inference): ) -> ChatCompletionResponse: params = self._get_params(request) r = client.completions.create(**params) - return process_chat_completion_response(r, self.formatter) + return process_chat_completion_response(r, self.formatter, request) async def _stream_chat_completion(self, request: ChatCompletionRequest, client: OpenAI) -> AsyncGenerator: params = self._get_params(request) @@ -123,7 +123,7 @@ class DatabricksInferenceAdapter(ModelRegistryHelper, Inference): yield chunk stream = _to_async_generator() - async for chunk in process_chat_completion_stream_response(stream, self.formatter): + async for chunk in process_chat_completion_stream_response(stream, self.formatter, request): yield chunk def _get_params(self, request: ChatCompletionRequest) -> dict: diff --git a/llama_stack/providers/remote/inference/fireworks/fireworks.py b/llama_stack/providers/remote/inference/fireworks/fireworks.py index d47c035b8..d978cb02e 100644 --- a/llama_stack/providers/remote/inference/fireworks/fireworks.py +++ b/llama_stack/providers/remote/inference/fireworks/fireworks.py @@ -230,7 +230,7 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProv r = await self._get_client().chat.completions.acreate(**params) else: r = await self._get_client().completion.acreate(**params) - return process_chat_completion_response(r, self.formatter) + return process_chat_completion_response(r, self.formatter, request) async def _stream_chat_completion(self, request: ChatCompletionRequest) -> AsyncGenerator: params = await self._get_params(request) @@ -244,7 +244,7 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProv yield chunk stream = _to_async_generator() - async for chunk in process_chat_completion_stream_response(stream, self.formatter): + async for chunk in process_chat_completion_stream_response(stream, self.formatter, request): yield chunk async def _get_params(self, request: Union[ChatCompletionRequest, CompletionRequest]) -> dict: diff --git a/llama_stack/providers/remote/inference/ollama/ollama.py b/llama_stack/providers/remote/inference/ollama/ollama.py index ecd195854..05a5d2d7a 100644 --- a/llama_stack/providers/remote/inference/ollama/ollama.py +++ b/llama_stack/providers/remote/inference/ollama/ollama.py @@ -304,7 +304,7 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate): response = OpenAICompatCompletionResponse( choices=[choice], ) - return process_chat_completion_response(response, self.formatter) + return process_chat_completion_response(response, self.formatter, request) async def _stream_chat_completion(self, request: ChatCompletionRequest) -> AsyncGenerator: params = await self._get_params(request) @@ -330,7 +330,7 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate): ) stream = _generate_and_convert_to_openai_compat() - async for chunk in process_chat_completion_stream_response(stream, self.formatter): + async for chunk in process_chat_completion_stream_response(stream, self.formatter, request): yield chunk async def embeddings( diff --git a/llama_stack/providers/remote/inference/runpod/runpod.py b/llama_stack/providers/remote/inference/runpod/runpod.py index a62b0c97f..c7b20b9a1 100644 --- a/llama_stack/providers/remote/inference/runpod/runpod.py +++ b/llama_stack/providers/remote/inference/runpod/runpod.py @@ -99,7 +99,7 @@ class RunpodInferenceAdapter(ModelRegistryHelper, Inference): ) -> ChatCompletionResponse: params = self._get_params(request) r = client.completions.create(**params) - return process_chat_completion_response(r, self.formatter) + return process_chat_completion_response(r, self.formatter, request) async def _stream_chat_completion(self, request: ChatCompletionRequest, client: OpenAI) -> AsyncGenerator: params = self._get_params(request) @@ -110,7 +110,7 @@ class RunpodInferenceAdapter(ModelRegistryHelper, Inference): yield chunk stream = _to_async_generator() - async for chunk in process_chat_completion_stream_response(stream, self.formatter): + async for chunk in process_chat_completion_stream_response(stream, self.formatter, request): yield chunk def _get_params(self, request: ChatCompletionRequest) -> dict: diff --git a/llama_stack/providers/remote/inference/sambanova/sambanova.py b/llama_stack/providers/remote/inference/sambanova/sambanova.py index 87aab1e88..18a78e69c 100644 --- a/llama_stack/providers/remote/inference/sambanova/sambanova.py +++ b/llama_stack/providers/remote/inference/sambanova/sambanova.py @@ -160,7 +160,7 @@ class SambaNovaInferenceAdapter(ModelRegistryHelper, Inference): yield chunk stream = _to_async_generator() - async for chunk in process_chat_completion_stream_response(stream, self.formatter): + async for chunk in process_chat_completion_stream_response(stream, self.formatter, request): yield chunk async def embeddings( diff --git a/llama_stack/providers/remote/inference/tgi/tgi.py b/llama_stack/providers/remote/inference/tgi/tgi.py index 2281319b3..97a6621fb 100644 --- a/llama_stack/providers/remote/inference/tgi/tgi.py +++ b/llama_stack/providers/remote/inference/tgi/tgi.py @@ -236,7 +236,7 @@ class _HfAdapter(Inference, ModelsProtocolPrivate): response = OpenAICompatCompletionResponse( choices=[choice], ) - return process_chat_completion_response(response, self.formatter) + return process_chat_completion_response(response, self.formatter, request) async def _stream_chat_completion(self, request: ChatCompletionRequest) -> AsyncGenerator: params = await self._get_params(request) @@ -252,7 +252,7 @@ class _HfAdapter(Inference, ModelsProtocolPrivate): ) stream = _generate_and_convert_to_openai_compat() - async for chunk in process_chat_completion_stream_response(stream, self.formatter): + async for chunk in process_chat_completion_stream_response(stream, self.formatter, request): yield chunk async def _get_params(self, request: ChatCompletionRequest) -> dict: diff --git a/llama_stack/providers/remote/inference/together/together.py b/llama_stack/providers/remote/inference/together/together.py index cf24daf60..a165b01d9 100644 --- a/llama_stack/providers/remote/inference/together/together.py +++ b/llama_stack/providers/remote/inference/together/together.py @@ -220,7 +220,7 @@ class TogetherInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProvi r = self._get_client().chat.completions.create(**params) else: r = self._get_client().completions.create(**params) - return process_chat_completion_response(r, self.formatter) + return process_chat_completion_response(r, self.formatter, request) async def _stream_chat_completion(self, request: ChatCompletionRequest) -> AsyncGenerator: params = await self._get_params(request) @@ -235,7 +235,7 @@ class TogetherInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProvi yield chunk stream = _to_async_generator() - async for chunk in process_chat_completion_stream_response(stream, self.formatter): + async for chunk in process_chat_completion_stream_response(stream, self.formatter, request): yield chunk async def _get_params(self, request: Union[ChatCompletionRequest, CompletionRequest]) -> dict: diff --git a/llama_stack/providers/remote/inference/vllm/vllm.py b/llama_stack/providers/remote/inference/vllm/vllm.py index 8618abccf..2e13a6262 100644 --- a/llama_stack/providers/remote/inference/vllm/vllm.py +++ b/llama_stack/providers/remote/inference/vllm/vllm.py @@ -232,7 +232,7 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate): yield chunk stream = _to_async_generator() - async for chunk in process_chat_completion_stream_response(stream, self.formatter): + async for chunk in process_chat_completion_stream_response(stream, self.formatter, request): yield chunk async def _nonstream_completion(self, request: CompletionRequest) -> CompletionResponse: diff --git a/llama_stack/providers/utils/inference/openai_compat.py b/llama_stack/providers/utils/inference/openai_compat.py index 8ee838d84..1047c9a58 100644 --- a/llama_stack/providers/utils/inference/openai_compat.py +++ b/llama_stack/providers/utils/inference/openai_compat.py @@ -3,7 +3,7 @@ # # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. - +import logging from typing import AsyncGenerator, Dict, List, Optional, Union from llama_models.datatypes import ( @@ -26,6 +26,7 @@ from llama_stack.apis.common.content_types import ( ) from llama_stack.apis.inference import ( + ChatCompletionRequest, ChatCompletionResponse, ChatCompletionResponseEvent, ChatCompletionResponseEventType, @@ -41,6 +42,8 @@ from llama_stack.providers.utils.inference.prompt_adapter import ( convert_image_content_to_url, ) +logger = logging.getLogger(__name__) + class OpenAICompatCompletionChoiceDelta(BaseModel): content: str @@ -170,7 +173,9 @@ def process_completion_response(response: OpenAICompatCompletionResponse, format def process_chat_completion_response( - response: OpenAICompatCompletionResponse, formatter: ChatFormat + response: OpenAICompatCompletionResponse, + formatter: ChatFormat, + request: ChatCompletionRequest, ) -> ChatCompletionResponse: choice = response.choices[0] @@ -179,6 +184,28 @@ def process_chat_completion_response( raw_message = formatter.decode_assistant_message_from_content( text_from_choice(choice), get_stop_reason(choice.finish_reason) ) + + # NOTE: If we do not set tools in chat-completion request, we should not + # expect the ToolCall in the response. Instead, we should return the raw + # response from the model. + if raw_message.tool_calls: + if not request.tools: + raw_message.tool_calls = [] + raw_message.content = text_from_choice(choice) + else: + # only return tool_calls if provided in the request + new_tool_calls = [] + request_tools = {t.tool_name: t for t in request.tools} + for t in raw_message.tool_calls: + if t.tool_name in request_tools: + new_tool_calls.append(t) + else: + logger.warning(f"Tool {t.tool_name} not found in request tools") + + if len(new_tool_calls) < len(raw_message.tool_calls): + raw_message.tool_calls = new_tool_calls + raw_message.content = text_from_choice(choice) + return ChatCompletionResponse( completion_message=CompletionMessage( content=raw_message.content, @@ -226,7 +253,9 @@ async def process_completion_stream_response( async def process_chat_completion_stream_response( - stream: AsyncGenerator[OpenAICompatCompletionResponse, None], formatter: ChatFormat + stream: AsyncGenerator[OpenAICompatCompletionResponse, None], + formatter: ChatFormat, + request: ChatCompletionRequest, ) -> AsyncGenerator: yield ChatCompletionResponseStreamChunk( event=ChatCompletionResponseEvent( @@ -305,6 +334,7 @@ async def process_chat_completion_stream_response( # parse tool calls and report errors message = formatter.decode_assistant_message_from_content(buffer, stop_reason) + parsed_tool_calls = len(message.tool_calls) > 0 if ipython and not parsed_tool_calls: yield ChatCompletionResponseStreamChunk( @@ -318,17 +348,33 @@ async def process_chat_completion_stream_response( ) ) + request_tools = {t.tool_name: t for t in request.tools} for tool_call in message.tool_calls: - yield ChatCompletionResponseStreamChunk( - event=ChatCompletionResponseEvent( - event_type=ChatCompletionResponseEventType.progress, - delta=ToolCallDelta( - tool_call=tool_call, - parse_status=ToolCallParseStatus.succeeded, - ), - stop_reason=stop_reason, + if tool_call.tool_name in request_tools: + yield ChatCompletionResponseStreamChunk( + event=ChatCompletionResponseEvent( + event_type=ChatCompletionResponseEventType.progress, + delta=ToolCallDelta( + tool_call=tool_call, + parse_status=ToolCallParseStatus.succeeded, + ), + stop_reason=stop_reason, + ) + ) + else: + logger.warning(f"Tool {tool_call.tool_name} not found in request tools") + yield ChatCompletionResponseStreamChunk( + event=ChatCompletionResponseEvent( + event_type=ChatCompletionResponseEventType.progress, + delta=ToolCallDelta( + # Parsing tool call failed due to tool call not being found in request tools, + # We still add the raw message text inside tool_call for responding back to the user + tool_call=buffer, + parse_status=ToolCallParseStatus.failed, + ), + stop_reason=stop_reason, + ) ) - ) yield ChatCompletionResponseStreamChunk( event=ChatCompletionResponseEvent( diff --git a/tests/client-sdk/inference/test_text_inference.py b/tests/client-sdk/inference/test_text_inference.py index 81b476218..206629602 100644 --- a/tests/client-sdk/inference/test_text_inference.py +++ b/tests/client-sdk/inference/test_text_inference.py @@ -158,7 +158,10 @@ def test_text_completion_structured_output(llama_stack_client, text_model_id, in "question,expected", [ ("Which planet do humans live on?", "Earth"), - ("Which planet has rings around it with a name starting with letter S?", "Saturn"), + ( + "Which planet has rings around it with a name starting with letter S?", + "Saturn", + ), ], ) def test_text_chat_completion_non_streaming(llama_stack_client, text_model_id, question, expected): @@ -280,3 +283,82 @@ def test_text_chat_completion_structured_output(llama_stack_client, text_model_i assert answer.last_name == "Jordan" assert answer.year_of_birth == 1963 assert answer.num_seasons_in_nba == 15 + + +@pytest.mark.parametrize( + "streaming", + [ + True, + False, + ], +) +def test_text_chat_completion_tool_calling_tools_not_in_request(llama_stack_client, text_model_id, streaming): + # TODO: more dynamic lookup on tool_prompt_format for model family + tool_prompt_format = "json" if "3.1" in text_model_id else "python_list" + request = { + "model_id": text_model_id, + "messages": [ + {"role": "system", "content": "You are a helpful assistant."}, + { + "role": "user", + "content": "What pods are in the namespace openshift-lightspeed?", + }, + { + "role": "assistant", + "content": "", + "stop_reason": "end_of_turn", + "tool_calls": [ + { + "call_id": "1", + "tool_name": "get_object_namespace_list", + "arguments": { + "kind": "pod", + "namespace": "openshift-lightspeed", + }, + } + ], + }, + { + "role": "tool", + "call_id": "1", + "tool_name": "get_object_namespace_list", + "content": "the objects are pod1, pod2, pod3", + }, + ], + "tools": [ + { + "tool_name": "get_object_namespace_list", + "description": "Get the list of objects in a namespace", + "parameters": { + "kind": { + "param_type": "string", + "description": "the type of object", + "required": True, + }, + "namespace": { + "param_type": "string", + "description": "the name of the namespace", + "required": True, + }, + }, + } + ], + "tool_choice": "auto", + "tool_prompt_format": tool_prompt_format, + "stream": streaming, + } + + response = llama_stack_client.inference.chat_completion(**request) + + if streaming: + for chunk in response: + delta = chunk.event.delta + if delta.type == "tool_call" and delta.parse_status == "succeeded": + assert delta.tool_call.tool_name == "get_object_namespace_list" + if delta.type == "tool_call" and delta.parse_status == "failed": + # expect raw message that failed to parse in tool_call + assert type(delta.tool_call) == str + assert len(delta.tool_call) > 0 + else: + for tc in response.completion_message.tool_calls: + assert tc.tool_name == "get_object_namespace_list"