[memory refactor][3/n] Introduce RAGToolRuntime as a specialized sub-protocol (#832)

See https://github.com/meta-llama/llama-stack/issues/827 for the broader
design.

Third part:
- we need to make `tool_runtime.rag_tool.query_context()` and
`tool_runtime.rag_tool.insert_documents()` methods work smoothly with
complete type safety. To that end, we introduce a sub-resource path
`tool-runtime/rag-tool/` and make changes to the resolver to make things
work.
- the PR updates the agents implementation to directly call these typed
APIs for memory accesses rather than going through the complex, untyped
"invoke_tool" API. the code looks much nicer and simpler (expectedly.)
- there are a number of hacks in the server resolver implementation
still, we will live with some and fix some

Note that we must make sure the client SDKs are able to handle this
subresource complexity also. Stainless has support for subresources, so
this should be possible but beware.

## Test Plan

Our RAG test is sad (doesn't actually test for actual RAG output) but I
verified that the implementation works. I will work on fixing the RAG
test afterwards.

```bash
pytest -s -v tests/agents/test_agents.py -k "rag and together" --safety-shield=meta-llama/Llama-Guard-3-8B
```
This commit is contained in:
Ashwin Bharambe 2025-01-22 10:04:16 -08:00 committed by GitHub
parent 78a481bb22
commit 1a7490470a
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
33 changed files with 1648 additions and 1345 deletions

View file

@ -172,10 +172,16 @@ def _get_endpoint_functions(
def _get_defining_class(member_fn: str, derived_cls: type) -> type: def _get_defining_class(member_fn: str, derived_cls: type) -> type:
"Find the class in which a member function is first defined in a class inheritance hierarchy." "Find the class in which a member function is first defined in a class inheritance hierarchy."
# This import must be dynamic here
from llama_stack.apis.tools import RAGToolRuntime, ToolRuntime
# iterate in reverse member resolution order to find most specific class first # iterate in reverse member resolution order to find most specific class first
for cls in reversed(inspect.getmro(derived_cls)): for cls in reversed(inspect.getmro(derived_cls)):
for name, _ in inspect.getmembers(cls, inspect.isfunction): for name, _ in inspect.getmembers(cls, inspect.isfunction):
if name == member_fn: if name == member_fn:
# HACK ALERT
if cls == RAGToolRuntime:
return ToolRuntime
return cls return cls
raise ValidationError( raise ValidationError(

File diff suppressed because it is too large Load diff

File diff suppressed because it is too large Load diff

View file

@ -5,3 +5,4 @@
# the root directory of this source tree. # the root directory of this source tree.
from .tools import * # noqa: F401 F403 from .tools import * # noqa: F401 F403
from .rag_tool import * # noqa: F401 F403

View file

@ -0,0 +1,95 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from enum import Enum
from typing import Any, Dict, List, Literal, Optional, Union
from llama_models.schema_utils import json_schema_type, register_schema, webmethod
from pydantic import BaseModel, Field
from typing_extensions import Annotated, Protocol, runtime_checkable
from llama_stack.apis.common.content_types import InterleavedContent, URL
from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol
@json_schema_type
class RAGDocument(BaseModel):
document_id: str
content: InterleavedContent | URL
mime_type: str | None = None
metadata: Dict[str, Any] = Field(default_factory=dict)
@json_schema_type
class RAGQueryResult(BaseModel):
content: Optional[InterleavedContent] = None
@json_schema_type
class RAGQueryGenerator(Enum):
default = "default"
llm = "llm"
custom = "custom"
@json_schema_type
class DefaultRAGQueryGeneratorConfig(BaseModel):
type: Literal["default"] = "default"
separator: str = " "
@json_schema_type
class LLMRAGQueryGeneratorConfig(BaseModel):
type: Literal["llm"] = "llm"
model: str
template: str
RAGQueryGeneratorConfig = register_schema(
Annotated[
Union[
DefaultRAGQueryGeneratorConfig,
LLMRAGQueryGeneratorConfig,
],
Field(discriminator="type"),
],
name="RAGQueryGeneratorConfig",
)
@json_schema_type
class RAGQueryConfig(BaseModel):
# This config defines how a query is generated using the messages
# for memory bank retrieval.
query_generator_config: RAGQueryGeneratorConfig = Field(
default=DefaultRAGQueryGeneratorConfig()
)
max_tokens_in_context: int = 4096
max_chunks: int = 5
@runtime_checkable
@trace_protocol
class RAGToolRuntime(Protocol):
@webmethod(route="/tool-runtime/rag-tool/insert-documents", method="POST")
async def insert_documents(
self,
documents: List[RAGDocument],
vector_db_id: str,
chunk_size_in_tokens: int = 512,
) -> None:
"""Index documents so they can be used by the RAG system"""
...
@webmethod(route="/tool-runtime/rag-tool/query-context", method="POST")
async def query_context(
self,
content: InterleavedContent,
query_config: RAGQueryConfig,
vector_db_ids: List[str],
) -> RAGQueryResult:
"""Query the RAG system for context; typically invoked by the agent"""
...

View file

@ -15,6 +15,8 @@ from llama_stack.apis.common.content_types import InterleavedContent, URL
from llama_stack.apis.resource import Resource, ResourceType from llama_stack.apis.resource import Resource, ResourceType
from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol
from .rag_tool import RAGToolRuntime
@json_schema_type @json_schema_type
class ToolParameter(BaseModel): class ToolParameter(BaseModel):
@ -130,11 +132,17 @@ class ToolGroups(Protocol):
... ...
class SpecialToolGroup(Enum):
rag_tool = "rag_tool"
@runtime_checkable @runtime_checkable
@trace_protocol @trace_protocol
class ToolRuntime(Protocol): class ToolRuntime(Protocol):
tool_store: ToolStore tool_store: ToolStore
rag_tool: RAGToolRuntime
# TODO: This needs to be renamed once OPEN API generator name conflict issue is fixed. # TODO: This needs to be renamed once OPEN API generator name conflict issue is fixed.
@webmethod(route="/tool-runtime/list-tools", method="GET") @webmethod(route="/tool-runtime/list-tools", method="GET")
async def list_runtime_tools( async def list_runtime_tools(
@ -143,7 +151,7 @@ class ToolRuntime(Protocol):
@webmethod(route="/tool-runtime/invoke", method="POST") @webmethod(route="/tool-runtime/invoke", method="POST")
async def invoke_tool( async def invoke_tool(
self, tool_name: str, args: Dict[str, Any] self, tool_name: str, kwargs: Dict[str, Any]
) -> ToolInvocationResult: ) -> ToolInvocationResult:
"""Run a tool with the given arguments""" """Run a tool with the given arguments"""
... ...

View file

@ -333,6 +333,8 @@ async def instantiate_provider(
impl.__provider_spec__ = provider_spec impl.__provider_spec__ = provider_spec
impl.__provider_config__ = config impl.__provider_config__ = config
# TODO: check compliance for special tool groups
# the impl should be for Api.tool_runtime, the name should be the special tool group, the protocol should be the special tool group protocol
check_protocol_compliance(impl, protocols[provider_spec.api]) check_protocol_compliance(impl, protocols[provider_spec.api])
if ( if (
not isinstance(provider_spec, AutoRoutedProviderSpec) not isinstance(provider_spec, AutoRoutedProviderSpec)

View file

@ -36,7 +36,14 @@ from llama_stack.apis.scoring import (
ScoringFnParams, ScoringFnParams,
) )
from llama_stack.apis.shields import Shield from llama_stack.apis.shields import Shield
from llama_stack.apis.tools import ToolDef, ToolRuntime from llama_stack.apis.tools import (
RAGDocument,
RAGQueryConfig,
RAGQueryResult,
RAGToolRuntime,
ToolDef,
ToolRuntime,
)
from llama_stack.apis.vector_io import Chunk, QueryChunksResponse, VectorIO from llama_stack.apis.vector_io import Chunk, QueryChunksResponse, VectorIO
from llama_stack.providers.datatypes import RoutingTable from llama_stack.providers.datatypes import RoutingTable
@ -400,22 +407,55 @@ class EvalRouter(Eval):
class ToolRuntimeRouter(ToolRuntime): class ToolRuntimeRouter(ToolRuntime):
class RagToolImpl(RAGToolRuntime):
def __init__( def __init__(
self, self,
routing_table: RoutingTable, routing_table: RoutingTable,
) -> None: ) -> None:
self.routing_table = routing_table self.routing_table = routing_table
async def query_context(
self,
content: InterleavedContent,
query_config: RAGQueryConfig,
vector_db_ids: List[str],
) -> RAGQueryResult:
return await self.routing_table.get_provider_impl(
"rag_tool.query_context"
).query_context(content, query_config, vector_db_ids)
async def insert_documents(
self,
documents: List[RAGDocument],
vector_db_id: str,
chunk_size_in_tokens: int = 512,
) -> None:
return await self.routing_table.get_provider_impl(
"rag_tool.insert_documents"
).insert_documents(documents, vector_db_id, chunk_size_in_tokens)
def __init__(
self,
routing_table: RoutingTable,
) -> None:
self.routing_table = routing_table
# HACK ALERT this should be in sync with "get_all_api_endpoints()"
# TODO: make sure rag_tool vs builtin::memory is correct everywhere
self.rag_tool = self.RagToolImpl(routing_table)
setattr(self, "rag_tool.query_context", self.rag_tool.query_context)
setattr(self, "rag_tool.insert_documents", self.rag_tool.insert_documents)
async def initialize(self) -> None: async def initialize(self) -> None:
pass pass
async def shutdown(self) -> None: async def shutdown(self) -> None:
pass pass
async def invoke_tool(self, tool_name: str, args: Dict[str, Any]) -> Any: async def invoke_tool(self, tool_name: str, kwargs: Dict[str, Any]) -> Any:
return await self.routing_table.get_provider_impl(tool_name).invoke_tool( return await self.routing_table.get_provider_impl(tool_name).invoke_tool(
tool_name=tool_name, tool_name=tool_name,
args=args, kwargs=kwargs,
) )
async def list_runtime_tools( async def list_runtime_tools(

View file

@ -9,6 +9,8 @@ from typing import Dict, List
from pydantic import BaseModel from pydantic import BaseModel
from llama_stack.apis.tools import RAGToolRuntime, SpecialToolGroup
from llama_stack.apis.version import LLAMA_STACK_API_VERSION from llama_stack.apis.version import LLAMA_STACK_API_VERSION
from llama_stack.distribution.resolver import api_protocol_map from llama_stack.distribution.resolver import api_protocol_map
@ -22,21 +24,39 @@ class ApiEndpoint(BaseModel):
name: str name: str
def toolgroup_protocol_map():
return {
SpecialToolGroup.rag_tool: RAGToolRuntime,
}
def get_all_api_endpoints() -> Dict[Api, List[ApiEndpoint]]: def get_all_api_endpoints() -> Dict[Api, List[ApiEndpoint]]:
apis = {} apis = {}
protocols = api_protocol_map() protocols = api_protocol_map()
toolgroup_protocols = toolgroup_protocol_map()
for api, protocol in protocols.items(): for api, protocol in protocols.items():
endpoints = [] endpoints = []
protocol_methods = inspect.getmembers(protocol, predicate=inspect.isfunction) protocol_methods = inspect.getmembers(protocol, predicate=inspect.isfunction)
# HACK ALERT
if api == Api.tool_runtime:
for tool_group in SpecialToolGroup:
sub_protocol = toolgroup_protocols[tool_group]
sub_protocol_methods = inspect.getmembers(
sub_protocol, predicate=inspect.isfunction
)
for name, method in sub_protocol_methods:
if not hasattr(method, "__webmethod__"):
continue
protocol_methods.append((f"{tool_group.value}.{name}", method))
for name, method in protocol_methods: for name, method in protocol_methods:
if not hasattr(method, "__webmethod__"): if not hasattr(method, "__webmethod__"):
continue continue
webmethod = method.__webmethod__ webmethod = method.__webmethod__
route = f"/{LLAMA_STACK_API_VERSION}/{webmethod.route.lstrip('/')}" route = f"/{LLAMA_STACK_API_VERSION}/{webmethod.route.lstrip('/')}"
if webmethod.method == "GET": if webmethod.method == "GET":
method = "get" method = "get"
elif webmethod.method == "DELETE": elif webmethod.method == "DELETE":

View file

@ -29,7 +29,7 @@ from llama_stack.apis.scoring_functions import ScoringFunctions
from llama_stack.apis.shields import Shields from llama_stack.apis.shields import Shields
from llama_stack.apis.synthetic_data_generation import SyntheticDataGeneration from llama_stack.apis.synthetic_data_generation import SyntheticDataGeneration
from llama_stack.apis.telemetry import Telemetry from llama_stack.apis.telemetry import Telemetry
from llama_stack.apis.tools import ToolGroups, ToolRuntime from llama_stack.apis.tools import RAGToolRuntime, ToolGroups, ToolRuntime
from llama_stack.apis.vector_dbs import VectorDBs from llama_stack.apis.vector_dbs import VectorDBs
from llama_stack.apis.vector_io import VectorIO from llama_stack.apis.vector_io import VectorIO
from llama_stack.distribution.datatypes import StackRunConfig from llama_stack.distribution.datatypes import StackRunConfig
@ -62,6 +62,7 @@ class LlamaStack(
Inspect, Inspect,
ToolGroups, ToolGroups,
ToolRuntime, ToolRuntime,
RAGToolRuntime,
): ):
pass pass

View file

@ -35,7 +35,7 @@ class DistributionRegistry(Protocol):
REGISTER_PREFIX = "distributions:registry" REGISTER_PREFIX = "distributions:registry"
KEY_VERSION = "v5" KEY_VERSION = "v6"
KEY_FORMAT = f"{REGISTER_PREFIX}:{KEY_VERSION}::" + "{type}:{identifier}" KEY_FORMAT = f"{REGISTER_PREFIX}:{KEY_VERSION}::" + "{type}:{identifier}"

View file

@ -19,9 +19,8 @@ async def get_provider_impl(
impl = MetaReferenceAgentsImpl( impl = MetaReferenceAgentsImpl(
config, config,
deps[Api.inference], deps[Api.inference],
deps[Api.memory], deps[Api.vector_io],
deps[Api.safety], deps[Api.safety],
deps[Api.memory_banks],
deps[Api.tool_runtime], deps[Api.tool_runtime],
deps[Api.tool_groups], deps[Api.tool_groups],
) )

View file

@ -59,13 +59,18 @@ from llama_stack.apis.inference import (
ToolResponseMessage, ToolResponseMessage,
UserMessage, UserMessage,
) )
from llama_stack.apis.memory import Memory, MemoryBankDocument
from llama_stack.apis.memory_banks import MemoryBanks, VectorMemoryBankParams
from llama_stack.apis.safety import Safety from llama_stack.apis.safety import Safety
from llama_stack.apis.tools import ToolGroups, ToolRuntime from llama_stack.apis.tools import (
DefaultRAGQueryGeneratorConfig,
RAGDocument,
RAGQueryConfig,
ToolGroups,
ToolRuntime,
)
from llama_stack.apis.vector_io import VectorIO
from llama_stack.providers.utils.kvstore import KVStore from llama_stack.providers.utils.kvstore import KVStore
from llama_stack.providers.utils.memory.vector_store import concat_interleaved_content
from llama_stack.providers.utils.telemetry import tracing from llama_stack.providers.utils.telemetry import tracing
from .persistence import AgentPersistence from .persistence import AgentPersistence
from .safety import SafetyException, ShieldRunnerMixin from .safety import SafetyException, ShieldRunnerMixin
@ -79,7 +84,7 @@ def make_random_string(length: int = 8):
TOOLS_ATTACHMENT_KEY_REGEX = re.compile(r"__tools_attachment__=(\{.*?\})") TOOLS_ATTACHMENT_KEY_REGEX = re.compile(r"__tools_attachment__=(\{.*?\})")
MEMORY_QUERY_TOOL = "query_memory" MEMORY_QUERY_TOOL = "rag_tool.query_context"
WEB_SEARCH_TOOL = "web_search" WEB_SEARCH_TOOL = "web_search"
MEMORY_GROUP = "builtin::memory" MEMORY_GROUP = "builtin::memory"
@ -91,20 +96,18 @@ class ChatAgent(ShieldRunnerMixin):
agent_config: AgentConfig, agent_config: AgentConfig,
tempdir: str, tempdir: str,
inference_api: Inference, inference_api: Inference,
memory_api: Memory,
memory_banks_api: MemoryBanks,
safety_api: Safety, safety_api: Safety,
tool_runtime_api: ToolRuntime, tool_runtime_api: ToolRuntime,
tool_groups_api: ToolGroups, tool_groups_api: ToolGroups,
vector_io_api: VectorIO,
persistence_store: KVStore, persistence_store: KVStore,
): ):
self.agent_id = agent_id self.agent_id = agent_id
self.agent_config = agent_config self.agent_config = agent_config
self.tempdir = tempdir self.tempdir = tempdir
self.inference_api = inference_api self.inference_api = inference_api
self.memory_api = memory_api
self.memory_banks_api = memory_banks_api
self.safety_api = safety_api self.safety_api = safety_api
self.vector_io_api = vector_io_api
self.storage = AgentPersistence(agent_id, persistence_store) self.storage = AgentPersistence(agent_id, persistence_store)
self.tool_runtime_api = tool_runtime_api self.tool_runtime_api = tool_runtime_api
self.tool_groups_api = tool_groups_api self.tool_groups_api = tool_groups_api
@ -370,24 +373,30 @@ class ChatAgent(ShieldRunnerMixin):
documents: Optional[List[Document]] = None, documents: Optional[List[Document]] = None,
toolgroups_for_turn: Optional[List[AgentToolGroup]] = None, toolgroups_for_turn: Optional[List[AgentToolGroup]] = None,
) -> AsyncGenerator: ) -> AsyncGenerator:
# TODO: simplify all of this code, it can be simpler
toolgroup_args = {} toolgroup_args = {}
toolgroups = set()
for toolgroup in self.agent_config.toolgroups: for toolgroup in self.agent_config.toolgroups:
if isinstance(toolgroup, AgentToolGroupWithArgs): if isinstance(toolgroup, AgentToolGroupWithArgs):
toolgroups.add(toolgroup.name)
toolgroup_args[toolgroup.name] = toolgroup.args toolgroup_args[toolgroup.name] = toolgroup.args
else:
toolgroups.add(toolgroup)
if toolgroups_for_turn: if toolgroups_for_turn:
for toolgroup in toolgroups_for_turn: for toolgroup in toolgroups_for_turn:
if isinstance(toolgroup, AgentToolGroupWithArgs): if isinstance(toolgroup, AgentToolGroupWithArgs):
toolgroups.add(toolgroup.name)
toolgroup_args[toolgroup.name] = toolgroup.args toolgroup_args[toolgroup.name] = toolgroup.args
else:
toolgroups.add(toolgroup)
tool_defs, tool_to_group = await self._get_tool_defs(toolgroups_for_turn) tool_defs, tool_to_group = await self._get_tool_defs(toolgroups_for_turn)
if documents: if documents:
await self.handle_documents( await self.handle_documents(
session_id, documents, input_messages, tool_defs session_id, documents, input_messages, tool_defs
) )
if MEMORY_QUERY_TOOL in tool_defs and len(input_messages) > 0:
memory_tool_group = tool_to_group.get(MEMORY_QUERY_TOOL, None) if MEMORY_GROUP in toolgroups and len(input_messages) > 0:
if memory_tool_group is None:
raise ValueError(f"Memory tool group not found for {MEMORY_QUERY_TOOL}")
with tracing.span(MEMORY_QUERY_TOOL) as span: with tracing.span(MEMORY_QUERY_TOOL) as span:
step_id = str(uuid.uuid4()) step_id = str(uuid.uuid4())
yield AgentTurnResponseStreamChunk( yield AgentTurnResponseStreamChunk(
@ -398,17 +407,15 @@ class ChatAgent(ShieldRunnerMixin):
) )
) )
) )
query_args = {
"messages": [msg.content for msg in input_messages],
**toolgroup_args.get(memory_tool_group, {}),
}
args = toolgroup_args.get(MEMORY_GROUP, {})
vector_db_ids = args.get("vector_db_ids", [])
session_info = await self.storage.get_session_info(session_id) session_info = await self.storage.get_session_info(session_id)
# if the session has a memory bank id, let the memory tool use it # if the session has a memory bank id, let the memory tool use it
if session_info.memory_bank_id: if session_info.memory_bank_id:
if "memory_bank_ids" not in query_args: vector_db_ids.append(session_info.memory_bank_id)
query_args["memory_bank_ids"] = []
query_args["memory_bank_ids"].append(session_info.memory_bank_id)
yield AgentTurnResponseStreamChunk( yield AgentTurnResponseStreamChunk(
event=AgentTurnResponseEvent( event=AgentTurnResponseEvent(
payload=AgentTurnResponseStepProgressPayload( payload=AgentTurnResponseStepProgressPayload(
@ -425,10 +432,18 @@ class ChatAgent(ShieldRunnerMixin):
) )
) )
) )
result = await self.tool_runtime_api.invoke_tool( result = await self.tool_runtime_api.rag_tool.query_context(
tool_name=MEMORY_QUERY_TOOL, content=concat_interleaved_content(
args=query_args, [msg.content for msg in input_messages]
),
query_config=RAGQueryConfig(
query_generator_config=DefaultRAGQueryGeneratorConfig(),
max_tokens_in_context=4096,
max_chunks=5,
),
vector_db_ids=vector_db_ids,
) )
retrieved_context = result.content
yield AgentTurnResponseStreamChunk( yield AgentTurnResponseStreamChunk(
event=AgentTurnResponseEvent( event=AgentTurnResponseEvent(
@ -449,7 +464,7 @@ class ChatAgent(ShieldRunnerMixin):
ToolResponse( ToolResponse(
call_id="", call_id="",
tool_name=MEMORY_QUERY_TOOL, tool_name=MEMORY_QUERY_TOOL,
content=result.content, content=retrieved_context or [],
) )
], ],
), ),
@ -459,13 +474,11 @@ class ChatAgent(ShieldRunnerMixin):
span.set_attribute( span.set_attribute(
"input", [m.model_dump_json() for m in input_messages] "input", [m.model_dump_json() for m in input_messages]
) )
span.set_attribute("output", result.content) span.set_attribute("output", retrieved_context)
span.set_attribute("error_code", result.error_code)
span.set_attribute("error_message", result.error_message)
span.set_attribute("tool_name", MEMORY_QUERY_TOOL) span.set_attribute("tool_name", MEMORY_QUERY_TOOL)
if result.error_code == 0: if retrieved_context:
last_message = input_messages[-1] last_message = input_messages[-1]
last_message.context = result.content last_message.context = retrieved_context
output_attachments = [] output_attachments = []
@ -842,12 +855,13 @@ class ChatAgent(ShieldRunnerMixin):
if session_info.memory_bank_id is None: if session_info.memory_bank_id is None:
bank_id = f"memory_bank_{session_id}" bank_id = f"memory_bank_{session_id}"
await self.memory_banks_api.register_memory_bank(
memory_bank_id=bank_id, # TODO: the semantic for registration is definitely not "creation"
params=VectorMemoryBankParams( # so we need to fix it if we expect the agent to create a new vector db
# for each session
await self.vector_io_api.register_vector_db(
vector_db_id=bank_id,
embedding_model="all-MiniLM-L6-v2", embedding_model="all-MiniLM-L6-v2",
chunk_size_in_tokens=512,
),
) )
await self.storage.add_memory_bank_to_session(session_id, bank_id) await self.storage.add_memory_bank_to_session(session_id, bank_id)
else: else:
@ -858,9 +872,9 @@ class ChatAgent(ShieldRunnerMixin):
async def add_to_session_memory_bank( async def add_to_session_memory_bank(
self, session_id: str, data: List[Document] self, session_id: str, data: List[Document]
) -> None: ) -> None:
bank_id = await self._ensure_memory_bank(session_id) vector_db_id = await self._ensure_memory_bank(session_id)
documents = [ documents = [
MemoryBankDocument( RAGDocument(
document_id=str(uuid.uuid4()), document_id=str(uuid.uuid4()),
content=a.content, content=a.content,
mime_type=a.mime_type, mime_type=a.mime_type,
@ -868,9 +882,10 @@ class ChatAgent(ShieldRunnerMixin):
) )
for a in data for a in data
] ]
await self.memory_api.insert_documents( await self.tool_runtime_api.rag_tool.insert_documents(
bank_id=bank_id,
documents=documents, documents=documents,
vector_db_id=vector_db_id,
chunk_size_in_tokens=512,
) )
@ -955,7 +970,7 @@ async def execute_tool_call_maybe(
result = await tool_runtime_api.invoke_tool( result = await tool_runtime_api.invoke_tool(
tool_name=name, tool_name=name,
args=dict( kwargs=dict(
session_id=session_id, session_id=session_id,
**tool_call_args, **tool_call_args,
), ),

View file

@ -26,10 +26,9 @@ from llama_stack.apis.agents import (
Turn, Turn,
) )
from llama_stack.apis.inference import Inference, ToolResponseMessage, UserMessage from llama_stack.apis.inference import Inference, ToolResponseMessage, UserMessage
from llama_stack.apis.memory import Memory
from llama_stack.apis.memory_banks import MemoryBanks
from llama_stack.apis.safety import Safety from llama_stack.apis.safety import Safety
from llama_stack.apis.tools import ToolGroups, ToolRuntime from llama_stack.apis.tools import ToolGroups, ToolRuntime
from llama_stack.apis.vector_io import VectorIO
from llama_stack.providers.utils.kvstore import InmemoryKVStoreImpl, kvstore_impl from llama_stack.providers.utils.kvstore import InmemoryKVStoreImpl, kvstore_impl
from .agent_instance import ChatAgent from .agent_instance import ChatAgent
@ -44,17 +43,15 @@ class MetaReferenceAgentsImpl(Agents):
self, self,
config: MetaReferenceAgentsImplConfig, config: MetaReferenceAgentsImplConfig,
inference_api: Inference, inference_api: Inference,
memory_api: Memory, vector_io_api: VectorIO,
safety_api: Safety, safety_api: Safety,
memory_banks_api: MemoryBanks,
tool_runtime_api: ToolRuntime, tool_runtime_api: ToolRuntime,
tool_groups_api: ToolGroups, tool_groups_api: ToolGroups,
): ):
self.config = config self.config = config
self.inference_api = inference_api self.inference_api = inference_api
self.memory_api = memory_api self.vector_io_api = vector_io_api
self.safety_api = safety_api self.safety_api = safety_api
self.memory_banks_api = memory_banks_api
self.tool_runtime_api = tool_runtime_api self.tool_runtime_api = tool_runtime_api
self.tool_groups_api = tool_groups_api self.tool_groups_api = tool_groups_api
@ -114,8 +111,7 @@ class MetaReferenceAgentsImpl(Agents):
tempdir=self.tempdir, tempdir=self.tempdir,
inference_api=self.inference_api, inference_api=self.inference_api,
safety_api=self.safety_api, safety_api=self.safety_api,
memory_api=self.memory_api, vector_io_api=self.vector_io_api,
memory_banks_api=self.memory_banks_api,
tool_runtime_api=self.tool_runtime_api, tool_runtime_api=self.tool_runtime_api,
tool_groups_api=self.tool_groups_api, tool_groups_api=self.tool_groups_api,
persistence_store=( persistence_store=(

View file

@ -60,9 +60,9 @@ class CodeInterpreterToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime):
] ]
async def invoke_tool( async def invoke_tool(
self, tool_name: str, args: Dict[str, Any] self, tool_name: str, kwargs: Dict[str, Any]
) -> ToolInvocationResult: ) -> ToolInvocationResult:
script = args["code"] script = kwargs["code"]
req = CodeExecutionRequest(scripts=[script]) req = CodeExecutionRequest(scripts=[script])
res = self.code_executor.execute(req) res = self.code_executor.execute(req)
pieces = [res["process_status"]] pieces = [res["process_status"]]

View file

@ -13,8 +13,6 @@ from .memory import MemoryToolRuntimeImpl
async def get_provider_impl(config: MemoryToolRuntimeConfig, deps: Dict[str, Any]): async def get_provider_impl(config: MemoryToolRuntimeConfig, deps: Dict[str, Any]):
impl = MemoryToolRuntimeImpl( impl = MemoryToolRuntimeImpl(config, deps[Api.vector_io], deps[Api.inference])
config, deps[Api.memory], deps[Api.memory_banks], deps[Api.inference]
)
await impl.initialize() await impl.initialize()
return impl return impl

View file

@ -4,87 +4,8 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
from enum import Enum from pydantic import BaseModel
from typing import Annotated, List, Literal, Union
from pydantic import BaseModel, Field
class _MemoryBankConfigCommon(BaseModel):
bank_id: str
class VectorMemoryBankConfig(_MemoryBankConfigCommon):
type: Literal["vector"] = "vector"
class KeyValueMemoryBankConfig(_MemoryBankConfigCommon):
type: Literal["keyvalue"] = "keyvalue"
keys: List[str] # what keys to focus on
class KeywordMemoryBankConfig(_MemoryBankConfigCommon):
type: Literal["keyword"] = "keyword"
class GraphMemoryBankConfig(_MemoryBankConfigCommon):
type: Literal["graph"] = "graph"
entities: List[str] # what entities to focus on
MemoryBankConfig = Annotated[
Union[
VectorMemoryBankConfig,
KeyValueMemoryBankConfig,
KeywordMemoryBankConfig,
GraphMemoryBankConfig,
],
Field(discriminator="type"),
]
class MemoryQueryGenerator(Enum):
default = "default"
llm = "llm"
custom = "custom"
class DefaultMemoryQueryGeneratorConfig(BaseModel):
type: Literal[MemoryQueryGenerator.default.value] = (
MemoryQueryGenerator.default.value
)
sep: str = " "
class LLMMemoryQueryGeneratorConfig(BaseModel):
type: Literal[MemoryQueryGenerator.llm.value] = MemoryQueryGenerator.llm.value
model: str
template: str
class CustomMemoryQueryGeneratorConfig(BaseModel):
type: Literal[MemoryQueryGenerator.custom.value] = MemoryQueryGenerator.custom.value
MemoryQueryGeneratorConfig = Annotated[
Union[
DefaultMemoryQueryGeneratorConfig,
LLMMemoryQueryGeneratorConfig,
CustomMemoryQueryGeneratorConfig,
],
Field(discriminator="type"),
]
class MemoryToolConfig(BaseModel):
memory_bank_configs: List[MemoryBankConfig] = Field(default_factory=list)
class MemoryToolRuntimeConfig(BaseModel): class MemoryToolRuntimeConfig(BaseModel):
# This config defines how a query is generated using the messages pass
# for memory bank retrieval.
query_generator_config: MemoryQueryGeneratorConfig = Field(
default=DefaultMemoryQueryGeneratorConfig()
)
max_tokens_in_context: int = 4096
max_chunks: int = 5

View file

@ -5,68 +5,64 @@
# the root directory of this source tree. # the root directory of this source tree.
from typing import List
from jinja2 import Template from jinja2 import Template
from pydantic import BaseModel
from llama_stack.apis.common.content_types import InterleavedContent from llama_stack.apis.common.content_types import InterleavedContent
from llama_stack.apis.inference import UserMessage from llama_stack.apis.inference import UserMessage
from llama_stack.apis.tools.rag_tool import (
DefaultRAGQueryGeneratorConfig,
LLMRAGQueryGeneratorConfig,
RAGQueryGenerator,
RAGQueryGeneratorConfig,
)
from llama_stack.providers.utils.inference.prompt_adapter import ( from llama_stack.providers.utils.inference.prompt_adapter import (
interleaved_content_as_str, interleaved_content_as_str,
) )
from .config import (
DefaultMemoryQueryGeneratorConfig,
LLMMemoryQueryGeneratorConfig,
MemoryQueryGenerator,
MemoryQueryGeneratorConfig,
)
async def generate_rag_query( async def generate_rag_query(
config: MemoryQueryGeneratorConfig, config: RAGQueryGeneratorConfig,
messages: List[InterleavedContent], content: InterleavedContent,
**kwargs, **kwargs,
): ):
""" """
Generates a query that will be used for Generates a query that will be used for
retrieving relevant information from the memory bank. retrieving relevant information from the memory bank.
""" """
if config.type == MemoryQueryGenerator.default.value: if config.type == RAGQueryGenerator.default.value:
query = await default_rag_query_generator(config, messages, **kwargs) query = await default_rag_query_generator(config, content, **kwargs)
elif config.type == MemoryQueryGenerator.llm.value: elif config.type == RAGQueryGenerator.llm.value:
query = await llm_rag_query_generator(config, messages, **kwargs) query = await llm_rag_query_generator(config, content, **kwargs)
else: else:
raise NotImplementedError(f"Unsupported memory query generator {config.type}") raise NotImplementedError(f"Unsupported memory query generator {config.type}")
return query return query
async def default_rag_query_generator( async def default_rag_query_generator(
config: DefaultMemoryQueryGeneratorConfig, config: DefaultRAGQueryGeneratorConfig,
messages: List[InterleavedContent], content: InterleavedContent,
**kwargs, **kwargs,
): ):
return config.sep.join(interleaved_content_as_str(m) for m in messages) return interleaved_content_as_str(content, sep=config.separator)
async def llm_rag_query_generator( async def llm_rag_query_generator(
config: LLMMemoryQueryGeneratorConfig, config: LLMRAGQueryGeneratorConfig,
messages: List[InterleavedContent], content: InterleavedContent,
**kwargs, **kwargs,
): ):
assert "inference_api" in kwargs, "LLMRAGQueryGenerator needs inference_api" assert "inference_api" in kwargs, "LLMRAGQueryGenerator needs inference_api"
inference_api = kwargs["inference_api"] inference_api = kwargs["inference_api"]
m_dict = { messages = []
"messages": [ if isinstance(content, list):
message.model_dump() if isinstance(message, BaseModel) else message messages = [interleaved_content_as_str(m) for m in content]
for message in messages else:
] messages = [interleaved_content_as_str(content)]
}
template = Template(config.template) template = Template(config.template)
content = template.render(m_dict) content = template.render({"messages": messages})
model = config.model model = config.model
message = UserMessage(content=content) message = UserMessage(content=content)

View file

@ -10,20 +10,29 @@ import secrets
import string import string
from typing import Any, Dict, List, Optional from typing import Any, Dict, List, Optional
from llama_stack.apis.common.content_types import URL from llama_stack.apis.common.content_types import (
from llama_stack.apis.inference import Inference, InterleavedContent InterleavedContent,
from llama_stack.apis.memory import Memory, QueryDocumentsResponse TextContentItem,
from llama_stack.apis.memory_banks import MemoryBanks URL,
)
from llama_stack.apis.inference import Inference
from llama_stack.apis.tools import ( from llama_stack.apis.tools import (
RAGDocument,
RAGQueryConfig,
RAGQueryResult,
RAGToolRuntime,
ToolDef, ToolDef,
ToolInvocationResult, ToolInvocationResult,
ToolParameter,
ToolRuntime, ToolRuntime,
) )
from llama_stack.apis.vector_io import QueryChunksResponse, VectorIO
from llama_stack.providers.datatypes import ToolsProtocolPrivate from llama_stack.providers.datatypes import ToolsProtocolPrivate
from llama_stack.providers.utils.memory.vector_store import concat_interleaved_content from llama_stack.providers.utils.memory.vector_store import (
content_from_doc,
make_overlapped_chunks,
)
from .config import MemoryToolConfig, MemoryToolRuntimeConfig from .config import MemoryToolRuntimeConfig
from .context_retriever import generate_rag_query from .context_retriever import generate_rag_query
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
@ -35,65 +44,79 @@ def make_random_string(length: int = 8):
) )
class MemoryToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime): class MemoryToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, RAGToolRuntime):
def __init__( def __init__(
self, self,
config: MemoryToolRuntimeConfig, config: MemoryToolRuntimeConfig,
memory_api: Memory, vector_io_api: VectorIO,
memory_banks_api: MemoryBanks,
inference_api: Inference, inference_api: Inference,
): ):
self.config = config self.config = config
self.memory_api = memory_api self.vector_io_api = vector_io_api
self.memory_banks_api = memory_banks_api
self.inference_api = inference_api self.inference_api = inference_api
async def initialize(self): async def initialize(self):
pass pass
async def list_runtime_tools( async def shutdown(self):
self, tool_group_id: Optional[str] = None, mcp_endpoint: Optional[URL] = None pass
) -> List[ToolDef]:
return [ async def insert_documents(
ToolDef( self,
name="query_memory", documents: List[RAGDocument],
description="Retrieve context from memory", vector_db_id: str,
parameters=[ chunk_size_in_tokens: int = 512,
ToolParameter( ) -> None:
name="messages", chunks = []
description="The input messages to search for", for doc in documents:
parameter_type="array", content = await content_from_doc(doc)
), chunks.extend(
], make_overlapped_chunks(
) doc.document_id,
] content,
chunk_size_in_tokens,
chunk_size_in_tokens // 4,
)
)
if not chunks:
return
await self.vector_io_api.insert_chunks(
chunks=chunks,
vector_db_id=vector_db_id,
)
async def query_context(
self,
content: InterleavedContent,
query_config: RAGQueryConfig,
vector_db_ids: List[str],
) -> RAGQueryResult:
if not vector_db_ids:
return RAGQueryResult(content=None)
async def _retrieve_context(
self, input_messages: List[InterleavedContent], bank_ids: List[str]
) -> Optional[List[InterleavedContent]]:
if not bank_ids:
return None
query = await generate_rag_query( query = await generate_rag_query(
self.config.query_generator_config, query_config.query_generator_config,
input_messages, content,
inference_api=self.inference_api, inference_api=self.inference_api,
) )
tasks = [ tasks = [
self.memory_api.query_documents( self.vector_io_api.query_chunks(
bank_id=bank_id, vector_db_id=vector_db_id,
query=query, query=query,
params={ params={
"max_chunks": self.config.max_chunks, "max_chunks": query_config.max_chunks,
}, },
) )
for bank_id in bank_ids for vector_db_id in vector_db_ids
] ]
results: List[QueryDocumentsResponse] = await asyncio.gather(*tasks) results: List[QueryChunksResponse] = await asyncio.gather(*tasks)
chunks = [c for r in results for c in r.chunks] chunks = [c for r in results for c in r.chunks]
scores = [s for r in results for s in r.scores] scores = [s for r in results for s in r.scores]
if not chunks: if not chunks:
return None return RAGQueryResult(content=None)
# sort by score # sort by score
chunks, scores = zip( chunks, scores = zip(
@ -102,45 +125,52 @@ class MemoryToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime):
tokens = 0 tokens = 0
picked = [] picked = []
for c in chunks[: self.config.max_chunks]: for c in chunks[: query_config.max_chunks]:
tokens += c.token_count metadata = c.metadata
if tokens > self.config.max_tokens_in_context: tokens += metadata["token_count"]
if tokens > query_config.max_tokens_in_context:
log.error( log.error(
f"Using {len(picked)} chunks; reached max tokens in context: {tokens}", f"Using {len(picked)} chunks; reached max tokens in context: {tokens}",
) )
break break
picked.append(f"id:{c.document_id}; content:{c.content}") picked.append(
TextContentItem(
text=f"id:{metadata['document_id']}; content:{c.content}",
)
)
return [ return RAGQueryResult(
"Here are the retrieved documents for relevant context:\n=== START-RETRIEVED-CONTEXT ===\n", content=[
TextContentItem(
text="Here are the retrieved documents for relevant context:\n=== START-RETRIEVED-CONTEXT ===\n",
),
*picked, *picked,
"\n=== END-RETRIEVED-CONTEXT ===\n", TextContentItem(
text="\n=== END-RETRIEVED-CONTEXT ===\n",
),
],
)
async def list_runtime_tools(
self, tool_group_id: Optional[str] = None, mcp_endpoint: Optional[URL] = None
) -> List[ToolDef]:
# Parameters are not listed since these methods are not yet invoked automatically
# by the LLM. The method is only implemented so things like /tools can list without
# encountering fatals.
return [
ToolDef(
name="rag_tool.query_context",
description="Retrieve context from memory",
),
ToolDef(
name="rag_tool.insert_documents",
description="Insert documents into memory",
),
] ]
async def invoke_tool( async def invoke_tool(
self, tool_name: str, args: Dict[str, Any] self, tool_name: str, kwargs: Dict[str, Any]
) -> ToolInvocationResult: ) -> ToolInvocationResult:
tool = await self.tool_store.get_tool(tool_name) raise RuntimeError(
tool_group = await self.tool_store.get_tool_group(tool.toolgroup_id) "This toolgroup should not be called generically but only through specific methods of the RAGToolRuntime protocol"
final_args = tool_group.args or {}
final_args.update(args)
config = MemoryToolConfig()
if tool.metadata and tool.metadata.get("config") is not None:
config = MemoryToolConfig(**tool.metadata["config"])
if "memory_bank_ids" in final_args:
bank_ids = final_args["memory_bank_ids"]
else:
bank_ids = [
bank_config.bank_id for bank_config in config.memory_bank_configs
]
if "messages" not in final_args:
raise ValueError("messages are required")
context = await self._retrieve_context(
final_args["messages"],
bank_ids,
)
if context is None:
context = []
return ToolInvocationResult(
content=concat_interleaved_content(context), error_code=0
) )

View file

@ -23,7 +23,7 @@ def available_providers() -> List[ProviderSpec]:
pip_packages=[], pip_packages=[],
module="llama_stack.providers.inline.tool_runtime.memory", module="llama_stack.providers.inline.tool_runtime.memory",
config_class="llama_stack.providers.inline.tool_runtime.memory.config.MemoryToolRuntimeConfig", config_class="llama_stack.providers.inline.tool_runtime.memory.config.MemoryToolRuntimeConfig",
api_dependencies=[Api.vector_io, Api.vector_dbs, Api.inference], api_dependencies=[Api.vector_io, Api.inference],
), ),
InlineProviderSpec( InlineProviderSpec(
api=Api.tool_runtime, api=Api.tool_runtime,

View file

@ -68,7 +68,7 @@ class BingSearchToolRuntimeImpl(
] ]
async def invoke_tool( async def invoke_tool(
self, tool_name: str, args: Dict[str, Any] self, tool_name: str, kwargs: Dict[str, Any]
) -> ToolInvocationResult: ) -> ToolInvocationResult:
api_key = self._get_api_key() api_key = self._get_api_key()
headers = { headers = {
@ -78,7 +78,7 @@ class BingSearchToolRuntimeImpl(
"count": self.config.top_k, "count": self.config.top_k,
"textDecorations": True, "textDecorations": True,
"textFormat": "HTML", "textFormat": "HTML",
"q": args["query"], "q": kwargs["query"],
} }
response = requests.get( response = requests.get(

View file

@ -68,7 +68,7 @@ class BraveSearchToolRuntimeImpl(
] ]
async def invoke_tool( async def invoke_tool(
self, tool_name: str, args: Dict[str, Any] self, tool_name: str, kwargs: Dict[str, Any]
) -> ToolInvocationResult: ) -> ToolInvocationResult:
api_key = self._get_api_key() api_key = self._get_api_key()
url = "https://api.search.brave.com/res/v1/web/search" url = "https://api.search.brave.com/res/v1/web/search"
@ -77,7 +77,7 @@ class BraveSearchToolRuntimeImpl(
"Accept-Encoding": "gzip", "Accept-Encoding": "gzip",
"Accept": "application/json", "Accept": "application/json",
} }
payload = {"q": args["query"]} payload = {"q": kwargs["query"]}
response = requests.get(url=url, params=payload, headers=headers) response = requests.get(url=url, params=payload, headers=headers)
response.raise_for_status() response.raise_for_status()
results = self._clean_brave_response(response.json()) results = self._clean_brave_response(response.json())

View file

@ -65,7 +65,7 @@ class ModelContextProtocolToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime):
return tools return tools
async def invoke_tool( async def invoke_tool(
self, tool_name: str, args: Dict[str, Any] self, tool_name: str, kwargs: Dict[str, Any]
) -> ToolInvocationResult: ) -> ToolInvocationResult:
tool = await self.tool_store.get_tool(tool_name) tool = await self.tool_store.get_tool(tool_name)
if tool.metadata is None or tool.metadata.get("endpoint") is None: if tool.metadata is None or tool.metadata.get("endpoint") is None:
@ -77,7 +77,7 @@ class ModelContextProtocolToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime):
async with sse_client(endpoint) as streams: async with sse_client(endpoint) as streams:
async with ClientSession(*streams) as session: async with ClientSession(*streams) as session:
await session.initialize() await session.initialize()
result = await session.call_tool(tool.identifier, args) result = await session.call_tool(tool.identifier, kwargs)
return ToolInvocationResult( return ToolInvocationResult(
content="\n".join([result.model_dump_json() for result in result.content]), content="\n".join([result.model_dump_json() for result in result.content]),

View file

@ -67,12 +67,12 @@ class TavilySearchToolRuntimeImpl(
] ]
async def invoke_tool( async def invoke_tool(
self, tool_name: str, args: Dict[str, Any] self, tool_name: str, kwargs: Dict[str, Any]
) -> ToolInvocationResult: ) -> ToolInvocationResult:
api_key = self._get_api_key() api_key = self._get_api_key()
response = requests.post( response = requests.post(
"https://api.tavily.com/search", "https://api.tavily.com/search",
json={"api_key": api_key, "query": args["query"]}, json={"api_key": api_key, "query": kwargs["query"]},
) )
return ToolInvocationResult( return ToolInvocationResult(

View file

@ -68,11 +68,11 @@ class WolframAlphaToolRuntimeImpl(
] ]
async def invoke_tool( async def invoke_tool(
self, tool_name: str, args: Dict[str, Any] self, tool_name: str, kwargs: Dict[str, Any]
) -> ToolInvocationResult: ) -> ToolInvocationResult:
api_key = self._get_api_key() api_key = self._get_api_key()
params = { params = {
"input": args["query"], "input": kwargs["query"],
"appid": api_key, "appid": api_key,
"format": "plaintext", "format": "plaintext",
"output": "json", "output": "json",

View file

@ -12,10 +12,10 @@ from ..conftest import (
get_test_config_for_api, get_test_config_for_api,
) )
from ..inference.fixtures import INFERENCE_FIXTURES from ..inference.fixtures import INFERENCE_FIXTURES
from ..memory.fixtures import MEMORY_FIXTURES
from ..safety.fixtures import SAFETY_FIXTURES, safety_model_from_shield from ..safety.fixtures import SAFETY_FIXTURES, safety_model_from_shield
from ..tools.fixtures import TOOL_RUNTIME_FIXTURES from ..tools.fixtures import TOOL_RUNTIME_FIXTURES
from ..vector_io.fixtures import VECTOR_IO_FIXTURES
from .fixtures import AGENTS_FIXTURES from .fixtures import AGENTS_FIXTURES
DEFAULT_PROVIDER_COMBINATIONS = [ DEFAULT_PROVIDER_COMBINATIONS = [
@ -23,7 +23,7 @@ DEFAULT_PROVIDER_COMBINATIONS = [
{ {
"inference": "meta_reference", "inference": "meta_reference",
"safety": "llama_guard", "safety": "llama_guard",
"memory": "faiss", "vector_io": "faiss",
"agents": "meta_reference", "agents": "meta_reference",
"tool_runtime": "memory_and_search", "tool_runtime": "memory_and_search",
}, },
@ -34,7 +34,7 @@ DEFAULT_PROVIDER_COMBINATIONS = [
{ {
"inference": "ollama", "inference": "ollama",
"safety": "llama_guard", "safety": "llama_guard",
"memory": "faiss", "vector_io": "faiss",
"agents": "meta_reference", "agents": "meta_reference",
"tool_runtime": "memory_and_search", "tool_runtime": "memory_and_search",
}, },
@ -46,7 +46,7 @@ DEFAULT_PROVIDER_COMBINATIONS = [
"inference": "together", "inference": "together",
"safety": "llama_guard", "safety": "llama_guard",
# make this work with Weaviate which is what the together distro supports # make this work with Weaviate which is what the together distro supports
"memory": "faiss", "vector_io": "faiss",
"agents": "meta_reference", "agents": "meta_reference",
"tool_runtime": "memory_and_search", "tool_runtime": "memory_and_search",
}, },
@ -57,7 +57,7 @@ DEFAULT_PROVIDER_COMBINATIONS = [
{ {
"inference": "fireworks", "inference": "fireworks",
"safety": "llama_guard", "safety": "llama_guard",
"memory": "faiss", "vector_io": "faiss",
"agents": "meta_reference", "agents": "meta_reference",
"tool_runtime": "memory_and_search", "tool_runtime": "memory_and_search",
}, },
@ -68,7 +68,7 @@ DEFAULT_PROVIDER_COMBINATIONS = [
{ {
"inference": "remote", "inference": "remote",
"safety": "remote", "safety": "remote",
"memory": "remote", "vector_io": "remote",
"agents": "remote", "agents": "remote",
"tool_runtime": "memory_and_search", "tool_runtime": "memory_and_search",
}, },
@ -115,7 +115,7 @@ def pytest_generate_tests(metafunc):
available_fixtures = { available_fixtures = {
"inference": INFERENCE_FIXTURES, "inference": INFERENCE_FIXTURES,
"safety": SAFETY_FIXTURES, "safety": SAFETY_FIXTURES,
"memory": MEMORY_FIXTURES, "vector_io": VECTOR_IO_FIXTURES,
"agents": AGENTS_FIXTURES, "agents": AGENTS_FIXTURES,
"tool_runtime": TOOL_RUNTIME_FIXTURES, "tool_runtime": TOOL_RUNTIME_FIXTURES,
} }

View file

@ -69,7 +69,7 @@ async def agents_stack(
providers = {} providers = {}
provider_data = {} provider_data = {}
for key in ["inference", "safety", "memory", "agents", "tool_runtime"]: for key in ["inference", "safety", "vector_io", "agents", "tool_runtime"]:
fixture = request.getfixturevalue(f"{key}_{fixture_dict[key]}") fixture = request.getfixturevalue(f"{key}_{fixture_dict[key]}")
providers[key] = fixture.providers providers[key] = fixture.providers
if key == "inference": if key == "inference":
@ -118,7 +118,7 @@ async def agents_stack(
) )
test_stack = await construct_stack_for_test( test_stack = await construct_stack_for_test(
[Api.agents, Api.inference, Api.safety, Api.memory, Api.tool_runtime], [Api.agents, Api.inference, Api.safety, Api.vector_io, Api.tool_runtime],
providers, providers,
provider_data, provider_data,
models=models, models=models,

View file

@ -214,9 +214,11 @@ class TestAgents:
turn_response = [ turn_response = [
chunk async for chunk in await agents_impl.create_agent_turn(**turn_request) chunk async for chunk in await agents_impl.create_agent_turn(**turn_request)
] ]
assert len(turn_response) > 0 assert len(turn_response) > 0
# FIXME: we need to check the content of the turn response and ensure
# RAG actually worked
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_create_agent_turn_with_tavily_search( async def test_create_agent_turn_with_tavily_search(
self, agents_stack, search_query_messages, common_params self, agents_stack, search_query_messages, common_params

View file

@ -8,13 +8,12 @@ import uuid
import pytest import pytest
from llama_stack.apis.tools import RAGDocument
from llama_stack.apis.vector_dbs import ListVectorDBsResponse, VectorDB from llama_stack.apis.vector_dbs import ListVectorDBsResponse, VectorDB
from llama_stack.apis.vector_io import QueryChunksResponse from llama_stack.apis.vector_io import QueryChunksResponse
from llama_stack.providers.utils.memory.vector_store import ( from llama_stack.providers.utils.memory.vector_store import make_overlapped_chunks
make_overlapped_chunks,
MemoryBankDocument,
)
# How to run this test: # How to run this test:
# #
@ -26,22 +25,22 @@ from llama_stack.providers.utils.memory.vector_store import (
@pytest.fixture(scope="session") @pytest.fixture(scope="session")
def sample_chunks(): def sample_chunks():
docs = [ docs = [
MemoryBankDocument( RAGDocument(
document_id="doc1", document_id="doc1",
content="Python is a high-level programming language.", content="Python is a high-level programming language.",
metadata={"category": "programming", "difficulty": "beginner"}, metadata={"category": "programming", "difficulty": "beginner"},
), ),
MemoryBankDocument( RAGDocument(
document_id="doc2", document_id="doc2",
content="Machine learning is a subset of artificial intelligence.", content="Machine learning is a subset of artificial intelligence.",
metadata={"category": "AI", "difficulty": "advanced"}, metadata={"category": "AI", "difficulty": "advanced"},
), ),
MemoryBankDocument( RAGDocument(
document_id="doc3", document_id="doc3",
content="Data structures are fundamental to computer science.", content="Data structures are fundamental to computer science.",
metadata={"category": "computer science", "difficulty": "intermediate"}, metadata={"category": "computer science", "difficulty": "intermediate"},
), ),
MemoryBankDocument( RAGDocument(
document_id="doc4", document_id="doc4",
content="Neural networks are inspired by biological neural networks.", content="Neural networks are inspired by biological neural networks.",
metadata={"category": "AI", "difficulty": "advanced"}, metadata={"category": "AI", "difficulty": "advanced"},

View file

@ -19,7 +19,6 @@ import numpy as np
from llama_models.llama3.api.tokenizer import Tokenizer from llama_models.llama3.api.tokenizer import Tokenizer
from numpy.typing import NDArray from numpy.typing import NDArray
from pydantic import BaseModel, Field
from pypdf import PdfReader from pypdf import PdfReader
from llama_stack.apis.common.content_types import ( from llama_stack.apis.common.content_types import (
@ -27,6 +26,7 @@ from llama_stack.apis.common.content_types import (
TextContentItem, TextContentItem,
URL, URL,
) )
from llama_stack.apis.tools import RAGDocument
from llama_stack.apis.vector_dbs import VectorDB from llama_stack.apis.vector_dbs import VectorDB
from llama_stack.apis.vector_io import Chunk, QueryChunksResponse from llama_stack.apis.vector_io import Chunk, QueryChunksResponse
from llama_stack.providers.datatypes import Api from llama_stack.providers.datatypes import Api
@ -34,17 +34,9 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
interleaved_content_as_str, interleaved_content_as_str,
) )
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
class MemoryBankDocument(BaseModel):
document_id: str
content: InterleavedContent | URL
mime_type: str | None = None
metadata: Dict[str, Any] = Field(default_factory=dict)
def parse_pdf(data: bytes) -> str: def parse_pdf(data: bytes) -> str:
# For PDF and DOC/DOCX files, we can't reliably convert to string # For PDF and DOC/DOCX files, we can't reliably convert to string
pdf_bytes = io.BytesIO(data) pdf_bytes = io.BytesIO(data)
@ -122,7 +114,7 @@ def concat_interleaved_content(content: List[InterleavedContent]) -> Interleaved
return ret return ret
async def content_from_doc(doc: MemoryBankDocument) -> str: async def content_from_doc(doc: RAGDocument) -> str:
if isinstance(doc.content, URL): if isinstance(doc.content, URL):
if doc.content.uri.startswith("data:"): if doc.content.uri.startswith("data:"):
return content_from_data(doc.content.uri) return content_from_data(doc.content.uri)
@ -161,7 +153,13 @@ def make_overlapped_chunks(
chunk = tokenizer.decode(toks) chunk = tokenizer.decode(toks)
# chunk is a string # chunk is a string
chunks.append( chunks.append(
Chunk(content=chunk, token_count=len(toks), document_id=document_id) Chunk(
content=chunk,
metadata={
"token_count": len(toks),
"document_id": document_id,
},
)
) )
return chunks return chunks

View file

@ -0,0 +1,105 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import json
from typing import List
import pytest
import requests
from pydantic import TypeAdapter
from llama_stack.apis.tools import (
DefaultRAGQueryGeneratorConfig,
RAGDocument,
RAGQueryConfig,
RAGQueryResult,
)
from llama_stack.apis.vector_dbs import VectorDB
from llama_stack.providers.utils.memory.vector_store import interleaved_content_as_str
class TestRAGToolEndpoints:
@pytest.fixture
def base_url(self) -> str:
return "http://localhost:8321/v1" # Adjust port if needed
@pytest.fixture
def sample_documents(self) -> List[RAGDocument]:
return [
RAGDocument(
document_id="doc1",
content="Python is a high-level programming language.",
metadata={"category": "programming", "difficulty": "beginner"},
),
RAGDocument(
document_id="doc2",
content="Machine learning is a subset of artificial intelligence.",
metadata={"category": "AI", "difficulty": "advanced"},
),
RAGDocument(
document_id="doc3",
content="Data structures are fundamental to computer science.",
metadata={"category": "computer science", "difficulty": "intermediate"},
),
]
@pytest.mark.asyncio
async def test_rag_workflow(
self, base_url: str, sample_documents: List[RAGDocument]
):
vector_db_payload = {
"vector_db_id": "test_vector_db",
"embedding_model": "all-MiniLM-L6-v2",
"embedding_dimension": 384,
}
response = requests.post(f"{base_url}/vector-dbs", json=vector_db_payload)
assert response.status_code == 200
vector_db = VectorDB(**response.json())
insert_payload = {
"documents": [
json.loads(doc.model_dump_json()) for doc in sample_documents
],
"vector_db_id": vector_db.identifier,
"chunk_size_in_tokens": 512,
}
response = requests.post(
f"{base_url}/tool-runtime/rag-tool/insert-documents",
json=insert_payload,
)
assert response.status_code == 200
query = "What is Python?"
query_config = RAGQueryConfig(
query_generator_config=DefaultRAGQueryGeneratorConfig(),
max_tokens_in_context=4096,
max_chunks=2,
)
query_payload = {
"content": query,
"query_config": json.loads(query_config.model_dump_json()),
"vector_db_ids": [vector_db.identifier],
}
response = requests.post(
f"{base_url}/tool-runtime/rag-tool/query-context",
json=query_payload,
)
assert response.status_code == 200
result = response.json()
result = TypeAdapter(RAGQueryResult).validate_python(result)
content_str = interleaved_content_as_str(result.content)
print(f"content: {content_str}")
assert len(content_str) > 0
assert "Python" in content_str
# Clean up: Delete the vector DB
response = requests.delete(f"{base_url}/vector-dbs/{vector_db.identifier}")
assert response.status_code == 200

View file

@ -4,7 +4,7 @@ distribution_spec:
providers: providers:
inference: inference:
- remote::together - remote::together
memory: vector_io:
- inline::faiss - inline::faiss
- remote::chromadb - remote::chromadb
- remote::pgvector - remote::pgvector

View file

@ -5,7 +5,7 @@ apis:
- datasetio - datasetio
- eval - eval
- inference - inference
- memory - vector_io
- safety - safety
- scoring - scoring
- telemetry - telemetry
@ -20,7 +20,7 @@ providers:
- provider_id: sentence-transformers - provider_id: sentence-transformers
provider_type: inline::sentence-transformers provider_type: inline::sentence-transformers
config: {} config: {}
memory: vector_io:
- provider_id: faiss - provider_id: faiss
provider_type: inline::faiss provider_type: inline::faiss
config: config:
@ -145,7 +145,6 @@ models:
model_type: embedding model_type: embedding
shields: shields:
- shield_id: meta-llama/Llama-Guard-3-8B - shield_id: meta-llama/Llama-Guard-3-8B
memory_banks: []
datasets: [] datasets: []
scoring_fns: [] scoring_fns: []
eval_tasks: [] eval_tasks: []