fix: update agents test (#1796)

# What does this PR do?
- we no longer query vector db when uploading documents as attachments

[//]: # (If resolving an issue, uncomment and update the line below)
[//]: # (Closes #[issue-number])

## Test Plan
```
pytest --stack-config="http://localhost:8321" -v tests/integration/agents/test_agents.py --text-model meta-llama/Llama-3.3-70B-Instruct
```

```
pytest --stack-config=fireworks -v tests/integration/agents/test_agents.py --text-model meta-llama/Llama-3.3-70B-Instruct --record-responses
```
<img width="1160" alt="image"
src="https://github.com/user-attachments/assets/90700f79-c002-4474-bb41-7bc0a39dc91c"
/>


[//]: # (## Documentation)
This commit is contained in:
Xi Yan 2025-03-26 22:00:43 -07:00 committed by GitHub
parent 193e531216
commit cfd30d2ad5
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -8,15 +8,13 @@ from typing import Any, Dict
from uuid import uuid4
import pytest
from llama_stack_client import Agent, AgentEventLogger, Document
from llama_stack_client.types.shared_params.agent_config import AgentConfig, ToolConfig
from llama_stack.apis.agents.agents import (
AgentConfig as Server__AgentConfig,
)
from llama_stack.apis.agents.agents import (
ToolChoice,
)
from llama_stack_client import Agent, AgentEventLogger, Document
from llama_stack_client.types.shared_params.agent_config import AgentConfig, ToolConfig
def get_boiling_point(liquid_name: str, celcius: bool = True) -> int:
@ -173,6 +171,7 @@ def test_tool_config(llama_stack_client_with_mocked_inference, agent_config):
def test_builtin_tool_web_search(llama_stack_client_with_mocked_inference, agent_config):
agent_config = {
**agent_config,
"instructions": "You are a helpful assistant that can use web search to answer questions.",
"tools": [
"builtin::websearch",
],
@ -184,20 +183,20 @@ def test_builtin_tool_web_search(llama_stack_client_with_mocked_inference, agent
messages=[
{
"role": "user",
"content": "Search the web and tell me who the founder of Meta is.",
"content": "Search the web and tell me what is the local time in Tokyo currently.",
}
],
session_id=session_id,
stream=False,
)
logs = [str(log) for log in AgentEventLogger().log(response) if log is not None]
logs_str = "".join(logs)
assert "tool_execution>" in logs_str
assert "Tool:brave_search Response:" in logs_str
assert "mark zuckerberg" in logs_str.lower()
if len(agent_config["output_shields"]) > 0:
assert "No Violation" in logs_str
found_tool_execution = False
for step in response.steps:
if step.step_type == "tool_execution":
assert step.tool_calls[0].tool_name == "brave_search"
found_tool_execution = True
break
assert found_tool_execution
def test_builtin_tool_code_execution(llama_stack_client_with_mocked_inference, agent_config):
@ -427,19 +426,7 @@ def test_rag_agent(llama_stack_client_with_mocked_inference, agent_config, rag_t
assert expected_kw in response.output_message.content.lower()
@pytest.mark.parametrize(
"tool",
[
dict(
name="builtin::rag/knowledge_search",
args={
"vector_db_ids": [],
},
),
"builtin::rag/knowledge_search",
],
)
def test_rag_agent_with_attachments(llama_stack_client_with_mocked_inference, agent_config, tool):
def test_rag_agent_with_attachments(llama_stack_client_with_mocked_inference, agent_config):
urls = ["chat.rst", "llama3.rst", "memory_optimizations.rst", "lora_finetune.rst"]
documents = [
Document(
@ -452,7 +439,6 @@ def test_rag_agent_with_attachments(llama_stack_client_with_mocked_inference, ag
]
agent_config = {
**agent_config,
"tools": [tool],
}
rag_agent = Agent(llama_stack_client_with_mocked_inference, **agent_config)
session_id = rag_agent.create_session(f"test-session-{uuid4()}")
@ -486,10 +472,6 @@ def test_rag_agent_with_attachments(llama_stack_client_with_mocked_inference, ag
stream=False,
)
# rag is called
tool_execution_step = [step for step in response.steps if step.step_type == "tool_execution"]
assert len(tool_execution_step) >= 1
assert tool_execution_step[0].tool_calls[0].tool_name == "knowledge_search"
assert "lora" in response.output_message.content.lower()