forked from phoenix-oss/llama-stack-mirror
fix: agents with non-llama model (#1550)
# Summary: Includes fixes to get test_agents working with openAI model, e.g. tool parsing and message conversion # Test Plan: ``` LLAMA_STACK_CONFIG=dev pytest -s -v tests/integration/agents/test_agents.py --safety-shield meta-llama/Llama-Guard-3-8B --text-model openai/gpt-4o-mini ``` --- [//]: # (BEGIN SAPLING FOOTER) Stack created with [Sapling](https://sapling-scm.com). Best reviewed with [ReviewStack](https://reviewstack.dev/meta-llama/llama-stack/pull/1550). * #1556 * __->__ #1550
This commit is contained in:
parent
0bdfc71f8d
commit
c23a7af5d6
4 changed files with 161 additions and 125 deletions
|
@ -891,16 +891,14 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
if memory_tool and code_interpreter_tool:
|
if memory_tool and code_interpreter_tool:
|
||||||
# if both memory and code_interpreter are available, we download the URLs
|
# if both memory and code_interpreter are available, we download the URLs
|
||||||
# and attach the data to the last message.
|
# and attach the data to the last message.
|
||||||
msg = await attachment_message(self.tempdir, url_items)
|
await attachment_message(self.tempdir, url_items, input_messages[-1])
|
||||||
input_messages.append(msg)
|
|
||||||
# Since memory is present, add all the data to the memory bank
|
# Since memory is present, add all the data to the memory bank
|
||||||
await self.add_to_session_vector_db(session_id, documents)
|
await self.add_to_session_vector_db(session_id, documents)
|
||||||
elif code_interpreter_tool:
|
elif code_interpreter_tool:
|
||||||
# if only code_interpreter is available, we download the URLs to a tempdir
|
# if only code_interpreter is available, we download the URLs to a tempdir
|
||||||
# and attach the path to them as a message to inference with the
|
# and attach the path to them as a message to inference with the
|
||||||
# assumption that the model invokes the code_interpreter tool with the path
|
# assumption that the model invokes the code_interpreter tool with the path
|
||||||
msg = await attachment_message(self.tempdir, url_items)
|
await attachment_message(self.tempdir, url_items, input_messages[-1])
|
||||||
input_messages.append(msg)
|
|
||||||
elif memory_tool:
|
elif memory_tool:
|
||||||
# if only memory is available, we load the data from the URLs and content items to the memory bank
|
# if only memory is available, we load the data from the URLs and content items to the memory bank
|
||||||
await self.add_to_session_vector_db(session_id, documents)
|
await self.add_to_session_vector_db(session_id, documents)
|
||||||
|
@ -967,8 +965,8 @@ async def load_data_from_urls(urls: List[URL]) -> List[str]:
|
||||||
return data
|
return data
|
||||||
|
|
||||||
|
|
||||||
async def attachment_message(tempdir: str, urls: List[URL]) -> ToolResponseMessage:
|
async def attachment_message(tempdir: str, urls: List[URL], message: UserMessage) -> None:
|
||||||
content = []
|
contents = []
|
||||||
|
|
||||||
for url in urls:
|
for url in urls:
|
||||||
uri = url.uri
|
uri = url.uri
|
||||||
|
@ -988,16 +986,19 @@ async def attachment_message(tempdir: str, urls: List[URL]) -> ToolResponseMessa
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unsupported URL {url}")
|
raise ValueError(f"Unsupported URL {url}")
|
||||||
|
|
||||||
content.append(
|
contents.append(
|
||||||
TextContentItem(
|
TextContentItem(
|
||||||
text=f'# User provided a file accessible to you at "{filepath}"\nYou can use code_interpreter to load and inspect it.'
|
text=f'# User provided a file accessible to you at "{filepath}"\nYou can use code_interpreter to load and inspect it.'
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
return ToolResponseMessage(
|
if isinstance(message.content, list):
|
||||||
call_id="",
|
message.content.extend(contents)
|
||||||
content=content,
|
else:
|
||||||
)
|
if isinstance(message.content, str):
|
||||||
|
message.content = [TextContentItem(text=message.content)] + contents
|
||||||
|
else:
|
||||||
|
message.content = [message.content] + contents
|
||||||
|
|
||||||
|
|
||||||
def _interpret_content_as_attachment(
|
def _interpret_content_as_attachment(
|
||||||
|
|
|
@ -192,7 +192,11 @@ class LiteLLMOpenAIMixin(
|
||||||
if request.tools:
|
if request.tools:
|
||||||
input_dict["tools"] = [convert_tooldef_to_openai_tool(tool) for tool in request.tools]
|
input_dict["tools"] = [convert_tooldef_to_openai_tool(tool) for tool in request.tools]
|
||||||
if request.tool_config.tool_choice:
|
if request.tool_config.tool_choice:
|
||||||
input_dict["tool_choice"] = request.tool_config.tool_choice.value
|
input_dict["tool_choice"] = (
|
||||||
|
request.tool_config.tool_choice.value
|
||||||
|
if isinstance(request.tool_config.tool_choice, ToolChoice)
|
||||||
|
else request.tool_config.tool_choice
|
||||||
|
)
|
||||||
|
|
||||||
provider_data = self.get_request_provider_data()
|
provider_data = self.get_request_provider_data()
|
||||||
key_field = self.provider_data_api_key_field
|
key_field = self.provider_data_api_key_field
|
||||||
|
|
|
@ -527,26 +527,30 @@ async def convert_message_to_openai_dict_new(
|
||||||
async def _convert_message_content(
|
async def _convert_message_content(
|
||||||
content: InterleavedContent,
|
content: InterleavedContent,
|
||||||
) -> Union[str, Iterable[OpenAIChatCompletionContentPartParam]]:
|
) -> Union[str, Iterable[OpenAIChatCompletionContentPartParam]]:
|
||||||
async def impl():
|
async def impl(
|
||||||
|
content_: InterleavedContent,
|
||||||
|
) -> Union[str, OpenAIChatCompletionContentPartParam, List[OpenAIChatCompletionContentPartParam]]:
|
||||||
# Llama Stack and OpenAI spec match for str and text input
|
# Llama Stack and OpenAI spec match for str and text input
|
||||||
if isinstance(content, str):
|
if isinstance(content_, str):
|
||||||
return content
|
return content_
|
||||||
elif isinstance(content, TextContentItem):
|
elif isinstance(content_, TextContentItem):
|
||||||
return OpenAIChatCompletionContentPartTextParam(
|
return OpenAIChatCompletionContentPartTextParam(
|
||||||
type="text",
|
type="text",
|
||||||
text=content.text,
|
text=content_.text,
|
||||||
)
|
)
|
||||||
elif isinstance(content, ImageContentItem):
|
elif isinstance(content_, ImageContentItem):
|
||||||
return OpenAIChatCompletionContentPartImageParam(
|
return OpenAIChatCompletionContentPartImageParam(
|
||||||
type="image_url",
|
type="image_url",
|
||||||
image_url=OpenAIImageURL(url=await convert_image_content_to_url(content)),
|
image_url=OpenAIImageURL(url=await convert_image_content_to_url(content_)),
|
||||||
)
|
)
|
||||||
elif isinstance(content, list):
|
elif isinstance(content_, list):
|
||||||
return [await _convert_message_content(item) for item in content]
|
return [await impl(item) for item in content_]
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unsupported content type: {type(content)}")
|
raise ValueError(f"Unsupported content type: {type(content_)}")
|
||||||
|
|
||||||
ret = await impl()
|
ret = await impl(content)
|
||||||
|
|
||||||
|
# OpenAI*Message expects a str or list
|
||||||
if isinstance(ret, str) or isinstance(ret, list):
|
if isinstance(ret, str) or isinstance(ret, list):
|
||||||
return ret
|
return ret
|
||||||
else:
|
else:
|
||||||
|
@ -566,13 +570,14 @@ async def convert_message_to_openai_dict_new(
|
||||||
OpenAIChatCompletionMessageToolCall(
|
OpenAIChatCompletionMessageToolCall(
|
||||||
id=tool.call_id,
|
id=tool.call_id,
|
||||||
function=OpenAIFunction(
|
function=OpenAIFunction(
|
||||||
name=tool.tool_name,
|
name=tool.tool_name if not isinstance(tool.tool_name, BuiltinTool) else tool.tool_name.value,
|
||||||
arguments=json.dumps(tool.arguments),
|
arguments=json.dumps(tool.arguments),
|
||||||
),
|
),
|
||||||
type="function",
|
type="function",
|
||||||
)
|
)
|
||||||
for tool in message.tool_calls
|
for tool in message.tool_calls
|
||||||
],
|
]
|
||||||
|
or None,
|
||||||
)
|
)
|
||||||
elif isinstance(message, ToolResponseMessage):
|
elif isinstance(message, ToolResponseMessage):
|
||||||
out = OpenAIChatCompletionToolMessage(
|
out = OpenAIChatCompletionToolMessage(
|
||||||
|
@ -858,7 +863,8 @@ async def convert_openai_chat_completion_stream(
|
||||||
event_type = ChatCompletionResponseEventType.progress
|
event_type = ChatCompletionResponseEventType.progress
|
||||||
|
|
||||||
stop_reason = None
|
stop_reason = None
|
||||||
toolcall_buffer = {}
|
tool_call_idx_to_buffer = {}
|
||||||
|
|
||||||
async for chunk in stream:
|
async for chunk in stream:
|
||||||
choice = chunk.choices[0] # assuming only one choice per chunk
|
choice = chunk.choices[0] # assuming only one choice per chunk
|
||||||
|
|
||||||
|
@ -868,7 +874,6 @@ async def convert_openai_chat_completion_stream(
|
||||||
|
|
||||||
# if there's a tool call, emit an event for each tool in the list
|
# if there's a tool call, emit an event for each tool in the list
|
||||||
# if tool call and content, emit both separately
|
# if tool call and content, emit both separately
|
||||||
|
|
||||||
if choice.delta.tool_calls:
|
if choice.delta.tool_calls:
|
||||||
# the call may have content and a tool call. ChatCompletionResponseEvent
|
# the call may have content and a tool call. ChatCompletionResponseEvent
|
||||||
# does not support both, so we emit the content first
|
# does not support both, so we emit the content first
|
||||||
|
@ -889,33 +894,42 @@ async def convert_openai_chat_completion_stream(
|
||||||
)
|
)
|
||||||
|
|
||||||
if not enable_incremental_tool_calls:
|
if not enable_incremental_tool_calls:
|
||||||
|
for tool_call in choice.delta.tool_calls:
|
||||||
yield ChatCompletionResponseStreamChunk(
|
yield ChatCompletionResponseStreamChunk(
|
||||||
event=ChatCompletionResponseEvent(
|
event=ChatCompletionResponseEvent(
|
||||||
event_type=next(event_type),
|
event_type=event_type,
|
||||||
delta=ToolCallDelta(
|
delta=ToolCallDelta(
|
||||||
tool_call=_convert_openai_tool_calls(choice.delta.tool_calls)[0],
|
tool_call=_convert_openai_tool_calls([tool_call])[0],
|
||||||
parse_status=ToolCallParseStatus.succeeded,
|
parse_status=ToolCallParseStatus.succeeded,
|
||||||
),
|
),
|
||||||
logprobs=_convert_openai_logprobs(logprobs),
|
logprobs=_convert_openai_logprobs(logprobs),
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
tool_call = choice.delta.tool_calls[0]
|
for tool_call in choice.delta.tool_calls:
|
||||||
if "name" not in toolcall_buffer:
|
idx = tool_call.index if hasattr(tool_call, "index") else 0
|
||||||
toolcall_buffer["call_id"] = tool_call.id
|
|
||||||
toolcall_buffer["name"] = None
|
|
||||||
toolcall_buffer["content"] = ""
|
|
||||||
if "arguments" not in toolcall_buffer:
|
|
||||||
toolcall_buffer["arguments"] = ""
|
|
||||||
|
|
||||||
|
if idx not in tool_call_idx_to_buffer:
|
||||||
|
tool_call_idx_to_buffer[idx] = {
|
||||||
|
"call_id": tool_call.id,
|
||||||
|
"name": None,
|
||||||
|
"arguments": "",
|
||||||
|
"content": "",
|
||||||
|
}
|
||||||
|
|
||||||
|
buffer = tool_call_idx_to_buffer[idx]
|
||||||
|
|
||||||
|
if tool_call.function:
|
||||||
if tool_call.function.name:
|
if tool_call.function.name:
|
||||||
toolcall_buffer["name"] = tool_call.function.name
|
buffer["name"] = tool_call.function.name
|
||||||
delta = f"{toolcall_buffer['name']}("
|
delta = f"{buffer['name']}("
|
||||||
if tool_call.function.arguments:
|
buffer["content"] += delta
|
||||||
toolcall_buffer["arguments"] += tool_call.function.arguments
|
|
||||||
delta = toolcall_buffer["arguments"]
|
if tool_call.function.arguments:
|
||||||
|
delta = tool_call.function.arguments
|
||||||
|
buffer["arguments"] += delta
|
||||||
|
buffer["content"] += delta
|
||||||
|
|
||||||
toolcall_buffer["content"] += delta
|
|
||||||
yield ChatCompletionResponseStreamChunk(
|
yield ChatCompletionResponseStreamChunk(
|
||||||
event=ChatCompletionResponseEvent(
|
event=ChatCompletionResponseEvent(
|
||||||
event_type=event_type,
|
event_type=event_type,
|
||||||
|
@ -926,7 +940,7 @@ async def convert_openai_chat_completion_stream(
|
||||||
logprobs=_convert_openai_logprobs(logprobs),
|
logprobs=_convert_openai_logprobs(logprobs),
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
else:
|
elif choice.delta.content:
|
||||||
yield ChatCompletionResponseStreamChunk(
|
yield ChatCompletionResponseStreamChunk(
|
||||||
event=ChatCompletionResponseEvent(
|
event=ChatCompletionResponseEvent(
|
||||||
event_type=event_type,
|
event_type=event_type,
|
||||||
|
@ -935,9 +949,11 @@ async def convert_openai_chat_completion_stream(
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
if toolcall_buffer:
|
for idx, buffer in tool_call_idx_to_buffer.items():
|
||||||
|
logger.debug(f"toolcall_buffer[{idx}]: {buffer}")
|
||||||
|
if buffer["name"]:
|
||||||
delta = ")"
|
delta = ")"
|
||||||
toolcall_buffer["content"] += delta
|
buffer["content"] += delta
|
||||||
yield ChatCompletionResponseStreamChunk(
|
yield ChatCompletionResponseStreamChunk(
|
||||||
event=ChatCompletionResponseEvent(
|
event=ChatCompletionResponseEvent(
|
||||||
event_type=event_type,
|
event_type=event_type,
|
||||||
|
@ -945,14 +961,15 @@ async def convert_openai_chat_completion_stream(
|
||||||
tool_call=delta,
|
tool_call=delta,
|
||||||
parse_status=ToolCallParseStatus.in_progress,
|
parse_status=ToolCallParseStatus.in_progress,
|
||||||
),
|
),
|
||||||
logprobs=_convert_openai_logprobs(logprobs),
|
logprobs=None,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
arguments = json.loads(toolcall_buffer["arguments"])
|
arguments = json.loads(buffer["arguments"])
|
||||||
tool_call = ToolCall(
|
tool_call = ToolCall(
|
||||||
call_id=toolcall_buffer["call_id"],
|
call_id=buffer["call_id"],
|
||||||
tool_name=toolcall_buffer["name"],
|
tool_name=buffer["name"],
|
||||||
arguments=arguments,
|
arguments=arguments,
|
||||||
)
|
)
|
||||||
yield ChatCompletionResponseStreamChunk(
|
yield ChatCompletionResponseStreamChunk(
|
||||||
|
@ -965,12 +982,13 @@ async def convert_openai_chat_completion_stream(
|
||||||
stop_reason=stop_reason,
|
stop_reason=stop_reason,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
except json.JSONDecodeError:
|
except json.JSONDecodeError as e:
|
||||||
|
print(f"Failed to parse arguments: {e}")
|
||||||
yield ChatCompletionResponseStreamChunk(
|
yield ChatCompletionResponseStreamChunk(
|
||||||
event=ChatCompletionResponseEvent(
|
event=ChatCompletionResponseEvent(
|
||||||
event_type=ChatCompletionResponseEventType.complete,
|
event_type=ChatCompletionResponseEventType.progress,
|
||||||
delta=ToolCallDelta(
|
delta=ToolCallDelta(
|
||||||
tool_call=toolcall_buffer["content"],
|
tool_call=buffer["content"],
|
||||||
parse_status=ToolCallParseStatus.failed,
|
parse_status=ToolCallParseStatus.failed,
|
||||||
),
|
),
|
||||||
stop_reason=stop_reason,
|
stop_reason=stop_reason,
|
||||||
|
|
|
@ -271,7 +271,7 @@ def test_custom_tool(llama_stack_client_with_mocked_inference, agent_config):
|
||||||
client_tool = get_boiling_point
|
client_tool = get_boiling_point
|
||||||
agent_config = {
|
agent_config = {
|
||||||
**agent_config,
|
**agent_config,
|
||||||
"tools": ["builtin::websearch", client_tool],
|
"tools": [client_tool],
|
||||||
}
|
}
|
||||||
|
|
||||||
agent = Agent(llama_stack_client_with_mocked_inference, **agent_config)
|
agent = Agent(llama_stack_client_with_mocked_inference, **agent_config)
|
||||||
|
@ -320,17 +320,39 @@ def test_custom_tool_infinite_loop(llama_stack_client_with_mocked_inference, age
|
||||||
assert num_tool_calls <= 5
|
assert num_tool_calls <= 5
|
||||||
|
|
||||||
|
|
||||||
def test_tool_choice(llama_stack_client_with_mocked_inference, agent_config):
|
def test_tool_choice_required(llama_stack_client_with_mocked_inference, agent_config):
|
||||||
def run_agent(tool_choice):
|
tool_execution_steps = run_agent_with_tool_choice(
|
||||||
|
llama_stack_client_with_mocked_inference, agent_config, "required"
|
||||||
|
)
|
||||||
|
assert len(tool_execution_steps) > 0
|
||||||
|
|
||||||
|
|
||||||
|
def test_tool_choice_none(llama_stack_client_with_mocked_inference, agent_config):
|
||||||
|
tool_execution_steps = run_agent_with_tool_choice(llama_stack_client_with_mocked_inference, agent_config, "none")
|
||||||
|
assert len(tool_execution_steps) == 0
|
||||||
|
|
||||||
|
|
||||||
|
def test_tool_choice_get_boiling_point(llama_stack_client_with_mocked_inference, agent_config):
|
||||||
|
if "llama" not in agent_config["model"].lower():
|
||||||
|
pytest.xfail("NotImplemented for non-llama models")
|
||||||
|
|
||||||
|
tool_execution_steps = run_agent_with_tool_choice(
|
||||||
|
llama_stack_client_with_mocked_inference, agent_config, "get_boiling_point"
|
||||||
|
)
|
||||||
|
assert len(tool_execution_steps) >= 1 and tool_execution_steps[0].tool_calls[0].tool_name == "get_boiling_point"
|
||||||
|
|
||||||
|
|
||||||
|
def run_agent_with_tool_choice(client, agent_config, tool_choice):
|
||||||
client_tool = get_boiling_point
|
client_tool = get_boiling_point
|
||||||
|
|
||||||
test_agent_config = {
|
test_agent_config = {
|
||||||
**agent_config,
|
**agent_config,
|
||||||
"tool_config": {"tool_choice": tool_choice},
|
"tool_config": {"tool_choice": tool_choice},
|
||||||
"tools": [client_tool],
|
"tools": [client_tool],
|
||||||
|
"max_infer_iters": 2,
|
||||||
}
|
}
|
||||||
|
|
||||||
agent = Agent(llama_stack_client_with_mocked_inference, **test_agent_config)
|
agent = Agent(client, **test_agent_config)
|
||||||
session_id = agent.create_session(f"test-session-{uuid4()}")
|
session_id = agent.create_session(f"test-session-{uuid4()}")
|
||||||
|
|
||||||
response = agent.create_turn(
|
response = agent.create_turn(
|
||||||
|
@ -346,15 +368,6 @@ def test_tool_choice(llama_stack_client_with_mocked_inference, agent_config):
|
||||||
|
|
||||||
return [step for step in response.steps if step.step_type == "tool_execution"]
|
return [step for step in response.steps if step.step_type == "tool_execution"]
|
||||||
|
|
||||||
tool_execution_steps = run_agent("required")
|
|
||||||
assert len(tool_execution_steps) > 0
|
|
||||||
|
|
||||||
tool_execution_steps = run_agent("none")
|
|
||||||
assert len(tool_execution_steps) == 0
|
|
||||||
|
|
||||||
tool_execution_steps = run_agent("get_boiling_point")
|
|
||||||
assert len(tool_execution_steps) >= 1 and tool_execution_steps[0].tool_calls[0].tool_name == "get_boiling_point"
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("rag_tool_name", ["builtin::rag/knowledge_search", "builtin::rag"])
|
@pytest.mark.parametrize("rag_tool_name", ["builtin::rag/knowledge_search", "builtin::rag"])
|
||||||
def test_rag_agent(llama_stack_client_with_mocked_inference, agent_config, rag_tool_name):
|
def test_rag_agent(llama_stack_client_with_mocked_inference, agent_config, rag_tool_name):
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue