llama-stack/llama_stack/providers/inline/agents/meta_reference/agents.py
Ashwin Bharambe 1a7490470a
[memory refactor][3/n] Introduce RAGToolRuntime as a specialized sub-protocol (#832)
See https://github.com/meta-llama/llama-stack/issues/827 for the broader
design.

Third part:
- we need to make `tool_runtime.rag_tool.query_context()` and
`tool_runtime.rag_tool.insert_documents()` methods work smoothly with
complete type safety. To that end, we introduce a sub-resource path
`tool-runtime/rag-tool/` and make changes to the resolver to make things
work.
- the PR updates the agents implementation to directly call these typed
APIs for memory accesses rather than going through the complex, untyped
"invoke_tool" API. the code looks much nicer and simpler (expectedly.)
- there are a number of hacks in the server resolver implementation
still, we will live with some and fix some

Note that we must make sure the client SDKs are able to handle this
subresource complexity also. Stainless has support for subresources, so
this should be possible but beware.

## Test Plan

Our RAG test is sad (doesn't actually test for actual RAG output) but I
verified that the implementation works. I will work on fixing the RAG
test afterwards.

```bash
pytest -s -v tests/agents/test_agents.py -k "rag and together" --safety-shield=meta-llama/Llama-Guard-3-8B
```
2025-01-22 10:04:16 -08:00

223 lines
7 KiB
Python

# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import json
import logging
import shutil
import tempfile
import uuid
from typing import AsyncGenerator, List, Optional, Union
from termcolor import colored
from llama_stack.apis.agents import (
AgentConfig,
AgentCreateResponse,
Agents,
AgentSessionCreateResponse,
AgentStepResponse,
AgentToolGroup,
AgentTurnCreateRequest,
Document,
Session,
Turn,
)
from llama_stack.apis.inference import Inference, ToolResponseMessage, UserMessage
from llama_stack.apis.safety import Safety
from llama_stack.apis.tools import ToolGroups, ToolRuntime
from llama_stack.apis.vector_io import VectorIO
from llama_stack.providers.utils.kvstore import InmemoryKVStoreImpl, kvstore_impl
from .agent_instance import ChatAgent
from .config import MetaReferenceAgentsImplConfig
logger = logging.getLogger()
logger.setLevel(logging.INFO)
class MetaReferenceAgentsImpl(Agents):
def __init__(
self,
config: MetaReferenceAgentsImplConfig,
inference_api: Inference,
vector_io_api: VectorIO,
safety_api: Safety,
tool_runtime_api: ToolRuntime,
tool_groups_api: ToolGroups,
):
self.config = config
self.inference_api = inference_api
self.vector_io_api = vector_io_api
self.safety_api = safety_api
self.tool_runtime_api = tool_runtime_api
self.tool_groups_api = tool_groups_api
self.in_memory_store = InmemoryKVStoreImpl()
self.tempdir = tempfile.mkdtemp()
async def initialize(self) -> None:
self.persistence_store = await kvstore_impl(self.config.persistence_store)
# check if "bwrap" is available
if not shutil.which("bwrap"):
print(
colored(
"Warning: `bwrap` is not available. Code interpreter tool will not work correctly.",
"yellow",
)
)
async def create_agent(
self,
agent_config: AgentConfig,
) -> AgentCreateResponse:
agent_id = str(uuid.uuid4())
await self.persistence_store.set(
key=f"agent:{agent_id}",
value=agent_config.model_dump_json(),
)
return AgentCreateResponse(
agent_id=agent_id,
)
async def get_agent(self, agent_id: str) -> ChatAgent:
agent_config = await self.persistence_store.get(
key=f"agent:{agent_id}",
)
if not agent_config:
raise ValueError(f"Could not find agent config for {agent_id}")
try:
agent_config = json.loads(agent_config)
except json.JSONDecodeError as e:
raise ValueError(
f"Could not JSON decode agent config for {agent_id}"
) from e
try:
agent_config = AgentConfig(**agent_config)
except Exception as e:
raise ValueError(
f"Could not validate(?) agent config for {agent_id}"
) from e
return ChatAgent(
agent_id=agent_id,
agent_config=agent_config,
tempdir=self.tempdir,
inference_api=self.inference_api,
safety_api=self.safety_api,
vector_io_api=self.vector_io_api,
tool_runtime_api=self.tool_runtime_api,
tool_groups_api=self.tool_groups_api,
persistence_store=(
self.persistence_store
if agent_config.enable_session_persistence
else self.in_memory_store
),
)
async def create_agent_session(
self,
agent_id: str,
session_name: str,
) -> AgentSessionCreateResponse:
agent = await self.get_agent(agent_id)
session_id = await agent.create_session(session_name)
return AgentSessionCreateResponse(
session_id=session_id,
)
async def create_agent_turn(
self,
agent_id: str,
session_id: str,
messages: List[
Union[
UserMessage,
ToolResponseMessage,
]
],
toolgroups: Optional[List[AgentToolGroup]] = None,
documents: Optional[List[Document]] = None,
stream: Optional[bool] = False,
) -> AsyncGenerator:
request = AgentTurnCreateRequest(
agent_id=agent_id,
session_id=session_id,
messages=messages,
stream=True,
toolgroups=toolgroups,
documents=documents,
)
if stream:
return self._create_agent_turn_streaming(request)
else:
raise NotImplementedError("Non-streaming agent turns not yet implemented")
async def _create_agent_turn_streaming(
self,
request: AgentTurnCreateRequest,
) -> AsyncGenerator:
agent = await self.get_agent(request.agent_id)
async for event in agent.create_and_execute_turn(request):
yield event
async def get_agents_turn(
self, agent_id: str, session_id: str, turn_id: str
) -> Turn:
turn = await self.persistence_store.get(
f"session:{agent_id}:{session_id}:{turn_id}"
)
turn = json.loads(turn)
turn = Turn(**turn)
return turn
async def get_agents_step(
self, agent_id: str, session_id: str, turn_id: str, step_id: str
) -> AgentStepResponse:
turn = await self.persistence_store.get(
f"session:{agent_id}:{session_id}:{turn_id}"
)
turn = json.loads(turn)
turn = Turn(**turn)
steps = turn.steps
for step in steps:
if step.step_id == step_id:
return AgentStepResponse(step=step)
raise ValueError(f"Provided step_id {step_id} could not be found")
async def get_agents_session(
self,
agent_id: str,
session_id: str,
turn_ids: Optional[List[str]] = None,
) -> Session:
session = await self.persistence_store.get(f"session:{agent_id}:{session_id}")
session = Session(**json.loads(session), turns=[])
turns = []
if turn_ids:
for turn_id in turn_ids:
turn = await self.persistence_store.get(
f"session:{agent_id}:{session_id}:{turn_id}"
)
turn = json.loads(turn)
turn = Turn(**turn)
turns.append(turn)
return Session(
session_name=session.session_name,
session_id=session_id,
turns=turns if turns else [],
started_at=session.started_at,
)
async def delete_agents_session(self, agent_id: str, session_id: str) -> None:
await self.persistence_store.delete(f"session:{agent_id}:{session_id}")
async def delete_agent(self, agent_id: str) -> None:
await self.persistence_store.delete(f"agent:{agent_id}")