mirror of
https://github.com/meta-llama/llama-stack.git
synced 2026-01-02 00:34:30 +00:00
PR tool call followups
This commit is contained in:
parent
1f60c0286d
commit
76e08cfde0
3 changed files with 126 additions and 30 deletions
|
|
@ -6,7 +6,7 @@
|
|||
|
||||
import json
|
||||
import warnings
|
||||
from typing import AsyncGenerator, Literal
|
||||
from typing import AsyncGenerator, Literal, Union
|
||||
|
||||
from groq import Stream
|
||||
from groq.types.chat.chat_completion import ChatCompletion
|
||||
|
|
@ -30,6 +30,8 @@ from groq.types.shared.function_definition import FunctionDefinition
|
|||
|
||||
from llama_models.llama3.api.datatypes import ToolParamDefinition
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from llama_stack.apis.common.content_types import (
|
||||
TextDelta,
|
||||
ToolCallDelta,
|
||||
|
|
@ -150,15 +152,26 @@ def convert_chat_completion_response(
|
|||
_convert_groq_tool_call(tool_call)
|
||||
for tool_call in choice.message.tool_calls
|
||||
]
|
||||
return ChatCompletionResponse(
|
||||
completion_message=CompletionMessage(
|
||||
tool_calls=tool_calls,
|
||||
stop_reason=StopReason.end_of_message,
|
||||
# Content is not optional
|
||||
content="",
|
||||
),
|
||||
logprobs=None,
|
||||
)
|
||||
if any(isinstance(tool_call, UnparseableToolCall) for tool_call in tool_calls):
|
||||
# If we couldn't parse a tool call, jsonify the tool calls and return them
|
||||
return ChatCompletionResponse(
|
||||
completion_message=CompletionMessage(
|
||||
stop_reason=StopReason.end_of_message,
|
||||
content=json.dumps(tool_calls, default=lambda x: x.model_dump()),
|
||||
),
|
||||
logprobs=None,
|
||||
)
|
||||
else:
|
||||
# Otherwise, return tool calls as normal
|
||||
return ChatCompletionResponse(
|
||||
completion_message=CompletionMessage(
|
||||
tool_calls=tool_calls,
|
||||
stop_reason=StopReason.end_of_message,
|
||||
# Content is not optional
|
||||
content="",
|
||||
),
|
||||
logprobs=None,
|
||||
)
|
||||
else:
|
||||
return ChatCompletionResponse(
|
||||
completion_message=CompletionMessage(
|
||||
|
|
@ -214,15 +227,27 @@ async def convert_chat_completion_response_stream(
|
|||
|
||||
# We assume Groq produces fully formed tool calls for each chunk
|
||||
tool_call = _convert_groq_tool_call(choice.delta.tool_calls[0])
|
||||
yield ChatCompletionResponseStreamChunk(
|
||||
event=ChatCompletionResponseEvent(
|
||||
event_type=event_type,
|
||||
delta=ToolCallDelta(
|
||||
content=tool_call,
|
||||
parse_status=ToolCallParseStatus.succeeded,
|
||||
),
|
||||
if isinstance(tool_call, ToolCall):
|
||||
yield ChatCompletionResponseStreamChunk(
|
||||
event=ChatCompletionResponseEvent(
|
||||
event_type=event_type,
|
||||
delta=ToolCallDelta(
|
||||
content=tool_call,
|
||||
parse_status=ToolCallParseStatus.succeeded,
|
||||
),
|
||||
)
|
||||
)
|
||||
else:
|
||||
# Otherwise it's an UnparseableToolCall - return the raw tool call
|
||||
yield ChatCompletionResponseStreamChunk(
|
||||
event=ChatCompletionResponseEvent(
|
||||
event_type=event_type,
|
||||
delta=ToolCallDelta(
|
||||
content=tool_call.model_dump_json(),
|
||||
parse_status=ToolCallParseStatus.failed,
|
||||
),
|
||||
)
|
||||
)
|
||||
)
|
||||
else:
|
||||
yield ChatCompletionResponseStreamChunk(
|
||||
event=ChatCompletionResponseEvent(
|
||||
|
|
@ -234,12 +259,35 @@ async def convert_chat_completion_response_stream(
|
|||
event_type = ChatCompletionResponseEventType.progress
|
||||
|
||||
|
||||
def _convert_groq_tool_call(tool_call: ChatCompletionMessageToolCall) -> ToolCall:
|
||||
class UnparseableToolCall(BaseModel):
|
||||
"""
|
||||
A ToolCall with arguments that are not valid JSON.
|
||||
Mirrors the ToolCall schema, but with arguments as a string.
|
||||
"""
|
||||
|
||||
call_id: str
|
||||
tool_name: str
|
||||
arguments: str
|
||||
|
||||
|
||||
def _convert_groq_tool_call(
|
||||
tool_call: ChatCompletionMessageToolCall,
|
||||
) -> Union[ToolCall, UnparseableToolCall]:
|
||||
"""
|
||||
Convert a Groq tool call to a ToolCall.
|
||||
Returns an UnparseableToolCall if the tool call is not valid JSON.
|
||||
"""
|
||||
try:
|
||||
arguments = json.loads(tool_call.function.arguments)
|
||||
except Exception as e:
|
||||
return UnparseableToolCall(
|
||||
call_id=tool_call.id,
|
||||
tool_name=tool_call.function.name,
|
||||
arguments=tool_call.function.arguments,
|
||||
)
|
||||
|
||||
return ToolCall(
|
||||
call_id=tool_call.id,
|
||||
tool_name=tool_call.function.name,
|
||||
# Note that Groq may return a string that is not valid JSON here
|
||||
# So this may raise a 500 error. Going to leave this as is to see
|
||||
# how big of an issue this is and what we can do about it.
|
||||
arguments=json.loads(tool_call.function.arguments),
|
||||
arguments=arguments,
|
||||
)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue