mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-06-28 02:53:30 +00:00
fix: update agents test (#1796)
# What does this PR do? - we no longer query vector db when uploading documents as attachments [//]: # (If resolving an issue, uncomment and update the line below) [//]: # (Closes #[issue-number]) ## Test Plan ``` pytest --stack-config="http://localhost:8321" -v tests/integration/agents/test_agents.py --text-model meta-llama/Llama-3.3-70B-Instruct ``` ``` pytest --stack-config=fireworks -v tests/integration/agents/test_agents.py --text-model meta-llama/Llama-3.3-70B-Instruct --record-responses ``` <img width="1160" alt="image" src="https://github.com/user-attachments/assets/90700f79-c002-4474-bb41-7bc0a39dc91c" /> [//]: # (## Documentation)
This commit is contained in:
parent
193e531216
commit
cfd30d2ad5
1 changed files with 13 additions and 31 deletions
|
@ -8,15 +8,13 @@ from typing import Any, Dict
|
||||||
from uuid import uuid4
|
from uuid import uuid4
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from llama_stack_client import Agent, AgentEventLogger, Document
|
|
||||||
from llama_stack_client.types.shared_params.agent_config import AgentConfig, ToolConfig
|
|
||||||
|
|
||||||
from llama_stack.apis.agents.agents import (
|
from llama_stack.apis.agents.agents import (
|
||||||
AgentConfig as Server__AgentConfig,
|
AgentConfig as Server__AgentConfig,
|
||||||
)
|
|
||||||
from llama_stack.apis.agents.agents import (
|
|
||||||
ToolChoice,
|
ToolChoice,
|
||||||
)
|
)
|
||||||
|
from llama_stack_client import Agent, AgentEventLogger, Document
|
||||||
|
from llama_stack_client.types.shared_params.agent_config import AgentConfig, ToolConfig
|
||||||
|
|
||||||
|
|
||||||
def get_boiling_point(liquid_name: str, celcius: bool = True) -> int:
|
def get_boiling_point(liquid_name: str, celcius: bool = True) -> int:
|
||||||
|
@ -173,6 +171,7 @@ def test_tool_config(llama_stack_client_with_mocked_inference, agent_config):
|
||||||
def test_builtin_tool_web_search(llama_stack_client_with_mocked_inference, agent_config):
|
def test_builtin_tool_web_search(llama_stack_client_with_mocked_inference, agent_config):
|
||||||
agent_config = {
|
agent_config = {
|
||||||
**agent_config,
|
**agent_config,
|
||||||
|
"instructions": "You are a helpful assistant that can use web search to answer questions.",
|
||||||
"tools": [
|
"tools": [
|
||||||
"builtin::websearch",
|
"builtin::websearch",
|
||||||
],
|
],
|
||||||
|
@ -184,20 +183,20 @@ def test_builtin_tool_web_search(llama_stack_client_with_mocked_inference, agent
|
||||||
messages=[
|
messages=[
|
||||||
{
|
{
|
||||||
"role": "user",
|
"role": "user",
|
||||||
"content": "Search the web and tell me who the founder of Meta is.",
|
"content": "Search the web and tell me what is the local time in Tokyo currently.",
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
session_id=session_id,
|
session_id=session_id,
|
||||||
|
stream=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
logs = [str(log) for log in AgentEventLogger().log(response) if log is not None]
|
found_tool_execution = False
|
||||||
logs_str = "".join(logs)
|
for step in response.steps:
|
||||||
|
if step.step_type == "tool_execution":
|
||||||
assert "tool_execution>" in logs_str
|
assert step.tool_calls[0].tool_name == "brave_search"
|
||||||
assert "Tool:brave_search Response:" in logs_str
|
found_tool_execution = True
|
||||||
assert "mark zuckerberg" in logs_str.lower()
|
break
|
||||||
if len(agent_config["output_shields"]) > 0:
|
assert found_tool_execution
|
||||||
assert "No Violation" in logs_str
|
|
||||||
|
|
||||||
|
|
||||||
def test_builtin_tool_code_execution(llama_stack_client_with_mocked_inference, agent_config):
|
def test_builtin_tool_code_execution(llama_stack_client_with_mocked_inference, agent_config):
|
||||||
|
@ -427,19 +426,7 @@ def test_rag_agent(llama_stack_client_with_mocked_inference, agent_config, rag_t
|
||||||
assert expected_kw in response.output_message.content.lower()
|
assert expected_kw in response.output_message.content.lower()
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
def test_rag_agent_with_attachments(llama_stack_client_with_mocked_inference, agent_config):
|
||||||
"tool",
|
|
||||||
[
|
|
||||||
dict(
|
|
||||||
name="builtin::rag/knowledge_search",
|
|
||||||
args={
|
|
||||||
"vector_db_ids": [],
|
|
||||||
},
|
|
||||||
),
|
|
||||||
"builtin::rag/knowledge_search",
|
|
||||||
],
|
|
||||||
)
|
|
||||||
def test_rag_agent_with_attachments(llama_stack_client_with_mocked_inference, agent_config, tool):
|
|
||||||
urls = ["chat.rst", "llama3.rst", "memory_optimizations.rst", "lora_finetune.rst"]
|
urls = ["chat.rst", "llama3.rst", "memory_optimizations.rst", "lora_finetune.rst"]
|
||||||
documents = [
|
documents = [
|
||||||
Document(
|
Document(
|
||||||
|
@ -452,7 +439,6 @@ def test_rag_agent_with_attachments(llama_stack_client_with_mocked_inference, ag
|
||||||
]
|
]
|
||||||
agent_config = {
|
agent_config = {
|
||||||
**agent_config,
|
**agent_config,
|
||||||
"tools": [tool],
|
|
||||||
}
|
}
|
||||||
rag_agent = Agent(llama_stack_client_with_mocked_inference, **agent_config)
|
rag_agent = Agent(llama_stack_client_with_mocked_inference, **agent_config)
|
||||||
session_id = rag_agent.create_session(f"test-session-{uuid4()}")
|
session_id = rag_agent.create_session(f"test-session-{uuid4()}")
|
||||||
|
@ -486,10 +472,6 @@ def test_rag_agent_with_attachments(llama_stack_client_with_mocked_inference, ag
|
||||||
stream=False,
|
stream=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
# rag is called
|
|
||||||
tool_execution_step = [step for step in response.steps if step.step_type == "tool_execution"]
|
|
||||||
assert len(tool_execution_step) >= 1
|
|
||||||
assert tool_execution_step[0].tool_calls[0].tool_name == "knowledge_search"
|
|
||||||
assert "lora" in response.output_message.content.lower()
|
assert "lora" in response.output_message.content.lower()
|
||||||
|
|
||||||
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue