revert: "chore(cleanup)!: remove tool_runtime.rag_tool" (#3877)

Reverts llamastack/llama-stack#3871

This PR broke RAG (even from Responses -- there _is_ a dependency)
This commit is contained in:
Ashwin Bharambe 2025-10-21 11:22:06 -07:00 committed by GitHub
parent eb3e9b85f9
commit bd3c473208
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
55 changed files with 3114 additions and 17 deletions

View file

@ -2039,6 +2039,69 @@ paths:
schema:
$ref: '#/components/schemas/URL'
deprecated: false
/v1/tool-runtime/rag-tool/insert:
post:
responses:
'200':
description: OK
'400':
$ref: '#/components/responses/BadRequest400'
'429':
$ref: >-
#/components/responses/TooManyRequests429
'500':
$ref: >-
#/components/responses/InternalServerError500
default:
$ref: '#/components/responses/DefaultError'
tags:
- ToolRuntime
summary: >-
Index documents so they can be used by the RAG system.
description: >-
Index documents so they can be used by the RAG system.
parameters: []
requestBody:
content:
application/json:
schema:
$ref: '#/components/schemas/InsertRequest'
required: true
deprecated: false
/v1/tool-runtime/rag-tool/query:
post:
responses:
'200':
description: >-
RAGQueryResult containing the retrieved content and metadata
content:
application/json:
schema:
$ref: '#/components/schemas/RAGQueryResult'
'400':
$ref: '#/components/responses/BadRequest400'
'429':
$ref: >-
#/components/responses/TooManyRequests429
'500':
$ref: >-
#/components/responses/InternalServerError500
default:
$ref: '#/components/responses/DefaultError'
tags:
- ToolRuntime
summary: >-
Query the RAG system for context; typically invoked by the agent.
description: >-
Query the RAG system for context; typically invoked by the agent.
parameters: []
requestBody:
content:
application/json:
schema:
$ref: '#/components/schemas/QueryRequest'
required: true
deprecated: false
/v1/toolgroups:
get:
responses:
@ -9858,6 +9921,274 @@ components:
title: ListToolDefsResponse
description: >-
Response containing a list of tool definitions.
RAGDocument:
type: object
properties:
document_id:
type: string
description: The unique identifier for the document.
content:
oneOf:
- type: string
- $ref: '#/components/schemas/InterleavedContentItem'
- type: array
items:
$ref: '#/components/schemas/InterleavedContentItem'
- $ref: '#/components/schemas/URL'
description: The content of the document.
mime_type:
type: string
description: The MIME type of the document.
metadata:
type: object
additionalProperties:
oneOf:
- type: 'null'
- type: boolean
- type: number
- type: string
- type: array
- type: object
description: Additional metadata for the document.
additionalProperties: false
required:
- document_id
- content
- metadata
title: RAGDocument
description: >-
A document to be used for document ingestion in the RAG Tool.
InsertRequest:
type: object
properties:
documents:
type: array
items:
$ref: '#/components/schemas/RAGDocument'
description: >-
List of documents to index in the RAG system
vector_db_id:
type: string
description: >-
ID of the vector database to store the document embeddings
chunk_size_in_tokens:
type: integer
description: >-
(Optional) Size in tokens for document chunking during indexing
additionalProperties: false
required:
- documents
- vector_db_id
- chunk_size_in_tokens
title: InsertRequest
DefaultRAGQueryGeneratorConfig:
type: object
properties:
type:
type: string
const: default
default: default
description: >-
Type of query generator, always 'default'
separator:
type: string
default: ' '
description: >-
String separator used to join query terms
additionalProperties: false
required:
- type
- separator
title: DefaultRAGQueryGeneratorConfig
description: >-
Configuration for the default RAG query generator.
LLMRAGQueryGeneratorConfig:
type: object
properties:
type:
type: string
const: llm
default: llm
description: Type of query generator, always 'llm'
model:
type: string
description: >-
Name of the language model to use for query generation
template:
type: string
description: >-
Template string for formatting the query generation prompt
additionalProperties: false
required:
- type
- model
- template
title: LLMRAGQueryGeneratorConfig
description: >-
Configuration for the LLM-based RAG query generator.
RAGQueryConfig:
type: object
properties:
query_generator_config:
oneOf:
- $ref: '#/components/schemas/DefaultRAGQueryGeneratorConfig'
- $ref: '#/components/schemas/LLMRAGQueryGeneratorConfig'
discriminator:
propertyName: type
mapping:
default: '#/components/schemas/DefaultRAGQueryGeneratorConfig'
llm: '#/components/schemas/LLMRAGQueryGeneratorConfig'
description: Configuration for the query generator.
max_tokens_in_context:
type: integer
default: 4096
description: Maximum number of tokens in the context.
max_chunks:
type: integer
default: 5
description: Maximum number of chunks to retrieve.
chunk_template:
type: string
default: >
Result {index}
Content: {chunk.content}
Metadata: {metadata}
description: >-
Template for formatting each retrieved chunk in the context. Available
placeholders: {index} (1-based chunk ordinal), {chunk.content} (chunk
content string), {metadata} (chunk metadata dict). Default: "Result {index}\nContent:
{chunk.content}\nMetadata: {metadata}\n"
mode:
$ref: '#/components/schemas/RAGSearchMode'
default: vector
description: >-
Search mode for retrieval—either "vector", "keyword", or "hybrid". Default
"vector".
ranker:
$ref: '#/components/schemas/Ranker'
description: >-
Configuration for the ranker to use in hybrid search. Defaults to RRF
ranker.
additionalProperties: false
required:
- query_generator_config
- max_tokens_in_context
- max_chunks
- chunk_template
title: RAGQueryConfig
description: >-
Configuration for the RAG query generation.
RAGSearchMode:
type: string
enum:
- vector
- keyword
- hybrid
title: RAGSearchMode
description: >-
Search modes for RAG query retrieval: - VECTOR: Uses vector similarity search
for semantic matching - KEYWORD: Uses keyword-based search for exact matching
- HYBRID: Combines both vector and keyword search for better results
RRFRanker:
type: object
properties:
type:
type: string
const: rrf
default: rrf
description: The type of ranker, always "rrf"
impact_factor:
type: number
default: 60.0
description: >-
The impact factor for RRF scoring. Higher values give more weight to higher-ranked
results. Must be greater than 0
additionalProperties: false
required:
- type
- impact_factor
title: RRFRanker
description: >-
Reciprocal Rank Fusion (RRF) ranker configuration.
Ranker:
oneOf:
- $ref: '#/components/schemas/RRFRanker'
- $ref: '#/components/schemas/WeightedRanker'
discriminator:
propertyName: type
mapping:
rrf: '#/components/schemas/RRFRanker'
weighted: '#/components/schemas/WeightedRanker'
WeightedRanker:
type: object
properties:
type:
type: string
const: weighted
default: weighted
description: The type of ranker, always "weighted"
alpha:
type: number
default: 0.5
description: >-
Weight factor between 0 and 1. 0 means only use keyword scores, 1 means
only use vector scores, values in between blend both scores.
additionalProperties: false
required:
- type
- alpha
title: WeightedRanker
description: >-
Weighted ranker configuration that combines vector and keyword scores.
QueryRequest:
type: object
properties:
content:
$ref: '#/components/schemas/InterleavedContent'
description: >-
The query content to search for in the indexed documents
vector_db_ids:
type: array
items:
type: string
description: >-
List of vector database IDs to search within
query_config:
$ref: '#/components/schemas/RAGQueryConfig'
description: >-
(Optional) Configuration parameters for the query operation
additionalProperties: false
required:
- content
- vector_db_ids
title: QueryRequest
RAGQueryResult:
type: object
properties:
content:
$ref: '#/components/schemas/InterleavedContent'
description: >-
(Optional) The retrieved content from the query
metadata:
type: object
additionalProperties:
oneOf:
- type: 'null'
- type: boolean
- type: number
- type: string
- type: array
- type: object
description: >-
Additional metadata about the query result
additionalProperties: false
required:
- metadata
title: RAGQueryResult
description: >-
Result of a RAG query containing retrieved content and metadata.
ToolGroup:
type: object
properties:

View file

@ -21,7 +21,7 @@ The `llamastack/distribution-meta-reference-gpu` distribution consists of the fo
| inference | `inline::meta-reference` |
| safety | `inline::llama-guard` |
| scoring | `inline::basic`, `inline::llm-as-judge`, `inline::braintrust` |
| tool_runtime | `remote::brave-search`, `remote::tavily-search`, `remote::model-context-protocol` |
| tool_runtime | `remote::brave-search`, `remote::tavily-search`, `inline::rag-runtime`, `remote::model-context-protocol` |
| vector_io | `inline::faiss`, `remote::chromadb`, `remote::pgvector` |

View file

@ -16,7 +16,7 @@ The `llamastack/distribution-nvidia` distribution consists of the following prov
| post_training | `remote::nvidia` |
| safety | `remote::nvidia` |
| scoring | `inline::basic` |
| tool_runtime | |
| tool_runtime | `inline::rag-runtime` |
| vector_io | `inline::faiss` |

View file

@ -28,7 +28,7 @@ description: |
#### Empirical Example
Consider the histogram below in which 10,000 randomly generated strings were inserted
in batches of 100 into both Faiss and sqlite-vec.
in batches of 100 into both Faiss and sqlite-vec using `client.tool_runtime.rag_tool.insert()`.
```{image} ../../../../_static/providers/vector_io/write_time_comparison_sqlite-vec-faiss.png
:alt: Comparison of SQLite-Vec and Faiss write times
@ -233,7 +233,7 @@ Datasets that can fit in memory, frequent reads | Faiss | Optimized for speed, i
#### Empirical Example
Consider the histogram below in which 10,000 randomly generated strings were inserted
in batches of 100 into both Faiss and sqlite-vec.
in batches of 100 into both Faiss and sqlite-vec using `client.tool_runtime.rag_tool.insert()`.
```{image} ../../../../_static/providers/vector_io/write_time_comparison_sqlite-vec-faiss.png
:alt: Comparison of SQLite-Vec and Faiss write times

View file

@ -196,10 +196,16 @@ def _get_endpoint_functions(
def _get_defining_class(member_fn: str, derived_cls: type) -> type:
"Find the class in which a member function is first defined in a class inheritance hierarchy."
# This import must be dynamic here
from llama_stack.apis.tools import RAGToolRuntime, ToolRuntime
# iterate in reverse member resolution order to find most specific class first
for cls in reversed(inspect.getmro(derived_cls)):
for name, _ in inspect.getmembers(cls, inspect.isfunction):
if name == member_fn:
# HACK ALERT
if cls == RAGToolRuntime:
return ToolRuntime
return cls
raise ValidationError(

View file

@ -2624,6 +2624,89 @@
"deprecated": false
}
},
"/v1/tool-runtime/rag-tool/insert": {
"post": {
"responses": {
"200": {
"description": "OK"
},
"400": {
"$ref": "#/components/responses/BadRequest400"
},
"429": {
"$ref": "#/components/responses/TooManyRequests429"
},
"500": {
"$ref": "#/components/responses/InternalServerError500"
},
"default": {
"$ref": "#/components/responses/DefaultError"
}
},
"tags": [
"ToolRuntime"
],
"summary": "Index documents so they can be used by the RAG system.",
"description": "Index documents so they can be used by the RAG system.",
"parameters": [],
"requestBody": {
"content": {
"application/json": {
"schema": {
"$ref": "#/components/schemas/InsertRequest"
}
}
},
"required": true
},
"deprecated": false
}
},
"/v1/tool-runtime/rag-tool/query": {
"post": {
"responses": {
"200": {
"description": "RAGQueryResult containing the retrieved content and metadata",
"content": {
"application/json": {
"schema": {
"$ref": "#/components/schemas/RAGQueryResult"
}
}
}
},
"400": {
"$ref": "#/components/responses/BadRequest400"
},
"429": {
"$ref": "#/components/responses/TooManyRequests429"
},
"500": {
"$ref": "#/components/responses/InternalServerError500"
},
"default": {
"$ref": "#/components/responses/DefaultError"
}
},
"tags": [
"ToolRuntime"
],
"summary": "Query the RAG system for context; typically invoked by the agent.",
"description": "Query the RAG system for context; typically invoked by the agent.",
"parameters": [],
"requestBody": {
"content": {
"application/json": {
"schema": {
"$ref": "#/components/schemas/QueryRequest"
}
}
},
"required": true
},
"deprecated": false
}
},
"/v1/toolgroups": {
"get": {
"responses": {
@ -11300,6 +11383,346 @@
"title": "ListToolDefsResponse",
"description": "Response containing a list of tool definitions."
},
"RAGDocument": {
"type": "object",
"properties": {
"document_id": {
"type": "string",
"description": "The unique identifier for the document."
},
"content": {
"oneOf": [
{
"type": "string"
},
{
"$ref": "#/components/schemas/InterleavedContentItem"
},
{
"type": "array",
"items": {
"$ref": "#/components/schemas/InterleavedContentItem"
}
},
{
"$ref": "#/components/schemas/URL"
}
],
"description": "The content of the document."
},
"mime_type": {
"type": "string",
"description": "The MIME type of the document."
},
"metadata": {
"type": "object",
"additionalProperties": {
"oneOf": [
{
"type": "null"
},
{
"type": "boolean"
},
{
"type": "number"
},
{
"type": "string"
},
{
"type": "array"
},
{
"type": "object"
}
]
},
"description": "Additional metadata for the document."
}
},
"additionalProperties": false,
"required": [
"document_id",
"content",
"metadata"
],
"title": "RAGDocument",
"description": "A document to be used for document ingestion in the RAG Tool."
},
"InsertRequest": {
"type": "object",
"properties": {
"documents": {
"type": "array",
"items": {
"$ref": "#/components/schemas/RAGDocument"
},
"description": "List of documents to index in the RAG system"
},
"vector_db_id": {
"type": "string",
"description": "ID of the vector database to store the document embeddings"
},
"chunk_size_in_tokens": {
"type": "integer",
"description": "(Optional) Size in tokens for document chunking during indexing"
}
},
"additionalProperties": false,
"required": [
"documents",
"vector_db_id",
"chunk_size_in_tokens"
],
"title": "InsertRequest"
},
"DefaultRAGQueryGeneratorConfig": {
"type": "object",
"properties": {
"type": {
"type": "string",
"const": "default",
"default": "default",
"description": "Type of query generator, always 'default'"
},
"separator": {
"type": "string",
"default": " ",
"description": "String separator used to join query terms"
}
},
"additionalProperties": false,
"required": [
"type",
"separator"
],
"title": "DefaultRAGQueryGeneratorConfig",
"description": "Configuration for the default RAG query generator."
},
"LLMRAGQueryGeneratorConfig": {
"type": "object",
"properties": {
"type": {
"type": "string",
"const": "llm",
"default": "llm",
"description": "Type of query generator, always 'llm'"
},
"model": {
"type": "string",
"description": "Name of the language model to use for query generation"
},
"template": {
"type": "string",
"description": "Template string for formatting the query generation prompt"
}
},
"additionalProperties": false,
"required": [
"type",
"model",
"template"
],
"title": "LLMRAGQueryGeneratorConfig",
"description": "Configuration for the LLM-based RAG query generator."
},
"RAGQueryConfig": {
"type": "object",
"properties": {
"query_generator_config": {
"oneOf": [
{
"$ref": "#/components/schemas/DefaultRAGQueryGeneratorConfig"
},
{
"$ref": "#/components/schemas/LLMRAGQueryGeneratorConfig"
}
],
"discriminator": {
"propertyName": "type",
"mapping": {
"default": "#/components/schemas/DefaultRAGQueryGeneratorConfig",
"llm": "#/components/schemas/LLMRAGQueryGeneratorConfig"
}
},
"description": "Configuration for the query generator."
},
"max_tokens_in_context": {
"type": "integer",
"default": 4096,
"description": "Maximum number of tokens in the context."
},
"max_chunks": {
"type": "integer",
"default": 5,
"description": "Maximum number of chunks to retrieve."
},
"chunk_template": {
"type": "string",
"default": "Result {index}\nContent: {chunk.content}\nMetadata: {metadata}\n",
"description": "Template for formatting each retrieved chunk in the context. Available placeholders: {index} (1-based chunk ordinal), {chunk.content} (chunk content string), {metadata} (chunk metadata dict). Default: \"Result {index}\\nContent: {chunk.content}\\nMetadata: {metadata}\\n\""
},
"mode": {
"$ref": "#/components/schemas/RAGSearchMode",
"default": "vector",
"description": "Search mode for retrieval—either \"vector\", \"keyword\", or \"hybrid\". Default \"vector\"."
},
"ranker": {
"$ref": "#/components/schemas/Ranker",
"description": "Configuration for the ranker to use in hybrid search. Defaults to RRF ranker."
}
},
"additionalProperties": false,
"required": [
"query_generator_config",
"max_tokens_in_context",
"max_chunks",
"chunk_template"
],
"title": "RAGQueryConfig",
"description": "Configuration for the RAG query generation."
},
"RAGSearchMode": {
"type": "string",
"enum": [
"vector",
"keyword",
"hybrid"
],
"title": "RAGSearchMode",
"description": "Search modes for RAG query retrieval: - VECTOR: Uses vector similarity search for semantic matching - KEYWORD: Uses keyword-based search for exact matching - HYBRID: Combines both vector and keyword search for better results"
},
"RRFRanker": {
"type": "object",
"properties": {
"type": {
"type": "string",
"const": "rrf",
"default": "rrf",
"description": "The type of ranker, always \"rrf\""
},
"impact_factor": {
"type": "number",
"default": 60.0,
"description": "The impact factor for RRF scoring. Higher values give more weight to higher-ranked results. Must be greater than 0"
}
},
"additionalProperties": false,
"required": [
"type",
"impact_factor"
],
"title": "RRFRanker",
"description": "Reciprocal Rank Fusion (RRF) ranker configuration."
},
"Ranker": {
"oneOf": [
{
"$ref": "#/components/schemas/RRFRanker"
},
{
"$ref": "#/components/schemas/WeightedRanker"
}
],
"discriminator": {
"propertyName": "type",
"mapping": {
"rrf": "#/components/schemas/RRFRanker",
"weighted": "#/components/schemas/WeightedRanker"
}
}
},
"WeightedRanker": {
"type": "object",
"properties": {
"type": {
"type": "string",
"const": "weighted",
"default": "weighted",
"description": "The type of ranker, always \"weighted\""
},
"alpha": {
"type": "number",
"default": 0.5,
"description": "Weight factor between 0 and 1. 0 means only use keyword scores, 1 means only use vector scores, values in between blend both scores."
}
},
"additionalProperties": false,
"required": [
"type",
"alpha"
],
"title": "WeightedRanker",
"description": "Weighted ranker configuration that combines vector and keyword scores."
},
"QueryRequest": {
"type": "object",
"properties": {
"content": {
"$ref": "#/components/schemas/InterleavedContent",
"description": "The query content to search for in the indexed documents"
},
"vector_db_ids": {
"type": "array",
"items": {
"type": "string"
},
"description": "List of vector database IDs to search within"
},
"query_config": {
"$ref": "#/components/schemas/RAGQueryConfig",
"description": "(Optional) Configuration parameters for the query operation"
}
},
"additionalProperties": false,
"required": [
"content",
"vector_db_ids"
],
"title": "QueryRequest"
},
"RAGQueryResult": {
"type": "object",
"properties": {
"content": {
"$ref": "#/components/schemas/InterleavedContent",
"description": "(Optional) The retrieved content from the query"
},
"metadata": {
"type": "object",
"additionalProperties": {
"oneOf": [
{
"type": "null"
},
{
"type": "boolean"
},
{
"type": "number"
},
{
"type": "string"
},
{
"type": "array"
},
{
"type": "object"
}
]
},
"description": "Additional metadata about the query result"
}
},
"additionalProperties": false,
"required": [
"metadata"
],
"title": "RAGQueryResult",
"description": "Result of a RAG query containing retrieved content and metadata."
},
"ToolGroup": {
"type": "object",
"properties": {

View file

@ -2036,6 +2036,69 @@ paths:
schema:
$ref: '#/components/schemas/URL'
deprecated: false
/v1/tool-runtime/rag-tool/insert:
post:
responses:
'200':
description: OK
'400':
$ref: '#/components/responses/BadRequest400'
'429':
$ref: >-
#/components/responses/TooManyRequests429
'500':
$ref: >-
#/components/responses/InternalServerError500
default:
$ref: '#/components/responses/DefaultError'
tags:
- ToolRuntime
summary: >-
Index documents so they can be used by the RAG system.
description: >-
Index documents so they can be used by the RAG system.
parameters: []
requestBody:
content:
application/json:
schema:
$ref: '#/components/schemas/InsertRequest'
required: true
deprecated: false
/v1/tool-runtime/rag-tool/query:
post:
responses:
'200':
description: >-
RAGQueryResult containing the retrieved content and metadata
content:
application/json:
schema:
$ref: '#/components/schemas/RAGQueryResult'
'400':
$ref: '#/components/responses/BadRequest400'
'429':
$ref: >-
#/components/responses/TooManyRequests429
'500':
$ref: >-
#/components/responses/InternalServerError500
default:
$ref: '#/components/responses/DefaultError'
tags:
- ToolRuntime
summary: >-
Query the RAG system for context; typically invoked by the agent.
description: >-
Query the RAG system for context; typically invoked by the agent.
parameters: []
requestBody:
content:
application/json:
schema:
$ref: '#/components/schemas/QueryRequest'
required: true
deprecated: false
/v1/toolgroups:
get:
responses:
@ -8645,6 +8708,274 @@ components:
title: ListToolDefsResponse
description: >-
Response containing a list of tool definitions.
RAGDocument:
type: object
properties:
document_id:
type: string
description: The unique identifier for the document.
content:
oneOf:
- type: string
- $ref: '#/components/schemas/InterleavedContentItem'
- type: array
items:
$ref: '#/components/schemas/InterleavedContentItem'
- $ref: '#/components/schemas/URL'
description: The content of the document.
mime_type:
type: string
description: The MIME type of the document.
metadata:
type: object
additionalProperties:
oneOf:
- type: 'null'
- type: boolean
- type: number
- type: string
- type: array
- type: object
description: Additional metadata for the document.
additionalProperties: false
required:
- document_id
- content
- metadata
title: RAGDocument
description: >-
A document to be used for document ingestion in the RAG Tool.
InsertRequest:
type: object
properties:
documents:
type: array
items:
$ref: '#/components/schemas/RAGDocument'
description: >-
List of documents to index in the RAG system
vector_db_id:
type: string
description: >-
ID of the vector database to store the document embeddings
chunk_size_in_tokens:
type: integer
description: >-
(Optional) Size in tokens for document chunking during indexing
additionalProperties: false
required:
- documents
- vector_db_id
- chunk_size_in_tokens
title: InsertRequest
DefaultRAGQueryGeneratorConfig:
type: object
properties:
type:
type: string
const: default
default: default
description: >-
Type of query generator, always 'default'
separator:
type: string
default: ' '
description: >-
String separator used to join query terms
additionalProperties: false
required:
- type
- separator
title: DefaultRAGQueryGeneratorConfig
description: >-
Configuration for the default RAG query generator.
LLMRAGQueryGeneratorConfig:
type: object
properties:
type:
type: string
const: llm
default: llm
description: Type of query generator, always 'llm'
model:
type: string
description: >-
Name of the language model to use for query generation
template:
type: string
description: >-
Template string for formatting the query generation prompt
additionalProperties: false
required:
- type
- model
- template
title: LLMRAGQueryGeneratorConfig
description: >-
Configuration for the LLM-based RAG query generator.
RAGQueryConfig:
type: object
properties:
query_generator_config:
oneOf:
- $ref: '#/components/schemas/DefaultRAGQueryGeneratorConfig'
- $ref: '#/components/schemas/LLMRAGQueryGeneratorConfig'
discriminator:
propertyName: type
mapping:
default: '#/components/schemas/DefaultRAGQueryGeneratorConfig'
llm: '#/components/schemas/LLMRAGQueryGeneratorConfig'
description: Configuration for the query generator.
max_tokens_in_context:
type: integer
default: 4096
description: Maximum number of tokens in the context.
max_chunks:
type: integer
default: 5
description: Maximum number of chunks to retrieve.
chunk_template:
type: string
default: >
Result {index}
Content: {chunk.content}
Metadata: {metadata}
description: >-
Template for formatting each retrieved chunk in the context. Available
placeholders: {index} (1-based chunk ordinal), {chunk.content} (chunk
content string), {metadata} (chunk metadata dict). Default: "Result {index}\nContent:
{chunk.content}\nMetadata: {metadata}\n"
mode:
$ref: '#/components/schemas/RAGSearchMode'
default: vector
description: >-
Search mode for retrieval—either "vector", "keyword", or "hybrid". Default
"vector".
ranker:
$ref: '#/components/schemas/Ranker'
description: >-
Configuration for the ranker to use in hybrid search. Defaults to RRF
ranker.
additionalProperties: false
required:
- query_generator_config
- max_tokens_in_context
- max_chunks
- chunk_template
title: RAGQueryConfig
description: >-
Configuration for the RAG query generation.
RAGSearchMode:
type: string
enum:
- vector
- keyword
- hybrid
title: RAGSearchMode
description: >-
Search modes for RAG query retrieval: - VECTOR: Uses vector similarity search
for semantic matching - KEYWORD: Uses keyword-based search for exact matching
- HYBRID: Combines both vector and keyword search for better results
RRFRanker:
type: object
properties:
type:
type: string
const: rrf
default: rrf
description: The type of ranker, always "rrf"
impact_factor:
type: number
default: 60.0
description: >-
The impact factor for RRF scoring. Higher values give more weight to higher-ranked
results. Must be greater than 0
additionalProperties: false
required:
- type
- impact_factor
title: RRFRanker
description: >-
Reciprocal Rank Fusion (RRF) ranker configuration.
Ranker:
oneOf:
- $ref: '#/components/schemas/RRFRanker'
- $ref: '#/components/schemas/WeightedRanker'
discriminator:
propertyName: type
mapping:
rrf: '#/components/schemas/RRFRanker'
weighted: '#/components/schemas/WeightedRanker'
WeightedRanker:
type: object
properties:
type:
type: string
const: weighted
default: weighted
description: The type of ranker, always "weighted"
alpha:
type: number
default: 0.5
description: >-
Weight factor between 0 and 1. 0 means only use keyword scores, 1 means
only use vector scores, values in between blend both scores.
additionalProperties: false
required:
- type
- alpha
title: WeightedRanker
description: >-
Weighted ranker configuration that combines vector and keyword scores.
QueryRequest:
type: object
properties:
content:
$ref: '#/components/schemas/InterleavedContent'
description: >-
The query content to search for in the indexed documents
vector_db_ids:
type: array
items:
type: string
description: >-
List of vector database IDs to search within
query_config:
$ref: '#/components/schemas/RAGQueryConfig'
description: >-
(Optional) Configuration parameters for the query operation
additionalProperties: false
required:
- content
- vector_db_ids
title: QueryRequest
RAGQueryResult:
type: object
properties:
content:
$ref: '#/components/schemas/InterleavedContent'
description: >-
(Optional) The retrieved content from the query
metadata:
type: object
additionalProperties:
oneOf:
- type: 'null'
- type: boolean
- type: number
- type: string
- type: array
- type: object
description: >-
Additional metadata about the query result
additionalProperties: false
required:
- metadata
title: RAGQueryResult
description: >-
Result of a RAG query containing retrieved content and metadata.
ToolGroup:
type: object
properties:

View file

@ -2624,6 +2624,89 @@
"deprecated": false
}
},
"/v1/tool-runtime/rag-tool/insert": {
"post": {
"responses": {
"200": {
"description": "OK"
},
"400": {
"$ref": "#/components/responses/BadRequest400"
},
"429": {
"$ref": "#/components/responses/TooManyRequests429"
},
"500": {
"$ref": "#/components/responses/InternalServerError500"
},
"default": {
"$ref": "#/components/responses/DefaultError"
}
},
"tags": [
"ToolRuntime"
],
"summary": "Index documents so they can be used by the RAG system.",
"description": "Index documents so they can be used by the RAG system.",
"parameters": [],
"requestBody": {
"content": {
"application/json": {
"schema": {
"$ref": "#/components/schemas/InsertRequest"
}
}
},
"required": true
},
"deprecated": false
}
},
"/v1/tool-runtime/rag-tool/query": {
"post": {
"responses": {
"200": {
"description": "RAGQueryResult containing the retrieved content and metadata",
"content": {
"application/json": {
"schema": {
"$ref": "#/components/schemas/RAGQueryResult"
}
}
}
},
"400": {
"$ref": "#/components/responses/BadRequest400"
},
"429": {
"$ref": "#/components/responses/TooManyRequests429"
},
"500": {
"$ref": "#/components/responses/InternalServerError500"
},
"default": {
"$ref": "#/components/responses/DefaultError"
}
},
"tags": [
"ToolRuntime"
],
"summary": "Query the RAG system for context; typically invoked by the agent.",
"description": "Query the RAG system for context; typically invoked by the agent.",
"parameters": [],
"requestBody": {
"content": {
"application/json": {
"schema": {
"$ref": "#/components/schemas/QueryRequest"
}
}
},
"required": true
},
"deprecated": false
}
},
"/v1/toolgroups": {
"get": {
"responses": {
@ -12972,6 +13055,346 @@
"title": "ListToolDefsResponse",
"description": "Response containing a list of tool definitions."
},
"RAGDocument": {
"type": "object",
"properties": {
"document_id": {
"type": "string",
"description": "The unique identifier for the document."
},
"content": {
"oneOf": [
{
"type": "string"
},
{
"$ref": "#/components/schemas/InterleavedContentItem"
},
{
"type": "array",
"items": {
"$ref": "#/components/schemas/InterleavedContentItem"
}
},
{
"$ref": "#/components/schemas/URL"
}
],
"description": "The content of the document."
},
"mime_type": {
"type": "string",
"description": "The MIME type of the document."
},
"metadata": {
"type": "object",
"additionalProperties": {
"oneOf": [
{
"type": "null"
},
{
"type": "boolean"
},
{
"type": "number"
},
{
"type": "string"
},
{
"type": "array"
},
{
"type": "object"
}
]
},
"description": "Additional metadata for the document."
}
},
"additionalProperties": false,
"required": [
"document_id",
"content",
"metadata"
],
"title": "RAGDocument",
"description": "A document to be used for document ingestion in the RAG Tool."
},
"InsertRequest": {
"type": "object",
"properties": {
"documents": {
"type": "array",
"items": {
"$ref": "#/components/schemas/RAGDocument"
},
"description": "List of documents to index in the RAG system"
},
"vector_db_id": {
"type": "string",
"description": "ID of the vector database to store the document embeddings"
},
"chunk_size_in_tokens": {
"type": "integer",
"description": "(Optional) Size in tokens for document chunking during indexing"
}
},
"additionalProperties": false,
"required": [
"documents",
"vector_db_id",
"chunk_size_in_tokens"
],
"title": "InsertRequest"
},
"DefaultRAGQueryGeneratorConfig": {
"type": "object",
"properties": {
"type": {
"type": "string",
"const": "default",
"default": "default",
"description": "Type of query generator, always 'default'"
},
"separator": {
"type": "string",
"default": " ",
"description": "String separator used to join query terms"
}
},
"additionalProperties": false,
"required": [
"type",
"separator"
],
"title": "DefaultRAGQueryGeneratorConfig",
"description": "Configuration for the default RAG query generator."
},
"LLMRAGQueryGeneratorConfig": {
"type": "object",
"properties": {
"type": {
"type": "string",
"const": "llm",
"default": "llm",
"description": "Type of query generator, always 'llm'"
},
"model": {
"type": "string",
"description": "Name of the language model to use for query generation"
},
"template": {
"type": "string",
"description": "Template string for formatting the query generation prompt"
}
},
"additionalProperties": false,
"required": [
"type",
"model",
"template"
],
"title": "LLMRAGQueryGeneratorConfig",
"description": "Configuration for the LLM-based RAG query generator."
},
"RAGQueryConfig": {
"type": "object",
"properties": {
"query_generator_config": {
"oneOf": [
{
"$ref": "#/components/schemas/DefaultRAGQueryGeneratorConfig"
},
{
"$ref": "#/components/schemas/LLMRAGQueryGeneratorConfig"
}
],
"discriminator": {
"propertyName": "type",
"mapping": {
"default": "#/components/schemas/DefaultRAGQueryGeneratorConfig",
"llm": "#/components/schemas/LLMRAGQueryGeneratorConfig"
}
},
"description": "Configuration for the query generator."
},
"max_tokens_in_context": {
"type": "integer",
"default": 4096,
"description": "Maximum number of tokens in the context."
},
"max_chunks": {
"type": "integer",
"default": 5,
"description": "Maximum number of chunks to retrieve."
},
"chunk_template": {
"type": "string",
"default": "Result {index}\nContent: {chunk.content}\nMetadata: {metadata}\n",
"description": "Template for formatting each retrieved chunk in the context. Available placeholders: {index} (1-based chunk ordinal), {chunk.content} (chunk content string), {metadata} (chunk metadata dict). Default: \"Result {index}\\nContent: {chunk.content}\\nMetadata: {metadata}\\n\""
},
"mode": {
"$ref": "#/components/schemas/RAGSearchMode",
"default": "vector",
"description": "Search mode for retrieval—either \"vector\", \"keyword\", or \"hybrid\". Default \"vector\"."
},
"ranker": {
"$ref": "#/components/schemas/Ranker",
"description": "Configuration for the ranker to use in hybrid search. Defaults to RRF ranker."
}
},
"additionalProperties": false,
"required": [
"query_generator_config",
"max_tokens_in_context",
"max_chunks",
"chunk_template"
],
"title": "RAGQueryConfig",
"description": "Configuration for the RAG query generation."
},
"RAGSearchMode": {
"type": "string",
"enum": [
"vector",
"keyword",
"hybrid"
],
"title": "RAGSearchMode",
"description": "Search modes for RAG query retrieval: - VECTOR: Uses vector similarity search for semantic matching - KEYWORD: Uses keyword-based search for exact matching - HYBRID: Combines both vector and keyword search for better results"
},
"RRFRanker": {
"type": "object",
"properties": {
"type": {
"type": "string",
"const": "rrf",
"default": "rrf",
"description": "The type of ranker, always \"rrf\""
},
"impact_factor": {
"type": "number",
"default": 60.0,
"description": "The impact factor for RRF scoring. Higher values give more weight to higher-ranked results. Must be greater than 0"
}
},
"additionalProperties": false,
"required": [
"type",
"impact_factor"
],
"title": "RRFRanker",
"description": "Reciprocal Rank Fusion (RRF) ranker configuration."
},
"Ranker": {
"oneOf": [
{
"$ref": "#/components/schemas/RRFRanker"
},
{
"$ref": "#/components/schemas/WeightedRanker"
}
],
"discriminator": {
"propertyName": "type",
"mapping": {
"rrf": "#/components/schemas/RRFRanker",
"weighted": "#/components/schemas/WeightedRanker"
}
}
},
"WeightedRanker": {
"type": "object",
"properties": {
"type": {
"type": "string",
"const": "weighted",
"default": "weighted",
"description": "The type of ranker, always \"weighted\""
},
"alpha": {
"type": "number",
"default": 0.5,
"description": "Weight factor between 0 and 1. 0 means only use keyword scores, 1 means only use vector scores, values in between blend both scores."
}
},
"additionalProperties": false,
"required": [
"type",
"alpha"
],
"title": "WeightedRanker",
"description": "Weighted ranker configuration that combines vector and keyword scores."
},
"QueryRequest": {
"type": "object",
"properties": {
"content": {
"$ref": "#/components/schemas/InterleavedContent",
"description": "The query content to search for in the indexed documents"
},
"vector_db_ids": {
"type": "array",
"items": {
"type": "string"
},
"description": "List of vector database IDs to search within"
},
"query_config": {
"$ref": "#/components/schemas/RAGQueryConfig",
"description": "(Optional) Configuration parameters for the query operation"
}
},
"additionalProperties": false,
"required": [
"content",
"vector_db_ids"
],
"title": "QueryRequest"
},
"RAGQueryResult": {
"type": "object",
"properties": {
"content": {
"$ref": "#/components/schemas/InterleavedContent",
"description": "(Optional) The retrieved content from the query"
},
"metadata": {
"type": "object",
"additionalProperties": {
"oneOf": [
{
"type": "null"
},
{
"type": "boolean"
},
{
"type": "number"
},
{
"type": "string"
},
{
"type": "array"
},
{
"type": "object"
}
]
},
"description": "Additional metadata about the query result"
}
},
"additionalProperties": false,
"required": [
"metadata"
],
"title": "RAGQueryResult",
"description": "Result of a RAG query containing retrieved content and metadata."
},
"ToolGroup": {
"type": "object",
"properties": {

View file

@ -2039,6 +2039,69 @@ paths:
schema:
$ref: '#/components/schemas/URL'
deprecated: false
/v1/tool-runtime/rag-tool/insert:
post:
responses:
'200':
description: OK
'400':
$ref: '#/components/responses/BadRequest400'
'429':
$ref: >-
#/components/responses/TooManyRequests429
'500':
$ref: >-
#/components/responses/InternalServerError500
default:
$ref: '#/components/responses/DefaultError'
tags:
- ToolRuntime
summary: >-
Index documents so they can be used by the RAG system.
description: >-
Index documents so they can be used by the RAG system.
parameters: []
requestBody:
content:
application/json:
schema:
$ref: '#/components/schemas/InsertRequest'
required: true
deprecated: false
/v1/tool-runtime/rag-tool/query:
post:
responses:
'200':
description: >-
RAGQueryResult containing the retrieved content and metadata
content:
application/json:
schema:
$ref: '#/components/schemas/RAGQueryResult'
'400':
$ref: '#/components/responses/BadRequest400'
'429':
$ref: >-
#/components/responses/TooManyRequests429
'500':
$ref: >-
#/components/responses/InternalServerError500
default:
$ref: '#/components/responses/DefaultError'
tags:
- ToolRuntime
summary: >-
Query the RAG system for context; typically invoked by the agent.
description: >-
Query the RAG system for context; typically invoked by the agent.
parameters: []
requestBody:
content:
application/json:
schema:
$ref: '#/components/schemas/QueryRequest'
required: true
deprecated: false
/v1/toolgroups:
get:
responses:
@ -9858,6 +9921,274 @@ components:
title: ListToolDefsResponse
description: >-
Response containing a list of tool definitions.
RAGDocument:
type: object
properties:
document_id:
type: string
description: The unique identifier for the document.
content:
oneOf:
- type: string
- $ref: '#/components/schemas/InterleavedContentItem'
- type: array
items:
$ref: '#/components/schemas/InterleavedContentItem'
- $ref: '#/components/schemas/URL'
description: The content of the document.
mime_type:
type: string
description: The MIME type of the document.
metadata:
type: object
additionalProperties:
oneOf:
- type: 'null'
- type: boolean
- type: number
- type: string
- type: array
- type: object
description: Additional metadata for the document.
additionalProperties: false
required:
- document_id
- content
- metadata
title: RAGDocument
description: >-
A document to be used for document ingestion in the RAG Tool.
InsertRequest:
type: object
properties:
documents:
type: array
items:
$ref: '#/components/schemas/RAGDocument'
description: >-
List of documents to index in the RAG system
vector_db_id:
type: string
description: >-
ID of the vector database to store the document embeddings
chunk_size_in_tokens:
type: integer
description: >-
(Optional) Size in tokens for document chunking during indexing
additionalProperties: false
required:
- documents
- vector_db_id
- chunk_size_in_tokens
title: InsertRequest
DefaultRAGQueryGeneratorConfig:
type: object
properties:
type:
type: string
const: default
default: default
description: >-
Type of query generator, always 'default'
separator:
type: string
default: ' '
description: >-
String separator used to join query terms
additionalProperties: false
required:
- type
- separator
title: DefaultRAGQueryGeneratorConfig
description: >-
Configuration for the default RAG query generator.
LLMRAGQueryGeneratorConfig:
type: object
properties:
type:
type: string
const: llm
default: llm
description: Type of query generator, always 'llm'
model:
type: string
description: >-
Name of the language model to use for query generation
template:
type: string
description: >-
Template string for formatting the query generation prompt
additionalProperties: false
required:
- type
- model
- template
title: LLMRAGQueryGeneratorConfig
description: >-
Configuration for the LLM-based RAG query generator.
RAGQueryConfig:
type: object
properties:
query_generator_config:
oneOf:
- $ref: '#/components/schemas/DefaultRAGQueryGeneratorConfig'
- $ref: '#/components/schemas/LLMRAGQueryGeneratorConfig'
discriminator:
propertyName: type
mapping:
default: '#/components/schemas/DefaultRAGQueryGeneratorConfig'
llm: '#/components/schemas/LLMRAGQueryGeneratorConfig'
description: Configuration for the query generator.
max_tokens_in_context:
type: integer
default: 4096
description: Maximum number of tokens in the context.
max_chunks:
type: integer
default: 5
description: Maximum number of chunks to retrieve.
chunk_template:
type: string
default: >
Result {index}
Content: {chunk.content}
Metadata: {metadata}
description: >-
Template for formatting each retrieved chunk in the context. Available
placeholders: {index} (1-based chunk ordinal), {chunk.content} (chunk
content string), {metadata} (chunk metadata dict). Default: "Result {index}\nContent:
{chunk.content}\nMetadata: {metadata}\n"
mode:
$ref: '#/components/schemas/RAGSearchMode'
default: vector
description: >-
Search mode for retrieval—either "vector", "keyword", or "hybrid". Default
"vector".
ranker:
$ref: '#/components/schemas/Ranker'
description: >-
Configuration for the ranker to use in hybrid search. Defaults to RRF
ranker.
additionalProperties: false
required:
- query_generator_config
- max_tokens_in_context
- max_chunks
- chunk_template
title: RAGQueryConfig
description: >-
Configuration for the RAG query generation.
RAGSearchMode:
type: string
enum:
- vector
- keyword
- hybrid
title: RAGSearchMode
description: >-
Search modes for RAG query retrieval: - VECTOR: Uses vector similarity search
for semantic matching - KEYWORD: Uses keyword-based search for exact matching
- HYBRID: Combines both vector and keyword search for better results
RRFRanker:
type: object
properties:
type:
type: string
const: rrf
default: rrf
description: The type of ranker, always "rrf"
impact_factor:
type: number
default: 60.0
description: >-
The impact factor for RRF scoring. Higher values give more weight to higher-ranked
results. Must be greater than 0
additionalProperties: false
required:
- type
- impact_factor
title: RRFRanker
description: >-
Reciprocal Rank Fusion (RRF) ranker configuration.
Ranker:
oneOf:
- $ref: '#/components/schemas/RRFRanker'
- $ref: '#/components/schemas/WeightedRanker'
discriminator:
propertyName: type
mapping:
rrf: '#/components/schemas/RRFRanker'
weighted: '#/components/schemas/WeightedRanker'
WeightedRanker:
type: object
properties:
type:
type: string
const: weighted
default: weighted
description: The type of ranker, always "weighted"
alpha:
type: number
default: 0.5
description: >-
Weight factor between 0 and 1. 0 means only use keyword scores, 1 means
only use vector scores, values in between blend both scores.
additionalProperties: false
required:
- type
- alpha
title: WeightedRanker
description: >-
Weighted ranker configuration that combines vector and keyword scores.
QueryRequest:
type: object
properties:
content:
$ref: '#/components/schemas/InterleavedContent'
description: >-
The query content to search for in the indexed documents
vector_db_ids:
type: array
items:
type: string
description: >-
List of vector database IDs to search within
query_config:
$ref: '#/components/schemas/RAGQueryConfig'
description: >-
(Optional) Configuration parameters for the query operation
additionalProperties: false
required:
- content
- vector_db_ids
title: QueryRequest
RAGQueryResult:
type: object
properties:
content:
$ref: '#/components/schemas/InterleavedContent'
description: >-
(Optional) The retrieved content from the query
metadata:
type: object
additionalProperties:
oneOf:
- type: 'null'
- type: boolean
- type: number
- type: string
- type: array
- type: object
description: >-
Additional metadata about the query result
additionalProperties: false
required:
- metadata
title: RAGQueryResult
description: >-
Result of a RAG query containing retrieved content and metadata.
ToolGroup:
type: object
properties:

View file

@ -4,4 +4,5 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .rag_tool import *
from .tools import *

View file

@ -0,0 +1,218 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from enum import Enum, StrEnum
from typing import Annotated, Any, Literal, Protocol
from pydantic import BaseModel, Field, field_validator
from typing_extensions import runtime_checkable
from llama_stack.apis.common.content_types import URL, InterleavedContent
from llama_stack.apis.version import LLAMA_STACK_API_V1
from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol
from llama_stack.schema_utils import json_schema_type, register_schema, webmethod
@json_schema_type
class RRFRanker(BaseModel):
"""
Reciprocal Rank Fusion (RRF) ranker configuration.
:param type: The type of ranker, always "rrf"
:param impact_factor: The impact factor for RRF scoring. Higher values give more weight to higher-ranked results.
Must be greater than 0
"""
type: Literal["rrf"] = "rrf"
impact_factor: float = Field(default=60.0, gt=0.0) # default of 60 for optimal performance
@json_schema_type
class WeightedRanker(BaseModel):
"""
Weighted ranker configuration that combines vector and keyword scores.
:param type: The type of ranker, always "weighted"
:param alpha: Weight factor between 0 and 1.
0 means only use keyword scores,
1 means only use vector scores,
values in between blend both scores.
"""
type: Literal["weighted"] = "weighted"
alpha: float = Field(
default=0.5,
ge=0.0,
le=1.0,
description="Weight factor between 0 and 1. 0 means only keyword scores, 1 means only vector scores.",
)
Ranker = Annotated[
RRFRanker | WeightedRanker,
Field(discriminator="type"),
]
register_schema(Ranker, name="Ranker")
@json_schema_type
class RAGDocument(BaseModel):
"""
A document to be used for document ingestion in the RAG Tool.
:param document_id: The unique identifier for the document.
:param content: The content of the document.
:param mime_type: The MIME type of the document.
:param metadata: Additional metadata for the document.
"""
document_id: str
content: InterleavedContent | URL
mime_type: str | None = None
metadata: dict[str, Any] = Field(default_factory=dict)
@json_schema_type
class RAGQueryResult(BaseModel):
"""Result of a RAG query containing retrieved content and metadata.
:param content: (Optional) The retrieved content from the query
:param metadata: Additional metadata about the query result
"""
content: InterleavedContent | None = None
metadata: dict[str, Any] = Field(default_factory=dict)
@json_schema_type
class RAGQueryGenerator(Enum):
"""Types of query generators for RAG systems.
:cvar default: Default query generator using simple text processing
:cvar llm: LLM-based query generator for enhanced query understanding
:cvar custom: Custom query generator implementation
"""
default = "default"
llm = "llm"
custom = "custom"
@json_schema_type
class RAGSearchMode(StrEnum):
"""
Search modes for RAG query retrieval:
- VECTOR: Uses vector similarity search for semantic matching
- KEYWORD: Uses keyword-based search for exact matching
- HYBRID: Combines both vector and keyword search for better results
"""
VECTOR = "vector"
KEYWORD = "keyword"
HYBRID = "hybrid"
@json_schema_type
class DefaultRAGQueryGeneratorConfig(BaseModel):
"""Configuration for the default RAG query generator.
:param type: Type of query generator, always 'default'
:param separator: String separator used to join query terms
"""
type: Literal["default"] = "default"
separator: str = " "
@json_schema_type
class LLMRAGQueryGeneratorConfig(BaseModel):
"""Configuration for the LLM-based RAG query generator.
:param type: Type of query generator, always 'llm'
:param model: Name of the language model to use for query generation
:param template: Template string for formatting the query generation prompt
"""
type: Literal["llm"] = "llm"
model: str
template: str
RAGQueryGeneratorConfig = Annotated[
DefaultRAGQueryGeneratorConfig | LLMRAGQueryGeneratorConfig,
Field(discriminator="type"),
]
register_schema(RAGQueryGeneratorConfig, name="RAGQueryGeneratorConfig")
@json_schema_type
class RAGQueryConfig(BaseModel):
"""
Configuration for the RAG query generation.
:param query_generator_config: Configuration for the query generator.
:param max_tokens_in_context: Maximum number of tokens in the context.
:param max_chunks: Maximum number of chunks to retrieve.
:param chunk_template: Template for formatting each retrieved chunk in the context.
Available placeholders: {index} (1-based chunk ordinal), {chunk.content} (chunk content string), {metadata} (chunk metadata dict).
Default: "Result {index}\\nContent: {chunk.content}\\nMetadata: {metadata}\\n"
:param mode: Search mode for retrievaleither "vector", "keyword", or "hybrid". Default "vector".
:param ranker: Configuration for the ranker to use in hybrid search. Defaults to RRF ranker.
"""
# This config defines how a query is generated using the messages
# for memory bank retrieval.
query_generator_config: RAGQueryGeneratorConfig = Field(default=DefaultRAGQueryGeneratorConfig())
max_tokens_in_context: int = 4096
max_chunks: int = 5
chunk_template: str = "Result {index}\nContent: {chunk.content}\nMetadata: {metadata}\n"
mode: RAGSearchMode | None = RAGSearchMode.VECTOR
ranker: Ranker | None = Field(default=None) # Only used for hybrid mode
@field_validator("chunk_template")
def validate_chunk_template(cls, v: str) -> str:
if "{chunk.content}" not in v:
raise ValueError("chunk_template must contain {chunk.content}")
if "{index}" not in v:
raise ValueError("chunk_template must contain {index}")
if len(v) == 0:
raise ValueError("chunk_template must not be empty")
return v
@runtime_checkable
@trace_protocol
class RAGToolRuntime(Protocol):
@webmethod(route="/tool-runtime/rag-tool/insert", method="POST", level=LLAMA_STACK_API_V1)
async def insert(
self,
documents: list[RAGDocument],
vector_db_id: str,
chunk_size_in_tokens: int = 512,
) -> None:
"""Index documents so they can be used by the RAG system.
:param documents: List of documents to index in the RAG system
:param vector_db_id: ID of the vector database to store the document embeddings
:param chunk_size_in_tokens: (Optional) Size in tokens for document chunking during indexing
"""
...
@webmethod(route="/tool-runtime/rag-tool/query", method="POST", level=LLAMA_STACK_API_V1)
async def query(
self,
content: InterleavedContent,
vector_db_ids: list[str],
query_config: RAGQueryConfig | None = None,
) -> RAGQueryResult:
"""Query the RAG system for context; typically invoked by the agent.
:param content: The query content to search for in the indexed documents
:param vector_db_ids: List of vector database IDs to search within
:param query_config: (Optional) Configuration parameters for the query operation
:returns: RAGQueryResult containing the retrieved content and metadata
"""
...

View file

@ -4,6 +4,7 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from enum import Enum
from typing import Any, Literal, Protocol
from pydantic import BaseModel
@ -15,6 +16,8 @@ from llama_stack.apis.version import LLAMA_STACK_API_V1
from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol
from llama_stack.schema_utils import json_schema_type, webmethod
from .rag_tool import RAGToolRuntime
@json_schema_type
class ToolDef(BaseModel):
@ -178,11 +181,22 @@ class ToolGroups(Protocol):
...
class SpecialToolGroup(Enum):
"""Special tool groups with predefined functionality.
:cvar rag_tool: Retrieval-Augmented Generation tool group for document search and retrieval
"""
rag_tool = "rag_tool"
@runtime_checkable
@trace_protocol
class ToolRuntime(Protocol):
tool_store: ToolStore | None = None
rag_tool: RAGToolRuntime | None = None
# TODO: This needs to be renamed once OPEN API generator name conflict issue is fixed.
@webmethod(route="/tool-runtime/list-tools", method="GET", level=LLAMA_STACK_API_V1)
async def list_runtime_tools(

View file

@ -8,8 +8,16 @@ from typing import Any
from llama_stack.apis.common.content_types import (
URL,
InterleavedContent,
)
from llama_stack.apis.tools import (
ListToolDefsResponse,
RAGDocument,
RAGQueryConfig,
RAGQueryResult,
RAGToolRuntime,
ToolRuntime,
)
from llama_stack.apis.tools import ListToolDefsResponse, ToolRuntime
from llama_stack.log import get_logger
from ..routing_tables.toolgroups import ToolGroupsRoutingTable
@ -18,6 +26,36 @@ logger = get_logger(name=__name__, category="core::routers")
class ToolRuntimeRouter(ToolRuntime):
class RagToolImpl(RAGToolRuntime):
def __init__(
self,
routing_table: ToolGroupsRoutingTable,
) -> None:
logger.debug("Initializing ToolRuntimeRouter.RagToolImpl")
self.routing_table = routing_table
async def query(
self,
content: InterleavedContent,
vector_store_ids: list[str],
query_config: RAGQueryConfig | None = None,
) -> RAGQueryResult:
logger.debug(f"ToolRuntimeRouter.RagToolImpl.query: {vector_store_ids}")
provider = await self.routing_table.get_provider_impl("knowledge_search")
return await provider.query(content, vector_store_ids, query_config)
async def insert(
self,
documents: list[RAGDocument],
vector_store_id: str,
chunk_size_in_tokens: int = 512,
) -> None:
logger.debug(
f"ToolRuntimeRouter.RagToolImpl.insert: {vector_store_id}, {len(documents)} documents, chunk_size={chunk_size_in_tokens}"
)
provider = await self.routing_table.get_provider_impl("insert_into_memory")
return await provider.insert(documents, vector_store_id, chunk_size_in_tokens)
def __init__(
self,
routing_table: ToolGroupsRoutingTable,
@ -25,6 +63,11 @@ class ToolRuntimeRouter(ToolRuntime):
logger.debug("Initializing ToolRuntimeRouter")
self.routing_table = routing_table
# HACK ALERT this should be in sync with "get_all_api_endpoints()"
self.rag_tool = self.RagToolImpl(routing_table)
for method in ("query", "insert"):
setattr(self, f"rag_tool.{method}", getattr(self.rag_tool, method))
async def initialize(self) -> None:
logger.debug("ToolRuntimeRouter.initialize")
pass

View file

@ -13,6 +13,7 @@ from aiohttp import hdrs
from starlette.routing import Route
from llama_stack.apis.datatypes import Api, ExternalApiSpec
from llama_stack.apis.tools import RAGToolRuntime, SpecialToolGroup
from llama_stack.core.resolver import api_protocol_map
from llama_stack.schema_utils import WebMethod
@ -24,16 +25,33 @@ RouteImpls = dict[str, PathImpl]
RouteMatch = tuple[EndpointFunc, PathParams, str, WebMethod]
def toolgroup_protocol_map():
return {
SpecialToolGroup.rag_tool: RAGToolRuntime,
}
def get_all_api_routes(
external_apis: dict[Api, ExternalApiSpec] | None = None,
) -> dict[Api, list[tuple[Route, WebMethod]]]:
apis = {}
protocols = api_protocol_map(external_apis)
toolgroup_protocols = toolgroup_protocol_map()
for api, protocol in protocols.items():
routes = []
protocol_methods = inspect.getmembers(protocol, predicate=inspect.isfunction)
# HACK ALERT
if api == Api.tool_runtime:
for tool_group in SpecialToolGroup:
sub_protocol = toolgroup_protocols[tool_group]
sub_protocol_methods = inspect.getmembers(sub_protocol, predicate=inspect.isfunction)
for name, method in sub_protocol_methods:
if not hasattr(method, "__webmethod__"):
continue
protocol_methods.append((f"{tool_group.value}.{name}", method))
for name, method in protocol_methods:
# Get all webmethods for this method (supports multiple decorators)
webmethods = getattr(method, "__webmethods__", [])

View file

@ -32,7 +32,7 @@ from llama_stack.apis.scoring_functions import ScoringFunctions
from llama_stack.apis.shields import Shields
from llama_stack.apis.synthetic_data_generation import SyntheticDataGeneration
from llama_stack.apis.telemetry import Telemetry
from llama_stack.apis.tools import ToolGroups, ToolRuntime
from llama_stack.apis.tools import RAGToolRuntime, ToolGroups, ToolRuntime
from llama_stack.apis.vector_io import VectorIO
from llama_stack.core.conversations.conversations import ConversationServiceConfig, ConversationServiceImpl
from llama_stack.core.datatypes import Provider, StackRunConfig, VectorStoresConfig
@ -80,6 +80,7 @@ class LlamaStack(
Inspect,
ToolGroups,
ToolRuntime,
RAGToolRuntime,
Files,
Prompts,
Conversations,

View file

@ -48,6 +48,7 @@ distribution_spec:
tool_runtime:
- provider_type: remote::brave-search
- provider_type: remote::tavily-search
- provider_type: inline::rag-runtime
- provider_type: remote::model-context-protocol
batches:
- provider_type: inline::reference

View file

@ -216,6 +216,8 @@ providers:
config:
api_key: ${env.TAVILY_SEARCH_API_KEY:=}
max_results: 3
- provider_id: rag-runtime
provider_type: inline::rag-runtime
- provider_id: model-context-protocol
provider_type: remote::model-context-protocol
batches:
@ -261,6 +263,8 @@ registered_resources:
tool_groups:
- toolgroup_id: builtin::websearch
provider_id: tavily-search
- toolgroup_id: builtin::rag
provider_id: rag-runtime
server:
port: 8321
telemetry:

View file

@ -26,6 +26,7 @@ distribution_spec:
tool_runtime:
- provider_type: remote::brave-search
- provider_type: remote::tavily-search
- provider_type: inline::rag-runtime
image_type: venv
additional_pip_packages:
- aiosqlite

View file

@ -45,6 +45,7 @@ def get_distribution_template() -> DistributionTemplate:
"tool_runtime": [
BuildProvider(provider_type="remote::brave-search"),
BuildProvider(provider_type="remote::tavily-search"),
BuildProvider(provider_type="inline::rag-runtime"),
],
}
name = "dell"
@ -97,6 +98,10 @@ def get_distribution_template() -> DistributionTemplate:
toolgroup_id="builtin::websearch",
provider_id="brave-search",
),
ToolGroupInput(
toolgroup_id="builtin::rag",
provider_id="rag-runtime",
),
]
return DistributionTemplate(

View file

@ -87,6 +87,8 @@ providers:
config:
api_key: ${env.TAVILY_SEARCH_API_KEY:=}
max_results: 3
- provider_id: rag-runtime
provider_type: inline::rag-runtime
storage:
backends:
kv_default:
@ -131,6 +133,8 @@ registered_resources:
tool_groups:
- toolgroup_id: builtin::websearch
provider_id: brave-search
- toolgroup_id: builtin::rag
provider_id: rag-runtime
server:
port: 8321
telemetry:

View file

@ -83,6 +83,8 @@ providers:
config:
api_key: ${env.TAVILY_SEARCH_API_KEY:=}
max_results: 3
- provider_id: rag-runtime
provider_type: inline::rag-runtime
storage:
backends:
kv_default:
@ -122,6 +124,8 @@ registered_resources:
tool_groups:
- toolgroup_id: builtin::websearch
provider_id: brave-search
- toolgroup_id: builtin::rag
provider_id: rag-runtime
server:
port: 8321
telemetry:

View file

@ -24,6 +24,7 @@ distribution_spec:
tool_runtime:
- provider_type: remote::brave-search
- provider_type: remote::tavily-search
- provider_type: inline::rag-runtime
- provider_type: remote::model-context-protocol
image_type: venv
additional_pip_packages:

View file

@ -47,6 +47,7 @@ def get_distribution_template() -> DistributionTemplate:
"tool_runtime": [
BuildProvider(provider_type="remote::brave-search"),
BuildProvider(provider_type="remote::tavily-search"),
BuildProvider(provider_type="inline::rag-runtime"),
BuildProvider(provider_type="remote::model-context-protocol"),
],
}
@ -91,6 +92,10 @@ def get_distribution_template() -> DistributionTemplate:
toolgroup_id="builtin::websearch",
provider_id="tavily-search",
),
ToolGroupInput(
toolgroup_id="builtin::rag",
provider_id="rag-runtime",
),
]
return DistributionTemplate(

View file

@ -98,6 +98,8 @@ providers:
config:
api_key: ${env.TAVILY_SEARCH_API_KEY:=}
max_results: 3
- provider_id: rag-runtime
provider_type: inline::rag-runtime
- provider_id: model-context-protocol
provider_type: remote::model-context-protocol
storage:
@ -144,6 +146,8 @@ registered_resources:
tool_groups:
- toolgroup_id: builtin::websearch
provider_id: tavily-search
- toolgroup_id: builtin::rag
provider_id: rag-runtime
server:
port: 8321
telemetry:

View file

@ -88,6 +88,8 @@ providers:
config:
api_key: ${env.TAVILY_SEARCH_API_KEY:=}
max_results: 3
- provider_id: rag-runtime
provider_type: inline::rag-runtime
- provider_id: model-context-protocol
provider_type: remote::model-context-protocol
storage:
@ -129,6 +131,8 @@ registered_resources:
tool_groups:
- toolgroup_id: builtin::websearch
provider_id: tavily-search
- toolgroup_id: builtin::rag
provider_id: rag-runtime
server:
port: 8321
telemetry:

View file

@ -19,7 +19,8 @@ distribution_spec:
- provider_type: remote::nvidia
scoring:
- provider_type: inline::basic
tool_runtime: []
tool_runtime:
- provider_type: inline::rag-runtime
files:
- provider_type: inline::localfs
image_type: venv

View file

@ -28,7 +28,7 @@ def get_distribution_template(name: str = "nvidia") -> DistributionTemplate:
BuildProvider(provider_type="remote::nvidia"),
],
"scoring": [BuildProvider(provider_type="inline::basic")],
"tool_runtime": [],
"tool_runtime": [BuildProvider(provider_type="inline::rag-runtime")],
"files": [BuildProvider(provider_type="inline::localfs")],
}
@ -66,7 +66,12 @@ def get_distribution_template(name: str = "nvidia") -> DistributionTemplate:
provider_id="nvidia",
)
default_tool_groups: list[ToolGroupInput] = []
default_tool_groups = [
ToolGroupInput(
toolgroup_id="builtin::rag",
provider_id="rag-runtime",
),
]
return DistributionTemplate(
name=name,

View file

@ -80,7 +80,9 @@ providers:
scoring:
- provider_id: basic
provider_type: inline::basic
tool_runtime: []
tool_runtime:
- provider_id: rag-runtime
provider_type: inline::rag-runtime
files:
- provider_id: meta-reference-files
provider_type: inline::localfs
@ -126,7 +128,9 @@ registered_resources:
datasets: []
scoring_fns: []
benchmarks: []
tool_groups: []
tool_groups:
- toolgroup_id: builtin::rag
provider_id: rag-runtime
server:
port: 8321
telemetry:

View file

@ -69,7 +69,9 @@ providers:
scoring:
- provider_id: basic
provider_type: inline::basic
tool_runtime: []
tool_runtime:
- provider_id: rag-runtime
provider_type: inline::rag-runtime
files:
- provider_id: meta-reference-files
provider_type: inline::localfs
@ -105,7 +107,9 @@ registered_resources:
datasets: []
scoring_fns: []
benchmarks: []
tool_groups: []
tool_groups:
- toolgroup_id: builtin::rag
provider_id: rag-runtime
server:
port: 8321
telemetry:

View file

@ -28,6 +28,7 @@ distribution_spec:
tool_runtime:
- provider_type: remote::brave-search
- provider_type: remote::tavily-search
- provider_type: inline::rag-runtime
- provider_type: remote::model-context-protocol
image_type: venv
additional_pip_packages:

View file

@ -118,6 +118,7 @@ def get_distribution_template() -> DistributionTemplate:
"tool_runtime": [
BuildProvider(provider_type="remote::brave-search"),
BuildProvider(provider_type="remote::tavily-search"),
BuildProvider(provider_type="inline::rag-runtime"),
BuildProvider(provider_type="remote::model-context-protocol"),
],
}
@ -153,6 +154,10 @@ def get_distribution_template() -> DistributionTemplate:
toolgroup_id="builtin::websearch",
provider_id="tavily-search",
),
ToolGroupInput(
toolgroup_id="builtin::rag",
provider_id="rag-runtime",
),
]
models, _ = get_model_registry(available_models)

View file

@ -118,6 +118,8 @@ providers:
config:
api_key: ${env.TAVILY_SEARCH_API_KEY:=}
max_results: 3
- provider_id: rag-runtime
provider_type: inline::rag-runtime
- provider_id: model-context-protocol
provider_type: remote::model-context-protocol
storage:
@ -242,6 +244,8 @@ registered_resources:
tool_groups:
- toolgroup_id: builtin::websearch
provider_id: tavily-search
- toolgroup_id: builtin::rag
provider_id: rag-runtime
server:
port: 8321
telemetry:

View file

@ -14,6 +14,7 @@ distribution_spec:
tool_runtime:
- provider_type: remote::brave-search
- provider_type: remote::tavily-search
- provider_type: inline::rag-runtime
- provider_type: remote::model-context-protocol
image_type: venv
additional_pip_packages:

View file

@ -45,6 +45,7 @@ def get_distribution_template() -> DistributionTemplate:
"tool_runtime": [
BuildProvider(provider_type="remote::brave-search"),
BuildProvider(provider_type="remote::tavily-search"),
BuildProvider(provider_type="inline::rag-runtime"),
BuildProvider(provider_type="remote::model-context-protocol"),
],
}
@ -65,6 +66,10 @@ def get_distribution_template() -> DistributionTemplate:
toolgroup_id="builtin::websearch",
provider_id="tavily-search",
),
ToolGroupInput(
toolgroup_id="builtin::rag",
provider_id="rag-runtime",
),
]
default_models = [

View file

@ -54,6 +54,8 @@ providers:
config:
api_key: ${env.TAVILY_SEARCH_API_KEY:=}
max_results: 3
- provider_id: rag-runtime
provider_type: inline::rag-runtime
- provider_id: model-context-protocol
provider_type: remote::model-context-protocol
storage:
@ -105,6 +107,8 @@ registered_resources:
tool_groups:
- toolgroup_id: builtin::websearch
provider_id: tavily-search
- toolgroup_id: builtin::rag
provider_id: rag-runtime
server:
port: 8321
telemetry:

View file

@ -49,6 +49,7 @@ distribution_spec:
tool_runtime:
- provider_type: remote::brave-search
- provider_type: remote::tavily-search
- provider_type: inline::rag-runtime
- provider_type: remote::model-context-protocol
batches:
- provider_type: inline::reference

View file

@ -219,6 +219,8 @@ providers:
config:
api_key: ${env.TAVILY_SEARCH_API_KEY:=}
max_results: 3
- provider_id: rag-runtime
provider_type: inline::rag-runtime
- provider_id: model-context-protocol
provider_type: remote::model-context-protocol
batches:
@ -264,6 +266,8 @@ registered_resources:
tool_groups:
- toolgroup_id: builtin::websearch
provider_id: tavily-search
- toolgroup_id: builtin::rag
provider_id: rag-runtime
server:
port: 8321
telemetry:

View file

@ -49,6 +49,7 @@ distribution_spec:
tool_runtime:
- provider_type: remote::brave-search
- provider_type: remote::tavily-search
- provider_type: inline::rag-runtime
- provider_type: remote::model-context-protocol
batches:
- provider_type: inline::reference

View file

@ -216,6 +216,8 @@ providers:
config:
api_key: ${env.TAVILY_SEARCH_API_KEY:=}
max_results: 3
- provider_id: rag-runtime
provider_type: inline::rag-runtime
- provider_id: model-context-protocol
provider_type: remote::model-context-protocol
batches:
@ -261,6 +263,8 @@ registered_resources:
tool_groups:
- toolgroup_id: builtin::websearch
provider_id: tavily-search
- toolgroup_id: builtin::rag
provider_id: rag-runtime
server:
port: 8321
telemetry:

View file

@ -140,6 +140,7 @@ def get_distribution_template(name: str = "starter") -> DistributionTemplate:
"tool_runtime": [
BuildProvider(provider_type="remote::brave-search"),
BuildProvider(provider_type="remote::tavily-search"),
BuildProvider(provider_type="inline::rag-runtime"),
BuildProvider(provider_type="remote::model-context-protocol"),
],
"batches": [
@ -161,6 +162,10 @@ def get_distribution_template(name: str = "starter") -> DistributionTemplate:
toolgroup_id="builtin::websearch",
provider_id="tavily-search",
),
ToolGroupInput(
toolgroup_id="builtin::rag",
provider_id="rag-runtime",
),
]
default_shields = [
# if the

View file

@ -23,6 +23,7 @@ distribution_spec:
tool_runtime:
- provider_type: remote::brave-search
- provider_type: remote::tavily-search
- provider_type: inline::rag-runtime
- provider_type: remote::model-context-protocol
files:
- provider_type: inline::localfs

View file

@ -83,6 +83,8 @@ providers:
config:
api_key: ${env.TAVILY_SEARCH_API_KEY:=}
max_results: 3
- provider_id: rag-runtime
provider_type: inline::rag-runtime
- provider_id: model-context-protocol
provider_type: remote::model-context-protocol
files:
@ -123,6 +125,8 @@ registered_resources:
tool_groups:
- toolgroup_id: builtin::websearch
provider_id: tavily-search
- toolgroup_id: builtin::rag
provider_id: rag-runtime
server:
port: 8321
telemetry:

View file

@ -33,6 +33,7 @@ def get_distribution_template(name: str = "watsonx") -> DistributionTemplate:
"tool_runtime": [
BuildProvider(provider_type="remote::brave-search"),
BuildProvider(provider_type="remote::tavily-search"),
BuildProvider(provider_type="inline::rag-runtime"),
BuildProvider(provider_type="remote::model-context-protocol"),
],
"files": [BuildProvider(provider_type="inline::localfs")],
@ -49,6 +50,10 @@ def get_distribution_template(name: str = "watsonx") -> DistributionTemplate:
toolgroup_id="builtin::websearch",
provider_id="tavily-search",
),
ToolGroupInput(
toolgroup_id="builtin::rag",
provider_id="rag-runtime",
),
]
files_provider = Provider(

View file

@ -0,0 +1,5 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

View file

@ -0,0 +1,19 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any
from llama_stack.providers.datatypes import Api
from .config import RagToolRuntimeConfig
async def get_provider_impl(config: RagToolRuntimeConfig, deps: dict[Api, Any]):
from .memory import MemoryToolRuntimeImpl
impl = MemoryToolRuntimeImpl(config, deps[Api.vector_io], deps[Api.inference], deps[Api.files])
await impl.initialize()
return impl

View file

@ -0,0 +1,15 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any
from pydantic import BaseModel
class RagToolRuntimeConfig(BaseModel):
@classmethod
def sample_run_config(cls, __distro_dir__: str, **kwargs: Any) -> dict[str, Any]:
return {}

View file

@ -0,0 +1,77 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from jinja2 import Template
from llama_stack.apis.common.content_types import InterleavedContent
from llama_stack.apis.inference import OpenAIChatCompletionRequestWithExtraBody, OpenAIUserMessageParam
from llama_stack.apis.tools.rag_tool import (
DefaultRAGQueryGeneratorConfig,
LLMRAGQueryGeneratorConfig,
RAGQueryGenerator,
RAGQueryGeneratorConfig,
)
from llama_stack.providers.utils.inference.prompt_adapter import (
interleaved_content_as_str,
)
async def generate_rag_query(
config: RAGQueryGeneratorConfig,
content: InterleavedContent,
**kwargs,
):
"""
Generates a query that will be used for
retrieving relevant information from the memory bank.
"""
if config.type == RAGQueryGenerator.default.value:
query = await default_rag_query_generator(config, content, **kwargs)
elif config.type == RAGQueryGenerator.llm.value:
query = await llm_rag_query_generator(config, content, **kwargs)
else:
raise NotImplementedError(f"Unsupported memory query generator {config.type}")
return query
async def default_rag_query_generator(
config: DefaultRAGQueryGeneratorConfig,
content: InterleavedContent,
**kwargs,
):
return interleaved_content_as_str(content, sep=config.separator)
async def llm_rag_query_generator(
config: LLMRAGQueryGeneratorConfig,
content: InterleavedContent,
**kwargs,
):
assert "inference_api" in kwargs, "LLMRAGQueryGenerator needs inference_api"
inference_api = kwargs["inference_api"]
messages = []
if isinstance(content, list):
messages = [interleaved_content_as_str(m) for m in content]
else:
messages = [interleaved_content_as_str(content)]
template = Template(config.template)
rendered_content: str = template.render({"messages": messages})
model = config.model
message = OpenAIUserMessageParam(content=rendered_content)
params = OpenAIChatCompletionRequestWithExtraBody(
model=model,
messages=[message],
stream=False,
)
response = await inference_api.openai_chat_completion(params)
query = response.choices[0].message.content
return query

View file

@ -0,0 +1,332 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import asyncio
import base64
import io
import mimetypes
from typing import Any
import httpx
from fastapi import UploadFile
from pydantic import TypeAdapter
from llama_stack.apis.common.content_types import (
URL,
InterleavedContent,
InterleavedContentItem,
TextContentItem,
)
from llama_stack.apis.files import Files, OpenAIFilePurpose
from llama_stack.apis.inference import Inference
from llama_stack.apis.tools import (
ListToolDefsResponse,
RAGDocument,
RAGQueryConfig,
RAGQueryResult,
RAGToolRuntime,
ToolDef,
ToolGroup,
ToolInvocationResult,
ToolRuntime,
)
from llama_stack.apis.vector_io import (
QueryChunksResponse,
VectorIO,
VectorStoreChunkingStrategyStatic,
VectorStoreChunkingStrategyStaticConfig,
)
from llama_stack.log import get_logger
from llama_stack.providers.datatypes import ToolGroupsProtocolPrivate
from llama_stack.providers.utils.inference.prompt_adapter import interleaved_content_as_str
from llama_stack.providers.utils.memory.vector_store import parse_data_url
from .config import RagToolRuntimeConfig
from .context_retriever import generate_rag_query
log = get_logger(name=__name__, category="tool_runtime")
async def raw_data_from_doc(doc: RAGDocument) -> tuple[bytes, str]:
"""Get raw binary data and mime type from a RAGDocument for file upload."""
if isinstance(doc.content, URL):
if doc.content.uri.startswith("data:"):
parts = parse_data_url(doc.content.uri)
mime_type = parts["mimetype"]
data = parts["data"]
if parts["is_base64"]:
file_data = base64.b64decode(data)
else:
file_data = data.encode("utf-8")
return file_data, mime_type
else:
async with httpx.AsyncClient() as client:
r = await client.get(doc.content.uri)
r.raise_for_status()
mime_type = r.headers.get("content-type", "application/octet-stream")
return r.content, mime_type
else:
if isinstance(doc.content, str):
content_str = doc.content
else:
content_str = interleaved_content_as_str(doc.content)
if content_str.startswith("data:"):
parts = parse_data_url(content_str)
mime_type = parts["mimetype"]
data = parts["data"]
if parts["is_base64"]:
file_data = base64.b64decode(data)
else:
file_data = data.encode("utf-8")
return file_data, mime_type
else:
return content_str.encode("utf-8"), "text/plain"
class MemoryToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime, RAGToolRuntime):
def __init__(
self,
config: RagToolRuntimeConfig,
vector_io_api: VectorIO,
inference_api: Inference,
files_api: Files,
):
self.config = config
self.vector_io_api = vector_io_api
self.inference_api = inference_api
self.files_api = files_api
async def initialize(self):
pass
async def shutdown(self):
pass
async def register_toolgroup(self, toolgroup: ToolGroup) -> None:
pass
async def unregister_toolgroup(self, toolgroup_id: str) -> None:
return
async def insert(
self,
documents: list[RAGDocument],
vector_db_id: str,
chunk_size_in_tokens: int = 512,
) -> None:
if not documents:
return
for doc in documents:
try:
try:
file_data, mime_type = await raw_data_from_doc(doc)
except Exception as e:
log.error(f"Failed to extract content from document {doc.document_id}: {e}")
continue
file_extension = mimetypes.guess_extension(mime_type) or ".txt"
filename = doc.metadata.get("filename", f"{doc.document_id}{file_extension}")
file_obj = io.BytesIO(file_data)
file_obj.name = filename
upload_file = UploadFile(file=file_obj, filename=filename)
try:
created_file = await self.files_api.openai_upload_file(
file=upload_file, purpose=OpenAIFilePurpose.ASSISTANTS
)
except Exception as e:
log.error(f"Failed to upload file for document {doc.document_id}: {e}")
continue
chunking_strategy = VectorStoreChunkingStrategyStatic(
static=VectorStoreChunkingStrategyStaticConfig(
max_chunk_size_tokens=chunk_size_in_tokens,
chunk_overlap_tokens=chunk_size_in_tokens // 4,
)
)
try:
await self.vector_io_api.openai_attach_file_to_vector_store(
vector_store_id=vector_db_id,
file_id=created_file.id,
attributes=doc.metadata,
chunking_strategy=chunking_strategy,
)
except Exception as e:
log.error(
f"Failed to attach file {created_file.id} to vector store {vector_db_id} for document {doc.document_id}: {e}"
)
continue
except Exception as e:
log.error(f"Unexpected error processing document {doc.document_id}: {e}")
continue
async def query(
self,
content: InterleavedContent,
vector_db_ids: list[str],
query_config: RAGQueryConfig | None = None,
) -> RAGQueryResult:
if not vector_db_ids:
raise ValueError(
"No vector DBs were provided to the knowledge search tool. Please provide at least one vector DB ID."
)
query_config = query_config or RAGQueryConfig()
query = await generate_rag_query(
query_config.query_generator_config,
content,
inference_api=self.inference_api,
)
tasks = [
self.vector_io_api.query_chunks(
vector_db_id=vector_db_id,
query=query,
params={
"mode": query_config.mode,
"max_chunks": query_config.max_chunks,
"score_threshold": 0.0,
"ranker": query_config.ranker,
},
)
for vector_db_id in vector_db_ids
]
results: list[QueryChunksResponse] = await asyncio.gather(*tasks)
chunks = []
scores = []
for vector_db_id, result in zip(vector_db_ids, results, strict=False):
for chunk, score in zip(result.chunks, result.scores, strict=False):
if not hasattr(chunk, "metadata") or chunk.metadata is None:
chunk.metadata = {}
chunk.metadata["vector_db_id"] = vector_db_id
chunks.append(chunk)
scores.append(score)
if not chunks:
return RAGQueryResult(content=None)
# sort by score
chunks, scores = zip(*sorted(zip(chunks, scores, strict=False), key=lambda x: x[1], reverse=True), strict=False) # type: ignore
chunks = chunks[: query_config.max_chunks]
tokens = 0
picked: list[InterleavedContentItem] = [
TextContentItem(
text=f"knowledge_search tool found {len(chunks)} chunks:\nBEGIN of knowledge_search tool results.\n"
)
]
for i, chunk in enumerate(chunks):
metadata = chunk.metadata
tokens += metadata.get("token_count", 0)
tokens += metadata.get("metadata_token_count", 0)
if tokens > query_config.max_tokens_in_context:
log.error(
f"Using {len(picked)} chunks; reached max tokens in context: {tokens}",
)
break
# Add useful keys from chunk_metadata to metadata and remove some from metadata
chunk_metadata_keys_to_include_from_context = [
"chunk_id",
"document_id",
"source",
]
metadata_keys_to_exclude_from_context = [
"token_count",
"metadata_token_count",
"vector_db_id",
]
metadata_for_context = {}
for k in chunk_metadata_keys_to_include_from_context:
metadata_for_context[k] = getattr(chunk.chunk_metadata, k)
for k in metadata:
if k not in metadata_keys_to_exclude_from_context:
metadata_for_context[k] = metadata[k]
text_content = query_config.chunk_template.format(index=i + 1, chunk=chunk, metadata=metadata_for_context)
picked.append(TextContentItem(text=text_content))
picked.append(TextContentItem(text="END of knowledge_search tool results.\n"))
picked.append(
TextContentItem(
text=f'The above results were retrieved to help answer the user\'s query: "{interleaved_content_as_str(content)}". Use them as supporting information only in answering this query.\n',
)
)
return RAGQueryResult(
content=picked,
metadata={
"document_ids": [c.document_id for c in chunks[: len(picked)]],
"chunks": [c.content for c in chunks[: len(picked)]],
"scores": scores[: len(picked)],
"vector_db_ids": [c.metadata["vector_db_id"] for c in chunks[: len(picked)]],
},
)
async def list_runtime_tools(
self, tool_group_id: str | None = None, mcp_endpoint: URL | None = None
) -> ListToolDefsResponse:
# Parameters are not listed since these methods are not yet invoked automatically
# by the LLM. The method is only implemented so things like /tools can list without
# encountering fatals.
return ListToolDefsResponse(
data=[
ToolDef(
name="insert_into_memory",
description="Insert documents into memory",
),
ToolDef(
name="knowledge_search",
description="Search for information in a database.",
input_schema={
"type": "object",
"properties": {
"query": {
"type": "string",
"description": "The query to search for. Can be a natural language sentence or keywords.",
}
},
"required": ["query"],
},
),
]
)
async def invoke_tool(self, tool_name: str, kwargs: dict[str, Any]) -> ToolInvocationResult:
vector_db_ids = kwargs.get("vector_db_ids", [])
query_config = kwargs.get("query_config")
if query_config:
query_config = TypeAdapter(RAGQueryConfig).validate_python(query_config)
else:
query_config = RAGQueryConfig()
query = kwargs["query"]
result = await self.query(
content=query,
vector_db_ids=vector_db_ids,
query_config=query_config,
)
return ToolInvocationResult(
content=result.content or [],
metadata={
**(result.metadata or {}),
"citation_files": getattr(result, "citation_files", None),
},
)

View file

@ -42,7 +42,6 @@ def available_providers() -> list[ProviderSpec]:
# CrossEncoder depends on torchao.quantization
pip_packages=[
"torch torchvision torchao>=0.12.0 --extra-index-url https://download.pytorch.org/whl/cpu",
"numpy tqdm transformers",
"sentence-transformers --no-deps",
# required by some SentenceTransformers architectures for tensor rearrange/merge ops
"einops",

View file

@ -7,13 +7,33 @@
from llama_stack.providers.datatypes import (
Api,
InlineProviderSpec,
ProviderSpec,
RemoteProviderSpec,
)
from llama_stack.providers.registry.vector_io import DEFAULT_VECTOR_IO_DEPS
def available_providers() -> list[ProviderSpec]:
return [
InlineProviderSpec(
api=Api.tool_runtime,
provider_type="inline::rag-runtime",
pip_packages=DEFAULT_VECTOR_IO_DEPS
+ [
"tqdm",
"numpy",
"scikit-learn",
"scipy",
"nltk",
"sentencepiece",
"transformers",
],
module="llama_stack.providers.inline.tool_runtime.rag",
config_class="llama_stack.providers.inline.tool_runtime.rag.config.RagToolRuntimeConfig",
api_dependencies=[Api.vector_io, Api.inference, Api.files],
description="RAG (Retrieval-Augmented Generation) tool runtime for document ingestion, chunking, and semantic search.",
),
RemoteProviderSpec(
api=Api.tool_runtime,
adapter_type="brave-search",

View file

@ -119,7 +119,7 @@ Datasets that can fit in memory, frequent reads | Faiss | Optimized for speed, i
#### Empirical Example
Consider the histogram below in which 10,000 randomly generated strings were inserted
in batches of 100 into both Faiss and sqlite-vec.
in batches of 100 into both Faiss and sqlite-vec using `client.tool_runtime.rag_tool.insert()`.
```{image} ../../../../_static/providers/vector_io/write_time_comparison_sqlite-vec-faiss.png
:alt: Comparison of SQLite-Vec and Faiss write times

View file

@ -12,14 +12,17 @@ from dataclasses import dataclass
from typing import Any
from urllib.parse import unquote
import httpx
import numpy as np
from numpy.typing import NDArray
from pydantic import BaseModel
from llama_stack.apis.common.content_types import (
URL,
InterleavedContent,
)
from llama_stack.apis.inference import OpenAIEmbeddingsRequestWithExtraBody
from llama_stack.apis.tools import RAGDocument
from llama_stack.apis.vector_io import Chunk, ChunkMetadata, QueryChunksResponse
from llama_stack.apis.vector_stores import VectorStore
from llama_stack.log import get_logger
@ -126,6 +129,31 @@ def content_from_data_and_mime_type(data: bytes | str, mime_type: str | None, en
return ""
async def content_from_doc(doc: RAGDocument) -> str:
if isinstance(doc.content, URL):
if doc.content.uri.startswith("data:"):
return content_from_data(doc.content.uri)
async with httpx.AsyncClient() as client:
r = await client.get(doc.content.uri)
if doc.mime_type == "application/pdf":
return parse_pdf(r.content)
return r.text
elif isinstance(doc.content, str):
pattern = re.compile("^(https?://|file://|data:)")
if pattern.match(doc.content):
if doc.content.startswith("data:"):
return content_from_data(doc.content)
async with httpx.AsyncClient() as client:
r = await client.get(doc.content)
if doc.mime_type == "application/pdf":
return parse_pdf(r.content)
return r.text
return doc.content
else:
# will raise ValueError if the content is not List[InterleavedContent] or InterleavedContent
return interleaved_content_as_str(doc.content)
def make_overlapped_chunks(
document_id: str, text: str, window_len: int, overlap_len: int, metadata: dict[str, Any]
) -> list[Chunk]:

View file

@ -4,11 +4,138 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from unittest.mock import patch
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from llama_stack.providers.utils.memory.vector_store import content_from_data_and_mime_type
from llama_stack.apis.common.content_types import URL, TextContentItem
from llama_stack.apis.tools import RAGDocument
from llama_stack.providers.utils.memory.vector_store import content_from_data_and_mime_type, content_from_doc
async def test_content_from_doc_with_url():
"""Test extracting content from RAGDocument with URL content."""
mock_url = URL(uri="https://example.com")
mock_doc = RAGDocument(document_id="foo", content=mock_url)
mock_response = MagicMock()
mock_response.text = "Sample content from URL"
with patch("httpx.AsyncClient") as mock_client:
mock_instance = AsyncMock()
mock_instance.get.return_value = mock_response
mock_client.return_value.__aenter__.return_value = mock_instance
result = await content_from_doc(mock_doc)
assert result == "Sample content from URL"
mock_instance.get.assert_called_once_with(mock_url.uri)
async def test_content_from_doc_with_pdf_url():
"""Test extracting content from RAGDocument with URL pointing to a PDF."""
mock_url = URL(uri="https://example.com/document.pdf")
mock_doc = RAGDocument(document_id="foo", content=mock_url, mime_type="application/pdf")
mock_response = MagicMock()
mock_response.content = b"PDF binary data"
with (
patch("httpx.AsyncClient") as mock_client,
patch("llama_stack.providers.utils.memory.vector_store.parse_pdf") as mock_parse_pdf,
):
mock_instance = AsyncMock()
mock_instance.get.return_value = mock_response
mock_client.return_value.__aenter__.return_value = mock_instance
mock_parse_pdf.return_value = "Extracted PDF content"
result = await content_from_doc(mock_doc)
assert result == "Extracted PDF content"
mock_instance.get.assert_called_once_with(mock_url.uri)
mock_parse_pdf.assert_called_once_with(b"PDF binary data")
async def test_content_from_doc_with_data_url():
"""Test extracting content from RAGDocument with data URL content."""
data_url = "data:text/plain;base64,SGVsbG8gV29ybGQ=" # "Hello World" base64 encoded
mock_url = URL(uri=data_url)
mock_doc = RAGDocument(document_id="foo", content=mock_url)
with patch("llama_stack.providers.utils.memory.vector_store.content_from_data") as mock_content_from_data:
mock_content_from_data.return_value = "Hello World"
result = await content_from_doc(mock_doc)
assert result == "Hello World"
mock_content_from_data.assert_called_once_with(data_url)
async def test_content_from_doc_with_string():
"""Test extracting content from RAGDocument with string content."""
content_string = "This is plain text content"
mock_doc = RAGDocument(document_id="foo", content=content_string)
result = await content_from_doc(mock_doc)
assert result == content_string
async def test_content_from_doc_with_string_url():
"""Test extracting content from RAGDocument with string URL content."""
url_string = "https://example.com"
mock_doc = RAGDocument(document_id="foo", content=url_string)
mock_response = MagicMock()
mock_response.text = "Sample content from URL string"
with patch("httpx.AsyncClient") as mock_client:
mock_instance = AsyncMock()
mock_instance.get.return_value = mock_response
mock_client.return_value.__aenter__.return_value = mock_instance
result = await content_from_doc(mock_doc)
assert result == "Sample content from URL string"
mock_instance.get.assert_called_once_with(url_string)
async def test_content_from_doc_with_string_pdf_url():
"""Test extracting content from RAGDocument with string URL pointing to a PDF."""
url_string = "https://example.com/document.pdf"
mock_doc = RAGDocument(document_id="foo", content=url_string, mime_type="application/pdf")
mock_response = MagicMock()
mock_response.content = b"PDF binary data"
with (
patch("httpx.AsyncClient") as mock_client,
patch("llama_stack.providers.utils.memory.vector_store.parse_pdf") as mock_parse_pdf,
):
mock_instance = AsyncMock()
mock_instance.get.return_value = mock_response
mock_client.return_value.__aenter__.return_value = mock_instance
mock_parse_pdf.return_value = "Extracted PDF content from string URL"
result = await content_from_doc(mock_doc)
assert result == "Extracted PDF content from string URL"
mock_instance.get.assert_called_once_with(url_string)
mock_parse_pdf.assert_called_once_with(b"PDF binary data")
async def test_content_from_doc_with_interleaved_content():
"""Test extracting content from RAGDocument with InterleavedContent (the new case added in the commit)."""
interleaved_content = [TextContentItem(text="First item"), TextContentItem(text="Second item")]
mock_doc = RAGDocument(document_id="foo", content=interleaved_content)
with patch("llama_stack.providers.utils.memory.vector_store.interleaved_content_as_str") as mock_interleaved:
mock_interleaved.return_value = "First item\nSecond item"
result = await content_from_doc(mock_doc)
assert result == "First item\nSecond item"
mock_interleaved.assert_called_once_with(interleaved_content)
def test_content_from_data_and_mime_type_success_utf8():
@ -51,3 +178,41 @@ def test_content_from_data_and_mime_type_both_encodings_fail():
# Should raise an exception instead of returning empty string
with pytest.raises(UnicodeDecodeError):
content_from_data_and_mime_type(data, mime_type)
async def test_memory_tool_error_handling():
"""Test that memory tool handles various failures gracefully without crashing."""
from llama_stack.providers.inline.tool_runtime.rag.config import RagToolRuntimeConfig
from llama_stack.providers.inline.tool_runtime.rag.memory import MemoryToolRuntimeImpl
config = RagToolRuntimeConfig()
memory_tool = MemoryToolRuntimeImpl(
config=config,
vector_io_api=AsyncMock(),
inference_api=AsyncMock(),
files_api=AsyncMock(),
)
docs = [
RAGDocument(document_id="good_doc", content="Good content", metadata={}),
RAGDocument(document_id="bad_url_doc", content=URL(uri="https://bad.url"), metadata={}),
RAGDocument(document_id="another_good_doc", content="Another good content", metadata={}),
]
mock_file1 = MagicMock()
mock_file1.id = "file_good1"
mock_file2 = MagicMock()
mock_file2.id = "file_good2"
memory_tool.files_api.openai_upload_file.side_effect = [mock_file1, mock_file2]
with patch("httpx.AsyncClient") as mock_client:
mock_instance = AsyncMock()
mock_instance.get.side_effect = Exception("Bad URL")
mock_client.return_value.__aenter__.return_value = mock_instance
# won't raise exception despite one document failing
await memory_tool.insert(docs, "vector_store_123")
# processed 2 documents successfully, skipped 1
assert memory_tool.files_api.openai_upload_file.call_count == 2
assert memory_tool.vector_io_api.openai_attach_file_to_vector_store.call_count == 2

View file

@ -0,0 +1,138 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from unittest.mock import AsyncMock, MagicMock
import pytest
from llama_stack.apis.tools.rag_tool import RAGQueryConfig
from llama_stack.apis.vector_io import (
Chunk,
ChunkMetadata,
QueryChunksResponse,
)
from llama_stack.providers.inline.tool_runtime.rag.memory import MemoryToolRuntimeImpl
class TestRagQuery:
async def test_query_raises_on_empty_vector_store_ids(self):
rag_tool = MemoryToolRuntimeImpl(
config=MagicMock(), vector_io_api=MagicMock(), inference_api=MagicMock(), files_api=MagicMock()
)
with pytest.raises(ValueError):
await rag_tool.query(content=MagicMock(), vector_db_ids=[])
async def test_query_chunk_metadata_handling(self):
rag_tool = MemoryToolRuntimeImpl(
config=MagicMock(), vector_io_api=MagicMock(), inference_api=MagicMock(), files_api=MagicMock()
)
content = "test query content"
vector_db_ids = ["db1"]
chunk_metadata = ChunkMetadata(
document_id="doc1",
chunk_id="chunk1",
source="test_source",
metadata_token_count=5,
)
interleaved_content = MagicMock()
chunk = Chunk(
content=interleaved_content,
metadata={
"key1": "value1",
"token_count": 10,
"metadata_token_count": 5,
# Note this is inserted into `metadata` during MemoryToolRuntimeImpl().insert()
"document_id": "doc1",
},
stored_chunk_id="chunk1",
chunk_metadata=chunk_metadata,
)
query_response = QueryChunksResponse(chunks=[chunk], scores=[1.0])
rag_tool.vector_io_api.query_chunks = AsyncMock(return_value=query_response)
result = await rag_tool.query(content=content, vector_db_ids=vector_db_ids)
assert result is not None
expected_metadata_string = (
"Metadata: {'chunk_id': 'chunk1', 'document_id': 'doc1', 'source': 'test_source', 'key1': 'value1'}"
)
assert expected_metadata_string in result.content[1].text
assert result.content is not None
async def test_query_raises_incorrect_mode(self):
with pytest.raises(ValueError):
RAGQueryConfig(mode="invalid_mode")
async def test_query_accepts_valid_modes(self):
default_config = RAGQueryConfig() # Test default (vector)
assert default_config.mode == "vector"
vector_config = RAGQueryConfig(mode="vector") # Test vector
assert vector_config.mode == "vector"
keyword_config = RAGQueryConfig(mode="keyword") # Test keyword
assert keyword_config.mode == "keyword"
hybrid_config = RAGQueryConfig(mode="hybrid") # Test hybrid
assert hybrid_config.mode == "hybrid"
# Test that invalid mode raises an error
with pytest.raises(ValueError):
RAGQueryConfig(mode="wrong_mode")
async def test_query_adds_vector_store_id_to_chunk_metadata(self):
rag_tool = MemoryToolRuntimeImpl(
config=MagicMock(),
vector_io_api=MagicMock(),
inference_api=MagicMock(),
files_api=MagicMock(),
)
vector_db_ids = ["db1", "db2"]
# Fake chunks from each DB
chunk_metadata1 = ChunkMetadata(
document_id="doc1",
chunk_id="chunk1",
source="test_source1",
metadata_token_count=5,
)
chunk1 = Chunk(
content="chunk from db1",
metadata={"vector_db_id": "db1", "document_id": "doc1"},
stored_chunk_id="c1",
chunk_metadata=chunk_metadata1,
)
chunk_metadata2 = ChunkMetadata(
document_id="doc2",
chunk_id="chunk2",
source="test_source2",
metadata_token_count=5,
)
chunk2 = Chunk(
content="chunk from db2",
metadata={"vector_db_id": "db2", "document_id": "doc2"},
stored_chunk_id="c2",
chunk_metadata=chunk_metadata2,
)
rag_tool.vector_io_api.query_chunks = AsyncMock(
side_effect=[
QueryChunksResponse(chunks=[chunk1], scores=[0.9]),
QueryChunksResponse(chunks=[chunk2], scores=[0.8]),
]
)
result = await rag_tool.query(content="test", vector_db_ids=vector_db_ids)
returned_chunks = result.metadata["chunks"]
returned_scores = result.metadata["scores"]
returned_doc_ids = result.metadata["document_ids"]
returned_vector_db_ids = result.metadata["vector_db_ids"]
assert returned_chunks == ["chunk from db1", "chunk from db2"]
assert returned_scores == (0.9, 0.8)
assert returned_doc_ids == ["doc1", "doc2"]
assert returned_vector_db_ids == ["db1", "db2"]

View file

@ -4,6 +4,10 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import base64
import mimetypes
import os
from pathlib import Path
from unittest.mock import AsyncMock, MagicMock
import numpy as np
@ -13,13 +17,37 @@ from llama_stack.apis.inference.inference import (
OpenAIEmbeddingData,
OpenAIEmbeddingsRequestWithExtraBody,
)
from llama_stack.apis.tools import RAGDocument
from llama_stack.apis.vector_io import Chunk
from llama_stack.providers.utils.memory.vector_store import (
URL,
VectorStoreWithIndex,
_validate_embedding,
content_from_doc,
make_overlapped_chunks,
)
DUMMY_PDF_PATH = Path(os.path.abspath(__file__)).parent / "fixtures" / "dummy.pdf"
# Depending on the machine, this can get parsed a couple of ways
DUMMY_PDF_TEXT_CHOICES = ["Dummy PDF file", "Dumm y PDF file"]
def read_file(file_path: str) -> bytes:
with open(file_path, "rb") as file:
return file.read()
def data_url_from_file(file_path: str) -> str:
with open(file_path, "rb") as file:
file_content = file.read()
base64_content = base64.b64encode(file_content).decode("utf-8")
mime_type, _ = mimetypes.guess_type(file_path)
data_url = f"data:{mime_type};base64,{base64_content}"
return data_url
class TestChunk:
def test_chunk(self):
@ -88,6 +116,45 @@ class TestValidateEmbedding:
class TestVectorStore:
async def test_returns_content_from_pdf_data_uri(self):
data_uri = data_url_from_file(DUMMY_PDF_PATH)
doc = RAGDocument(
document_id="dummy",
content=data_uri,
mime_type="application/pdf",
metadata={},
)
content = await content_from_doc(doc)
assert content in DUMMY_PDF_TEXT_CHOICES
@pytest.mark.allow_network
async def test_downloads_pdf_and_returns_content(self):
# Using GitHub to host the PDF file
url = "https://raw.githubusercontent.com/meta-llama/llama-stack/da035d69cfca915318eaf485770a467ca3c2a238/llama_stack/providers/tests/memory/fixtures/dummy.pdf"
doc = RAGDocument(
document_id="dummy",
content=url,
mime_type="application/pdf",
metadata={},
)
content = await content_from_doc(doc)
assert content in DUMMY_PDF_TEXT_CHOICES
@pytest.mark.allow_network
async def test_downloads_pdf_and_returns_content_with_url_object(self):
# Using GitHub to host the PDF file
url = "https://raw.githubusercontent.com/meta-llama/llama-stack/da035d69cfca915318eaf485770a467ca3c2a238/llama_stack/providers/tests/memory/fixtures/dummy.pdf"
doc = RAGDocument(
document_id="dummy",
content=URL(
uri=url,
),
mime_type="application/pdf",
metadata={},
)
content = await content_from_doc(doc)
assert content in DUMMY_PDF_TEXT_CHOICES
@pytest.mark.parametrize(
"window_len, overlap_len, expected_chunks",
[