diff --git a/llama_stack/providers/impls/meta_reference/agents/agent_instance.py b/llama_stack/providers/impls/meta_reference/agents/agent_instance.py index 7cf86b1f3..6e9fc7c63 100644 --- a/llama_stack/providers/impls/meta_reference/agents/agent_instance.py +++ b/llama_stack/providers/impls/meta_reference/agents/agent_instance.py @@ -6,7 +6,6 @@ import asyncio import copy -import json import os import secrets import shutil @@ -28,6 +27,7 @@ from llama_stack.apis.safety import * # noqa: F403 from llama_stack.providers.utils.kvstore import KVStore +from .persistence import AgentPersistence from .rag.context_retriever import generate_rag_query from .safety import SafetyException, ShieldRunnerMixin from .tools.base import BaseTool @@ -47,13 +47,6 @@ def make_random_string(length: int = 8): ) -class AgentSessionInfo(BaseModel): - session_id: str - session_name: str - memory_bank_id: Optional[str] = None - started_at: datetime - - class ChatAgent(ShieldRunnerMixin): def __init__( self, @@ -69,7 +62,7 @@ class ChatAgent(ShieldRunnerMixin): self.inference_api = inference_api self.memory_api = memory_api self.safety_api = safety_api - self.persistence_store = persistence_store + self.storage = AgentPersistence(agent_id, persistence_store) self.tempdir = tempfile.mkdtemp() @@ -143,68 +136,16 @@ class ChatAgent(ShieldRunnerMixin): return messages async def create_session(self, name: str) -> str: - session_id = str(uuid.uuid4()) - session_info = AgentSessionInfo( - session_id=session_id, - session_name=name, - started_at=datetime.now(), - ) - await self.persistence_store.set( - key=f"session:{self.agent_id}:{session_id}", - value=session_info.json(), - ) - return session_id - - async def _get_session_info(self, session_id: str) -> Optional[AgentSessionInfo]: - value = await self.persistence_store.get( - key=f"session:{self.agent_id}:{session_id}", - ) - if not value: - return None - - return AgentSessionInfo(**json.loads(value)) - - async def _add_memory_bank_to_session(self, session_id: str, bank_id: str): - session_info = await self._get_session_info(session_id) - if session_info is None: - raise ValueError(f"Session {session_id} not found") - - session_info.memory_bank_id = bank_id - await self.persistence_store.set( - key=f"session:{self.agent_id}:{session_id}", - value=session_info.json(), - ) - - async def _add_turn_to_session(self, session_id: str, turn: Turn): - await self.persistence_store.set( - key=f"session:{self.agent_id}:{session_id}:{turn.turn_id}", - value=turn.json(), - ) - - async def _get_session_turns(self, session_id: str) -> List[Turn]: - values = await self.persistence_store.range( - start_key=f"session:{self.agent_id}:{session_id}:", - end_key=f"session:{self.agent_id}:{session_id}:\xff\xff\xff\xff", - ) - turns = [] - for value in values: - try: - turn = Turn(**json.loads(value)) - turns.append(turn) - except Exception as e: - print(f"Error parsing turn: {e}") - continue - - return turns + return await self.storage.create_session(name) async def create_and_execute_turn( self, request: AgentTurnCreateRequest ) -> AsyncGenerator: - session_info = await self._get_session_info(request.session_id) + session_info = await self.storage.get_session_info(request.session_id) if session_info is None: raise ValueError(f"Session {request.session_id} not found") - turns = await self._get_session_turns(request.session_id) + turns = await self.storage.get_session_turns(request.session_id) messages = [] if len(turns) == 0 and self.agent_config.instructions != "": @@ -267,7 +208,7 @@ class ChatAgent(ShieldRunnerMixin): completed_at=datetime.now(), steps=steps, ) - await self._add_turn_to_session(request.session_id, turn) + await self.storage.add_turn_to_session(request.session_id, turn) chunk = AgentTurnResponseStreamChunk( event=AgentTurnResponseEvent( @@ -665,7 +606,7 @@ class ChatAgent(ShieldRunnerMixin): n_iter += 1 async def _ensure_memory_bank(self, session_id: str) -> str: - session_info = await self._get_session_info(session_id) + session_info = await self.storage.get_session_info(session_id) if session_info is None: raise ValueError(f"Session {session_id} not found") @@ -678,7 +619,7 @@ class ChatAgent(ShieldRunnerMixin): ), ) bank_id = memory_bank.bank_id - await self._add_memory_bank_to_session(session_id, bank_id) + await self.storage.add_memory_bank_to_session(session_id, bank_id) else: bank_id = session_info.memory_bank_id @@ -730,7 +671,7 @@ class ChatAgent(ShieldRunnerMixin): ] await self.memory_api.insert_documents(bank_id, documents) else: - session_info = await self._get_session_info(session_id) + session_info = await self.storage.get_session_info(session_id) if session_info.memory_bank_id: bank_ids.append(session_info.memory_bank_id) diff --git a/llama_stack/providers/impls/meta_reference/agents/persistence.py b/llama_stack/providers/impls/meta_reference/agents/persistence.py new file mode 100644 index 000000000..37ac75d6a --- /dev/null +++ b/llama_stack/providers/impls/meta_reference/agents/persistence.py @@ -0,0 +1,84 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import json + +import uuid +from datetime import datetime + +from typing import List, Optional +from llama_stack.apis.agents import * # noqa: F403 +from pydantic import BaseModel + +from llama_stack.providers.utils.kvstore import KVStore + + +class AgentSessionInfo(BaseModel): + session_id: str + session_name: str + memory_bank_id: Optional[str] = None + started_at: datetime + + +class AgentPersistence: + def __init__(self, agent_id: str, kvstore: KVStore): + self.agent_id = agent_id + self.kvstore = kvstore + + async def create_session(self, name: str) -> str: + session_id = str(uuid.uuid4()) + session_info = AgentSessionInfo( + session_id=session_id, + session_name=name, + started_at=datetime.now(), + ) + await self.kvstore.set( + key=f"session:{self.agent_id}:{session_id}", + value=session_info.json(), + ) + return session_id + + async def get_session_info(self, session_id: str) -> Optional[AgentSessionInfo]: + value = await self.kvstore.get( + key=f"session:{self.agent_id}:{session_id}", + ) + if not value: + return None + + return AgentSessionInfo(**json.loads(value)) + + async def add_memory_bank_to_session(self, session_id: str, bank_id: str): + session_info = await self.get_session_info(session_id) + if session_info is None: + raise ValueError(f"Session {session_id} not found") + + session_info.memory_bank_id = bank_id + await self.kvstore.set( + key=f"session:{self.agent_id}:{session_id}", + value=session_info.json(), + ) + + async def add_turn_to_session(self, session_id: str, turn: Turn): + await self.kvstore.set( + key=f"session:{self.agent_id}:{session_id}:{turn.turn_id}", + value=turn.json(), + ) + + async def get_session_turns(self, session_id: str) -> List[Turn]: + values = await self.kvstore.range( + start_key=f"session:{self.agent_id}:{session_id}:", + end_key=f"session:{self.agent_id}:{session_id}:\xff\xff\xff\xff", + ) + turns = [] + for value in values: + try: + turn = Turn(**json.loads(value)) + turns.append(turn) + except Exception as e: + print(f"Error parsing turn: {e}") + continue + + return turns