mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-12 13:00:39 +00:00
temp fix
Summary: Test Plan:
This commit is contained in:
parent
270d64007a
commit
f389afe024
4 changed files with 250 additions and 166 deletions
|
@ -656,30 +656,27 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
)
|
||||
)
|
||||
)
|
||||
tool_call = message.tool_calls[0]
|
||||
yield AgentTurnResponseStreamChunk(
|
||||
event=AgentTurnResponseEvent(
|
||||
payload=AgentTurnResponseStepProgressPayload(
|
||||
step_type=StepType.tool_execution.value,
|
||||
step_id=step_id,
|
||||
tool_call=tool_call,
|
||||
delta=ToolCallDelta(
|
||||
parse_status=ToolCallParseStatus.in_progress,
|
||||
tool_call=tool_call,
|
||||
),
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
# If tool is a client tool, yield CompletionMessage and return
|
||||
if tool_call.tool_name in client_tools:
|
||||
# Process all tool calls instead of just the first one
|
||||
tool_responses = []
|
||||
tool_execution_start_time = datetime.now().astimezone().isoformat()
|
||||
|
||||
# Check if any tool is a client tool
|
||||
client_tool_found = False
|
||||
for tool_call in message.tool_calls:
|
||||
if tool_call.tool_name in client_tools:
|
||||
client_tool_found = True
|
||||
break
|
||||
|
||||
# If any tool is a client tool, yield CompletionMessage and return
|
||||
if client_tool_found:
|
||||
await self.storage.set_in_progress_tool_call_step(
|
||||
session_id,
|
||||
turn_id,
|
||||
ToolExecutionStep(
|
||||
step_id=step_id,
|
||||
turn_id=turn_id,
|
||||
tool_calls=[tool_call],
|
||||
tool_calls=message.tool_calls,
|
||||
tool_responses=[],
|
||||
started_at=datetime.now().astimezone().isoformat(),
|
||||
),
|
||||
|
@ -687,41 +684,86 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
yield message
|
||||
return
|
||||
|
||||
# If tool is a builtin server tool, execute it
|
||||
tool_name = tool_call.tool_name
|
||||
if isinstance(tool_name, BuiltinTool):
|
||||
tool_name = tool_name.value
|
||||
with tracing.span(
|
||||
"tool_execution",
|
||||
{
|
||||
"tool_name": tool_name,
|
||||
"input": message.model_dump_json(),
|
||||
},
|
||||
) as span:
|
||||
tool_execution_start_time = datetime.now().astimezone().isoformat()
|
||||
tool_call = message.tool_calls[0]
|
||||
tool_result = await execute_tool_call_maybe(
|
||||
self.tool_runtime_api,
|
||||
session_id,
|
||||
tool_call,
|
||||
toolgroup_args,
|
||||
tool_to_group,
|
||||
)
|
||||
if tool_result.content is None:
|
||||
raise ValueError(
|
||||
f"Tool call result (id: {tool_call.call_id}, name: {tool_call.tool_name}) does not have any content"
|
||||
# Add the original message with tool calls to input_messages before processing tool calls
|
||||
input_messages.append(message)
|
||||
|
||||
# Process all tool calls
|
||||
for tool_call in message.tool_calls:
|
||||
yield AgentTurnResponseStreamChunk(
|
||||
event=AgentTurnResponseEvent(
|
||||
payload=AgentTurnResponseStepProgressPayload(
|
||||
step_type=StepType.tool_execution.value,
|
||||
step_id=step_id,
|
||||
tool_call=tool_call,
|
||||
delta=ToolCallDelta(
|
||||
parse_status=ToolCallParseStatus.in_progress,
|
||||
tool_call=tool_call,
|
||||
),
|
||||
)
|
||||
)
|
||||
result_messages = [
|
||||
ToolResponseMessage(
|
||||
)
|
||||
|
||||
# Execute the tool call
|
||||
tool_name = tool_call.tool_name
|
||||
if isinstance(tool_name, BuiltinTool):
|
||||
tool_name = tool_name.value
|
||||
with tracing.span(
|
||||
"tool_execution",
|
||||
{
|
||||
"tool_name": tool_name,
|
||||
"input": tool_call.model_dump_json()
|
||||
if hasattr(tool_call, "model_dump_json")
|
||||
else str(tool_call),
|
||||
},
|
||||
) as span:
|
||||
tool_result = await execute_tool_call_maybe(
|
||||
self.tool_runtime_api,
|
||||
session_id,
|
||||
tool_call,
|
||||
toolgroup_args,
|
||||
tool_to_group,
|
||||
)
|
||||
if tool_result.content is None:
|
||||
raise ValueError(
|
||||
f"Tool call result (id: {tool_call.call_id}, name: {tool_call.tool_name}) does not have any content"
|
||||
)
|
||||
|
||||
result_message = ToolResponseMessage(
|
||||
call_id=tool_call.call_id,
|
||||
tool_name=tool_call.tool_name,
|
||||
content=tool_result.content,
|
||||
)
|
||||
]
|
||||
assert len(result_messages) == 1, "Currently not supporting multiple messages"
|
||||
result_message = result_messages[0]
|
||||
span.set_attribute("output", result_message.model_dump_json())
|
||||
|
||||
tool_responses.append(
|
||||
ToolResponse(
|
||||
call_id=result_message.call_id,
|
||||
tool_name=result_message.tool_name,
|
||||
content=result_message.content,
|
||||
metadata=tool_result.metadata,
|
||||
)
|
||||
)
|
||||
|
||||
span.set_attribute(
|
||||
"output",
|
||||
result_message.model_dump_json()
|
||||
if hasattr(result_message, "model_dump_json")
|
||||
else str(result_message),
|
||||
)
|
||||
|
||||
# TODO: add tool-input touchpoint and a "start" event for this step also
|
||||
# but that needs a lot more refactoring of Tool code potentially
|
||||
if (type(result_message.content) is str) and (
|
||||
out_attachment := _interpret_content_as_attachment(result_message.content)
|
||||
):
|
||||
# NOTE: when we push this message back to the model, the model may ignore the
|
||||
# attached file path etc. since the model is trained to only provide a user message
|
||||
# with the summary. We keep all generated attachments and then attach them to final message
|
||||
output_attachments.append(out_attachment)
|
||||
|
||||
# Add the result message to input_messages
|
||||
input_messages.append(result_message)
|
||||
|
||||
# Complete the tool execution step
|
||||
yield AgentTurnResponseStreamChunk(
|
||||
event=AgentTurnResponseEvent(
|
||||
payload=AgentTurnResponseStepCompletePayload(
|
||||
|
@ -730,15 +772,8 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
step_details=ToolExecutionStep(
|
||||
step_id=step_id,
|
||||
turn_id=turn_id,
|
||||
tool_calls=[tool_call],
|
||||
tool_responses=[
|
||||
ToolResponse(
|
||||
call_id=result_message.call_id,
|
||||
tool_name=result_message.tool_name,
|
||||
content=result_message.content,
|
||||
metadata=tool_result.metadata,
|
||||
)
|
||||
],
|
||||
tool_calls=message.tool_calls,
|
||||
tool_responses=tool_responses,
|
||||
started_at=tool_execution_start_time,
|
||||
completed_at=datetime.now().astimezone().isoformat(),
|
||||
),
|
||||
|
@ -746,18 +781,6 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
)
|
||||
)
|
||||
|
||||
# TODO: add tool-input touchpoint and a "start" event for this step also
|
||||
# but that needs a lot more refactoring of Tool code potentially
|
||||
if (type(result_message.content) is str) and (
|
||||
out_attachment := _interpret_content_as_attachment(result_message.content)
|
||||
):
|
||||
# NOTE: when we push this message back to the model, the model may ignore the
|
||||
# attached file path etc. since the model is trained to only provide a user message
|
||||
# with the summary. We keep all generated attachments and then attach them to final message
|
||||
output_attachments.append(out_attachment)
|
||||
|
||||
input_messages = input_messages + [message, result_message]
|
||||
|
||||
n_iter += 1
|
||||
|
||||
async def _get_tool_defs(
|
||||
|
|
|
@ -127,7 +127,7 @@ class MemoryToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, RAGToolRuntime):
|
|||
tokens = 0
|
||||
picked = [
|
||||
TextContentItem(
|
||||
text=f"knowledge_search tool found {len(chunks)} chunks:\nBEGIN of knowledge_search tool results.\n"
|
||||
text=f"knowledge_search tool found {len(chunks)} chunks for query:\n{query}\nBEGIN of knowledge_search tool results.\n"
|
||||
)
|
||||
]
|
||||
for i, c in enumerate(chunks):
|
||||
|
|
|
@ -99,6 +99,10 @@ class LiteLLMOpenAIMixin(
|
|||
params = await self._get_params(request)
|
||||
# unfortunately, we need to use synchronous litellm.completion here because litellm
|
||||
# caches various httpx.client objects in a non-eventloop aware manner
|
||||
|
||||
from rich.pretty import pprint
|
||||
|
||||
pprint(params)
|
||||
response = litellm.completion(**params)
|
||||
if stream:
|
||||
return self._stream_chat_completion(response)
|
||||
|
|
|
@ -523,10 +523,11 @@ async def convert_message_to_openai_dict_new(message: Message | Dict) -> OpenAIC
|
|||
) -> Union[str, Iterable[OpenAIChatCompletionContentPartParam]]:
|
||||
# Llama Stack and OpenAI spec match for str and text input
|
||||
if isinstance(content, str):
|
||||
return OpenAIChatCompletionContentPartTextParam(
|
||||
type="text",
|
||||
text=content,
|
||||
)
|
||||
# return OpenAIChatCompletionContentPartTextParam(
|
||||
# type="text",
|
||||
# text=content,
|
||||
# )
|
||||
return content
|
||||
elif isinstance(content, TextContentItem):
|
||||
return OpenAIChatCompletionContentPartTextParam(
|
||||
type="text",
|
||||
|
@ -568,12 +569,12 @@ async def convert_message_to_openai_dict_new(message: Message | Dict) -> OpenAIC
|
|||
out = OpenAIChatCompletionToolMessage(
|
||||
role="tool",
|
||||
tool_call_id=message.call_id,
|
||||
content=message.content,
|
||||
content=await _convert_user_message_content(message.content),
|
||||
)
|
||||
elif isinstance(message, SystemMessage):
|
||||
out = OpenAIChatCompletionSystemMessage(
|
||||
role="system",
|
||||
content=message.content,
|
||||
content=await _convert_user_message_content(message.content),
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unsupported message type: {type(message)}")
|
||||
|
@ -831,18 +832,26 @@ async def convert_openai_chat_completion_stream(
|
|||
Convert a stream of OpenAI chat completion chunks into a stream
|
||||
of ChatCompletionResponseStreamChunk.
|
||||
"""
|
||||
|
||||
# generate a stream of ChatCompletionResponseEventType: start -> progress -> progress -> ...
|
||||
def _event_type_generator() -> Generator[ChatCompletionResponseEventType, None, None]:
|
||||
yield ChatCompletionResponseEventType.start
|
||||
while True:
|
||||
yield ChatCompletionResponseEventType.progress
|
||||
|
||||
event_type = _event_type_generator()
|
||||
|
||||
stop_reason = None
|
||||
toolcall_buffer = {}
|
||||
# Use a dictionary to track multiple tool calls by their index
|
||||
toolcall_buffers = {}
|
||||
# Track which tool calls have been completed
|
||||
completed_tool_indices = set()
|
||||
# Track the highest index seen so far
|
||||
highest_index_seen = -1
|
||||
|
||||
yield ChatCompletionResponseStreamChunk(
|
||||
event=ChatCompletionResponseEvent(
|
||||
event_type=ChatCompletionResponseEventType.start,
|
||||
delta=TextDelta(text=""),
|
||||
stop_reason=stop_reason,
|
||||
)
|
||||
)
|
||||
|
||||
async for chunk in stream:
|
||||
from rich.pretty import pprint
|
||||
|
||||
pprint(chunk)
|
||||
choice = chunk.choices[0] # assuming only one choice per chunk
|
||||
|
||||
# we assume there's only one finish_reason in the stream
|
||||
|
@ -851,112 +860,108 @@ async def convert_openai_chat_completion_stream(
|
|||
|
||||
# if there's a tool call, emit an event for each tool in the list
|
||||
# if tool call and content, emit both separately
|
||||
|
||||
if choice.delta.tool_calls:
|
||||
# the call may have content and a tool call. ChatCompletionResponseEvent
|
||||
# does not support both, so we emit the content first
|
||||
if choice.delta.content:
|
||||
yield ChatCompletionResponseStreamChunk(
|
||||
event=ChatCompletionResponseEvent(
|
||||
event_type=next(event_type),
|
||||
event_type=ChatCompletionResponseEventType.progress,
|
||||
delta=TextDelta(text=choice.delta.content),
|
||||
logprobs=_convert_openai_logprobs(logprobs),
|
||||
)
|
||||
)
|
||||
|
||||
# it is possible to have parallel tool calls in stream, but
|
||||
# ChatCompletionResponseEvent only supports one per stream
|
||||
if len(choice.delta.tool_calls) > 1:
|
||||
warnings.warn("multiple tool calls found in a single delta, using the first, ignoring the rest")
|
||||
# Process each tool call in the delta
|
||||
for tool_call in choice.delta.tool_calls:
|
||||
# Get the tool call index
|
||||
tool_index = getattr(tool_call, "index", 0)
|
||||
|
||||
if not enable_incremental_tool_calls:
|
||||
yield ChatCompletionResponseStreamChunk(
|
||||
event=ChatCompletionResponseEvent(
|
||||
event_type=next(event_type),
|
||||
delta=ToolCallDelta(
|
||||
tool_call=_convert_openai_tool_calls(choice.delta.tool_calls)[0],
|
||||
parse_status=ToolCallParseStatus.succeeded,
|
||||
),
|
||||
logprobs=_convert_openai_logprobs(logprobs),
|
||||
# If we see a new higher index, complete all previous tool calls
|
||||
if tool_index > highest_index_seen:
|
||||
# Complete all previous tool calls
|
||||
for prev_index in range(highest_index_seen + 1):
|
||||
if prev_index in toolcall_buffers and prev_index not in completed_tool_indices:
|
||||
# Complete this tool call
|
||||
async for event in _complete_tool_call(
|
||||
toolcall_buffers[prev_index],
|
||||
logprobs,
|
||||
None, # No stop_reason for intermediate tool calls
|
||||
):
|
||||
yield event
|
||||
completed_tool_indices.add(prev_index)
|
||||
|
||||
highest_index_seen = tool_index
|
||||
|
||||
# Skip if this tool call has already been completed
|
||||
if tool_index in completed_tool_indices:
|
||||
continue
|
||||
|
||||
# Initialize buffer for this tool call if it doesn't exist
|
||||
if tool_index not in toolcall_buffers:
|
||||
toolcall_buffers[tool_index] = {
|
||||
"call_id": tool_call.id,
|
||||
"name": None,
|
||||
"content": "",
|
||||
"arguments": "",
|
||||
"complete": False,
|
||||
}
|
||||
|
||||
buffer = toolcall_buffers[tool_index]
|
||||
|
||||
# Handle function name
|
||||
if tool_call.function and tool_call.function.name:
|
||||
buffer["name"] = tool_call.function.name
|
||||
delta = f"{buffer['name']}("
|
||||
buffer["content"] += delta
|
||||
|
||||
# Emit the function name
|
||||
yield ChatCompletionResponseStreamChunk(
|
||||
event=ChatCompletionResponseEvent(
|
||||
event_type=ChatCompletionResponseEventType.progress,
|
||||
delta=ToolCallDelta(
|
||||
tool_call=delta,
|
||||
parse_status=ToolCallParseStatus.in_progress,
|
||||
),
|
||||
logprobs=_convert_openai_logprobs(logprobs),
|
||||
)
|
||||
)
|
||||
)
|
||||
else:
|
||||
tool_call = choice.delta.tool_calls[0]
|
||||
if "name" not in toolcall_buffer:
|
||||
toolcall_buffer["call_id"] = tool_call.id
|
||||
toolcall_buffer["name"] = None
|
||||
toolcall_buffer["content"] = ""
|
||||
if "arguments" not in toolcall_buffer:
|
||||
toolcall_buffer["arguments"] = ""
|
||||
|
||||
if tool_call.function.name:
|
||||
toolcall_buffer["name"] = tool_call.function.name
|
||||
delta = f"{toolcall_buffer['name']}("
|
||||
if tool_call.function.arguments:
|
||||
toolcall_buffer["arguments"] += tool_call.function.arguments
|
||||
delta = toolcall_buffer["arguments"]
|
||||
# Handle function arguments
|
||||
if tool_call.function and tool_call.function.arguments:
|
||||
delta = tool_call.function.arguments
|
||||
buffer["arguments"] += delta
|
||||
buffer["content"] += delta
|
||||
|
||||
toolcall_buffer["content"] += delta
|
||||
yield ChatCompletionResponseStreamChunk(
|
||||
event=ChatCompletionResponseEvent(
|
||||
event_type=next(event_type),
|
||||
delta=ToolCallDelta(
|
||||
tool_call=delta,
|
||||
parse_status=ToolCallParseStatus.in_progress,
|
||||
),
|
||||
logprobs=_convert_openai_logprobs(logprobs),
|
||||
# Emit the argument fragment
|
||||
yield ChatCompletionResponseStreamChunk(
|
||||
event=ChatCompletionResponseEvent(
|
||||
event_type=ChatCompletionResponseEventType.progress,
|
||||
delta=ToolCallDelta(
|
||||
tool_call=delta,
|
||||
parse_status=ToolCallParseStatus.in_progress,
|
||||
),
|
||||
logprobs=_convert_openai_logprobs(logprobs),
|
||||
)
|
||||
)
|
||||
)
|
||||
else:
|
||||
yield ChatCompletionResponseStreamChunk(
|
||||
event=ChatCompletionResponseEvent(
|
||||
event_type=next(event_type),
|
||||
event_type=ChatCompletionResponseEventType.progress,
|
||||
delta=TextDelta(text=choice.delta.content or ""),
|
||||
logprobs=_convert_openai_logprobs(logprobs),
|
||||
)
|
||||
)
|
||||
|
||||
if toolcall_buffer:
|
||||
delta = ")"
|
||||
toolcall_buffer["content"] += delta
|
||||
yield ChatCompletionResponseStreamChunk(
|
||||
event=ChatCompletionResponseEvent(
|
||||
event_type=next(event_type),
|
||||
delta=ToolCallDelta(
|
||||
tool_call=delta,
|
||||
parse_status=ToolCallParseStatus.in_progress,
|
||||
),
|
||||
logprobs=_convert_openai_logprobs(logprobs),
|
||||
)
|
||||
)
|
||||
try:
|
||||
arguments = json.loads(toolcall_buffer["arguments"])
|
||||
tool_call = ToolCall(
|
||||
call_id=toolcall_buffer["call_id"],
|
||||
tool_name=toolcall_buffer["name"],
|
||||
arguments=arguments,
|
||||
)
|
||||
yield ChatCompletionResponseStreamChunk(
|
||||
event=ChatCompletionResponseEvent(
|
||||
event_type=ChatCompletionResponseEventType.complete,
|
||||
delta=ToolCallDelta(
|
||||
tool_call=tool_call,
|
||||
parse_status=ToolCallParseStatus.succeeded,
|
||||
),
|
||||
stop_reason=stop_reason,
|
||||
)
|
||||
)
|
||||
except json.JSONDecodeError:
|
||||
yield ChatCompletionResponseStreamChunk(
|
||||
event=ChatCompletionResponseEvent(
|
||||
event_type=ChatCompletionResponseEventType.complete,
|
||||
delta=ToolCallDelta(
|
||||
tool_call=toolcall_buffer["content"],
|
||||
parse_status=ToolCallParseStatus.failed,
|
||||
),
|
||||
stop_reason=stop_reason,
|
||||
)
|
||||
)
|
||||
# Final complete event if no tool calls were processed
|
||||
if toolcall_buffers:
|
||||
# Process all tool calls that haven't been completed yet
|
||||
for tool_index in sorted(toolcall_buffers.keys()):
|
||||
if tool_index not in completed_tool_indices:
|
||||
# Complete this tool call
|
||||
async for event in _complete_tool_call(toolcall_buffers[tool_index], logprobs, stop_reason):
|
||||
yield event
|
||||
completed_tool_indices.add(tool_index)
|
||||
|
||||
yield ChatCompletionResponseStreamChunk(
|
||||
event=ChatCompletionResponseEvent(
|
||||
|
@ -965,3 +970,55 @@ async def convert_openai_chat_completion_stream(
|
|||
stop_reason=stop_reason,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
async def _complete_tool_call(buffer, logprobs, stop_reason):
|
||||
"""Helper function to complete a tool call and yield the appropriate events."""
|
||||
# Add closing parenthesis
|
||||
delta = ")"
|
||||
buffer["content"] += delta
|
||||
yield ChatCompletionResponseStreamChunk(
|
||||
event=ChatCompletionResponseEvent(
|
||||
event_type=ChatCompletionResponseEventType.progress,
|
||||
delta=ToolCallDelta(
|
||||
tool_call=delta,
|
||||
parse_status=ToolCallParseStatus.in_progress,
|
||||
),
|
||||
logprobs=_convert_openai_logprobs(logprobs),
|
||||
)
|
||||
)
|
||||
|
||||
try:
|
||||
# Parse the arguments
|
||||
arguments = json.loads(buffer["arguments"])
|
||||
tool_call = ToolCall(
|
||||
call_id=buffer["call_id"],
|
||||
tool_name=buffer["name"],
|
||||
arguments=arguments,
|
||||
)
|
||||
|
||||
yield ChatCompletionResponseStreamChunk(
|
||||
event=ChatCompletionResponseEvent(
|
||||
event_type=ChatCompletionResponseEventType.progress,
|
||||
delta=ToolCallDelta(
|
||||
tool_call=tool_call,
|
||||
parse_status=ToolCallParseStatus.succeeded,
|
||||
),
|
||||
stop_reason=stop_reason,
|
||||
)
|
||||
)
|
||||
except json.JSONDecodeError:
|
||||
print(f"Failed to parse tool call arguments: {buffer['arguments']}")
|
||||
|
||||
event_type_to_use = ChatCompletionResponseEventType.complete
|
||||
|
||||
yield ChatCompletionResponseStreamChunk(
|
||||
event=ChatCompletionResponseEvent(
|
||||
event_type=event_type_to_use,
|
||||
delta=ToolCallDelta(
|
||||
tool_call=buffer["content"],
|
||||
parse_status=ToolCallParseStatus.failed,
|
||||
),
|
||||
stop_reason=stop_reason,
|
||||
)
|
||||
)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue