mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-07 02:58:21 +00:00
do not popualte tool calls if it is not in request
This commit is contained in:
parent
3856927ee8
commit
c9ef88854c
2 changed files with 45 additions and 2 deletions
|
@ -226,14 +226,19 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProv
|
||||||
|
|
||||||
async def _nonstream_chat_completion(self, request: ChatCompletionRequest) -> ChatCompletionResponse:
|
async def _nonstream_chat_completion(self, request: ChatCompletionRequest) -> ChatCompletionResponse:
|
||||||
params = await self._get_params(request)
|
params = await self._get_params(request)
|
||||||
|
|
||||||
if "messages" in params:
|
if "messages" in params:
|
||||||
r = await self._get_client().chat.completions.acreate(**params)
|
r = await self._get_client().chat.completions.acreate(**params)
|
||||||
else:
|
else:
|
||||||
r = await self._get_client().completion.acreate(**params)
|
r = await self._get_client().completion.acreate(**params)
|
||||||
return process_chat_completion_response(r, self.formatter)
|
return process_chat_completion_response(r, self.formatter, request)
|
||||||
|
|
||||||
async def _stream_chat_completion(self, request: ChatCompletionRequest) -> AsyncGenerator:
|
async def _stream_chat_completion(self, request: ChatCompletionRequest) -> AsyncGenerator:
|
||||||
params = await self._get_params(request)
|
params = await self._get_params(request)
|
||||||
|
from rich.pretty import pprint
|
||||||
|
|
||||||
|
print("!! FIREWORKS STREAM PARAMS !!")
|
||||||
|
pprint(params)
|
||||||
|
|
||||||
async def _to_async_generator():
|
async def _to_async_generator():
|
||||||
if "messages" in params:
|
if "messages" in params:
|
||||||
|
|
|
@ -26,6 +26,7 @@ from llama_stack.apis.common.content_types import (
|
||||||
)
|
)
|
||||||
|
|
||||||
from llama_stack.apis.inference import (
|
from llama_stack.apis.inference import (
|
||||||
|
ChatCompletionRequest,
|
||||||
ChatCompletionResponse,
|
ChatCompletionResponse,
|
||||||
ChatCompletionResponseEvent,
|
ChatCompletionResponseEvent,
|
||||||
ChatCompletionResponseEventType,
|
ChatCompletionResponseEventType,
|
||||||
|
@ -170,13 +171,44 @@ def process_completion_response(response: OpenAICompatCompletionResponse, format
|
||||||
|
|
||||||
|
|
||||||
def process_chat_completion_response(
|
def process_chat_completion_response(
|
||||||
response: OpenAICompatCompletionResponse, formatter: ChatFormat
|
response: OpenAICompatCompletionResponse,
|
||||||
|
formatter: ChatFormat,
|
||||||
|
request: ChatCompletionRequest,
|
||||||
) -> ChatCompletionResponse:
|
) -> ChatCompletionResponse:
|
||||||
|
from rich.pretty import pprint
|
||||||
|
|
||||||
|
print("!! FIREWORKS NON_STREAM !!")
|
||||||
|
pprint(response)
|
||||||
|
|
||||||
choice = response.choices[0]
|
choice = response.choices[0]
|
||||||
|
|
||||||
raw_message = formatter.decode_assistant_message_from_content(
|
raw_message = formatter.decode_assistant_message_from_content(
|
||||||
text_from_choice(choice), get_stop_reason(choice.finish_reason)
|
text_from_choice(choice), get_stop_reason(choice.finish_reason)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# NOTE: If we do not set tools in chat-completion request, we should not
|
||||||
|
# expect the ToolCall in the response. Instead, we should return the raw
|
||||||
|
# response from the model.
|
||||||
|
if raw_message.tool_calls:
|
||||||
|
if not request.tools:
|
||||||
|
print("Parsed tool calls but no tools provided in request")
|
||||||
|
raw_message.tool_calls = []
|
||||||
|
raw_message.content = text_from_choice(choice)
|
||||||
|
# check if tool_calls is provided in the request
|
||||||
|
elif request.tools:
|
||||||
|
new_tool_calls = []
|
||||||
|
request_tools = {t.tool_name: t for t in request.tools}
|
||||||
|
for t in raw_message.tool_calls:
|
||||||
|
if t.tool_name in request_tools:
|
||||||
|
new_tool_calls.append(t)
|
||||||
|
else:
|
||||||
|
print(f"Tool {t.tool_name} not found in request tools")
|
||||||
|
|
||||||
|
if len(new_tool_calls) < len(raw_message.tool_calls):
|
||||||
|
print("Some tool calls were not provided in the request")
|
||||||
|
raw_message.tool_calls = new_tool_calls
|
||||||
|
raw_message.content = text_from_choice(choice)
|
||||||
|
|
||||||
return ChatCompletionResponse(
|
return ChatCompletionResponse(
|
||||||
completion_message=CompletionMessage(
|
completion_message=CompletionMessage(
|
||||||
content=raw_message.content,
|
content=raw_message.content,
|
||||||
|
@ -236,8 +268,12 @@ async def process_chat_completion_stream_response(
|
||||||
buffer = ""
|
buffer = ""
|
||||||
ipython = False
|
ipython = False
|
||||||
stop_reason = None
|
stop_reason = None
|
||||||
|
from rich.pretty import pprint
|
||||||
|
|
||||||
async for chunk in stream:
|
async for chunk in stream:
|
||||||
|
print("!! CHUNK !!")
|
||||||
|
pprint(chunk)
|
||||||
|
|
||||||
choice = chunk.choices[0]
|
choice = chunk.choices[0]
|
||||||
finish_reason = choice.finish_reason
|
finish_reason = choice.finish_reason
|
||||||
|
|
||||||
|
@ -303,6 +339,8 @@ async def process_chat_completion_stream_response(
|
||||||
|
|
||||||
# parse tool calls and report errors
|
# parse tool calls and report errors
|
||||||
message = formatter.decode_assistant_message_from_content(buffer, stop_reason)
|
message = formatter.decode_assistant_message_from_content(buffer, stop_reason)
|
||||||
|
print(f"Parse TOOL CALLS message: {message}")
|
||||||
|
|
||||||
parsed_tool_calls = len(message.tool_calls) > 0
|
parsed_tool_calls = len(message.tool_calls) > 0
|
||||||
if ipython and not parsed_tool_calls:
|
if ipython and not parsed_tool_calls:
|
||||||
yield ChatCompletionResponseStreamChunk(
|
yield ChatCompletionResponseStreamChunk(
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue