diff --git a/llama_stack/distribution/routers/routing_tables.py b/llama_stack/distribution/routers/routing_tables.py index 2aff0f3a2..45708649b 100644 --- a/llama_stack/distribution/routers/routing_tables.py +++ b/llama_stack/distribution/routers/routing_tables.py @@ -6,7 +6,7 @@ from typing import Any, Dict, List, Optional -from pydantic import parse_obj_as +from pydantic import TypeAdapter from llama_stack.apis.common.content_types import URL from llama_stack.apis.common.type_system import ParamType @@ -39,7 +39,6 @@ from llama_stack.distribution.datatypes import ( RoutableObjectWithProvider, RoutedProtocol, ) - from llama_stack.distribution.store import DistributionRegistry from llama_stack.providers.datatypes import Api, RoutingTable @@ -361,7 +360,7 @@ class MemoryBanksRoutingTable(CommonRoutingTableImpl, MemoryBanks): memory_bank_data["embedding_dimension"] = model.metadata[ "embedding_dimension" ] - memory_bank = parse_obj_as(MemoryBank, memory_bank_data) + memory_bank = TypeAdapter(MemoryBank).validate_python(memory_bank_data) await self.register_object(memory_bank) return memory_bank @@ -525,7 +524,7 @@ class ToolGroupsRoutingTable(CommonRoutingTableImpl, ToolGroups): provider_id = list(self.impls_by_provider_id.keys())[0] # parse tool group to the type if dict - tool_group = parse_obj_as(ToolGroupDef, tool_group) + tool_group = TypeAdapter(ToolGroupDef).validate_python(tool_group) if isinstance(tool_group, MCPToolGroupDef): tool_defs = await self.impls_by_provider_id[provider_id].discover_tools( tool_group diff --git a/llama_stack/providers/tests/agents/conftest.py b/llama_stack/providers/tests/agents/conftest.py index d80013fae..f805fbbbb 100644 --- a/llama_stack/providers/tests/agents/conftest.py +++ b/llama_stack/providers/tests/agents/conftest.py @@ -19,7 +19,7 @@ DEFAULT_PROVIDER_COMBINATIONS = [ "safety": "llama_guard", "memory": "faiss", "agents": "meta_reference", - "tool_runtime": "memory", + "tool_runtime": "memory_and_search", }, id="meta_reference", marks=pytest.mark.meta_reference, @@ -30,7 +30,7 @@ DEFAULT_PROVIDER_COMBINATIONS = [ "safety": "llama_guard", "memory": "faiss", "agents": "meta_reference", - "tool_runtime": "memory", + "tool_runtime": "memory_and_search", }, id="ollama", marks=pytest.mark.ollama, @@ -42,7 +42,7 @@ DEFAULT_PROVIDER_COMBINATIONS = [ # make this work with Weaviate which is what the together distro supports "memory": "faiss", "agents": "meta_reference", - "tool_runtime": "memory", + "tool_runtime": "memory_and_search", }, id="together", marks=pytest.mark.together, @@ -53,7 +53,7 @@ DEFAULT_PROVIDER_COMBINATIONS = [ "safety": "llama_guard", "memory": "faiss", "agents": "meta_reference", - "tool_runtime": "memory", + "tool_runtime": "memory_and_search", }, id="fireworks", marks=pytest.mark.fireworks, @@ -64,7 +64,7 @@ DEFAULT_PROVIDER_COMBINATIONS = [ "safety": "remote", "memory": "remote", "agents": "remote", - "tool_runtime": "memory", + "tool_runtime": "memory_and_search", }, id="remote", marks=pytest.mark.remote, diff --git a/llama_stack/providers/tests/agents/fixtures.py b/llama_stack/providers/tests/agents/fixtures.py index 97d0d47e6..71e98102e 100644 --- a/llama_stack/providers/tests/agents/fixtures.py +++ b/llama_stack/providers/tests/agents/fixtures.py @@ -64,7 +64,7 @@ def agents_meta_reference() -> ProviderFixture: @pytest.fixture(scope="session") -def tool_runtime_memory() -> ProviderFixture: +def tool_runtime_memory_and_search() -> ProviderFixture: return ProviderFixture( providers=[ Provider( @@ -72,31 +72,19 @@ def tool_runtime_memory() -> ProviderFixture: provider_type="inline::memory-runtime", config={}, ), - Provider( - provider_id="brave-search", - provider_type="inline::brave-search", - config={ - "api_key": os.environ["BRAVE_SEARCH_API_KEY"], - }, - ), Provider( provider_id="tavily-search", - provider_type="inline::tavily-search", + provider_type="remote::tavily-search", config={ "api_key": os.environ["TAVILY_SEARCH_API_KEY"], }, ), - Provider( - provider_id="code-interpreter", - provider_type="inline::code-interpreter", - config={}, - ), ], ) AGENTS_FIXTURES = ["meta_reference", "remote"] -TOOL_RUNTIME_FIXTURES = ["memory"] +TOOL_RUNTIME_FIXTURES = ["memory_and_search"] @pytest_asyncio.fixture(scope="session") @@ -173,43 +161,25 @@ async def agents_stack(request, inference_model, safety_shield): name="memory", description="memory", parameters=[ - ToolParameter( - name="session_id", - description="session id", - parameter_type="string", - required=True, - ), ToolParameter( name="input_messages", description="messages", parameter_type="list", required=True, ), - ToolParameter( - name="attachments", - description="attachments", - parameter_type="list", - required=False, - ), ], - metadata={}, + metadata={ + "config": { + "memory_bank_configs": [ + {"bank_id": "test_bank", "type": "vector"} + ] + } + }, ) ], ), provider_id="memory-runtime", ), - ToolGroupInput( - tool_group_id="code_interpreter_group", - tool_group=UserDefinedToolGroupDef( - tools=[ - BuiltInToolDef( - built_in_type=BuiltinTool.code_interpreter, - metadata={}, - ) - ], - ), - provider_id="code-interpreter", - ), ] test_stack = await construct_stack_for_test( diff --git a/llama_stack/providers/tests/agents/test_agents.py b/llama_stack/providers/tests/agents/test_agents.py index a8c472da4..3534e0f84 100644 --- a/llama_stack/providers/tests/agents/test_agents.py +++ b/llama_stack/providers/tests/agents/test_agents.py @@ -8,19 +8,13 @@ import os from typing import Dict, List import pytest -from llama_models.llama3.api.datatypes import BuiltinTool from llama_stack.apis.agents import ( AgentConfig, - AgentTool, AgentTurnResponseEventType, AgentTurnResponseStepCompletePayload, AgentTurnResponseStreamChunk, AgentTurnResponseTurnCompletePayload, - Attachment, - MemoryToolDefinition, - SearchEngineType, - SearchToolDefinition, ShieldCallStep, StepType, ToolChoice, @@ -228,7 +222,7 @@ class TestAgents: check_turn_complete_event(turn_response, session_id, sample_messages) @pytest.mark.asyncio - async def test_rag_agent_as_attachments( + async def test_rag_agent( self, agents_stack, attachment_message, @@ -236,6 +230,8 @@ class TestAgents: common_params, ): agents_impl = agents_stack.impls[Api.agents] + memory_banks_impl = agents_stack.impls[Api.memory_banks] + memory_impl = agents_stack.impls[Api.memory] urls = [ "memory_optimizations.rst", "chat.rst", @@ -244,14 +240,28 @@ class TestAgents: "qat_finetune.rst", "lora_finetune.rst", ] - - attachments = [ - Attachment( + documents = [ + MemoryBankDocument( + document_id=f"num-{i}", content=f"https://raw.githubusercontent.com/pytorch/torchtune/main/docs/source/tutorials/{url}", mime_type="text/plain", + metadata={}, ) for i, url in enumerate(urls) ] + await memory_banks_impl.register_memory_bank( + memory_bank_id="test_bank", + params=VectorMemoryBankParams( + embedding_model="all-MiniLM-L6-v2", + chunk_size_in_tokens=512, + overlap_size_in_tokens=64, + ), + provider_id="faiss", + ) + memory_impl.insert_documents( + bank_id="test_bank", + documents=documents, + ) agent_config = AgentConfig( **{ @@ -266,7 +276,6 @@ class TestAgents: agent_id=agent_id, session_id=session_id, messages=attachment_message, - attachments=attachments, stream=True, ) turn_response = [ @@ -290,11 +299,11 @@ class TestAgents: assert len(turn_response) > 0 @pytest.mark.asyncio - async def test_create_agent_turn_with_brave_search( + async def test_create_agent_turn_with_tavily_search( self, agents_stack, search_query_messages, common_params ): - if "BRAVE_SEARCH_API_KEY" not in os.environ: - pytest.skip("BRAVE_SEARCH_API_KEY not set, skipping test") + if "TAVILY_SEARCH_API_KEY" not in os.environ: + pytest.skip("TAVILY_SEARCH_API_KEY not set, skipping test") await create_agent_turn_with_search_tool( agents_stack,