mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-22 22:19:49 +00:00
fix agent server tests
This commit is contained in:
parent
439f52b067
commit
18d9937500
4 changed files with 41 additions and 63 deletions
|
|
@ -8,19 +8,13 @@ import os
|
|||
from typing import Dict, List
|
||||
|
||||
import pytest
|
||||
from llama_models.llama3.api.datatypes import BuiltinTool
|
||||
|
||||
from llama_stack.apis.agents import (
|
||||
AgentConfig,
|
||||
AgentTool,
|
||||
AgentTurnResponseEventType,
|
||||
AgentTurnResponseStepCompletePayload,
|
||||
AgentTurnResponseStreamChunk,
|
||||
AgentTurnResponseTurnCompletePayload,
|
||||
Attachment,
|
||||
MemoryToolDefinition,
|
||||
SearchEngineType,
|
||||
SearchToolDefinition,
|
||||
ShieldCallStep,
|
||||
StepType,
|
||||
ToolChoice,
|
||||
|
|
@ -228,7 +222,7 @@ class TestAgents:
|
|||
check_turn_complete_event(turn_response, session_id, sample_messages)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_rag_agent_as_attachments(
|
||||
async def test_rag_agent(
|
||||
self,
|
||||
agents_stack,
|
||||
attachment_message,
|
||||
|
|
@ -236,6 +230,8 @@ class TestAgents:
|
|||
common_params,
|
||||
):
|
||||
agents_impl = agents_stack.impls[Api.agents]
|
||||
memory_banks_impl = agents_stack.impls[Api.memory_banks]
|
||||
memory_impl = agents_stack.impls[Api.memory]
|
||||
urls = [
|
||||
"memory_optimizations.rst",
|
||||
"chat.rst",
|
||||
|
|
@ -244,14 +240,28 @@ class TestAgents:
|
|||
"qat_finetune.rst",
|
||||
"lora_finetune.rst",
|
||||
]
|
||||
|
||||
attachments = [
|
||||
Attachment(
|
||||
documents = [
|
||||
MemoryBankDocument(
|
||||
document_id=f"num-{i}",
|
||||
content=f"https://raw.githubusercontent.com/pytorch/torchtune/main/docs/source/tutorials/{url}",
|
||||
mime_type="text/plain",
|
||||
metadata={},
|
||||
)
|
||||
for i, url in enumerate(urls)
|
||||
]
|
||||
await memory_banks_impl.register_memory_bank(
|
||||
memory_bank_id="test_bank",
|
||||
params=VectorMemoryBankParams(
|
||||
embedding_model="all-MiniLM-L6-v2",
|
||||
chunk_size_in_tokens=512,
|
||||
overlap_size_in_tokens=64,
|
||||
),
|
||||
provider_id="faiss",
|
||||
)
|
||||
memory_impl.insert_documents(
|
||||
bank_id="test_bank",
|
||||
documents=documents,
|
||||
)
|
||||
|
||||
agent_config = AgentConfig(
|
||||
**{
|
||||
|
|
@ -266,7 +276,6 @@ class TestAgents:
|
|||
agent_id=agent_id,
|
||||
session_id=session_id,
|
||||
messages=attachment_message,
|
||||
attachments=attachments,
|
||||
stream=True,
|
||||
)
|
||||
turn_response = [
|
||||
|
|
@ -290,11 +299,11 @@ class TestAgents:
|
|||
assert len(turn_response) > 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_agent_turn_with_brave_search(
|
||||
async def test_create_agent_turn_with_tavily_search(
|
||||
self, agents_stack, search_query_messages, common_params
|
||||
):
|
||||
if "BRAVE_SEARCH_API_KEY" not in os.environ:
|
||||
pytest.skip("BRAVE_SEARCH_API_KEY not set, skipping test")
|
||||
if "TAVILY_SEARCH_API_KEY" not in os.environ:
|
||||
pytest.skip("TAVILY_SEARCH_API_KEY not set, skipping test")
|
||||
|
||||
await create_agent_turn_with_search_tool(
|
||||
agents_stack,
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue