mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-29 15:23:51 +00:00
refactor persistence into another file
This commit is contained in:
parent
cd4880126b
commit
b153a67a3e
2 changed files with 93 additions and 68 deletions
|
@ -6,7 +6,6 @@
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import copy
|
import copy
|
||||||
import json
|
|
||||||
import os
|
import os
|
||||||
import secrets
|
import secrets
|
||||||
import shutil
|
import shutil
|
||||||
|
@ -28,6 +27,7 @@ from llama_stack.apis.safety import * # noqa: F403
|
||||||
|
|
||||||
from llama_stack.providers.utils.kvstore import KVStore
|
from llama_stack.providers.utils.kvstore import KVStore
|
||||||
|
|
||||||
|
from .persistence import AgentPersistence
|
||||||
from .rag.context_retriever import generate_rag_query
|
from .rag.context_retriever import generate_rag_query
|
||||||
from .safety import SafetyException, ShieldRunnerMixin
|
from .safety import SafetyException, ShieldRunnerMixin
|
||||||
from .tools.base import BaseTool
|
from .tools.base import BaseTool
|
||||||
|
@ -47,13 +47,6 @@ def make_random_string(length: int = 8):
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class AgentSessionInfo(BaseModel):
|
|
||||||
session_id: str
|
|
||||||
session_name: str
|
|
||||||
memory_bank_id: Optional[str] = None
|
|
||||||
started_at: datetime
|
|
||||||
|
|
||||||
|
|
||||||
class ChatAgent(ShieldRunnerMixin):
|
class ChatAgent(ShieldRunnerMixin):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
@ -69,7 +62,7 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
self.inference_api = inference_api
|
self.inference_api = inference_api
|
||||||
self.memory_api = memory_api
|
self.memory_api = memory_api
|
||||||
self.safety_api = safety_api
|
self.safety_api = safety_api
|
||||||
self.persistence_store = persistence_store
|
self.storage = AgentPersistence(agent_id, persistence_store)
|
||||||
|
|
||||||
self.tempdir = tempfile.mkdtemp()
|
self.tempdir = tempfile.mkdtemp()
|
||||||
|
|
||||||
|
@ -143,68 +136,16 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
return messages
|
return messages
|
||||||
|
|
||||||
async def create_session(self, name: str) -> str:
|
async def create_session(self, name: str) -> str:
|
||||||
session_id = str(uuid.uuid4())
|
return await self.storage.create_session(name)
|
||||||
session_info = AgentSessionInfo(
|
|
||||||
session_id=session_id,
|
|
||||||
session_name=name,
|
|
||||||
started_at=datetime.now(),
|
|
||||||
)
|
|
||||||
await self.persistence_store.set(
|
|
||||||
key=f"session:{self.agent_id}:{session_id}",
|
|
||||||
value=session_info.json(),
|
|
||||||
)
|
|
||||||
return session_id
|
|
||||||
|
|
||||||
async def _get_session_info(self, session_id: str) -> Optional[AgentSessionInfo]:
|
|
||||||
value = await self.persistence_store.get(
|
|
||||||
key=f"session:{self.agent_id}:{session_id}",
|
|
||||||
)
|
|
||||||
if not value:
|
|
||||||
return None
|
|
||||||
|
|
||||||
return AgentSessionInfo(**json.loads(value))
|
|
||||||
|
|
||||||
async def _add_memory_bank_to_session(self, session_id: str, bank_id: str):
|
|
||||||
session_info = await self._get_session_info(session_id)
|
|
||||||
if session_info is None:
|
|
||||||
raise ValueError(f"Session {session_id} not found")
|
|
||||||
|
|
||||||
session_info.memory_bank_id = bank_id
|
|
||||||
await self.persistence_store.set(
|
|
||||||
key=f"session:{self.agent_id}:{session_id}",
|
|
||||||
value=session_info.json(),
|
|
||||||
)
|
|
||||||
|
|
||||||
async def _add_turn_to_session(self, session_id: str, turn: Turn):
|
|
||||||
await self.persistence_store.set(
|
|
||||||
key=f"session:{self.agent_id}:{session_id}:{turn.turn_id}",
|
|
||||||
value=turn.json(),
|
|
||||||
)
|
|
||||||
|
|
||||||
async def _get_session_turns(self, session_id: str) -> List[Turn]:
|
|
||||||
values = await self.persistence_store.range(
|
|
||||||
start_key=f"session:{self.agent_id}:{session_id}:",
|
|
||||||
end_key=f"session:{self.agent_id}:{session_id}:\xff\xff\xff\xff",
|
|
||||||
)
|
|
||||||
turns = []
|
|
||||||
for value in values:
|
|
||||||
try:
|
|
||||||
turn = Turn(**json.loads(value))
|
|
||||||
turns.append(turn)
|
|
||||||
except Exception as e:
|
|
||||||
print(f"Error parsing turn: {e}")
|
|
||||||
continue
|
|
||||||
|
|
||||||
return turns
|
|
||||||
|
|
||||||
async def create_and_execute_turn(
|
async def create_and_execute_turn(
|
||||||
self, request: AgentTurnCreateRequest
|
self, request: AgentTurnCreateRequest
|
||||||
) -> AsyncGenerator:
|
) -> AsyncGenerator:
|
||||||
session_info = await self._get_session_info(request.session_id)
|
session_info = await self.storage.get_session_info(request.session_id)
|
||||||
if session_info is None:
|
if session_info is None:
|
||||||
raise ValueError(f"Session {request.session_id} not found")
|
raise ValueError(f"Session {request.session_id} not found")
|
||||||
|
|
||||||
turns = await self._get_session_turns(request.session_id)
|
turns = await self.storage.get_session_turns(request.session_id)
|
||||||
|
|
||||||
messages = []
|
messages = []
|
||||||
if len(turns) == 0 and self.agent_config.instructions != "":
|
if len(turns) == 0 and self.agent_config.instructions != "":
|
||||||
|
@ -267,7 +208,7 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
completed_at=datetime.now(),
|
completed_at=datetime.now(),
|
||||||
steps=steps,
|
steps=steps,
|
||||||
)
|
)
|
||||||
await self._add_turn_to_session(request.session_id, turn)
|
await self.storage.add_turn_to_session(request.session_id, turn)
|
||||||
|
|
||||||
chunk = AgentTurnResponseStreamChunk(
|
chunk = AgentTurnResponseStreamChunk(
|
||||||
event=AgentTurnResponseEvent(
|
event=AgentTurnResponseEvent(
|
||||||
|
@ -665,7 +606,7 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
n_iter += 1
|
n_iter += 1
|
||||||
|
|
||||||
async def _ensure_memory_bank(self, session_id: str) -> str:
|
async def _ensure_memory_bank(self, session_id: str) -> str:
|
||||||
session_info = await self._get_session_info(session_id)
|
session_info = await self.storage.get_session_info(session_id)
|
||||||
if session_info is None:
|
if session_info is None:
|
||||||
raise ValueError(f"Session {session_id} not found")
|
raise ValueError(f"Session {session_id} not found")
|
||||||
|
|
||||||
|
@ -678,7 +619,7 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
bank_id = memory_bank.bank_id
|
bank_id = memory_bank.bank_id
|
||||||
await self._add_memory_bank_to_session(session_id, bank_id)
|
await self.storage.add_memory_bank_to_session(session_id, bank_id)
|
||||||
else:
|
else:
|
||||||
bank_id = session_info.memory_bank_id
|
bank_id = session_info.memory_bank_id
|
||||||
|
|
||||||
|
@ -730,7 +671,7 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
]
|
]
|
||||||
await self.memory_api.insert_documents(bank_id, documents)
|
await self.memory_api.insert_documents(bank_id, documents)
|
||||||
else:
|
else:
|
||||||
session_info = await self._get_session_info(session_id)
|
session_info = await self.storage.get_session_info(session_id)
|
||||||
if session_info.memory_bank_id:
|
if session_info.memory_bank_id:
|
||||||
bank_ids.append(session_info.memory_bank_id)
|
bank_ids.append(session_info.memory_bank_id)
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,84 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
import json
|
||||||
|
|
||||||
|
import uuid
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
from typing import List, Optional
|
||||||
|
from llama_stack.apis.agents import * # noqa: F403
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
from llama_stack.providers.utils.kvstore import KVStore
|
||||||
|
|
||||||
|
|
||||||
|
class AgentSessionInfo(BaseModel):
|
||||||
|
session_id: str
|
||||||
|
session_name: str
|
||||||
|
memory_bank_id: Optional[str] = None
|
||||||
|
started_at: datetime
|
||||||
|
|
||||||
|
|
||||||
|
class AgentPersistence:
|
||||||
|
def __init__(self, agent_id: str, kvstore: KVStore):
|
||||||
|
self.agent_id = agent_id
|
||||||
|
self.kvstore = kvstore
|
||||||
|
|
||||||
|
async def create_session(self, name: str) -> str:
|
||||||
|
session_id = str(uuid.uuid4())
|
||||||
|
session_info = AgentSessionInfo(
|
||||||
|
session_id=session_id,
|
||||||
|
session_name=name,
|
||||||
|
started_at=datetime.now(),
|
||||||
|
)
|
||||||
|
await self.kvstore.set(
|
||||||
|
key=f"session:{self.agent_id}:{session_id}",
|
||||||
|
value=session_info.json(),
|
||||||
|
)
|
||||||
|
return session_id
|
||||||
|
|
||||||
|
async def get_session_info(self, session_id: str) -> Optional[AgentSessionInfo]:
|
||||||
|
value = await self.kvstore.get(
|
||||||
|
key=f"session:{self.agent_id}:{session_id}",
|
||||||
|
)
|
||||||
|
if not value:
|
||||||
|
return None
|
||||||
|
|
||||||
|
return AgentSessionInfo(**json.loads(value))
|
||||||
|
|
||||||
|
async def add_memory_bank_to_session(self, session_id: str, bank_id: str):
|
||||||
|
session_info = await self.get_session_info(session_id)
|
||||||
|
if session_info is None:
|
||||||
|
raise ValueError(f"Session {session_id} not found")
|
||||||
|
|
||||||
|
session_info.memory_bank_id = bank_id
|
||||||
|
await self.kvstore.set(
|
||||||
|
key=f"session:{self.agent_id}:{session_id}",
|
||||||
|
value=session_info.json(),
|
||||||
|
)
|
||||||
|
|
||||||
|
async def add_turn_to_session(self, session_id: str, turn: Turn):
|
||||||
|
await self.kvstore.set(
|
||||||
|
key=f"session:{self.agent_id}:{session_id}:{turn.turn_id}",
|
||||||
|
value=turn.json(),
|
||||||
|
)
|
||||||
|
|
||||||
|
async def get_session_turns(self, session_id: str) -> List[Turn]:
|
||||||
|
values = await self.kvstore.range(
|
||||||
|
start_key=f"session:{self.agent_id}:{session_id}:",
|
||||||
|
end_key=f"session:{self.agent_id}:{session_id}:\xff\xff\xff\xff",
|
||||||
|
)
|
||||||
|
turns = []
|
||||||
|
for value in values:
|
||||||
|
try:
|
||||||
|
turn = Turn(**json.loads(value))
|
||||||
|
turns.append(turn)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error parsing turn: {e}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
return turns
|
Loading…
Add table
Add a link
Reference in a new issue