mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-13 05:17:26 +00:00
temp fix
Summary: Test Plan:
This commit is contained in:
parent
270d64007a
commit
f389afe024
4 changed files with 250 additions and 166 deletions
|
@ -656,30 +656,27 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
tool_call = message.tool_calls[0]
|
|
||||||
yield AgentTurnResponseStreamChunk(
|
|
||||||
event=AgentTurnResponseEvent(
|
|
||||||
payload=AgentTurnResponseStepProgressPayload(
|
|
||||||
step_type=StepType.tool_execution.value,
|
|
||||||
step_id=step_id,
|
|
||||||
tool_call=tool_call,
|
|
||||||
delta=ToolCallDelta(
|
|
||||||
parse_status=ToolCallParseStatus.in_progress,
|
|
||||||
tool_call=tool_call,
|
|
||||||
),
|
|
||||||
)
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
# If tool is a client tool, yield CompletionMessage and return
|
# Process all tool calls instead of just the first one
|
||||||
if tool_call.tool_name in client_tools:
|
tool_responses = []
|
||||||
|
tool_execution_start_time = datetime.now().astimezone().isoformat()
|
||||||
|
|
||||||
|
# Check if any tool is a client tool
|
||||||
|
client_tool_found = False
|
||||||
|
for tool_call in message.tool_calls:
|
||||||
|
if tool_call.tool_name in client_tools:
|
||||||
|
client_tool_found = True
|
||||||
|
break
|
||||||
|
|
||||||
|
# If any tool is a client tool, yield CompletionMessage and return
|
||||||
|
if client_tool_found:
|
||||||
await self.storage.set_in_progress_tool_call_step(
|
await self.storage.set_in_progress_tool_call_step(
|
||||||
session_id,
|
session_id,
|
||||||
turn_id,
|
turn_id,
|
||||||
ToolExecutionStep(
|
ToolExecutionStep(
|
||||||
step_id=step_id,
|
step_id=step_id,
|
||||||
turn_id=turn_id,
|
turn_id=turn_id,
|
||||||
tool_calls=[tool_call],
|
tool_calls=message.tool_calls,
|
||||||
tool_responses=[],
|
tool_responses=[],
|
||||||
started_at=datetime.now().astimezone().isoformat(),
|
started_at=datetime.now().astimezone().isoformat(),
|
||||||
),
|
),
|
||||||
|
@ -687,41 +684,86 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
yield message
|
yield message
|
||||||
return
|
return
|
||||||
|
|
||||||
# If tool is a builtin server tool, execute it
|
# Add the original message with tool calls to input_messages before processing tool calls
|
||||||
tool_name = tool_call.tool_name
|
input_messages.append(message)
|
||||||
if isinstance(tool_name, BuiltinTool):
|
|
||||||
tool_name = tool_name.value
|
# Process all tool calls
|
||||||
with tracing.span(
|
for tool_call in message.tool_calls:
|
||||||
"tool_execution",
|
yield AgentTurnResponseStreamChunk(
|
||||||
{
|
event=AgentTurnResponseEvent(
|
||||||
"tool_name": tool_name,
|
payload=AgentTurnResponseStepProgressPayload(
|
||||||
"input": message.model_dump_json(),
|
step_type=StepType.tool_execution.value,
|
||||||
},
|
step_id=step_id,
|
||||||
) as span:
|
tool_call=tool_call,
|
||||||
tool_execution_start_time = datetime.now().astimezone().isoformat()
|
delta=ToolCallDelta(
|
||||||
tool_call = message.tool_calls[0]
|
parse_status=ToolCallParseStatus.in_progress,
|
||||||
tool_result = await execute_tool_call_maybe(
|
tool_call=tool_call,
|
||||||
self.tool_runtime_api,
|
),
|
||||||
session_id,
|
)
|
||||||
tool_call,
|
|
||||||
toolgroup_args,
|
|
||||||
tool_to_group,
|
|
||||||
)
|
|
||||||
if tool_result.content is None:
|
|
||||||
raise ValueError(
|
|
||||||
f"Tool call result (id: {tool_call.call_id}, name: {tool_call.tool_name}) does not have any content"
|
|
||||||
)
|
)
|
||||||
result_messages = [
|
)
|
||||||
ToolResponseMessage(
|
|
||||||
|
# Execute the tool call
|
||||||
|
tool_name = tool_call.tool_name
|
||||||
|
if isinstance(tool_name, BuiltinTool):
|
||||||
|
tool_name = tool_name.value
|
||||||
|
with tracing.span(
|
||||||
|
"tool_execution",
|
||||||
|
{
|
||||||
|
"tool_name": tool_name,
|
||||||
|
"input": tool_call.model_dump_json()
|
||||||
|
if hasattr(tool_call, "model_dump_json")
|
||||||
|
else str(tool_call),
|
||||||
|
},
|
||||||
|
) as span:
|
||||||
|
tool_result = await execute_tool_call_maybe(
|
||||||
|
self.tool_runtime_api,
|
||||||
|
session_id,
|
||||||
|
tool_call,
|
||||||
|
toolgroup_args,
|
||||||
|
tool_to_group,
|
||||||
|
)
|
||||||
|
if tool_result.content is None:
|
||||||
|
raise ValueError(
|
||||||
|
f"Tool call result (id: {tool_call.call_id}, name: {tool_call.tool_name}) does not have any content"
|
||||||
|
)
|
||||||
|
|
||||||
|
result_message = ToolResponseMessage(
|
||||||
call_id=tool_call.call_id,
|
call_id=tool_call.call_id,
|
||||||
tool_name=tool_call.tool_name,
|
tool_name=tool_call.tool_name,
|
||||||
content=tool_result.content,
|
content=tool_result.content,
|
||||||
)
|
)
|
||||||
]
|
|
||||||
assert len(result_messages) == 1, "Currently not supporting multiple messages"
|
|
||||||
result_message = result_messages[0]
|
|
||||||
span.set_attribute("output", result_message.model_dump_json())
|
|
||||||
|
|
||||||
|
tool_responses.append(
|
||||||
|
ToolResponse(
|
||||||
|
call_id=result_message.call_id,
|
||||||
|
tool_name=result_message.tool_name,
|
||||||
|
content=result_message.content,
|
||||||
|
metadata=tool_result.metadata,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
span.set_attribute(
|
||||||
|
"output",
|
||||||
|
result_message.model_dump_json()
|
||||||
|
if hasattr(result_message, "model_dump_json")
|
||||||
|
else str(result_message),
|
||||||
|
)
|
||||||
|
|
||||||
|
# TODO: add tool-input touchpoint and a "start" event for this step also
|
||||||
|
# but that needs a lot more refactoring of Tool code potentially
|
||||||
|
if (type(result_message.content) is str) and (
|
||||||
|
out_attachment := _interpret_content_as_attachment(result_message.content)
|
||||||
|
):
|
||||||
|
# NOTE: when we push this message back to the model, the model may ignore the
|
||||||
|
# attached file path etc. since the model is trained to only provide a user message
|
||||||
|
# with the summary. We keep all generated attachments and then attach them to final message
|
||||||
|
output_attachments.append(out_attachment)
|
||||||
|
|
||||||
|
# Add the result message to input_messages
|
||||||
|
input_messages.append(result_message)
|
||||||
|
|
||||||
|
# Complete the tool execution step
|
||||||
yield AgentTurnResponseStreamChunk(
|
yield AgentTurnResponseStreamChunk(
|
||||||
event=AgentTurnResponseEvent(
|
event=AgentTurnResponseEvent(
|
||||||
payload=AgentTurnResponseStepCompletePayload(
|
payload=AgentTurnResponseStepCompletePayload(
|
||||||
|
@ -730,15 +772,8 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
step_details=ToolExecutionStep(
|
step_details=ToolExecutionStep(
|
||||||
step_id=step_id,
|
step_id=step_id,
|
||||||
turn_id=turn_id,
|
turn_id=turn_id,
|
||||||
tool_calls=[tool_call],
|
tool_calls=message.tool_calls,
|
||||||
tool_responses=[
|
tool_responses=tool_responses,
|
||||||
ToolResponse(
|
|
||||||
call_id=result_message.call_id,
|
|
||||||
tool_name=result_message.tool_name,
|
|
||||||
content=result_message.content,
|
|
||||||
metadata=tool_result.metadata,
|
|
||||||
)
|
|
||||||
],
|
|
||||||
started_at=tool_execution_start_time,
|
started_at=tool_execution_start_time,
|
||||||
completed_at=datetime.now().astimezone().isoformat(),
|
completed_at=datetime.now().astimezone().isoformat(),
|
||||||
),
|
),
|
||||||
|
@ -746,18 +781,6 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
# TODO: add tool-input touchpoint and a "start" event for this step also
|
|
||||||
# but that needs a lot more refactoring of Tool code potentially
|
|
||||||
if (type(result_message.content) is str) and (
|
|
||||||
out_attachment := _interpret_content_as_attachment(result_message.content)
|
|
||||||
):
|
|
||||||
# NOTE: when we push this message back to the model, the model may ignore the
|
|
||||||
# attached file path etc. since the model is trained to only provide a user message
|
|
||||||
# with the summary. We keep all generated attachments and then attach them to final message
|
|
||||||
output_attachments.append(out_attachment)
|
|
||||||
|
|
||||||
input_messages = input_messages + [message, result_message]
|
|
||||||
|
|
||||||
n_iter += 1
|
n_iter += 1
|
||||||
|
|
||||||
async def _get_tool_defs(
|
async def _get_tool_defs(
|
||||||
|
|
|
@ -127,7 +127,7 @@ class MemoryToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, RAGToolRuntime):
|
||||||
tokens = 0
|
tokens = 0
|
||||||
picked = [
|
picked = [
|
||||||
TextContentItem(
|
TextContentItem(
|
||||||
text=f"knowledge_search tool found {len(chunks)} chunks:\nBEGIN of knowledge_search tool results.\n"
|
text=f"knowledge_search tool found {len(chunks)} chunks for query:\n{query}\nBEGIN of knowledge_search tool results.\n"
|
||||||
)
|
)
|
||||||
]
|
]
|
||||||
for i, c in enumerate(chunks):
|
for i, c in enumerate(chunks):
|
||||||
|
|
|
@ -99,6 +99,10 @@ class LiteLLMOpenAIMixin(
|
||||||
params = await self._get_params(request)
|
params = await self._get_params(request)
|
||||||
# unfortunately, we need to use synchronous litellm.completion here because litellm
|
# unfortunately, we need to use synchronous litellm.completion here because litellm
|
||||||
# caches various httpx.client objects in a non-eventloop aware manner
|
# caches various httpx.client objects in a non-eventloop aware manner
|
||||||
|
|
||||||
|
from rich.pretty import pprint
|
||||||
|
|
||||||
|
pprint(params)
|
||||||
response = litellm.completion(**params)
|
response = litellm.completion(**params)
|
||||||
if stream:
|
if stream:
|
||||||
return self._stream_chat_completion(response)
|
return self._stream_chat_completion(response)
|
||||||
|
|
|
@ -523,10 +523,11 @@ async def convert_message_to_openai_dict_new(message: Message | Dict) -> OpenAIC
|
||||||
) -> Union[str, Iterable[OpenAIChatCompletionContentPartParam]]:
|
) -> Union[str, Iterable[OpenAIChatCompletionContentPartParam]]:
|
||||||
# Llama Stack and OpenAI spec match for str and text input
|
# Llama Stack and OpenAI spec match for str and text input
|
||||||
if isinstance(content, str):
|
if isinstance(content, str):
|
||||||
return OpenAIChatCompletionContentPartTextParam(
|
# return OpenAIChatCompletionContentPartTextParam(
|
||||||
type="text",
|
# type="text",
|
||||||
text=content,
|
# text=content,
|
||||||
)
|
# )
|
||||||
|
return content
|
||||||
elif isinstance(content, TextContentItem):
|
elif isinstance(content, TextContentItem):
|
||||||
return OpenAIChatCompletionContentPartTextParam(
|
return OpenAIChatCompletionContentPartTextParam(
|
||||||
type="text",
|
type="text",
|
||||||
|
@ -568,12 +569,12 @@ async def convert_message_to_openai_dict_new(message: Message | Dict) -> OpenAIC
|
||||||
out = OpenAIChatCompletionToolMessage(
|
out = OpenAIChatCompletionToolMessage(
|
||||||
role="tool",
|
role="tool",
|
||||||
tool_call_id=message.call_id,
|
tool_call_id=message.call_id,
|
||||||
content=message.content,
|
content=await _convert_user_message_content(message.content),
|
||||||
)
|
)
|
||||||
elif isinstance(message, SystemMessage):
|
elif isinstance(message, SystemMessage):
|
||||||
out = OpenAIChatCompletionSystemMessage(
|
out = OpenAIChatCompletionSystemMessage(
|
||||||
role="system",
|
role="system",
|
||||||
content=message.content,
|
content=await _convert_user_message_content(message.content),
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unsupported message type: {type(message)}")
|
raise ValueError(f"Unsupported message type: {type(message)}")
|
||||||
|
@ -831,18 +832,26 @@ async def convert_openai_chat_completion_stream(
|
||||||
Convert a stream of OpenAI chat completion chunks into a stream
|
Convert a stream of OpenAI chat completion chunks into a stream
|
||||||
of ChatCompletionResponseStreamChunk.
|
of ChatCompletionResponseStreamChunk.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# generate a stream of ChatCompletionResponseEventType: start -> progress -> progress -> ...
|
|
||||||
def _event_type_generator() -> Generator[ChatCompletionResponseEventType, None, None]:
|
|
||||||
yield ChatCompletionResponseEventType.start
|
|
||||||
while True:
|
|
||||||
yield ChatCompletionResponseEventType.progress
|
|
||||||
|
|
||||||
event_type = _event_type_generator()
|
|
||||||
|
|
||||||
stop_reason = None
|
stop_reason = None
|
||||||
toolcall_buffer = {}
|
# Use a dictionary to track multiple tool calls by their index
|
||||||
|
toolcall_buffers = {}
|
||||||
|
# Track which tool calls have been completed
|
||||||
|
completed_tool_indices = set()
|
||||||
|
# Track the highest index seen so far
|
||||||
|
highest_index_seen = -1
|
||||||
|
|
||||||
|
yield ChatCompletionResponseStreamChunk(
|
||||||
|
event=ChatCompletionResponseEvent(
|
||||||
|
event_type=ChatCompletionResponseEventType.start,
|
||||||
|
delta=TextDelta(text=""),
|
||||||
|
stop_reason=stop_reason,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
async for chunk in stream:
|
async for chunk in stream:
|
||||||
|
from rich.pretty import pprint
|
||||||
|
|
||||||
|
pprint(chunk)
|
||||||
choice = chunk.choices[0] # assuming only one choice per chunk
|
choice = chunk.choices[0] # assuming only one choice per chunk
|
||||||
|
|
||||||
# we assume there's only one finish_reason in the stream
|
# we assume there's only one finish_reason in the stream
|
||||||
|
@ -851,112 +860,108 @@ async def convert_openai_chat_completion_stream(
|
||||||
|
|
||||||
# if there's a tool call, emit an event for each tool in the list
|
# if there's a tool call, emit an event for each tool in the list
|
||||||
# if tool call and content, emit both separately
|
# if tool call and content, emit both separately
|
||||||
|
|
||||||
if choice.delta.tool_calls:
|
if choice.delta.tool_calls:
|
||||||
# the call may have content and a tool call. ChatCompletionResponseEvent
|
# the call may have content and a tool call. ChatCompletionResponseEvent
|
||||||
# does not support both, so we emit the content first
|
# does not support both, so we emit the content first
|
||||||
if choice.delta.content:
|
if choice.delta.content:
|
||||||
yield ChatCompletionResponseStreamChunk(
|
yield ChatCompletionResponseStreamChunk(
|
||||||
event=ChatCompletionResponseEvent(
|
event=ChatCompletionResponseEvent(
|
||||||
event_type=next(event_type),
|
event_type=ChatCompletionResponseEventType.progress,
|
||||||
delta=TextDelta(text=choice.delta.content),
|
delta=TextDelta(text=choice.delta.content),
|
||||||
logprobs=_convert_openai_logprobs(logprobs),
|
logprobs=_convert_openai_logprobs(logprobs),
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
# it is possible to have parallel tool calls in stream, but
|
# Process each tool call in the delta
|
||||||
# ChatCompletionResponseEvent only supports one per stream
|
for tool_call in choice.delta.tool_calls:
|
||||||
if len(choice.delta.tool_calls) > 1:
|
# Get the tool call index
|
||||||
warnings.warn("multiple tool calls found in a single delta, using the first, ignoring the rest")
|
tool_index = getattr(tool_call, "index", 0)
|
||||||
|
|
||||||
if not enable_incremental_tool_calls:
|
# If we see a new higher index, complete all previous tool calls
|
||||||
yield ChatCompletionResponseStreamChunk(
|
if tool_index > highest_index_seen:
|
||||||
event=ChatCompletionResponseEvent(
|
# Complete all previous tool calls
|
||||||
event_type=next(event_type),
|
for prev_index in range(highest_index_seen + 1):
|
||||||
delta=ToolCallDelta(
|
if prev_index in toolcall_buffers and prev_index not in completed_tool_indices:
|
||||||
tool_call=_convert_openai_tool_calls(choice.delta.tool_calls)[0],
|
# Complete this tool call
|
||||||
parse_status=ToolCallParseStatus.succeeded,
|
async for event in _complete_tool_call(
|
||||||
),
|
toolcall_buffers[prev_index],
|
||||||
logprobs=_convert_openai_logprobs(logprobs),
|
logprobs,
|
||||||
|
None, # No stop_reason for intermediate tool calls
|
||||||
|
):
|
||||||
|
yield event
|
||||||
|
completed_tool_indices.add(prev_index)
|
||||||
|
|
||||||
|
highest_index_seen = tool_index
|
||||||
|
|
||||||
|
# Skip if this tool call has already been completed
|
||||||
|
if tool_index in completed_tool_indices:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Initialize buffer for this tool call if it doesn't exist
|
||||||
|
if tool_index not in toolcall_buffers:
|
||||||
|
toolcall_buffers[tool_index] = {
|
||||||
|
"call_id": tool_call.id,
|
||||||
|
"name": None,
|
||||||
|
"content": "",
|
||||||
|
"arguments": "",
|
||||||
|
"complete": False,
|
||||||
|
}
|
||||||
|
|
||||||
|
buffer = toolcall_buffers[tool_index]
|
||||||
|
|
||||||
|
# Handle function name
|
||||||
|
if tool_call.function and tool_call.function.name:
|
||||||
|
buffer["name"] = tool_call.function.name
|
||||||
|
delta = f"{buffer['name']}("
|
||||||
|
buffer["content"] += delta
|
||||||
|
|
||||||
|
# Emit the function name
|
||||||
|
yield ChatCompletionResponseStreamChunk(
|
||||||
|
event=ChatCompletionResponseEvent(
|
||||||
|
event_type=ChatCompletionResponseEventType.progress,
|
||||||
|
delta=ToolCallDelta(
|
||||||
|
tool_call=delta,
|
||||||
|
parse_status=ToolCallParseStatus.in_progress,
|
||||||
|
),
|
||||||
|
logprobs=_convert_openai_logprobs(logprobs),
|
||||||
|
)
|
||||||
)
|
)
|
||||||
)
|
|
||||||
else:
|
|
||||||
tool_call = choice.delta.tool_calls[0]
|
|
||||||
if "name" not in toolcall_buffer:
|
|
||||||
toolcall_buffer["call_id"] = tool_call.id
|
|
||||||
toolcall_buffer["name"] = None
|
|
||||||
toolcall_buffer["content"] = ""
|
|
||||||
if "arguments" not in toolcall_buffer:
|
|
||||||
toolcall_buffer["arguments"] = ""
|
|
||||||
|
|
||||||
if tool_call.function.name:
|
# Handle function arguments
|
||||||
toolcall_buffer["name"] = tool_call.function.name
|
if tool_call.function and tool_call.function.arguments:
|
||||||
delta = f"{toolcall_buffer['name']}("
|
delta = tool_call.function.arguments
|
||||||
if tool_call.function.arguments:
|
buffer["arguments"] += delta
|
||||||
toolcall_buffer["arguments"] += tool_call.function.arguments
|
buffer["content"] += delta
|
||||||
delta = toolcall_buffer["arguments"]
|
|
||||||
|
|
||||||
toolcall_buffer["content"] += delta
|
# Emit the argument fragment
|
||||||
yield ChatCompletionResponseStreamChunk(
|
yield ChatCompletionResponseStreamChunk(
|
||||||
event=ChatCompletionResponseEvent(
|
event=ChatCompletionResponseEvent(
|
||||||
event_type=next(event_type),
|
event_type=ChatCompletionResponseEventType.progress,
|
||||||
delta=ToolCallDelta(
|
delta=ToolCallDelta(
|
||||||
tool_call=delta,
|
tool_call=delta,
|
||||||
parse_status=ToolCallParseStatus.in_progress,
|
parse_status=ToolCallParseStatus.in_progress,
|
||||||
),
|
),
|
||||||
logprobs=_convert_openai_logprobs(logprobs),
|
logprobs=_convert_openai_logprobs(logprobs),
|
||||||
|
)
|
||||||
)
|
)
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
yield ChatCompletionResponseStreamChunk(
|
yield ChatCompletionResponseStreamChunk(
|
||||||
event=ChatCompletionResponseEvent(
|
event=ChatCompletionResponseEvent(
|
||||||
event_type=next(event_type),
|
event_type=ChatCompletionResponseEventType.progress,
|
||||||
delta=TextDelta(text=choice.delta.content or ""),
|
delta=TextDelta(text=choice.delta.content or ""),
|
||||||
logprobs=_convert_openai_logprobs(logprobs),
|
logprobs=_convert_openai_logprobs(logprobs),
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
if toolcall_buffer:
|
# Final complete event if no tool calls were processed
|
||||||
delta = ")"
|
if toolcall_buffers:
|
||||||
toolcall_buffer["content"] += delta
|
# Process all tool calls that haven't been completed yet
|
||||||
yield ChatCompletionResponseStreamChunk(
|
for tool_index in sorted(toolcall_buffers.keys()):
|
||||||
event=ChatCompletionResponseEvent(
|
if tool_index not in completed_tool_indices:
|
||||||
event_type=next(event_type),
|
# Complete this tool call
|
||||||
delta=ToolCallDelta(
|
async for event in _complete_tool_call(toolcall_buffers[tool_index], logprobs, stop_reason):
|
||||||
tool_call=delta,
|
yield event
|
||||||
parse_status=ToolCallParseStatus.in_progress,
|
completed_tool_indices.add(tool_index)
|
||||||
),
|
|
||||||
logprobs=_convert_openai_logprobs(logprobs),
|
|
||||||
)
|
|
||||||
)
|
|
||||||
try:
|
|
||||||
arguments = json.loads(toolcall_buffer["arguments"])
|
|
||||||
tool_call = ToolCall(
|
|
||||||
call_id=toolcall_buffer["call_id"],
|
|
||||||
tool_name=toolcall_buffer["name"],
|
|
||||||
arguments=arguments,
|
|
||||||
)
|
|
||||||
yield ChatCompletionResponseStreamChunk(
|
|
||||||
event=ChatCompletionResponseEvent(
|
|
||||||
event_type=ChatCompletionResponseEventType.complete,
|
|
||||||
delta=ToolCallDelta(
|
|
||||||
tool_call=tool_call,
|
|
||||||
parse_status=ToolCallParseStatus.succeeded,
|
|
||||||
),
|
|
||||||
stop_reason=stop_reason,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
except json.JSONDecodeError:
|
|
||||||
yield ChatCompletionResponseStreamChunk(
|
|
||||||
event=ChatCompletionResponseEvent(
|
|
||||||
event_type=ChatCompletionResponseEventType.complete,
|
|
||||||
delta=ToolCallDelta(
|
|
||||||
tool_call=toolcall_buffer["content"],
|
|
||||||
parse_status=ToolCallParseStatus.failed,
|
|
||||||
),
|
|
||||||
stop_reason=stop_reason,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
yield ChatCompletionResponseStreamChunk(
|
yield ChatCompletionResponseStreamChunk(
|
||||||
event=ChatCompletionResponseEvent(
|
event=ChatCompletionResponseEvent(
|
||||||
|
@ -965,3 +970,55 @@ async def convert_openai_chat_completion_stream(
|
||||||
stop_reason=stop_reason,
|
stop_reason=stop_reason,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def _complete_tool_call(buffer, logprobs, stop_reason):
|
||||||
|
"""Helper function to complete a tool call and yield the appropriate events."""
|
||||||
|
# Add closing parenthesis
|
||||||
|
delta = ")"
|
||||||
|
buffer["content"] += delta
|
||||||
|
yield ChatCompletionResponseStreamChunk(
|
||||||
|
event=ChatCompletionResponseEvent(
|
||||||
|
event_type=ChatCompletionResponseEventType.progress,
|
||||||
|
delta=ToolCallDelta(
|
||||||
|
tool_call=delta,
|
||||||
|
parse_status=ToolCallParseStatus.in_progress,
|
||||||
|
),
|
||||||
|
logprobs=_convert_openai_logprobs(logprobs),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Parse the arguments
|
||||||
|
arguments = json.loads(buffer["arguments"])
|
||||||
|
tool_call = ToolCall(
|
||||||
|
call_id=buffer["call_id"],
|
||||||
|
tool_name=buffer["name"],
|
||||||
|
arguments=arguments,
|
||||||
|
)
|
||||||
|
|
||||||
|
yield ChatCompletionResponseStreamChunk(
|
||||||
|
event=ChatCompletionResponseEvent(
|
||||||
|
event_type=ChatCompletionResponseEventType.progress,
|
||||||
|
delta=ToolCallDelta(
|
||||||
|
tool_call=tool_call,
|
||||||
|
parse_status=ToolCallParseStatus.succeeded,
|
||||||
|
),
|
||||||
|
stop_reason=stop_reason,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
print(f"Failed to parse tool call arguments: {buffer['arguments']}")
|
||||||
|
|
||||||
|
event_type_to_use = ChatCompletionResponseEventType.complete
|
||||||
|
|
||||||
|
yield ChatCompletionResponseStreamChunk(
|
||||||
|
event=ChatCompletionResponseEvent(
|
||||||
|
event_type=event_type_to_use,
|
||||||
|
delta=ToolCallDelta(
|
||||||
|
tool_call=buffer["content"],
|
||||||
|
parse_status=ToolCallParseStatus.failed,
|
||||||
|
),
|
||||||
|
stop_reason=stop_reason,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue