docs: update test_agents to use new Agent SDK API (#1402)

# Summary:
new Agent SDK API is added in
https://github.com/meta-llama/llama-stack-client-python/pull/178

Update docs and test to reflect this.

Closes https://github.com/meta-llama/llama-stack/issues/1365

# Test Plan:
```bash
py.test -v -s --nbval-lax ./docs/getting_started.ipynb

LLAMA_STACK_CONFIG=fireworks \
   pytest -s -v tests/integration/agents/test_agents.py \
  --safety-shield meta-llama/Llama-Guard-3-8B --text-model meta-llama/Llama-3.1-8B-Instruct
```
This commit is contained in:
ehhuang 2025-03-06 15:21:12 -08:00 committed by GitHub
parent 3d71e5a036
commit ca2910d27a
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
13 changed files with 121 additions and 206 deletions

View file

@ -64,7 +64,7 @@ def get_boiling_point_with_metadata(liquid_name: str, celcius: bool = True) -> D
def agent_config(llama_stack_client_with_mocked_inference, text_model_id):
available_shields = [shield.identifier for shield in llama_stack_client_with_mocked_inference.shields.list()]
available_shields = available_shields[:1]
agent_config = AgentConfig(
agent_config = dict(
model=text_model_id,
instructions="You are a helpful assistant",
sampling_params={
@ -74,7 +74,7 @@ def agent_config(llama_stack_client_with_mocked_inference, text_model_id):
"top_p": 0.9,
},
},
toolgroups=[],
tools=[],
input_shields=available_shields,
output_shields=available_shields,
enable_session_persistence=False,
@ -83,7 +83,7 @@ def agent_config(llama_stack_client_with_mocked_inference, text_model_id):
def test_agent_simple(llama_stack_client_with_mocked_inference, agent_config):
agent = Agent(llama_stack_client_with_mocked_inference, agent_config)
agent = Agent(llama_stack_client_with_mocked_inference, **agent_config)
session_id = agent.create_session(f"test-session-{uuid4()}")
simple_hello = agent.create_turn(
@ -137,7 +137,7 @@ def test_tool_config(llama_stack_client_with_mocked_inference, agent_config):
agent_config = AgentConfig(
**common_params,
)
Server__AgentConfig(**agent_config)
Server__AgentConfig(**common_params)
agent_config = AgentConfig(
**common_params,
@ -179,11 +179,11 @@ def test_tool_config(llama_stack_client_with_mocked_inference, agent_config):
def test_builtin_tool_web_search(llama_stack_client_with_mocked_inference, agent_config):
agent_config = {
**agent_config,
"toolgroups": [
"tools": [
"builtin::websearch",
],
}
agent = Agent(llama_stack_client_with_mocked_inference, agent_config)
agent = Agent(llama_stack_client_with_mocked_inference, **agent_config)
session_id = agent.create_session(f"test-session-{uuid4()}")
response = agent.create_turn(
@ -209,11 +209,11 @@ def test_builtin_tool_web_search(llama_stack_client_with_mocked_inference, agent
def test_builtin_tool_code_execution(llama_stack_client_with_mocked_inference, agent_config):
agent_config = {
**agent_config,
"toolgroups": [
"tools": [
"builtin::code_interpreter",
],
}
agent = Agent(llama_stack_client_with_mocked_inference, agent_config)
agent = Agent(llama_stack_client_with_mocked_inference, **agent_config)
session_id = agent.create_session(f"test-session-{uuid4()}")
response = agent.create_turn(
@ -238,12 +238,12 @@ def test_builtin_tool_code_execution(llama_stack_client_with_mocked_inference, a
def test_code_interpreter_for_attachments(llama_stack_client_with_mocked_inference, agent_config):
agent_config = {
**agent_config,
"toolgroups": [
"tools": [
"builtin::code_interpreter",
],
}
codex_agent = Agent(llama_stack_client_with_mocked_inference, agent_config)
codex_agent = Agent(llama_stack_client_with_mocked_inference, **agent_config)
session_id = codex_agent.create_session(f"test-session-{uuid4()}")
inflation_doc = AgentDocument(
content="https://raw.githubusercontent.com/meta-llama/llama-stack-apps/main/examples/resources/inflation.csv",
@ -275,11 +275,11 @@ def test_custom_tool(llama_stack_client_with_mocked_inference, agent_config):
client_tool = get_boiling_point
agent_config = {
**agent_config,
"toolgroups": ["builtin::websearch"],
"tools": ["builtin::websearch", client_tool],
"client_tools": [client_tool.get_tool_definition()],
}
agent = Agent(llama_stack_client_with_mocked_inference, agent_config, client_tools=(client_tool,))
agent = Agent(llama_stack_client_with_mocked_inference, **agent_config)
session_id = agent.create_session(f"test-session-{uuid4()}")
response = agent.create_turn(
@ -303,11 +303,11 @@ def test_custom_tool_infinite_loop(llama_stack_client_with_mocked_inference, age
agent_config = {
**agent_config,
"instructions": "You are a helpful assistant Always respond with tool calls no matter what. ",
"client_tools": [client_tool.get_tool_definition()],
"tools": [client_tool],
"max_infer_iters": 5,
}
agent = Agent(llama_stack_client_with_mocked_inference, agent_config, client_tools=(client_tool,))
agent = Agent(llama_stack_client_with_mocked_inference, **agent_config)
session_id = agent.create_session(f"test-session-{uuid4()}")
response = agent.create_turn(
@ -332,10 +332,10 @@ def test_tool_choice(llama_stack_client_with_mocked_inference, agent_config):
test_agent_config = {
**agent_config,
"tool_config": {"tool_choice": tool_choice},
"client_tools": [client_tool.get_tool_definition()],
"tools": [client_tool],
}
agent = Agent(llama_stack_client_with_mocked_inference, test_agent_config, client_tools=(client_tool,))
agent = Agent(llama_stack_client_with_mocked_inference, **test_agent_config)
session_id = agent.create_session(f"test-session-{uuid4()}")
response = agent.create_turn(
@ -387,7 +387,7 @@ def test_rag_agent(llama_stack_client_with_mocked_inference, agent_config, rag_t
)
agent_config = {
**agent_config,
"toolgroups": [
"tools": [
dict(
name=rag_tool_name,
args={
@ -396,7 +396,7 @@ def test_rag_agent(llama_stack_client_with_mocked_inference, agent_config, rag_t
)
],
}
rag_agent = Agent(llama_stack_client_with_mocked_inference, agent_config)
rag_agent = Agent(llama_stack_client_with_mocked_inference, **agent_config)
session_id = rag_agent.create_session(f"test-session-{uuid4()}")
user_prompts = [
(
@ -422,7 +422,7 @@ def test_rag_agent(llama_stack_client_with_mocked_inference, agent_config, rag_t
@pytest.mark.parametrize(
"toolgroup",
"tool",
[
dict(
name="builtin::rag/knowledge_search",
@ -433,7 +433,7 @@ def test_rag_agent(llama_stack_client_with_mocked_inference, agent_config, rag_t
"builtin::rag/knowledge_search",
],
)
def test_rag_agent_with_attachments(llama_stack_client_with_mocked_inference, agent_config, toolgroup):
def test_rag_agent_with_attachments(llama_stack_client_with_mocked_inference, agent_config, tool):
urls = ["chat.rst", "llama3.rst", "memory_optimizations.rst", "lora_finetune.rst"]
documents = [
Document(
@ -446,9 +446,9 @@ def test_rag_agent_with_attachments(llama_stack_client_with_mocked_inference, ag
]
agent_config = {
**agent_config,
"toolgroups": [toolgroup],
"tools": [tool],
}
rag_agent = Agent(llama_stack_client_with_mocked_inference, agent_config)
rag_agent = Agent(llama_stack_client_with_mocked_inference, **agent_config)
session_id = rag_agent.create_session(f"test-session-{uuid4()}")
user_prompts = [
(
@ -521,7 +521,7 @@ def test_rag_and_code_agent(llama_stack_client_with_mocked_inference, agent_conf
)
agent_config = {
**agent_config,
"toolgroups": [
"tools": [
dict(
name="builtin::rag/knowledge_search",
args={"vector_db_ids": [vector_db_id]},
@ -529,7 +529,7 @@ def test_rag_and_code_agent(llama_stack_client_with_mocked_inference, agent_conf
"builtin::code_interpreter",
],
}
agent = Agent(llama_stack_client_with_mocked_inference, agent_config)
agent = Agent(llama_stack_client_with_mocked_inference, **agent_config)
inflation_doc = Document(
document_id="test_csv",
content="https://raw.githubusercontent.com/meta-llama/llama-stack-apps/main/examples/resources/inflation.csv",
@ -578,10 +578,10 @@ def test_create_turn_response(llama_stack_client_with_mocked_inference, agent_co
**agent_config,
"input_shields": [],
"output_shields": [],
"client_tools": [client_tool.get_tool_definition()],
"tools": [client_tool],
}
agent = Agent(llama_stack_client_with_mocked_inference, agent_config, client_tools=(client_tool,))
agent = Agent(llama_stack_client_with_mocked_inference, **agent_config)
session_id = agent.create_session(f"test-session-{uuid4()}")
response = agent.create_turn(