mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-06-27 18:50:41 +00:00
docs: update test_agents to use new Agent SDK API (#1402)
# Summary: new Agent SDK API is added in https://github.com/meta-llama/llama-stack-client-python/pull/178 Update docs and test to reflect this. Closes https://github.com/meta-llama/llama-stack/issues/1365 # Test Plan: ```bash py.test -v -s --nbval-lax ./docs/getting_started.ipynb LLAMA_STACK_CONFIG=fireworks \ pytest -s -v tests/integration/agents/test_agents.py \ --safety-shield meta-llama/Llama-Guard-3-8B --text-model meta-llama/Llama-3.1-8B-Instruct ```
This commit is contained in:
parent
3d71e5a036
commit
ca2910d27a
13 changed files with 121 additions and 206 deletions
|
@ -64,7 +64,7 @@ def get_boiling_point_with_metadata(liquid_name: str, celcius: bool = True) -> D
|
|||
def agent_config(llama_stack_client_with_mocked_inference, text_model_id):
|
||||
available_shields = [shield.identifier for shield in llama_stack_client_with_mocked_inference.shields.list()]
|
||||
available_shields = available_shields[:1]
|
||||
agent_config = AgentConfig(
|
||||
agent_config = dict(
|
||||
model=text_model_id,
|
||||
instructions="You are a helpful assistant",
|
||||
sampling_params={
|
||||
|
@ -74,7 +74,7 @@ def agent_config(llama_stack_client_with_mocked_inference, text_model_id):
|
|||
"top_p": 0.9,
|
||||
},
|
||||
},
|
||||
toolgroups=[],
|
||||
tools=[],
|
||||
input_shields=available_shields,
|
||||
output_shields=available_shields,
|
||||
enable_session_persistence=False,
|
||||
|
@ -83,7 +83,7 @@ def agent_config(llama_stack_client_with_mocked_inference, text_model_id):
|
|||
|
||||
|
||||
def test_agent_simple(llama_stack_client_with_mocked_inference, agent_config):
|
||||
agent = Agent(llama_stack_client_with_mocked_inference, agent_config)
|
||||
agent = Agent(llama_stack_client_with_mocked_inference, **agent_config)
|
||||
session_id = agent.create_session(f"test-session-{uuid4()}")
|
||||
|
||||
simple_hello = agent.create_turn(
|
||||
|
@ -137,7 +137,7 @@ def test_tool_config(llama_stack_client_with_mocked_inference, agent_config):
|
|||
agent_config = AgentConfig(
|
||||
**common_params,
|
||||
)
|
||||
Server__AgentConfig(**agent_config)
|
||||
Server__AgentConfig(**common_params)
|
||||
|
||||
agent_config = AgentConfig(
|
||||
**common_params,
|
||||
|
@ -179,11 +179,11 @@ def test_tool_config(llama_stack_client_with_mocked_inference, agent_config):
|
|||
def test_builtin_tool_web_search(llama_stack_client_with_mocked_inference, agent_config):
|
||||
agent_config = {
|
||||
**agent_config,
|
||||
"toolgroups": [
|
||||
"tools": [
|
||||
"builtin::websearch",
|
||||
],
|
||||
}
|
||||
agent = Agent(llama_stack_client_with_mocked_inference, agent_config)
|
||||
agent = Agent(llama_stack_client_with_mocked_inference, **agent_config)
|
||||
session_id = agent.create_session(f"test-session-{uuid4()}")
|
||||
|
||||
response = agent.create_turn(
|
||||
|
@ -209,11 +209,11 @@ def test_builtin_tool_web_search(llama_stack_client_with_mocked_inference, agent
|
|||
def test_builtin_tool_code_execution(llama_stack_client_with_mocked_inference, agent_config):
|
||||
agent_config = {
|
||||
**agent_config,
|
||||
"toolgroups": [
|
||||
"tools": [
|
||||
"builtin::code_interpreter",
|
||||
],
|
||||
}
|
||||
agent = Agent(llama_stack_client_with_mocked_inference, agent_config)
|
||||
agent = Agent(llama_stack_client_with_mocked_inference, **agent_config)
|
||||
session_id = agent.create_session(f"test-session-{uuid4()}")
|
||||
|
||||
response = agent.create_turn(
|
||||
|
@ -238,12 +238,12 @@ def test_builtin_tool_code_execution(llama_stack_client_with_mocked_inference, a
|
|||
def test_code_interpreter_for_attachments(llama_stack_client_with_mocked_inference, agent_config):
|
||||
agent_config = {
|
||||
**agent_config,
|
||||
"toolgroups": [
|
||||
"tools": [
|
||||
"builtin::code_interpreter",
|
||||
],
|
||||
}
|
||||
|
||||
codex_agent = Agent(llama_stack_client_with_mocked_inference, agent_config)
|
||||
codex_agent = Agent(llama_stack_client_with_mocked_inference, **agent_config)
|
||||
session_id = codex_agent.create_session(f"test-session-{uuid4()}")
|
||||
inflation_doc = AgentDocument(
|
||||
content="https://raw.githubusercontent.com/meta-llama/llama-stack-apps/main/examples/resources/inflation.csv",
|
||||
|
@ -275,11 +275,11 @@ def test_custom_tool(llama_stack_client_with_mocked_inference, agent_config):
|
|||
client_tool = get_boiling_point
|
||||
agent_config = {
|
||||
**agent_config,
|
||||
"toolgroups": ["builtin::websearch"],
|
||||
"tools": ["builtin::websearch", client_tool],
|
||||
"client_tools": [client_tool.get_tool_definition()],
|
||||
}
|
||||
|
||||
agent = Agent(llama_stack_client_with_mocked_inference, agent_config, client_tools=(client_tool,))
|
||||
agent = Agent(llama_stack_client_with_mocked_inference, **agent_config)
|
||||
session_id = agent.create_session(f"test-session-{uuid4()}")
|
||||
|
||||
response = agent.create_turn(
|
||||
|
@ -303,11 +303,11 @@ def test_custom_tool_infinite_loop(llama_stack_client_with_mocked_inference, age
|
|||
agent_config = {
|
||||
**agent_config,
|
||||
"instructions": "You are a helpful assistant Always respond with tool calls no matter what. ",
|
||||
"client_tools": [client_tool.get_tool_definition()],
|
||||
"tools": [client_tool],
|
||||
"max_infer_iters": 5,
|
||||
}
|
||||
|
||||
agent = Agent(llama_stack_client_with_mocked_inference, agent_config, client_tools=(client_tool,))
|
||||
agent = Agent(llama_stack_client_with_mocked_inference, **agent_config)
|
||||
session_id = agent.create_session(f"test-session-{uuid4()}")
|
||||
|
||||
response = agent.create_turn(
|
||||
|
@ -332,10 +332,10 @@ def test_tool_choice(llama_stack_client_with_mocked_inference, agent_config):
|
|||
test_agent_config = {
|
||||
**agent_config,
|
||||
"tool_config": {"tool_choice": tool_choice},
|
||||
"client_tools": [client_tool.get_tool_definition()],
|
||||
"tools": [client_tool],
|
||||
}
|
||||
|
||||
agent = Agent(llama_stack_client_with_mocked_inference, test_agent_config, client_tools=(client_tool,))
|
||||
agent = Agent(llama_stack_client_with_mocked_inference, **test_agent_config)
|
||||
session_id = agent.create_session(f"test-session-{uuid4()}")
|
||||
|
||||
response = agent.create_turn(
|
||||
|
@ -387,7 +387,7 @@ def test_rag_agent(llama_stack_client_with_mocked_inference, agent_config, rag_t
|
|||
)
|
||||
agent_config = {
|
||||
**agent_config,
|
||||
"toolgroups": [
|
||||
"tools": [
|
||||
dict(
|
||||
name=rag_tool_name,
|
||||
args={
|
||||
|
@ -396,7 +396,7 @@ def test_rag_agent(llama_stack_client_with_mocked_inference, agent_config, rag_t
|
|||
)
|
||||
],
|
||||
}
|
||||
rag_agent = Agent(llama_stack_client_with_mocked_inference, agent_config)
|
||||
rag_agent = Agent(llama_stack_client_with_mocked_inference, **agent_config)
|
||||
session_id = rag_agent.create_session(f"test-session-{uuid4()}")
|
||||
user_prompts = [
|
||||
(
|
||||
|
@ -422,7 +422,7 @@ def test_rag_agent(llama_stack_client_with_mocked_inference, agent_config, rag_t
|
|||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"toolgroup",
|
||||
"tool",
|
||||
[
|
||||
dict(
|
||||
name="builtin::rag/knowledge_search",
|
||||
|
@ -433,7 +433,7 @@ def test_rag_agent(llama_stack_client_with_mocked_inference, agent_config, rag_t
|
|||
"builtin::rag/knowledge_search",
|
||||
],
|
||||
)
|
||||
def test_rag_agent_with_attachments(llama_stack_client_with_mocked_inference, agent_config, toolgroup):
|
||||
def test_rag_agent_with_attachments(llama_stack_client_with_mocked_inference, agent_config, tool):
|
||||
urls = ["chat.rst", "llama3.rst", "memory_optimizations.rst", "lora_finetune.rst"]
|
||||
documents = [
|
||||
Document(
|
||||
|
@ -446,9 +446,9 @@ def test_rag_agent_with_attachments(llama_stack_client_with_mocked_inference, ag
|
|||
]
|
||||
agent_config = {
|
||||
**agent_config,
|
||||
"toolgroups": [toolgroup],
|
||||
"tools": [tool],
|
||||
}
|
||||
rag_agent = Agent(llama_stack_client_with_mocked_inference, agent_config)
|
||||
rag_agent = Agent(llama_stack_client_with_mocked_inference, **agent_config)
|
||||
session_id = rag_agent.create_session(f"test-session-{uuid4()}")
|
||||
user_prompts = [
|
||||
(
|
||||
|
@ -521,7 +521,7 @@ def test_rag_and_code_agent(llama_stack_client_with_mocked_inference, agent_conf
|
|||
)
|
||||
agent_config = {
|
||||
**agent_config,
|
||||
"toolgroups": [
|
||||
"tools": [
|
||||
dict(
|
||||
name="builtin::rag/knowledge_search",
|
||||
args={"vector_db_ids": [vector_db_id]},
|
||||
|
@ -529,7 +529,7 @@ def test_rag_and_code_agent(llama_stack_client_with_mocked_inference, agent_conf
|
|||
"builtin::code_interpreter",
|
||||
],
|
||||
}
|
||||
agent = Agent(llama_stack_client_with_mocked_inference, agent_config)
|
||||
agent = Agent(llama_stack_client_with_mocked_inference, **agent_config)
|
||||
inflation_doc = Document(
|
||||
document_id="test_csv",
|
||||
content="https://raw.githubusercontent.com/meta-llama/llama-stack-apps/main/examples/resources/inflation.csv",
|
||||
|
@ -578,10 +578,10 @@ def test_create_turn_response(llama_stack_client_with_mocked_inference, agent_co
|
|||
**agent_config,
|
||||
"input_shields": [],
|
||||
"output_shields": [],
|
||||
"client_tools": [client_tool.get_tool_definition()],
|
||||
"tools": [client_tool],
|
||||
}
|
||||
|
||||
agent = Agent(llama_stack_client_with_mocked_inference, agent_config, client_tools=(client_tool,))
|
||||
agent = Agent(llama_stack_client_with_mocked_inference, **agent_config)
|
||||
session_id = agent.create_session(f"test-session-{uuid4()}")
|
||||
|
||||
response = agent.create_turn(
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue