refactor persistence into another file

This commit is contained in:
Ashwin Bharambe 2024-09-21 23:21:54 -07:00
parent cd4880126b
commit b153a67a3e
2 changed files with 93 additions and 68 deletions

View file

@ -6,7 +6,6 @@
import asyncio
import copy
import json
import os
import secrets
import shutil
@ -28,6 +27,7 @@ from llama_stack.apis.safety import * # noqa: F403
from llama_stack.providers.utils.kvstore import KVStore
from .persistence import AgentPersistence
from .rag.context_retriever import generate_rag_query
from .safety import SafetyException, ShieldRunnerMixin
from .tools.base import BaseTool
@ -47,13 +47,6 @@ def make_random_string(length: int = 8):
)
class AgentSessionInfo(BaseModel):
session_id: str
session_name: str
memory_bank_id: Optional[str] = None
started_at: datetime
class ChatAgent(ShieldRunnerMixin):
def __init__(
self,
@ -69,7 +62,7 @@ class ChatAgent(ShieldRunnerMixin):
self.inference_api = inference_api
self.memory_api = memory_api
self.safety_api = safety_api
self.persistence_store = persistence_store
self.storage = AgentPersistence(agent_id, persistence_store)
self.tempdir = tempfile.mkdtemp()
@ -143,68 +136,16 @@ class ChatAgent(ShieldRunnerMixin):
return messages
async def create_session(self, name: str) -> str:
session_id = str(uuid.uuid4())
session_info = AgentSessionInfo(
session_id=session_id,
session_name=name,
started_at=datetime.now(),
)
await self.persistence_store.set(
key=f"session:{self.agent_id}:{session_id}",
value=session_info.json(),
)
return session_id
async def _get_session_info(self, session_id: str) -> Optional[AgentSessionInfo]:
value = await self.persistence_store.get(
key=f"session:{self.agent_id}:{session_id}",
)
if not value:
return None
return AgentSessionInfo(**json.loads(value))
async def _add_memory_bank_to_session(self, session_id: str, bank_id: str):
session_info = await self._get_session_info(session_id)
if session_info is None:
raise ValueError(f"Session {session_id} not found")
session_info.memory_bank_id = bank_id
await self.persistence_store.set(
key=f"session:{self.agent_id}:{session_id}",
value=session_info.json(),
)
async def _add_turn_to_session(self, session_id: str, turn: Turn):
await self.persistence_store.set(
key=f"session:{self.agent_id}:{session_id}:{turn.turn_id}",
value=turn.json(),
)
async def _get_session_turns(self, session_id: str) -> List[Turn]:
values = await self.persistence_store.range(
start_key=f"session:{self.agent_id}:{session_id}:",
end_key=f"session:{self.agent_id}:{session_id}:\xff\xff\xff\xff",
)
turns = []
for value in values:
try:
turn = Turn(**json.loads(value))
turns.append(turn)
except Exception as e:
print(f"Error parsing turn: {e}")
continue
return turns
return await self.storage.create_session(name)
async def create_and_execute_turn(
self, request: AgentTurnCreateRequest
) -> AsyncGenerator:
session_info = await self._get_session_info(request.session_id)
session_info = await self.storage.get_session_info(request.session_id)
if session_info is None:
raise ValueError(f"Session {request.session_id} not found")
turns = await self._get_session_turns(request.session_id)
turns = await self.storage.get_session_turns(request.session_id)
messages = []
if len(turns) == 0 and self.agent_config.instructions != "":
@ -267,7 +208,7 @@ class ChatAgent(ShieldRunnerMixin):
completed_at=datetime.now(),
steps=steps,
)
await self._add_turn_to_session(request.session_id, turn)
await self.storage.add_turn_to_session(request.session_id, turn)
chunk = AgentTurnResponseStreamChunk(
event=AgentTurnResponseEvent(
@ -665,7 +606,7 @@ class ChatAgent(ShieldRunnerMixin):
n_iter += 1
async def _ensure_memory_bank(self, session_id: str) -> str:
session_info = await self._get_session_info(session_id)
session_info = await self.storage.get_session_info(session_id)
if session_info is None:
raise ValueError(f"Session {session_id} not found")
@ -678,7 +619,7 @@ class ChatAgent(ShieldRunnerMixin):
),
)
bank_id = memory_bank.bank_id
await self._add_memory_bank_to_session(session_id, bank_id)
await self.storage.add_memory_bank_to_session(session_id, bank_id)
else:
bank_id = session_info.memory_bank_id
@ -730,7 +671,7 @@ class ChatAgent(ShieldRunnerMixin):
]
await self.memory_api.insert_documents(bank_id, documents)
else:
session_info = await self._get_session_info(session_id)
session_info = await self.storage.get_session_info(session_id)
if session_info.memory_bank_id:
bank_ids.append(session_info.memory_bank_id)

View file

@ -0,0 +1,84 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import json
import uuid
from datetime import datetime
from typing import List, Optional
from llama_stack.apis.agents import * # noqa: F403
from pydantic import BaseModel
from llama_stack.providers.utils.kvstore import KVStore
class AgentSessionInfo(BaseModel):
session_id: str
session_name: str
memory_bank_id: Optional[str] = None
started_at: datetime
class AgentPersistence:
def __init__(self, agent_id: str, kvstore: KVStore):
self.agent_id = agent_id
self.kvstore = kvstore
async def create_session(self, name: str) -> str:
session_id = str(uuid.uuid4())
session_info = AgentSessionInfo(
session_id=session_id,
session_name=name,
started_at=datetime.now(),
)
await self.kvstore.set(
key=f"session:{self.agent_id}:{session_id}",
value=session_info.json(),
)
return session_id
async def get_session_info(self, session_id: str) -> Optional[AgentSessionInfo]:
value = await self.kvstore.get(
key=f"session:{self.agent_id}:{session_id}",
)
if not value:
return None
return AgentSessionInfo(**json.loads(value))
async def add_memory_bank_to_session(self, session_id: str, bank_id: str):
session_info = await self.get_session_info(session_id)
if session_info is None:
raise ValueError(f"Session {session_id} not found")
session_info.memory_bank_id = bank_id
await self.kvstore.set(
key=f"session:{self.agent_id}:{session_id}",
value=session_info.json(),
)
async def add_turn_to_session(self, session_id: str, turn: Turn):
await self.kvstore.set(
key=f"session:{self.agent_id}:{session_id}:{turn.turn_id}",
value=turn.json(),
)
async def get_session_turns(self, session_id: str) -> List[Turn]:
values = await self.kvstore.range(
start_key=f"session:{self.agent_id}:{session_id}:",
end_key=f"session:{self.agent_id}:{session_id}:\xff\xff\xff\xff",
)
turns = []
for value in values:
try:
turn = Turn(**json.loads(value))
turns.append(turn)
except Exception as e:
print(f"Error parsing turn: {e}")
continue
return turns