mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-29 15:23:51 +00:00
refactor persistence into another file
This commit is contained in:
parent
cd4880126b
commit
b153a67a3e
2 changed files with 93 additions and 68 deletions
|
@ -6,7 +6,6 @@
|
|||
|
||||
import asyncio
|
||||
import copy
|
||||
import json
|
||||
import os
|
||||
import secrets
|
||||
import shutil
|
||||
|
@ -28,6 +27,7 @@ from llama_stack.apis.safety import * # noqa: F403
|
|||
|
||||
from llama_stack.providers.utils.kvstore import KVStore
|
||||
|
||||
from .persistence import AgentPersistence
|
||||
from .rag.context_retriever import generate_rag_query
|
||||
from .safety import SafetyException, ShieldRunnerMixin
|
||||
from .tools.base import BaseTool
|
||||
|
@ -47,13 +47,6 @@ def make_random_string(length: int = 8):
|
|||
)
|
||||
|
||||
|
||||
class AgentSessionInfo(BaseModel):
|
||||
session_id: str
|
||||
session_name: str
|
||||
memory_bank_id: Optional[str] = None
|
||||
started_at: datetime
|
||||
|
||||
|
||||
class ChatAgent(ShieldRunnerMixin):
|
||||
def __init__(
|
||||
self,
|
||||
|
@ -69,7 +62,7 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
self.inference_api = inference_api
|
||||
self.memory_api = memory_api
|
||||
self.safety_api = safety_api
|
||||
self.persistence_store = persistence_store
|
||||
self.storage = AgentPersistence(agent_id, persistence_store)
|
||||
|
||||
self.tempdir = tempfile.mkdtemp()
|
||||
|
||||
|
@ -143,68 +136,16 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
return messages
|
||||
|
||||
async def create_session(self, name: str) -> str:
|
||||
session_id = str(uuid.uuid4())
|
||||
session_info = AgentSessionInfo(
|
||||
session_id=session_id,
|
||||
session_name=name,
|
||||
started_at=datetime.now(),
|
||||
)
|
||||
await self.persistence_store.set(
|
||||
key=f"session:{self.agent_id}:{session_id}",
|
||||
value=session_info.json(),
|
||||
)
|
||||
return session_id
|
||||
|
||||
async def _get_session_info(self, session_id: str) -> Optional[AgentSessionInfo]:
|
||||
value = await self.persistence_store.get(
|
||||
key=f"session:{self.agent_id}:{session_id}",
|
||||
)
|
||||
if not value:
|
||||
return None
|
||||
|
||||
return AgentSessionInfo(**json.loads(value))
|
||||
|
||||
async def _add_memory_bank_to_session(self, session_id: str, bank_id: str):
|
||||
session_info = await self._get_session_info(session_id)
|
||||
if session_info is None:
|
||||
raise ValueError(f"Session {session_id} not found")
|
||||
|
||||
session_info.memory_bank_id = bank_id
|
||||
await self.persistence_store.set(
|
||||
key=f"session:{self.agent_id}:{session_id}",
|
||||
value=session_info.json(),
|
||||
)
|
||||
|
||||
async def _add_turn_to_session(self, session_id: str, turn: Turn):
|
||||
await self.persistence_store.set(
|
||||
key=f"session:{self.agent_id}:{session_id}:{turn.turn_id}",
|
||||
value=turn.json(),
|
||||
)
|
||||
|
||||
async def _get_session_turns(self, session_id: str) -> List[Turn]:
|
||||
values = await self.persistence_store.range(
|
||||
start_key=f"session:{self.agent_id}:{session_id}:",
|
||||
end_key=f"session:{self.agent_id}:{session_id}:\xff\xff\xff\xff",
|
||||
)
|
||||
turns = []
|
||||
for value in values:
|
||||
try:
|
||||
turn = Turn(**json.loads(value))
|
||||
turns.append(turn)
|
||||
except Exception as e:
|
||||
print(f"Error parsing turn: {e}")
|
||||
continue
|
||||
|
||||
return turns
|
||||
return await self.storage.create_session(name)
|
||||
|
||||
async def create_and_execute_turn(
|
||||
self, request: AgentTurnCreateRequest
|
||||
) -> AsyncGenerator:
|
||||
session_info = await self._get_session_info(request.session_id)
|
||||
session_info = await self.storage.get_session_info(request.session_id)
|
||||
if session_info is None:
|
||||
raise ValueError(f"Session {request.session_id} not found")
|
||||
|
||||
turns = await self._get_session_turns(request.session_id)
|
||||
turns = await self.storage.get_session_turns(request.session_id)
|
||||
|
||||
messages = []
|
||||
if len(turns) == 0 and self.agent_config.instructions != "":
|
||||
|
@ -267,7 +208,7 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
completed_at=datetime.now(),
|
||||
steps=steps,
|
||||
)
|
||||
await self._add_turn_to_session(request.session_id, turn)
|
||||
await self.storage.add_turn_to_session(request.session_id, turn)
|
||||
|
||||
chunk = AgentTurnResponseStreamChunk(
|
||||
event=AgentTurnResponseEvent(
|
||||
|
@ -665,7 +606,7 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
n_iter += 1
|
||||
|
||||
async def _ensure_memory_bank(self, session_id: str) -> str:
|
||||
session_info = await self._get_session_info(session_id)
|
||||
session_info = await self.storage.get_session_info(session_id)
|
||||
if session_info is None:
|
||||
raise ValueError(f"Session {session_id} not found")
|
||||
|
||||
|
@ -678,7 +619,7 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
),
|
||||
)
|
||||
bank_id = memory_bank.bank_id
|
||||
await self._add_memory_bank_to_session(session_id, bank_id)
|
||||
await self.storage.add_memory_bank_to_session(session_id, bank_id)
|
||||
else:
|
||||
bank_id = session_info.memory_bank_id
|
||||
|
||||
|
@ -730,7 +671,7 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
]
|
||||
await self.memory_api.insert_documents(bank_id, documents)
|
||||
else:
|
||||
session_info = await self._get_session_info(session_id)
|
||||
session_info = await self.storage.get_session_info(session_id)
|
||||
if session_info.memory_bank_id:
|
||||
bank_ids.append(session_info.memory_bank_id)
|
||||
|
||||
|
|
|
@ -0,0 +1,84 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import json
|
||||
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
|
||||
from typing import List, Optional
|
||||
from llama_stack.apis.agents import * # noqa: F403
|
||||
from pydantic import BaseModel
|
||||
|
||||
from llama_stack.providers.utils.kvstore import KVStore
|
||||
|
||||
|
||||
class AgentSessionInfo(BaseModel):
|
||||
session_id: str
|
||||
session_name: str
|
||||
memory_bank_id: Optional[str] = None
|
||||
started_at: datetime
|
||||
|
||||
|
||||
class AgentPersistence:
|
||||
def __init__(self, agent_id: str, kvstore: KVStore):
|
||||
self.agent_id = agent_id
|
||||
self.kvstore = kvstore
|
||||
|
||||
async def create_session(self, name: str) -> str:
|
||||
session_id = str(uuid.uuid4())
|
||||
session_info = AgentSessionInfo(
|
||||
session_id=session_id,
|
||||
session_name=name,
|
||||
started_at=datetime.now(),
|
||||
)
|
||||
await self.kvstore.set(
|
||||
key=f"session:{self.agent_id}:{session_id}",
|
||||
value=session_info.json(),
|
||||
)
|
||||
return session_id
|
||||
|
||||
async def get_session_info(self, session_id: str) -> Optional[AgentSessionInfo]:
|
||||
value = await self.kvstore.get(
|
||||
key=f"session:{self.agent_id}:{session_id}",
|
||||
)
|
||||
if not value:
|
||||
return None
|
||||
|
||||
return AgentSessionInfo(**json.loads(value))
|
||||
|
||||
async def add_memory_bank_to_session(self, session_id: str, bank_id: str):
|
||||
session_info = await self.get_session_info(session_id)
|
||||
if session_info is None:
|
||||
raise ValueError(f"Session {session_id} not found")
|
||||
|
||||
session_info.memory_bank_id = bank_id
|
||||
await self.kvstore.set(
|
||||
key=f"session:{self.agent_id}:{session_id}",
|
||||
value=session_info.json(),
|
||||
)
|
||||
|
||||
async def add_turn_to_session(self, session_id: str, turn: Turn):
|
||||
await self.kvstore.set(
|
||||
key=f"session:{self.agent_id}:{session_id}:{turn.turn_id}",
|
||||
value=turn.json(),
|
||||
)
|
||||
|
||||
async def get_session_turns(self, session_id: str) -> List[Turn]:
|
||||
values = await self.kvstore.range(
|
||||
start_key=f"session:{self.agent_id}:{session_id}:",
|
||||
end_key=f"session:{self.agent_id}:{session_id}:\xff\xff\xff\xff",
|
||||
)
|
||||
turns = []
|
||||
for value in values:
|
||||
try:
|
||||
turn = Turn(**json.loads(value))
|
||||
turns.append(turn)
|
||||
except Exception as e:
|
||||
print(f"Error parsing turn: {e}")
|
||||
continue
|
||||
|
||||
return turns
|
Loading…
Add table
Add a link
Reference in a new issue