do not popualte tool calls if it is not in request

This commit is contained in:
Xi Yan 2025-02-10 18:43:51 -08:00
parent 3856927ee8
commit c9ef88854c
2 changed files with 45 additions and 2 deletions

View file

@ -226,14 +226,19 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProv
async def _nonstream_chat_completion(self, request: ChatCompletionRequest) -> ChatCompletionResponse:
params = await self._get_params(request)
if "messages" in params:
r = await self._get_client().chat.completions.acreate(**params)
else:
r = await self._get_client().completion.acreate(**params)
return process_chat_completion_response(r, self.formatter)
return process_chat_completion_response(r, self.formatter, request)
async def _stream_chat_completion(self, request: ChatCompletionRequest) -> AsyncGenerator:
params = await self._get_params(request)
from rich.pretty import pprint
print("!! FIREWORKS STREAM PARAMS !!")
pprint(params)
async def _to_async_generator():
if "messages" in params:

View file

@ -26,6 +26,7 @@ from llama_stack.apis.common.content_types import (
)
from llama_stack.apis.inference import (
ChatCompletionRequest,
ChatCompletionResponse,
ChatCompletionResponseEvent,
ChatCompletionResponseEventType,
@ -170,13 +171,44 @@ def process_completion_response(response: OpenAICompatCompletionResponse, format
def process_chat_completion_response(
response: OpenAICompatCompletionResponse, formatter: ChatFormat
response: OpenAICompatCompletionResponse,
formatter: ChatFormat,
request: ChatCompletionRequest,
) -> ChatCompletionResponse:
from rich.pretty import pprint
print("!! FIREWORKS NON_STREAM !!")
pprint(response)
choice = response.choices[0]
raw_message = formatter.decode_assistant_message_from_content(
text_from_choice(choice), get_stop_reason(choice.finish_reason)
)
# NOTE: If we do not set tools in chat-completion request, we should not
# expect the ToolCall in the response. Instead, we should return the raw
# response from the model.
if raw_message.tool_calls:
if not request.tools:
print("Parsed tool calls but no tools provided in request")
raw_message.tool_calls = []
raw_message.content = text_from_choice(choice)
# check if tool_calls is provided in the request
elif request.tools:
new_tool_calls = []
request_tools = {t.tool_name: t for t in request.tools}
for t in raw_message.tool_calls:
if t.tool_name in request_tools:
new_tool_calls.append(t)
else:
print(f"Tool {t.tool_name} not found in request tools")
if len(new_tool_calls) < len(raw_message.tool_calls):
print("Some tool calls were not provided in the request")
raw_message.tool_calls = new_tool_calls
raw_message.content = text_from_choice(choice)
return ChatCompletionResponse(
completion_message=CompletionMessage(
content=raw_message.content,
@ -236,8 +268,12 @@ async def process_chat_completion_stream_response(
buffer = ""
ipython = False
stop_reason = None
from rich.pretty import pprint
async for chunk in stream:
print("!! CHUNK !!")
pprint(chunk)
choice = chunk.choices[0]
finish_reason = choice.finish_reason
@ -303,6 +339,8 @@ async def process_chat_completion_stream_response(
# parse tool calls and report errors
message = formatter.decode_assistant_message_from_content(buffer, stop_reason)
print(f"Parse TOOL CALLS message: {message}")
parsed_tool_calls = len(message.tool_calls) > 0
if ipython and not parsed_tool_calls:
yield ChatCompletionResponseStreamChunk(