fix: agents with non-llama model (#1550)

# Summary:
Includes fixes to get test_agents working with openAI model, e.g. tool
parsing and message conversion

# Test Plan:
```
LLAMA_STACK_CONFIG=dev pytest -s -v tests/integration/agents/test_agents.py --safety-shield meta-llama/Llama-Guard-3-8B --text-model openai/gpt-4o-mini
```

---
[//]: # (BEGIN SAPLING FOOTER)
Stack created with [Sapling](https://sapling-scm.com). Best reviewed
with
[ReviewStack](https://reviewstack.dev/meta-llama/llama-stack/pull/1550).
* #1556
* __->__ #1550
This commit is contained in:
ehhuang 2025-03-17 22:11:06 -07:00 committed by GitHub
parent 0bdfc71f8d
commit c23a7af5d6
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 161 additions and 125 deletions

View file

@ -891,16 +891,14 @@ class ChatAgent(ShieldRunnerMixin):
if memory_tool and code_interpreter_tool: if memory_tool and code_interpreter_tool:
# if both memory and code_interpreter are available, we download the URLs # if both memory and code_interpreter are available, we download the URLs
# and attach the data to the last message. # and attach the data to the last message.
msg = await attachment_message(self.tempdir, url_items) await attachment_message(self.tempdir, url_items, input_messages[-1])
input_messages.append(msg)
# Since memory is present, add all the data to the memory bank # Since memory is present, add all the data to the memory bank
await self.add_to_session_vector_db(session_id, documents) await self.add_to_session_vector_db(session_id, documents)
elif code_interpreter_tool: elif code_interpreter_tool:
# if only code_interpreter is available, we download the URLs to a tempdir # if only code_interpreter is available, we download the URLs to a tempdir
# and attach the path to them as a message to inference with the # and attach the path to them as a message to inference with the
# assumption that the model invokes the code_interpreter tool with the path # assumption that the model invokes the code_interpreter tool with the path
msg = await attachment_message(self.tempdir, url_items) await attachment_message(self.tempdir, url_items, input_messages[-1])
input_messages.append(msg)
elif memory_tool: elif memory_tool:
# if only memory is available, we load the data from the URLs and content items to the memory bank # if only memory is available, we load the data from the URLs and content items to the memory bank
await self.add_to_session_vector_db(session_id, documents) await self.add_to_session_vector_db(session_id, documents)
@ -967,8 +965,8 @@ async def load_data_from_urls(urls: List[URL]) -> List[str]:
return data return data
async def attachment_message(tempdir: str, urls: List[URL]) -> ToolResponseMessage: async def attachment_message(tempdir: str, urls: List[URL], message: UserMessage) -> None:
content = [] contents = []
for url in urls: for url in urls:
uri = url.uri uri = url.uri
@ -988,16 +986,19 @@ async def attachment_message(tempdir: str, urls: List[URL]) -> ToolResponseMessa
else: else:
raise ValueError(f"Unsupported URL {url}") raise ValueError(f"Unsupported URL {url}")
content.append( contents.append(
TextContentItem( TextContentItem(
text=f'# User provided a file accessible to you at "{filepath}"\nYou can use code_interpreter to load and inspect it.' text=f'# User provided a file accessible to you at "{filepath}"\nYou can use code_interpreter to load and inspect it.'
) )
) )
return ToolResponseMessage( if isinstance(message.content, list):
call_id="", message.content.extend(contents)
content=content, else:
) if isinstance(message.content, str):
message.content = [TextContentItem(text=message.content)] + contents
else:
message.content = [message.content] + contents
def _interpret_content_as_attachment( def _interpret_content_as_attachment(

View file

@ -192,7 +192,11 @@ class LiteLLMOpenAIMixin(
if request.tools: if request.tools:
input_dict["tools"] = [convert_tooldef_to_openai_tool(tool) for tool in request.tools] input_dict["tools"] = [convert_tooldef_to_openai_tool(tool) for tool in request.tools]
if request.tool_config.tool_choice: if request.tool_config.tool_choice:
input_dict["tool_choice"] = request.tool_config.tool_choice.value input_dict["tool_choice"] = (
request.tool_config.tool_choice.value
if isinstance(request.tool_config.tool_choice, ToolChoice)
else request.tool_config.tool_choice
)
provider_data = self.get_request_provider_data() provider_data = self.get_request_provider_data()
key_field = self.provider_data_api_key_field key_field = self.provider_data_api_key_field

View file

@ -527,26 +527,30 @@ async def convert_message_to_openai_dict_new(
async def _convert_message_content( async def _convert_message_content(
content: InterleavedContent, content: InterleavedContent,
) -> Union[str, Iterable[OpenAIChatCompletionContentPartParam]]: ) -> Union[str, Iterable[OpenAIChatCompletionContentPartParam]]:
async def impl(): async def impl(
content_: InterleavedContent,
) -> Union[str, OpenAIChatCompletionContentPartParam, List[OpenAIChatCompletionContentPartParam]]:
# Llama Stack and OpenAI spec match for str and text input # Llama Stack and OpenAI spec match for str and text input
if isinstance(content, str): if isinstance(content_, str):
return content return content_
elif isinstance(content, TextContentItem): elif isinstance(content_, TextContentItem):
return OpenAIChatCompletionContentPartTextParam( return OpenAIChatCompletionContentPartTextParam(
type="text", type="text",
text=content.text, text=content_.text,
) )
elif isinstance(content, ImageContentItem): elif isinstance(content_, ImageContentItem):
return OpenAIChatCompletionContentPartImageParam( return OpenAIChatCompletionContentPartImageParam(
type="image_url", type="image_url",
image_url=OpenAIImageURL(url=await convert_image_content_to_url(content)), image_url=OpenAIImageURL(url=await convert_image_content_to_url(content_)),
) )
elif isinstance(content, list): elif isinstance(content_, list):
return [await _convert_message_content(item) for item in content] return [await impl(item) for item in content_]
else: else:
raise ValueError(f"Unsupported content type: {type(content)}") raise ValueError(f"Unsupported content type: {type(content_)}")
ret = await impl() ret = await impl(content)
# OpenAI*Message expects a str or list
if isinstance(ret, str) or isinstance(ret, list): if isinstance(ret, str) or isinstance(ret, list):
return ret return ret
else: else:
@ -566,13 +570,14 @@ async def convert_message_to_openai_dict_new(
OpenAIChatCompletionMessageToolCall( OpenAIChatCompletionMessageToolCall(
id=tool.call_id, id=tool.call_id,
function=OpenAIFunction( function=OpenAIFunction(
name=tool.tool_name, name=tool.tool_name if not isinstance(tool.tool_name, BuiltinTool) else tool.tool_name.value,
arguments=json.dumps(tool.arguments), arguments=json.dumps(tool.arguments),
), ),
type="function", type="function",
) )
for tool in message.tool_calls for tool in message.tool_calls
], ]
or None,
) )
elif isinstance(message, ToolResponseMessage): elif isinstance(message, ToolResponseMessage):
out = OpenAIChatCompletionToolMessage( out = OpenAIChatCompletionToolMessage(
@ -858,7 +863,8 @@ async def convert_openai_chat_completion_stream(
event_type = ChatCompletionResponseEventType.progress event_type = ChatCompletionResponseEventType.progress
stop_reason = None stop_reason = None
toolcall_buffer = {} tool_call_idx_to_buffer = {}
async for chunk in stream: async for chunk in stream:
choice = chunk.choices[0] # assuming only one choice per chunk choice = chunk.choices[0] # assuming only one choice per chunk
@ -868,7 +874,6 @@ async def convert_openai_chat_completion_stream(
# if there's a tool call, emit an event for each tool in the list # if there's a tool call, emit an event for each tool in the list
# if tool call and content, emit both separately # if tool call and content, emit both separately
if choice.delta.tool_calls: if choice.delta.tool_calls:
# the call may have content and a tool call. ChatCompletionResponseEvent # the call may have content and a tool call. ChatCompletionResponseEvent
# does not support both, so we emit the content first # does not support both, so we emit the content first
@ -889,33 +894,42 @@ async def convert_openai_chat_completion_stream(
) )
if not enable_incremental_tool_calls: if not enable_incremental_tool_calls:
for tool_call in choice.delta.tool_calls:
yield ChatCompletionResponseStreamChunk( yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent( event=ChatCompletionResponseEvent(
event_type=next(event_type), event_type=event_type,
delta=ToolCallDelta( delta=ToolCallDelta(
tool_call=_convert_openai_tool_calls(choice.delta.tool_calls)[0], tool_call=_convert_openai_tool_calls([tool_call])[0],
parse_status=ToolCallParseStatus.succeeded, parse_status=ToolCallParseStatus.succeeded,
), ),
logprobs=_convert_openai_logprobs(logprobs), logprobs=_convert_openai_logprobs(logprobs),
) )
) )
else: else:
tool_call = choice.delta.tool_calls[0] for tool_call in choice.delta.tool_calls:
if "name" not in toolcall_buffer: idx = tool_call.index if hasattr(tool_call, "index") else 0
toolcall_buffer["call_id"] = tool_call.id
toolcall_buffer["name"] = None
toolcall_buffer["content"] = ""
if "arguments" not in toolcall_buffer:
toolcall_buffer["arguments"] = ""
if idx not in tool_call_idx_to_buffer:
tool_call_idx_to_buffer[idx] = {
"call_id": tool_call.id,
"name": None,
"arguments": "",
"content": "",
}
buffer = tool_call_idx_to_buffer[idx]
if tool_call.function:
if tool_call.function.name: if tool_call.function.name:
toolcall_buffer["name"] = tool_call.function.name buffer["name"] = tool_call.function.name
delta = f"{toolcall_buffer['name']}(" delta = f"{buffer['name']}("
if tool_call.function.arguments: buffer["content"] += delta
toolcall_buffer["arguments"] += tool_call.function.arguments
delta = toolcall_buffer["arguments"] if tool_call.function.arguments:
delta = tool_call.function.arguments
buffer["arguments"] += delta
buffer["content"] += delta
toolcall_buffer["content"] += delta
yield ChatCompletionResponseStreamChunk( yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent( event=ChatCompletionResponseEvent(
event_type=event_type, event_type=event_type,
@ -926,7 +940,7 @@ async def convert_openai_chat_completion_stream(
logprobs=_convert_openai_logprobs(logprobs), logprobs=_convert_openai_logprobs(logprobs),
) )
) )
else: elif choice.delta.content:
yield ChatCompletionResponseStreamChunk( yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent( event=ChatCompletionResponseEvent(
event_type=event_type, event_type=event_type,
@ -935,9 +949,11 @@ async def convert_openai_chat_completion_stream(
) )
) )
if toolcall_buffer: for idx, buffer in tool_call_idx_to_buffer.items():
logger.debug(f"toolcall_buffer[{idx}]: {buffer}")
if buffer["name"]:
delta = ")" delta = ")"
toolcall_buffer["content"] += delta buffer["content"] += delta
yield ChatCompletionResponseStreamChunk( yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent( event=ChatCompletionResponseEvent(
event_type=event_type, event_type=event_type,
@ -945,14 +961,15 @@ async def convert_openai_chat_completion_stream(
tool_call=delta, tool_call=delta,
parse_status=ToolCallParseStatus.in_progress, parse_status=ToolCallParseStatus.in_progress,
), ),
logprobs=_convert_openai_logprobs(logprobs), logprobs=None,
) )
) )
try: try:
arguments = json.loads(toolcall_buffer["arguments"]) arguments = json.loads(buffer["arguments"])
tool_call = ToolCall( tool_call = ToolCall(
call_id=toolcall_buffer["call_id"], call_id=buffer["call_id"],
tool_name=toolcall_buffer["name"], tool_name=buffer["name"],
arguments=arguments, arguments=arguments,
) )
yield ChatCompletionResponseStreamChunk( yield ChatCompletionResponseStreamChunk(
@ -965,12 +982,13 @@ async def convert_openai_chat_completion_stream(
stop_reason=stop_reason, stop_reason=stop_reason,
) )
) )
except json.JSONDecodeError: except json.JSONDecodeError as e:
print(f"Failed to parse arguments: {e}")
yield ChatCompletionResponseStreamChunk( yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent( event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.complete, event_type=ChatCompletionResponseEventType.progress,
delta=ToolCallDelta( delta=ToolCallDelta(
tool_call=toolcall_buffer["content"], tool_call=buffer["content"],
parse_status=ToolCallParseStatus.failed, parse_status=ToolCallParseStatus.failed,
), ),
stop_reason=stop_reason, stop_reason=stop_reason,

View file

@ -271,7 +271,7 @@ def test_custom_tool(llama_stack_client_with_mocked_inference, agent_config):
client_tool = get_boiling_point client_tool = get_boiling_point
agent_config = { agent_config = {
**agent_config, **agent_config,
"tools": ["builtin::websearch", client_tool], "tools": [client_tool],
} }
agent = Agent(llama_stack_client_with_mocked_inference, **agent_config) agent = Agent(llama_stack_client_with_mocked_inference, **agent_config)
@ -320,17 +320,39 @@ def test_custom_tool_infinite_loop(llama_stack_client_with_mocked_inference, age
assert num_tool_calls <= 5 assert num_tool_calls <= 5
def test_tool_choice(llama_stack_client_with_mocked_inference, agent_config): def test_tool_choice_required(llama_stack_client_with_mocked_inference, agent_config):
def run_agent(tool_choice): tool_execution_steps = run_agent_with_tool_choice(
llama_stack_client_with_mocked_inference, agent_config, "required"
)
assert len(tool_execution_steps) > 0
def test_tool_choice_none(llama_stack_client_with_mocked_inference, agent_config):
tool_execution_steps = run_agent_with_tool_choice(llama_stack_client_with_mocked_inference, agent_config, "none")
assert len(tool_execution_steps) == 0
def test_tool_choice_get_boiling_point(llama_stack_client_with_mocked_inference, agent_config):
if "llama" not in agent_config["model"].lower():
pytest.xfail("NotImplemented for non-llama models")
tool_execution_steps = run_agent_with_tool_choice(
llama_stack_client_with_mocked_inference, agent_config, "get_boiling_point"
)
assert len(tool_execution_steps) >= 1 and tool_execution_steps[0].tool_calls[0].tool_name == "get_boiling_point"
def run_agent_with_tool_choice(client, agent_config, tool_choice):
client_tool = get_boiling_point client_tool = get_boiling_point
test_agent_config = { test_agent_config = {
**agent_config, **agent_config,
"tool_config": {"tool_choice": tool_choice}, "tool_config": {"tool_choice": tool_choice},
"tools": [client_tool], "tools": [client_tool],
"max_infer_iters": 2,
} }
agent = Agent(llama_stack_client_with_mocked_inference, **test_agent_config) agent = Agent(client, **test_agent_config)
session_id = agent.create_session(f"test-session-{uuid4()}") session_id = agent.create_session(f"test-session-{uuid4()}")
response = agent.create_turn( response = agent.create_turn(
@ -346,15 +368,6 @@ def test_tool_choice(llama_stack_client_with_mocked_inference, agent_config):
return [step for step in response.steps if step.step_type == "tool_execution"] return [step for step in response.steps if step.step_type == "tool_execution"]
tool_execution_steps = run_agent("required")
assert len(tool_execution_steps) > 0
tool_execution_steps = run_agent("none")
assert len(tool_execution_steps) == 0
tool_execution_steps = run_agent("get_boiling_point")
assert len(tool_execution_steps) >= 1 and tool_execution_steps[0].tool_calls[0].tool_name == "get_boiling_point"
@pytest.mark.parametrize("rag_tool_name", ["builtin::rag/knowledge_search", "builtin::rag"]) @pytest.mark.parametrize("rag_tool_name", ["builtin::rag/knowledge_search", "builtin::rag"])
def test_rag_agent(llama_stack_client_with_mocked_inference, agent_config, rag_tool_name): def test_rag_agent(llama_stack_client_with_mocked_inference, agent_config, rag_tool_name):