diff --git a/llama_stack/providers/impls/meta_reference/agents/agent_instance.py b/llama_stack/providers/impls/meta_reference/agents/agent_instance.py index fca335bf5..0d334fdad 100644 --- a/llama_stack/providers/impls/meta_reference/agents/agent_instance.py +++ b/llama_stack/providers/impls/meta_reference/agents/agent_instance.py @@ -24,6 +24,7 @@ from termcolor import cprint from llama_stack.apis.agents import * # noqa: F403 from llama_stack.apis.inference import * # noqa: F403 from llama_stack.apis.memory import * # noqa: F403 +from llama_stack.apis.memory_banks import * # noqa: F403 from llama_stack.apis.safety import * # noqa: F403 from llama_stack.providers.utils.kvstore import KVStore @@ -56,6 +57,7 @@ class ChatAgent(ShieldRunnerMixin): agent_config: AgentConfig, inference_api: Inference, memory_api: Memory, + memory_banks_api: MemoryBanks, safety_api: Safety, persistence_store: KVStore, ): @@ -63,6 +65,7 @@ class ChatAgent(ShieldRunnerMixin): self.agent_config = agent_config self.inference_api = inference_api self.memory_api = memory_api + self.memory_banks_api = memory_banks_api self.safety_api = safety_api self.storage = AgentPersistence(agent_id, persistence_store) @@ -643,7 +646,7 @@ class ChatAgent(ShieldRunnerMixin): embedding_model="all-MiniLM-L6-v2", chunk_size_in_tokens=512, ) - await self.memory_api.register_memory_bank(memory_bank) + await self.memory_banks_api.register_memory_bank(memory_bank) await self.storage.add_memory_bank_to_session(session_id, bank_id) else: bank_id = session_info.memory_bank_id diff --git a/llama_stack/providers/impls/meta_reference/agents/agents.py b/llama_stack/providers/impls/meta_reference/agents/agents.py index e6fa1744d..4dbc71dfa 100644 --- a/llama_stack/providers/impls/meta_reference/agents/agents.py +++ b/llama_stack/providers/impls/meta_reference/agents/agents.py @@ -11,6 +11,7 @@ from typing import AsyncGenerator from llama_stack.apis.inference import Inference from llama_stack.apis.memory import Memory +from llama_stack.apis.memory_banks import MemoryBanks from llama_stack.apis.safety import Safety from llama_stack.apis.agents import * # noqa: F403 @@ -30,11 +31,14 @@ class MetaReferenceAgentsImpl(Agents): inference_api: Inference, memory_api: Memory, safety_api: Safety, + memory_banks_api: MemoryBanks, ): self.config = config self.inference_api = inference_api self.memory_api = memory_api self.safety_api = safety_api + self.memory_banks_api = memory_banks_api + self.in_memory_store = InmemoryKVStoreImpl() async def initialize(self) -> None: @@ -81,6 +85,7 @@ class MetaReferenceAgentsImpl(Agents): inference_api=self.inference_api, safety_api=self.safety_api, memory_api=self.memory_api, + memory_banks_api=self.memory_banks_api, persistence_store=( self.persistence_store if agent_config.enable_session_persistence diff --git a/llama_stack/providers/registry/agents.py b/llama_stack/providers/registry/agents.py index 2603b5faf..8f4d3a03e 100644 --- a/llama_stack/providers/registry/agents.py +++ b/llama_stack/providers/registry/agents.py @@ -28,6 +28,7 @@ def available_providers() -> List[ProviderSpec]: Api.inference, Api.safety, Api.memory, + Api.memory_banks, ], ), remote_provider_spec(