From c9ef88854c78d110226cf3b3cf01e9059c0838e6 Mon Sep 17 00:00:00 2001 From: Xi Yan Date: Mon, 10 Feb 2025 18:43:51 -0800 Subject: [PATCH] do not popualte tool calls if it is not in request --- .../remote/inference/fireworks/fireworks.py | 7 +++- .../utils/inference/openai_compat.py | 40 ++++++++++++++++++- 2 files changed, 45 insertions(+), 2 deletions(-) diff --git a/llama_stack/providers/remote/inference/fireworks/fireworks.py b/llama_stack/providers/remote/inference/fireworks/fireworks.py index d47c035b8..186036486 100644 --- a/llama_stack/providers/remote/inference/fireworks/fireworks.py +++ b/llama_stack/providers/remote/inference/fireworks/fireworks.py @@ -226,14 +226,19 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProv async def _nonstream_chat_completion(self, request: ChatCompletionRequest) -> ChatCompletionResponse: params = await self._get_params(request) + if "messages" in params: r = await self._get_client().chat.completions.acreate(**params) else: r = await self._get_client().completion.acreate(**params) - return process_chat_completion_response(r, self.formatter) + return process_chat_completion_response(r, self.formatter, request) async def _stream_chat_completion(self, request: ChatCompletionRequest) -> AsyncGenerator: params = await self._get_params(request) + from rich.pretty import pprint + + print("!! FIREWORKS STREAM PARAMS !!") + pprint(params) async def _to_async_generator(): if "messages" in params: diff --git a/llama_stack/providers/utils/inference/openai_compat.py b/llama_stack/providers/utils/inference/openai_compat.py index a3e893d8f..3eb0977d5 100644 --- a/llama_stack/providers/utils/inference/openai_compat.py +++ b/llama_stack/providers/utils/inference/openai_compat.py @@ -26,6 +26,7 @@ from llama_stack.apis.common.content_types import ( ) from llama_stack.apis.inference import ( + ChatCompletionRequest, ChatCompletionResponse, ChatCompletionResponseEvent, ChatCompletionResponseEventType, @@ -170,13 +171,44 @@ def process_completion_response(response: OpenAICompatCompletionResponse, format def process_chat_completion_response( - response: OpenAICompatCompletionResponse, formatter: ChatFormat + response: OpenAICompatCompletionResponse, + formatter: ChatFormat, + request: ChatCompletionRequest, ) -> ChatCompletionResponse: + from rich.pretty import pprint + + print("!! FIREWORKS NON_STREAM !!") + pprint(response) + choice = response.choices[0] raw_message = formatter.decode_assistant_message_from_content( text_from_choice(choice), get_stop_reason(choice.finish_reason) ) + + # NOTE: If we do not set tools in chat-completion request, we should not + # expect the ToolCall in the response. Instead, we should return the raw + # response from the model. + if raw_message.tool_calls: + if not request.tools: + print("Parsed tool calls but no tools provided in request") + raw_message.tool_calls = [] + raw_message.content = text_from_choice(choice) + # check if tool_calls is provided in the request + elif request.tools: + new_tool_calls = [] + request_tools = {t.tool_name: t for t in request.tools} + for t in raw_message.tool_calls: + if t.tool_name in request_tools: + new_tool_calls.append(t) + else: + print(f"Tool {t.tool_name} not found in request tools") + + if len(new_tool_calls) < len(raw_message.tool_calls): + print("Some tool calls were not provided in the request") + raw_message.tool_calls = new_tool_calls + raw_message.content = text_from_choice(choice) + return ChatCompletionResponse( completion_message=CompletionMessage( content=raw_message.content, @@ -236,8 +268,12 @@ async def process_chat_completion_stream_response( buffer = "" ipython = False stop_reason = None + from rich.pretty import pprint async for chunk in stream: + print("!! CHUNK !!") + pprint(chunk) + choice = chunk.choices[0] finish_reason = choice.finish_reason @@ -303,6 +339,8 @@ async def process_chat_completion_stream_response( # parse tool calls and report errors message = formatter.decode_assistant_message_from_content(buffer, stop_reason) + print(f"Parse TOOL CALLS message: {message}") + parsed_tool_calls = len(message.tool_calls) > 0 if ipython and not parsed_tool_calls: yield ChatCompletionResponseStreamChunk(