do not popualte tool calls if it is not in request

This commit is contained in:
Xi Yan 2025-02-10 18:43:51 -08:00
parent 3856927ee8
commit c9ef88854c
2 changed files with 45 additions and 2 deletions

View file

@ -226,14 +226,19 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProv
async def _nonstream_chat_completion(self, request: ChatCompletionRequest) -> ChatCompletionResponse: async def _nonstream_chat_completion(self, request: ChatCompletionRequest) -> ChatCompletionResponse:
params = await self._get_params(request) params = await self._get_params(request)
if "messages" in params: if "messages" in params:
r = await self._get_client().chat.completions.acreate(**params) r = await self._get_client().chat.completions.acreate(**params)
else: else:
r = await self._get_client().completion.acreate(**params) r = await self._get_client().completion.acreate(**params)
return process_chat_completion_response(r, self.formatter) return process_chat_completion_response(r, self.formatter, request)
async def _stream_chat_completion(self, request: ChatCompletionRequest) -> AsyncGenerator: async def _stream_chat_completion(self, request: ChatCompletionRequest) -> AsyncGenerator:
params = await self._get_params(request) params = await self._get_params(request)
from rich.pretty import pprint
print("!! FIREWORKS STREAM PARAMS !!")
pprint(params)
async def _to_async_generator(): async def _to_async_generator():
if "messages" in params: if "messages" in params:

View file

@ -26,6 +26,7 @@ from llama_stack.apis.common.content_types import (
) )
from llama_stack.apis.inference import ( from llama_stack.apis.inference import (
ChatCompletionRequest,
ChatCompletionResponse, ChatCompletionResponse,
ChatCompletionResponseEvent, ChatCompletionResponseEvent,
ChatCompletionResponseEventType, ChatCompletionResponseEventType,
@ -170,13 +171,44 @@ def process_completion_response(response: OpenAICompatCompletionResponse, format
def process_chat_completion_response( def process_chat_completion_response(
response: OpenAICompatCompletionResponse, formatter: ChatFormat response: OpenAICompatCompletionResponse,
formatter: ChatFormat,
request: ChatCompletionRequest,
) -> ChatCompletionResponse: ) -> ChatCompletionResponse:
from rich.pretty import pprint
print("!! FIREWORKS NON_STREAM !!")
pprint(response)
choice = response.choices[0] choice = response.choices[0]
raw_message = formatter.decode_assistant_message_from_content( raw_message = formatter.decode_assistant_message_from_content(
text_from_choice(choice), get_stop_reason(choice.finish_reason) text_from_choice(choice), get_stop_reason(choice.finish_reason)
) )
# NOTE: If we do not set tools in chat-completion request, we should not
# expect the ToolCall in the response. Instead, we should return the raw
# response from the model.
if raw_message.tool_calls:
if not request.tools:
print("Parsed tool calls but no tools provided in request")
raw_message.tool_calls = []
raw_message.content = text_from_choice(choice)
# check if tool_calls is provided in the request
elif request.tools:
new_tool_calls = []
request_tools = {t.tool_name: t for t in request.tools}
for t in raw_message.tool_calls:
if t.tool_name in request_tools:
new_tool_calls.append(t)
else:
print(f"Tool {t.tool_name} not found in request tools")
if len(new_tool_calls) < len(raw_message.tool_calls):
print("Some tool calls were not provided in the request")
raw_message.tool_calls = new_tool_calls
raw_message.content = text_from_choice(choice)
return ChatCompletionResponse( return ChatCompletionResponse(
completion_message=CompletionMessage( completion_message=CompletionMessage(
content=raw_message.content, content=raw_message.content,
@ -236,8 +268,12 @@ async def process_chat_completion_stream_response(
buffer = "" buffer = ""
ipython = False ipython = False
stop_reason = None stop_reason = None
from rich.pretty import pprint
async for chunk in stream: async for chunk in stream:
print("!! CHUNK !!")
pprint(chunk)
choice = chunk.choices[0] choice = chunk.choices[0]
finish_reason = choice.finish_reason finish_reason = choice.finish_reason
@ -303,6 +339,8 @@ async def process_chat_completion_stream_response(
# parse tool calls and report errors # parse tool calls and report errors
message = formatter.decode_assistant_message_from_content(buffer, stop_reason) message = formatter.decode_assistant_message_from_content(buffer, stop_reason)
print(f"Parse TOOL CALLS message: {message}")
parsed_tool_calls = len(message.tool_calls) > 0 parsed_tool_calls = len(message.tool_calls) > 0
if ipython and not parsed_tool_calls: if ipython and not parsed_tool_calls:
yield ChatCompletionResponseStreamChunk( yield ChatCompletionResponseStreamChunk(