fix agent test_rag

This commit is contained in:
Hardik Shah 2025-07-14 13:35:57 -07:00
parent 3d83322003
commit 4c53db9f50

View file

@ -77,6 +77,24 @@ def agent_config(llama_stack_client, text_model_id):
return agent_config return agent_config
@pytest.fixture(scope="session")
def agent_config_without_safety(text_model_id):
agent_config = dict(
model=text_model_id,
instructions="You are a helpful assistant",
sampling_params={
"strategy": {
"type": "top_p",
"temperature": 0.0001,
"top_p": 0.9,
},
},
tools=[],
enable_session_persistence=False,
)
return agent_config
def test_agent_simple(llama_stack_client, agent_config): def test_agent_simple(llama_stack_client, agent_config):
agent = Agent(llama_stack_client, **agent_config) agent = Agent(llama_stack_client, **agent_config)
session_id = agent.create_session(f"test-session-{uuid4()}") session_id = agent.create_session(f"test-session-{uuid4()}")
@ -491,7 +509,7 @@ def test_rag_agent(llama_stack_client, agent_config, rag_tool_name):
assert expected_kw in response.output_message.content.lower() assert expected_kw in response.output_message.content.lower()
def test_rag_agent_with_attachments(llama_stack_client, agent_config): def test_rag_agent_with_attachments(llama_stack_client, agent_config_without_safety):
urls = ["llama3.rst", "lora_finetune.rst"] urls = ["llama3.rst", "lora_finetune.rst"]
documents = [ documents = [
# passign as url # passign as url
@ -514,14 +532,8 @@ def test_rag_agent_with_attachments(llama_stack_client, agent_config):
metadata={}, metadata={},
), ),
] ]
rag_agent = Agent(llama_stack_client, **agent_config) rag_agent = Agent(llama_stack_client, **agent_config_without_safety)
session_id = rag_agent.create_session(f"test-session-{uuid4()}") session_id = rag_agent.create_session(f"test-session-{uuid4()}")
user_prompts = [
(
"Instead of the standard multi-head attention, what attention type does Llama3-8B use?",
"grouped",
),
]
user_prompts = [ user_prompts = [
( (
"I am attaching some documentation for Torchtune. Help me answer questions I will ask next.", "I am attaching some documentation for Torchtune. Help me answer questions I will ask next.",
@ -549,82 +561,6 @@ def test_rag_agent_with_attachments(llama_stack_client, agent_config):
assert "lora" in response.output_message.content.lower() assert "lora" in response.output_message.content.lower()
@pytest.mark.skip(reason="Code interpreter is currently disabled in the Stack")
def test_rag_and_code_agent(llama_stack_client, agent_config):
if "llama-4" in agent_config["model"].lower():
pytest.xfail("Not working for llama4")
documents = []
documents.append(
Document(
document_id="nba_wiki",
content="The NBA was created on August 3, 1949, with the merger of the Basketball Association of America (BAA) and the National Basketball League (NBL).",
metadata={},
)
)
documents.append(
Document(
document_id="perplexity_wiki",
content="""Perplexity the company was founded in 2022 by Aravind Srinivas, Andy Konwinski, Denis Yarats and Johnny Ho, engineers with backgrounds in back-end systems, artificial intelligence (AI) and machine learning:
Srinivas, the CEO, worked at OpenAI as an AI researcher.
Konwinski was among the founding team at Databricks.
Yarats, the CTO, was an AI research scientist at Meta.
Ho, the CSO, worked as an engineer at Quora, then as a quantitative trader on Wall Street.[5]""",
metadata={},
)
)
vector_db_id = f"test-vector-db-{uuid4()}"
llama_stack_client.vector_dbs.register(
vector_db_id=vector_db_id,
embedding_model="all-MiniLM-L6-v2",
embedding_dimension=384,
)
llama_stack_client.tool_runtime.rag_tool.insert(
documents=documents,
vector_db_id=vector_db_id,
chunk_size_in_tokens=128,
)
agent_config = {
**agent_config,
"tools": [
dict(
name="builtin::rag/knowledge_search",
args={"vector_db_ids": [vector_db_id]},
),
"builtin::code_interpreter",
],
}
agent = Agent(llama_stack_client, **agent_config)
user_prompts = [
(
"when was Perplexity the company founded?",
[],
"knowledge_search",
"2022",
),
(
"when was the nba created?",
[],
"knowledge_search",
"1949",
),
]
for prompt, docs, tool_name, expected_kw in user_prompts:
session_id = agent.create_session(f"test-session-{uuid4()}")
response = agent.create_turn(
messages=[{"role": "user", "content": prompt}],
session_id=session_id,
documents=docs,
stream=False,
)
tool_execution_step = next(step for step in response.steps if step.step_type == "tool_execution")
assert tool_execution_step.tool_calls[0].tool_name == tool_name, f"Failed on {prompt}"
if expected_kw:
assert expected_kw in response.output_message.content.lower()
@pytest.mark.parametrize( @pytest.mark.parametrize(
"client_tools", "client_tools",
[(get_boiling_point, False), (get_boiling_point_with_metadata, True)], [(get_boiling_point, False), (get_boiling_point_with_metadata, True)],