initial cut at using kvstores for agent persistence

This commit is contained in:
Ashwin Bharambe 2024-09-21 21:16:26 -07:00
parent 61974e337f
commit 4eb0f30891
10 changed files with 153 additions and 120 deletions

View file

@ -25,10 +25,19 @@ from llama_stack.apis.inference import * # noqa: F403
from llama_stack.apis.memory import * # noqa: F403
from llama_stack.apis.safety import * # noqa: F403
from llama_stack.providers.utils.api import KVStore
from .rag.context_retriever import generate_rag_query
from .safety import SafetyException, ShieldRunnerMixin
from .tools.base import BaseTool
from .tools.builtin import interpret_content_as_attachment, SingleMessageBuiltinTool
from .tools.builtin import (
CodeInterpreterTool,
interpret_content_as_attachment,
PhotogenTool,
SearchTool,
WolframAlphaTool,
)
from .tools.safety import with_safety
def make_random_string(length: int = 8):
@ -40,23 +49,44 @@ def make_random_string(length: int = 8):
class ChatAgent(ShieldRunnerMixin):
def __init__(
self,
agent_id: str,
agent_config: AgentConfig,
inference_api: Inference,
memory_api: Memory,
safety_api: Safety,
builtin_tools: List[SingleMessageBuiltinTool],
max_infer_iters: int = 10,
persistence_store: KVStore,
):
self.agent_id = agent_id
self.agent_config = agent_config
self.inference_api = inference_api
self.memory_api = memory_api
self.safety_api = safety_api
self.max_infer_iters = max_infer_iters
self.tools_dict = {t.get_name(): t for t in builtin_tools}
self.persistence_store = persistence_store
self.tempdir = tempfile.mkdtemp()
self.sessions = {}
builtin_tools = []
for tool_defn in agent_config.tools:
if isinstance(tool_defn, WolframAlphaToolDefinition):
tool = WolframAlphaTool(tool_defn.api_key)
elif isinstance(tool_defn, SearchToolDefinition):
tool = SearchTool(tool_defn.engine, tool_defn.api_key)
elif isinstance(tool_defn, CodeInterpreterToolDefinition):
tool = CodeInterpreterTool()
elif isinstance(tool_defn, PhotogenToolDefinition):
tool = PhotogenTool(dump_dir=self.tempdir)
else:
continue
builtin_tools.append(
with_safety(
tool,
safety_api,
tool_defn.input_shields,
tool_defn.output_shields,
)
)
self.tools_dict = {t.get_name(): t for t in builtin_tools}
ShieldRunnerMixin.__init__(
self,
@ -80,7 +110,6 @@ class ChatAgent(ShieldRunnerMixin):
msg.context = None
messages.append(msg)
# messages.extend(turn.input_messages)
for step in turn.steps:
if step.step_type == StepType.inference.value:
messages.append(step.model_response)
@ -105,31 +134,64 @@ class ChatAgent(ShieldRunnerMixin):
# print_dialog(messages)
return messages
def create_session(self, name: str) -> Session:
async def create_session(self, name: str) -> str:
session_id = str(uuid.uuid4())
session = Session(
session = dict(
session_id=session_id,
session_name=name,
turns=[],
started_at=datetime.now(),
started_at=str(datetime.now()),
)
self.sessions[session_id] = session
return session
await self.persistence_store.set(
key=f"session:{self.agent_id}:{session_id}",
value=json.dumps(session),
)
return session_id
async def _add_turn_to_session(self, session_id: str, turn: Turn):
await self.persistence_store.set(
key=f"session:{self.agent_id}:{session_id}:{turn.turn_id}",
value=json.dumps(turn.dict()),
)
async def _get_session_turns(self, session_id: str) -> List[Turn]:
values = await self.persistence_store.range(
start_key=f"session:{self.agent_id}:{session_id}:",
end_key=f"session:{self.agent_id}:{session_id}:\xff\xff\xff\xff",
)
turns = []
for value in values:
try:
turn = Turn(**json.loads(value))
turns.append(turn)
except Exception as e:
print(f"Error parsing turn: {e}")
continue
return turns
async def _get_session_info(self, session_id: str) -> Optional[Dict[str, str]]:
value = await self.persistence_store.get(
key=f"session:{self.agent_id}:{session_id}",
)
if not value:
return None
return json.loads(value)
async def create_and_execute_turn(
self, request: AgentTurnCreateRequest
) -> AsyncGenerator:
assert (
request.session_id in self.sessions
), f"Session {request.session_id} not found"
session_info = await self._get_session_info(request.session_id)
if session_info is None:
raise ValueError(f"Session {request.session_id} not found")
session = self.sessions[request.session_id]
turns = await self._get_session_turns(request.session_id)
messages = []
if len(session.turns) == 0 and self.agent_config.instructions != "":
if len(turns) == 0 and self.agent_config.instructions != "":
messages.append(SystemMessage(content=self.agent_config.instructions))
for i, turn in enumerate(session.turns):
for i, turn in enumerate(turns):
messages.extend(self.turn_to_messages(turn))
messages.extend(request.messages)
@ -186,7 +248,7 @@ class ChatAgent(ShieldRunnerMixin):
completed_at=datetime.now(),
steps=steps,
)
session.turns.append(turn)
await self._add_turn_to_session(request.session_id, turn)
chunk = AgentTurnResponseStreamChunk(
event=AgentTurnResponseEvent(
@ -455,7 +517,7 @@ class ChatAgent(ShieldRunnerMixin):
)
)
if n_iter >= self.max_infer_iters:
if n_iter >= self.agent_config.max_infer_iters:
cprint("Done with MAX iterations, exiting.")
yield message
break

View file

@ -6,7 +6,6 @@
import logging
import tempfile
import uuid
from typing import AsyncGenerator
@ -15,39 +14,15 @@ from llama_stack.apis.memory import Memory
from llama_stack.apis.safety import Safety
from llama_stack.apis.agents import * # noqa: F403
from llama_stack.providers.utils.kvstore import kvstore_impl
from .agent_instance import ChatAgent
from .config import MetaReferenceAgentsImplConfig
from .tools.builtin import (
CodeInterpreterTool,
PhotogenTool,
SearchTool,
WolframAlphaTool,
)
from .tools.safety import with_safety
logger = logging.getLogger()
logger.setLevel(logging.INFO)
AGENT_INSTANCES_BY_ID = {}
class KVStore(Protocol):
def get(self, key: str) -> str:
...
def set(self, key: str, value: str) -> None:
...
def kvstore_impl(config: KVStoreConfig) -> KVStore:
if config.type == KVStoreType.redis:
from .kvstore_impls.redis import RedisKVStoreImpl
return RedisKVStoreImpl(config)
return None
class MetaReferenceAgentsImpl(Agents):
def __init__(
self,
@ -60,10 +35,9 @@ class MetaReferenceAgentsImpl(Agents):
self.inference_api = inference_api
self.memory_api = memory_api
self.safety_api = safety_api
self.kvstore = kvstore_impl(config.kvstore)
async def initialize(self) -> None:
pass
self.persistence_store = await kvstore_impl(self.config.persistence_store)
async def create_agent(
self,
@ -71,38 +45,42 @@ class MetaReferenceAgentsImpl(Agents):
) -> AgentCreateResponse:
agent_id = str(uuid.uuid4())
builtin_tools = []
for tool_defn in agent_config.tools:
if isinstance(tool_defn, WolframAlphaToolDefinition):
tool = WolframAlphaTool(tool_defn.api_key)
elif isinstance(tool_defn, SearchToolDefinition):
tool = SearchTool(tool_defn.engine, tool_defn.api_key)
elif isinstance(tool_defn, CodeInterpreterToolDefinition):
tool = CodeInterpreterTool()
elif isinstance(tool_defn, PhotogenToolDefinition):
tool = PhotogenTool(dump_dir=tempfile.mkdtemp())
else:
continue
await self.persistence_store.set(
key=f"agent:{agent_id}",
value=json.dumps(agent_config.dict()),
)
return AgentCreateResponse(
agent_id=agent_id,
)
builtin_tools.append(
with_safety(
tool,
self.safety_api,
tool_defn.input_shields,
tool_defn.output_shields,
)
)
async def get_agent(self, agent_id: str) -> ChatAgent:
agent_config = await self.persistence_store.get(
key=f"agent:{agent_id}",
)
if not agent_config:
raise ValueError(f"Could not find agent config for {agent_id}")
AGENT_INSTANCES_BY_ID[agent_id] = ChatAgent(
try:
agent_config = json.loads(agent_config)
except json.JSONDecodeError as e:
raise ValueError(
f"Could not JSON decode agent config for {agent_id}"
) from e
try:
agent_config = AgentConfig(**agent_config)
except Exception as e:
raise ValueError(
f"Could not validate(?) agent config for {agent_id}"
) from e
return ChatAgent(
agent_id=agent_id,
agent_config=agent_config,
inference_api=self.inference_api,
safety_api=self.safety_api,
memory_api=self.memory_api,
builtin_tools=builtin_tools,
)
return AgentCreateResponse(
agent_id=agent_id,
persistence_store=self.persistence_store,
)
async def create_agent_session(
@ -110,12 +88,11 @@ class MetaReferenceAgentsImpl(Agents):
agent_id: str,
session_name: str,
) -> AgentSessionCreateResponse:
assert agent_id in AGENT_INSTANCES_BY_ID, f"System {agent_id} not found"
agent = AGENT_INSTANCES_BY_ID[agent_id]
agent = await self.get_agent(agent_id)
session = agent.create_session(session_name)
session_id = await agent.create_session(session_name)
return AgentSessionCreateResponse(
session_id=session.session_id,
session_id=session_id,
)
async def create_agent_turn(
@ -131,6 +108,8 @@ class MetaReferenceAgentsImpl(Agents):
attachments: Optional[List[Attachment]] = None,
stream: Optional[bool] = False,
) -> AsyncGenerator:
agent = await self.get_agent(agent_id)
# wrapper request to make it easier to pass around (internal only, not exposed to API)
request = AgentTurnCreateRequest(
agent_id=agent_id,
@ -140,12 +119,5 @@ class MetaReferenceAgentsImpl(Agents):
stream=stream,
)
agent_id = request.agent_id
assert agent_id in AGENT_INSTANCES_BY_ID, f"System {agent_id} not found"
agent = AGENT_INSTANCES_BY_ID[agent_id]
assert (
request.session_id in agent.sessions
), f"Session {request.session_id} not found"
async for event in agent.create_and_execute_turn(request):
yield event

View file

@ -8,4 +8,4 @@ from llama_stack.providers.utils.kvstore import KVStoreConfig
class MetaReferenceAgentsImplConfig(BaseModel):
kv_store: KVStoreConfig
persistence_store: KVStoreConfig