mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-29 07:14:20 +00:00
agents unit test
This commit is contained in:
parent
4667c1f542
commit
6d7f341ecf
2 changed files with 102 additions and 1 deletions
|
@ -31,4 +31,4 @@ providers:
|
||||||
persistence_store:
|
persistence_store:
|
||||||
namespace: null
|
namespace: null
|
||||||
type: sqlite
|
type: sqlite
|
||||||
db_path: /Users/ashwin/.llama/runtime/kvstore.db
|
db_path: ~/.llama/runtime/kvstore.db
|
||||||
|
|
|
@ -64,6 +64,24 @@ def search_query_messages():
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def attachment_message():
|
||||||
|
return [
|
||||||
|
UserMessage(
|
||||||
|
content="I am attaching some documentation for Torchtune. Help me answer questions I will ask next.",
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def query_attachment_messages():
|
||||||
|
return [
|
||||||
|
UserMessage(
|
||||||
|
content="What are the top 5 topics that were explained? Only list succinct bullet points."
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_create_agent_turn(agents_settings, sample_messages):
|
async def test_create_agent_turn(agents_settings, sample_messages):
|
||||||
agents_impl = agents_settings["impl"]
|
agents_impl = agents_settings["impl"]
|
||||||
|
@ -123,6 +141,89 @@ async def test_create_agent_turn(agents_settings, sample_messages):
|
||||||
assert len(final_event.turn.output_message.content) > 0
|
assert len(final_event.turn.output_message.content) > 0
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_rag_agent_as_attachments(
|
||||||
|
agents_settings, attachment_message, query_attachment_messages
|
||||||
|
):
|
||||||
|
urls = [
|
||||||
|
"memory_optimizations.rst",
|
||||||
|
"chat.rst",
|
||||||
|
"llama3.rst",
|
||||||
|
"datasets.rst",
|
||||||
|
"qat_finetune.rst",
|
||||||
|
"lora_finetune.rst",
|
||||||
|
]
|
||||||
|
|
||||||
|
attachments = [
|
||||||
|
Attachment(
|
||||||
|
content=f"https://raw.githubusercontent.com/pytorch/torchtune/main/docs/source/tutorials/{url}",
|
||||||
|
mime_type="text/plain",
|
||||||
|
)
|
||||||
|
for i, url in enumerate(urls)
|
||||||
|
]
|
||||||
|
|
||||||
|
agents_impl = agents_settings["impl"]
|
||||||
|
|
||||||
|
agent_config = AgentConfig(
|
||||||
|
model=agents_settings["common_params"]["model"],
|
||||||
|
instructions=agents_settings["common_params"]["instructions"],
|
||||||
|
enable_session_persistence=True,
|
||||||
|
sampling_params=SamplingParams(temperature=0.7, top_p=0.95),
|
||||||
|
input_shields=[],
|
||||||
|
output_shields=[],
|
||||||
|
tools=[
|
||||||
|
MemoryToolDefinition(
|
||||||
|
memory_bank_configs=[],
|
||||||
|
query_generator_config={
|
||||||
|
"type": "default",
|
||||||
|
"sep": " ",
|
||||||
|
},
|
||||||
|
max_tokens_in_context=4096,
|
||||||
|
max_chunks=10,
|
||||||
|
),
|
||||||
|
],
|
||||||
|
max_infer_iters=5,
|
||||||
|
)
|
||||||
|
|
||||||
|
create_response = await agents_impl.create_agent(agent_config)
|
||||||
|
agent_id = create_response.agent_id
|
||||||
|
|
||||||
|
# Create a session
|
||||||
|
session_create_response = await agents_impl.create_agent_session(
|
||||||
|
agent_id, "Test Session"
|
||||||
|
)
|
||||||
|
session_id = session_create_response.session_id
|
||||||
|
|
||||||
|
# Create and execute a turn
|
||||||
|
turn_request = dict(
|
||||||
|
agent_id=agent_id,
|
||||||
|
session_id=session_id,
|
||||||
|
messages=attachment_message,
|
||||||
|
attachments=attachments,
|
||||||
|
stream=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
turn_response = [
|
||||||
|
chunk async for chunk in agents_impl.create_agent_turn(**turn_request)
|
||||||
|
]
|
||||||
|
|
||||||
|
assert len(turn_response) > 0
|
||||||
|
|
||||||
|
# Create a second turn querying the agent
|
||||||
|
turn_request = dict(
|
||||||
|
agent_id=agent_id,
|
||||||
|
session_id=session_id,
|
||||||
|
messages=query_attachment_messages,
|
||||||
|
stream=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
turn_response = [
|
||||||
|
chunk async for chunk in agents_impl.create_agent_turn(**turn_request)
|
||||||
|
]
|
||||||
|
|
||||||
|
assert len(turn_response) > 0
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_create_agent_turn_with_brave_search(
|
async def test_create_agent_turn_with_brave_search(
|
||||||
agents_settings, search_query_messages
|
agents_settings, search_query_messages
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue