fix agent server tests

This commit is contained in:
Dinesh Yeduguru 2024-12-26 18:24:27 -08:00
parent 439f52b067
commit 18d9937500
4 changed files with 41 additions and 63 deletions

View file

@ -8,19 +8,13 @@ import os
from typing import Dict, List
import pytest
from llama_models.llama3.api.datatypes import BuiltinTool
from llama_stack.apis.agents import (
AgentConfig,
AgentTool,
AgentTurnResponseEventType,
AgentTurnResponseStepCompletePayload,
AgentTurnResponseStreamChunk,
AgentTurnResponseTurnCompletePayload,
Attachment,
MemoryToolDefinition,
SearchEngineType,
SearchToolDefinition,
ShieldCallStep,
StepType,
ToolChoice,
@ -228,7 +222,7 @@ class TestAgents:
check_turn_complete_event(turn_response, session_id, sample_messages)
@pytest.mark.asyncio
async def test_rag_agent_as_attachments(
async def test_rag_agent(
self,
agents_stack,
attachment_message,
@ -236,6 +230,8 @@ class TestAgents:
common_params,
):
agents_impl = agents_stack.impls[Api.agents]
memory_banks_impl = agents_stack.impls[Api.memory_banks]
memory_impl = agents_stack.impls[Api.memory]
urls = [
"memory_optimizations.rst",
"chat.rst",
@ -244,14 +240,28 @@ class TestAgents:
"qat_finetune.rst",
"lora_finetune.rst",
]
attachments = [
Attachment(
documents = [
MemoryBankDocument(
document_id=f"num-{i}",
content=f"https://raw.githubusercontent.com/pytorch/torchtune/main/docs/source/tutorials/{url}",
mime_type="text/plain",
metadata={},
)
for i, url in enumerate(urls)
]
await memory_banks_impl.register_memory_bank(
memory_bank_id="test_bank",
params=VectorMemoryBankParams(
embedding_model="all-MiniLM-L6-v2",
chunk_size_in_tokens=512,
overlap_size_in_tokens=64,
),
provider_id="faiss",
)
memory_impl.insert_documents(
bank_id="test_bank",
documents=documents,
)
agent_config = AgentConfig(
**{
@ -266,7 +276,6 @@ class TestAgents:
agent_id=agent_id,
session_id=session_id,
messages=attachment_message,
attachments=attachments,
stream=True,
)
turn_response = [
@ -290,11 +299,11 @@ class TestAgents:
assert len(turn_response) > 0
@pytest.mark.asyncio
async def test_create_agent_turn_with_brave_search(
async def test_create_agent_turn_with_tavily_search(
self, agents_stack, search_query_messages, common_params
):
if "BRAVE_SEARCH_API_KEY" not in os.environ:
pytest.skip("BRAVE_SEARCH_API_KEY not set, skipping test")
if "TAVILY_SEARCH_API_KEY" not in os.environ:
pytest.skip("TAVILY_SEARCH_API_KEY not set, skipping test")
await create_agent_turn_with_search_tool(
agents_stack,