fix agent server tests

This commit is contained in:
Dinesh Yeduguru 2024-12-26 18:24:27 -08:00
parent 439f52b067
commit 18d9937500
4 changed files with 41 additions and 63 deletions

View file

@ -6,7 +6,7 @@
from typing import Any, Dict, List, Optional
from pydantic import parse_obj_as
from pydantic import TypeAdapter
from llama_stack.apis.common.content_types import URL
from llama_stack.apis.common.type_system import ParamType
@ -39,7 +39,6 @@ from llama_stack.distribution.datatypes import (
RoutableObjectWithProvider,
RoutedProtocol,
)
from llama_stack.distribution.store import DistributionRegistry
from llama_stack.providers.datatypes import Api, RoutingTable
@ -361,7 +360,7 @@ class MemoryBanksRoutingTable(CommonRoutingTableImpl, MemoryBanks):
memory_bank_data["embedding_dimension"] = model.metadata[
"embedding_dimension"
]
memory_bank = parse_obj_as(MemoryBank, memory_bank_data)
memory_bank = TypeAdapter(MemoryBank).validate_python(memory_bank_data)
await self.register_object(memory_bank)
return memory_bank
@ -525,7 +524,7 @@ class ToolGroupsRoutingTable(CommonRoutingTableImpl, ToolGroups):
provider_id = list(self.impls_by_provider_id.keys())[0]
# parse tool group to the type if dict
tool_group = parse_obj_as(ToolGroupDef, tool_group)
tool_group = TypeAdapter(ToolGroupDef).validate_python(tool_group)
if isinstance(tool_group, MCPToolGroupDef):
tool_defs = await self.impls_by_provider_id[provider_id].discover_tools(
tool_group

View file

@ -19,7 +19,7 @@ DEFAULT_PROVIDER_COMBINATIONS = [
"safety": "llama_guard",
"memory": "faiss",
"agents": "meta_reference",
"tool_runtime": "memory",
"tool_runtime": "memory_and_search",
},
id="meta_reference",
marks=pytest.mark.meta_reference,
@ -30,7 +30,7 @@ DEFAULT_PROVIDER_COMBINATIONS = [
"safety": "llama_guard",
"memory": "faiss",
"agents": "meta_reference",
"tool_runtime": "memory",
"tool_runtime": "memory_and_search",
},
id="ollama",
marks=pytest.mark.ollama,
@ -42,7 +42,7 @@ DEFAULT_PROVIDER_COMBINATIONS = [
# make this work with Weaviate which is what the together distro supports
"memory": "faiss",
"agents": "meta_reference",
"tool_runtime": "memory",
"tool_runtime": "memory_and_search",
},
id="together",
marks=pytest.mark.together,
@ -53,7 +53,7 @@ DEFAULT_PROVIDER_COMBINATIONS = [
"safety": "llama_guard",
"memory": "faiss",
"agents": "meta_reference",
"tool_runtime": "memory",
"tool_runtime": "memory_and_search",
},
id="fireworks",
marks=pytest.mark.fireworks,
@ -64,7 +64,7 @@ DEFAULT_PROVIDER_COMBINATIONS = [
"safety": "remote",
"memory": "remote",
"agents": "remote",
"tool_runtime": "memory",
"tool_runtime": "memory_and_search",
},
id="remote",
marks=pytest.mark.remote,

View file

@ -64,7 +64,7 @@ def agents_meta_reference() -> ProviderFixture:
@pytest.fixture(scope="session")
def tool_runtime_memory() -> ProviderFixture:
def tool_runtime_memory_and_search() -> ProviderFixture:
return ProviderFixture(
providers=[
Provider(
@ -72,31 +72,19 @@ def tool_runtime_memory() -> ProviderFixture:
provider_type="inline::memory-runtime",
config={},
),
Provider(
provider_id="brave-search",
provider_type="inline::brave-search",
config={
"api_key": os.environ["BRAVE_SEARCH_API_KEY"],
},
),
Provider(
provider_id="tavily-search",
provider_type="inline::tavily-search",
provider_type="remote::tavily-search",
config={
"api_key": os.environ["TAVILY_SEARCH_API_KEY"],
},
),
Provider(
provider_id="code-interpreter",
provider_type="inline::code-interpreter",
config={},
),
],
)
AGENTS_FIXTURES = ["meta_reference", "remote"]
TOOL_RUNTIME_FIXTURES = ["memory"]
TOOL_RUNTIME_FIXTURES = ["memory_and_search"]
@pytest_asyncio.fixture(scope="session")
@ -173,43 +161,25 @@ async def agents_stack(request, inference_model, safety_shield):
name="memory",
description="memory",
parameters=[
ToolParameter(
name="session_id",
description="session id",
parameter_type="string",
required=True,
),
ToolParameter(
name="input_messages",
description="messages",
parameter_type="list",
required=True,
),
ToolParameter(
name="attachments",
description="attachments",
parameter_type="list",
required=False,
),
],
metadata={},
metadata={
"config": {
"memory_bank_configs": [
{"bank_id": "test_bank", "type": "vector"}
]
}
},
)
],
),
provider_id="memory-runtime",
),
ToolGroupInput(
tool_group_id="code_interpreter_group",
tool_group=UserDefinedToolGroupDef(
tools=[
BuiltInToolDef(
built_in_type=BuiltinTool.code_interpreter,
metadata={},
)
],
),
provider_id="code-interpreter",
),
]
test_stack = await construct_stack_for_test(

View file

@ -8,19 +8,13 @@ import os
from typing import Dict, List
import pytest
from llama_models.llama3.api.datatypes import BuiltinTool
from llama_stack.apis.agents import (
AgentConfig,
AgentTool,
AgentTurnResponseEventType,
AgentTurnResponseStepCompletePayload,
AgentTurnResponseStreamChunk,
AgentTurnResponseTurnCompletePayload,
Attachment,
MemoryToolDefinition,
SearchEngineType,
SearchToolDefinition,
ShieldCallStep,
StepType,
ToolChoice,
@ -228,7 +222,7 @@ class TestAgents:
check_turn_complete_event(turn_response, session_id, sample_messages)
@pytest.mark.asyncio
async def test_rag_agent_as_attachments(
async def test_rag_agent(
self,
agents_stack,
attachment_message,
@ -236,6 +230,8 @@ class TestAgents:
common_params,
):
agents_impl = agents_stack.impls[Api.agents]
memory_banks_impl = agents_stack.impls[Api.memory_banks]
memory_impl = agents_stack.impls[Api.memory]
urls = [
"memory_optimizations.rst",
"chat.rst",
@ -244,14 +240,28 @@ class TestAgents:
"qat_finetune.rst",
"lora_finetune.rst",
]
attachments = [
Attachment(
documents = [
MemoryBankDocument(
document_id=f"num-{i}",
content=f"https://raw.githubusercontent.com/pytorch/torchtune/main/docs/source/tutorials/{url}",
mime_type="text/plain",
metadata={},
)
for i, url in enumerate(urls)
]
await memory_banks_impl.register_memory_bank(
memory_bank_id="test_bank",
params=VectorMemoryBankParams(
embedding_model="all-MiniLM-L6-v2",
chunk_size_in_tokens=512,
overlap_size_in_tokens=64,
),
provider_id="faiss",
)
memory_impl.insert_documents(
bank_id="test_bank",
documents=documents,
)
agent_config = AgentConfig(
**{
@ -266,7 +276,6 @@ class TestAgents:
agent_id=agent_id,
session_id=session_id,
messages=attachment_message,
attachments=attachments,
stream=True,
)
turn_response = [
@ -290,11 +299,11 @@ class TestAgents:
assert len(turn_response) > 0
@pytest.mark.asyncio
async def test_create_agent_turn_with_brave_search(
async def test_create_agent_turn_with_tavily_search(
self, agents_stack, search_query_messages, common_params
):
if "BRAVE_SEARCH_API_KEY" not in os.environ:
pytest.skip("BRAVE_SEARCH_API_KEY not set, skipping test")
if "TAVILY_SEARCH_API_KEY" not in os.environ:
pytest.skip("TAVILY_SEARCH_API_KEY not set, skipping test")
await create_agent_turn_with_search_tool(
agents_stack,