mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-06-27 18:50:41 +00:00
fix: agents with non-llama model (#1550)
# Summary: Includes fixes to get test_agents working with openAI model, e.g. tool parsing and message conversion # Test Plan: ``` LLAMA_STACK_CONFIG=dev pytest -s -v tests/integration/agents/test_agents.py --safety-shield meta-llama/Llama-Guard-3-8B --text-model openai/gpt-4o-mini ``` --- [//]: # (BEGIN SAPLING FOOTER) Stack created with [Sapling](https://sapling-scm.com). Best reviewed with [ReviewStack](https://reviewstack.dev/meta-llama/llama-stack/pull/1550). * #1556 * __->__ #1550
This commit is contained in:
parent
0bdfc71f8d
commit
c23a7af5d6
4 changed files with 161 additions and 125 deletions
|
@ -271,7 +271,7 @@ def test_custom_tool(llama_stack_client_with_mocked_inference, agent_config):
|
|||
client_tool = get_boiling_point
|
||||
agent_config = {
|
||||
**agent_config,
|
||||
"tools": ["builtin::websearch", client_tool],
|
||||
"tools": [client_tool],
|
||||
}
|
||||
|
||||
agent = Agent(llama_stack_client_with_mocked_inference, **agent_config)
|
||||
|
@ -320,42 +320,55 @@ def test_custom_tool_infinite_loop(llama_stack_client_with_mocked_inference, age
|
|||
assert num_tool_calls <= 5
|
||||
|
||||
|
||||
def test_tool_choice(llama_stack_client_with_mocked_inference, agent_config):
|
||||
def run_agent(tool_choice):
|
||||
client_tool = get_boiling_point
|
||||
|
||||
test_agent_config = {
|
||||
**agent_config,
|
||||
"tool_config": {"tool_choice": tool_choice},
|
||||
"tools": [client_tool],
|
||||
}
|
||||
|
||||
agent = Agent(llama_stack_client_with_mocked_inference, **test_agent_config)
|
||||
session_id = agent.create_session(f"test-session-{uuid4()}")
|
||||
|
||||
response = agent.create_turn(
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": "What is the boiling point of polyjuice?",
|
||||
},
|
||||
],
|
||||
session_id=session_id,
|
||||
stream=False,
|
||||
)
|
||||
|
||||
return [step for step in response.steps if step.step_type == "tool_execution"]
|
||||
|
||||
tool_execution_steps = run_agent("required")
|
||||
def test_tool_choice_required(llama_stack_client_with_mocked_inference, agent_config):
|
||||
tool_execution_steps = run_agent_with_tool_choice(
|
||||
llama_stack_client_with_mocked_inference, agent_config, "required"
|
||||
)
|
||||
assert len(tool_execution_steps) > 0
|
||||
|
||||
tool_execution_steps = run_agent("none")
|
||||
|
||||
def test_tool_choice_none(llama_stack_client_with_mocked_inference, agent_config):
|
||||
tool_execution_steps = run_agent_with_tool_choice(llama_stack_client_with_mocked_inference, agent_config, "none")
|
||||
assert len(tool_execution_steps) == 0
|
||||
|
||||
tool_execution_steps = run_agent("get_boiling_point")
|
||||
|
||||
def test_tool_choice_get_boiling_point(llama_stack_client_with_mocked_inference, agent_config):
|
||||
if "llama" not in agent_config["model"].lower():
|
||||
pytest.xfail("NotImplemented for non-llama models")
|
||||
|
||||
tool_execution_steps = run_agent_with_tool_choice(
|
||||
llama_stack_client_with_mocked_inference, agent_config, "get_boiling_point"
|
||||
)
|
||||
assert len(tool_execution_steps) >= 1 and tool_execution_steps[0].tool_calls[0].tool_name == "get_boiling_point"
|
||||
|
||||
|
||||
def run_agent_with_tool_choice(client, agent_config, tool_choice):
|
||||
client_tool = get_boiling_point
|
||||
|
||||
test_agent_config = {
|
||||
**agent_config,
|
||||
"tool_config": {"tool_choice": tool_choice},
|
||||
"tools": [client_tool],
|
||||
"max_infer_iters": 2,
|
||||
}
|
||||
|
||||
agent = Agent(client, **test_agent_config)
|
||||
session_id = agent.create_session(f"test-session-{uuid4()}")
|
||||
|
||||
response = agent.create_turn(
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": "What is the boiling point of polyjuice?",
|
||||
},
|
||||
],
|
||||
session_id=session_id,
|
||||
stream=False,
|
||||
)
|
||||
|
||||
return [step for step in response.steps if step.step_type == "tool_execution"]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("rag_tool_name", ["builtin::rag/knowledge_search", "builtin::rag"])
|
||||
def test_rag_agent(llama_stack_client_with_mocked_inference, agent_config, rag_tool_name):
|
||||
urls = ["chat.rst", "llama3.rst", "memory_optimizations.rst", "lora_finetune.rst"]
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue