[memory refactor][3/n] Introduce RAGToolRuntime as a specialized sub-protocol (#832)

See https://github.com/meta-llama/llama-stack/issues/827 for the broader
design.

Third part:
- we need to make `tool_runtime.rag_tool.query_context()` and
`tool_runtime.rag_tool.insert_documents()` methods work smoothly with
complete type safety. To that end, we introduce a sub-resource path
`tool-runtime/rag-tool/` and make changes to the resolver to make things
work.
- the PR updates the agents implementation to directly call these typed
APIs for memory accesses rather than going through the complex, untyped
"invoke_tool" API. the code looks much nicer and simpler (expectedly.)
- there are a number of hacks in the server resolver implementation
still, we will live with some and fix some

Note that we must make sure the client SDKs are able to handle this
subresource complexity also. Stainless has support for subresources, so
this should be possible but beware.

## Test Plan

Our RAG test is sad (doesn't actually test for actual RAG output) but I
verified that the implementation works. I will work on fixing the RAG
test afterwards.

```bash
pytest -s -v tests/agents/test_agents.py -k "rag and together" --safety-shield=meta-llama/Llama-Guard-3-8B
```
This commit is contained in:
Ashwin Bharambe 2025-01-22 10:04:16 -08:00 committed by GitHub
parent 78a481bb22
commit 1a7490470a
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
33 changed files with 1648 additions and 1345 deletions

View file

@ -172,10 +172,16 @@ def _get_endpoint_functions(
def _get_defining_class(member_fn: str, derived_cls: type) -> type:
"Find the class in which a member function is first defined in a class inheritance hierarchy."
# This import must be dynamic here
from llama_stack.apis.tools import RAGToolRuntime, ToolRuntime
# iterate in reverse member resolution order to find most specific class first
for cls in reversed(inspect.getmro(derived_cls)):
for name, _ in inspect.getmembers(cls, inspect.isfunction):
if name == member_fn:
# HACK ALERT
if cls == RAGToolRuntime:
return ToolRuntime
return cls
raise ValidationError(

File diff suppressed because it is too large Load diff

File diff suppressed because it is too large Load diff

View file

@ -5,3 +5,4 @@
# the root directory of this source tree.
from .tools import * # noqa: F401 F403
from .rag_tool import * # noqa: F401 F403

View file

@ -0,0 +1,95 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from enum import Enum
from typing import Any, Dict, List, Literal, Optional, Union
from llama_models.schema_utils import json_schema_type, register_schema, webmethod
from pydantic import BaseModel, Field
from typing_extensions import Annotated, Protocol, runtime_checkable
from llama_stack.apis.common.content_types import InterleavedContent, URL
from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol
@json_schema_type
class RAGDocument(BaseModel):
document_id: str
content: InterleavedContent | URL
mime_type: str | None = None
metadata: Dict[str, Any] = Field(default_factory=dict)
@json_schema_type
class RAGQueryResult(BaseModel):
content: Optional[InterleavedContent] = None
@json_schema_type
class RAGQueryGenerator(Enum):
default = "default"
llm = "llm"
custom = "custom"
@json_schema_type
class DefaultRAGQueryGeneratorConfig(BaseModel):
type: Literal["default"] = "default"
separator: str = " "
@json_schema_type
class LLMRAGQueryGeneratorConfig(BaseModel):
type: Literal["llm"] = "llm"
model: str
template: str
RAGQueryGeneratorConfig = register_schema(
Annotated[
Union[
DefaultRAGQueryGeneratorConfig,
LLMRAGQueryGeneratorConfig,
],
Field(discriminator="type"),
],
name="RAGQueryGeneratorConfig",
)
@json_schema_type
class RAGQueryConfig(BaseModel):
# This config defines how a query is generated using the messages
# for memory bank retrieval.
query_generator_config: RAGQueryGeneratorConfig = Field(
default=DefaultRAGQueryGeneratorConfig()
)
max_tokens_in_context: int = 4096
max_chunks: int = 5
@runtime_checkable
@trace_protocol
class RAGToolRuntime(Protocol):
@webmethod(route="/tool-runtime/rag-tool/insert-documents", method="POST")
async def insert_documents(
self,
documents: List[RAGDocument],
vector_db_id: str,
chunk_size_in_tokens: int = 512,
) -> None:
"""Index documents so they can be used by the RAG system"""
...
@webmethod(route="/tool-runtime/rag-tool/query-context", method="POST")
async def query_context(
self,
content: InterleavedContent,
query_config: RAGQueryConfig,
vector_db_ids: List[str],
) -> RAGQueryResult:
"""Query the RAG system for context; typically invoked by the agent"""
...

View file

@ -15,6 +15,8 @@ from llama_stack.apis.common.content_types import InterleavedContent, URL
from llama_stack.apis.resource import Resource, ResourceType
from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol
from .rag_tool import RAGToolRuntime
@json_schema_type
class ToolParameter(BaseModel):
@ -130,11 +132,17 @@ class ToolGroups(Protocol):
...
class SpecialToolGroup(Enum):
rag_tool = "rag_tool"
@runtime_checkable
@trace_protocol
class ToolRuntime(Protocol):
tool_store: ToolStore
rag_tool: RAGToolRuntime
# TODO: This needs to be renamed once OPEN API generator name conflict issue is fixed.
@webmethod(route="/tool-runtime/list-tools", method="GET")
async def list_runtime_tools(
@ -143,7 +151,7 @@ class ToolRuntime(Protocol):
@webmethod(route="/tool-runtime/invoke", method="POST")
async def invoke_tool(
self, tool_name: str, args: Dict[str, Any]
self, tool_name: str, kwargs: Dict[str, Any]
) -> ToolInvocationResult:
"""Run a tool with the given arguments"""
...

View file

@ -333,6 +333,8 @@ async def instantiate_provider(
impl.__provider_spec__ = provider_spec
impl.__provider_config__ = config
# TODO: check compliance for special tool groups
# the impl should be for Api.tool_runtime, the name should be the special tool group, the protocol should be the special tool group protocol
check_protocol_compliance(impl, protocols[provider_spec.api])
if (
not isinstance(provider_spec, AutoRoutedProviderSpec)

View file

@ -36,7 +36,14 @@ from llama_stack.apis.scoring import (
ScoringFnParams,
)
from llama_stack.apis.shields import Shield
from llama_stack.apis.tools import ToolDef, ToolRuntime
from llama_stack.apis.tools import (
RAGDocument,
RAGQueryConfig,
RAGQueryResult,
RAGToolRuntime,
ToolDef,
ToolRuntime,
)
from llama_stack.apis.vector_io import Chunk, QueryChunksResponse, VectorIO
from llama_stack.providers.datatypes import RoutingTable
@ -400,22 +407,55 @@ class EvalRouter(Eval):
class ToolRuntimeRouter(ToolRuntime):
class RagToolImpl(RAGToolRuntime):
def __init__(
self,
routing_table: RoutingTable,
) -> None:
self.routing_table = routing_table
async def query_context(
self,
content: InterleavedContent,
query_config: RAGQueryConfig,
vector_db_ids: List[str],
) -> RAGQueryResult:
return await self.routing_table.get_provider_impl(
"rag_tool.query_context"
).query_context(content, query_config, vector_db_ids)
async def insert_documents(
self,
documents: List[RAGDocument],
vector_db_id: str,
chunk_size_in_tokens: int = 512,
) -> None:
return await self.routing_table.get_provider_impl(
"rag_tool.insert_documents"
).insert_documents(documents, vector_db_id, chunk_size_in_tokens)
def __init__(
self,
routing_table: RoutingTable,
) -> None:
self.routing_table = routing_table
# HACK ALERT this should be in sync with "get_all_api_endpoints()"
# TODO: make sure rag_tool vs builtin::memory is correct everywhere
self.rag_tool = self.RagToolImpl(routing_table)
setattr(self, "rag_tool.query_context", self.rag_tool.query_context)
setattr(self, "rag_tool.insert_documents", self.rag_tool.insert_documents)
async def initialize(self) -> None:
pass
async def shutdown(self) -> None:
pass
async def invoke_tool(self, tool_name: str, args: Dict[str, Any]) -> Any:
async def invoke_tool(self, tool_name: str, kwargs: Dict[str, Any]) -> Any:
return await self.routing_table.get_provider_impl(tool_name).invoke_tool(
tool_name=tool_name,
args=args,
kwargs=kwargs,
)
async def list_runtime_tools(

View file

@ -9,6 +9,8 @@ from typing import Dict, List
from pydantic import BaseModel
from llama_stack.apis.tools import RAGToolRuntime, SpecialToolGroup
from llama_stack.apis.version import LLAMA_STACK_API_VERSION
from llama_stack.distribution.resolver import api_protocol_map
@ -22,21 +24,39 @@ class ApiEndpoint(BaseModel):
name: str
def toolgroup_protocol_map():
return {
SpecialToolGroup.rag_tool: RAGToolRuntime,
}
def get_all_api_endpoints() -> Dict[Api, List[ApiEndpoint]]:
apis = {}
protocols = api_protocol_map()
toolgroup_protocols = toolgroup_protocol_map()
for api, protocol in protocols.items():
endpoints = []
protocol_methods = inspect.getmembers(protocol, predicate=inspect.isfunction)
# HACK ALERT
if api == Api.tool_runtime:
for tool_group in SpecialToolGroup:
sub_protocol = toolgroup_protocols[tool_group]
sub_protocol_methods = inspect.getmembers(
sub_protocol, predicate=inspect.isfunction
)
for name, method in sub_protocol_methods:
if not hasattr(method, "__webmethod__"):
continue
protocol_methods.append((f"{tool_group.value}.{name}", method))
for name, method in protocol_methods:
if not hasattr(method, "__webmethod__"):
continue
webmethod = method.__webmethod__
route = f"/{LLAMA_STACK_API_VERSION}/{webmethod.route.lstrip('/')}"
if webmethod.method == "GET":
method = "get"
elif webmethod.method == "DELETE":

View file

@ -29,7 +29,7 @@ from llama_stack.apis.scoring_functions import ScoringFunctions
from llama_stack.apis.shields import Shields
from llama_stack.apis.synthetic_data_generation import SyntheticDataGeneration
from llama_stack.apis.telemetry import Telemetry
from llama_stack.apis.tools import ToolGroups, ToolRuntime
from llama_stack.apis.tools import RAGToolRuntime, ToolGroups, ToolRuntime
from llama_stack.apis.vector_dbs import VectorDBs
from llama_stack.apis.vector_io import VectorIO
from llama_stack.distribution.datatypes import StackRunConfig
@ -62,6 +62,7 @@ class LlamaStack(
Inspect,
ToolGroups,
ToolRuntime,
RAGToolRuntime,
):
pass

View file

@ -35,7 +35,7 @@ class DistributionRegistry(Protocol):
REGISTER_PREFIX = "distributions:registry"
KEY_VERSION = "v5"
KEY_VERSION = "v6"
KEY_FORMAT = f"{REGISTER_PREFIX}:{KEY_VERSION}::" + "{type}:{identifier}"

View file

@ -19,9 +19,8 @@ async def get_provider_impl(
impl = MetaReferenceAgentsImpl(
config,
deps[Api.inference],
deps[Api.memory],
deps[Api.vector_io],
deps[Api.safety],
deps[Api.memory_banks],
deps[Api.tool_runtime],
deps[Api.tool_groups],
)

View file

@ -59,13 +59,18 @@ from llama_stack.apis.inference import (
ToolResponseMessage,
UserMessage,
)
from llama_stack.apis.memory import Memory, MemoryBankDocument
from llama_stack.apis.memory_banks import MemoryBanks, VectorMemoryBankParams
from llama_stack.apis.safety import Safety
from llama_stack.apis.tools import ToolGroups, ToolRuntime
from llama_stack.apis.tools import (
DefaultRAGQueryGeneratorConfig,
RAGDocument,
RAGQueryConfig,
ToolGroups,
ToolRuntime,
)
from llama_stack.apis.vector_io import VectorIO
from llama_stack.providers.utils.kvstore import KVStore
from llama_stack.providers.utils.memory.vector_store import concat_interleaved_content
from llama_stack.providers.utils.telemetry import tracing
from .persistence import AgentPersistence
from .safety import SafetyException, ShieldRunnerMixin
@ -79,7 +84,7 @@ def make_random_string(length: int = 8):
TOOLS_ATTACHMENT_KEY_REGEX = re.compile(r"__tools_attachment__=(\{.*?\})")
MEMORY_QUERY_TOOL = "query_memory"
MEMORY_QUERY_TOOL = "rag_tool.query_context"
WEB_SEARCH_TOOL = "web_search"
MEMORY_GROUP = "builtin::memory"
@ -91,20 +96,18 @@ class ChatAgent(ShieldRunnerMixin):
agent_config: AgentConfig,
tempdir: str,
inference_api: Inference,
memory_api: Memory,
memory_banks_api: MemoryBanks,
safety_api: Safety,
tool_runtime_api: ToolRuntime,
tool_groups_api: ToolGroups,
vector_io_api: VectorIO,
persistence_store: KVStore,
):
self.agent_id = agent_id
self.agent_config = agent_config
self.tempdir = tempdir
self.inference_api = inference_api
self.memory_api = memory_api
self.memory_banks_api = memory_banks_api
self.safety_api = safety_api
self.vector_io_api = vector_io_api
self.storage = AgentPersistence(agent_id, persistence_store)
self.tool_runtime_api = tool_runtime_api
self.tool_groups_api = tool_groups_api
@ -370,24 +373,30 @@ class ChatAgent(ShieldRunnerMixin):
documents: Optional[List[Document]] = None,
toolgroups_for_turn: Optional[List[AgentToolGroup]] = None,
) -> AsyncGenerator:
# TODO: simplify all of this code, it can be simpler
toolgroup_args = {}
toolgroups = set()
for toolgroup in self.agent_config.toolgroups:
if isinstance(toolgroup, AgentToolGroupWithArgs):
toolgroups.add(toolgroup.name)
toolgroup_args[toolgroup.name] = toolgroup.args
else:
toolgroups.add(toolgroup)
if toolgroups_for_turn:
for toolgroup in toolgroups_for_turn:
if isinstance(toolgroup, AgentToolGroupWithArgs):
toolgroups.add(toolgroup.name)
toolgroup_args[toolgroup.name] = toolgroup.args
else:
toolgroups.add(toolgroup)
tool_defs, tool_to_group = await self._get_tool_defs(toolgroups_for_turn)
if documents:
await self.handle_documents(
session_id, documents, input_messages, tool_defs
)
if MEMORY_QUERY_TOOL in tool_defs and len(input_messages) > 0:
memory_tool_group = tool_to_group.get(MEMORY_QUERY_TOOL, None)
if memory_tool_group is None:
raise ValueError(f"Memory tool group not found for {MEMORY_QUERY_TOOL}")
if MEMORY_GROUP in toolgroups and len(input_messages) > 0:
with tracing.span(MEMORY_QUERY_TOOL) as span:
step_id = str(uuid.uuid4())
yield AgentTurnResponseStreamChunk(
@ -398,17 +407,15 @@ class ChatAgent(ShieldRunnerMixin):
)
)
)
query_args = {
"messages": [msg.content for msg in input_messages],
**toolgroup_args.get(memory_tool_group, {}),
}
args = toolgroup_args.get(MEMORY_GROUP, {})
vector_db_ids = args.get("vector_db_ids", [])
session_info = await self.storage.get_session_info(session_id)
# if the session has a memory bank id, let the memory tool use it
if session_info.memory_bank_id:
if "memory_bank_ids" not in query_args:
query_args["memory_bank_ids"] = []
query_args["memory_bank_ids"].append(session_info.memory_bank_id)
vector_db_ids.append(session_info.memory_bank_id)
yield AgentTurnResponseStreamChunk(
event=AgentTurnResponseEvent(
payload=AgentTurnResponseStepProgressPayload(
@ -425,10 +432,18 @@ class ChatAgent(ShieldRunnerMixin):
)
)
)
result = await self.tool_runtime_api.invoke_tool(
tool_name=MEMORY_QUERY_TOOL,
args=query_args,
result = await self.tool_runtime_api.rag_tool.query_context(
content=concat_interleaved_content(
[msg.content for msg in input_messages]
),
query_config=RAGQueryConfig(
query_generator_config=DefaultRAGQueryGeneratorConfig(),
max_tokens_in_context=4096,
max_chunks=5,
),
vector_db_ids=vector_db_ids,
)
retrieved_context = result.content
yield AgentTurnResponseStreamChunk(
event=AgentTurnResponseEvent(
@ -449,7 +464,7 @@ class ChatAgent(ShieldRunnerMixin):
ToolResponse(
call_id="",
tool_name=MEMORY_QUERY_TOOL,
content=result.content,
content=retrieved_context or [],
)
],
),
@ -459,13 +474,11 @@ class ChatAgent(ShieldRunnerMixin):
span.set_attribute(
"input", [m.model_dump_json() for m in input_messages]
)
span.set_attribute("output", result.content)
span.set_attribute("error_code", result.error_code)
span.set_attribute("error_message", result.error_message)
span.set_attribute("output", retrieved_context)
span.set_attribute("tool_name", MEMORY_QUERY_TOOL)
if result.error_code == 0:
if retrieved_context:
last_message = input_messages[-1]
last_message.context = result.content
last_message.context = retrieved_context
output_attachments = []
@ -842,12 +855,13 @@ class ChatAgent(ShieldRunnerMixin):
if session_info.memory_bank_id is None:
bank_id = f"memory_bank_{session_id}"
await self.memory_banks_api.register_memory_bank(
memory_bank_id=bank_id,
params=VectorMemoryBankParams(
embedding_model="all-MiniLM-L6-v2",
chunk_size_in_tokens=512,
),
# TODO: the semantic for registration is definitely not "creation"
# so we need to fix it if we expect the agent to create a new vector db
# for each session
await self.vector_io_api.register_vector_db(
vector_db_id=bank_id,
embedding_model="all-MiniLM-L6-v2",
)
await self.storage.add_memory_bank_to_session(session_id, bank_id)
else:
@ -858,9 +872,9 @@ class ChatAgent(ShieldRunnerMixin):
async def add_to_session_memory_bank(
self, session_id: str, data: List[Document]
) -> None:
bank_id = await self._ensure_memory_bank(session_id)
vector_db_id = await self._ensure_memory_bank(session_id)
documents = [
MemoryBankDocument(
RAGDocument(
document_id=str(uuid.uuid4()),
content=a.content,
mime_type=a.mime_type,
@ -868,9 +882,10 @@ class ChatAgent(ShieldRunnerMixin):
)
for a in data
]
await self.memory_api.insert_documents(
bank_id=bank_id,
await self.tool_runtime_api.rag_tool.insert_documents(
documents=documents,
vector_db_id=vector_db_id,
chunk_size_in_tokens=512,
)
@ -955,7 +970,7 @@ async def execute_tool_call_maybe(
result = await tool_runtime_api.invoke_tool(
tool_name=name,
args=dict(
kwargs=dict(
session_id=session_id,
**tool_call_args,
),

View file

@ -26,10 +26,9 @@ from llama_stack.apis.agents import (
Turn,
)
from llama_stack.apis.inference import Inference, ToolResponseMessage, UserMessage
from llama_stack.apis.memory import Memory
from llama_stack.apis.memory_banks import MemoryBanks
from llama_stack.apis.safety import Safety
from llama_stack.apis.tools import ToolGroups, ToolRuntime
from llama_stack.apis.vector_io import VectorIO
from llama_stack.providers.utils.kvstore import InmemoryKVStoreImpl, kvstore_impl
from .agent_instance import ChatAgent
@ -44,17 +43,15 @@ class MetaReferenceAgentsImpl(Agents):
self,
config: MetaReferenceAgentsImplConfig,
inference_api: Inference,
memory_api: Memory,
vector_io_api: VectorIO,
safety_api: Safety,
memory_banks_api: MemoryBanks,
tool_runtime_api: ToolRuntime,
tool_groups_api: ToolGroups,
):
self.config = config
self.inference_api = inference_api
self.memory_api = memory_api
self.vector_io_api = vector_io_api
self.safety_api = safety_api
self.memory_banks_api = memory_banks_api
self.tool_runtime_api = tool_runtime_api
self.tool_groups_api = tool_groups_api
@ -114,8 +111,7 @@ class MetaReferenceAgentsImpl(Agents):
tempdir=self.tempdir,
inference_api=self.inference_api,
safety_api=self.safety_api,
memory_api=self.memory_api,
memory_banks_api=self.memory_banks_api,
vector_io_api=self.vector_io_api,
tool_runtime_api=self.tool_runtime_api,
tool_groups_api=self.tool_groups_api,
persistence_store=(

View file

@ -60,9 +60,9 @@ class CodeInterpreterToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime):
]
async def invoke_tool(
self, tool_name: str, args: Dict[str, Any]
self, tool_name: str, kwargs: Dict[str, Any]
) -> ToolInvocationResult:
script = args["code"]
script = kwargs["code"]
req = CodeExecutionRequest(scripts=[script])
res = self.code_executor.execute(req)
pieces = [res["process_status"]]

View file

@ -13,8 +13,6 @@ from .memory import MemoryToolRuntimeImpl
async def get_provider_impl(config: MemoryToolRuntimeConfig, deps: Dict[str, Any]):
impl = MemoryToolRuntimeImpl(
config, deps[Api.memory], deps[Api.memory_banks], deps[Api.inference]
)
impl = MemoryToolRuntimeImpl(config, deps[Api.vector_io], deps[Api.inference])
await impl.initialize()
return impl

View file

@ -4,87 +4,8 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from enum import Enum
from typing import Annotated, List, Literal, Union
from pydantic import BaseModel, Field
class _MemoryBankConfigCommon(BaseModel):
bank_id: str
class VectorMemoryBankConfig(_MemoryBankConfigCommon):
type: Literal["vector"] = "vector"
class KeyValueMemoryBankConfig(_MemoryBankConfigCommon):
type: Literal["keyvalue"] = "keyvalue"
keys: List[str] # what keys to focus on
class KeywordMemoryBankConfig(_MemoryBankConfigCommon):
type: Literal["keyword"] = "keyword"
class GraphMemoryBankConfig(_MemoryBankConfigCommon):
type: Literal["graph"] = "graph"
entities: List[str] # what entities to focus on
MemoryBankConfig = Annotated[
Union[
VectorMemoryBankConfig,
KeyValueMemoryBankConfig,
KeywordMemoryBankConfig,
GraphMemoryBankConfig,
],
Field(discriminator="type"),
]
class MemoryQueryGenerator(Enum):
default = "default"
llm = "llm"
custom = "custom"
class DefaultMemoryQueryGeneratorConfig(BaseModel):
type: Literal[MemoryQueryGenerator.default.value] = (
MemoryQueryGenerator.default.value
)
sep: str = " "
class LLMMemoryQueryGeneratorConfig(BaseModel):
type: Literal[MemoryQueryGenerator.llm.value] = MemoryQueryGenerator.llm.value
model: str
template: str
class CustomMemoryQueryGeneratorConfig(BaseModel):
type: Literal[MemoryQueryGenerator.custom.value] = MemoryQueryGenerator.custom.value
MemoryQueryGeneratorConfig = Annotated[
Union[
DefaultMemoryQueryGeneratorConfig,
LLMMemoryQueryGeneratorConfig,
CustomMemoryQueryGeneratorConfig,
],
Field(discriminator="type"),
]
class MemoryToolConfig(BaseModel):
memory_bank_configs: List[MemoryBankConfig] = Field(default_factory=list)
from pydantic import BaseModel
class MemoryToolRuntimeConfig(BaseModel):
# This config defines how a query is generated using the messages
# for memory bank retrieval.
query_generator_config: MemoryQueryGeneratorConfig = Field(
default=DefaultMemoryQueryGeneratorConfig()
)
max_tokens_in_context: int = 4096
max_chunks: int = 5
pass

View file

@ -5,68 +5,64 @@
# the root directory of this source tree.
from typing import List
from jinja2 import Template
from pydantic import BaseModel
from llama_stack.apis.common.content_types import InterleavedContent
from llama_stack.apis.inference import UserMessage
from llama_stack.apis.tools.rag_tool import (
DefaultRAGQueryGeneratorConfig,
LLMRAGQueryGeneratorConfig,
RAGQueryGenerator,
RAGQueryGeneratorConfig,
)
from llama_stack.providers.utils.inference.prompt_adapter import (
interleaved_content_as_str,
)
from .config import (
DefaultMemoryQueryGeneratorConfig,
LLMMemoryQueryGeneratorConfig,
MemoryQueryGenerator,
MemoryQueryGeneratorConfig,
)
async def generate_rag_query(
config: MemoryQueryGeneratorConfig,
messages: List[InterleavedContent],
config: RAGQueryGeneratorConfig,
content: InterleavedContent,
**kwargs,
):
"""
Generates a query that will be used for
retrieving relevant information from the memory bank.
"""
if config.type == MemoryQueryGenerator.default.value:
query = await default_rag_query_generator(config, messages, **kwargs)
elif config.type == MemoryQueryGenerator.llm.value:
query = await llm_rag_query_generator(config, messages, **kwargs)
if config.type == RAGQueryGenerator.default.value:
query = await default_rag_query_generator(config, content, **kwargs)
elif config.type == RAGQueryGenerator.llm.value:
query = await llm_rag_query_generator(config, content, **kwargs)
else:
raise NotImplementedError(f"Unsupported memory query generator {config.type}")
return query
async def default_rag_query_generator(
config: DefaultMemoryQueryGeneratorConfig,
messages: List[InterleavedContent],
config: DefaultRAGQueryGeneratorConfig,
content: InterleavedContent,
**kwargs,
):
return config.sep.join(interleaved_content_as_str(m) for m in messages)
return interleaved_content_as_str(content, sep=config.separator)
async def llm_rag_query_generator(
config: LLMMemoryQueryGeneratorConfig,
messages: List[InterleavedContent],
config: LLMRAGQueryGeneratorConfig,
content: InterleavedContent,
**kwargs,
):
assert "inference_api" in kwargs, "LLMRAGQueryGenerator needs inference_api"
inference_api = kwargs["inference_api"]
m_dict = {
"messages": [
message.model_dump() if isinstance(message, BaseModel) else message
for message in messages
]
}
messages = []
if isinstance(content, list):
messages = [interleaved_content_as_str(m) for m in content]
else:
messages = [interleaved_content_as_str(content)]
template = Template(config.template)
content = template.render(m_dict)
content = template.render({"messages": messages})
model = config.model
message = UserMessage(content=content)

View file

@ -10,20 +10,29 @@ import secrets
import string
from typing import Any, Dict, List, Optional
from llama_stack.apis.common.content_types import URL
from llama_stack.apis.inference import Inference, InterleavedContent
from llama_stack.apis.memory import Memory, QueryDocumentsResponse
from llama_stack.apis.memory_banks import MemoryBanks
from llama_stack.apis.common.content_types import (
InterleavedContent,
TextContentItem,
URL,
)
from llama_stack.apis.inference import Inference
from llama_stack.apis.tools import (
RAGDocument,
RAGQueryConfig,
RAGQueryResult,
RAGToolRuntime,
ToolDef,
ToolInvocationResult,
ToolParameter,
ToolRuntime,
)
from llama_stack.apis.vector_io import QueryChunksResponse, VectorIO
from llama_stack.providers.datatypes import ToolsProtocolPrivate
from llama_stack.providers.utils.memory.vector_store import concat_interleaved_content
from llama_stack.providers.utils.memory.vector_store import (
content_from_doc,
make_overlapped_chunks,
)
from .config import MemoryToolConfig, MemoryToolRuntimeConfig
from .config import MemoryToolRuntimeConfig
from .context_retriever import generate_rag_query
log = logging.getLogger(__name__)
@ -35,65 +44,79 @@ def make_random_string(length: int = 8):
)
class MemoryToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime):
class MemoryToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, RAGToolRuntime):
def __init__(
self,
config: MemoryToolRuntimeConfig,
memory_api: Memory,
memory_banks_api: MemoryBanks,
vector_io_api: VectorIO,
inference_api: Inference,
):
self.config = config
self.memory_api = memory_api
self.memory_banks_api = memory_banks_api
self.vector_io_api = vector_io_api
self.inference_api = inference_api
async def initialize(self):
pass
async def list_runtime_tools(
self, tool_group_id: Optional[str] = None, mcp_endpoint: Optional[URL] = None
) -> List[ToolDef]:
return [
ToolDef(
name="query_memory",
description="Retrieve context from memory",
parameters=[
ToolParameter(
name="messages",
description="The input messages to search for",
parameter_type="array",
),
],
)
]
async def shutdown(self):
pass
async def insert_documents(
self,
documents: List[RAGDocument],
vector_db_id: str,
chunk_size_in_tokens: int = 512,
) -> None:
chunks = []
for doc in documents:
content = await content_from_doc(doc)
chunks.extend(
make_overlapped_chunks(
doc.document_id,
content,
chunk_size_in_tokens,
chunk_size_in_tokens // 4,
)
)
if not chunks:
return
await self.vector_io_api.insert_chunks(
chunks=chunks,
vector_db_id=vector_db_id,
)
async def query_context(
self,
content: InterleavedContent,
query_config: RAGQueryConfig,
vector_db_ids: List[str],
) -> RAGQueryResult:
if not vector_db_ids:
return RAGQueryResult(content=None)
async def _retrieve_context(
self, input_messages: List[InterleavedContent], bank_ids: List[str]
) -> Optional[List[InterleavedContent]]:
if not bank_ids:
return None
query = await generate_rag_query(
self.config.query_generator_config,
input_messages,
query_config.query_generator_config,
content,
inference_api=self.inference_api,
)
tasks = [
self.memory_api.query_documents(
bank_id=bank_id,
self.vector_io_api.query_chunks(
vector_db_id=vector_db_id,
query=query,
params={
"max_chunks": self.config.max_chunks,
"max_chunks": query_config.max_chunks,
},
)
for bank_id in bank_ids
for vector_db_id in vector_db_ids
]
results: List[QueryDocumentsResponse] = await asyncio.gather(*tasks)
results: List[QueryChunksResponse] = await asyncio.gather(*tasks)
chunks = [c for r in results for c in r.chunks]
scores = [s for r in results for s in r.scores]
if not chunks:
return None
return RAGQueryResult(content=None)
# sort by score
chunks, scores = zip(
@ -102,45 +125,52 @@ class MemoryToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime):
tokens = 0
picked = []
for c in chunks[: self.config.max_chunks]:
tokens += c.token_count
if tokens > self.config.max_tokens_in_context:
for c in chunks[: query_config.max_chunks]:
metadata = c.metadata
tokens += metadata["token_count"]
if tokens > query_config.max_tokens_in_context:
log.error(
f"Using {len(picked)} chunks; reached max tokens in context: {tokens}",
)
break
picked.append(f"id:{c.document_id}; content:{c.content}")
picked.append(
TextContentItem(
text=f"id:{metadata['document_id']}; content:{c.content}",
)
)
return RAGQueryResult(
content=[
TextContentItem(
text="Here are the retrieved documents for relevant context:\n=== START-RETRIEVED-CONTEXT ===\n",
),
*picked,
TextContentItem(
text="\n=== END-RETRIEVED-CONTEXT ===\n",
),
],
)
async def list_runtime_tools(
self, tool_group_id: Optional[str] = None, mcp_endpoint: Optional[URL] = None
) -> List[ToolDef]:
# Parameters are not listed since these methods are not yet invoked automatically
# by the LLM. The method is only implemented so things like /tools can list without
# encountering fatals.
return [
"Here are the retrieved documents for relevant context:\n=== START-RETRIEVED-CONTEXT ===\n",
*picked,
"\n=== END-RETRIEVED-CONTEXT ===\n",
ToolDef(
name="rag_tool.query_context",
description="Retrieve context from memory",
),
ToolDef(
name="rag_tool.insert_documents",
description="Insert documents into memory",
),
]
async def invoke_tool(
self, tool_name: str, args: Dict[str, Any]
self, tool_name: str, kwargs: Dict[str, Any]
) -> ToolInvocationResult:
tool = await self.tool_store.get_tool(tool_name)
tool_group = await self.tool_store.get_tool_group(tool.toolgroup_id)
final_args = tool_group.args or {}
final_args.update(args)
config = MemoryToolConfig()
if tool.metadata and tool.metadata.get("config") is not None:
config = MemoryToolConfig(**tool.metadata["config"])
if "memory_bank_ids" in final_args:
bank_ids = final_args["memory_bank_ids"]
else:
bank_ids = [
bank_config.bank_id for bank_config in config.memory_bank_configs
]
if "messages" not in final_args:
raise ValueError("messages are required")
context = await self._retrieve_context(
final_args["messages"],
bank_ids,
)
if context is None:
context = []
return ToolInvocationResult(
content=concat_interleaved_content(context), error_code=0
raise RuntimeError(
"This toolgroup should not be called generically but only through specific methods of the RAGToolRuntime protocol"
)

View file

@ -23,7 +23,7 @@ def available_providers() -> List[ProviderSpec]:
pip_packages=[],
module="llama_stack.providers.inline.tool_runtime.memory",
config_class="llama_stack.providers.inline.tool_runtime.memory.config.MemoryToolRuntimeConfig",
api_dependencies=[Api.vector_io, Api.vector_dbs, Api.inference],
api_dependencies=[Api.vector_io, Api.inference],
),
InlineProviderSpec(
api=Api.tool_runtime,

View file

@ -68,7 +68,7 @@ class BingSearchToolRuntimeImpl(
]
async def invoke_tool(
self, tool_name: str, args: Dict[str, Any]
self, tool_name: str, kwargs: Dict[str, Any]
) -> ToolInvocationResult:
api_key = self._get_api_key()
headers = {
@ -78,7 +78,7 @@ class BingSearchToolRuntimeImpl(
"count": self.config.top_k,
"textDecorations": True,
"textFormat": "HTML",
"q": args["query"],
"q": kwargs["query"],
}
response = requests.get(

View file

@ -68,7 +68,7 @@ class BraveSearchToolRuntimeImpl(
]
async def invoke_tool(
self, tool_name: str, args: Dict[str, Any]
self, tool_name: str, kwargs: Dict[str, Any]
) -> ToolInvocationResult:
api_key = self._get_api_key()
url = "https://api.search.brave.com/res/v1/web/search"
@ -77,7 +77,7 @@ class BraveSearchToolRuntimeImpl(
"Accept-Encoding": "gzip",
"Accept": "application/json",
}
payload = {"q": args["query"]}
payload = {"q": kwargs["query"]}
response = requests.get(url=url, params=payload, headers=headers)
response.raise_for_status()
results = self._clean_brave_response(response.json())

View file

@ -65,7 +65,7 @@ class ModelContextProtocolToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime):
return tools
async def invoke_tool(
self, tool_name: str, args: Dict[str, Any]
self, tool_name: str, kwargs: Dict[str, Any]
) -> ToolInvocationResult:
tool = await self.tool_store.get_tool(tool_name)
if tool.metadata is None or tool.metadata.get("endpoint") is None:
@ -77,7 +77,7 @@ class ModelContextProtocolToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime):
async with sse_client(endpoint) as streams:
async with ClientSession(*streams) as session:
await session.initialize()
result = await session.call_tool(tool.identifier, args)
result = await session.call_tool(tool.identifier, kwargs)
return ToolInvocationResult(
content="\n".join([result.model_dump_json() for result in result.content]),

View file

@ -67,12 +67,12 @@ class TavilySearchToolRuntimeImpl(
]
async def invoke_tool(
self, tool_name: str, args: Dict[str, Any]
self, tool_name: str, kwargs: Dict[str, Any]
) -> ToolInvocationResult:
api_key = self._get_api_key()
response = requests.post(
"https://api.tavily.com/search",
json={"api_key": api_key, "query": args["query"]},
json={"api_key": api_key, "query": kwargs["query"]},
)
return ToolInvocationResult(

View file

@ -68,11 +68,11 @@ class WolframAlphaToolRuntimeImpl(
]
async def invoke_tool(
self, tool_name: str, args: Dict[str, Any]
self, tool_name: str, kwargs: Dict[str, Any]
) -> ToolInvocationResult:
api_key = self._get_api_key()
params = {
"input": args["query"],
"input": kwargs["query"],
"appid": api_key,
"format": "plaintext",
"output": "json",

View file

@ -12,10 +12,10 @@ from ..conftest import (
get_test_config_for_api,
)
from ..inference.fixtures import INFERENCE_FIXTURES
from ..memory.fixtures import MEMORY_FIXTURES
from ..safety.fixtures import SAFETY_FIXTURES, safety_model_from_shield
from ..tools.fixtures import TOOL_RUNTIME_FIXTURES
from ..vector_io.fixtures import VECTOR_IO_FIXTURES
from .fixtures import AGENTS_FIXTURES
DEFAULT_PROVIDER_COMBINATIONS = [
@ -23,7 +23,7 @@ DEFAULT_PROVIDER_COMBINATIONS = [
{
"inference": "meta_reference",
"safety": "llama_guard",
"memory": "faiss",
"vector_io": "faiss",
"agents": "meta_reference",
"tool_runtime": "memory_and_search",
},
@ -34,7 +34,7 @@ DEFAULT_PROVIDER_COMBINATIONS = [
{
"inference": "ollama",
"safety": "llama_guard",
"memory": "faiss",
"vector_io": "faiss",
"agents": "meta_reference",
"tool_runtime": "memory_and_search",
},
@ -46,7 +46,7 @@ DEFAULT_PROVIDER_COMBINATIONS = [
"inference": "together",
"safety": "llama_guard",
# make this work with Weaviate which is what the together distro supports
"memory": "faiss",
"vector_io": "faiss",
"agents": "meta_reference",
"tool_runtime": "memory_and_search",
},
@ -57,7 +57,7 @@ DEFAULT_PROVIDER_COMBINATIONS = [
{
"inference": "fireworks",
"safety": "llama_guard",
"memory": "faiss",
"vector_io": "faiss",
"agents": "meta_reference",
"tool_runtime": "memory_and_search",
},
@ -68,7 +68,7 @@ DEFAULT_PROVIDER_COMBINATIONS = [
{
"inference": "remote",
"safety": "remote",
"memory": "remote",
"vector_io": "remote",
"agents": "remote",
"tool_runtime": "memory_and_search",
},
@ -115,7 +115,7 @@ def pytest_generate_tests(metafunc):
available_fixtures = {
"inference": INFERENCE_FIXTURES,
"safety": SAFETY_FIXTURES,
"memory": MEMORY_FIXTURES,
"vector_io": VECTOR_IO_FIXTURES,
"agents": AGENTS_FIXTURES,
"tool_runtime": TOOL_RUNTIME_FIXTURES,
}

View file

@ -69,7 +69,7 @@ async def agents_stack(
providers = {}
provider_data = {}
for key in ["inference", "safety", "memory", "agents", "tool_runtime"]:
for key in ["inference", "safety", "vector_io", "agents", "tool_runtime"]:
fixture = request.getfixturevalue(f"{key}_{fixture_dict[key]}")
providers[key] = fixture.providers
if key == "inference":
@ -118,7 +118,7 @@ async def agents_stack(
)
test_stack = await construct_stack_for_test(
[Api.agents, Api.inference, Api.safety, Api.memory, Api.tool_runtime],
[Api.agents, Api.inference, Api.safety, Api.vector_io, Api.tool_runtime],
providers,
provider_data,
models=models,

View file

@ -214,9 +214,11 @@ class TestAgents:
turn_response = [
chunk async for chunk in await agents_impl.create_agent_turn(**turn_request)
]
assert len(turn_response) > 0
# FIXME: we need to check the content of the turn response and ensure
# RAG actually worked
@pytest.mark.asyncio
async def test_create_agent_turn_with_tavily_search(
self, agents_stack, search_query_messages, common_params

View file

@ -8,13 +8,12 @@ import uuid
import pytest
from llama_stack.apis.tools import RAGDocument
from llama_stack.apis.vector_dbs import ListVectorDBsResponse, VectorDB
from llama_stack.apis.vector_io import QueryChunksResponse
from llama_stack.providers.utils.memory.vector_store import (
make_overlapped_chunks,
MemoryBankDocument,
)
from llama_stack.providers.utils.memory.vector_store import make_overlapped_chunks
# How to run this test:
#
@ -26,22 +25,22 @@ from llama_stack.providers.utils.memory.vector_store import (
@pytest.fixture(scope="session")
def sample_chunks():
docs = [
MemoryBankDocument(
RAGDocument(
document_id="doc1",
content="Python is a high-level programming language.",
metadata={"category": "programming", "difficulty": "beginner"},
),
MemoryBankDocument(
RAGDocument(
document_id="doc2",
content="Machine learning is a subset of artificial intelligence.",
metadata={"category": "AI", "difficulty": "advanced"},
),
MemoryBankDocument(
RAGDocument(
document_id="doc3",
content="Data structures are fundamental to computer science.",
metadata={"category": "computer science", "difficulty": "intermediate"},
),
MemoryBankDocument(
RAGDocument(
document_id="doc4",
content="Neural networks are inspired by biological neural networks.",
metadata={"category": "AI", "difficulty": "advanced"},

View file

@ -19,7 +19,6 @@ import numpy as np
from llama_models.llama3.api.tokenizer import Tokenizer
from numpy.typing import NDArray
from pydantic import BaseModel, Field
from pypdf import PdfReader
from llama_stack.apis.common.content_types import (
@ -27,6 +26,7 @@ from llama_stack.apis.common.content_types import (
TextContentItem,
URL,
)
from llama_stack.apis.tools import RAGDocument
from llama_stack.apis.vector_dbs import VectorDB
from llama_stack.apis.vector_io import Chunk, QueryChunksResponse
from llama_stack.providers.datatypes import Api
@ -34,17 +34,9 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
interleaved_content_as_str,
)
log = logging.getLogger(__name__)
class MemoryBankDocument(BaseModel):
document_id: str
content: InterleavedContent | URL
mime_type: str | None = None
metadata: Dict[str, Any] = Field(default_factory=dict)
def parse_pdf(data: bytes) -> str:
# For PDF and DOC/DOCX files, we can't reliably convert to string
pdf_bytes = io.BytesIO(data)
@ -122,7 +114,7 @@ def concat_interleaved_content(content: List[InterleavedContent]) -> Interleaved
return ret
async def content_from_doc(doc: MemoryBankDocument) -> str:
async def content_from_doc(doc: RAGDocument) -> str:
if isinstance(doc.content, URL):
if doc.content.uri.startswith("data:"):
return content_from_data(doc.content.uri)
@ -161,7 +153,13 @@ def make_overlapped_chunks(
chunk = tokenizer.decode(toks)
# chunk is a string
chunks.append(
Chunk(content=chunk, token_count=len(toks), document_id=document_id)
Chunk(
content=chunk,
metadata={
"token_count": len(toks),
"document_id": document_id,
},
)
)
return chunks

View file

@ -0,0 +1,105 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import json
from typing import List
import pytest
import requests
from pydantic import TypeAdapter
from llama_stack.apis.tools import (
DefaultRAGQueryGeneratorConfig,
RAGDocument,
RAGQueryConfig,
RAGQueryResult,
)
from llama_stack.apis.vector_dbs import VectorDB
from llama_stack.providers.utils.memory.vector_store import interleaved_content_as_str
class TestRAGToolEndpoints:
@pytest.fixture
def base_url(self) -> str:
return "http://localhost:8321/v1" # Adjust port if needed
@pytest.fixture
def sample_documents(self) -> List[RAGDocument]:
return [
RAGDocument(
document_id="doc1",
content="Python is a high-level programming language.",
metadata={"category": "programming", "difficulty": "beginner"},
),
RAGDocument(
document_id="doc2",
content="Machine learning is a subset of artificial intelligence.",
metadata={"category": "AI", "difficulty": "advanced"},
),
RAGDocument(
document_id="doc3",
content="Data structures are fundamental to computer science.",
metadata={"category": "computer science", "difficulty": "intermediate"},
),
]
@pytest.mark.asyncio
async def test_rag_workflow(
self, base_url: str, sample_documents: List[RAGDocument]
):
vector_db_payload = {
"vector_db_id": "test_vector_db",
"embedding_model": "all-MiniLM-L6-v2",
"embedding_dimension": 384,
}
response = requests.post(f"{base_url}/vector-dbs", json=vector_db_payload)
assert response.status_code == 200
vector_db = VectorDB(**response.json())
insert_payload = {
"documents": [
json.loads(doc.model_dump_json()) for doc in sample_documents
],
"vector_db_id": vector_db.identifier,
"chunk_size_in_tokens": 512,
}
response = requests.post(
f"{base_url}/tool-runtime/rag-tool/insert-documents",
json=insert_payload,
)
assert response.status_code == 200
query = "What is Python?"
query_config = RAGQueryConfig(
query_generator_config=DefaultRAGQueryGeneratorConfig(),
max_tokens_in_context=4096,
max_chunks=2,
)
query_payload = {
"content": query,
"query_config": json.loads(query_config.model_dump_json()),
"vector_db_ids": [vector_db.identifier],
}
response = requests.post(
f"{base_url}/tool-runtime/rag-tool/query-context",
json=query_payload,
)
assert response.status_code == 200
result = response.json()
result = TypeAdapter(RAGQueryResult).validate_python(result)
content_str = interleaved_content_as_str(result.content)
print(f"content: {content_str}")
assert len(content_str) > 0
assert "Python" in content_str
# Clean up: Delete the vector DB
response = requests.delete(f"{base_url}/vector-dbs/{vector_db.identifier}")
assert response.status_code == 200

View file

@ -4,7 +4,7 @@ distribution_spec:
providers:
inference:
- remote::together
memory:
vector_io:
- inline::faiss
- remote::chromadb
- remote::pgvector

View file

@ -5,7 +5,7 @@ apis:
- datasetio
- eval
- inference
- memory
- vector_io
- safety
- scoring
- telemetry
@ -20,7 +20,7 @@ providers:
- provider_id: sentence-transformers
provider_type: inline::sentence-transformers
config: {}
memory:
vector_io:
- provider_id: faiss
provider_type: inline::faiss
config:
@ -145,7 +145,6 @@ models:
model_type: embedding
shields:
- shield_id: meta-llama/Llama-Guard-3-8B
memory_banks: []
datasets: []
scoring_fns: []
eval_tasks: []