mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-06-27 18:50:41 +00:00
fix: update agents test (#1796)
# What does this PR do? - we no longer query vector db when uploading documents as attachments [//]: # (If resolving an issue, uncomment and update the line below) [//]: # (Closes #[issue-number]) ## Test Plan ``` pytest --stack-config="http://localhost:8321" -v tests/integration/agents/test_agents.py --text-model meta-llama/Llama-3.3-70B-Instruct ``` ``` pytest --stack-config=fireworks -v tests/integration/agents/test_agents.py --text-model meta-llama/Llama-3.3-70B-Instruct --record-responses ``` <img width="1160" alt="image" src="https://github.com/user-attachments/assets/90700f79-c002-4474-bb41-7bc0a39dc91c" /> [//]: # (## Documentation)
This commit is contained in:
parent
193e531216
commit
cfd30d2ad5
1 changed files with 13 additions and 31 deletions
|
@ -8,15 +8,13 @@ from typing import Any, Dict
|
|||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
from llama_stack_client import Agent, AgentEventLogger, Document
|
||||
from llama_stack_client.types.shared_params.agent_config import AgentConfig, ToolConfig
|
||||
|
||||
from llama_stack.apis.agents.agents import (
|
||||
AgentConfig as Server__AgentConfig,
|
||||
)
|
||||
from llama_stack.apis.agents.agents import (
|
||||
ToolChoice,
|
||||
)
|
||||
from llama_stack_client import Agent, AgentEventLogger, Document
|
||||
from llama_stack_client.types.shared_params.agent_config import AgentConfig, ToolConfig
|
||||
|
||||
|
||||
def get_boiling_point(liquid_name: str, celcius: bool = True) -> int:
|
||||
|
@ -173,6 +171,7 @@ def test_tool_config(llama_stack_client_with_mocked_inference, agent_config):
|
|||
def test_builtin_tool_web_search(llama_stack_client_with_mocked_inference, agent_config):
|
||||
agent_config = {
|
||||
**agent_config,
|
||||
"instructions": "You are a helpful assistant that can use web search to answer questions.",
|
||||
"tools": [
|
||||
"builtin::websearch",
|
||||
],
|
||||
|
@ -184,20 +183,20 @@ def test_builtin_tool_web_search(llama_stack_client_with_mocked_inference, agent
|
|||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": "Search the web and tell me who the founder of Meta is.",
|
||||
"content": "Search the web and tell me what is the local time in Tokyo currently.",
|
||||
}
|
||||
],
|
||||
session_id=session_id,
|
||||
stream=False,
|
||||
)
|
||||
|
||||
logs = [str(log) for log in AgentEventLogger().log(response) if log is not None]
|
||||
logs_str = "".join(logs)
|
||||
|
||||
assert "tool_execution>" in logs_str
|
||||
assert "Tool:brave_search Response:" in logs_str
|
||||
assert "mark zuckerberg" in logs_str.lower()
|
||||
if len(agent_config["output_shields"]) > 0:
|
||||
assert "No Violation" in logs_str
|
||||
found_tool_execution = False
|
||||
for step in response.steps:
|
||||
if step.step_type == "tool_execution":
|
||||
assert step.tool_calls[0].tool_name == "brave_search"
|
||||
found_tool_execution = True
|
||||
break
|
||||
assert found_tool_execution
|
||||
|
||||
|
||||
def test_builtin_tool_code_execution(llama_stack_client_with_mocked_inference, agent_config):
|
||||
|
@ -427,19 +426,7 @@ def test_rag_agent(llama_stack_client_with_mocked_inference, agent_config, rag_t
|
|||
assert expected_kw in response.output_message.content.lower()
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"tool",
|
||||
[
|
||||
dict(
|
||||
name="builtin::rag/knowledge_search",
|
||||
args={
|
||||
"vector_db_ids": [],
|
||||
},
|
||||
),
|
||||
"builtin::rag/knowledge_search",
|
||||
],
|
||||
)
|
||||
def test_rag_agent_with_attachments(llama_stack_client_with_mocked_inference, agent_config, tool):
|
||||
def test_rag_agent_with_attachments(llama_stack_client_with_mocked_inference, agent_config):
|
||||
urls = ["chat.rst", "llama3.rst", "memory_optimizations.rst", "lora_finetune.rst"]
|
||||
documents = [
|
||||
Document(
|
||||
|
@ -452,7 +439,6 @@ def test_rag_agent_with_attachments(llama_stack_client_with_mocked_inference, ag
|
|||
]
|
||||
agent_config = {
|
||||
**agent_config,
|
||||
"tools": [tool],
|
||||
}
|
||||
rag_agent = Agent(llama_stack_client_with_mocked_inference, **agent_config)
|
||||
session_id = rag_agent.create_session(f"test-session-{uuid4()}")
|
||||
|
@ -486,10 +472,6 @@ def test_rag_agent_with_attachments(llama_stack_client_with_mocked_inference, ag
|
|||
stream=False,
|
||||
)
|
||||
|
||||
# rag is called
|
||||
tool_execution_step = [step for step in response.steps if step.step_type == "tool_execution"]
|
||||
assert len(tool_execution_step) >= 1
|
||||
assert tool_execution_step[0].tool_calls[0].tool_name == "knowledge_search"
|
||||
assert "lora" in response.output_message.content.lower()
|
||||
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue