mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-10-04 20:14:13 +00:00
initial cut at using kvstores for agent persistence
This commit is contained in:
parent
61974e337f
commit
4eb0f30891
10 changed files with 153 additions and 120 deletions
|
@ -6,7 +6,6 @@
|
|||
|
||||
|
||||
import logging
|
||||
import tempfile
|
||||
import uuid
|
||||
from typing import AsyncGenerator
|
||||
|
||||
|
@ -15,39 +14,15 @@ from llama_stack.apis.memory import Memory
|
|||
from llama_stack.apis.safety import Safety
|
||||
from llama_stack.apis.agents import * # noqa: F403
|
||||
|
||||
from llama_stack.providers.utils.kvstore import kvstore_impl
|
||||
|
||||
from .agent_instance import ChatAgent
|
||||
from .config import MetaReferenceAgentsImplConfig
|
||||
from .tools.builtin import (
|
||||
CodeInterpreterTool,
|
||||
PhotogenTool,
|
||||
SearchTool,
|
||||
WolframAlphaTool,
|
||||
)
|
||||
from .tools.safety import with_safety
|
||||
|
||||
|
||||
logger = logging.getLogger()
|
||||
logger.setLevel(logging.INFO)
|
||||
|
||||
|
||||
AGENT_INSTANCES_BY_ID = {}
|
||||
|
||||
|
||||
class KVStore(Protocol):
|
||||
def get(self, key: str) -> str:
|
||||
...
|
||||
|
||||
def set(self, key: str, value: str) -> None:
|
||||
...
|
||||
|
||||
def kvstore_impl(config: KVStoreConfig) -> KVStore:
|
||||
if config.type == KVStoreType.redis:
|
||||
from .kvstore_impls.redis import RedisKVStoreImpl
|
||||
return RedisKVStoreImpl(config)
|
||||
|
||||
return None
|
||||
|
||||
|
||||
class MetaReferenceAgentsImpl(Agents):
|
||||
def __init__(
|
||||
self,
|
||||
|
@ -60,10 +35,9 @@ class MetaReferenceAgentsImpl(Agents):
|
|||
self.inference_api = inference_api
|
||||
self.memory_api = memory_api
|
||||
self.safety_api = safety_api
|
||||
self.kvstore = kvstore_impl(config.kvstore)
|
||||
|
||||
async def initialize(self) -> None:
|
||||
pass
|
||||
self.persistence_store = await kvstore_impl(self.config.persistence_store)
|
||||
|
||||
async def create_agent(
|
||||
self,
|
||||
|
@ -71,38 +45,42 @@ class MetaReferenceAgentsImpl(Agents):
|
|||
) -> AgentCreateResponse:
|
||||
agent_id = str(uuid.uuid4())
|
||||
|
||||
builtin_tools = []
|
||||
for tool_defn in agent_config.tools:
|
||||
if isinstance(tool_defn, WolframAlphaToolDefinition):
|
||||
tool = WolframAlphaTool(tool_defn.api_key)
|
||||
elif isinstance(tool_defn, SearchToolDefinition):
|
||||
tool = SearchTool(tool_defn.engine, tool_defn.api_key)
|
||||
elif isinstance(tool_defn, CodeInterpreterToolDefinition):
|
||||
tool = CodeInterpreterTool()
|
||||
elif isinstance(tool_defn, PhotogenToolDefinition):
|
||||
tool = PhotogenTool(dump_dir=tempfile.mkdtemp())
|
||||
else:
|
||||
continue
|
||||
await self.persistence_store.set(
|
||||
key=f"agent:{agent_id}",
|
||||
value=json.dumps(agent_config.dict()),
|
||||
)
|
||||
return AgentCreateResponse(
|
||||
agent_id=agent_id,
|
||||
)
|
||||
|
||||
builtin_tools.append(
|
||||
with_safety(
|
||||
tool,
|
||||
self.safety_api,
|
||||
tool_defn.input_shields,
|
||||
tool_defn.output_shields,
|
||||
)
|
||||
)
|
||||
async def get_agent(self, agent_id: str) -> ChatAgent:
|
||||
agent_config = await self.persistence_store.get(
|
||||
key=f"agent:{agent_id}",
|
||||
)
|
||||
if not agent_config:
|
||||
raise ValueError(f"Could not find agent config for {agent_id}")
|
||||
|
||||
AGENT_INSTANCES_BY_ID[agent_id] = ChatAgent(
|
||||
try:
|
||||
agent_config = json.loads(agent_config)
|
||||
except json.JSONDecodeError as e:
|
||||
raise ValueError(
|
||||
f"Could not JSON decode agent config for {agent_id}"
|
||||
) from e
|
||||
|
||||
try:
|
||||
agent_config = AgentConfig(**agent_config)
|
||||
except Exception as e:
|
||||
raise ValueError(
|
||||
f"Could not validate(?) agent config for {agent_id}"
|
||||
) from e
|
||||
|
||||
return ChatAgent(
|
||||
agent_id=agent_id,
|
||||
agent_config=agent_config,
|
||||
inference_api=self.inference_api,
|
||||
safety_api=self.safety_api,
|
||||
memory_api=self.memory_api,
|
||||
builtin_tools=builtin_tools,
|
||||
)
|
||||
|
||||
return AgentCreateResponse(
|
||||
agent_id=agent_id,
|
||||
persistence_store=self.persistence_store,
|
||||
)
|
||||
|
||||
async def create_agent_session(
|
||||
|
@ -110,12 +88,11 @@ class MetaReferenceAgentsImpl(Agents):
|
|||
agent_id: str,
|
||||
session_name: str,
|
||||
) -> AgentSessionCreateResponse:
|
||||
assert agent_id in AGENT_INSTANCES_BY_ID, f"System {agent_id} not found"
|
||||
agent = AGENT_INSTANCES_BY_ID[agent_id]
|
||||
agent = await self.get_agent(agent_id)
|
||||
|
||||
session = agent.create_session(session_name)
|
||||
session_id = await agent.create_session(session_name)
|
||||
return AgentSessionCreateResponse(
|
||||
session_id=session.session_id,
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
async def create_agent_turn(
|
||||
|
@ -131,6 +108,8 @@ class MetaReferenceAgentsImpl(Agents):
|
|||
attachments: Optional[List[Attachment]] = None,
|
||||
stream: Optional[bool] = False,
|
||||
) -> AsyncGenerator:
|
||||
agent = await self.get_agent(agent_id)
|
||||
|
||||
# wrapper request to make it easier to pass around (internal only, not exposed to API)
|
||||
request = AgentTurnCreateRequest(
|
||||
agent_id=agent_id,
|
||||
|
@ -140,12 +119,5 @@ class MetaReferenceAgentsImpl(Agents):
|
|||
stream=stream,
|
||||
)
|
||||
|
||||
agent_id = request.agent_id
|
||||
assert agent_id in AGENT_INSTANCES_BY_ID, f"System {agent_id} not found"
|
||||
agent = AGENT_INSTANCES_BY_ID[agent_id]
|
||||
|
||||
assert (
|
||||
request.session_id in agent.sessions
|
||||
), f"Session {request.session_id} not found"
|
||||
async for event in agent.create_and_execute_turn(request):
|
||||
yield event
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue