mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-29 07:14:20 +00:00
Pass memory bank API to agent impl
This commit is contained in:
parent
6788173ffc
commit
2d94ca71a9
3 changed files with 10 additions and 1 deletions
|
@ -24,6 +24,7 @@ from termcolor import cprint
|
||||||
from llama_stack.apis.agents import * # noqa: F403
|
from llama_stack.apis.agents import * # noqa: F403
|
||||||
from llama_stack.apis.inference import * # noqa: F403
|
from llama_stack.apis.inference import * # noqa: F403
|
||||||
from llama_stack.apis.memory import * # noqa: F403
|
from llama_stack.apis.memory import * # noqa: F403
|
||||||
|
from llama_stack.apis.memory_banks import * # noqa: F403
|
||||||
from llama_stack.apis.safety import * # noqa: F403
|
from llama_stack.apis.safety import * # noqa: F403
|
||||||
|
|
||||||
from llama_stack.providers.utils.kvstore import KVStore
|
from llama_stack.providers.utils.kvstore import KVStore
|
||||||
|
@ -56,6 +57,7 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
agent_config: AgentConfig,
|
agent_config: AgentConfig,
|
||||||
inference_api: Inference,
|
inference_api: Inference,
|
||||||
memory_api: Memory,
|
memory_api: Memory,
|
||||||
|
memory_banks_api: MemoryBanks,
|
||||||
safety_api: Safety,
|
safety_api: Safety,
|
||||||
persistence_store: KVStore,
|
persistence_store: KVStore,
|
||||||
):
|
):
|
||||||
|
@ -63,6 +65,7 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
self.agent_config = agent_config
|
self.agent_config = agent_config
|
||||||
self.inference_api = inference_api
|
self.inference_api = inference_api
|
||||||
self.memory_api = memory_api
|
self.memory_api = memory_api
|
||||||
|
self.memory_banks_api = memory_banks_api
|
||||||
self.safety_api = safety_api
|
self.safety_api = safety_api
|
||||||
self.storage = AgentPersistence(agent_id, persistence_store)
|
self.storage = AgentPersistence(agent_id, persistence_store)
|
||||||
|
|
||||||
|
@ -643,7 +646,7 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
embedding_model="all-MiniLM-L6-v2",
|
embedding_model="all-MiniLM-L6-v2",
|
||||||
chunk_size_in_tokens=512,
|
chunk_size_in_tokens=512,
|
||||||
)
|
)
|
||||||
await self.memory_api.register_memory_bank(memory_bank)
|
await self.memory_banks_api.register_memory_bank(memory_bank)
|
||||||
await self.storage.add_memory_bank_to_session(session_id, bank_id)
|
await self.storage.add_memory_bank_to_session(session_id, bank_id)
|
||||||
else:
|
else:
|
||||||
bank_id = session_info.memory_bank_id
|
bank_id = session_info.memory_bank_id
|
||||||
|
|
|
@ -11,6 +11,7 @@ from typing import AsyncGenerator
|
||||||
|
|
||||||
from llama_stack.apis.inference import Inference
|
from llama_stack.apis.inference import Inference
|
||||||
from llama_stack.apis.memory import Memory
|
from llama_stack.apis.memory import Memory
|
||||||
|
from llama_stack.apis.memory_banks import MemoryBanks
|
||||||
from llama_stack.apis.safety import Safety
|
from llama_stack.apis.safety import Safety
|
||||||
from llama_stack.apis.agents import * # noqa: F403
|
from llama_stack.apis.agents import * # noqa: F403
|
||||||
|
|
||||||
|
@ -30,11 +31,14 @@ class MetaReferenceAgentsImpl(Agents):
|
||||||
inference_api: Inference,
|
inference_api: Inference,
|
||||||
memory_api: Memory,
|
memory_api: Memory,
|
||||||
safety_api: Safety,
|
safety_api: Safety,
|
||||||
|
memory_banks_api: MemoryBanks,
|
||||||
):
|
):
|
||||||
self.config = config
|
self.config = config
|
||||||
self.inference_api = inference_api
|
self.inference_api = inference_api
|
||||||
self.memory_api = memory_api
|
self.memory_api = memory_api
|
||||||
self.safety_api = safety_api
|
self.safety_api = safety_api
|
||||||
|
self.memory_banks_api = memory_banks_api
|
||||||
|
|
||||||
self.in_memory_store = InmemoryKVStoreImpl()
|
self.in_memory_store = InmemoryKVStoreImpl()
|
||||||
|
|
||||||
async def initialize(self) -> None:
|
async def initialize(self) -> None:
|
||||||
|
@ -81,6 +85,7 @@ class MetaReferenceAgentsImpl(Agents):
|
||||||
inference_api=self.inference_api,
|
inference_api=self.inference_api,
|
||||||
safety_api=self.safety_api,
|
safety_api=self.safety_api,
|
||||||
memory_api=self.memory_api,
|
memory_api=self.memory_api,
|
||||||
|
memory_banks_api=self.memory_banks_api,
|
||||||
persistence_store=(
|
persistence_store=(
|
||||||
self.persistence_store
|
self.persistence_store
|
||||||
if agent_config.enable_session_persistence
|
if agent_config.enable_session_persistence
|
||||||
|
|
|
@ -28,6 +28,7 @@ def available_providers() -> List[ProviderSpec]:
|
||||||
Api.inference,
|
Api.inference,
|
||||||
Api.safety,
|
Api.safety,
|
||||||
Api.memory,
|
Api.memory,
|
||||||
|
Api.memory_banks,
|
||||||
],
|
],
|
||||||
),
|
),
|
||||||
remote_provider_spec(
|
remote_provider_spec(
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue