[memory refactor][3/n] Introduce RAGToolRuntime as a specialized sub-protocol (#832)

See https://github.com/meta-llama/llama-stack/issues/827 for the broader
design.

Third part:
- we need to make `tool_runtime.rag_tool.query_context()` and
`tool_runtime.rag_tool.insert_documents()` methods work smoothly with
complete type safety. To that end, we introduce a sub-resource path
`tool-runtime/rag-tool/` and make changes to the resolver to make things
work.
- the PR updates the agents implementation to directly call these typed
APIs for memory accesses rather than going through the complex, untyped
"invoke_tool" API. the code looks much nicer and simpler (expectedly.)
- there are a number of hacks in the server resolver implementation
still, we will live with some and fix some

Note that we must make sure the client SDKs are able to handle this
subresource complexity also. Stainless has support for subresources, so
this should be possible but beware.

## Test Plan

Our RAG test is sad (doesn't actually test for actual RAG output) but I
verified that the implementation works. I will work on fixing the RAG
test afterwards.

```bash
pytest -s -v tests/agents/test_agents.py -k "rag and together" --safety-shield=meta-llama/Llama-Guard-3-8B
```
This commit is contained in:
Ashwin Bharambe 2025-01-22 10:04:16 -08:00 committed by GitHub
parent 78a481bb22
commit 1a7490470a
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
33 changed files with 1648 additions and 1345 deletions

View file

@ -333,6 +333,8 @@ async def instantiate_provider(
impl.__provider_spec__ = provider_spec
impl.__provider_config__ = config
# TODO: check compliance for special tool groups
# the impl should be for Api.tool_runtime, the name should be the special tool group, the protocol should be the special tool group protocol
check_protocol_compliance(impl, protocols[provider_spec.api])
if (
not isinstance(provider_spec, AutoRoutedProviderSpec)

View file

@ -36,7 +36,14 @@ from llama_stack.apis.scoring import (
ScoringFnParams,
)
from llama_stack.apis.shields import Shield
from llama_stack.apis.tools import ToolDef, ToolRuntime
from llama_stack.apis.tools import (
RAGDocument,
RAGQueryConfig,
RAGQueryResult,
RAGToolRuntime,
ToolDef,
ToolRuntime,
)
from llama_stack.apis.vector_io import Chunk, QueryChunksResponse, VectorIO
from llama_stack.providers.datatypes import RoutingTable
@ -400,22 +407,55 @@ class EvalRouter(Eval):
class ToolRuntimeRouter(ToolRuntime):
class RagToolImpl(RAGToolRuntime):
def __init__(
self,
routing_table: RoutingTable,
) -> None:
self.routing_table = routing_table
async def query_context(
self,
content: InterleavedContent,
query_config: RAGQueryConfig,
vector_db_ids: List[str],
) -> RAGQueryResult:
return await self.routing_table.get_provider_impl(
"rag_tool.query_context"
).query_context(content, query_config, vector_db_ids)
async def insert_documents(
self,
documents: List[RAGDocument],
vector_db_id: str,
chunk_size_in_tokens: int = 512,
) -> None:
return await self.routing_table.get_provider_impl(
"rag_tool.insert_documents"
).insert_documents(documents, vector_db_id, chunk_size_in_tokens)
def __init__(
self,
routing_table: RoutingTable,
) -> None:
self.routing_table = routing_table
# HACK ALERT this should be in sync with "get_all_api_endpoints()"
# TODO: make sure rag_tool vs builtin::memory is correct everywhere
self.rag_tool = self.RagToolImpl(routing_table)
setattr(self, "rag_tool.query_context", self.rag_tool.query_context)
setattr(self, "rag_tool.insert_documents", self.rag_tool.insert_documents)
async def initialize(self) -> None:
pass
async def shutdown(self) -> None:
pass
async def invoke_tool(self, tool_name: str, args: Dict[str, Any]) -> Any:
async def invoke_tool(self, tool_name: str, kwargs: Dict[str, Any]) -> Any:
return await self.routing_table.get_provider_impl(tool_name).invoke_tool(
tool_name=tool_name,
args=args,
kwargs=kwargs,
)
async def list_runtime_tools(

View file

@ -9,6 +9,8 @@ from typing import Dict, List
from pydantic import BaseModel
from llama_stack.apis.tools import RAGToolRuntime, SpecialToolGroup
from llama_stack.apis.version import LLAMA_STACK_API_VERSION
from llama_stack.distribution.resolver import api_protocol_map
@ -22,21 +24,39 @@ class ApiEndpoint(BaseModel):
name: str
def toolgroup_protocol_map():
return {
SpecialToolGroup.rag_tool: RAGToolRuntime,
}
def get_all_api_endpoints() -> Dict[Api, List[ApiEndpoint]]:
apis = {}
protocols = api_protocol_map()
toolgroup_protocols = toolgroup_protocol_map()
for api, protocol in protocols.items():
endpoints = []
protocol_methods = inspect.getmembers(protocol, predicate=inspect.isfunction)
# HACK ALERT
if api == Api.tool_runtime:
for tool_group in SpecialToolGroup:
sub_protocol = toolgroup_protocols[tool_group]
sub_protocol_methods = inspect.getmembers(
sub_protocol, predicate=inspect.isfunction
)
for name, method in sub_protocol_methods:
if not hasattr(method, "__webmethod__"):
continue
protocol_methods.append((f"{tool_group.value}.{name}", method))
for name, method in protocol_methods:
if not hasattr(method, "__webmethod__"):
continue
webmethod = method.__webmethod__
route = f"/{LLAMA_STACK_API_VERSION}/{webmethod.route.lstrip('/')}"
if webmethod.method == "GET":
method = "get"
elif webmethod.method == "DELETE":

View file

@ -29,7 +29,7 @@ from llama_stack.apis.scoring_functions import ScoringFunctions
from llama_stack.apis.shields import Shields
from llama_stack.apis.synthetic_data_generation import SyntheticDataGeneration
from llama_stack.apis.telemetry import Telemetry
from llama_stack.apis.tools import ToolGroups, ToolRuntime
from llama_stack.apis.tools import RAGToolRuntime, ToolGroups, ToolRuntime
from llama_stack.apis.vector_dbs import VectorDBs
from llama_stack.apis.vector_io import VectorIO
from llama_stack.distribution.datatypes import StackRunConfig
@ -62,6 +62,7 @@ class LlamaStack(
Inspect,
ToolGroups,
ToolRuntime,
RAGToolRuntime,
):
pass

View file

@ -35,7 +35,7 @@ class DistributionRegistry(Protocol):
REGISTER_PREFIX = "distributions:registry"
KEY_VERSION = "v5"
KEY_VERSION = "v6"
KEY_FORMAT = f"{REGISTER_PREFIX}:{KEY_VERSION}::" + "{type}:{identifier}"