From bd3c4732080627c55f63064a2af795b672919b80 Mon Sep 17 00:00:00 2001 From: Ashwin Bharambe Date: Tue, 21 Oct 2025 11:22:06 -0700 Subject: [PATCH] revert: "chore(cleanup)!: remove tool_runtime.rag_tool" (#3877) Reverts llamastack/llama-stack#3871 This PR broke RAG (even from Responses -- there _is_ a dependency) --- client-sdks/stainless/openapi.yml | 331 ++++++++++++++ .../self_hosted_distro/meta-reference-gpu.md | 2 +- .../self_hosted_distro/nvidia.md | 2 +- .../providers/vector_io/inline_sqlite-vec.mdx | 4 +- .../openapi_generator/pyopenapi/operations.py | 6 + docs/static/llama-stack-spec.html | 423 ++++++++++++++++++ docs/static/llama-stack-spec.yaml | 331 ++++++++++++++ docs/static/stainless-llama-stack-spec.html | 423 ++++++++++++++++++ docs/static/stainless-llama-stack-spec.yaml | 331 ++++++++++++++ llama_stack/apis/tools/__init__.py | 1 + llama_stack/apis/tools/rag_tool.py | 218 +++++++++ llama_stack/apis/tools/tools.py | 14 + llama_stack/core/routers/tool_runtime.py | 45 +- llama_stack/core/server/routes.py | 18 + llama_stack/core/stack.py | 3 +- llama_stack/distributions/ci-tests/build.yaml | 1 + llama_stack/distributions/ci-tests/run.yaml | 4 + llama_stack/distributions/dell/build.yaml | 1 + llama_stack/distributions/dell/dell.py | 5 + .../distributions/dell/run-with-safety.yaml | 4 + llama_stack/distributions/dell/run.yaml | 4 + .../meta-reference-gpu/build.yaml | 1 + .../meta-reference-gpu/meta_reference.py | 5 + .../meta-reference-gpu/run-with-safety.yaml | 4 + .../distributions/meta-reference-gpu/run.yaml | 4 + llama_stack/distributions/nvidia/build.yaml | 3 +- llama_stack/distributions/nvidia/nvidia.py | 9 +- .../distributions/nvidia/run-with-safety.yaml | 8 +- llama_stack/distributions/nvidia/run.yaml | 8 +- .../distributions/open-benchmark/build.yaml | 1 + .../open-benchmark/open_benchmark.py | 5 + .../distributions/open-benchmark/run.yaml | 4 + .../distributions/postgres-demo/build.yaml | 1 + .../postgres-demo/postgres_demo.py | 5 + .../distributions/postgres-demo/run.yaml | 4 + .../distributions/starter-gpu/build.yaml | 1 + .../distributions/starter-gpu/run.yaml | 4 + llama_stack/distributions/starter/build.yaml | 1 + llama_stack/distributions/starter/run.yaml | 4 + llama_stack/distributions/starter/starter.py | 5 + llama_stack/distributions/watsonx/build.yaml | 1 + llama_stack/distributions/watsonx/run.yaml | 4 + llama_stack/distributions/watsonx/watsonx.py | 5 + .../providers/inline/tool_runtime/__init__.py | 5 + .../inline/tool_runtime/rag/__init__.py | 19 + .../inline/tool_runtime/rag/config.py | 15 + .../tool_runtime/rag/context_retriever.py | 77 ++++ .../inline/tool_runtime/rag/memory.py | 332 ++++++++++++++ llama_stack/providers/registry/inference.py | 1 - .../providers/registry/tool_runtime.py | 20 + llama_stack/providers/registry/vector_io.py | 2 +- .../providers/utils/memory/vector_store.py | 28 ++ .../utils/memory/test_vector_store.py | 169 ++++++- tests/unit/rag/test_rag_query.py | 138 ++++++ tests/unit/rag/test_vector_store.py | 67 +++ 55 files changed, 3114 insertions(+), 17 deletions(-) create mode 100644 llama_stack/apis/tools/rag_tool.py create mode 100644 llama_stack/providers/inline/tool_runtime/__init__.py create mode 100644 llama_stack/providers/inline/tool_runtime/rag/__init__.py create mode 100644 llama_stack/providers/inline/tool_runtime/rag/config.py create mode 100644 llama_stack/providers/inline/tool_runtime/rag/context_retriever.py create mode 100644 llama_stack/providers/inline/tool_runtime/rag/memory.py create mode 100644 tests/unit/rag/test_rag_query.py diff --git a/client-sdks/stainless/openapi.yml b/client-sdks/stainless/openapi.yml index 98a309f12..93049a14a 100644 --- a/client-sdks/stainless/openapi.yml +++ b/client-sdks/stainless/openapi.yml @@ -2039,6 +2039,69 @@ paths: schema: $ref: '#/components/schemas/URL' deprecated: false + /v1/tool-runtime/rag-tool/insert: + post: + responses: + '200': + description: OK + '400': + $ref: '#/components/responses/BadRequest400' + '429': + $ref: >- + #/components/responses/TooManyRequests429 + '500': + $ref: >- + #/components/responses/InternalServerError500 + default: + $ref: '#/components/responses/DefaultError' + tags: + - ToolRuntime + summary: >- + Index documents so they can be used by the RAG system. + description: >- + Index documents so they can be used by the RAG system. + parameters: [] + requestBody: + content: + application/json: + schema: + $ref: '#/components/schemas/InsertRequest' + required: true + deprecated: false + /v1/tool-runtime/rag-tool/query: + post: + responses: + '200': + description: >- + RAGQueryResult containing the retrieved content and metadata + content: + application/json: + schema: + $ref: '#/components/schemas/RAGQueryResult' + '400': + $ref: '#/components/responses/BadRequest400' + '429': + $ref: >- + #/components/responses/TooManyRequests429 + '500': + $ref: >- + #/components/responses/InternalServerError500 + default: + $ref: '#/components/responses/DefaultError' + tags: + - ToolRuntime + summary: >- + Query the RAG system for context; typically invoked by the agent. + description: >- + Query the RAG system for context; typically invoked by the agent. + parameters: [] + requestBody: + content: + application/json: + schema: + $ref: '#/components/schemas/QueryRequest' + required: true + deprecated: false /v1/toolgroups: get: responses: @@ -9858,6 +9921,274 @@ components: title: ListToolDefsResponse description: >- Response containing a list of tool definitions. + RAGDocument: + type: object + properties: + document_id: + type: string + description: The unique identifier for the document. + content: + oneOf: + - type: string + - $ref: '#/components/schemas/InterleavedContentItem' + - type: array + items: + $ref: '#/components/schemas/InterleavedContentItem' + - $ref: '#/components/schemas/URL' + description: The content of the document. + mime_type: + type: string + description: The MIME type of the document. + metadata: + type: object + additionalProperties: + oneOf: + - type: 'null' + - type: boolean + - type: number + - type: string + - type: array + - type: object + description: Additional metadata for the document. + additionalProperties: false + required: + - document_id + - content + - metadata + title: RAGDocument + description: >- + A document to be used for document ingestion in the RAG Tool. + InsertRequest: + type: object + properties: + documents: + type: array + items: + $ref: '#/components/schemas/RAGDocument' + description: >- + List of documents to index in the RAG system + vector_db_id: + type: string + description: >- + ID of the vector database to store the document embeddings + chunk_size_in_tokens: + type: integer + description: >- + (Optional) Size in tokens for document chunking during indexing + additionalProperties: false + required: + - documents + - vector_db_id + - chunk_size_in_tokens + title: InsertRequest + DefaultRAGQueryGeneratorConfig: + type: object + properties: + type: + type: string + const: default + default: default + description: >- + Type of query generator, always 'default' + separator: + type: string + default: ' ' + description: >- + String separator used to join query terms + additionalProperties: false + required: + - type + - separator + title: DefaultRAGQueryGeneratorConfig + description: >- + Configuration for the default RAG query generator. + LLMRAGQueryGeneratorConfig: + type: object + properties: + type: + type: string + const: llm + default: llm + description: Type of query generator, always 'llm' + model: + type: string + description: >- + Name of the language model to use for query generation + template: + type: string + description: >- + Template string for formatting the query generation prompt + additionalProperties: false + required: + - type + - model + - template + title: LLMRAGQueryGeneratorConfig + description: >- + Configuration for the LLM-based RAG query generator. + RAGQueryConfig: + type: object + properties: + query_generator_config: + oneOf: + - $ref: '#/components/schemas/DefaultRAGQueryGeneratorConfig' + - $ref: '#/components/schemas/LLMRAGQueryGeneratorConfig' + discriminator: + propertyName: type + mapping: + default: '#/components/schemas/DefaultRAGQueryGeneratorConfig' + llm: '#/components/schemas/LLMRAGQueryGeneratorConfig' + description: Configuration for the query generator. + max_tokens_in_context: + type: integer + default: 4096 + description: Maximum number of tokens in the context. + max_chunks: + type: integer + default: 5 + description: Maximum number of chunks to retrieve. + chunk_template: + type: string + default: > + Result {index} + + Content: {chunk.content} + + Metadata: {metadata} + description: >- + Template for formatting each retrieved chunk in the context. Available + placeholders: {index} (1-based chunk ordinal), {chunk.content} (chunk + content string), {metadata} (chunk metadata dict). Default: "Result {index}\nContent: + {chunk.content}\nMetadata: {metadata}\n" + mode: + $ref: '#/components/schemas/RAGSearchMode' + default: vector + description: >- + Search mode for retrieval—either "vector", "keyword", or "hybrid". Default + "vector". + ranker: + $ref: '#/components/schemas/Ranker' + description: >- + Configuration for the ranker to use in hybrid search. Defaults to RRF + ranker. + additionalProperties: false + required: + - query_generator_config + - max_tokens_in_context + - max_chunks + - chunk_template + title: RAGQueryConfig + description: >- + Configuration for the RAG query generation. + RAGSearchMode: + type: string + enum: + - vector + - keyword + - hybrid + title: RAGSearchMode + description: >- + Search modes for RAG query retrieval: - VECTOR: Uses vector similarity search + for semantic matching - KEYWORD: Uses keyword-based search for exact matching + - HYBRID: Combines both vector and keyword search for better results + RRFRanker: + type: object + properties: + type: + type: string + const: rrf + default: rrf + description: The type of ranker, always "rrf" + impact_factor: + type: number + default: 60.0 + description: >- + The impact factor for RRF scoring. Higher values give more weight to higher-ranked + results. Must be greater than 0 + additionalProperties: false + required: + - type + - impact_factor + title: RRFRanker + description: >- + Reciprocal Rank Fusion (RRF) ranker configuration. + Ranker: + oneOf: + - $ref: '#/components/schemas/RRFRanker' + - $ref: '#/components/schemas/WeightedRanker' + discriminator: + propertyName: type + mapping: + rrf: '#/components/schemas/RRFRanker' + weighted: '#/components/schemas/WeightedRanker' + WeightedRanker: + type: object + properties: + type: + type: string + const: weighted + default: weighted + description: The type of ranker, always "weighted" + alpha: + type: number + default: 0.5 + description: >- + Weight factor between 0 and 1. 0 means only use keyword scores, 1 means + only use vector scores, values in between blend both scores. + additionalProperties: false + required: + - type + - alpha + title: WeightedRanker + description: >- + Weighted ranker configuration that combines vector and keyword scores. + QueryRequest: + type: object + properties: + content: + $ref: '#/components/schemas/InterleavedContent' + description: >- + The query content to search for in the indexed documents + vector_db_ids: + type: array + items: + type: string + description: >- + List of vector database IDs to search within + query_config: + $ref: '#/components/schemas/RAGQueryConfig' + description: >- + (Optional) Configuration parameters for the query operation + additionalProperties: false + required: + - content + - vector_db_ids + title: QueryRequest + RAGQueryResult: + type: object + properties: + content: + $ref: '#/components/schemas/InterleavedContent' + description: >- + (Optional) The retrieved content from the query + metadata: + type: object + additionalProperties: + oneOf: + - type: 'null' + - type: boolean + - type: number + - type: string + - type: array + - type: object + description: >- + Additional metadata about the query result + additionalProperties: false + required: + - metadata + title: RAGQueryResult + description: >- + Result of a RAG query containing retrieved content and metadata. ToolGroup: type: object properties: diff --git a/docs/docs/distributions/self_hosted_distro/meta-reference-gpu.md b/docs/docs/distributions/self_hosted_distro/meta-reference-gpu.md index 666850976..b7134b3e1 100644 --- a/docs/docs/distributions/self_hosted_distro/meta-reference-gpu.md +++ b/docs/docs/distributions/self_hosted_distro/meta-reference-gpu.md @@ -21,7 +21,7 @@ The `llamastack/distribution-meta-reference-gpu` distribution consists of the fo | inference | `inline::meta-reference` | | safety | `inline::llama-guard` | | scoring | `inline::basic`, `inline::llm-as-judge`, `inline::braintrust` | -| tool_runtime | `remote::brave-search`, `remote::tavily-search`, `remote::model-context-protocol` | +| tool_runtime | `remote::brave-search`, `remote::tavily-search`, `inline::rag-runtime`, `remote::model-context-protocol` | | vector_io | `inline::faiss`, `remote::chromadb`, `remote::pgvector` | diff --git a/docs/docs/distributions/self_hosted_distro/nvidia.md b/docs/docs/distributions/self_hosted_distro/nvidia.md index b1de9ddb8..4a7d99ff5 100644 --- a/docs/docs/distributions/self_hosted_distro/nvidia.md +++ b/docs/docs/distributions/self_hosted_distro/nvidia.md @@ -16,7 +16,7 @@ The `llamastack/distribution-nvidia` distribution consists of the following prov | post_training | `remote::nvidia` | | safety | `remote::nvidia` | | scoring | `inline::basic` | -| tool_runtime | | +| tool_runtime | `inline::rag-runtime` | | vector_io | `inline::faiss` | diff --git a/docs/docs/providers/vector_io/inline_sqlite-vec.mdx b/docs/docs/providers/vector_io/inline_sqlite-vec.mdx index 459498a59..98a372250 100644 --- a/docs/docs/providers/vector_io/inline_sqlite-vec.mdx +++ b/docs/docs/providers/vector_io/inline_sqlite-vec.mdx @@ -28,7 +28,7 @@ description: | #### Empirical Example Consider the histogram below in which 10,000 randomly generated strings were inserted - in batches of 100 into both Faiss and sqlite-vec. + in batches of 100 into both Faiss and sqlite-vec using `client.tool_runtime.rag_tool.insert()`. ```{image} ../../../../_static/providers/vector_io/write_time_comparison_sqlite-vec-faiss.png :alt: Comparison of SQLite-Vec and Faiss write times @@ -233,7 +233,7 @@ Datasets that can fit in memory, frequent reads | Faiss | Optimized for speed, i #### Empirical Example Consider the histogram below in which 10,000 randomly generated strings were inserted -in batches of 100 into both Faiss and sqlite-vec. +in batches of 100 into both Faiss and sqlite-vec using `client.tool_runtime.rag_tool.insert()`. ```{image} ../../../../_static/providers/vector_io/write_time_comparison_sqlite-vec-faiss.png :alt: Comparison of SQLite-Vec and Faiss write times diff --git a/docs/openapi_generator/pyopenapi/operations.py b/docs/openapi_generator/pyopenapi/operations.py index e5f33f13d..2970d7e53 100644 --- a/docs/openapi_generator/pyopenapi/operations.py +++ b/docs/openapi_generator/pyopenapi/operations.py @@ -196,10 +196,16 @@ def _get_endpoint_functions( def _get_defining_class(member_fn: str, derived_cls: type) -> type: "Find the class in which a member function is first defined in a class inheritance hierarchy." + # This import must be dynamic here + from llama_stack.apis.tools import RAGToolRuntime, ToolRuntime + # iterate in reverse member resolution order to find most specific class first for cls in reversed(inspect.getmro(derived_cls)): for name, _ in inspect.getmembers(cls, inspect.isfunction): if name == member_fn: + # HACK ALERT + if cls == RAGToolRuntime: + return ToolRuntime return cls raise ValidationError( diff --git a/docs/static/llama-stack-spec.html b/docs/static/llama-stack-spec.html index 7dfb2ed13..61deaec1e 100644 --- a/docs/static/llama-stack-spec.html +++ b/docs/static/llama-stack-spec.html @@ -2624,6 +2624,89 @@ "deprecated": false } }, + "/v1/tool-runtime/rag-tool/insert": { + "post": { + "responses": { + "200": { + "description": "OK" + }, + "400": { + "$ref": "#/components/responses/BadRequest400" + }, + "429": { + "$ref": "#/components/responses/TooManyRequests429" + }, + "500": { + "$ref": "#/components/responses/InternalServerError500" + }, + "default": { + "$ref": "#/components/responses/DefaultError" + } + }, + "tags": [ + "ToolRuntime" + ], + "summary": "Index documents so they can be used by the RAG system.", + "description": "Index documents so they can be used by the RAG system.", + "parameters": [], + "requestBody": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/InsertRequest" + } + } + }, + "required": true + }, + "deprecated": false + } + }, + "/v1/tool-runtime/rag-tool/query": { + "post": { + "responses": { + "200": { + "description": "RAGQueryResult containing the retrieved content and metadata", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/RAGQueryResult" + } + } + } + }, + "400": { + "$ref": "#/components/responses/BadRequest400" + }, + "429": { + "$ref": "#/components/responses/TooManyRequests429" + }, + "500": { + "$ref": "#/components/responses/InternalServerError500" + }, + "default": { + "$ref": "#/components/responses/DefaultError" + } + }, + "tags": [ + "ToolRuntime" + ], + "summary": "Query the RAG system for context; typically invoked by the agent.", + "description": "Query the RAG system for context; typically invoked by the agent.", + "parameters": [], + "requestBody": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/QueryRequest" + } + } + }, + "required": true + }, + "deprecated": false + } + }, "/v1/toolgroups": { "get": { "responses": { @@ -11300,6 +11383,346 @@ "title": "ListToolDefsResponse", "description": "Response containing a list of tool definitions." }, + "RAGDocument": { + "type": "object", + "properties": { + "document_id": { + "type": "string", + "description": "The unique identifier for the document." + }, + "content": { + "oneOf": [ + { + "type": "string" + }, + { + "$ref": "#/components/schemas/InterleavedContentItem" + }, + { + "type": "array", + "items": { + "$ref": "#/components/schemas/InterleavedContentItem" + } + }, + { + "$ref": "#/components/schemas/URL" + } + ], + "description": "The content of the document." + }, + "mime_type": { + "type": "string", + "description": "The MIME type of the document." + }, + "metadata": { + "type": "object", + "additionalProperties": { + "oneOf": [ + { + "type": "null" + }, + { + "type": "boolean" + }, + { + "type": "number" + }, + { + "type": "string" + }, + { + "type": "array" + }, + { + "type": "object" + } + ] + }, + "description": "Additional metadata for the document." + } + }, + "additionalProperties": false, + "required": [ + "document_id", + "content", + "metadata" + ], + "title": "RAGDocument", + "description": "A document to be used for document ingestion in the RAG Tool." + }, + "InsertRequest": { + "type": "object", + "properties": { + "documents": { + "type": "array", + "items": { + "$ref": "#/components/schemas/RAGDocument" + }, + "description": "List of documents to index in the RAG system" + }, + "vector_db_id": { + "type": "string", + "description": "ID of the vector database to store the document embeddings" + }, + "chunk_size_in_tokens": { + "type": "integer", + "description": "(Optional) Size in tokens for document chunking during indexing" + } + }, + "additionalProperties": false, + "required": [ + "documents", + "vector_db_id", + "chunk_size_in_tokens" + ], + "title": "InsertRequest" + }, + "DefaultRAGQueryGeneratorConfig": { + "type": "object", + "properties": { + "type": { + "type": "string", + "const": "default", + "default": "default", + "description": "Type of query generator, always 'default'" + }, + "separator": { + "type": "string", + "default": " ", + "description": "String separator used to join query terms" + } + }, + "additionalProperties": false, + "required": [ + "type", + "separator" + ], + "title": "DefaultRAGQueryGeneratorConfig", + "description": "Configuration for the default RAG query generator." + }, + "LLMRAGQueryGeneratorConfig": { + "type": "object", + "properties": { + "type": { + "type": "string", + "const": "llm", + "default": "llm", + "description": "Type of query generator, always 'llm'" + }, + "model": { + "type": "string", + "description": "Name of the language model to use for query generation" + }, + "template": { + "type": "string", + "description": "Template string for formatting the query generation prompt" + } + }, + "additionalProperties": false, + "required": [ + "type", + "model", + "template" + ], + "title": "LLMRAGQueryGeneratorConfig", + "description": "Configuration for the LLM-based RAG query generator." + }, + "RAGQueryConfig": { + "type": "object", + "properties": { + "query_generator_config": { + "oneOf": [ + { + "$ref": "#/components/schemas/DefaultRAGQueryGeneratorConfig" + }, + { + "$ref": "#/components/schemas/LLMRAGQueryGeneratorConfig" + } + ], + "discriminator": { + "propertyName": "type", + "mapping": { + "default": "#/components/schemas/DefaultRAGQueryGeneratorConfig", + "llm": "#/components/schemas/LLMRAGQueryGeneratorConfig" + } + }, + "description": "Configuration for the query generator." + }, + "max_tokens_in_context": { + "type": "integer", + "default": 4096, + "description": "Maximum number of tokens in the context." + }, + "max_chunks": { + "type": "integer", + "default": 5, + "description": "Maximum number of chunks to retrieve." + }, + "chunk_template": { + "type": "string", + "default": "Result {index}\nContent: {chunk.content}\nMetadata: {metadata}\n", + "description": "Template for formatting each retrieved chunk in the context. Available placeholders: {index} (1-based chunk ordinal), {chunk.content} (chunk content string), {metadata} (chunk metadata dict). Default: \"Result {index}\\nContent: {chunk.content}\\nMetadata: {metadata}\\n\"" + }, + "mode": { + "$ref": "#/components/schemas/RAGSearchMode", + "default": "vector", + "description": "Search mode for retrieval—either \"vector\", \"keyword\", or \"hybrid\". Default \"vector\"." + }, + "ranker": { + "$ref": "#/components/schemas/Ranker", + "description": "Configuration for the ranker to use in hybrid search. Defaults to RRF ranker." + } + }, + "additionalProperties": false, + "required": [ + "query_generator_config", + "max_tokens_in_context", + "max_chunks", + "chunk_template" + ], + "title": "RAGQueryConfig", + "description": "Configuration for the RAG query generation." + }, + "RAGSearchMode": { + "type": "string", + "enum": [ + "vector", + "keyword", + "hybrid" + ], + "title": "RAGSearchMode", + "description": "Search modes for RAG query retrieval: - VECTOR: Uses vector similarity search for semantic matching - KEYWORD: Uses keyword-based search for exact matching - HYBRID: Combines both vector and keyword search for better results" + }, + "RRFRanker": { + "type": "object", + "properties": { + "type": { + "type": "string", + "const": "rrf", + "default": "rrf", + "description": "The type of ranker, always \"rrf\"" + }, + "impact_factor": { + "type": "number", + "default": 60.0, + "description": "The impact factor for RRF scoring. Higher values give more weight to higher-ranked results. Must be greater than 0" + } + }, + "additionalProperties": false, + "required": [ + "type", + "impact_factor" + ], + "title": "RRFRanker", + "description": "Reciprocal Rank Fusion (RRF) ranker configuration." + }, + "Ranker": { + "oneOf": [ + { + "$ref": "#/components/schemas/RRFRanker" + }, + { + "$ref": "#/components/schemas/WeightedRanker" + } + ], + "discriminator": { + "propertyName": "type", + "mapping": { + "rrf": "#/components/schemas/RRFRanker", + "weighted": "#/components/schemas/WeightedRanker" + } + } + }, + "WeightedRanker": { + "type": "object", + "properties": { + "type": { + "type": "string", + "const": "weighted", + "default": "weighted", + "description": "The type of ranker, always \"weighted\"" + }, + "alpha": { + "type": "number", + "default": 0.5, + "description": "Weight factor between 0 and 1. 0 means only use keyword scores, 1 means only use vector scores, values in between blend both scores." + } + }, + "additionalProperties": false, + "required": [ + "type", + "alpha" + ], + "title": "WeightedRanker", + "description": "Weighted ranker configuration that combines vector and keyword scores." + }, + "QueryRequest": { + "type": "object", + "properties": { + "content": { + "$ref": "#/components/schemas/InterleavedContent", + "description": "The query content to search for in the indexed documents" + }, + "vector_db_ids": { + "type": "array", + "items": { + "type": "string" + }, + "description": "List of vector database IDs to search within" + }, + "query_config": { + "$ref": "#/components/schemas/RAGQueryConfig", + "description": "(Optional) Configuration parameters for the query operation" + } + }, + "additionalProperties": false, + "required": [ + "content", + "vector_db_ids" + ], + "title": "QueryRequest" + }, + "RAGQueryResult": { + "type": "object", + "properties": { + "content": { + "$ref": "#/components/schemas/InterleavedContent", + "description": "(Optional) The retrieved content from the query" + }, + "metadata": { + "type": "object", + "additionalProperties": { + "oneOf": [ + { + "type": "null" + }, + { + "type": "boolean" + }, + { + "type": "number" + }, + { + "type": "string" + }, + { + "type": "array" + }, + { + "type": "object" + } + ] + }, + "description": "Additional metadata about the query result" + } + }, + "additionalProperties": false, + "required": [ + "metadata" + ], + "title": "RAGQueryResult", + "description": "Result of a RAG query containing retrieved content and metadata." + }, "ToolGroup": { "type": "object", "properties": { diff --git a/docs/static/llama-stack-spec.yaml b/docs/static/llama-stack-spec.yaml index 1b0fefe55..c6197b36f 100644 --- a/docs/static/llama-stack-spec.yaml +++ b/docs/static/llama-stack-spec.yaml @@ -2036,6 +2036,69 @@ paths: schema: $ref: '#/components/schemas/URL' deprecated: false + /v1/tool-runtime/rag-tool/insert: + post: + responses: + '200': + description: OK + '400': + $ref: '#/components/responses/BadRequest400' + '429': + $ref: >- + #/components/responses/TooManyRequests429 + '500': + $ref: >- + #/components/responses/InternalServerError500 + default: + $ref: '#/components/responses/DefaultError' + tags: + - ToolRuntime + summary: >- + Index documents so they can be used by the RAG system. + description: >- + Index documents so they can be used by the RAG system. + parameters: [] + requestBody: + content: + application/json: + schema: + $ref: '#/components/schemas/InsertRequest' + required: true + deprecated: false + /v1/tool-runtime/rag-tool/query: + post: + responses: + '200': + description: >- + RAGQueryResult containing the retrieved content and metadata + content: + application/json: + schema: + $ref: '#/components/schemas/RAGQueryResult' + '400': + $ref: '#/components/responses/BadRequest400' + '429': + $ref: >- + #/components/responses/TooManyRequests429 + '500': + $ref: >- + #/components/responses/InternalServerError500 + default: + $ref: '#/components/responses/DefaultError' + tags: + - ToolRuntime + summary: >- + Query the RAG system for context; typically invoked by the agent. + description: >- + Query the RAG system for context; typically invoked by the agent. + parameters: [] + requestBody: + content: + application/json: + schema: + $ref: '#/components/schemas/QueryRequest' + required: true + deprecated: false /v1/toolgroups: get: responses: @@ -8645,6 +8708,274 @@ components: title: ListToolDefsResponse description: >- Response containing a list of tool definitions. + RAGDocument: + type: object + properties: + document_id: + type: string + description: The unique identifier for the document. + content: + oneOf: + - type: string + - $ref: '#/components/schemas/InterleavedContentItem' + - type: array + items: + $ref: '#/components/schemas/InterleavedContentItem' + - $ref: '#/components/schemas/URL' + description: The content of the document. + mime_type: + type: string + description: The MIME type of the document. + metadata: + type: object + additionalProperties: + oneOf: + - type: 'null' + - type: boolean + - type: number + - type: string + - type: array + - type: object + description: Additional metadata for the document. + additionalProperties: false + required: + - document_id + - content + - metadata + title: RAGDocument + description: >- + A document to be used for document ingestion in the RAG Tool. + InsertRequest: + type: object + properties: + documents: + type: array + items: + $ref: '#/components/schemas/RAGDocument' + description: >- + List of documents to index in the RAG system + vector_db_id: + type: string + description: >- + ID of the vector database to store the document embeddings + chunk_size_in_tokens: + type: integer + description: >- + (Optional) Size in tokens for document chunking during indexing + additionalProperties: false + required: + - documents + - vector_db_id + - chunk_size_in_tokens + title: InsertRequest + DefaultRAGQueryGeneratorConfig: + type: object + properties: + type: + type: string + const: default + default: default + description: >- + Type of query generator, always 'default' + separator: + type: string + default: ' ' + description: >- + String separator used to join query terms + additionalProperties: false + required: + - type + - separator + title: DefaultRAGQueryGeneratorConfig + description: >- + Configuration for the default RAG query generator. + LLMRAGQueryGeneratorConfig: + type: object + properties: + type: + type: string + const: llm + default: llm + description: Type of query generator, always 'llm' + model: + type: string + description: >- + Name of the language model to use for query generation + template: + type: string + description: >- + Template string for formatting the query generation prompt + additionalProperties: false + required: + - type + - model + - template + title: LLMRAGQueryGeneratorConfig + description: >- + Configuration for the LLM-based RAG query generator. + RAGQueryConfig: + type: object + properties: + query_generator_config: + oneOf: + - $ref: '#/components/schemas/DefaultRAGQueryGeneratorConfig' + - $ref: '#/components/schemas/LLMRAGQueryGeneratorConfig' + discriminator: + propertyName: type + mapping: + default: '#/components/schemas/DefaultRAGQueryGeneratorConfig' + llm: '#/components/schemas/LLMRAGQueryGeneratorConfig' + description: Configuration for the query generator. + max_tokens_in_context: + type: integer + default: 4096 + description: Maximum number of tokens in the context. + max_chunks: + type: integer + default: 5 + description: Maximum number of chunks to retrieve. + chunk_template: + type: string + default: > + Result {index} + + Content: {chunk.content} + + Metadata: {metadata} + description: >- + Template for formatting each retrieved chunk in the context. Available + placeholders: {index} (1-based chunk ordinal), {chunk.content} (chunk + content string), {metadata} (chunk metadata dict). Default: "Result {index}\nContent: + {chunk.content}\nMetadata: {metadata}\n" + mode: + $ref: '#/components/schemas/RAGSearchMode' + default: vector + description: >- + Search mode for retrieval—either "vector", "keyword", or "hybrid". Default + "vector". + ranker: + $ref: '#/components/schemas/Ranker' + description: >- + Configuration for the ranker to use in hybrid search. Defaults to RRF + ranker. + additionalProperties: false + required: + - query_generator_config + - max_tokens_in_context + - max_chunks + - chunk_template + title: RAGQueryConfig + description: >- + Configuration for the RAG query generation. + RAGSearchMode: + type: string + enum: + - vector + - keyword + - hybrid + title: RAGSearchMode + description: >- + Search modes for RAG query retrieval: - VECTOR: Uses vector similarity search + for semantic matching - KEYWORD: Uses keyword-based search for exact matching + - HYBRID: Combines both vector and keyword search for better results + RRFRanker: + type: object + properties: + type: + type: string + const: rrf + default: rrf + description: The type of ranker, always "rrf" + impact_factor: + type: number + default: 60.0 + description: >- + The impact factor for RRF scoring. Higher values give more weight to higher-ranked + results. Must be greater than 0 + additionalProperties: false + required: + - type + - impact_factor + title: RRFRanker + description: >- + Reciprocal Rank Fusion (RRF) ranker configuration. + Ranker: + oneOf: + - $ref: '#/components/schemas/RRFRanker' + - $ref: '#/components/schemas/WeightedRanker' + discriminator: + propertyName: type + mapping: + rrf: '#/components/schemas/RRFRanker' + weighted: '#/components/schemas/WeightedRanker' + WeightedRanker: + type: object + properties: + type: + type: string + const: weighted + default: weighted + description: The type of ranker, always "weighted" + alpha: + type: number + default: 0.5 + description: >- + Weight factor between 0 and 1. 0 means only use keyword scores, 1 means + only use vector scores, values in between blend both scores. + additionalProperties: false + required: + - type + - alpha + title: WeightedRanker + description: >- + Weighted ranker configuration that combines vector and keyword scores. + QueryRequest: + type: object + properties: + content: + $ref: '#/components/schemas/InterleavedContent' + description: >- + The query content to search for in the indexed documents + vector_db_ids: + type: array + items: + type: string + description: >- + List of vector database IDs to search within + query_config: + $ref: '#/components/schemas/RAGQueryConfig' + description: >- + (Optional) Configuration parameters for the query operation + additionalProperties: false + required: + - content + - vector_db_ids + title: QueryRequest + RAGQueryResult: + type: object + properties: + content: + $ref: '#/components/schemas/InterleavedContent' + description: >- + (Optional) The retrieved content from the query + metadata: + type: object + additionalProperties: + oneOf: + - type: 'null' + - type: boolean + - type: number + - type: string + - type: array + - type: object + description: >- + Additional metadata about the query result + additionalProperties: false + required: + - metadata + title: RAGQueryResult + description: >- + Result of a RAG query containing retrieved content and metadata. ToolGroup: type: object properties: diff --git a/docs/static/stainless-llama-stack-spec.html b/docs/static/stainless-llama-stack-spec.html index 7930b28e6..38122ebc0 100644 --- a/docs/static/stainless-llama-stack-spec.html +++ b/docs/static/stainless-llama-stack-spec.html @@ -2624,6 +2624,89 @@ "deprecated": false } }, + "/v1/tool-runtime/rag-tool/insert": { + "post": { + "responses": { + "200": { + "description": "OK" + }, + "400": { + "$ref": "#/components/responses/BadRequest400" + }, + "429": { + "$ref": "#/components/responses/TooManyRequests429" + }, + "500": { + "$ref": "#/components/responses/InternalServerError500" + }, + "default": { + "$ref": "#/components/responses/DefaultError" + } + }, + "tags": [ + "ToolRuntime" + ], + "summary": "Index documents so they can be used by the RAG system.", + "description": "Index documents so they can be used by the RAG system.", + "parameters": [], + "requestBody": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/InsertRequest" + } + } + }, + "required": true + }, + "deprecated": false + } + }, + "/v1/tool-runtime/rag-tool/query": { + "post": { + "responses": { + "200": { + "description": "RAGQueryResult containing the retrieved content and metadata", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/RAGQueryResult" + } + } + } + }, + "400": { + "$ref": "#/components/responses/BadRequest400" + }, + "429": { + "$ref": "#/components/responses/TooManyRequests429" + }, + "500": { + "$ref": "#/components/responses/InternalServerError500" + }, + "default": { + "$ref": "#/components/responses/DefaultError" + } + }, + "tags": [ + "ToolRuntime" + ], + "summary": "Query the RAG system for context; typically invoked by the agent.", + "description": "Query the RAG system for context; typically invoked by the agent.", + "parameters": [], + "requestBody": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/QueryRequest" + } + } + }, + "required": true + }, + "deprecated": false + } + }, "/v1/toolgroups": { "get": { "responses": { @@ -12972,6 +13055,346 @@ "title": "ListToolDefsResponse", "description": "Response containing a list of tool definitions." }, + "RAGDocument": { + "type": "object", + "properties": { + "document_id": { + "type": "string", + "description": "The unique identifier for the document." + }, + "content": { + "oneOf": [ + { + "type": "string" + }, + { + "$ref": "#/components/schemas/InterleavedContentItem" + }, + { + "type": "array", + "items": { + "$ref": "#/components/schemas/InterleavedContentItem" + } + }, + { + "$ref": "#/components/schemas/URL" + } + ], + "description": "The content of the document." + }, + "mime_type": { + "type": "string", + "description": "The MIME type of the document." + }, + "metadata": { + "type": "object", + "additionalProperties": { + "oneOf": [ + { + "type": "null" + }, + { + "type": "boolean" + }, + { + "type": "number" + }, + { + "type": "string" + }, + { + "type": "array" + }, + { + "type": "object" + } + ] + }, + "description": "Additional metadata for the document." + } + }, + "additionalProperties": false, + "required": [ + "document_id", + "content", + "metadata" + ], + "title": "RAGDocument", + "description": "A document to be used for document ingestion in the RAG Tool." + }, + "InsertRequest": { + "type": "object", + "properties": { + "documents": { + "type": "array", + "items": { + "$ref": "#/components/schemas/RAGDocument" + }, + "description": "List of documents to index in the RAG system" + }, + "vector_db_id": { + "type": "string", + "description": "ID of the vector database to store the document embeddings" + }, + "chunk_size_in_tokens": { + "type": "integer", + "description": "(Optional) Size in tokens for document chunking during indexing" + } + }, + "additionalProperties": false, + "required": [ + "documents", + "vector_db_id", + "chunk_size_in_tokens" + ], + "title": "InsertRequest" + }, + "DefaultRAGQueryGeneratorConfig": { + "type": "object", + "properties": { + "type": { + "type": "string", + "const": "default", + "default": "default", + "description": "Type of query generator, always 'default'" + }, + "separator": { + "type": "string", + "default": " ", + "description": "String separator used to join query terms" + } + }, + "additionalProperties": false, + "required": [ + "type", + "separator" + ], + "title": "DefaultRAGQueryGeneratorConfig", + "description": "Configuration for the default RAG query generator." + }, + "LLMRAGQueryGeneratorConfig": { + "type": "object", + "properties": { + "type": { + "type": "string", + "const": "llm", + "default": "llm", + "description": "Type of query generator, always 'llm'" + }, + "model": { + "type": "string", + "description": "Name of the language model to use for query generation" + }, + "template": { + "type": "string", + "description": "Template string for formatting the query generation prompt" + } + }, + "additionalProperties": false, + "required": [ + "type", + "model", + "template" + ], + "title": "LLMRAGQueryGeneratorConfig", + "description": "Configuration for the LLM-based RAG query generator." + }, + "RAGQueryConfig": { + "type": "object", + "properties": { + "query_generator_config": { + "oneOf": [ + { + "$ref": "#/components/schemas/DefaultRAGQueryGeneratorConfig" + }, + { + "$ref": "#/components/schemas/LLMRAGQueryGeneratorConfig" + } + ], + "discriminator": { + "propertyName": "type", + "mapping": { + "default": "#/components/schemas/DefaultRAGQueryGeneratorConfig", + "llm": "#/components/schemas/LLMRAGQueryGeneratorConfig" + } + }, + "description": "Configuration for the query generator." + }, + "max_tokens_in_context": { + "type": "integer", + "default": 4096, + "description": "Maximum number of tokens in the context." + }, + "max_chunks": { + "type": "integer", + "default": 5, + "description": "Maximum number of chunks to retrieve." + }, + "chunk_template": { + "type": "string", + "default": "Result {index}\nContent: {chunk.content}\nMetadata: {metadata}\n", + "description": "Template for formatting each retrieved chunk in the context. Available placeholders: {index} (1-based chunk ordinal), {chunk.content} (chunk content string), {metadata} (chunk metadata dict). Default: \"Result {index}\\nContent: {chunk.content}\\nMetadata: {metadata}\\n\"" + }, + "mode": { + "$ref": "#/components/schemas/RAGSearchMode", + "default": "vector", + "description": "Search mode for retrieval—either \"vector\", \"keyword\", or \"hybrid\". Default \"vector\"." + }, + "ranker": { + "$ref": "#/components/schemas/Ranker", + "description": "Configuration for the ranker to use in hybrid search. Defaults to RRF ranker." + } + }, + "additionalProperties": false, + "required": [ + "query_generator_config", + "max_tokens_in_context", + "max_chunks", + "chunk_template" + ], + "title": "RAGQueryConfig", + "description": "Configuration for the RAG query generation." + }, + "RAGSearchMode": { + "type": "string", + "enum": [ + "vector", + "keyword", + "hybrid" + ], + "title": "RAGSearchMode", + "description": "Search modes for RAG query retrieval: - VECTOR: Uses vector similarity search for semantic matching - KEYWORD: Uses keyword-based search for exact matching - HYBRID: Combines both vector and keyword search for better results" + }, + "RRFRanker": { + "type": "object", + "properties": { + "type": { + "type": "string", + "const": "rrf", + "default": "rrf", + "description": "The type of ranker, always \"rrf\"" + }, + "impact_factor": { + "type": "number", + "default": 60.0, + "description": "The impact factor for RRF scoring. Higher values give more weight to higher-ranked results. Must be greater than 0" + } + }, + "additionalProperties": false, + "required": [ + "type", + "impact_factor" + ], + "title": "RRFRanker", + "description": "Reciprocal Rank Fusion (RRF) ranker configuration." + }, + "Ranker": { + "oneOf": [ + { + "$ref": "#/components/schemas/RRFRanker" + }, + { + "$ref": "#/components/schemas/WeightedRanker" + } + ], + "discriminator": { + "propertyName": "type", + "mapping": { + "rrf": "#/components/schemas/RRFRanker", + "weighted": "#/components/schemas/WeightedRanker" + } + } + }, + "WeightedRanker": { + "type": "object", + "properties": { + "type": { + "type": "string", + "const": "weighted", + "default": "weighted", + "description": "The type of ranker, always \"weighted\"" + }, + "alpha": { + "type": "number", + "default": 0.5, + "description": "Weight factor between 0 and 1. 0 means only use keyword scores, 1 means only use vector scores, values in between blend both scores." + } + }, + "additionalProperties": false, + "required": [ + "type", + "alpha" + ], + "title": "WeightedRanker", + "description": "Weighted ranker configuration that combines vector and keyword scores." + }, + "QueryRequest": { + "type": "object", + "properties": { + "content": { + "$ref": "#/components/schemas/InterleavedContent", + "description": "The query content to search for in the indexed documents" + }, + "vector_db_ids": { + "type": "array", + "items": { + "type": "string" + }, + "description": "List of vector database IDs to search within" + }, + "query_config": { + "$ref": "#/components/schemas/RAGQueryConfig", + "description": "(Optional) Configuration parameters for the query operation" + } + }, + "additionalProperties": false, + "required": [ + "content", + "vector_db_ids" + ], + "title": "QueryRequest" + }, + "RAGQueryResult": { + "type": "object", + "properties": { + "content": { + "$ref": "#/components/schemas/InterleavedContent", + "description": "(Optional) The retrieved content from the query" + }, + "metadata": { + "type": "object", + "additionalProperties": { + "oneOf": [ + { + "type": "null" + }, + { + "type": "boolean" + }, + { + "type": "number" + }, + { + "type": "string" + }, + { + "type": "array" + }, + { + "type": "object" + } + ] + }, + "description": "Additional metadata about the query result" + } + }, + "additionalProperties": false, + "required": [ + "metadata" + ], + "title": "RAGQueryResult", + "description": "Result of a RAG query containing retrieved content and metadata." + }, "ToolGroup": { "type": "object", "properties": { diff --git a/docs/static/stainless-llama-stack-spec.yaml b/docs/static/stainless-llama-stack-spec.yaml index 98a309f12..93049a14a 100644 --- a/docs/static/stainless-llama-stack-spec.yaml +++ b/docs/static/stainless-llama-stack-spec.yaml @@ -2039,6 +2039,69 @@ paths: schema: $ref: '#/components/schemas/URL' deprecated: false + /v1/tool-runtime/rag-tool/insert: + post: + responses: + '200': + description: OK + '400': + $ref: '#/components/responses/BadRequest400' + '429': + $ref: >- + #/components/responses/TooManyRequests429 + '500': + $ref: >- + #/components/responses/InternalServerError500 + default: + $ref: '#/components/responses/DefaultError' + tags: + - ToolRuntime + summary: >- + Index documents so they can be used by the RAG system. + description: >- + Index documents so they can be used by the RAG system. + parameters: [] + requestBody: + content: + application/json: + schema: + $ref: '#/components/schemas/InsertRequest' + required: true + deprecated: false + /v1/tool-runtime/rag-tool/query: + post: + responses: + '200': + description: >- + RAGQueryResult containing the retrieved content and metadata + content: + application/json: + schema: + $ref: '#/components/schemas/RAGQueryResult' + '400': + $ref: '#/components/responses/BadRequest400' + '429': + $ref: >- + #/components/responses/TooManyRequests429 + '500': + $ref: >- + #/components/responses/InternalServerError500 + default: + $ref: '#/components/responses/DefaultError' + tags: + - ToolRuntime + summary: >- + Query the RAG system for context; typically invoked by the agent. + description: >- + Query the RAG system for context; typically invoked by the agent. + parameters: [] + requestBody: + content: + application/json: + schema: + $ref: '#/components/schemas/QueryRequest' + required: true + deprecated: false /v1/toolgroups: get: responses: @@ -9858,6 +9921,274 @@ components: title: ListToolDefsResponse description: >- Response containing a list of tool definitions. + RAGDocument: + type: object + properties: + document_id: + type: string + description: The unique identifier for the document. + content: + oneOf: + - type: string + - $ref: '#/components/schemas/InterleavedContentItem' + - type: array + items: + $ref: '#/components/schemas/InterleavedContentItem' + - $ref: '#/components/schemas/URL' + description: The content of the document. + mime_type: + type: string + description: The MIME type of the document. + metadata: + type: object + additionalProperties: + oneOf: + - type: 'null' + - type: boolean + - type: number + - type: string + - type: array + - type: object + description: Additional metadata for the document. + additionalProperties: false + required: + - document_id + - content + - metadata + title: RAGDocument + description: >- + A document to be used for document ingestion in the RAG Tool. + InsertRequest: + type: object + properties: + documents: + type: array + items: + $ref: '#/components/schemas/RAGDocument' + description: >- + List of documents to index in the RAG system + vector_db_id: + type: string + description: >- + ID of the vector database to store the document embeddings + chunk_size_in_tokens: + type: integer + description: >- + (Optional) Size in tokens for document chunking during indexing + additionalProperties: false + required: + - documents + - vector_db_id + - chunk_size_in_tokens + title: InsertRequest + DefaultRAGQueryGeneratorConfig: + type: object + properties: + type: + type: string + const: default + default: default + description: >- + Type of query generator, always 'default' + separator: + type: string + default: ' ' + description: >- + String separator used to join query terms + additionalProperties: false + required: + - type + - separator + title: DefaultRAGQueryGeneratorConfig + description: >- + Configuration for the default RAG query generator. + LLMRAGQueryGeneratorConfig: + type: object + properties: + type: + type: string + const: llm + default: llm + description: Type of query generator, always 'llm' + model: + type: string + description: >- + Name of the language model to use for query generation + template: + type: string + description: >- + Template string for formatting the query generation prompt + additionalProperties: false + required: + - type + - model + - template + title: LLMRAGQueryGeneratorConfig + description: >- + Configuration for the LLM-based RAG query generator. + RAGQueryConfig: + type: object + properties: + query_generator_config: + oneOf: + - $ref: '#/components/schemas/DefaultRAGQueryGeneratorConfig' + - $ref: '#/components/schemas/LLMRAGQueryGeneratorConfig' + discriminator: + propertyName: type + mapping: + default: '#/components/schemas/DefaultRAGQueryGeneratorConfig' + llm: '#/components/schemas/LLMRAGQueryGeneratorConfig' + description: Configuration for the query generator. + max_tokens_in_context: + type: integer + default: 4096 + description: Maximum number of tokens in the context. + max_chunks: + type: integer + default: 5 + description: Maximum number of chunks to retrieve. + chunk_template: + type: string + default: > + Result {index} + + Content: {chunk.content} + + Metadata: {metadata} + description: >- + Template for formatting each retrieved chunk in the context. Available + placeholders: {index} (1-based chunk ordinal), {chunk.content} (chunk + content string), {metadata} (chunk metadata dict). Default: "Result {index}\nContent: + {chunk.content}\nMetadata: {metadata}\n" + mode: + $ref: '#/components/schemas/RAGSearchMode' + default: vector + description: >- + Search mode for retrieval—either "vector", "keyword", or "hybrid". Default + "vector". + ranker: + $ref: '#/components/schemas/Ranker' + description: >- + Configuration for the ranker to use in hybrid search. Defaults to RRF + ranker. + additionalProperties: false + required: + - query_generator_config + - max_tokens_in_context + - max_chunks + - chunk_template + title: RAGQueryConfig + description: >- + Configuration for the RAG query generation. + RAGSearchMode: + type: string + enum: + - vector + - keyword + - hybrid + title: RAGSearchMode + description: >- + Search modes for RAG query retrieval: - VECTOR: Uses vector similarity search + for semantic matching - KEYWORD: Uses keyword-based search for exact matching + - HYBRID: Combines both vector and keyword search for better results + RRFRanker: + type: object + properties: + type: + type: string + const: rrf + default: rrf + description: The type of ranker, always "rrf" + impact_factor: + type: number + default: 60.0 + description: >- + The impact factor for RRF scoring. Higher values give more weight to higher-ranked + results. Must be greater than 0 + additionalProperties: false + required: + - type + - impact_factor + title: RRFRanker + description: >- + Reciprocal Rank Fusion (RRF) ranker configuration. + Ranker: + oneOf: + - $ref: '#/components/schemas/RRFRanker' + - $ref: '#/components/schemas/WeightedRanker' + discriminator: + propertyName: type + mapping: + rrf: '#/components/schemas/RRFRanker' + weighted: '#/components/schemas/WeightedRanker' + WeightedRanker: + type: object + properties: + type: + type: string + const: weighted + default: weighted + description: The type of ranker, always "weighted" + alpha: + type: number + default: 0.5 + description: >- + Weight factor between 0 and 1. 0 means only use keyword scores, 1 means + only use vector scores, values in between blend both scores. + additionalProperties: false + required: + - type + - alpha + title: WeightedRanker + description: >- + Weighted ranker configuration that combines vector and keyword scores. + QueryRequest: + type: object + properties: + content: + $ref: '#/components/schemas/InterleavedContent' + description: >- + The query content to search for in the indexed documents + vector_db_ids: + type: array + items: + type: string + description: >- + List of vector database IDs to search within + query_config: + $ref: '#/components/schemas/RAGQueryConfig' + description: >- + (Optional) Configuration parameters for the query operation + additionalProperties: false + required: + - content + - vector_db_ids + title: QueryRequest + RAGQueryResult: + type: object + properties: + content: + $ref: '#/components/schemas/InterleavedContent' + description: >- + (Optional) The retrieved content from the query + metadata: + type: object + additionalProperties: + oneOf: + - type: 'null' + - type: boolean + - type: number + - type: string + - type: array + - type: object + description: >- + Additional metadata about the query result + additionalProperties: false + required: + - metadata + title: RAGQueryResult + description: >- + Result of a RAG query containing retrieved content and metadata. ToolGroup: type: object properties: diff --git a/llama_stack/apis/tools/__init__.py b/llama_stack/apis/tools/__init__.py index 2908d1c62..b25310ecf 100644 --- a/llama_stack/apis/tools/__init__.py +++ b/llama_stack/apis/tools/__init__.py @@ -4,4 +4,5 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. +from .rag_tool import * from .tools import * diff --git a/llama_stack/apis/tools/rag_tool.py b/llama_stack/apis/tools/rag_tool.py new file mode 100644 index 000000000..ed7847e23 --- /dev/null +++ b/llama_stack/apis/tools/rag_tool.py @@ -0,0 +1,218 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from enum import Enum, StrEnum +from typing import Annotated, Any, Literal, Protocol + +from pydantic import BaseModel, Field, field_validator +from typing_extensions import runtime_checkable + +from llama_stack.apis.common.content_types import URL, InterleavedContent +from llama_stack.apis.version import LLAMA_STACK_API_V1 +from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol +from llama_stack.schema_utils import json_schema_type, register_schema, webmethod + + +@json_schema_type +class RRFRanker(BaseModel): + """ + Reciprocal Rank Fusion (RRF) ranker configuration. + + :param type: The type of ranker, always "rrf" + :param impact_factor: The impact factor for RRF scoring. Higher values give more weight to higher-ranked results. + Must be greater than 0 + """ + + type: Literal["rrf"] = "rrf" + impact_factor: float = Field(default=60.0, gt=0.0) # default of 60 for optimal performance + + +@json_schema_type +class WeightedRanker(BaseModel): + """ + Weighted ranker configuration that combines vector and keyword scores. + + :param type: The type of ranker, always "weighted" + :param alpha: Weight factor between 0 and 1. + 0 means only use keyword scores, + 1 means only use vector scores, + values in between blend both scores. + """ + + type: Literal["weighted"] = "weighted" + alpha: float = Field( + default=0.5, + ge=0.0, + le=1.0, + description="Weight factor between 0 and 1. 0 means only keyword scores, 1 means only vector scores.", + ) + + +Ranker = Annotated[ + RRFRanker | WeightedRanker, + Field(discriminator="type"), +] +register_schema(Ranker, name="Ranker") + + +@json_schema_type +class RAGDocument(BaseModel): + """ + A document to be used for document ingestion in the RAG Tool. + + :param document_id: The unique identifier for the document. + :param content: The content of the document. + :param mime_type: The MIME type of the document. + :param metadata: Additional metadata for the document. + """ + + document_id: str + content: InterleavedContent | URL + mime_type: str | None = None + metadata: dict[str, Any] = Field(default_factory=dict) + + +@json_schema_type +class RAGQueryResult(BaseModel): + """Result of a RAG query containing retrieved content and metadata. + + :param content: (Optional) The retrieved content from the query + :param metadata: Additional metadata about the query result + """ + + content: InterleavedContent | None = None + metadata: dict[str, Any] = Field(default_factory=dict) + + +@json_schema_type +class RAGQueryGenerator(Enum): + """Types of query generators for RAG systems. + + :cvar default: Default query generator using simple text processing + :cvar llm: LLM-based query generator for enhanced query understanding + :cvar custom: Custom query generator implementation + """ + + default = "default" + llm = "llm" + custom = "custom" + + +@json_schema_type +class RAGSearchMode(StrEnum): + """ + Search modes for RAG query retrieval: + - VECTOR: Uses vector similarity search for semantic matching + - KEYWORD: Uses keyword-based search for exact matching + - HYBRID: Combines both vector and keyword search for better results + """ + + VECTOR = "vector" + KEYWORD = "keyword" + HYBRID = "hybrid" + + +@json_schema_type +class DefaultRAGQueryGeneratorConfig(BaseModel): + """Configuration for the default RAG query generator. + + :param type: Type of query generator, always 'default' + :param separator: String separator used to join query terms + """ + + type: Literal["default"] = "default" + separator: str = " " + + +@json_schema_type +class LLMRAGQueryGeneratorConfig(BaseModel): + """Configuration for the LLM-based RAG query generator. + + :param type: Type of query generator, always 'llm' + :param model: Name of the language model to use for query generation + :param template: Template string for formatting the query generation prompt + """ + + type: Literal["llm"] = "llm" + model: str + template: str + + +RAGQueryGeneratorConfig = Annotated[ + DefaultRAGQueryGeneratorConfig | LLMRAGQueryGeneratorConfig, + Field(discriminator="type"), +] +register_schema(RAGQueryGeneratorConfig, name="RAGQueryGeneratorConfig") + + +@json_schema_type +class RAGQueryConfig(BaseModel): + """ + Configuration for the RAG query generation. + + :param query_generator_config: Configuration for the query generator. + :param max_tokens_in_context: Maximum number of tokens in the context. + :param max_chunks: Maximum number of chunks to retrieve. + :param chunk_template: Template for formatting each retrieved chunk in the context. + Available placeholders: {index} (1-based chunk ordinal), {chunk.content} (chunk content string), {metadata} (chunk metadata dict). + Default: "Result {index}\\nContent: {chunk.content}\\nMetadata: {metadata}\\n" + :param mode: Search mode for retrieval—either "vector", "keyword", or "hybrid". Default "vector". + :param ranker: Configuration for the ranker to use in hybrid search. Defaults to RRF ranker. + """ + + # This config defines how a query is generated using the messages + # for memory bank retrieval. + query_generator_config: RAGQueryGeneratorConfig = Field(default=DefaultRAGQueryGeneratorConfig()) + max_tokens_in_context: int = 4096 + max_chunks: int = 5 + chunk_template: str = "Result {index}\nContent: {chunk.content}\nMetadata: {metadata}\n" + mode: RAGSearchMode | None = RAGSearchMode.VECTOR + ranker: Ranker | None = Field(default=None) # Only used for hybrid mode + + @field_validator("chunk_template") + def validate_chunk_template(cls, v: str) -> str: + if "{chunk.content}" not in v: + raise ValueError("chunk_template must contain {chunk.content}") + if "{index}" not in v: + raise ValueError("chunk_template must contain {index}") + if len(v) == 0: + raise ValueError("chunk_template must not be empty") + return v + + +@runtime_checkable +@trace_protocol +class RAGToolRuntime(Protocol): + @webmethod(route="/tool-runtime/rag-tool/insert", method="POST", level=LLAMA_STACK_API_V1) + async def insert( + self, + documents: list[RAGDocument], + vector_db_id: str, + chunk_size_in_tokens: int = 512, + ) -> None: + """Index documents so they can be used by the RAG system. + + :param documents: List of documents to index in the RAG system + :param vector_db_id: ID of the vector database to store the document embeddings + :param chunk_size_in_tokens: (Optional) Size in tokens for document chunking during indexing + """ + ... + + @webmethod(route="/tool-runtime/rag-tool/query", method="POST", level=LLAMA_STACK_API_V1) + async def query( + self, + content: InterleavedContent, + vector_db_ids: list[str], + query_config: RAGQueryConfig | None = None, + ) -> RAGQueryResult: + """Query the RAG system for context; typically invoked by the agent. + + :param content: The query content to search for in the indexed documents + :param vector_db_ids: List of vector database IDs to search within + :param query_config: (Optional) Configuration parameters for the query operation + :returns: RAGQueryResult containing the retrieved content and metadata + """ + ... diff --git a/llama_stack/apis/tools/tools.py b/llama_stack/apis/tools/tools.py index feac0d33e..b6a1a2543 100644 --- a/llama_stack/apis/tools/tools.py +++ b/llama_stack/apis/tools/tools.py @@ -4,6 +4,7 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. +from enum import Enum from typing import Any, Literal, Protocol from pydantic import BaseModel @@ -15,6 +16,8 @@ from llama_stack.apis.version import LLAMA_STACK_API_V1 from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol from llama_stack.schema_utils import json_schema_type, webmethod +from .rag_tool import RAGToolRuntime + @json_schema_type class ToolDef(BaseModel): @@ -178,11 +181,22 @@ class ToolGroups(Protocol): ... +class SpecialToolGroup(Enum): + """Special tool groups with predefined functionality. + + :cvar rag_tool: Retrieval-Augmented Generation tool group for document search and retrieval + """ + + rag_tool = "rag_tool" + + @runtime_checkable @trace_protocol class ToolRuntime(Protocol): tool_store: ToolStore | None = None + rag_tool: RAGToolRuntime | None = None + # TODO: This needs to be renamed once OPEN API generator name conflict issue is fixed. @webmethod(route="/tool-runtime/list-tools", method="GET", level=LLAMA_STACK_API_V1) async def list_runtime_tools( diff --git a/llama_stack/core/routers/tool_runtime.py b/llama_stack/core/routers/tool_runtime.py index 7c5bb25c6..be4c13905 100644 --- a/llama_stack/core/routers/tool_runtime.py +++ b/llama_stack/core/routers/tool_runtime.py @@ -8,8 +8,16 @@ from typing import Any from llama_stack.apis.common.content_types import ( URL, + InterleavedContent, +) +from llama_stack.apis.tools import ( + ListToolDefsResponse, + RAGDocument, + RAGQueryConfig, + RAGQueryResult, + RAGToolRuntime, + ToolRuntime, ) -from llama_stack.apis.tools import ListToolDefsResponse, ToolRuntime from llama_stack.log import get_logger from ..routing_tables.toolgroups import ToolGroupsRoutingTable @@ -18,6 +26,36 @@ logger = get_logger(name=__name__, category="core::routers") class ToolRuntimeRouter(ToolRuntime): + class RagToolImpl(RAGToolRuntime): + def __init__( + self, + routing_table: ToolGroupsRoutingTable, + ) -> None: + logger.debug("Initializing ToolRuntimeRouter.RagToolImpl") + self.routing_table = routing_table + + async def query( + self, + content: InterleavedContent, + vector_store_ids: list[str], + query_config: RAGQueryConfig | None = None, + ) -> RAGQueryResult: + logger.debug(f"ToolRuntimeRouter.RagToolImpl.query: {vector_store_ids}") + provider = await self.routing_table.get_provider_impl("knowledge_search") + return await provider.query(content, vector_store_ids, query_config) + + async def insert( + self, + documents: list[RAGDocument], + vector_store_id: str, + chunk_size_in_tokens: int = 512, + ) -> None: + logger.debug( + f"ToolRuntimeRouter.RagToolImpl.insert: {vector_store_id}, {len(documents)} documents, chunk_size={chunk_size_in_tokens}" + ) + provider = await self.routing_table.get_provider_impl("insert_into_memory") + return await provider.insert(documents, vector_store_id, chunk_size_in_tokens) + def __init__( self, routing_table: ToolGroupsRoutingTable, @@ -25,6 +63,11 @@ class ToolRuntimeRouter(ToolRuntime): logger.debug("Initializing ToolRuntimeRouter") self.routing_table = routing_table + # HACK ALERT this should be in sync with "get_all_api_endpoints()" + self.rag_tool = self.RagToolImpl(routing_table) + for method in ("query", "insert"): + setattr(self, f"rag_tool.{method}", getattr(self.rag_tool, method)) + async def initialize(self) -> None: logger.debug("ToolRuntimeRouter.initialize") pass diff --git a/llama_stack/core/server/routes.py b/llama_stack/core/server/routes.py index ed76ea86f..4970d0bf8 100644 --- a/llama_stack/core/server/routes.py +++ b/llama_stack/core/server/routes.py @@ -13,6 +13,7 @@ from aiohttp import hdrs from starlette.routing import Route from llama_stack.apis.datatypes import Api, ExternalApiSpec +from llama_stack.apis.tools import RAGToolRuntime, SpecialToolGroup from llama_stack.core.resolver import api_protocol_map from llama_stack.schema_utils import WebMethod @@ -24,16 +25,33 @@ RouteImpls = dict[str, PathImpl] RouteMatch = tuple[EndpointFunc, PathParams, str, WebMethod] +def toolgroup_protocol_map(): + return { + SpecialToolGroup.rag_tool: RAGToolRuntime, + } + + def get_all_api_routes( external_apis: dict[Api, ExternalApiSpec] | None = None, ) -> dict[Api, list[tuple[Route, WebMethod]]]: apis = {} protocols = api_protocol_map(external_apis) + toolgroup_protocols = toolgroup_protocol_map() for api, protocol in protocols.items(): routes = [] protocol_methods = inspect.getmembers(protocol, predicate=inspect.isfunction) + # HACK ALERT + if api == Api.tool_runtime: + for tool_group in SpecialToolGroup: + sub_protocol = toolgroup_protocols[tool_group] + sub_protocol_methods = inspect.getmembers(sub_protocol, predicate=inspect.isfunction) + for name, method in sub_protocol_methods: + if not hasattr(method, "__webmethod__"): + continue + protocol_methods.append((f"{tool_group.value}.{name}", method)) + for name, method in protocol_methods: # Get all webmethods for this method (supports multiple decorators) webmethods = getattr(method, "__webmethods__", []) diff --git a/llama_stack/core/stack.py b/llama_stack/core/stack.py index 49100b4bc..4cf1d072d 100644 --- a/llama_stack/core/stack.py +++ b/llama_stack/core/stack.py @@ -32,7 +32,7 @@ from llama_stack.apis.scoring_functions import ScoringFunctions from llama_stack.apis.shields import Shields from llama_stack.apis.synthetic_data_generation import SyntheticDataGeneration from llama_stack.apis.telemetry import Telemetry -from llama_stack.apis.tools import ToolGroups, ToolRuntime +from llama_stack.apis.tools import RAGToolRuntime, ToolGroups, ToolRuntime from llama_stack.apis.vector_io import VectorIO from llama_stack.core.conversations.conversations import ConversationServiceConfig, ConversationServiceImpl from llama_stack.core.datatypes import Provider, StackRunConfig, VectorStoresConfig @@ -80,6 +80,7 @@ class LlamaStack( Inspect, ToolGroups, ToolRuntime, + RAGToolRuntime, Files, Prompts, Conversations, diff --git a/llama_stack/distributions/ci-tests/build.yaml b/llama_stack/distributions/ci-tests/build.yaml index 3cf43de15..c01e415a9 100644 --- a/llama_stack/distributions/ci-tests/build.yaml +++ b/llama_stack/distributions/ci-tests/build.yaml @@ -48,6 +48,7 @@ distribution_spec: tool_runtime: - provider_type: remote::brave-search - provider_type: remote::tavily-search + - provider_type: inline::rag-runtime - provider_type: remote::model-context-protocol batches: - provider_type: inline::reference diff --git a/llama_stack/distributions/ci-tests/run.yaml b/llama_stack/distributions/ci-tests/run.yaml index f403527fc..ecf9eed3b 100644 --- a/llama_stack/distributions/ci-tests/run.yaml +++ b/llama_stack/distributions/ci-tests/run.yaml @@ -216,6 +216,8 @@ providers: config: api_key: ${env.TAVILY_SEARCH_API_KEY:=} max_results: 3 + - provider_id: rag-runtime + provider_type: inline::rag-runtime - provider_id: model-context-protocol provider_type: remote::model-context-protocol batches: @@ -261,6 +263,8 @@ registered_resources: tool_groups: - toolgroup_id: builtin::websearch provider_id: tavily-search + - toolgroup_id: builtin::rag + provider_id: rag-runtime server: port: 8321 telemetry: diff --git a/llama_stack/distributions/dell/build.yaml b/llama_stack/distributions/dell/build.yaml index 0275a47a1..7bc26ca9e 100644 --- a/llama_stack/distributions/dell/build.yaml +++ b/llama_stack/distributions/dell/build.yaml @@ -26,6 +26,7 @@ distribution_spec: tool_runtime: - provider_type: remote::brave-search - provider_type: remote::tavily-search + - provider_type: inline::rag-runtime image_type: venv additional_pip_packages: - aiosqlite diff --git a/llama_stack/distributions/dell/dell.py b/llama_stack/distributions/dell/dell.py index 708ba0b10..88e72688f 100644 --- a/llama_stack/distributions/dell/dell.py +++ b/llama_stack/distributions/dell/dell.py @@ -45,6 +45,7 @@ def get_distribution_template() -> DistributionTemplate: "tool_runtime": [ BuildProvider(provider_type="remote::brave-search"), BuildProvider(provider_type="remote::tavily-search"), + BuildProvider(provider_type="inline::rag-runtime"), ], } name = "dell" @@ -97,6 +98,10 @@ def get_distribution_template() -> DistributionTemplate: toolgroup_id="builtin::websearch", provider_id="brave-search", ), + ToolGroupInput( + toolgroup_id="builtin::rag", + provider_id="rag-runtime", + ), ] return DistributionTemplate( diff --git a/llama_stack/distributions/dell/run-with-safety.yaml b/llama_stack/distributions/dell/run-with-safety.yaml index 062c50e2b..2563f2f4b 100644 --- a/llama_stack/distributions/dell/run-with-safety.yaml +++ b/llama_stack/distributions/dell/run-with-safety.yaml @@ -87,6 +87,8 @@ providers: config: api_key: ${env.TAVILY_SEARCH_API_KEY:=} max_results: 3 + - provider_id: rag-runtime + provider_type: inline::rag-runtime storage: backends: kv_default: @@ -131,6 +133,8 @@ registered_resources: tool_groups: - toolgroup_id: builtin::websearch provider_id: brave-search + - toolgroup_id: builtin::rag + provider_id: rag-runtime server: port: 8321 telemetry: diff --git a/llama_stack/distributions/dell/run.yaml b/llama_stack/distributions/dell/run.yaml index 42e0658bd..7bada394f 100644 --- a/llama_stack/distributions/dell/run.yaml +++ b/llama_stack/distributions/dell/run.yaml @@ -83,6 +83,8 @@ providers: config: api_key: ${env.TAVILY_SEARCH_API_KEY:=} max_results: 3 + - provider_id: rag-runtime + provider_type: inline::rag-runtime storage: backends: kv_default: @@ -122,6 +124,8 @@ registered_resources: tool_groups: - toolgroup_id: builtin::websearch provider_id: brave-search + - toolgroup_id: builtin::rag + provider_id: rag-runtime server: port: 8321 telemetry: diff --git a/llama_stack/distributions/meta-reference-gpu/build.yaml b/llama_stack/distributions/meta-reference-gpu/build.yaml index 74da29bb8..1513742a7 100644 --- a/llama_stack/distributions/meta-reference-gpu/build.yaml +++ b/llama_stack/distributions/meta-reference-gpu/build.yaml @@ -24,6 +24,7 @@ distribution_spec: tool_runtime: - provider_type: remote::brave-search - provider_type: remote::tavily-search + - provider_type: inline::rag-runtime - provider_type: remote::model-context-protocol image_type: venv additional_pip_packages: diff --git a/llama_stack/distributions/meta-reference-gpu/meta_reference.py b/llama_stack/distributions/meta-reference-gpu/meta_reference.py index aa66d43a0..4e4ddef33 100644 --- a/llama_stack/distributions/meta-reference-gpu/meta_reference.py +++ b/llama_stack/distributions/meta-reference-gpu/meta_reference.py @@ -47,6 +47,7 @@ def get_distribution_template() -> DistributionTemplate: "tool_runtime": [ BuildProvider(provider_type="remote::brave-search"), BuildProvider(provider_type="remote::tavily-search"), + BuildProvider(provider_type="inline::rag-runtime"), BuildProvider(provider_type="remote::model-context-protocol"), ], } @@ -91,6 +92,10 @@ def get_distribution_template() -> DistributionTemplate: toolgroup_id="builtin::websearch", provider_id="tavily-search", ), + ToolGroupInput( + toolgroup_id="builtin::rag", + provider_id="rag-runtime", + ), ] return DistributionTemplate( diff --git a/llama_stack/distributions/meta-reference-gpu/run-with-safety.yaml b/llama_stack/distributions/meta-reference-gpu/run-with-safety.yaml index 6e74201db..01b5db4f9 100644 --- a/llama_stack/distributions/meta-reference-gpu/run-with-safety.yaml +++ b/llama_stack/distributions/meta-reference-gpu/run-with-safety.yaml @@ -98,6 +98,8 @@ providers: config: api_key: ${env.TAVILY_SEARCH_API_KEY:=} max_results: 3 + - provider_id: rag-runtime + provider_type: inline::rag-runtime - provider_id: model-context-protocol provider_type: remote::model-context-protocol storage: @@ -144,6 +146,8 @@ registered_resources: tool_groups: - toolgroup_id: builtin::websearch provider_id: tavily-search + - toolgroup_id: builtin::rag + provider_id: rag-runtime server: port: 8321 telemetry: diff --git a/llama_stack/distributions/meta-reference-gpu/run.yaml b/llama_stack/distributions/meta-reference-gpu/run.yaml index 92934ca74..87c33dde0 100644 --- a/llama_stack/distributions/meta-reference-gpu/run.yaml +++ b/llama_stack/distributions/meta-reference-gpu/run.yaml @@ -88,6 +88,8 @@ providers: config: api_key: ${env.TAVILY_SEARCH_API_KEY:=} max_results: 3 + - provider_id: rag-runtime + provider_type: inline::rag-runtime - provider_id: model-context-protocol provider_type: remote::model-context-protocol storage: @@ -129,6 +131,8 @@ registered_resources: tool_groups: - toolgroup_id: builtin::websearch provider_id: tavily-search + - toolgroup_id: builtin::rag + provider_id: rag-runtime server: port: 8321 telemetry: diff --git a/llama_stack/distributions/nvidia/build.yaml b/llama_stack/distributions/nvidia/build.yaml index 3412ea15b..8ddd12439 100644 --- a/llama_stack/distributions/nvidia/build.yaml +++ b/llama_stack/distributions/nvidia/build.yaml @@ -19,7 +19,8 @@ distribution_spec: - provider_type: remote::nvidia scoring: - provider_type: inline::basic - tool_runtime: [] + tool_runtime: + - provider_type: inline::rag-runtime files: - provider_type: inline::localfs image_type: venv diff --git a/llama_stack/distributions/nvidia/nvidia.py b/llama_stack/distributions/nvidia/nvidia.py index 889f83aa5..a92a2e6f8 100644 --- a/llama_stack/distributions/nvidia/nvidia.py +++ b/llama_stack/distributions/nvidia/nvidia.py @@ -28,7 +28,7 @@ def get_distribution_template(name: str = "nvidia") -> DistributionTemplate: BuildProvider(provider_type="remote::nvidia"), ], "scoring": [BuildProvider(provider_type="inline::basic")], - "tool_runtime": [], + "tool_runtime": [BuildProvider(provider_type="inline::rag-runtime")], "files": [BuildProvider(provider_type="inline::localfs")], } @@ -66,7 +66,12 @@ def get_distribution_template(name: str = "nvidia") -> DistributionTemplate: provider_id="nvidia", ) - default_tool_groups: list[ToolGroupInput] = [] + default_tool_groups = [ + ToolGroupInput( + toolgroup_id="builtin::rag", + provider_id="rag-runtime", + ), + ] return DistributionTemplate( name=name, diff --git a/llama_stack/distributions/nvidia/run-with-safety.yaml b/llama_stack/distributions/nvidia/run-with-safety.yaml index dca29ed2a..c23d0f9cb 100644 --- a/llama_stack/distributions/nvidia/run-with-safety.yaml +++ b/llama_stack/distributions/nvidia/run-with-safety.yaml @@ -80,7 +80,9 @@ providers: scoring: - provider_id: basic provider_type: inline::basic - tool_runtime: [] + tool_runtime: + - provider_id: rag-runtime + provider_type: inline::rag-runtime files: - provider_id: meta-reference-files provider_type: inline::localfs @@ -126,7 +128,9 @@ registered_resources: datasets: [] scoring_fns: [] benchmarks: [] - tool_groups: [] + tool_groups: + - toolgroup_id: builtin::rag + provider_id: rag-runtime server: port: 8321 telemetry: diff --git a/llama_stack/distributions/nvidia/run.yaml b/llama_stack/distributions/nvidia/run.yaml index e35d9c44c..81e744d53 100644 --- a/llama_stack/distributions/nvidia/run.yaml +++ b/llama_stack/distributions/nvidia/run.yaml @@ -69,7 +69,9 @@ providers: scoring: - provider_id: basic provider_type: inline::basic - tool_runtime: [] + tool_runtime: + - provider_id: rag-runtime + provider_type: inline::rag-runtime files: - provider_id: meta-reference-files provider_type: inline::localfs @@ -105,7 +107,9 @@ registered_resources: datasets: [] scoring_fns: [] benchmarks: [] - tool_groups: [] + tool_groups: + - toolgroup_id: builtin::rag + provider_id: rag-runtime server: port: 8321 telemetry: diff --git a/llama_stack/distributions/open-benchmark/build.yaml b/llama_stack/distributions/open-benchmark/build.yaml index 9fc0e9eb0..05acd98e3 100644 --- a/llama_stack/distributions/open-benchmark/build.yaml +++ b/llama_stack/distributions/open-benchmark/build.yaml @@ -28,6 +28,7 @@ distribution_spec: tool_runtime: - provider_type: remote::brave-search - provider_type: remote::tavily-search + - provider_type: inline::rag-runtime - provider_type: remote::model-context-protocol image_type: venv additional_pip_packages: diff --git a/llama_stack/distributions/open-benchmark/open_benchmark.py b/llama_stack/distributions/open-benchmark/open_benchmark.py index cceec74fd..2b7760894 100644 --- a/llama_stack/distributions/open-benchmark/open_benchmark.py +++ b/llama_stack/distributions/open-benchmark/open_benchmark.py @@ -118,6 +118,7 @@ def get_distribution_template() -> DistributionTemplate: "tool_runtime": [ BuildProvider(provider_type="remote::brave-search"), BuildProvider(provider_type="remote::tavily-search"), + BuildProvider(provider_type="inline::rag-runtime"), BuildProvider(provider_type="remote::model-context-protocol"), ], } @@ -153,6 +154,10 @@ def get_distribution_template() -> DistributionTemplate: toolgroup_id="builtin::websearch", provider_id="tavily-search", ), + ToolGroupInput( + toolgroup_id="builtin::rag", + provider_id="rag-runtime", + ), ] models, _ = get_model_registry(available_models) diff --git a/llama_stack/distributions/open-benchmark/run.yaml b/llama_stack/distributions/open-benchmark/run.yaml index 8f63e4417..4fd0e199b 100644 --- a/llama_stack/distributions/open-benchmark/run.yaml +++ b/llama_stack/distributions/open-benchmark/run.yaml @@ -118,6 +118,8 @@ providers: config: api_key: ${env.TAVILY_SEARCH_API_KEY:=} max_results: 3 + - provider_id: rag-runtime + provider_type: inline::rag-runtime - provider_id: model-context-protocol provider_type: remote::model-context-protocol storage: @@ -242,6 +244,8 @@ registered_resources: tool_groups: - toolgroup_id: builtin::websearch provider_id: tavily-search + - toolgroup_id: builtin::rag + provider_id: rag-runtime server: port: 8321 telemetry: diff --git a/llama_stack/distributions/postgres-demo/build.yaml b/llama_stack/distributions/postgres-demo/build.yaml index 99b4edeb3..063dc3999 100644 --- a/llama_stack/distributions/postgres-demo/build.yaml +++ b/llama_stack/distributions/postgres-demo/build.yaml @@ -14,6 +14,7 @@ distribution_spec: tool_runtime: - provider_type: remote::brave-search - provider_type: remote::tavily-search + - provider_type: inline::rag-runtime - provider_type: remote::model-context-protocol image_type: venv additional_pip_packages: diff --git a/llama_stack/distributions/postgres-demo/postgres_demo.py b/llama_stack/distributions/postgres-demo/postgres_demo.py index 9f8d35cb1..876370ef3 100644 --- a/llama_stack/distributions/postgres-demo/postgres_demo.py +++ b/llama_stack/distributions/postgres-demo/postgres_demo.py @@ -45,6 +45,7 @@ def get_distribution_template() -> DistributionTemplate: "tool_runtime": [ BuildProvider(provider_type="remote::brave-search"), BuildProvider(provider_type="remote::tavily-search"), + BuildProvider(provider_type="inline::rag-runtime"), BuildProvider(provider_type="remote::model-context-protocol"), ], } @@ -65,6 +66,10 @@ def get_distribution_template() -> DistributionTemplate: toolgroup_id="builtin::websearch", provider_id="tavily-search", ), + ToolGroupInput( + toolgroup_id="builtin::rag", + provider_id="rag-runtime", + ), ] default_models = [ diff --git a/llama_stack/distributions/postgres-demo/run.yaml b/llama_stack/distributions/postgres-demo/run.yaml index 67222969c..0d7ecff48 100644 --- a/llama_stack/distributions/postgres-demo/run.yaml +++ b/llama_stack/distributions/postgres-demo/run.yaml @@ -54,6 +54,8 @@ providers: config: api_key: ${env.TAVILY_SEARCH_API_KEY:=} max_results: 3 + - provider_id: rag-runtime + provider_type: inline::rag-runtime - provider_id: model-context-protocol provider_type: remote::model-context-protocol storage: @@ -105,6 +107,8 @@ registered_resources: tool_groups: - toolgroup_id: builtin::websearch provider_id: tavily-search + - toolgroup_id: builtin::rag + provider_id: rag-runtime server: port: 8321 telemetry: diff --git a/llama_stack/distributions/starter-gpu/build.yaml b/llama_stack/distributions/starter-gpu/build.yaml index 678d7995d..b2e2a0c85 100644 --- a/llama_stack/distributions/starter-gpu/build.yaml +++ b/llama_stack/distributions/starter-gpu/build.yaml @@ -49,6 +49,7 @@ distribution_spec: tool_runtime: - provider_type: remote::brave-search - provider_type: remote::tavily-search + - provider_type: inline::rag-runtime - provider_type: remote::model-context-protocol batches: - provider_type: inline::reference diff --git a/llama_stack/distributions/starter-gpu/run.yaml b/llama_stack/distributions/starter-gpu/run.yaml index 4764dc02c..92483c78e 100644 --- a/llama_stack/distributions/starter-gpu/run.yaml +++ b/llama_stack/distributions/starter-gpu/run.yaml @@ -219,6 +219,8 @@ providers: config: api_key: ${env.TAVILY_SEARCH_API_KEY:=} max_results: 3 + - provider_id: rag-runtime + provider_type: inline::rag-runtime - provider_id: model-context-protocol provider_type: remote::model-context-protocol batches: @@ -264,6 +266,8 @@ registered_resources: tool_groups: - toolgroup_id: builtin::websearch provider_id: tavily-search + - toolgroup_id: builtin::rag + provider_id: rag-runtime server: port: 8321 telemetry: diff --git a/llama_stack/distributions/starter/build.yaml b/llama_stack/distributions/starter/build.yaml index e6cd3c688..baa80ef3e 100644 --- a/llama_stack/distributions/starter/build.yaml +++ b/llama_stack/distributions/starter/build.yaml @@ -49,6 +49,7 @@ distribution_spec: tool_runtime: - provider_type: remote::brave-search - provider_type: remote::tavily-search + - provider_type: inline::rag-runtime - provider_type: remote::model-context-protocol batches: - provider_type: inline::reference diff --git a/llama_stack/distributions/starter/run.yaml b/llama_stack/distributions/starter/run.yaml index 88358501e..3b9d8f890 100644 --- a/llama_stack/distributions/starter/run.yaml +++ b/llama_stack/distributions/starter/run.yaml @@ -216,6 +216,8 @@ providers: config: api_key: ${env.TAVILY_SEARCH_API_KEY:=} max_results: 3 + - provider_id: rag-runtime + provider_type: inline::rag-runtime - provider_id: model-context-protocol provider_type: remote::model-context-protocol batches: @@ -261,6 +263,8 @@ registered_resources: tool_groups: - toolgroup_id: builtin::websearch provider_id: tavily-search + - toolgroup_id: builtin::rag + provider_id: rag-runtime server: port: 8321 telemetry: diff --git a/llama_stack/distributions/starter/starter.py b/llama_stack/distributions/starter/starter.py index bad6279bd..c8c7101a6 100644 --- a/llama_stack/distributions/starter/starter.py +++ b/llama_stack/distributions/starter/starter.py @@ -140,6 +140,7 @@ def get_distribution_template(name: str = "starter") -> DistributionTemplate: "tool_runtime": [ BuildProvider(provider_type="remote::brave-search"), BuildProvider(provider_type="remote::tavily-search"), + BuildProvider(provider_type="inline::rag-runtime"), BuildProvider(provider_type="remote::model-context-protocol"), ], "batches": [ @@ -161,6 +162,10 @@ def get_distribution_template(name: str = "starter") -> DistributionTemplate: toolgroup_id="builtin::websearch", provider_id="tavily-search", ), + ToolGroupInput( + toolgroup_id="builtin::rag", + provider_id="rag-runtime", + ), ] default_shields = [ # if the diff --git a/llama_stack/distributions/watsonx/build.yaml b/llama_stack/distributions/watsonx/build.yaml index d2c396085..dba1a94e2 100644 --- a/llama_stack/distributions/watsonx/build.yaml +++ b/llama_stack/distributions/watsonx/build.yaml @@ -23,6 +23,7 @@ distribution_spec: tool_runtime: - provider_type: remote::brave-search - provider_type: remote::tavily-search + - provider_type: inline::rag-runtime - provider_type: remote::model-context-protocol files: - provider_type: inline::localfs diff --git a/llama_stack/distributions/watsonx/run.yaml b/llama_stack/distributions/watsonx/run.yaml index ddc7e095f..ca3c8402d 100644 --- a/llama_stack/distributions/watsonx/run.yaml +++ b/llama_stack/distributions/watsonx/run.yaml @@ -83,6 +83,8 @@ providers: config: api_key: ${env.TAVILY_SEARCH_API_KEY:=} max_results: 3 + - provider_id: rag-runtime + provider_type: inline::rag-runtime - provider_id: model-context-protocol provider_type: remote::model-context-protocol files: @@ -123,6 +125,8 @@ registered_resources: tool_groups: - toolgroup_id: builtin::websearch provider_id: tavily-search + - toolgroup_id: builtin::rag + provider_id: rag-runtime server: port: 8321 telemetry: diff --git a/llama_stack/distributions/watsonx/watsonx.py b/llama_stack/distributions/watsonx/watsonx.py index b16f76fcb..d79aea872 100644 --- a/llama_stack/distributions/watsonx/watsonx.py +++ b/llama_stack/distributions/watsonx/watsonx.py @@ -33,6 +33,7 @@ def get_distribution_template(name: str = "watsonx") -> DistributionTemplate: "tool_runtime": [ BuildProvider(provider_type="remote::brave-search"), BuildProvider(provider_type="remote::tavily-search"), + BuildProvider(provider_type="inline::rag-runtime"), BuildProvider(provider_type="remote::model-context-protocol"), ], "files": [BuildProvider(provider_type="inline::localfs")], @@ -49,6 +50,10 @@ def get_distribution_template(name: str = "watsonx") -> DistributionTemplate: toolgroup_id="builtin::websearch", provider_id="tavily-search", ), + ToolGroupInput( + toolgroup_id="builtin::rag", + provider_id="rag-runtime", + ), ] files_provider = Provider( diff --git a/llama_stack/providers/inline/tool_runtime/__init__.py b/llama_stack/providers/inline/tool_runtime/__init__.py new file mode 100644 index 000000000..756f351d8 --- /dev/null +++ b/llama_stack/providers/inline/tool_runtime/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. diff --git a/llama_stack/providers/inline/tool_runtime/rag/__init__.py b/llama_stack/providers/inline/tool_runtime/rag/__init__.py new file mode 100644 index 000000000..f9a7e7b89 --- /dev/null +++ b/llama_stack/providers/inline/tool_runtime/rag/__init__.py @@ -0,0 +1,19 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from typing import Any + +from llama_stack.providers.datatypes import Api + +from .config import RagToolRuntimeConfig + + +async def get_provider_impl(config: RagToolRuntimeConfig, deps: dict[Api, Any]): + from .memory import MemoryToolRuntimeImpl + + impl = MemoryToolRuntimeImpl(config, deps[Api.vector_io], deps[Api.inference], deps[Api.files]) + await impl.initialize() + return impl diff --git a/llama_stack/providers/inline/tool_runtime/rag/config.py b/llama_stack/providers/inline/tool_runtime/rag/config.py new file mode 100644 index 000000000..43ba78e65 --- /dev/null +++ b/llama_stack/providers/inline/tool_runtime/rag/config.py @@ -0,0 +1,15 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from typing import Any + +from pydantic import BaseModel + + +class RagToolRuntimeConfig(BaseModel): + @classmethod + def sample_run_config(cls, __distro_dir__: str, **kwargs: Any) -> dict[str, Any]: + return {} diff --git a/llama_stack/providers/inline/tool_runtime/rag/context_retriever.py b/llama_stack/providers/inline/tool_runtime/rag/context_retriever.py new file mode 100644 index 000000000..14cbec49d --- /dev/null +++ b/llama_stack/providers/inline/tool_runtime/rag/context_retriever.py @@ -0,0 +1,77 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + + +from jinja2 import Template + +from llama_stack.apis.common.content_types import InterleavedContent +from llama_stack.apis.inference import OpenAIChatCompletionRequestWithExtraBody, OpenAIUserMessageParam +from llama_stack.apis.tools.rag_tool import ( + DefaultRAGQueryGeneratorConfig, + LLMRAGQueryGeneratorConfig, + RAGQueryGenerator, + RAGQueryGeneratorConfig, +) +from llama_stack.providers.utils.inference.prompt_adapter import ( + interleaved_content_as_str, +) + + +async def generate_rag_query( + config: RAGQueryGeneratorConfig, + content: InterleavedContent, + **kwargs, +): + """ + Generates a query that will be used for + retrieving relevant information from the memory bank. + """ + if config.type == RAGQueryGenerator.default.value: + query = await default_rag_query_generator(config, content, **kwargs) + elif config.type == RAGQueryGenerator.llm.value: + query = await llm_rag_query_generator(config, content, **kwargs) + else: + raise NotImplementedError(f"Unsupported memory query generator {config.type}") + return query + + +async def default_rag_query_generator( + config: DefaultRAGQueryGeneratorConfig, + content: InterleavedContent, + **kwargs, +): + return interleaved_content_as_str(content, sep=config.separator) + + +async def llm_rag_query_generator( + config: LLMRAGQueryGeneratorConfig, + content: InterleavedContent, + **kwargs, +): + assert "inference_api" in kwargs, "LLMRAGQueryGenerator needs inference_api" + inference_api = kwargs["inference_api"] + + messages = [] + if isinstance(content, list): + messages = [interleaved_content_as_str(m) for m in content] + else: + messages = [interleaved_content_as_str(content)] + + template = Template(config.template) + rendered_content: str = template.render({"messages": messages}) + + model = config.model + message = OpenAIUserMessageParam(content=rendered_content) + params = OpenAIChatCompletionRequestWithExtraBody( + model=model, + messages=[message], + stream=False, + ) + response = await inference_api.openai_chat_completion(params) + + query = response.choices[0].message.content + + return query diff --git a/llama_stack/providers/inline/tool_runtime/rag/memory.py b/llama_stack/providers/inline/tool_runtime/rag/memory.py new file mode 100644 index 000000000..dc3dfbbca --- /dev/null +++ b/llama_stack/providers/inline/tool_runtime/rag/memory.py @@ -0,0 +1,332 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import asyncio +import base64 +import io +import mimetypes +from typing import Any + +import httpx +from fastapi import UploadFile +from pydantic import TypeAdapter + +from llama_stack.apis.common.content_types import ( + URL, + InterleavedContent, + InterleavedContentItem, + TextContentItem, +) +from llama_stack.apis.files import Files, OpenAIFilePurpose +from llama_stack.apis.inference import Inference +from llama_stack.apis.tools import ( + ListToolDefsResponse, + RAGDocument, + RAGQueryConfig, + RAGQueryResult, + RAGToolRuntime, + ToolDef, + ToolGroup, + ToolInvocationResult, + ToolRuntime, +) +from llama_stack.apis.vector_io import ( + QueryChunksResponse, + VectorIO, + VectorStoreChunkingStrategyStatic, + VectorStoreChunkingStrategyStaticConfig, +) +from llama_stack.log import get_logger +from llama_stack.providers.datatypes import ToolGroupsProtocolPrivate +from llama_stack.providers.utils.inference.prompt_adapter import interleaved_content_as_str +from llama_stack.providers.utils.memory.vector_store import parse_data_url + +from .config import RagToolRuntimeConfig +from .context_retriever import generate_rag_query + +log = get_logger(name=__name__, category="tool_runtime") + + +async def raw_data_from_doc(doc: RAGDocument) -> tuple[bytes, str]: + """Get raw binary data and mime type from a RAGDocument for file upload.""" + if isinstance(doc.content, URL): + if doc.content.uri.startswith("data:"): + parts = parse_data_url(doc.content.uri) + mime_type = parts["mimetype"] + data = parts["data"] + + if parts["is_base64"]: + file_data = base64.b64decode(data) + else: + file_data = data.encode("utf-8") + + return file_data, mime_type + else: + async with httpx.AsyncClient() as client: + r = await client.get(doc.content.uri) + r.raise_for_status() + mime_type = r.headers.get("content-type", "application/octet-stream") + return r.content, mime_type + else: + if isinstance(doc.content, str): + content_str = doc.content + else: + content_str = interleaved_content_as_str(doc.content) + + if content_str.startswith("data:"): + parts = parse_data_url(content_str) + mime_type = parts["mimetype"] + data = parts["data"] + + if parts["is_base64"]: + file_data = base64.b64decode(data) + else: + file_data = data.encode("utf-8") + + return file_data, mime_type + else: + return content_str.encode("utf-8"), "text/plain" + + +class MemoryToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime, RAGToolRuntime): + def __init__( + self, + config: RagToolRuntimeConfig, + vector_io_api: VectorIO, + inference_api: Inference, + files_api: Files, + ): + self.config = config + self.vector_io_api = vector_io_api + self.inference_api = inference_api + self.files_api = files_api + + async def initialize(self): + pass + + async def shutdown(self): + pass + + async def register_toolgroup(self, toolgroup: ToolGroup) -> None: + pass + + async def unregister_toolgroup(self, toolgroup_id: str) -> None: + return + + async def insert( + self, + documents: list[RAGDocument], + vector_db_id: str, + chunk_size_in_tokens: int = 512, + ) -> None: + if not documents: + return + + for doc in documents: + try: + try: + file_data, mime_type = await raw_data_from_doc(doc) + except Exception as e: + log.error(f"Failed to extract content from document {doc.document_id}: {e}") + continue + + file_extension = mimetypes.guess_extension(mime_type) or ".txt" + filename = doc.metadata.get("filename", f"{doc.document_id}{file_extension}") + + file_obj = io.BytesIO(file_data) + file_obj.name = filename + + upload_file = UploadFile(file=file_obj, filename=filename) + + try: + created_file = await self.files_api.openai_upload_file( + file=upload_file, purpose=OpenAIFilePurpose.ASSISTANTS + ) + except Exception as e: + log.error(f"Failed to upload file for document {doc.document_id}: {e}") + continue + + chunking_strategy = VectorStoreChunkingStrategyStatic( + static=VectorStoreChunkingStrategyStaticConfig( + max_chunk_size_tokens=chunk_size_in_tokens, + chunk_overlap_tokens=chunk_size_in_tokens // 4, + ) + ) + + try: + await self.vector_io_api.openai_attach_file_to_vector_store( + vector_store_id=vector_db_id, + file_id=created_file.id, + attributes=doc.metadata, + chunking_strategy=chunking_strategy, + ) + except Exception as e: + log.error( + f"Failed to attach file {created_file.id} to vector store {vector_db_id} for document {doc.document_id}: {e}" + ) + continue + + except Exception as e: + log.error(f"Unexpected error processing document {doc.document_id}: {e}") + continue + + async def query( + self, + content: InterleavedContent, + vector_db_ids: list[str], + query_config: RAGQueryConfig | None = None, + ) -> RAGQueryResult: + if not vector_db_ids: + raise ValueError( + "No vector DBs were provided to the knowledge search tool. Please provide at least one vector DB ID." + ) + + query_config = query_config or RAGQueryConfig() + query = await generate_rag_query( + query_config.query_generator_config, + content, + inference_api=self.inference_api, + ) + tasks = [ + self.vector_io_api.query_chunks( + vector_db_id=vector_db_id, + query=query, + params={ + "mode": query_config.mode, + "max_chunks": query_config.max_chunks, + "score_threshold": 0.0, + "ranker": query_config.ranker, + }, + ) + for vector_db_id in vector_db_ids + ] + results: list[QueryChunksResponse] = await asyncio.gather(*tasks) + + chunks = [] + scores = [] + + for vector_db_id, result in zip(vector_db_ids, results, strict=False): + for chunk, score in zip(result.chunks, result.scores, strict=False): + if not hasattr(chunk, "metadata") or chunk.metadata is None: + chunk.metadata = {} + chunk.metadata["vector_db_id"] = vector_db_id + + chunks.append(chunk) + scores.append(score) + + if not chunks: + return RAGQueryResult(content=None) + + # sort by score + chunks, scores = zip(*sorted(zip(chunks, scores, strict=False), key=lambda x: x[1], reverse=True), strict=False) # type: ignore + chunks = chunks[: query_config.max_chunks] + + tokens = 0 + picked: list[InterleavedContentItem] = [ + TextContentItem( + text=f"knowledge_search tool found {len(chunks)} chunks:\nBEGIN of knowledge_search tool results.\n" + ) + ] + for i, chunk in enumerate(chunks): + metadata = chunk.metadata + tokens += metadata.get("token_count", 0) + tokens += metadata.get("metadata_token_count", 0) + + if tokens > query_config.max_tokens_in_context: + log.error( + f"Using {len(picked)} chunks; reached max tokens in context: {tokens}", + ) + break + + # Add useful keys from chunk_metadata to metadata and remove some from metadata + chunk_metadata_keys_to_include_from_context = [ + "chunk_id", + "document_id", + "source", + ] + metadata_keys_to_exclude_from_context = [ + "token_count", + "metadata_token_count", + "vector_db_id", + ] + metadata_for_context = {} + for k in chunk_metadata_keys_to_include_from_context: + metadata_for_context[k] = getattr(chunk.chunk_metadata, k) + for k in metadata: + if k not in metadata_keys_to_exclude_from_context: + metadata_for_context[k] = metadata[k] + + text_content = query_config.chunk_template.format(index=i + 1, chunk=chunk, metadata=metadata_for_context) + picked.append(TextContentItem(text=text_content)) + + picked.append(TextContentItem(text="END of knowledge_search tool results.\n")) + picked.append( + TextContentItem( + text=f'The above results were retrieved to help answer the user\'s query: "{interleaved_content_as_str(content)}". Use them as supporting information only in answering this query.\n', + ) + ) + + return RAGQueryResult( + content=picked, + metadata={ + "document_ids": [c.document_id for c in chunks[: len(picked)]], + "chunks": [c.content for c in chunks[: len(picked)]], + "scores": scores[: len(picked)], + "vector_db_ids": [c.metadata["vector_db_id"] for c in chunks[: len(picked)]], + }, + ) + + async def list_runtime_tools( + self, tool_group_id: str | None = None, mcp_endpoint: URL | None = None + ) -> ListToolDefsResponse: + # Parameters are not listed since these methods are not yet invoked automatically + # by the LLM. The method is only implemented so things like /tools can list without + # encountering fatals. + return ListToolDefsResponse( + data=[ + ToolDef( + name="insert_into_memory", + description="Insert documents into memory", + ), + ToolDef( + name="knowledge_search", + description="Search for information in a database.", + input_schema={ + "type": "object", + "properties": { + "query": { + "type": "string", + "description": "The query to search for. Can be a natural language sentence or keywords.", + } + }, + "required": ["query"], + }, + ), + ] + ) + + async def invoke_tool(self, tool_name: str, kwargs: dict[str, Any]) -> ToolInvocationResult: + vector_db_ids = kwargs.get("vector_db_ids", []) + query_config = kwargs.get("query_config") + if query_config: + query_config = TypeAdapter(RAGQueryConfig).validate_python(query_config) + else: + query_config = RAGQueryConfig() + + query = kwargs["query"] + result = await self.query( + content=query, + vector_db_ids=vector_db_ids, + query_config=query_config, + ) + + return ToolInvocationResult( + content=result.content or [], + metadata={ + **(result.metadata or {}), + "citation_files": getattr(result, "citation_files", None), + }, + ) diff --git a/llama_stack/providers/registry/inference.py b/llama_stack/providers/registry/inference.py index 2e52e2d12..35afb296d 100644 --- a/llama_stack/providers/registry/inference.py +++ b/llama_stack/providers/registry/inference.py @@ -42,7 +42,6 @@ def available_providers() -> list[ProviderSpec]: # CrossEncoder depends on torchao.quantization pip_packages=[ "torch torchvision torchao>=0.12.0 --extra-index-url https://download.pytorch.org/whl/cpu", - "numpy tqdm transformers", "sentence-transformers --no-deps", # required by some SentenceTransformers architectures for tensor rearrange/merge ops "einops", diff --git a/llama_stack/providers/registry/tool_runtime.py b/llama_stack/providers/registry/tool_runtime.py index 514d9d0a0..39dc7fccd 100644 --- a/llama_stack/providers/registry/tool_runtime.py +++ b/llama_stack/providers/registry/tool_runtime.py @@ -7,13 +7,33 @@ from llama_stack.providers.datatypes import ( Api, + InlineProviderSpec, ProviderSpec, RemoteProviderSpec, ) +from llama_stack.providers.registry.vector_io import DEFAULT_VECTOR_IO_DEPS def available_providers() -> list[ProviderSpec]: return [ + InlineProviderSpec( + api=Api.tool_runtime, + provider_type="inline::rag-runtime", + pip_packages=DEFAULT_VECTOR_IO_DEPS + + [ + "tqdm", + "numpy", + "scikit-learn", + "scipy", + "nltk", + "sentencepiece", + "transformers", + ], + module="llama_stack.providers.inline.tool_runtime.rag", + config_class="llama_stack.providers.inline.tool_runtime.rag.config.RagToolRuntimeConfig", + api_dependencies=[Api.vector_io, Api.inference, Api.files], + description="RAG (Retrieval-Augmented Generation) tool runtime for document ingestion, chunking, and semantic search.", + ), RemoteProviderSpec( api=Api.tool_runtime, adapter_type="brave-search", diff --git a/llama_stack/providers/registry/vector_io.py b/llama_stack/providers/registry/vector_io.py index db81ea35d..ff3b8486f 100644 --- a/llama_stack/providers/registry/vector_io.py +++ b/llama_stack/providers/registry/vector_io.py @@ -119,7 +119,7 @@ Datasets that can fit in memory, frequent reads | Faiss | Optimized for speed, i #### Empirical Example Consider the histogram below in which 10,000 randomly generated strings were inserted -in batches of 100 into both Faiss and sqlite-vec. +in batches of 100 into both Faiss and sqlite-vec using `client.tool_runtime.rag_tool.insert()`. ```{image} ../../../../_static/providers/vector_io/write_time_comparison_sqlite-vec-faiss.png :alt: Comparison of SQLite-Vec and Faiss write times diff --git a/llama_stack/providers/utils/memory/vector_store.py b/llama_stack/providers/utils/memory/vector_store.py index 9e9c9a08a..6c8746e92 100644 --- a/llama_stack/providers/utils/memory/vector_store.py +++ b/llama_stack/providers/utils/memory/vector_store.py @@ -12,14 +12,17 @@ from dataclasses import dataclass from typing import Any from urllib.parse import unquote +import httpx import numpy as np from numpy.typing import NDArray from pydantic import BaseModel from llama_stack.apis.common.content_types import ( + URL, InterleavedContent, ) from llama_stack.apis.inference import OpenAIEmbeddingsRequestWithExtraBody +from llama_stack.apis.tools import RAGDocument from llama_stack.apis.vector_io import Chunk, ChunkMetadata, QueryChunksResponse from llama_stack.apis.vector_stores import VectorStore from llama_stack.log import get_logger @@ -126,6 +129,31 @@ def content_from_data_and_mime_type(data: bytes | str, mime_type: str | None, en return "" +async def content_from_doc(doc: RAGDocument) -> str: + if isinstance(doc.content, URL): + if doc.content.uri.startswith("data:"): + return content_from_data(doc.content.uri) + async with httpx.AsyncClient() as client: + r = await client.get(doc.content.uri) + if doc.mime_type == "application/pdf": + return parse_pdf(r.content) + return r.text + elif isinstance(doc.content, str): + pattern = re.compile("^(https?://|file://|data:)") + if pattern.match(doc.content): + if doc.content.startswith("data:"): + return content_from_data(doc.content) + async with httpx.AsyncClient() as client: + r = await client.get(doc.content) + if doc.mime_type == "application/pdf": + return parse_pdf(r.content) + return r.text + return doc.content + else: + # will raise ValueError if the content is not List[InterleavedContent] or InterleavedContent + return interleaved_content_as_str(doc.content) + + def make_overlapped_chunks( document_id: str, text: str, window_len: int, overlap_len: int, metadata: dict[str, Any] ) -> list[Chunk]: diff --git a/tests/unit/providers/utils/memory/test_vector_store.py b/tests/unit/providers/utils/memory/test_vector_store.py index 3a5cd5bf7..590bdd1d2 100644 --- a/tests/unit/providers/utils/memory/test_vector_store.py +++ b/tests/unit/providers/utils/memory/test_vector_store.py @@ -4,11 +4,138 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from unittest.mock import patch +from unittest.mock import AsyncMock, MagicMock, patch import pytest -from llama_stack.providers.utils.memory.vector_store import content_from_data_and_mime_type +from llama_stack.apis.common.content_types import URL, TextContentItem +from llama_stack.apis.tools import RAGDocument +from llama_stack.providers.utils.memory.vector_store import content_from_data_and_mime_type, content_from_doc + + +async def test_content_from_doc_with_url(): + """Test extracting content from RAGDocument with URL content.""" + mock_url = URL(uri="https://example.com") + mock_doc = RAGDocument(document_id="foo", content=mock_url) + + mock_response = MagicMock() + mock_response.text = "Sample content from URL" + + with patch("httpx.AsyncClient") as mock_client: + mock_instance = AsyncMock() + mock_instance.get.return_value = mock_response + mock_client.return_value.__aenter__.return_value = mock_instance + + result = await content_from_doc(mock_doc) + + assert result == "Sample content from URL" + mock_instance.get.assert_called_once_with(mock_url.uri) + + +async def test_content_from_doc_with_pdf_url(): + """Test extracting content from RAGDocument with URL pointing to a PDF.""" + mock_url = URL(uri="https://example.com/document.pdf") + mock_doc = RAGDocument(document_id="foo", content=mock_url, mime_type="application/pdf") + + mock_response = MagicMock() + mock_response.content = b"PDF binary data" + + with ( + patch("httpx.AsyncClient") as mock_client, + patch("llama_stack.providers.utils.memory.vector_store.parse_pdf") as mock_parse_pdf, + ): + mock_instance = AsyncMock() + mock_instance.get.return_value = mock_response + mock_client.return_value.__aenter__.return_value = mock_instance + mock_parse_pdf.return_value = "Extracted PDF content" + + result = await content_from_doc(mock_doc) + + assert result == "Extracted PDF content" + mock_instance.get.assert_called_once_with(mock_url.uri) + mock_parse_pdf.assert_called_once_with(b"PDF binary data") + + +async def test_content_from_doc_with_data_url(): + """Test extracting content from RAGDocument with data URL content.""" + data_url = "data:text/plain;base64,SGVsbG8gV29ybGQ=" # "Hello World" base64 encoded + mock_url = URL(uri=data_url) + mock_doc = RAGDocument(document_id="foo", content=mock_url) + + with patch("llama_stack.providers.utils.memory.vector_store.content_from_data") as mock_content_from_data: + mock_content_from_data.return_value = "Hello World" + + result = await content_from_doc(mock_doc) + + assert result == "Hello World" + mock_content_from_data.assert_called_once_with(data_url) + + +async def test_content_from_doc_with_string(): + """Test extracting content from RAGDocument with string content.""" + content_string = "This is plain text content" + mock_doc = RAGDocument(document_id="foo", content=content_string) + + result = await content_from_doc(mock_doc) + + assert result == content_string + + +async def test_content_from_doc_with_string_url(): + """Test extracting content from RAGDocument with string URL content.""" + url_string = "https://example.com" + mock_doc = RAGDocument(document_id="foo", content=url_string) + + mock_response = MagicMock() + mock_response.text = "Sample content from URL string" + + with patch("httpx.AsyncClient") as mock_client: + mock_instance = AsyncMock() + mock_instance.get.return_value = mock_response + mock_client.return_value.__aenter__.return_value = mock_instance + + result = await content_from_doc(mock_doc) + + assert result == "Sample content from URL string" + mock_instance.get.assert_called_once_with(url_string) + + +async def test_content_from_doc_with_string_pdf_url(): + """Test extracting content from RAGDocument with string URL pointing to a PDF.""" + url_string = "https://example.com/document.pdf" + mock_doc = RAGDocument(document_id="foo", content=url_string, mime_type="application/pdf") + + mock_response = MagicMock() + mock_response.content = b"PDF binary data" + + with ( + patch("httpx.AsyncClient") as mock_client, + patch("llama_stack.providers.utils.memory.vector_store.parse_pdf") as mock_parse_pdf, + ): + mock_instance = AsyncMock() + mock_instance.get.return_value = mock_response + mock_client.return_value.__aenter__.return_value = mock_instance + mock_parse_pdf.return_value = "Extracted PDF content from string URL" + + result = await content_from_doc(mock_doc) + + assert result == "Extracted PDF content from string URL" + mock_instance.get.assert_called_once_with(url_string) + mock_parse_pdf.assert_called_once_with(b"PDF binary data") + + +async def test_content_from_doc_with_interleaved_content(): + """Test extracting content from RAGDocument with InterleavedContent (the new case added in the commit).""" + interleaved_content = [TextContentItem(text="First item"), TextContentItem(text="Second item")] + mock_doc = RAGDocument(document_id="foo", content=interleaved_content) + + with patch("llama_stack.providers.utils.memory.vector_store.interleaved_content_as_str") as mock_interleaved: + mock_interleaved.return_value = "First item\nSecond item" + + result = await content_from_doc(mock_doc) + + assert result == "First item\nSecond item" + mock_interleaved.assert_called_once_with(interleaved_content) def test_content_from_data_and_mime_type_success_utf8(): @@ -51,3 +178,41 @@ def test_content_from_data_and_mime_type_both_encodings_fail(): # Should raise an exception instead of returning empty string with pytest.raises(UnicodeDecodeError): content_from_data_and_mime_type(data, mime_type) + + +async def test_memory_tool_error_handling(): + """Test that memory tool handles various failures gracefully without crashing.""" + from llama_stack.providers.inline.tool_runtime.rag.config import RagToolRuntimeConfig + from llama_stack.providers.inline.tool_runtime.rag.memory import MemoryToolRuntimeImpl + + config = RagToolRuntimeConfig() + memory_tool = MemoryToolRuntimeImpl( + config=config, + vector_io_api=AsyncMock(), + inference_api=AsyncMock(), + files_api=AsyncMock(), + ) + + docs = [ + RAGDocument(document_id="good_doc", content="Good content", metadata={}), + RAGDocument(document_id="bad_url_doc", content=URL(uri="https://bad.url"), metadata={}), + RAGDocument(document_id="another_good_doc", content="Another good content", metadata={}), + ] + + mock_file1 = MagicMock() + mock_file1.id = "file_good1" + mock_file2 = MagicMock() + mock_file2.id = "file_good2" + memory_tool.files_api.openai_upload_file.side_effect = [mock_file1, mock_file2] + + with patch("httpx.AsyncClient") as mock_client: + mock_instance = AsyncMock() + mock_instance.get.side_effect = Exception("Bad URL") + mock_client.return_value.__aenter__.return_value = mock_instance + + # won't raise exception despite one document failing + await memory_tool.insert(docs, "vector_store_123") + + # processed 2 documents successfully, skipped 1 + assert memory_tool.files_api.openai_upload_file.call_count == 2 + assert memory_tool.vector_io_api.openai_attach_file_to_vector_store.call_count == 2 diff --git a/tests/unit/rag/test_rag_query.py b/tests/unit/rag/test_rag_query.py new file mode 100644 index 000000000..c012bc4f0 --- /dev/null +++ b/tests/unit/rag/test_rag_query.py @@ -0,0 +1,138 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from unittest.mock import AsyncMock, MagicMock + +import pytest + +from llama_stack.apis.tools.rag_tool import RAGQueryConfig +from llama_stack.apis.vector_io import ( + Chunk, + ChunkMetadata, + QueryChunksResponse, +) +from llama_stack.providers.inline.tool_runtime.rag.memory import MemoryToolRuntimeImpl + + +class TestRagQuery: + async def test_query_raises_on_empty_vector_store_ids(self): + rag_tool = MemoryToolRuntimeImpl( + config=MagicMock(), vector_io_api=MagicMock(), inference_api=MagicMock(), files_api=MagicMock() + ) + with pytest.raises(ValueError): + await rag_tool.query(content=MagicMock(), vector_db_ids=[]) + + async def test_query_chunk_metadata_handling(self): + rag_tool = MemoryToolRuntimeImpl( + config=MagicMock(), vector_io_api=MagicMock(), inference_api=MagicMock(), files_api=MagicMock() + ) + content = "test query content" + vector_db_ids = ["db1"] + + chunk_metadata = ChunkMetadata( + document_id="doc1", + chunk_id="chunk1", + source="test_source", + metadata_token_count=5, + ) + interleaved_content = MagicMock() + chunk = Chunk( + content=interleaved_content, + metadata={ + "key1": "value1", + "token_count": 10, + "metadata_token_count": 5, + # Note this is inserted into `metadata` during MemoryToolRuntimeImpl().insert() + "document_id": "doc1", + }, + stored_chunk_id="chunk1", + chunk_metadata=chunk_metadata, + ) + + query_response = QueryChunksResponse(chunks=[chunk], scores=[1.0]) + + rag_tool.vector_io_api.query_chunks = AsyncMock(return_value=query_response) + result = await rag_tool.query(content=content, vector_db_ids=vector_db_ids) + + assert result is not None + expected_metadata_string = ( + "Metadata: {'chunk_id': 'chunk1', 'document_id': 'doc1', 'source': 'test_source', 'key1': 'value1'}" + ) + assert expected_metadata_string in result.content[1].text + assert result.content is not None + + async def test_query_raises_incorrect_mode(self): + with pytest.raises(ValueError): + RAGQueryConfig(mode="invalid_mode") + + async def test_query_accepts_valid_modes(self): + default_config = RAGQueryConfig() # Test default (vector) + assert default_config.mode == "vector" + vector_config = RAGQueryConfig(mode="vector") # Test vector + assert vector_config.mode == "vector" + keyword_config = RAGQueryConfig(mode="keyword") # Test keyword + assert keyword_config.mode == "keyword" + hybrid_config = RAGQueryConfig(mode="hybrid") # Test hybrid + assert hybrid_config.mode == "hybrid" + + # Test that invalid mode raises an error + with pytest.raises(ValueError): + RAGQueryConfig(mode="wrong_mode") + + async def test_query_adds_vector_store_id_to_chunk_metadata(self): + rag_tool = MemoryToolRuntimeImpl( + config=MagicMock(), + vector_io_api=MagicMock(), + inference_api=MagicMock(), + files_api=MagicMock(), + ) + + vector_db_ids = ["db1", "db2"] + + # Fake chunks from each DB + chunk_metadata1 = ChunkMetadata( + document_id="doc1", + chunk_id="chunk1", + source="test_source1", + metadata_token_count=5, + ) + chunk1 = Chunk( + content="chunk from db1", + metadata={"vector_db_id": "db1", "document_id": "doc1"}, + stored_chunk_id="c1", + chunk_metadata=chunk_metadata1, + ) + + chunk_metadata2 = ChunkMetadata( + document_id="doc2", + chunk_id="chunk2", + source="test_source2", + metadata_token_count=5, + ) + chunk2 = Chunk( + content="chunk from db2", + metadata={"vector_db_id": "db2", "document_id": "doc2"}, + stored_chunk_id="c2", + chunk_metadata=chunk_metadata2, + ) + + rag_tool.vector_io_api.query_chunks = AsyncMock( + side_effect=[ + QueryChunksResponse(chunks=[chunk1], scores=[0.9]), + QueryChunksResponse(chunks=[chunk2], scores=[0.8]), + ] + ) + + result = await rag_tool.query(content="test", vector_db_ids=vector_db_ids) + returned_chunks = result.metadata["chunks"] + returned_scores = result.metadata["scores"] + returned_doc_ids = result.metadata["document_ids"] + returned_vector_db_ids = result.metadata["vector_db_ids"] + + assert returned_chunks == ["chunk from db1", "chunk from db2"] + assert returned_scores == (0.9, 0.8) + assert returned_doc_ids == ["doc1", "doc2"] + assert returned_vector_db_ids == ["db1", "db2"] diff --git a/tests/unit/rag/test_vector_store.py b/tests/unit/rag/test_vector_store.py index e185b83e7..200da5c26 100644 --- a/tests/unit/rag/test_vector_store.py +++ b/tests/unit/rag/test_vector_store.py @@ -4,6 +4,10 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. +import base64 +import mimetypes +import os +from pathlib import Path from unittest.mock import AsyncMock, MagicMock import numpy as np @@ -13,13 +17,37 @@ from llama_stack.apis.inference.inference import ( OpenAIEmbeddingData, OpenAIEmbeddingsRequestWithExtraBody, ) +from llama_stack.apis.tools import RAGDocument from llama_stack.apis.vector_io import Chunk from llama_stack.providers.utils.memory.vector_store import ( + URL, VectorStoreWithIndex, _validate_embedding, + content_from_doc, make_overlapped_chunks, ) +DUMMY_PDF_PATH = Path(os.path.abspath(__file__)).parent / "fixtures" / "dummy.pdf" +# Depending on the machine, this can get parsed a couple of ways +DUMMY_PDF_TEXT_CHOICES = ["Dummy PDF file", "Dumm y PDF file"] + + +def read_file(file_path: str) -> bytes: + with open(file_path, "rb") as file: + return file.read() + + +def data_url_from_file(file_path: str) -> str: + with open(file_path, "rb") as file: + file_content = file.read() + + base64_content = base64.b64encode(file_content).decode("utf-8") + mime_type, _ = mimetypes.guess_type(file_path) + + data_url = f"data:{mime_type};base64,{base64_content}" + + return data_url + class TestChunk: def test_chunk(self): @@ -88,6 +116,45 @@ class TestValidateEmbedding: class TestVectorStore: + async def test_returns_content_from_pdf_data_uri(self): + data_uri = data_url_from_file(DUMMY_PDF_PATH) + doc = RAGDocument( + document_id="dummy", + content=data_uri, + mime_type="application/pdf", + metadata={}, + ) + content = await content_from_doc(doc) + assert content in DUMMY_PDF_TEXT_CHOICES + + @pytest.mark.allow_network + async def test_downloads_pdf_and_returns_content(self): + # Using GitHub to host the PDF file + url = "https://raw.githubusercontent.com/meta-llama/llama-stack/da035d69cfca915318eaf485770a467ca3c2a238/llama_stack/providers/tests/memory/fixtures/dummy.pdf" + doc = RAGDocument( + document_id="dummy", + content=url, + mime_type="application/pdf", + metadata={}, + ) + content = await content_from_doc(doc) + assert content in DUMMY_PDF_TEXT_CHOICES + + @pytest.mark.allow_network + async def test_downloads_pdf_and_returns_content_with_url_object(self): + # Using GitHub to host the PDF file + url = "https://raw.githubusercontent.com/meta-llama/llama-stack/da035d69cfca915318eaf485770a467ca3c2a238/llama_stack/providers/tests/memory/fixtures/dummy.pdf" + doc = RAGDocument( + document_id="dummy", + content=URL( + uri=url, + ), + mime_type="application/pdf", + metadata={}, + ) + content = await content_from_doc(doc) + assert content in DUMMY_PDF_TEXT_CHOICES + @pytest.mark.parametrize( "window_len, overlap_len, expected_chunks", [