revert: "chore(cleanup)!: remove tool_runtime.rag_tool" (#3877)

Reverts llamastack/llama-stack#3871

This PR broke RAG (even from Responses -- there _is_ a dependency)
This commit is contained in:
Ashwin Bharambe 2025-10-21 11:22:06 -07:00 committed by GitHub
parent eb3e9b85f9
commit bd3c473208
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
55 changed files with 3114 additions and 17 deletions

View file

@ -4,4 +4,5 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .rag_tool import *
from .tools import *

View file

@ -0,0 +1,218 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from enum import Enum, StrEnum
from typing import Annotated, Any, Literal, Protocol
from pydantic import BaseModel, Field, field_validator
from typing_extensions import runtime_checkable
from llama_stack.apis.common.content_types import URL, InterleavedContent
from llama_stack.apis.version import LLAMA_STACK_API_V1
from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol
from llama_stack.schema_utils import json_schema_type, register_schema, webmethod
@json_schema_type
class RRFRanker(BaseModel):
"""
Reciprocal Rank Fusion (RRF) ranker configuration.
:param type: The type of ranker, always "rrf"
:param impact_factor: The impact factor for RRF scoring. Higher values give more weight to higher-ranked results.
Must be greater than 0
"""
type: Literal["rrf"] = "rrf"
impact_factor: float = Field(default=60.0, gt=0.0) # default of 60 for optimal performance
@json_schema_type
class WeightedRanker(BaseModel):
"""
Weighted ranker configuration that combines vector and keyword scores.
:param type: The type of ranker, always "weighted"
:param alpha: Weight factor between 0 and 1.
0 means only use keyword scores,
1 means only use vector scores,
values in between blend both scores.
"""
type: Literal["weighted"] = "weighted"
alpha: float = Field(
default=0.5,
ge=0.0,
le=1.0,
description="Weight factor between 0 and 1. 0 means only keyword scores, 1 means only vector scores.",
)
Ranker = Annotated[
RRFRanker | WeightedRanker,
Field(discriminator="type"),
]
register_schema(Ranker, name="Ranker")
@json_schema_type
class RAGDocument(BaseModel):
"""
A document to be used for document ingestion in the RAG Tool.
:param document_id: The unique identifier for the document.
:param content: The content of the document.
:param mime_type: The MIME type of the document.
:param metadata: Additional metadata for the document.
"""
document_id: str
content: InterleavedContent | URL
mime_type: str | None = None
metadata: dict[str, Any] = Field(default_factory=dict)
@json_schema_type
class RAGQueryResult(BaseModel):
"""Result of a RAG query containing retrieved content and metadata.
:param content: (Optional) The retrieved content from the query
:param metadata: Additional metadata about the query result
"""
content: InterleavedContent | None = None
metadata: dict[str, Any] = Field(default_factory=dict)
@json_schema_type
class RAGQueryGenerator(Enum):
"""Types of query generators for RAG systems.
:cvar default: Default query generator using simple text processing
:cvar llm: LLM-based query generator for enhanced query understanding
:cvar custom: Custom query generator implementation
"""
default = "default"
llm = "llm"
custom = "custom"
@json_schema_type
class RAGSearchMode(StrEnum):
"""
Search modes for RAG query retrieval:
- VECTOR: Uses vector similarity search for semantic matching
- KEYWORD: Uses keyword-based search for exact matching
- HYBRID: Combines both vector and keyword search for better results
"""
VECTOR = "vector"
KEYWORD = "keyword"
HYBRID = "hybrid"
@json_schema_type
class DefaultRAGQueryGeneratorConfig(BaseModel):
"""Configuration for the default RAG query generator.
:param type: Type of query generator, always 'default'
:param separator: String separator used to join query terms
"""
type: Literal["default"] = "default"
separator: str = " "
@json_schema_type
class LLMRAGQueryGeneratorConfig(BaseModel):
"""Configuration for the LLM-based RAG query generator.
:param type: Type of query generator, always 'llm'
:param model: Name of the language model to use for query generation
:param template: Template string for formatting the query generation prompt
"""
type: Literal["llm"] = "llm"
model: str
template: str
RAGQueryGeneratorConfig = Annotated[
DefaultRAGQueryGeneratorConfig | LLMRAGQueryGeneratorConfig,
Field(discriminator="type"),
]
register_schema(RAGQueryGeneratorConfig, name="RAGQueryGeneratorConfig")
@json_schema_type
class RAGQueryConfig(BaseModel):
"""
Configuration for the RAG query generation.
:param query_generator_config: Configuration for the query generator.
:param max_tokens_in_context: Maximum number of tokens in the context.
:param max_chunks: Maximum number of chunks to retrieve.
:param chunk_template: Template for formatting each retrieved chunk in the context.
Available placeholders: {index} (1-based chunk ordinal), {chunk.content} (chunk content string), {metadata} (chunk metadata dict).
Default: "Result {index}\\nContent: {chunk.content}\\nMetadata: {metadata}\\n"
:param mode: Search mode for retrievaleither "vector", "keyword", or "hybrid". Default "vector".
:param ranker: Configuration for the ranker to use in hybrid search. Defaults to RRF ranker.
"""
# This config defines how a query is generated using the messages
# for memory bank retrieval.
query_generator_config: RAGQueryGeneratorConfig = Field(default=DefaultRAGQueryGeneratorConfig())
max_tokens_in_context: int = 4096
max_chunks: int = 5
chunk_template: str = "Result {index}\nContent: {chunk.content}\nMetadata: {metadata}\n"
mode: RAGSearchMode | None = RAGSearchMode.VECTOR
ranker: Ranker | None = Field(default=None) # Only used for hybrid mode
@field_validator("chunk_template")
def validate_chunk_template(cls, v: str) -> str:
if "{chunk.content}" not in v:
raise ValueError("chunk_template must contain {chunk.content}")
if "{index}" not in v:
raise ValueError("chunk_template must contain {index}")
if len(v) == 0:
raise ValueError("chunk_template must not be empty")
return v
@runtime_checkable
@trace_protocol
class RAGToolRuntime(Protocol):
@webmethod(route="/tool-runtime/rag-tool/insert", method="POST", level=LLAMA_STACK_API_V1)
async def insert(
self,
documents: list[RAGDocument],
vector_db_id: str,
chunk_size_in_tokens: int = 512,
) -> None:
"""Index documents so they can be used by the RAG system.
:param documents: List of documents to index in the RAG system
:param vector_db_id: ID of the vector database to store the document embeddings
:param chunk_size_in_tokens: (Optional) Size in tokens for document chunking during indexing
"""
...
@webmethod(route="/tool-runtime/rag-tool/query", method="POST", level=LLAMA_STACK_API_V1)
async def query(
self,
content: InterleavedContent,
vector_db_ids: list[str],
query_config: RAGQueryConfig | None = None,
) -> RAGQueryResult:
"""Query the RAG system for context; typically invoked by the agent.
:param content: The query content to search for in the indexed documents
:param vector_db_ids: List of vector database IDs to search within
:param query_config: (Optional) Configuration parameters for the query operation
:returns: RAGQueryResult containing the retrieved content and metadata
"""
...

View file

@ -4,6 +4,7 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from enum import Enum
from typing import Any, Literal, Protocol
from pydantic import BaseModel
@ -15,6 +16,8 @@ from llama_stack.apis.version import LLAMA_STACK_API_V1
from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol
from llama_stack.schema_utils import json_schema_type, webmethod
from .rag_tool import RAGToolRuntime
@json_schema_type
class ToolDef(BaseModel):
@ -178,11 +181,22 @@ class ToolGroups(Protocol):
...
class SpecialToolGroup(Enum):
"""Special tool groups with predefined functionality.
:cvar rag_tool: Retrieval-Augmented Generation tool group for document search and retrieval
"""
rag_tool = "rag_tool"
@runtime_checkable
@trace_protocol
class ToolRuntime(Protocol):
tool_store: ToolStore | None = None
rag_tool: RAGToolRuntime | None = None
# TODO: This needs to be renamed once OPEN API generator name conflict issue is fixed.
@webmethod(route="/tool-runtime/list-tools", method="GET", level=LLAMA_STACK_API_V1)
async def list_runtime_tools(

View file

@ -8,8 +8,16 @@ from typing import Any
from llama_stack.apis.common.content_types import (
URL,
InterleavedContent,
)
from llama_stack.apis.tools import (
ListToolDefsResponse,
RAGDocument,
RAGQueryConfig,
RAGQueryResult,
RAGToolRuntime,
ToolRuntime,
)
from llama_stack.apis.tools import ListToolDefsResponse, ToolRuntime
from llama_stack.log import get_logger
from ..routing_tables.toolgroups import ToolGroupsRoutingTable
@ -18,6 +26,36 @@ logger = get_logger(name=__name__, category="core::routers")
class ToolRuntimeRouter(ToolRuntime):
class RagToolImpl(RAGToolRuntime):
def __init__(
self,
routing_table: ToolGroupsRoutingTable,
) -> None:
logger.debug("Initializing ToolRuntimeRouter.RagToolImpl")
self.routing_table = routing_table
async def query(
self,
content: InterleavedContent,
vector_store_ids: list[str],
query_config: RAGQueryConfig | None = None,
) -> RAGQueryResult:
logger.debug(f"ToolRuntimeRouter.RagToolImpl.query: {vector_store_ids}")
provider = await self.routing_table.get_provider_impl("knowledge_search")
return await provider.query(content, vector_store_ids, query_config)
async def insert(
self,
documents: list[RAGDocument],
vector_store_id: str,
chunk_size_in_tokens: int = 512,
) -> None:
logger.debug(
f"ToolRuntimeRouter.RagToolImpl.insert: {vector_store_id}, {len(documents)} documents, chunk_size={chunk_size_in_tokens}"
)
provider = await self.routing_table.get_provider_impl("insert_into_memory")
return await provider.insert(documents, vector_store_id, chunk_size_in_tokens)
def __init__(
self,
routing_table: ToolGroupsRoutingTable,
@ -25,6 +63,11 @@ class ToolRuntimeRouter(ToolRuntime):
logger.debug("Initializing ToolRuntimeRouter")
self.routing_table = routing_table
# HACK ALERT this should be in sync with "get_all_api_endpoints()"
self.rag_tool = self.RagToolImpl(routing_table)
for method in ("query", "insert"):
setattr(self, f"rag_tool.{method}", getattr(self.rag_tool, method))
async def initialize(self) -> None:
logger.debug("ToolRuntimeRouter.initialize")
pass

View file

@ -13,6 +13,7 @@ from aiohttp import hdrs
from starlette.routing import Route
from llama_stack.apis.datatypes import Api, ExternalApiSpec
from llama_stack.apis.tools import RAGToolRuntime, SpecialToolGroup
from llama_stack.core.resolver import api_protocol_map
from llama_stack.schema_utils import WebMethod
@ -24,16 +25,33 @@ RouteImpls = dict[str, PathImpl]
RouteMatch = tuple[EndpointFunc, PathParams, str, WebMethod]
def toolgroup_protocol_map():
return {
SpecialToolGroup.rag_tool: RAGToolRuntime,
}
def get_all_api_routes(
external_apis: dict[Api, ExternalApiSpec] | None = None,
) -> dict[Api, list[tuple[Route, WebMethod]]]:
apis = {}
protocols = api_protocol_map(external_apis)
toolgroup_protocols = toolgroup_protocol_map()
for api, protocol in protocols.items():
routes = []
protocol_methods = inspect.getmembers(protocol, predicate=inspect.isfunction)
# HACK ALERT
if api == Api.tool_runtime:
for tool_group in SpecialToolGroup:
sub_protocol = toolgroup_protocols[tool_group]
sub_protocol_methods = inspect.getmembers(sub_protocol, predicate=inspect.isfunction)
for name, method in sub_protocol_methods:
if not hasattr(method, "__webmethod__"):
continue
protocol_methods.append((f"{tool_group.value}.{name}", method))
for name, method in protocol_methods:
# Get all webmethods for this method (supports multiple decorators)
webmethods = getattr(method, "__webmethods__", [])

View file

@ -32,7 +32,7 @@ from llama_stack.apis.scoring_functions import ScoringFunctions
from llama_stack.apis.shields import Shields
from llama_stack.apis.synthetic_data_generation import SyntheticDataGeneration
from llama_stack.apis.telemetry import Telemetry
from llama_stack.apis.tools import ToolGroups, ToolRuntime
from llama_stack.apis.tools import RAGToolRuntime, ToolGroups, ToolRuntime
from llama_stack.apis.vector_io import VectorIO
from llama_stack.core.conversations.conversations import ConversationServiceConfig, ConversationServiceImpl
from llama_stack.core.datatypes import Provider, StackRunConfig, VectorStoresConfig
@ -80,6 +80,7 @@ class LlamaStack(
Inspect,
ToolGroups,
ToolRuntime,
RAGToolRuntime,
Files,
Prompts,
Conversations,

View file

@ -48,6 +48,7 @@ distribution_spec:
tool_runtime:
- provider_type: remote::brave-search
- provider_type: remote::tavily-search
- provider_type: inline::rag-runtime
- provider_type: remote::model-context-protocol
batches:
- provider_type: inline::reference

View file

@ -216,6 +216,8 @@ providers:
config:
api_key: ${env.TAVILY_SEARCH_API_KEY:=}
max_results: 3
- provider_id: rag-runtime
provider_type: inline::rag-runtime
- provider_id: model-context-protocol
provider_type: remote::model-context-protocol
batches:
@ -261,6 +263,8 @@ registered_resources:
tool_groups:
- toolgroup_id: builtin::websearch
provider_id: tavily-search
- toolgroup_id: builtin::rag
provider_id: rag-runtime
server:
port: 8321
telemetry:

View file

@ -26,6 +26,7 @@ distribution_spec:
tool_runtime:
- provider_type: remote::brave-search
- provider_type: remote::tavily-search
- provider_type: inline::rag-runtime
image_type: venv
additional_pip_packages:
- aiosqlite

View file

@ -45,6 +45,7 @@ def get_distribution_template() -> DistributionTemplate:
"tool_runtime": [
BuildProvider(provider_type="remote::brave-search"),
BuildProvider(provider_type="remote::tavily-search"),
BuildProvider(provider_type="inline::rag-runtime"),
],
}
name = "dell"
@ -97,6 +98,10 @@ def get_distribution_template() -> DistributionTemplate:
toolgroup_id="builtin::websearch",
provider_id="brave-search",
),
ToolGroupInput(
toolgroup_id="builtin::rag",
provider_id="rag-runtime",
),
]
return DistributionTemplate(

View file

@ -87,6 +87,8 @@ providers:
config:
api_key: ${env.TAVILY_SEARCH_API_KEY:=}
max_results: 3
- provider_id: rag-runtime
provider_type: inline::rag-runtime
storage:
backends:
kv_default:
@ -131,6 +133,8 @@ registered_resources:
tool_groups:
- toolgroup_id: builtin::websearch
provider_id: brave-search
- toolgroup_id: builtin::rag
provider_id: rag-runtime
server:
port: 8321
telemetry:

View file

@ -83,6 +83,8 @@ providers:
config:
api_key: ${env.TAVILY_SEARCH_API_KEY:=}
max_results: 3
- provider_id: rag-runtime
provider_type: inline::rag-runtime
storage:
backends:
kv_default:
@ -122,6 +124,8 @@ registered_resources:
tool_groups:
- toolgroup_id: builtin::websearch
provider_id: brave-search
- toolgroup_id: builtin::rag
provider_id: rag-runtime
server:
port: 8321
telemetry:

View file

@ -24,6 +24,7 @@ distribution_spec:
tool_runtime:
- provider_type: remote::brave-search
- provider_type: remote::tavily-search
- provider_type: inline::rag-runtime
- provider_type: remote::model-context-protocol
image_type: venv
additional_pip_packages:

View file

@ -47,6 +47,7 @@ def get_distribution_template() -> DistributionTemplate:
"tool_runtime": [
BuildProvider(provider_type="remote::brave-search"),
BuildProvider(provider_type="remote::tavily-search"),
BuildProvider(provider_type="inline::rag-runtime"),
BuildProvider(provider_type="remote::model-context-protocol"),
],
}
@ -91,6 +92,10 @@ def get_distribution_template() -> DistributionTemplate:
toolgroup_id="builtin::websearch",
provider_id="tavily-search",
),
ToolGroupInput(
toolgroup_id="builtin::rag",
provider_id="rag-runtime",
),
]
return DistributionTemplate(

View file

@ -98,6 +98,8 @@ providers:
config:
api_key: ${env.TAVILY_SEARCH_API_KEY:=}
max_results: 3
- provider_id: rag-runtime
provider_type: inline::rag-runtime
- provider_id: model-context-protocol
provider_type: remote::model-context-protocol
storage:
@ -144,6 +146,8 @@ registered_resources:
tool_groups:
- toolgroup_id: builtin::websearch
provider_id: tavily-search
- toolgroup_id: builtin::rag
provider_id: rag-runtime
server:
port: 8321
telemetry:

View file

@ -88,6 +88,8 @@ providers:
config:
api_key: ${env.TAVILY_SEARCH_API_KEY:=}
max_results: 3
- provider_id: rag-runtime
provider_type: inline::rag-runtime
- provider_id: model-context-protocol
provider_type: remote::model-context-protocol
storage:
@ -129,6 +131,8 @@ registered_resources:
tool_groups:
- toolgroup_id: builtin::websearch
provider_id: tavily-search
- toolgroup_id: builtin::rag
provider_id: rag-runtime
server:
port: 8321
telemetry:

View file

@ -19,7 +19,8 @@ distribution_spec:
- provider_type: remote::nvidia
scoring:
- provider_type: inline::basic
tool_runtime: []
tool_runtime:
- provider_type: inline::rag-runtime
files:
- provider_type: inline::localfs
image_type: venv

View file

@ -28,7 +28,7 @@ def get_distribution_template(name: str = "nvidia") -> DistributionTemplate:
BuildProvider(provider_type="remote::nvidia"),
],
"scoring": [BuildProvider(provider_type="inline::basic")],
"tool_runtime": [],
"tool_runtime": [BuildProvider(provider_type="inline::rag-runtime")],
"files": [BuildProvider(provider_type="inline::localfs")],
}
@ -66,7 +66,12 @@ def get_distribution_template(name: str = "nvidia") -> DistributionTemplate:
provider_id="nvidia",
)
default_tool_groups: list[ToolGroupInput] = []
default_tool_groups = [
ToolGroupInput(
toolgroup_id="builtin::rag",
provider_id="rag-runtime",
),
]
return DistributionTemplate(
name=name,

View file

@ -80,7 +80,9 @@ providers:
scoring:
- provider_id: basic
provider_type: inline::basic
tool_runtime: []
tool_runtime:
- provider_id: rag-runtime
provider_type: inline::rag-runtime
files:
- provider_id: meta-reference-files
provider_type: inline::localfs
@ -126,7 +128,9 @@ registered_resources:
datasets: []
scoring_fns: []
benchmarks: []
tool_groups: []
tool_groups:
- toolgroup_id: builtin::rag
provider_id: rag-runtime
server:
port: 8321
telemetry:

View file

@ -69,7 +69,9 @@ providers:
scoring:
- provider_id: basic
provider_type: inline::basic
tool_runtime: []
tool_runtime:
- provider_id: rag-runtime
provider_type: inline::rag-runtime
files:
- provider_id: meta-reference-files
provider_type: inline::localfs
@ -105,7 +107,9 @@ registered_resources:
datasets: []
scoring_fns: []
benchmarks: []
tool_groups: []
tool_groups:
- toolgroup_id: builtin::rag
provider_id: rag-runtime
server:
port: 8321
telemetry:

View file

@ -28,6 +28,7 @@ distribution_spec:
tool_runtime:
- provider_type: remote::brave-search
- provider_type: remote::tavily-search
- provider_type: inline::rag-runtime
- provider_type: remote::model-context-protocol
image_type: venv
additional_pip_packages:

View file

@ -118,6 +118,7 @@ def get_distribution_template() -> DistributionTemplate:
"tool_runtime": [
BuildProvider(provider_type="remote::brave-search"),
BuildProvider(provider_type="remote::tavily-search"),
BuildProvider(provider_type="inline::rag-runtime"),
BuildProvider(provider_type="remote::model-context-protocol"),
],
}
@ -153,6 +154,10 @@ def get_distribution_template() -> DistributionTemplate:
toolgroup_id="builtin::websearch",
provider_id="tavily-search",
),
ToolGroupInput(
toolgroup_id="builtin::rag",
provider_id="rag-runtime",
),
]
models, _ = get_model_registry(available_models)

View file

@ -118,6 +118,8 @@ providers:
config:
api_key: ${env.TAVILY_SEARCH_API_KEY:=}
max_results: 3
- provider_id: rag-runtime
provider_type: inline::rag-runtime
- provider_id: model-context-protocol
provider_type: remote::model-context-protocol
storage:
@ -242,6 +244,8 @@ registered_resources:
tool_groups:
- toolgroup_id: builtin::websearch
provider_id: tavily-search
- toolgroup_id: builtin::rag
provider_id: rag-runtime
server:
port: 8321
telemetry:

View file

@ -14,6 +14,7 @@ distribution_spec:
tool_runtime:
- provider_type: remote::brave-search
- provider_type: remote::tavily-search
- provider_type: inline::rag-runtime
- provider_type: remote::model-context-protocol
image_type: venv
additional_pip_packages:

View file

@ -45,6 +45,7 @@ def get_distribution_template() -> DistributionTemplate:
"tool_runtime": [
BuildProvider(provider_type="remote::brave-search"),
BuildProvider(provider_type="remote::tavily-search"),
BuildProvider(provider_type="inline::rag-runtime"),
BuildProvider(provider_type="remote::model-context-protocol"),
],
}
@ -65,6 +66,10 @@ def get_distribution_template() -> DistributionTemplate:
toolgroup_id="builtin::websearch",
provider_id="tavily-search",
),
ToolGroupInput(
toolgroup_id="builtin::rag",
provider_id="rag-runtime",
),
]
default_models = [

View file

@ -54,6 +54,8 @@ providers:
config:
api_key: ${env.TAVILY_SEARCH_API_KEY:=}
max_results: 3
- provider_id: rag-runtime
provider_type: inline::rag-runtime
- provider_id: model-context-protocol
provider_type: remote::model-context-protocol
storage:
@ -105,6 +107,8 @@ registered_resources:
tool_groups:
- toolgroup_id: builtin::websearch
provider_id: tavily-search
- toolgroup_id: builtin::rag
provider_id: rag-runtime
server:
port: 8321
telemetry:

View file

@ -49,6 +49,7 @@ distribution_spec:
tool_runtime:
- provider_type: remote::brave-search
- provider_type: remote::tavily-search
- provider_type: inline::rag-runtime
- provider_type: remote::model-context-protocol
batches:
- provider_type: inline::reference

View file

@ -219,6 +219,8 @@ providers:
config:
api_key: ${env.TAVILY_SEARCH_API_KEY:=}
max_results: 3
- provider_id: rag-runtime
provider_type: inline::rag-runtime
- provider_id: model-context-protocol
provider_type: remote::model-context-protocol
batches:
@ -264,6 +266,8 @@ registered_resources:
tool_groups:
- toolgroup_id: builtin::websearch
provider_id: tavily-search
- toolgroup_id: builtin::rag
provider_id: rag-runtime
server:
port: 8321
telemetry:

View file

@ -49,6 +49,7 @@ distribution_spec:
tool_runtime:
- provider_type: remote::brave-search
- provider_type: remote::tavily-search
- provider_type: inline::rag-runtime
- provider_type: remote::model-context-protocol
batches:
- provider_type: inline::reference

View file

@ -216,6 +216,8 @@ providers:
config:
api_key: ${env.TAVILY_SEARCH_API_KEY:=}
max_results: 3
- provider_id: rag-runtime
provider_type: inline::rag-runtime
- provider_id: model-context-protocol
provider_type: remote::model-context-protocol
batches:
@ -261,6 +263,8 @@ registered_resources:
tool_groups:
- toolgroup_id: builtin::websearch
provider_id: tavily-search
- toolgroup_id: builtin::rag
provider_id: rag-runtime
server:
port: 8321
telemetry:

View file

@ -140,6 +140,7 @@ def get_distribution_template(name: str = "starter") -> DistributionTemplate:
"tool_runtime": [
BuildProvider(provider_type="remote::brave-search"),
BuildProvider(provider_type="remote::tavily-search"),
BuildProvider(provider_type="inline::rag-runtime"),
BuildProvider(provider_type="remote::model-context-protocol"),
],
"batches": [
@ -161,6 +162,10 @@ def get_distribution_template(name: str = "starter") -> DistributionTemplate:
toolgroup_id="builtin::websearch",
provider_id="tavily-search",
),
ToolGroupInput(
toolgroup_id="builtin::rag",
provider_id="rag-runtime",
),
]
default_shields = [
# if the

View file

@ -23,6 +23,7 @@ distribution_spec:
tool_runtime:
- provider_type: remote::brave-search
- provider_type: remote::tavily-search
- provider_type: inline::rag-runtime
- provider_type: remote::model-context-protocol
files:
- provider_type: inline::localfs

View file

@ -83,6 +83,8 @@ providers:
config:
api_key: ${env.TAVILY_SEARCH_API_KEY:=}
max_results: 3
- provider_id: rag-runtime
provider_type: inline::rag-runtime
- provider_id: model-context-protocol
provider_type: remote::model-context-protocol
files:
@ -123,6 +125,8 @@ registered_resources:
tool_groups:
- toolgroup_id: builtin::websearch
provider_id: tavily-search
- toolgroup_id: builtin::rag
provider_id: rag-runtime
server:
port: 8321
telemetry:

View file

@ -33,6 +33,7 @@ def get_distribution_template(name: str = "watsonx") -> DistributionTemplate:
"tool_runtime": [
BuildProvider(provider_type="remote::brave-search"),
BuildProvider(provider_type="remote::tavily-search"),
BuildProvider(provider_type="inline::rag-runtime"),
BuildProvider(provider_type="remote::model-context-protocol"),
],
"files": [BuildProvider(provider_type="inline::localfs")],
@ -49,6 +50,10 @@ def get_distribution_template(name: str = "watsonx") -> DistributionTemplate:
toolgroup_id="builtin::websearch",
provider_id="tavily-search",
),
ToolGroupInput(
toolgroup_id="builtin::rag",
provider_id="rag-runtime",
),
]
files_provider = Provider(

View file

@ -0,0 +1,5 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

View file

@ -0,0 +1,19 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any
from llama_stack.providers.datatypes import Api
from .config import RagToolRuntimeConfig
async def get_provider_impl(config: RagToolRuntimeConfig, deps: dict[Api, Any]):
from .memory import MemoryToolRuntimeImpl
impl = MemoryToolRuntimeImpl(config, deps[Api.vector_io], deps[Api.inference], deps[Api.files])
await impl.initialize()
return impl

View file

@ -0,0 +1,15 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any
from pydantic import BaseModel
class RagToolRuntimeConfig(BaseModel):
@classmethod
def sample_run_config(cls, __distro_dir__: str, **kwargs: Any) -> dict[str, Any]:
return {}

View file

@ -0,0 +1,77 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from jinja2 import Template
from llama_stack.apis.common.content_types import InterleavedContent
from llama_stack.apis.inference import OpenAIChatCompletionRequestWithExtraBody, OpenAIUserMessageParam
from llama_stack.apis.tools.rag_tool import (
DefaultRAGQueryGeneratorConfig,
LLMRAGQueryGeneratorConfig,
RAGQueryGenerator,
RAGQueryGeneratorConfig,
)
from llama_stack.providers.utils.inference.prompt_adapter import (
interleaved_content_as_str,
)
async def generate_rag_query(
config: RAGQueryGeneratorConfig,
content: InterleavedContent,
**kwargs,
):
"""
Generates a query that will be used for
retrieving relevant information from the memory bank.
"""
if config.type == RAGQueryGenerator.default.value:
query = await default_rag_query_generator(config, content, **kwargs)
elif config.type == RAGQueryGenerator.llm.value:
query = await llm_rag_query_generator(config, content, **kwargs)
else:
raise NotImplementedError(f"Unsupported memory query generator {config.type}")
return query
async def default_rag_query_generator(
config: DefaultRAGQueryGeneratorConfig,
content: InterleavedContent,
**kwargs,
):
return interleaved_content_as_str(content, sep=config.separator)
async def llm_rag_query_generator(
config: LLMRAGQueryGeneratorConfig,
content: InterleavedContent,
**kwargs,
):
assert "inference_api" in kwargs, "LLMRAGQueryGenerator needs inference_api"
inference_api = kwargs["inference_api"]
messages = []
if isinstance(content, list):
messages = [interleaved_content_as_str(m) for m in content]
else:
messages = [interleaved_content_as_str(content)]
template = Template(config.template)
rendered_content: str = template.render({"messages": messages})
model = config.model
message = OpenAIUserMessageParam(content=rendered_content)
params = OpenAIChatCompletionRequestWithExtraBody(
model=model,
messages=[message],
stream=False,
)
response = await inference_api.openai_chat_completion(params)
query = response.choices[0].message.content
return query

View file

@ -0,0 +1,332 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import asyncio
import base64
import io
import mimetypes
from typing import Any
import httpx
from fastapi import UploadFile
from pydantic import TypeAdapter
from llama_stack.apis.common.content_types import (
URL,
InterleavedContent,
InterleavedContentItem,
TextContentItem,
)
from llama_stack.apis.files import Files, OpenAIFilePurpose
from llama_stack.apis.inference import Inference
from llama_stack.apis.tools import (
ListToolDefsResponse,
RAGDocument,
RAGQueryConfig,
RAGQueryResult,
RAGToolRuntime,
ToolDef,
ToolGroup,
ToolInvocationResult,
ToolRuntime,
)
from llama_stack.apis.vector_io import (
QueryChunksResponse,
VectorIO,
VectorStoreChunkingStrategyStatic,
VectorStoreChunkingStrategyStaticConfig,
)
from llama_stack.log import get_logger
from llama_stack.providers.datatypes import ToolGroupsProtocolPrivate
from llama_stack.providers.utils.inference.prompt_adapter import interleaved_content_as_str
from llama_stack.providers.utils.memory.vector_store import parse_data_url
from .config import RagToolRuntimeConfig
from .context_retriever import generate_rag_query
log = get_logger(name=__name__, category="tool_runtime")
async def raw_data_from_doc(doc: RAGDocument) -> tuple[bytes, str]:
"""Get raw binary data and mime type from a RAGDocument for file upload."""
if isinstance(doc.content, URL):
if doc.content.uri.startswith("data:"):
parts = parse_data_url(doc.content.uri)
mime_type = parts["mimetype"]
data = parts["data"]
if parts["is_base64"]:
file_data = base64.b64decode(data)
else:
file_data = data.encode("utf-8")
return file_data, mime_type
else:
async with httpx.AsyncClient() as client:
r = await client.get(doc.content.uri)
r.raise_for_status()
mime_type = r.headers.get("content-type", "application/octet-stream")
return r.content, mime_type
else:
if isinstance(doc.content, str):
content_str = doc.content
else:
content_str = interleaved_content_as_str(doc.content)
if content_str.startswith("data:"):
parts = parse_data_url(content_str)
mime_type = parts["mimetype"]
data = parts["data"]
if parts["is_base64"]:
file_data = base64.b64decode(data)
else:
file_data = data.encode("utf-8")
return file_data, mime_type
else:
return content_str.encode("utf-8"), "text/plain"
class MemoryToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime, RAGToolRuntime):
def __init__(
self,
config: RagToolRuntimeConfig,
vector_io_api: VectorIO,
inference_api: Inference,
files_api: Files,
):
self.config = config
self.vector_io_api = vector_io_api
self.inference_api = inference_api
self.files_api = files_api
async def initialize(self):
pass
async def shutdown(self):
pass
async def register_toolgroup(self, toolgroup: ToolGroup) -> None:
pass
async def unregister_toolgroup(self, toolgroup_id: str) -> None:
return
async def insert(
self,
documents: list[RAGDocument],
vector_db_id: str,
chunk_size_in_tokens: int = 512,
) -> None:
if not documents:
return
for doc in documents:
try:
try:
file_data, mime_type = await raw_data_from_doc(doc)
except Exception as e:
log.error(f"Failed to extract content from document {doc.document_id}: {e}")
continue
file_extension = mimetypes.guess_extension(mime_type) or ".txt"
filename = doc.metadata.get("filename", f"{doc.document_id}{file_extension}")
file_obj = io.BytesIO(file_data)
file_obj.name = filename
upload_file = UploadFile(file=file_obj, filename=filename)
try:
created_file = await self.files_api.openai_upload_file(
file=upload_file, purpose=OpenAIFilePurpose.ASSISTANTS
)
except Exception as e:
log.error(f"Failed to upload file for document {doc.document_id}: {e}")
continue
chunking_strategy = VectorStoreChunkingStrategyStatic(
static=VectorStoreChunkingStrategyStaticConfig(
max_chunk_size_tokens=chunk_size_in_tokens,
chunk_overlap_tokens=chunk_size_in_tokens // 4,
)
)
try:
await self.vector_io_api.openai_attach_file_to_vector_store(
vector_store_id=vector_db_id,
file_id=created_file.id,
attributes=doc.metadata,
chunking_strategy=chunking_strategy,
)
except Exception as e:
log.error(
f"Failed to attach file {created_file.id} to vector store {vector_db_id} for document {doc.document_id}: {e}"
)
continue
except Exception as e:
log.error(f"Unexpected error processing document {doc.document_id}: {e}")
continue
async def query(
self,
content: InterleavedContent,
vector_db_ids: list[str],
query_config: RAGQueryConfig | None = None,
) -> RAGQueryResult:
if not vector_db_ids:
raise ValueError(
"No vector DBs were provided to the knowledge search tool. Please provide at least one vector DB ID."
)
query_config = query_config or RAGQueryConfig()
query = await generate_rag_query(
query_config.query_generator_config,
content,
inference_api=self.inference_api,
)
tasks = [
self.vector_io_api.query_chunks(
vector_db_id=vector_db_id,
query=query,
params={
"mode": query_config.mode,
"max_chunks": query_config.max_chunks,
"score_threshold": 0.0,
"ranker": query_config.ranker,
},
)
for vector_db_id in vector_db_ids
]
results: list[QueryChunksResponse] = await asyncio.gather(*tasks)
chunks = []
scores = []
for vector_db_id, result in zip(vector_db_ids, results, strict=False):
for chunk, score in zip(result.chunks, result.scores, strict=False):
if not hasattr(chunk, "metadata") or chunk.metadata is None:
chunk.metadata = {}
chunk.metadata["vector_db_id"] = vector_db_id
chunks.append(chunk)
scores.append(score)
if not chunks:
return RAGQueryResult(content=None)
# sort by score
chunks, scores = zip(*sorted(zip(chunks, scores, strict=False), key=lambda x: x[1], reverse=True), strict=False) # type: ignore
chunks = chunks[: query_config.max_chunks]
tokens = 0
picked: list[InterleavedContentItem] = [
TextContentItem(
text=f"knowledge_search tool found {len(chunks)} chunks:\nBEGIN of knowledge_search tool results.\n"
)
]
for i, chunk in enumerate(chunks):
metadata = chunk.metadata
tokens += metadata.get("token_count", 0)
tokens += metadata.get("metadata_token_count", 0)
if tokens > query_config.max_tokens_in_context:
log.error(
f"Using {len(picked)} chunks; reached max tokens in context: {tokens}",
)
break
# Add useful keys from chunk_metadata to metadata and remove some from metadata
chunk_metadata_keys_to_include_from_context = [
"chunk_id",
"document_id",
"source",
]
metadata_keys_to_exclude_from_context = [
"token_count",
"metadata_token_count",
"vector_db_id",
]
metadata_for_context = {}
for k in chunk_metadata_keys_to_include_from_context:
metadata_for_context[k] = getattr(chunk.chunk_metadata, k)
for k in metadata:
if k not in metadata_keys_to_exclude_from_context:
metadata_for_context[k] = metadata[k]
text_content = query_config.chunk_template.format(index=i + 1, chunk=chunk, metadata=metadata_for_context)
picked.append(TextContentItem(text=text_content))
picked.append(TextContentItem(text="END of knowledge_search tool results.\n"))
picked.append(
TextContentItem(
text=f'The above results were retrieved to help answer the user\'s query: "{interleaved_content_as_str(content)}". Use them as supporting information only in answering this query.\n',
)
)
return RAGQueryResult(
content=picked,
metadata={
"document_ids": [c.document_id for c in chunks[: len(picked)]],
"chunks": [c.content for c in chunks[: len(picked)]],
"scores": scores[: len(picked)],
"vector_db_ids": [c.metadata["vector_db_id"] for c in chunks[: len(picked)]],
},
)
async def list_runtime_tools(
self, tool_group_id: str | None = None, mcp_endpoint: URL | None = None
) -> ListToolDefsResponse:
# Parameters are not listed since these methods are not yet invoked automatically
# by the LLM. The method is only implemented so things like /tools can list without
# encountering fatals.
return ListToolDefsResponse(
data=[
ToolDef(
name="insert_into_memory",
description="Insert documents into memory",
),
ToolDef(
name="knowledge_search",
description="Search for information in a database.",
input_schema={
"type": "object",
"properties": {
"query": {
"type": "string",
"description": "The query to search for. Can be a natural language sentence or keywords.",
}
},
"required": ["query"],
},
),
]
)
async def invoke_tool(self, tool_name: str, kwargs: dict[str, Any]) -> ToolInvocationResult:
vector_db_ids = kwargs.get("vector_db_ids", [])
query_config = kwargs.get("query_config")
if query_config:
query_config = TypeAdapter(RAGQueryConfig).validate_python(query_config)
else:
query_config = RAGQueryConfig()
query = kwargs["query"]
result = await self.query(
content=query,
vector_db_ids=vector_db_ids,
query_config=query_config,
)
return ToolInvocationResult(
content=result.content or [],
metadata={
**(result.metadata or {}),
"citation_files": getattr(result, "citation_files", None),
},
)

View file

@ -42,7 +42,6 @@ def available_providers() -> list[ProviderSpec]:
# CrossEncoder depends on torchao.quantization
pip_packages=[
"torch torchvision torchao>=0.12.0 --extra-index-url https://download.pytorch.org/whl/cpu",
"numpy tqdm transformers",
"sentence-transformers --no-deps",
# required by some SentenceTransformers architectures for tensor rearrange/merge ops
"einops",

View file

@ -7,13 +7,33 @@
from llama_stack.providers.datatypes import (
Api,
InlineProviderSpec,
ProviderSpec,
RemoteProviderSpec,
)
from llama_stack.providers.registry.vector_io import DEFAULT_VECTOR_IO_DEPS
def available_providers() -> list[ProviderSpec]:
return [
InlineProviderSpec(
api=Api.tool_runtime,
provider_type="inline::rag-runtime",
pip_packages=DEFAULT_VECTOR_IO_DEPS
+ [
"tqdm",
"numpy",
"scikit-learn",
"scipy",
"nltk",
"sentencepiece",
"transformers",
],
module="llama_stack.providers.inline.tool_runtime.rag",
config_class="llama_stack.providers.inline.tool_runtime.rag.config.RagToolRuntimeConfig",
api_dependencies=[Api.vector_io, Api.inference, Api.files],
description="RAG (Retrieval-Augmented Generation) tool runtime for document ingestion, chunking, and semantic search.",
),
RemoteProviderSpec(
api=Api.tool_runtime,
adapter_type="brave-search",

View file

@ -119,7 +119,7 @@ Datasets that can fit in memory, frequent reads | Faiss | Optimized for speed, i
#### Empirical Example
Consider the histogram below in which 10,000 randomly generated strings were inserted
in batches of 100 into both Faiss and sqlite-vec.
in batches of 100 into both Faiss and sqlite-vec using `client.tool_runtime.rag_tool.insert()`.
```{image} ../../../../_static/providers/vector_io/write_time_comparison_sqlite-vec-faiss.png
:alt: Comparison of SQLite-Vec and Faiss write times

View file

@ -12,14 +12,17 @@ from dataclasses import dataclass
from typing import Any
from urllib.parse import unquote
import httpx
import numpy as np
from numpy.typing import NDArray
from pydantic import BaseModel
from llama_stack.apis.common.content_types import (
URL,
InterleavedContent,
)
from llama_stack.apis.inference import OpenAIEmbeddingsRequestWithExtraBody
from llama_stack.apis.tools import RAGDocument
from llama_stack.apis.vector_io import Chunk, ChunkMetadata, QueryChunksResponse
from llama_stack.apis.vector_stores import VectorStore
from llama_stack.log import get_logger
@ -126,6 +129,31 @@ def content_from_data_and_mime_type(data: bytes | str, mime_type: str | None, en
return ""
async def content_from_doc(doc: RAGDocument) -> str:
if isinstance(doc.content, URL):
if doc.content.uri.startswith("data:"):
return content_from_data(doc.content.uri)
async with httpx.AsyncClient() as client:
r = await client.get(doc.content.uri)
if doc.mime_type == "application/pdf":
return parse_pdf(r.content)
return r.text
elif isinstance(doc.content, str):
pattern = re.compile("^(https?://|file://|data:)")
if pattern.match(doc.content):
if doc.content.startswith("data:"):
return content_from_data(doc.content)
async with httpx.AsyncClient() as client:
r = await client.get(doc.content)
if doc.mime_type == "application/pdf":
return parse_pdf(r.content)
return r.text
return doc.content
else:
# will raise ValueError if the content is not List[InterleavedContent] or InterleavedContent
return interleaved_content_as_str(doc.content)
def make_overlapped_chunks(
document_id: str, text: str, window_len: int, overlap_len: int, metadata: dict[str, Any]
) -> list[Chunk]: