Pass memory bank API to agent impl

This commit is contained in:
Ashwin Bharambe 2024-10-09 21:16:57 -07:00
parent 6788173ffc
commit 2d94ca71a9
3 changed files with 10 additions and 1 deletions

View file

@ -24,6 +24,7 @@ from termcolor import cprint
from llama_stack.apis.agents import * # noqa: F403 from llama_stack.apis.agents import * # noqa: F403
from llama_stack.apis.inference import * # noqa: F403 from llama_stack.apis.inference import * # noqa: F403
from llama_stack.apis.memory import * # noqa: F403 from llama_stack.apis.memory import * # noqa: F403
from llama_stack.apis.memory_banks import * # noqa: F403
from llama_stack.apis.safety import * # noqa: F403 from llama_stack.apis.safety import * # noqa: F403
from llama_stack.providers.utils.kvstore import KVStore from llama_stack.providers.utils.kvstore import KVStore
@ -56,6 +57,7 @@ class ChatAgent(ShieldRunnerMixin):
agent_config: AgentConfig, agent_config: AgentConfig,
inference_api: Inference, inference_api: Inference,
memory_api: Memory, memory_api: Memory,
memory_banks_api: MemoryBanks,
safety_api: Safety, safety_api: Safety,
persistence_store: KVStore, persistence_store: KVStore,
): ):
@ -63,6 +65,7 @@ class ChatAgent(ShieldRunnerMixin):
self.agent_config = agent_config self.agent_config = agent_config
self.inference_api = inference_api self.inference_api = inference_api
self.memory_api = memory_api self.memory_api = memory_api
self.memory_banks_api = memory_banks_api
self.safety_api = safety_api self.safety_api = safety_api
self.storage = AgentPersistence(agent_id, persistence_store) self.storage = AgentPersistence(agent_id, persistence_store)
@ -643,7 +646,7 @@ class ChatAgent(ShieldRunnerMixin):
embedding_model="all-MiniLM-L6-v2", embedding_model="all-MiniLM-L6-v2",
chunk_size_in_tokens=512, chunk_size_in_tokens=512,
) )
await self.memory_api.register_memory_bank(memory_bank) await self.memory_banks_api.register_memory_bank(memory_bank)
await self.storage.add_memory_bank_to_session(session_id, bank_id) await self.storage.add_memory_bank_to_session(session_id, bank_id)
else: else:
bank_id = session_info.memory_bank_id bank_id = session_info.memory_bank_id

View file

@ -11,6 +11,7 @@ from typing import AsyncGenerator
from llama_stack.apis.inference import Inference from llama_stack.apis.inference import Inference
from llama_stack.apis.memory import Memory from llama_stack.apis.memory import Memory
from llama_stack.apis.memory_banks import MemoryBanks
from llama_stack.apis.safety import Safety from llama_stack.apis.safety import Safety
from llama_stack.apis.agents import * # noqa: F403 from llama_stack.apis.agents import * # noqa: F403
@ -30,11 +31,14 @@ class MetaReferenceAgentsImpl(Agents):
inference_api: Inference, inference_api: Inference,
memory_api: Memory, memory_api: Memory,
safety_api: Safety, safety_api: Safety,
memory_banks_api: MemoryBanks,
): ):
self.config = config self.config = config
self.inference_api = inference_api self.inference_api = inference_api
self.memory_api = memory_api self.memory_api = memory_api
self.safety_api = safety_api self.safety_api = safety_api
self.memory_banks_api = memory_banks_api
self.in_memory_store = InmemoryKVStoreImpl() self.in_memory_store = InmemoryKVStoreImpl()
async def initialize(self) -> None: async def initialize(self) -> None:
@ -81,6 +85,7 @@ class MetaReferenceAgentsImpl(Agents):
inference_api=self.inference_api, inference_api=self.inference_api,
safety_api=self.safety_api, safety_api=self.safety_api,
memory_api=self.memory_api, memory_api=self.memory_api,
memory_banks_api=self.memory_banks_api,
persistence_store=( persistence_store=(
self.persistence_store self.persistence_store
if agent_config.enable_session_persistence if agent_config.enable_session_persistence

View file

@ -28,6 +28,7 @@ def available_providers() -> List[ProviderSpec]:
Api.inference, Api.inference,
Api.safety, Api.safety,
Api.memory, Api.memory,
Api.memory_banks,
], ],
), ),
remote_provider_spec( remote_provider_spec(