mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-23 09:09:43 +00:00
fix agent test_rag
This commit is contained in:
parent
3d83322003
commit
4c53db9f50
1 changed files with 20 additions and 84 deletions
|
|
@ -77,6 +77,24 @@ def agent_config(llama_stack_client, text_model_id):
|
||||||
return agent_config
|
return agent_config
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="session")
|
||||||
|
def agent_config_without_safety(text_model_id):
|
||||||
|
agent_config = dict(
|
||||||
|
model=text_model_id,
|
||||||
|
instructions="You are a helpful assistant",
|
||||||
|
sampling_params={
|
||||||
|
"strategy": {
|
||||||
|
"type": "top_p",
|
||||||
|
"temperature": 0.0001,
|
||||||
|
"top_p": 0.9,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
tools=[],
|
||||||
|
enable_session_persistence=False,
|
||||||
|
)
|
||||||
|
return agent_config
|
||||||
|
|
||||||
|
|
||||||
def test_agent_simple(llama_stack_client, agent_config):
|
def test_agent_simple(llama_stack_client, agent_config):
|
||||||
agent = Agent(llama_stack_client, **agent_config)
|
agent = Agent(llama_stack_client, **agent_config)
|
||||||
session_id = agent.create_session(f"test-session-{uuid4()}")
|
session_id = agent.create_session(f"test-session-{uuid4()}")
|
||||||
|
|
@ -491,7 +509,7 @@ def test_rag_agent(llama_stack_client, agent_config, rag_tool_name):
|
||||||
assert expected_kw in response.output_message.content.lower()
|
assert expected_kw in response.output_message.content.lower()
|
||||||
|
|
||||||
|
|
||||||
def test_rag_agent_with_attachments(llama_stack_client, agent_config):
|
def test_rag_agent_with_attachments(llama_stack_client, agent_config_without_safety):
|
||||||
urls = ["llama3.rst", "lora_finetune.rst"]
|
urls = ["llama3.rst", "lora_finetune.rst"]
|
||||||
documents = [
|
documents = [
|
||||||
# passign as url
|
# passign as url
|
||||||
|
|
@ -514,14 +532,8 @@ def test_rag_agent_with_attachments(llama_stack_client, agent_config):
|
||||||
metadata={},
|
metadata={},
|
||||||
),
|
),
|
||||||
]
|
]
|
||||||
rag_agent = Agent(llama_stack_client, **agent_config)
|
rag_agent = Agent(llama_stack_client, **agent_config_without_safety)
|
||||||
session_id = rag_agent.create_session(f"test-session-{uuid4()}")
|
session_id = rag_agent.create_session(f"test-session-{uuid4()}")
|
||||||
user_prompts = [
|
|
||||||
(
|
|
||||||
"Instead of the standard multi-head attention, what attention type does Llama3-8B use?",
|
|
||||||
"grouped",
|
|
||||||
),
|
|
||||||
]
|
|
||||||
user_prompts = [
|
user_prompts = [
|
||||||
(
|
(
|
||||||
"I am attaching some documentation for Torchtune. Help me answer questions I will ask next.",
|
"I am attaching some documentation for Torchtune. Help me answer questions I will ask next.",
|
||||||
|
|
@ -549,82 +561,6 @@ def test_rag_agent_with_attachments(llama_stack_client, agent_config):
|
||||||
assert "lora" in response.output_message.content.lower()
|
assert "lora" in response.output_message.content.lower()
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skip(reason="Code interpreter is currently disabled in the Stack")
|
|
||||||
def test_rag_and_code_agent(llama_stack_client, agent_config):
|
|
||||||
if "llama-4" in agent_config["model"].lower():
|
|
||||||
pytest.xfail("Not working for llama4")
|
|
||||||
|
|
||||||
documents = []
|
|
||||||
documents.append(
|
|
||||||
Document(
|
|
||||||
document_id="nba_wiki",
|
|
||||||
content="The NBA was created on August 3, 1949, with the merger of the Basketball Association of America (BAA) and the National Basketball League (NBL).",
|
|
||||||
metadata={},
|
|
||||||
)
|
|
||||||
)
|
|
||||||
documents.append(
|
|
||||||
Document(
|
|
||||||
document_id="perplexity_wiki",
|
|
||||||
content="""Perplexity the company was founded in 2022 by Aravind Srinivas, Andy Konwinski, Denis Yarats and Johnny Ho, engineers with backgrounds in back-end systems, artificial intelligence (AI) and machine learning:
|
|
||||||
|
|
||||||
Srinivas, the CEO, worked at OpenAI as an AI researcher.
|
|
||||||
Konwinski was among the founding team at Databricks.
|
|
||||||
Yarats, the CTO, was an AI research scientist at Meta.
|
|
||||||
Ho, the CSO, worked as an engineer at Quora, then as a quantitative trader on Wall Street.[5]""",
|
|
||||||
metadata={},
|
|
||||||
)
|
|
||||||
)
|
|
||||||
vector_db_id = f"test-vector-db-{uuid4()}"
|
|
||||||
llama_stack_client.vector_dbs.register(
|
|
||||||
vector_db_id=vector_db_id,
|
|
||||||
embedding_model="all-MiniLM-L6-v2",
|
|
||||||
embedding_dimension=384,
|
|
||||||
)
|
|
||||||
llama_stack_client.tool_runtime.rag_tool.insert(
|
|
||||||
documents=documents,
|
|
||||||
vector_db_id=vector_db_id,
|
|
||||||
chunk_size_in_tokens=128,
|
|
||||||
)
|
|
||||||
agent_config = {
|
|
||||||
**agent_config,
|
|
||||||
"tools": [
|
|
||||||
dict(
|
|
||||||
name="builtin::rag/knowledge_search",
|
|
||||||
args={"vector_db_ids": [vector_db_id]},
|
|
||||||
),
|
|
||||||
"builtin::code_interpreter",
|
|
||||||
],
|
|
||||||
}
|
|
||||||
agent = Agent(llama_stack_client, **agent_config)
|
|
||||||
user_prompts = [
|
|
||||||
(
|
|
||||||
"when was Perplexity the company founded?",
|
|
||||||
[],
|
|
||||||
"knowledge_search",
|
|
||||||
"2022",
|
|
||||||
),
|
|
||||||
(
|
|
||||||
"when was the nba created?",
|
|
||||||
[],
|
|
||||||
"knowledge_search",
|
|
||||||
"1949",
|
|
||||||
),
|
|
||||||
]
|
|
||||||
|
|
||||||
for prompt, docs, tool_name, expected_kw in user_prompts:
|
|
||||||
session_id = agent.create_session(f"test-session-{uuid4()}")
|
|
||||||
response = agent.create_turn(
|
|
||||||
messages=[{"role": "user", "content": prompt}],
|
|
||||||
session_id=session_id,
|
|
||||||
documents=docs,
|
|
||||||
stream=False,
|
|
||||||
)
|
|
||||||
tool_execution_step = next(step for step in response.steps if step.step_type == "tool_execution")
|
|
||||||
assert tool_execution_step.tool_calls[0].tool_name == tool_name, f"Failed on {prompt}"
|
|
||||||
if expected_kw:
|
|
||||||
assert expected_kw in response.output_message.content.lower()
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"client_tools",
|
"client_tools",
|
||||||
[(get_boiling_point, False), (get_boiling_point_with_metadata, True)],
|
[(get_boiling_point, False), (get_boiling_point_with_metadata, True)],
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue