Merge bf442eb3f3 into sapling-pr-archive-ehhuang

This commit is contained in:
ehhuang 2025-10-21 09:47:38 -07:00 committed by GitHub
commit a7c0ec991b
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
97 changed files with 1986 additions and 5308 deletions

View file

@ -86,10 +86,9 @@ runs:
if: ${{ always() }} if: ${{ always() }}
shell: bash shell: bash
run: | run: |
sudo docker logs ollama > ollama-${{ inputs.inference-mode }}.log || true # Ollama logs (if ollama container exists)
distro_name=$(echo "${{ inputs.stack-config }}" | sed 's/^docker://' | sed 's/^server://') sudo docker logs ollama > ollama-${{ inputs.inference-mode }}.log 2>&1 || true
stack_container_name="llama-stack-test-$distro_name" # Note: distro container logs are now dumped in integration-tests.sh before container is removed
sudo docker logs $stack_container_name > docker-${distro_name}-${{ inputs.inference-mode }}.log || true
- name: Upload logs - name: Upload logs
if: ${{ always() }} if: ${{ always() }}

View file

@ -99,7 +99,7 @@ jobs:
owner: context.repo.owner, owner: context.repo.owner,
repo: context.repo.repo, repo: context.repo.repo,
issue_number: ${{ steps.check_author.outputs.pr_number }}, issue_number: ${{ steps.check_author.outputs.pr_number }},
body: `⏳ Running pre-commit hooks on PR #${{ steps.check_author.outputs.pr_number }}...` body: `⏳ Running [pre-commit hooks](https://github.com/${context.repo.owner}/${context.repo.repo}/actions/runs/${context.runId}) on PR #${{ steps.check_author.outputs.pr_number }}...`
}); });
- name: Checkout PR branch (same-repo) - name: Checkout PR branch (same-repo)

View file

@ -208,19 +208,6 @@ resources:
type: http type: http
endpoint: post /v1/conversations/{conversation_id}/items endpoint: post /v1/conversations/{conversation_id}/items
datasets:
models:
list_datasets_response: ListDatasetsResponse
methods:
register: post /v1beta/datasets
retrieve: get /v1beta/datasets/{dataset_id}
list:
endpoint: get /v1beta/datasets
paginated: false
unregister: delete /v1beta/datasets/{dataset_id}
iterrows: get /v1beta/datasetio/iterrows/{dataset_id}
appendrows: post /v1beta/datasetio/append-rows/{dataset_id}
inspect: inspect:
models: models:
healthInfo: HealthInfo healthInfo: HealthInfo
@ -521,6 +508,21 @@ resources:
stream_event_model: alpha.agents.turn.agent_turn_response_stream_chunk stream_event_model: alpha.agents.turn.agent_turn_response_stream_chunk
param_discriminator: stream param_discriminator: stream
beta:
subresources:
datasets:
models:
list_datasets_response: ListDatasetsResponse
methods:
register: post /v1beta/datasets
retrieve: get /v1beta/datasets/{dataset_id}
list:
endpoint: get /v1beta/datasets
paginated: false
unregister: delete /v1beta/datasets/{dataset_id}
iterrows: get /v1beta/datasetio/iterrows/{dataset_id}
appendrows: post /v1beta/datasetio/append-rows/{dataset_id}
settings: settings:
license: MIT license: MIT

View file

@ -2039,69 +2039,6 @@ paths:
schema: schema:
$ref: '#/components/schemas/URL' $ref: '#/components/schemas/URL'
deprecated: false deprecated: false
/v1/tool-runtime/rag-tool/insert:
post:
responses:
'200':
description: OK
'400':
$ref: '#/components/responses/BadRequest400'
'429':
$ref: >-
#/components/responses/TooManyRequests429
'500':
$ref: >-
#/components/responses/InternalServerError500
default:
$ref: '#/components/responses/DefaultError'
tags:
- ToolRuntime
summary: >-
Index documents so they can be used by the RAG system.
description: >-
Index documents so they can be used by the RAG system.
parameters: []
requestBody:
content:
application/json:
schema:
$ref: '#/components/schemas/InsertRequest'
required: true
deprecated: false
/v1/tool-runtime/rag-tool/query:
post:
responses:
'200':
description: >-
RAGQueryResult containing the retrieved content and metadata
content:
application/json:
schema:
$ref: '#/components/schemas/RAGQueryResult'
'400':
$ref: '#/components/responses/BadRequest400'
'429':
$ref: >-
#/components/responses/TooManyRequests429
'500':
$ref: >-
#/components/responses/InternalServerError500
default:
$ref: '#/components/responses/DefaultError'
tags:
- ToolRuntime
summary: >-
Query the RAG system for context; typically invoked by the agent.
description: >-
Query the RAG system for context; typically invoked by the agent.
parameters: []
requestBody:
content:
application/json:
schema:
$ref: '#/components/schemas/QueryRequest'
required: true
deprecated: false
/v1/toolgroups: /v1/toolgroups:
get: get:
responses: responses:
@ -6440,7 +6377,7 @@ components:
enum: enum:
- model - model
- shield - shield
- vector_db - vector_store
- dataset - dataset
- scoring_function - scoring_function
- benchmark - benchmark
@ -9132,7 +9069,7 @@ components:
enum: enum:
- model - model
- shield - shield
- vector_db - vector_store
- dataset - dataset
- scoring_function - scoring_function
- benchmark - benchmark
@ -9440,7 +9377,7 @@ components:
enum: enum:
- model - model
- shield - shield
- vector_db - vector_store
- dataset - dataset
- scoring_function - scoring_function
- benchmark - benchmark
@ -9921,274 +9858,6 @@ components:
title: ListToolDefsResponse title: ListToolDefsResponse
description: >- description: >-
Response containing a list of tool definitions. Response containing a list of tool definitions.
RAGDocument:
type: object
properties:
document_id:
type: string
description: The unique identifier for the document.
content:
oneOf:
- type: string
- $ref: '#/components/schemas/InterleavedContentItem'
- type: array
items:
$ref: '#/components/schemas/InterleavedContentItem'
- $ref: '#/components/schemas/URL'
description: The content of the document.
mime_type:
type: string
description: The MIME type of the document.
metadata:
type: object
additionalProperties:
oneOf:
- type: 'null'
- type: boolean
- type: number
- type: string
- type: array
- type: object
description: Additional metadata for the document.
additionalProperties: false
required:
- document_id
- content
- metadata
title: RAGDocument
description: >-
A document to be used for document ingestion in the RAG Tool.
InsertRequest:
type: object
properties:
documents:
type: array
items:
$ref: '#/components/schemas/RAGDocument'
description: >-
List of documents to index in the RAG system
vector_db_id:
type: string
description: >-
ID of the vector database to store the document embeddings
chunk_size_in_tokens:
type: integer
description: >-
(Optional) Size in tokens for document chunking during indexing
additionalProperties: false
required:
- documents
- vector_db_id
- chunk_size_in_tokens
title: InsertRequest
DefaultRAGQueryGeneratorConfig:
type: object
properties:
type:
type: string
const: default
default: default
description: >-
Type of query generator, always 'default'
separator:
type: string
default: ' '
description: >-
String separator used to join query terms
additionalProperties: false
required:
- type
- separator
title: DefaultRAGQueryGeneratorConfig
description: >-
Configuration for the default RAG query generator.
LLMRAGQueryGeneratorConfig:
type: object
properties:
type:
type: string
const: llm
default: llm
description: Type of query generator, always 'llm'
model:
type: string
description: >-
Name of the language model to use for query generation
template:
type: string
description: >-
Template string for formatting the query generation prompt
additionalProperties: false
required:
- type
- model
- template
title: LLMRAGQueryGeneratorConfig
description: >-
Configuration for the LLM-based RAG query generator.
RAGQueryConfig:
type: object
properties:
query_generator_config:
oneOf:
- $ref: '#/components/schemas/DefaultRAGQueryGeneratorConfig'
- $ref: '#/components/schemas/LLMRAGQueryGeneratorConfig'
discriminator:
propertyName: type
mapping:
default: '#/components/schemas/DefaultRAGQueryGeneratorConfig'
llm: '#/components/schemas/LLMRAGQueryGeneratorConfig'
description: Configuration for the query generator.
max_tokens_in_context:
type: integer
default: 4096
description: Maximum number of tokens in the context.
max_chunks:
type: integer
default: 5
description: Maximum number of chunks to retrieve.
chunk_template:
type: string
default: >
Result {index}
Content: {chunk.content}
Metadata: {metadata}
description: >-
Template for formatting each retrieved chunk in the context. Available
placeholders: {index} (1-based chunk ordinal), {chunk.content} (chunk
content string), {metadata} (chunk metadata dict). Default: "Result {index}\nContent:
{chunk.content}\nMetadata: {metadata}\n"
mode:
$ref: '#/components/schemas/RAGSearchMode'
default: vector
description: >-
Search mode for retrieval—either "vector", "keyword", or "hybrid". Default
"vector".
ranker:
$ref: '#/components/schemas/Ranker'
description: >-
Configuration for the ranker to use in hybrid search. Defaults to RRF
ranker.
additionalProperties: false
required:
- query_generator_config
- max_tokens_in_context
- max_chunks
- chunk_template
title: RAGQueryConfig
description: >-
Configuration for the RAG query generation.
RAGSearchMode:
type: string
enum:
- vector
- keyword
- hybrid
title: RAGSearchMode
description: >-
Search modes for RAG query retrieval: - VECTOR: Uses vector similarity search
for semantic matching - KEYWORD: Uses keyword-based search for exact matching
- HYBRID: Combines both vector and keyword search for better results
RRFRanker:
type: object
properties:
type:
type: string
const: rrf
default: rrf
description: The type of ranker, always "rrf"
impact_factor:
type: number
default: 60.0
description: >-
The impact factor for RRF scoring. Higher values give more weight to higher-ranked
results. Must be greater than 0
additionalProperties: false
required:
- type
- impact_factor
title: RRFRanker
description: >-
Reciprocal Rank Fusion (RRF) ranker configuration.
Ranker:
oneOf:
- $ref: '#/components/schemas/RRFRanker'
- $ref: '#/components/schemas/WeightedRanker'
discriminator:
propertyName: type
mapping:
rrf: '#/components/schemas/RRFRanker'
weighted: '#/components/schemas/WeightedRanker'
WeightedRanker:
type: object
properties:
type:
type: string
const: weighted
default: weighted
description: The type of ranker, always "weighted"
alpha:
type: number
default: 0.5
description: >-
Weight factor between 0 and 1. 0 means only use keyword scores, 1 means
only use vector scores, values in between blend both scores.
additionalProperties: false
required:
- type
- alpha
title: WeightedRanker
description: >-
Weighted ranker configuration that combines vector and keyword scores.
QueryRequest:
type: object
properties:
content:
$ref: '#/components/schemas/InterleavedContent'
description: >-
The query content to search for in the indexed documents
vector_db_ids:
type: array
items:
type: string
description: >-
List of vector database IDs to search within
query_config:
$ref: '#/components/schemas/RAGQueryConfig'
description: >-
(Optional) Configuration parameters for the query operation
additionalProperties: false
required:
- content
- vector_db_ids
title: QueryRequest
RAGQueryResult:
type: object
properties:
content:
$ref: '#/components/schemas/InterleavedContent'
description: >-
(Optional) The retrieved content from the query
metadata:
type: object
additionalProperties:
oneOf:
- type: 'null'
- type: boolean
- type: number
- type: string
- type: array
- type: object
description: >-
Additional metadata about the query result
additionalProperties: false
required:
- metadata
title: RAGQueryResult
description: >-
Result of a RAG query containing retrieved content and metadata.
ToolGroup: ToolGroup:
type: object type: object
properties: properties:
@ -10203,7 +9872,7 @@ components:
enum: enum:
- model - model
- shield - shield
- vector_db - vector_store
- dataset - dataset
- scoring_function - scoring_function
- benchmark - benchmark
@ -11325,7 +10994,7 @@ components:
enum: enum:
- model - model
- shield - shield
- vector_db - vector_store
- dataset - dataset
- scoring_function - scoring_function
- benchmark - benchmark
@ -12652,7 +12321,7 @@ components:
enum: enum:
- model - model
- shield - shield
- vector_db - vector_store
- dataset - dataset
- scoring_function - scoring_function
- benchmark - benchmark

View file

@ -21,7 +21,7 @@ The `llamastack/distribution-meta-reference-gpu` distribution consists of the fo
| inference | `inline::meta-reference` | | inference | `inline::meta-reference` |
| safety | `inline::llama-guard` | | safety | `inline::llama-guard` |
| scoring | `inline::basic`, `inline::llm-as-judge`, `inline::braintrust` | | scoring | `inline::basic`, `inline::llm-as-judge`, `inline::braintrust` |
| tool_runtime | `remote::brave-search`, `remote::tavily-search`, `inline::rag-runtime`, `remote::model-context-protocol` | | tool_runtime | `remote::brave-search`, `remote::tavily-search`, `remote::model-context-protocol` |
| vector_io | `inline::faiss`, `remote::chromadb`, `remote::pgvector` | | vector_io | `inline::faiss`, `remote::chromadb`, `remote::pgvector` |

View file

@ -16,7 +16,7 @@ The `llamastack/distribution-nvidia` distribution consists of the following prov
| post_training | `remote::nvidia` | | post_training | `remote::nvidia` |
| safety | `remote::nvidia` | | safety | `remote::nvidia` |
| scoring | `inline::basic` | | scoring | `inline::basic` |
| tool_runtime | `inline::rag-runtime` | | tool_runtime | |
| vector_io | `inline::faiss` | | vector_io | `inline::faiss` |

View file

@ -28,7 +28,7 @@ description: |
#### Empirical Example #### Empirical Example
Consider the histogram below in which 10,000 randomly generated strings were inserted Consider the histogram below in which 10,000 randomly generated strings were inserted
in batches of 100 into both Faiss and sqlite-vec using `client.tool_runtime.rag_tool.insert()`. in batches of 100 into both Faiss and sqlite-vec.
```{image} ../../../../_static/providers/vector_io/write_time_comparison_sqlite-vec-faiss.png ```{image} ../../../../_static/providers/vector_io/write_time_comparison_sqlite-vec-faiss.png
:alt: Comparison of SQLite-Vec and Faiss write times :alt: Comparison of SQLite-Vec and Faiss write times
@ -233,7 +233,7 @@ Datasets that can fit in memory, frequent reads | Faiss | Optimized for speed, i
#### Empirical Example #### Empirical Example
Consider the histogram below in which 10,000 randomly generated strings were inserted Consider the histogram below in which 10,000 randomly generated strings were inserted
in batches of 100 into both Faiss and sqlite-vec using `client.tool_runtime.rag_tool.insert()`. in batches of 100 into both Faiss and sqlite-vec.
```{image} ../../../../_static/providers/vector_io/write_time_comparison_sqlite-vec-faiss.png ```{image} ../../../../_static/providers/vector_io/write_time_comparison_sqlite-vec-faiss.png
:alt: Comparison of SQLite-Vec and Faiss write times :alt: Comparison of SQLite-Vec and Faiss write times

View file

@ -32,7 +32,6 @@ Commands:
scoring_functions Manage scoring functions. scoring_functions Manage scoring functions.
shields Manage safety shield services. shields Manage safety shield services.
toolgroups Manage available tool groups. toolgroups Manage available tool groups.
vector_dbs Manage vector databases.
``` ```
### `llama-stack-client configure` ### `llama-stack-client configure`
@ -211,53 +210,6 @@ Unregister a model from distribution endpoint
llama-stack-client models unregister <model_id> llama-stack-client models unregister <model_id>
``` ```
## Vector DB Management
Manage vector databases.
### `llama-stack-client vector_dbs list`
Show available vector dbs on distribution endpoint
```bash
llama-stack-client vector_dbs list
```
```
┏━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓
┃ identifier ┃ provider_id ┃ provider_resource_id ┃ vector_db_type ┃ params ┃
┡━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┩
│ my_demo_vector_db │ faiss │ my_demo_vector_db │ │ embedding_dimension: 768 │
│ │ │ │ │ embedding_model: nomic-embed-text-v1.5 │
│ │ │ │ │ type: vector_db │
│ │ │ │ │ │
└──────────────────────────┴─────────────┴──────────────────────────┴────────────────┴───────────────────────────────────┘
```
### `llama-stack-client vector_dbs register`
Create a new vector db
```bash
llama-stack-client vector_dbs register <vector-db-id> [--provider-id <provider-id>] [--provider-vector-db-id <provider-vector-db-id>] [--embedding-model <embedding-model>] [--embedding-dimension <embedding-dimension>]
```
Required arguments:
- `VECTOR_DB_ID`: Vector DB ID
Optional arguments:
- `--provider-id`: Provider ID for the vector db
- `--provider-vector-db-id`: Provider's vector db ID
- `--embedding-model`: Embedding model to use. Default: `nomic-embed-text-v1.5`
- `--embedding-dimension`: Dimension of embeddings. Default: 768
### `llama-stack-client vector_dbs unregister`
Delete a vector db
```bash
llama-stack-client vector_dbs unregister <vector-db-id>
```
Required arguments:
- `VECTOR_DB_ID`: Vector DB ID
## Shield Management ## Shield Management
Manage safety shield services. Manage safety shield services.
### `llama-stack-client shields list` ### `llama-stack-client shields list`

File diff suppressed because one or more lines are too long

View file

@ -196,16 +196,10 @@ def _get_endpoint_functions(
def _get_defining_class(member_fn: str, derived_cls: type) -> type: def _get_defining_class(member_fn: str, derived_cls: type) -> type:
"Find the class in which a member function is first defined in a class inheritance hierarchy." "Find the class in which a member function is first defined in a class inheritance hierarchy."
# This import must be dynamic here
from llama_stack.apis.tools import RAGToolRuntime, ToolRuntime
# iterate in reverse member resolution order to find most specific class first # iterate in reverse member resolution order to find most specific class first
for cls in reversed(inspect.getmro(derived_cls)): for cls in reversed(inspect.getmro(derived_cls)):
for name, _ in inspect.getmembers(cls, inspect.isfunction): for name, _ in inspect.getmembers(cls, inspect.isfunction):
if name == member_fn: if name == member_fn:
# HACK ALERT
if cls == RAGToolRuntime:
return ToolRuntime
return cls return cls
raise ValidationError( raise ValidationError(

View file

@ -5547,7 +5547,7 @@
"enum": [ "enum": [
"model", "model",
"shield", "shield",
"vector_db", "vector_store",
"dataset", "dataset",
"scoring_function", "scoring_function",
"benchmark", "benchmark",
@ -5798,7 +5798,7 @@
"enum": [ "enum": [
"model", "model",
"shield", "shield",
"vector_db", "vector_store",
"dataset", "dataset",
"scoring_function", "scoring_function",
"benchmark", "benchmark",

View file

@ -4114,7 +4114,7 @@ components:
enum: enum:
- model - model
- shield - shield
- vector_db - vector_store
- dataset - dataset
- scoring_function - scoring_function
- benchmark - benchmark
@ -4303,7 +4303,7 @@ components:
enum: enum:
- model - model
- shield - shield
- vector_db - vector_store
- dataset - dataset
- scoring_function - scoring_function
- benchmark - benchmark

View file

@ -1850,7 +1850,7 @@
"enum": [ "enum": [
"model", "model",
"shield", "shield",
"vector_db", "vector_store",
"dataset", "dataset",
"scoring_function", "scoring_function",
"benchmark", "benchmark",
@ -3983,7 +3983,7 @@
"enum": [ "enum": [
"model", "model",
"shield", "shield",
"vector_db", "vector_store",
"dataset", "dataset",
"scoring_function", "scoring_function",
"benchmark", "benchmark",

View file

@ -1320,7 +1320,7 @@ components:
enum: enum:
- model - model
- shield - shield
- vector_db - vector_store
- dataset - dataset
- scoring_function - scoring_function
- benchmark - benchmark
@ -2927,7 +2927,7 @@ components:
enum: enum:
- model - model
- shield - shield
- vector_db - vector_store
- dataset - dataset
- scoring_function - scoring_function
- benchmark - benchmark

View file

@ -2624,89 +2624,6 @@
"deprecated": false "deprecated": false
} }
}, },
"/v1/tool-runtime/rag-tool/insert": {
"post": {
"responses": {
"200": {
"description": "OK"
},
"400": {
"$ref": "#/components/responses/BadRequest400"
},
"429": {
"$ref": "#/components/responses/TooManyRequests429"
},
"500": {
"$ref": "#/components/responses/InternalServerError500"
},
"default": {
"$ref": "#/components/responses/DefaultError"
}
},
"tags": [
"ToolRuntime"
],
"summary": "Index documents so they can be used by the RAG system.",
"description": "Index documents so they can be used by the RAG system.",
"parameters": [],
"requestBody": {
"content": {
"application/json": {
"schema": {
"$ref": "#/components/schemas/InsertRequest"
}
}
},
"required": true
},
"deprecated": false
}
},
"/v1/tool-runtime/rag-tool/query": {
"post": {
"responses": {
"200": {
"description": "RAGQueryResult containing the retrieved content and metadata",
"content": {
"application/json": {
"schema": {
"$ref": "#/components/schemas/RAGQueryResult"
}
}
}
},
"400": {
"$ref": "#/components/responses/BadRequest400"
},
"429": {
"$ref": "#/components/responses/TooManyRequests429"
},
"500": {
"$ref": "#/components/responses/InternalServerError500"
},
"default": {
"$ref": "#/components/responses/DefaultError"
}
},
"tags": [
"ToolRuntime"
],
"summary": "Query the RAG system for context; typically invoked by the agent.",
"description": "Query the RAG system for context; typically invoked by the agent.",
"parameters": [],
"requestBody": {
"content": {
"application/json": {
"schema": {
"$ref": "#/components/schemas/QueryRequest"
}
}
},
"required": true
},
"deprecated": false
}
},
"/v1/toolgroups": { "/v1/toolgroups": {
"get": { "get": {
"responses": { "responses": {
@ -6800,7 +6717,7 @@
"enum": [ "enum": [
"model", "model",
"shield", "shield",
"vector_db", "vector_store",
"dataset", "dataset",
"scoring_function", "scoring_function",
"benchmark", "benchmark",
@ -10205,7 +10122,7 @@
"enum": [ "enum": [
"model", "model",
"shield", "shield",
"vector_db", "vector_store",
"dataset", "dataset",
"scoring_function", "scoring_function",
"benchmark", "benchmark",
@ -10687,7 +10604,7 @@
"enum": [ "enum": [
"model", "model",
"shield", "shield",
"vector_db", "vector_store",
"dataset", "dataset",
"scoring_function", "scoring_function",
"benchmark", "benchmark",
@ -11383,346 +11300,6 @@
"title": "ListToolDefsResponse", "title": "ListToolDefsResponse",
"description": "Response containing a list of tool definitions." "description": "Response containing a list of tool definitions."
}, },
"RAGDocument": {
"type": "object",
"properties": {
"document_id": {
"type": "string",
"description": "The unique identifier for the document."
},
"content": {
"oneOf": [
{
"type": "string"
},
{
"$ref": "#/components/schemas/InterleavedContentItem"
},
{
"type": "array",
"items": {
"$ref": "#/components/schemas/InterleavedContentItem"
}
},
{
"$ref": "#/components/schemas/URL"
}
],
"description": "The content of the document."
},
"mime_type": {
"type": "string",
"description": "The MIME type of the document."
},
"metadata": {
"type": "object",
"additionalProperties": {
"oneOf": [
{
"type": "null"
},
{
"type": "boolean"
},
{
"type": "number"
},
{
"type": "string"
},
{
"type": "array"
},
{
"type": "object"
}
]
},
"description": "Additional metadata for the document."
}
},
"additionalProperties": false,
"required": [
"document_id",
"content",
"metadata"
],
"title": "RAGDocument",
"description": "A document to be used for document ingestion in the RAG Tool."
},
"InsertRequest": {
"type": "object",
"properties": {
"documents": {
"type": "array",
"items": {
"$ref": "#/components/schemas/RAGDocument"
},
"description": "List of documents to index in the RAG system"
},
"vector_db_id": {
"type": "string",
"description": "ID of the vector database to store the document embeddings"
},
"chunk_size_in_tokens": {
"type": "integer",
"description": "(Optional) Size in tokens for document chunking during indexing"
}
},
"additionalProperties": false,
"required": [
"documents",
"vector_db_id",
"chunk_size_in_tokens"
],
"title": "InsertRequest"
},
"DefaultRAGQueryGeneratorConfig": {
"type": "object",
"properties": {
"type": {
"type": "string",
"const": "default",
"default": "default",
"description": "Type of query generator, always 'default'"
},
"separator": {
"type": "string",
"default": " ",
"description": "String separator used to join query terms"
}
},
"additionalProperties": false,
"required": [
"type",
"separator"
],
"title": "DefaultRAGQueryGeneratorConfig",
"description": "Configuration for the default RAG query generator."
},
"LLMRAGQueryGeneratorConfig": {
"type": "object",
"properties": {
"type": {
"type": "string",
"const": "llm",
"default": "llm",
"description": "Type of query generator, always 'llm'"
},
"model": {
"type": "string",
"description": "Name of the language model to use for query generation"
},
"template": {
"type": "string",
"description": "Template string for formatting the query generation prompt"
}
},
"additionalProperties": false,
"required": [
"type",
"model",
"template"
],
"title": "LLMRAGQueryGeneratorConfig",
"description": "Configuration for the LLM-based RAG query generator."
},
"RAGQueryConfig": {
"type": "object",
"properties": {
"query_generator_config": {
"oneOf": [
{
"$ref": "#/components/schemas/DefaultRAGQueryGeneratorConfig"
},
{
"$ref": "#/components/schemas/LLMRAGQueryGeneratorConfig"
}
],
"discriminator": {
"propertyName": "type",
"mapping": {
"default": "#/components/schemas/DefaultRAGQueryGeneratorConfig",
"llm": "#/components/schemas/LLMRAGQueryGeneratorConfig"
}
},
"description": "Configuration for the query generator."
},
"max_tokens_in_context": {
"type": "integer",
"default": 4096,
"description": "Maximum number of tokens in the context."
},
"max_chunks": {
"type": "integer",
"default": 5,
"description": "Maximum number of chunks to retrieve."
},
"chunk_template": {
"type": "string",
"default": "Result {index}\nContent: {chunk.content}\nMetadata: {metadata}\n",
"description": "Template for formatting each retrieved chunk in the context. Available placeholders: {index} (1-based chunk ordinal), {chunk.content} (chunk content string), {metadata} (chunk metadata dict). Default: \"Result {index}\\nContent: {chunk.content}\\nMetadata: {metadata}\\n\""
},
"mode": {
"$ref": "#/components/schemas/RAGSearchMode",
"default": "vector",
"description": "Search mode for retrieval—either \"vector\", \"keyword\", or \"hybrid\". Default \"vector\"."
},
"ranker": {
"$ref": "#/components/schemas/Ranker",
"description": "Configuration for the ranker to use in hybrid search. Defaults to RRF ranker."
}
},
"additionalProperties": false,
"required": [
"query_generator_config",
"max_tokens_in_context",
"max_chunks",
"chunk_template"
],
"title": "RAGQueryConfig",
"description": "Configuration for the RAG query generation."
},
"RAGSearchMode": {
"type": "string",
"enum": [
"vector",
"keyword",
"hybrid"
],
"title": "RAGSearchMode",
"description": "Search modes for RAG query retrieval: - VECTOR: Uses vector similarity search for semantic matching - KEYWORD: Uses keyword-based search for exact matching - HYBRID: Combines both vector and keyword search for better results"
},
"RRFRanker": {
"type": "object",
"properties": {
"type": {
"type": "string",
"const": "rrf",
"default": "rrf",
"description": "The type of ranker, always \"rrf\""
},
"impact_factor": {
"type": "number",
"default": 60.0,
"description": "The impact factor for RRF scoring. Higher values give more weight to higher-ranked results. Must be greater than 0"
}
},
"additionalProperties": false,
"required": [
"type",
"impact_factor"
],
"title": "RRFRanker",
"description": "Reciprocal Rank Fusion (RRF) ranker configuration."
},
"Ranker": {
"oneOf": [
{
"$ref": "#/components/schemas/RRFRanker"
},
{
"$ref": "#/components/schemas/WeightedRanker"
}
],
"discriminator": {
"propertyName": "type",
"mapping": {
"rrf": "#/components/schemas/RRFRanker",
"weighted": "#/components/schemas/WeightedRanker"
}
}
},
"WeightedRanker": {
"type": "object",
"properties": {
"type": {
"type": "string",
"const": "weighted",
"default": "weighted",
"description": "The type of ranker, always \"weighted\""
},
"alpha": {
"type": "number",
"default": 0.5,
"description": "Weight factor between 0 and 1. 0 means only use keyword scores, 1 means only use vector scores, values in between blend both scores."
}
},
"additionalProperties": false,
"required": [
"type",
"alpha"
],
"title": "WeightedRanker",
"description": "Weighted ranker configuration that combines vector and keyword scores."
},
"QueryRequest": {
"type": "object",
"properties": {
"content": {
"$ref": "#/components/schemas/InterleavedContent",
"description": "The query content to search for in the indexed documents"
},
"vector_db_ids": {
"type": "array",
"items": {
"type": "string"
},
"description": "List of vector database IDs to search within"
},
"query_config": {
"$ref": "#/components/schemas/RAGQueryConfig",
"description": "(Optional) Configuration parameters for the query operation"
}
},
"additionalProperties": false,
"required": [
"content",
"vector_db_ids"
],
"title": "QueryRequest"
},
"RAGQueryResult": {
"type": "object",
"properties": {
"content": {
"$ref": "#/components/schemas/InterleavedContent",
"description": "(Optional) The retrieved content from the query"
},
"metadata": {
"type": "object",
"additionalProperties": {
"oneOf": [
{
"type": "null"
},
{
"type": "boolean"
},
{
"type": "number"
},
{
"type": "string"
},
{
"type": "array"
},
{
"type": "object"
}
]
},
"description": "Additional metadata about the query result"
}
},
"additionalProperties": false,
"required": [
"metadata"
],
"title": "RAGQueryResult",
"description": "Result of a RAG query containing retrieved content and metadata."
},
"ToolGroup": { "ToolGroup": {
"type": "object", "type": "object",
"properties": { "properties": {
@ -11740,7 +11317,7 @@
"enum": [ "enum": [
"model", "model",
"shield", "shield",
"vector_db", "vector_store",
"dataset", "dataset",
"scoring_function", "scoring_function",
"benchmark", "benchmark",

View file

@ -2036,69 +2036,6 @@ paths:
schema: schema:
$ref: '#/components/schemas/URL' $ref: '#/components/schemas/URL'
deprecated: false deprecated: false
/v1/tool-runtime/rag-tool/insert:
post:
responses:
'200':
description: OK
'400':
$ref: '#/components/responses/BadRequest400'
'429':
$ref: >-
#/components/responses/TooManyRequests429
'500':
$ref: >-
#/components/responses/InternalServerError500
default:
$ref: '#/components/responses/DefaultError'
tags:
- ToolRuntime
summary: >-
Index documents so they can be used by the RAG system.
description: >-
Index documents so they can be used by the RAG system.
parameters: []
requestBody:
content:
application/json:
schema:
$ref: '#/components/schemas/InsertRequest'
required: true
deprecated: false
/v1/tool-runtime/rag-tool/query:
post:
responses:
'200':
description: >-
RAGQueryResult containing the retrieved content and metadata
content:
application/json:
schema:
$ref: '#/components/schemas/RAGQueryResult'
'400':
$ref: '#/components/responses/BadRequest400'
'429':
$ref: >-
#/components/responses/TooManyRequests429
'500':
$ref: >-
#/components/responses/InternalServerError500
default:
$ref: '#/components/responses/DefaultError'
tags:
- ToolRuntime
summary: >-
Query the RAG system for context; typically invoked by the agent.
description: >-
Query the RAG system for context; typically invoked by the agent.
parameters: []
requestBody:
content:
application/json:
schema:
$ref: '#/components/schemas/QueryRequest'
required: true
deprecated: false
/v1/toolgroups: /v1/toolgroups:
get: get:
responses: responses:
@ -5227,7 +5164,7 @@ components:
enum: enum:
- model - model
- shield - shield
- vector_db - vector_store
- dataset - dataset
- scoring_function - scoring_function
- benchmark - benchmark
@ -7919,7 +7856,7 @@ components:
enum: enum:
- model - model
- shield - shield
- vector_db - vector_store
- dataset - dataset
- scoring_function - scoring_function
- benchmark - benchmark
@ -8227,7 +8164,7 @@ components:
enum: enum:
- model - model
- shield - shield
- vector_db - vector_store
- dataset - dataset
- scoring_function - scoring_function
- benchmark - benchmark
@ -8708,274 +8645,6 @@ components:
title: ListToolDefsResponse title: ListToolDefsResponse
description: >- description: >-
Response containing a list of tool definitions. Response containing a list of tool definitions.
RAGDocument:
type: object
properties:
document_id:
type: string
description: The unique identifier for the document.
content:
oneOf:
- type: string
- $ref: '#/components/schemas/InterleavedContentItem'
- type: array
items:
$ref: '#/components/schemas/InterleavedContentItem'
- $ref: '#/components/schemas/URL'
description: The content of the document.
mime_type:
type: string
description: The MIME type of the document.
metadata:
type: object
additionalProperties:
oneOf:
- type: 'null'
- type: boolean
- type: number
- type: string
- type: array
- type: object
description: Additional metadata for the document.
additionalProperties: false
required:
- document_id
- content
- metadata
title: RAGDocument
description: >-
A document to be used for document ingestion in the RAG Tool.
InsertRequest:
type: object
properties:
documents:
type: array
items:
$ref: '#/components/schemas/RAGDocument'
description: >-
List of documents to index in the RAG system
vector_db_id:
type: string
description: >-
ID of the vector database to store the document embeddings
chunk_size_in_tokens:
type: integer
description: >-
(Optional) Size in tokens for document chunking during indexing
additionalProperties: false
required:
- documents
- vector_db_id
- chunk_size_in_tokens
title: InsertRequest
DefaultRAGQueryGeneratorConfig:
type: object
properties:
type:
type: string
const: default
default: default
description: >-
Type of query generator, always 'default'
separator:
type: string
default: ' '
description: >-
String separator used to join query terms
additionalProperties: false
required:
- type
- separator
title: DefaultRAGQueryGeneratorConfig
description: >-
Configuration for the default RAG query generator.
LLMRAGQueryGeneratorConfig:
type: object
properties:
type:
type: string
const: llm
default: llm
description: Type of query generator, always 'llm'
model:
type: string
description: >-
Name of the language model to use for query generation
template:
type: string
description: >-
Template string for formatting the query generation prompt
additionalProperties: false
required:
- type
- model
- template
title: LLMRAGQueryGeneratorConfig
description: >-
Configuration for the LLM-based RAG query generator.
RAGQueryConfig:
type: object
properties:
query_generator_config:
oneOf:
- $ref: '#/components/schemas/DefaultRAGQueryGeneratorConfig'
- $ref: '#/components/schemas/LLMRAGQueryGeneratorConfig'
discriminator:
propertyName: type
mapping:
default: '#/components/schemas/DefaultRAGQueryGeneratorConfig'
llm: '#/components/schemas/LLMRAGQueryGeneratorConfig'
description: Configuration for the query generator.
max_tokens_in_context:
type: integer
default: 4096
description: Maximum number of tokens in the context.
max_chunks:
type: integer
default: 5
description: Maximum number of chunks to retrieve.
chunk_template:
type: string
default: >
Result {index}
Content: {chunk.content}
Metadata: {metadata}
description: >-
Template for formatting each retrieved chunk in the context. Available
placeholders: {index} (1-based chunk ordinal), {chunk.content} (chunk
content string), {metadata} (chunk metadata dict). Default: "Result {index}\nContent:
{chunk.content}\nMetadata: {metadata}\n"
mode:
$ref: '#/components/schemas/RAGSearchMode'
default: vector
description: >-
Search mode for retrieval—either "vector", "keyword", or "hybrid". Default
"vector".
ranker:
$ref: '#/components/schemas/Ranker'
description: >-
Configuration for the ranker to use in hybrid search. Defaults to RRF
ranker.
additionalProperties: false
required:
- query_generator_config
- max_tokens_in_context
- max_chunks
- chunk_template
title: RAGQueryConfig
description: >-
Configuration for the RAG query generation.
RAGSearchMode:
type: string
enum:
- vector
- keyword
- hybrid
title: RAGSearchMode
description: >-
Search modes for RAG query retrieval: - VECTOR: Uses vector similarity search
for semantic matching - KEYWORD: Uses keyword-based search for exact matching
- HYBRID: Combines both vector and keyword search for better results
RRFRanker:
type: object
properties:
type:
type: string
const: rrf
default: rrf
description: The type of ranker, always "rrf"
impact_factor:
type: number
default: 60.0
description: >-
The impact factor for RRF scoring. Higher values give more weight to higher-ranked
results. Must be greater than 0
additionalProperties: false
required:
- type
- impact_factor
title: RRFRanker
description: >-
Reciprocal Rank Fusion (RRF) ranker configuration.
Ranker:
oneOf:
- $ref: '#/components/schemas/RRFRanker'
- $ref: '#/components/schemas/WeightedRanker'
discriminator:
propertyName: type
mapping:
rrf: '#/components/schemas/RRFRanker'
weighted: '#/components/schemas/WeightedRanker'
WeightedRanker:
type: object
properties:
type:
type: string
const: weighted
default: weighted
description: The type of ranker, always "weighted"
alpha:
type: number
default: 0.5
description: >-
Weight factor between 0 and 1. 0 means only use keyword scores, 1 means
only use vector scores, values in between blend both scores.
additionalProperties: false
required:
- type
- alpha
title: WeightedRanker
description: >-
Weighted ranker configuration that combines vector and keyword scores.
QueryRequest:
type: object
properties:
content:
$ref: '#/components/schemas/InterleavedContent'
description: >-
The query content to search for in the indexed documents
vector_db_ids:
type: array
items:
type: string
description: >-
List of vector database IDs to search within
query_config:
$ref: '#/components/schemas/RAGQueryConfig'
description: >-
(Optional) Configuration parameters for the query operation
additionalProperties: false
required:
- content
- vector_db_ids
title: QueryRequest
RAGQueryResult:
type: object
properties:
content:
$ref: '#/components/schemas/InterleavedContent'
description: >-
(Optional) The retrieved content from the query
metadata:
type: object
additionalProperties:
oneOf:
- type: 'null'
- type: boolean
- type: number
- type: string
- type: array
- type: object
description: >-
Additional metadata about the query result
additionalProperties: false
required:
- metadata
title: RAGQueryResult
description: >-
Result of a RAG query containing retrieved content and metadata.
ToolGroup: ToolGroup:
type: object type: object
properties: properties:
@ -8990,7 +8659,7 @@ components:
enum: enum:
- model - model
- shield - shield
- vector_db - vector_store
- dataset - dataset
- scoring_function - scoring_function
- benchmark - benchmark

View file

@ -2624,89 +2624,6 @@
"deprecated": false "deprecated": false
} }
}, },
"/v1/tool-runtime/rag-tool/insert": {
"post": {
"responses": {
"200": {
"description": "OK"
},
"400": {
"$ref": "#/components/responses/BadRequest400"
},
"429": {
"$ref": "#/components/responses/TooManyRequests429"
},
"500": {
"$ref": "#/components/responses/InternalServerError500"
},
"default": {
"$ref": "#/components/responses/DefaultError"
}
},
"tags": [
"ToolRuntime"
],
"summary": "Index documents so they can be used by the RAG system.",
"description": "Index documents so they can be used by the RAG system.",
"parameters": [],
"requestBody": {
"content": {
"application/json": {
"schema": {
"$ref": "#/components/schemas/InsertRequest"
}
}
},
"required": true
},
"deprecated": false
}
},
"/v1/tool-runtime/rag-tool/query": {
"post": {
"responses": {
"200": {
"description": "RAGQueryResult containing the retrieved content and metadata",
"content": {
"application/json": {
"schema": {
"$ref": "#/components/schemas/RAGQueryResult"
}
}
}
},
"400": {
"$ref": "#/components/responses/BadRequest400"
},
"429": {
"$ref": "#/components/responses/TooManyRequests429"
},
"500": {
"$ref": "#/components/responses/InternalServerError500"
},
"default": {
"$ref": "#/components/responses/DefaultError"
}
},
"tags": [
"ToolRuntime"
],
"summary": "Query the RAG system for context; typically invoked by the agent.",
"description": "Query the RAG system for context; typically invoked by the agent.",
"parameters": [],
"requestBody": {
"content": {
"application/json": {
"schema": {
"$ref": "#/components/schemas/QueryRequest"
}
}
},
"required": true
},
"deprecated": false
}
},
"/v1/toolgroups": { "/v1/toolgroups": {
"get": { "get": {
"responses": { "responses": {
@ -8472,7 +8389,7 @@
"enum": [ "enum": [
"model", "model",
"shield", "shield",
"vector_db", "vector_store",
"dataset", "dataset",
"scoring_function", "scoring_function",
"benchmark", "benchmark",
@ -11877,7 +11794,7 @@
"enum": [ "enum": [
"model", "model",
"shield", "shield",
"vector_db", "vector_store",
"dataset", "dataset",
"scoring_function", "scoring_function",
"benchmark", "benchmark",
@ -12359,7 +12276,7 @@
"enum": [ "enum": [
"model", "model",
"shield", "shield",
"vector_db", "vector_store",
"dataset", "dataset",
"scoring_function", "scoring_function",
"benchmark", "benchmark",
@ -13055,346 +12972,6 @@
"title": "ListToolDefsResponse", "title": "ListToolDefsResponse",
"description": "Response containing a list of tool definitions." "description": "Response containing a list of tool definitions."
}, },
"RAGDocument": {
"type": "object",
"properties": {
"document_id": {
"type": "string",
"description": "The unique identifier for the document."
},
"content": {
"oneOf": [
{
"type": "string"
},
{
"$ref": "#/components/schemas/InterleavedContentItem"
},
{
"type": "array",
"items": {
"$ref": "#/components/schemas/InterleavedContentItem"
}
},
{
"$ref": "#/components/schemas/URL"
}
],
"description": "The content of the document."
},
"mime_type": {
"type": "string",
"description": "The MIME type of the document."
},
"metadata": {
"type": "object",
"additionalProperties": {
"oneOf": [
{
"type": "null"
},
{
"type": "boolean"
},
{
"type": "number"
},
{
"type": "string"
},
{
"type": "array"
},
{
"type": "object"
}
]
},
"description": "Additional metadata for the document."
}
},
"additionalProperties": false,
"required": [
"document_id",
"content",
"metadata"
],
"title": "RAGDocument",
"description": "A document to be used for document ingestion in the RAG Tool."
},
"InsertRequest": {
"type": "object",
"properties": {
"documents": {
"type": "array",
"items": {
"$ref": "#/components/schemas/RAGDocument"
},
"description": "List of documents to index in the RAG system"
},
"vector_db_id": {
"type": "string",
"description": "ID of the vector database to store the document embeddings"
},
"chunk_size_in_tokens": {
"type": "integer",
"description": "(Optional) Size in tokens for document chunking during indexing"
}
},
"additionalProperties": false,
"required": [
"documents",
"vector_db_id",
"chunk_size_in_tokens"
],
"title": "InsertRequest"
},
"DefaultRAGQueryGeneratorConfig": {
"type": "object",
"properties": {
"type": {
"type": "string",
"const": "default",
"default": "default",
"description": "Type of query generator, always 'default'"
},
"separator": {
"type": "string",
"default": " ",
"description": "String separator used to join query terms"
}
},
"additionalProperties": false,
"required": [
"type",
"separator"
],
"title": "DefaultRAGQueryGeneratorConfig",
"description": "Configuration for the default RAG query generator."
},
"LLMRAGQueryGeneratorConfig": {
"type": "object",
"properties": {
"type": {
"type": "string",
"const": "llm",
"default": "llm",
"description": "Type of query generator, always 'llm'"
},
"model": {
"type": "string",
"description": "Name of the language model to use for query generation"
},
"template": {
"type": "string",
"description": "Template string for formatting the query generation prompt"
}
},
"additionalProperties": false,
"required": [
"type",
"model",
"template"
],
"title": "LLMRAGQueryGeneratorConfig",
"description": "Configuration for the LLM-based RAG query generator."
},
"RAGQueryConfig": {
"type": "object",
"properties": {
"query_generator_config": {
"oneOf": [
{
"$ref": "#/components/schemas/DefaultRAGQueryGeneratorConfig"
},
{
"$ref": "#/components/schemas/LLMRAGQueryGeneratorConfig"
}
],
"discriminator": {
"propertyName": "type",
"mapping": {
"default": "#/components/schemas/DefaultRAGQueryGeneratorConfig",
"llm": "#/components/schemas/LLMRAGQueryGeneratorConfig"
}
},
"description": "Configuration for the query generator."
},
"max_tokens_in_context": {
"type": "integer",
"default": 4096,
"description": "Maximum number of tokens in the context."
},
"max_chunks": {
"type": "integer",
"default": 5,
"description": "Maximum number of chunks to retrieve."
},
"chunk_template": {
"type": "string",
"default": "Result {index}\nContent: {chunk.content}\nMetadata: {metadata}\n",
"description": "Template for formatting each retrieved chunk in the context. Available placeholders: {index} (1-based chunk ordinal), {chunk.content} (chunk content string), {metadata} (chunk metadata dict). Default: \"Result {index}\\nContent: {chunk.content}\\nMetadata: {metadata}\\n\""
},
"mode": {
"$ref": "#/components/schemas/RAGSearchMode",
"default": "vector",
"description": "Search mode for retrieval—either \"vector\", \"keyword\", or \"hybrid\". Default \"vector\"."
},
"ranker": {
"$ref": "#/components/schemas/Ranker",
"description": "Configuration for the ranker to use in hybrid search. Defaults to RRF ranker."
}
},
"additionalProperties": false,
"required": [
"query_generator_config",
"max_tokens_in_context",
"max_chunks",
"chunk_template"
],
"title": "RAGQueryConfig",
"description": "Configuration for the RAG query generation."
},
"RAGSearchMode": {
"type": "string",
"enum": [
"vector",
"keyword",
"hybrid"
],
"title": "RAGSearchMode",
"description": "Search modes for RAG query retrieval: - VECTOR: Uses vector similarity search for semantic matching - KEYWORD: Uses keyword-based search for exact matching - HYBRID: Combines both vector and keyword search for better results"
},
"RRFRanker": {
"type": "object",
"properties": {
"type": {
"type": "string",
"const": "rrf",
"default": "rrf",
"description": "The type of ranker, always \"rrf\""
},
"impact_factor": {
"type": "number",
"default": 60.0,
"description": "The impact factor for RRF scoring. Higher values give more weight to higher-ranked results. Must be greater than 0"
}
},
"additionalProperties": false,
"required": [
"type",
"impact_factor"
],
"title": "RRFRanker",
"description": "Reciprocal Rank Fusion (RRF) ranker configuration."
},
"Ranker": {
"oneOf": [
{
"$ref": "#/components/schemas/RRFRanker"
},
{
"$ref": "#/components/schemas/WeightedRanker"
}
],
"discriminator": {
"propertyName": "type",
"mapping": {
"rrf": "#/components/schemas/RRFRanker",
"weighted": "#/components/schemas/WeightedRanker"
}
}
},
"WeightedRanker": {
"type": "object",
"properties": {
"type": {
"type": "string",
"const": "weighted",
"default": "weighted",
"description": "The type of ranker, always \"weighted\""
},
"alpha": {
"type": "number",
"default": 0.5,
"description": "Weight factor between 0 and 1. 0 means only use keyword scores, 1 means only use vector scores, values in between blend both scores."
}
},
"additionalProperties": false,
"required": [
"type",
"alpha"
],
"title": "WeightedRanker",
"description": "Weighted ranker configuration that combines vector and keyword scores."
},
"QueryRequest": {
"type": "object",
"properties": {
"content": {
"$ref": "#/components/schemas/InterleavedContent",
"description": "The query content to search for in the indexed documents"
},
"vector_db_ids": {
"type": "array",
"items": {
"type": "string"
},
"description": "List of vector database IDs to search within"
},
"query_config": {
"$ref": "#/components/schemas/RAGQueryConfig",
"description": "(Optional) Configuration parameters for the query operation"
}
},
"additionalProperties": false,
"required": [
"content",
"vector_db_ids"
],
"title": "QueryRequest"
},
"RAGQueryResult": {
"type": "object",
"properties": {
"content": {
"$ref": "#/components/schemas/InterleavedContent",
"description": "(Optional) The retrieved content from the query"
},
"metadata": {
"type": "object",
"additionalProperties": {
"oneOf": [
{
"type": "null"
},
{
"type": "boolean"
},
{
"type": "number"
},
{
"type": "string"
},
{
"type": "array"
},
{
"type": "object"
}
]
},
"description": "Additional metadata about the query result"
}
},
"additionalProperties": false,
"required": [
"metadata"
],
"title": "RAGQueryResult",
"description": "Result of a RAG query containing retrieved content and metadata."
},
"ToolGroup": { "ToolGroup": {
"type": "object", "type": "object",
"properties": { "properties": {
@ -13412,7 +12989,7 @@
"enum": [ "enum": [
"model", "model",
"shield", "shield",
"vector_db", "vector_store",
"dataset", "dataset",
"scoring_function", "scoring_function",
"benchmark", "benchmark",
@ -14959,7 +14536,7 @@
"enum": [ "enum": [
"model", "model",
"shield", "shield",
"vector_db", "vector_store",
"dataset", "dataset",
"scoring_function", "scoring_function",
"benchmark", "benchmark",
@ -16704,7 +16281,7 @@
"enum": [ "enum": [
"model", "model",
"shield", "shield",
"vector_db", "vector_store",
"dataset", "dataset",
"scoring_function", "scoring_function",
"benchmark", "benchmark",

View file

@ -2039,69 +2039,6 @@ paths:
schema: schema:
$ref: '#/components/schemas/URL' $ref: '#/components/schemas/URL'
deprecated: false deprecated: false
/v1/tool-runtime/rag-tool/insert:
post:
responses:
'200':
description: OK
'400':
$ref: '#/components/responses/BadRequest400'
'429':
$ref: >-
#/components/responses/TooManyRequests429
'500':
$ref: >-
#/components/responses/InternalServerError500
default:
$ref: '#/components/responses/DefaultError'
tags:
- ToolRuntime
summary: >-
Index documents so they can be used by the RAG system.
description: >-
Index documents so they can be used by the RAG system.
parameters: []
requestBody:
content:
application/json:
schema:
$ref: '#/components/schemas/InsertRequest'
required: true
deprecated: false
/v1/tool-runtime/rag-tool/query:
post:
responses:
'200':
description: >-
RAGQueryResult containing the retrieved content and metadata
content:
application/json:
schema:
$ref: '#/components/schemas/RAGQueryResult'
'400':
$ref: '#/components/responses/BadRequest400'
'429':
$ref: >-
#/components/responses/TooManyRequests429
'500':
$ref: >-
#/components/responses/InternalServerError500
default:
$ref: '#/components/responses/DefaultError'
tags:
- ToolRuntime
summary: >-
Query the RAG system for context; typically invoked by the agent.
description: >-
Query the RAG system for context; typically invoked by the agent.
parameters: []
requestBody:
content:
application/json:
schema:
$ref: '#/components/schemas/QueryRequest'
required: true
deprecated: false
/v1/toolgroups: /v1/toolgroups:
get: get:
responses: responses:
@ -6440,7 +6377,7 @@ components:
enum: enum:
- model - model
- shield - shield
- vector_db - vector_store
- dataset - dataset
- scoring_function - scoring_function
- benchmark - benchmark
@ -9132,7 +9069,7 @@ components:
enum: enum:
- model - model
- shield - shield
- vector_db - vector_store
- dataset - dataset
- scoring_function - scoring_function
- benchmark - benchmark
@ -9440,7 +9377,7 @@ components:
enum: enum:
- model - model
- shield - shield
- vector_db - vector_store
- dataset - dataset
- scoring_function - scoring_function
- benchmark - benchmark
@ -9921,274 +9858,6 @@ components:
title: ListToolDefsResponse title: ListToolDefsResponse
description: >- description: >-
Response containing a list of tool definitions. Response containing a list of tool definitions.
RAGDocument:
type: object
properties:
document_id:
type: string
description: The unique identifier for the document.
content:
oneOf:
- type: string
- $ref: '#/components/schemas/InterleavedContentItem'
- type: array
items:
$ref: '#/components/schemas/InterleavedContentItem'
- $ref: '#/components/schemas/URL'
description: The content of the document.
mime_type:
type: string
description: The MIME type of the document.
metadata:
type: object
additionalProperties:
oneOf:
- type: 'null'
- type: boolean
- type: number
- type: string
- type: array
- type: object
description: Additional metadata for the document.
additionalProperties: false
required:
- document_id
- content
- metadata
title: RAGDocument
description: >-
A document to be used for document ingestion in the RAG Tool.
InsertRequest:
type: object
properties:
documents:
type: array
items:
$ref: '#/components/schemas/RAGDocument'
description: >-
List of documents to index in the RAG system
vector_db_id:
type: string
description: >-
ID of the vector database to store the document embeddings
chunk_size_in_tokens:
type: integer
description: >-
(Optional) Size in tokens for document chunking during indexing
additionalProperties: false
required:
- documents
- vector_db_id
- chunk_size_in_tokens
title: InsertRequest
DefaultRAGQueryGeneratorConfig:
type: object
properties:
type:
type: string
const: default
default: default
description: >-
Type of query generator, always 'default'
separator:
type: string
default: ' '
description: >-
String separator used to join query terms
additionalProperties: false
required:
- type
- separator
title: DefaultRAGQueryGeneratorConfig
description: >-
Configuration for the default RAG query generator.
LLMRAGQueryGeneratorConfig:
type: object
properties:
type:
type: string
const: llm
default: llm
description: Type of query generator, always 'llm'
model:
type: string
description: >-
Name of the language model to use for query generation
template:
type: string
description: >-
Template string for formatting the query generation prompt
additionalProperties: false
required:
- type
- model
- template
title: LLMRAGQueryGeneratorConfig
description: >-
Configuration for the LLM-based RAG query generator.
RAGQueryConfig:
type: object
properties:
query_generator_config:
oneOf:
- $ref: '#/components/schemas/DefaultRAGQueryGeneratorConfig'
- $ref: '#/components/schemas/LLMRAGQueryGeneratorConfig'
discriminator:
propertyName: type
mapping:
default: '#/components/schemas/DefaultRAGQueryGeneratorConfig'
llm: '#/components/schemas/LLMRAGQueryGeneratorConfig'
description: Configuration for the query generator.
max_tokens_in_context:
type: integer
default: 4096
description: Maximum number of tokens in the context.
max_chunks:
type: integer
default: 5
description: Maximum number of chunks to retrieve.
chunk_template:
type: string
default: >
Result {index}
Content: {chunk.content}
Metadata: {metadata}
description: >-
Template for formatting each retrieved chunk in the context. Available
placeholders: {index} (1-based chunk ordinal), {chunk.content} (chunk
content string), {metadata} (chunk metadata dict). Default: "Result {index}\nContent:
{chunk.content}\nMetadata: {metadata}\n"
mode:
$ref: '#/components/schemas/RAGSearchMode'
default: vector
description: >-
Search mode for retrieval—either "vector", "keyword", or "hybrid". Default
"vector".
ranker:
$ref: '#/components/schemas/Ranker'
description: >-
Configuration for the ranker to use in hybrid search. Defaults to RRF
ranker.
additionalProperties: false
required:
- query_generator_config
- max_tokens_in_context
- max_chunks
- chunk_template
title: RAGQueryConfig
description: >-
Configuration for the RAG query generation.
RAGSearchMode:
type: string
enum:
- vector
- keyword
- hybrid
title: RAGSearchMode
description: >-
Search modes for RAG query retrieval: - VECTOR: Uses vector similarity search
for semantic matching - KEYWORD: Uses keyword-based search for exact matching
- HYBRID: Combines both vector and keyword search for better results
RRFRanker:
type: object
properties:
type:
type: string
const: rrf
default: rrf
description: The type of ranker, always "rrf"
impact_factor:
type: number
default: 60.0
description: >-
The impact factor for RRF scoring. Higher values give more weight to higher-ranked
results. Must be greater than 0
additionalProperties: false
required:
- type
- impact_factor
title: RRFRanker
description: >-
Reciprocal Rank Fusion (RRF) ranker configuration.
Ranker:
oneOf:
- $ref: '#/components/schemas/RRFRanker'
- $ref: '#/components/schemas/WeightedRanker'
discriminator:
propertyName: type
mapping:
rrf: '#/components/schemas/RRFRanker'
weighted: '#/components/schemas/WeightedRanker'
WeightedRanker:
type: object
properties:
type:
type: string
const: weighted
default: weighted
description: The type of ranker, always "weighted"
alpha:
type: number
default: 0.5
description: >-
Weight factor between 0 and 1. 0 means only use keyword scores, 1 means
only use vector scores, values in between blend both scores.
additionalProperties: false
required:
- type
- alpha
title: WeightedRanker
description: >-
Weighted ranker configuration that combines vector and keyword scores.
QueryRequest:
type: object
properties:
content:
$ref: '#/components/schemas/InterleavedContent'
description: >-
The query content to search for in the indexed documents
vector_db_ids:
type: array
items:
type: string
description: >-
List of vector database IDs to search within
query_config:
$ref: '#/components/schemas/RAGQueryConfig'
description: >-
(Optional) Configuration parameters for the query operation
additionalProperties: false
required:
- content
- vector_db_ids
title: QueryRequest
RAGQueryResult:
type: object
properties:
content:
$ref: '#/components/schemas/InterleavedContent'
description: >-
(Optional) The retrieved content from the query
metadata:
type: object
additionalProperties:
oneOf:
- type: 'null'
- type: boolean
- type: number
- type: string
- type: array
- type: object
description: >-
Additional metadata about the query result
additionalProperties: false
required:
- metadata
title: RAGQueryResult
description: >-
Result of a RAG query containing retrieved content and metadata.
ToolGroup: ToolGroup:
type: object type: object
properties: properties:
@ -10203,7 +9872,7 @@ components:
enum: enum:
- model - model
- shield - shield
- vector_db - vector_store
- dataset - dataset
- scoring_function - scoring_function
- benchmark - benchmark
@ -11325,7 +10994,7 @@ components:
enum: enum:
- model - model
- shield - shield
- vector_db - vector_store
- dataset - dataset
- scoring_function - scoring_function
- benchmark - benchmark
@ -12652,7 +12321,7 @@ components:
enum: enum:
- model - model
- shield - shield
- vector_db - vector_store
- dataset - dataset
- scoring_function - scoring_function
- benchmark - benchmark

View file

@ -121,7 +121,7 @@ class Api(Enum, metaclass=DynamicApiMeta):
models = "models" models = "models"
shields = "shields" shields = "shields"
vector_dbs = "vector_dbs" # only used for routing vector_stores = "vector_stores" # only used for routing table
datasets = "datasets" datasets = "datasets"
scoring_functions = "scoring_functions" scoring_functions = "scoring_functions"
benchmarks = "benchmarks" benchmarks = "benchmarks"

View file

@ -13,7 +13,7 @@ from pydantic import BaseModel, Field
class ResourceType(StrEnum): class ResourceType(StrEnum):
model = "model" model = "model"
shield = "shield" shield = "shield"
vector_db = "vector_db" vector_store = "vector_store"
dataset = "dataset" dataset = "dataset"
scoring_function = "scoring_function" scoring_function = "scoring_function"
benchmark = "benchmark" benchmark = "benchmark"
@ -34,4 +34,4 @@ class Resource(BaseModel):
provider_id: str = Field(description="ID of the provider that owns this resource") provider_id: str = Field(description="ID of the provider that owns this resource")
type: ResourceType = Field(description="Type of resource (e.g. 'model', 'shield', 'vector_db', etc.)") type: ResourceType = Field(description="Type of resource (e.g. 'model', 'shield', 'vector_store', etc.)")

View file

@ -4,5 +4,4 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
from .rag_tool import *
from .tools import * from .tools import *

View file

@ -1,218 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from enum import Enum, StrEnum
from typing import Annotated, Any, Literal, Protocol
from pydantic import BaseModel, Field, field_validator
from typing_extensions import runtime_checkable
from llama_stack.apis.common.content_types import URL, InterleavedContent
from llama_stack.apis.version import LLAMA_STACK_API_V1
from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol
from llama_stack.schema_utils import json_schema_type, register_schema, webmethod
@json_schema_type
class RRFRanker(BaseModel):
"""
Reciprocal Rank Fusion (RRF) ranker configuration.
:param type: The type of ranker, always "rrf"
:param impact_factor: The impact factor for RRF scoring. Higher values give more weight to higher-ranked results.
Must be greater than 0
"""
type: Literal["rrf"] = "rrf"
impact_factor: float = Field(default=60.0, gt=0.0) # default of 60 for optimal performance
@json_schema_type
class WeightedRanker(BaseModel):
"""
Weighted ranker configuration that combines vector and keyword scores.
:param type: The type of ranker, always "weighted"
:param alpha: Weight factor between 0 and 1.
0 means only use keyword scores,
1 means only use vector scores,
values in between blend both scores.
"""
type: Literal["weighted"] = "weighted"
alpha: float = Field(
default=0.5,
ge=0.0,
le=1.0,
description="Weight factor between 0 and 1. 0 means only keyword scores, 1 means only vector scores.",
)
Ranker = Annotated[
RRFRanker | WeightedRanker,
Field(discriminator="type"),
]
register_schema(Ranker, name="Ranker")
@json_schema_type
class RAGDocument(BaseModel):
"""
A document to be used for document ingestion in the RAG Tool.
:param document_id: The unique identifier for the document.
:param content: The content of the document.
:param mime_type: The MIME type of the document.
:param metadata: Additional metadata for the document.
"""
document_id: str
content: InterleavedContent | URL
mime_type: str | None = None
metadata: dict[str, Any] = Field(default_factory=dict)
@json_schema_type
class RAGQueryResult(BaseModel):
"""Result of a RAG query containing retrieved content and metadata.
:param content: (Optional) The retrieved content from the query
:param metadata: Additional metadata about the query result
"""
content: InterleavedContent | None = None
metadata: dict[str, Any] = Field(default_factory=dict)
@json_schema_type
class RAGQueryGenerator(Enum):
"""Types of query generators for RAG systems.
:cvar default: Default query generator using simple text processing
:cvar llm: LLM-based query generator for enhanced query understanding
:cvar custom: Custom query generator implementation
"""
default = "default"
llm = "llm"
custom = "custom"
@json_schema_type
class RAGSearchMode(StrEnum):
"""
Search modes for RAG query retrieval:
- VECTOR: Uses vector similarity search for semantic matching
- KEYWORD: Uses keyword-based search for exact matching
- HYBRID: Combines both vector and keyword search for better results
"""
VECTOR = "vector"
KEYWORD = "keyword"
HYBRID = "hybrid"
@json_schema_type
class DefaultRAGQueryGeneratorConfig(BaseModel):
"""Configuration for the default RAG query generator.
:param type: Type of query generator, always 'default'
:param separator: String separator used to join query terms
"""
type: Literal["default"] = "default"
separator: str = " "
@json_schema_type
class LLMRAGQueryGeneratorConfig(BaseModel):
"""Configuration for the LLM-based RAG query generator.
:param type: Type of query generator, always 'llm'
:param model: Name of the language model to use for query generation
:param template: Template string for formatting the query generation prompt
"""
type: Literal["llm"] = "llm"
model: str
template: str
RAGQueryGeneratorConfig = Annotated[
DefaultRAGQueryGeneratorConfig | LLMRAGQueryGeneratorConfig,
Field(discriminator="type"),
]
register_schema(RAGQueryGeneratorConfig, name="RAGQueryGeneratorConfig")
@json_schema_type
class RAGQueryConfig(BaseModel):
"""
Configuration for the RAG query generation.
:param query_generator_config: Configuration for the query generator.
:param max_tokens_in_context: Maximum number of tokens in the context.
:param max_chunks: Maximum number of chunks to retrieve.
:param chunk_template: Template for formatting each retrieved chunk in the context.
Available placeholders: {index} (1-based chunk ordinal), {chunk.content} (chunk content string), {metadata} (chunk metadata dict).
Default: "Result {index}\\nContent: {chunk.content}\\nMetadata: {metadata}\\n"
:param mode: Search mode for retrievaleither "vector", "keyword", or "hybrid". Default "vector".
:param ranker: Configuration for the ranker to use in hybrid search. Defaults to RRF ranker.
"""
# This config defines how a query is generated using the messages
# for memory bank retrieval.
query_generator_config: RAGQueryGeneratorConfig = Field(default=DefaultRAGQueryGeneratorConfig())
max_tokens_in_context: int = 4096
max_chunks: int = 5
chunk_template: str = "Result {index}\nContent: {chunk.content}\nMetadata: {metadata}\n"
mode: RAGSearchMode | None = RAGSearchMode.VECTOR
ranker: Ranker | None = Field(default=None) # Only used for hybrid mode
@field_validator("chunk_template")
def validate_chunk_template(cls, v: str) -> str:
if "{chunk.content}" not in v:
raise ValueError("chunk_template must contain {chunk.content}")
if "{index}" not in v:
raise ValueError("chunk_template must contain {index}")
if len(v) == 0:
raise ValueError("chunk_template must not be empty")
return v
@runtime_checkable
@trace_protocol
class RAGToolRuntime(Protocol):
@webmethod(route="/tool-runtime/rag-tool/insert", method="POST", level=LLAMA_STACK_API_V1)
async def insert(
self,
documents: list[RAGDocument],
vector_db_id: str,
chunk_size_in_tokens: int = 512,
) -> None:
"""Index documents so they can be used by the RAG system.
:param documents: List of documents to index in the RAG system
:param vector_db_id: ID of the vector database to store the document embeddings
:param chunk_size_in_tokens: (Optional) Size in tokens for document chunking during indexing
"""
...
@webmethod(route="/tool-runtime/rag-tool/query", method="POST", level=LLAMA_STACK_API_V1)
async def query(
self,
content: InterleavedContent,
vector_db_ids: list[str],
query_config: RAGQueryConfig | None = None,
) -> RAGQueryResult:
"""Query the RAG system for context; typically invoked by the agent.
:param content: The query content to search for in the indexed documents
:param vector_db_ids: List of vector database IDs to search within
:param query_config: (Optional) Configuration parameters for the query operation
:returns: RAGQueryResult containing the retrieved content and metadata
"""
...

View file

@ -4,7 +4,6 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
from enum import Enum
from typing import Any, Literal, Protocol from typing import Any, Literal, Protocol
from pydantic import BaseModel from pydantic import BaseModel
@ -16,8 +15,6 @@ from llama_stack.apis.version import LLAMA_STACK_API_V1
from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol
from llama_stack.schema_utils import json_schema_type, webmethod from llama_stack.schema_utils import json_schema_type, webmethod
from .rag_tool import RAGToolRuntime
@json_schema_type @json_schema_type
class ToolDef(BaseModel): class ToolDef(BaseModel):
@ -181,22 +178,11 @@ class ToolGroups(Protocol):
... ...
class SpecialToolGroup(Enum):
"""Special tool groups with predefined functionality.
:cvar rag_tool: Retrieval-Augmented Generation tool group for document search and retrieval
"""
rag_tool = "rag_tool"
@runtime_checkable @runtime_checkable
@trace_protocol @trace_protocol
class ToolRuntime(Protocol): class ToolRuntime(Protocol):
tool_store: ToolStore | None = None tool_store: ToolStore | None = None
rag_tool: RAGToolRuntime | None = None
# TODO: This needs to be renamed once OPEN API generator name conflict issue is fixed. # TODO: This needs to be renamed once OPEN API generator name conflict issue is fixed.
@webmethod(route="/tool-runtime/list-tools", method="GET", level=LLAMA_STACK_API_V1) @webmethod(route="/tool-runtime/list-tools", method="GET", level=LLAMA_STACK_API_V1)
async def list_runtime_tools( async def list_runtime_tools(

View file

@ -1,93 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Literal, Protocol, runtime_checkable
from pydantic import BaseModel
from llama_stack.apis.resource import Resource, ResourceType
from llama_stack.schema_utils import json_schema_type
@json_schema_type
class VectorDB(Resource):
"""Vector database resource for storing and querying vector embeddings.
:param type: Type of resource, always 'vector_db' for vector databases
:param embedding_model: Name of the embedding model to use for vector generation
:param embedding_dimension: Dimension of the embedding vectors
"""
type: Literal[ResourceType.vector_db] = ResourceType.vector_db
embedding_model: str
embedding_dimension: int
vector_db_name: str | None = None
@property
def vector_db_id(self) -> str:
return self.identifier
@property
def provider_vector_db_id(self) -> str | None:
return self.provider_resource_id
class VectorDBInput(BaseModel):
"""Input parameters for creating or configuring a vector database.
:param vector_db_id: Unique identifier for the vector database
:param embedding_model: Name of the embedding model to use for vector generation
:param embedding_dimension: Dimension of the embedding vectors
:param provider_vector_db_id: (Optional) Provider-specific identifier for the vector database
"""
vector_db_id: str
embedding_model: str
embedding_dimension: int
provider_id: str | None = None
provider_vector_db_id: str | None = None
class ListVectorDBsResponse(BaseModel):
"""Response from listing vector databases.
:param data: List of vector databases
"""
data: list[VectorDB]
@runtime_checkable
class VectorDBs(Protocol):
"""Internal protocol for vector_dbs routing - no public API endpoints."""
async def list_vector_dbs(self) -> ListVectorDBsResponse:
"""Internal method to list vector databases."""
...
async def get_vector_db(
self,
vector_db_id: str,
) -> VectorDB:
"""Internal method to get a vector database by ID."""
...
async def register_vector_db(
self,
vector_db_id: str,
embedding_model: str,
embedding_dimension: int | None = 384,
provider_id: str | None = None,
vector_db_name: str | None = None,
provider_vector_db_id: str | None = None,
) -> VectorDB:
"""Internal method to register a vector database."""
...
async def unregister_vector_db(self, vector_db_id: str) -> None:
"""Internal method to unregister a vector database."""
...

View file

@ -15,7 +15,7 @@ from fastapi import Body
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from llama_stack.apis.inference import InterleavedContent from llama_stack.apis.inference import InterleavedContent
from llama_stack.apis.vector_dbs import VectorDB from llama_stack.apis.vector_stores import VectorStore
from llama_stack.apis.version import LLAMA_STACK_API_V1 from llama_stack.apis.version import LLAMA_STACK_API_V1
from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol
from llama_stack.providers.utils.vector_io.vector_utils import generate_chunk_id from llama_stack.providers.utils.vector_io.vector_utils import generate_chunk_id
@ -140,6 +140,7 @@ class VectorStoreFileCounts(BaseModel):
total: int total: int
# TODO: rename this as OpenAIVectorStore
@json_schema_type @json_schema_type
class VectorStoreObject(BaseModel): class VectorStoreObject(BaseModel):
"""OpenAI Vector Store object. """OpenAI Vector Store object.
@ -517,17 +518,18 @@ class OpenAICreateVectorStoreFileBatchRequestWithExtraBody(BaseModel, extra="all
chunking_strategy: VectorStoreChunkingStrategy | None = None chunking_strategy: VectorStoreChunkingStrategy | None = None
class VectorDBStore(Protocol): class VectorStoreTable(Protocol):
def get_vector_db(self, vector_db_id: str) -> VectorDB | None: ... def get_vector_store(self, vector_store_id: str) -> VectorStore | None: ...
@runtime_checkable @runtime_checkable
@trace_protocol @trace_protocol
class VectorIO(Protocol): class VectorIO(Protocol):
vector_db_store: VectorDBStore | None = None vector_store_table: VectorStoreTable | None = None
# this will just block now until chunks are inserted, but it should # this will just block now until chunks are inserted, but it should
# probably return a Job instance which can be polled for completion # probably return a Job instance which can be polled for completion
# TODO: rename vector_db_id to vector_store_id once Stainless is working
@webmethod(route="/vector-io/insert", method="POST", level=LLAMA_STACK_API_V1) @webmethod(route="/vector-io/insert", method="POST", level=LLAMA_STACK_API_V1)
async def insert_chunks( async def insert_chunks(
self, self,
@ -546,6 +548,7 @@ class VectorIO(Protocol):
""" """
... ...
# TODO: rename vector_db_id to vector_store_id once Stainless is working
@webmethod(route="/vector-io/query", method="POST", level=LLAMA_STACK_API_V1) @webmethod(route="/vector-io/query", method="POST", level=LLAMA_STACK_API_V1)
async def query_chunks( async def query_chunks(
self, self,

View file

@ -4,4 +4,4 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
from .vector_dbs import * from .vector_stores import *

View file

@ -0,0 +1,51 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Literal
from pydantic import BaseModel
from llama_stack.apis.resource import Resource, ResourceType
# Internal resource type for storing the vector store routing and other information
class VectorStore(Resource):
"""Vector database resource for storing and querying vector embeddings.
:param type: Type of resource, always 'vector_store' for vector stores
:param embedding_model: Name of the embedding model to use for vector generation
:param embedding_dimension: Dimension of the embedding vectors
"""
type: Literal[ResourceType.vector_store] = ResourceType.vector_store
embedding_model: str
embedding_dimension: int
vector_store_name: str | None = None
@property
def vector_store_id(self) -> str:
return self.identifier
@property
def provider_vector_store_id(self) -> str | None:
return self.provider_resource_id
class VectorStoreInput(BaseModel):
"""Input parameters for creating or configuring a vector database.
:param vector_store_id: Unique identifier for the vector store
:param embedding_model: Name of the embedding model to use for vector generation
:param embedding_dimension: Dimension of the embedding vectors
:param provider_vector_store_id: (Optional) Provider-specific identifier for the vector store
"""
vector_store_id: str
embedding_model: str
embedding_dimension: int
provider_id: str | None = None
provider_vector_store_id: str | None = None

View file

@ -41,7 +41,7 @@ class AccessRule(BaseModel):
A rule defines a list of action either to permit or to forbid. It may specify a A rule defines a list of action either to permit or to forbid. It may specify a
principal or a resource that must match for the rule to take effect. The resource principal or a resource that must match for the rule to take effect. The resource
to match should be specified in the form of a type qualified identifier, e.g. to match should be specified in the form of a type qualified identifier, e.g.
model::my-model or vector_db::some-db, or a wildcard for all resources of a type, model::my-model or vector_store::some-db, or a wildcard for all resources of a type,
e.g. model::*. If the principal or resource are not specified, they will match all e.g. model::*. If the principal or resource are not specified, they will match all
requests. requests.
@ -79,9 +79,9 @@ class AccessRule(BaseModel):
description: any user has read access to any resource created by a member of their team description: any user has read access to any resource created by a member of their team
- forbid: - forbid:
actions: [create, read, delete] actions: [create, read, delete]
resource: vector_db::* resource: vector_store::*
unless: user with admin in roles unless: user with admin in roles
description: only user with admin role can use vector_db resources description: only user with admin role can use vector_store resources
""" """

View file

@ -23,8 +23,8 @@ from llama_stack.apis.scoring import Scoring
from llama_stack.apis.scoring_functions import ScoringFn, ScoringFnInput from llama_stack.apis.scoring_functions import ScoringFn, ScoringFnInput
from llama_stack.apis.shields import Shield, ShieldInput from llama_stack.apis.shields import Shield, ShieldInput
from llama_stack.apis.tools import ToolGroup, ToolGroupInput, ToolRuntime from llama_stack.apis.tools import ToolGroup, ToolGroupInput, ToolRuntime
from llama_stack.apis.vector_dbs import VectorDB, VectorDBInput
from llama_stack.apis.vector_io import VectorIO from llama_stack.apis.vector_io import VectorIO
from llama_stack.apis.vector_stores import VectorStore, VectorStoreInput
from llama_stack.core.access_control.datatypes import AccessRule from llama_stack.core.access_control.datatypes import AccessRule
from llama_stack.core.storage.datatypes import ( from llama_stack.core.storage.datatypes import (
KVStoreReference, KVStoreReference,
@ -71,7 +71,7 @@ class ShieldWithOwner(Shield, ResourceWithOwner):
pass pass
class VectorDBWithOwner(VectorDB, ResourceWithOwner): class VectorStoreWithOwner(VectorStore, ResourceWithOwner):
pass pass
@ -91,12 +91,12 @@ class ToolGroupWithOwner(ToolGroup, ResourceWithOwner):
pass pass
RoutableObject = Model | Shield | VectorDB | Dataset | ScoringFn | Benchmark | ToolGroup RoutableObject = Model | Shield | VectorStore | Dataset | ScoringFn | Benchmark | ToolGroup
RoutableObjectWithProvider = Annotated[ RoutableObjectWithProvider = Annotated[
ModelWithOwner ModelWithOwner
| ShieldWithOwner | ShieldWithOwner
| VectorDBWithOwner | VectorStoreWithOwner
| DatasetWithOwner | DatasetWithOwner
| ScoringFnWithOwner | ScoringFnWithOwner
| BenchmarkWithOwner | BenchmarkWithOwner
@ -427,7 +427,7 @@ class RegisteredResources(BaseModel):
models: list[ModelInput] = Field(default_factory=list) models: list[ModelInput] = Field(default_factory=list)
shields: list[ShieldInput] = Field(default_factory=list) shields: list[ShieldInput] = Field(default_factory=list)
vector_dbs: list[VectorDBInput] = Field(default_factory=list) vector_stores: list[VectorStoreInput] = Field(default_factory=list)
datasets: list[DatasetInput] = Field(default_factory=list) datasets: list[DatasetInput] = Field(default_factory=list)
scoring_fns: list[ScoringFnInput] = Field(default_factory=list) scoring_fns: list[ScoringFnInput] = Field(default_factory=list)
benchmarks: list[BenchmarkInput] = Field(default_factory=list) benchmarks: list[BenchmarkInput] = Field(default_factory=list)

View file

@ -64,7 +64,7 @@ def builtin_automatically_routed_apis() -> list[AutoRoutedApiInfo]:
router_api=Api.tool_runtime, router_api=Api.tool_runtime,
), ),
AutoRoutedApiInfo( AutoRoutedApiInfo(
routing_table_api=Api.vector_dbs, routing_table_api=Api.vector_stores,
router_api=Api.vector_io, router_api=Api.vector_io,
), ),
] ]

View file

@ -29,8 +29,8 @@ from llama_stack.apis.scoring_functions import ScoringFunctions
from llama_stack.apis.shields import Shields from llama_stack.apis.shields import Shields
from llama_stack.apis.telemetry import Telemetry from llama_stack.apis.telemetry import Telemetry
from llama_stack.apis.tools import ToolGroups, ToolRuntime from llama_stack.apis.tools import ToolGroups, ToolRuntime
from llama_stack.apis.vector_dbs import VectorDBs
from llama_stack.apis.vector_io import VectorIO from llama_stack.apis.vector_io import VectorIO
from llama_stack.apis.vector_stores import VectorStore
from llama_stack.apis.version import LLAMA_STACK_API_V1ALPHA from llama_stack.apis.version import LLAMA_STACK_API_V1ALPHA
from llama_stack.core.client import get_client_impl from llama_stack.core.client import get_client_impl
from llama_stack.core.datatypes import ( from llama_stack.core.datatypes import (
@ -82,7 +82,7 @@ def api_protocol_map(external_apis: dict[Api, ExternalApiSpec] | None = None) ->
Api.inspect: Inspect, Api.inspect: Inspect,
Api.batches: Batches, Api.batches: Batches,
Api.vector_io: VectorIO, Api.vector_io: VectorIO,
Api.vector_dbs: VectorDBs, Api.vector_stores: VectorStore,
Api.models: Models, Api.models: Models,
Api.safety: Safety, Api.safety: Safety,
Api.shields: Shields, Api.shields: Shields,

View file

@ -29,7 +29,7 @@ async def get_routing_table_impl(
from ..routing_tables.scoring_functions import ScoringFunctionsRoutingTable from ..routing_tables.scoring_functions import ScoringFunctionsRoutingTable
from ..routing_tables.shields import ShieldsRoutingTable from ..routing_tables.shields import ShieldsRoutingTable
from ..routing_tables.toolgroups import ToolGroupsRoutingTable from ..routing_tables.toolgroups import ToolGroupsRoutingTable
from ..routing_tables.vector_dbs import VectorDBsRoutingTable from ..routing_tables.vector_stores import VectorStoresRoutingTable
api_to_tables = { api_to_tables = {
"models": ModelsRoutingTable, "models": ModelsRoutingTable,
@ -38,7 +38,7 @@ async def get_routing_table_impl(
"scoring_functions": ScoringFunctionsRoutingTable, "scoring_functions": ScoringFunctionsRoutingTable,
"benchmarks": BenchmarksRoutingTable, "benchmarks": BenchmarksRoutingTable,
"tool_groups": ToolGroupsRoutingTable, "tool_groups": ToolGroupsRoutingTable,
"vector_dbs": VectorDBsRoutingTable, "vector_stores": VectorStoresRoutingTable,
} }
if api.value not in api_to_tables: if api.value not in api_to_tables:

View file

@ -8,16 +8,8 @@ from typing import Any
from llama_stack.apis.common.content_types import ( from llama_stack.apis.common.content_types import (
URL, URL,
InterleavedContent,
)
from llama_stack.apis.tools import (
ListToolDefsResponse,
RAGDocument,
RAGQueryConfig,
RAGQueryResult,
RAGToolRuntime,
ToolRuntime,
) )
from llama_stack.apis.tools import ListToolDefsResponse, ToolRuntime
from llama_stack.log import get_logger from llama_stack.log import get_logger
from ..routing_tables.toolgroups import ToolGroupsRoutingTable from ..routing_tables.toolgroups import ToolGroupsRoutingTable
@ -26,36 +18,6 @@ logger = get_logger(name=__name__, category="core::routers")
class ToolRuntimeRouter(ToolRuntime): class ToolRuntimeRouter(ToolRuntime):
class RagToolImpl(RAGToolRuntime):
def __init__(
self,
routing_table: ToolGroupsRoutingTable,
) -> None:
logger.debug("Initializing ToolRuntimeRouter.RagToolImpl")
self.routing_table = routing_table
async def query(
self,
content: InterleavedContent,
vector_db_ids: list[str],
query_config: RAGQueryConfig | None = None,
) -> RAGQueryResult:
logger.debug(f"ToolRuntimeRouter.RagToolImpl.query: {vector_db_ids}")
provider = await self.routing_table.get_provider_impl("knowledge_search")
return await provider.query(content, vector_db_ids, query_config)
async def insert(
self,
documents: list[RAGDocument],
vector_db_id: str,
chunk_size_in_tokens: int = 512,
) -> None:
logger.debug(
f"ToolRuntimeRouter.RagToolImpl.insert: {vector_db_id}, {len(documents)} documents, chunk_size={chunk_size_in_tokens}"
)
provider = await self.routing_table.get_provider_impl("insert_into_memory")
return await provider.insert(documents, vector_db_id, chunk_size_in_tokens)
def __init__( def __init__(
self, self,
routing_table: ToolGroupsRoutingTable, routing_table: ToolGroupsRoutingTable,
@ -63,11 +25,6 @@ class ToolRuntimeRouter(ToolRuntime):
logger.debug("Initializing ToolRuntimeRouter") logger.debug("Initializing ToolRuntimeRouter")
self.routing_table = routing_table self.routing_table = routing_table
# HACK ALERT this should be in sync with "get_all_api_endpoints()"
self.rag_tool = self.RagToolImpl(routing_table)
for method in ("query", "insert"):
setattr(self, f"rag_tool.{method}", getattr(self.rag_tool, method))
async def initialize(self) -> None: async def initialize(self) -> None:
logger.debug("ToolRuntimeRouter.initialize") logger.debug("ToolRuntimeRouter.initialize")
pass pass

View file

@ -71,25 +71,6 @@ class VectorIORouter(VectorIO):
raise ValueError(f"Embedding model '{embedding_model_id}' not found or not an embedding model") raise ValueError(f"Embedding model '{embedding_model_id}' not found or not an embedding model")
async def register_vector_db(
self,
vector_db_id: str,
embedding_model: str,
embedding_dimension: int | None = 384,
provider_id: str | None = None,
vector_db_name: str | None = None,
provider_vector_db_id: str | None = None,
) -> None:
logger.debug(f"VectorIORouter.register_vector_db: {vector_db_id}, {embedding_model}")
await self.routing_table.register_vector_db(
vector_db_id,
embedding_model,
embedding_dimension,
provider_id,
vector_db_name,
provider_vector_db_id,
)
async def insert_chunks( async def insert_chunks(
self, self,
vector_db_id: str, vector_db_id: str,
@ -165,22 +146,22 @@ class VectorIORouter(VectorIO):
else: else:
provider_id = list(self.routing_table.impls_by_provider_id.keys())[0] provider_id = list(self.routing_table.impls_by_provider_id.keys())[0]
vector_db_id = f"vs_{uuid.uuid4()}" vector_store_id = f"vs_{uuid.uuid4()}"
registered_vector_db = await self.routing_table.register_vector_db( registered_vector_store = await self.routing_table.register_vector_store(
vector_db_id=vector_db_id, vector_store_id=vector_store_id,
embedding_model=embedding_model, embedding_model=embedding_model,
embedding_dimension=embedding_dimension, embedding_dimension=embedding_dimension,
provider_id=provider_id, provider_id=provider_id,
provider_vector_db_id=vector_db_id, provider_vector_store_id=vector_store_id,
vector_db_name=params.name, vector_store_name=params.name,
) )
provider = await self.routing_table.get_provider_impl(registered_vector_db.identifier) provider = await self.routing_table.get_provider_impl(registered_vector_store.identifier)
# Update model_extra with registered values so provider uses the already-registered vector_db # Update model_extra with registered values so provider uses the already-registered vector_store
if params.model_extra is None: if params.model_extra is None:
params.model_extra = {} params.model_extra = {}
params.model_extra["provider_vector_db_id"] = registered_vector_db.provider_resource_id params.model_extra["provider_vector_store_id"] = registered_vector_store.provider_resource_id
params.model_extra["provider_id"] = registered_vector_db.provider_id params.model_extra["provider_id"] = registered_vector_store.provider_id
if embedding_model is not None: if embedding_model is not None:
params.model_extra["embedding_model"] = embedding_model params.model_extra["embedding_model"] = embedding_model
if embedding_dimension is not None: if embedding_dimension is not None:
@ -198,15 +179,15 @@ class VectorIORouter(VectorIO):
logger.debug(f"VectorIORouter.openai_list_vector_stores: limit={limit}") logger.debug(f"VectorIORouter.openai_list_vector_stores: limit={limit}")
# Route to default provider for now - could aggregate from all providers in the future # Route to default provider for now - could aggregate from all providers in the future
# call retrieve on each vector dbs to get list of vector stores # call retrieve on each vector dbs to get list of vector stores
vector_dbs = await self.routing_table.get_all_with_type("vector_db") vector_stores = await self.routing_table.get_all_with_type("vector_store")
all_stores = [] all_stores = []
for vector_db in vector_dbs: for vector_store in vector_stores:
try: try:
provider = await self.routing_table.get_provider_impl(vector_db.identifier) provider = await self.routing_table.get_provider_impl(vector_store.identifier)
vector_store = await provider.openai_retrieve_vector_store(vector_db.identifier) vector_store = await provider.openai_retrieve_vector_store(vector_store.identifier)
all_stores.append(vector_store) all_stores.append(vector_store)
except Exception as e: except Exception as e:
logger.error(f"Error retrieving vector store {vector_db.identifier}: {e}") logger.error(f"Error retrieving vector store {vector_store.identifier}: {e}")
continue continue
# Sort by created_at # Sort by created_at

View file

@ -41,7 +41,7 @@ async def register_object_with_provider(obj: RoutableObject, p: Any) -> Routable
elif api == Api.safety: elif api == Api.safety:
return await p.register_shield(obj) return await p.register_shield(obj)
elif api == Api.vector_io: elif api == Api.vector_io:
return await p.register_vector_db(obj) return await p.register_vector_store(obj)
elif api == Api.datasetio: elif api == Api.datasetio:
return await p.register_dataset(obj) return await p.register_dataset(obj)
elif api == Api.scoring: elif api == Api.scoring:
@ -57,7 +57,7 @@ async def register_object_with_provider(obj: RoutableObject, p: Any) -> Routable
async def unregister_object_from_provider(obj: RoutableObject, p: Any) -> None: async def unregister_object_from_provider(obj: RoutableObject, p: Any) -> None:
api = get_impl_api(p) api = get_impl_api(p)
if api == Api.vector_io: if api == Api.vector_io:
return await p.unregister_vector_db(obj.identifier) return await p.unregister_vector_store(obj.identifier)
elif api == Api.inference: elif api == Api.inference:
return await p.unregister_model(obj.identifier) return await p.unregister_model(obj.identifier)
elif api == Api.safety: elif api == Api.safety:
@ -108,7 +108,7 @@ class CommonRoutingTableImpl(RoutingTable):
elif api == Api.safety: elif api == Api.safety:
p.shield_store = self p.shield_store = self
elif api == Api.vector_io: elif api == Api.vector_io:
p.vector_db_store = self p.vector_store_store = self
elif api == Api.datasetio: elif api == Api.datasetio:
p.dataset_store = self p.dataset_store = self
elif api == Api.scoring: elif api == Api.scoring:
@ -134,15 +134,15 @@ class CommonRoutingTableImpl(RoutingTable):
from .scoring_functions import ScoringFunctionsRoutingTable from .scoring_functions import ScoringFunctionsRoutingTable
from .shields import ShieldsRoutingTable from .shields import ShieldsRoutingTable
from .toolgroups import ToolGroupsRoutingTable from .toolgroups import ToolGroupsRoutingTable
from .vector_dbs import VectorDBsRoutingTable from .vector_stores import VectorStoresRoutingTable
def apiname_object(): def apiname_object():
if isinstance(self, ModelsRoutingTable): if isinstance(self, ModelsRoutingTable):
return ("Inference", "model") return ("Inference", "model")
elif isinstance(self, ShieldsRoutingTable): elif isinstance(self, ShieldsRoutingTable):
return ("Safety", "shield") return ("Safety", "shield")
elif isinstance(self, VectorDBsRoutingTable): elif isinstance(self, VectorStoresRoutingTable):
return ("VectorIO", "vector_db") return ("VectorIO", "vector_store")
elif isinstance(self, DatasetsRoutingTable): elif isinstance(self, DatasetsRoutingTable):
return ("DatasetIO", "dataset") return ("DatasetIO", "dataset")
elif isinstance(self, ScoringFunctionsRoutingTable): elif isinstance(self, ScoringFunctionsRoutingTable):

View file

@ -6,15 +6,12 @@
from typing import Any from typing import Any
from pydantic import TypeAdapter
from llama_stack.apis.common.errors import ModelNotFoundError, ModelTypeError from llama_stack.apis.common.errors import ModelNotFoundError, ModelTypeError
from llama_stack.apis.models import ModelType from llama_stack.apis.models import ModelType
from llama_stack.apis.resource import ResourceType from llama_stack.apis.resource import ResourceType
# Removed VectorDBs import to avoid exposing public API # Removed VectorStores import to avoid exposing public API
from llama_stack.apis.vector_io.vector_io import ( from llama_stack.apis.vector_io.vector_io import (
OpenAICreateVectorStoreRequestWithExtraBody,
SearchRankingOptions, SearchRankingOptions,
VectorStoreChunkingStrategy, VectorStoreChunkingStrategy,
VectorStoreDeleteResponse, VectorStoreDeleteResponse,
@ -26,7 +23,7 @@ from llama_stack.apis.vector_io.vector_io import (
VectorStoreSearchResponsePage, VectorStoreSearchResponsePage,
) )
from llama_stack.core.datatypes import ( from llama_stack.core.datatypes import (
VectorDBWithOwner, VectorStoreWithOwner,
) )
from llama_stack.log import get_logger from llama_stack.log import get_logger
@ -35,23 +32,23 @@ from .common import CommonRoutingTableImpl, lookup_model
logger = get_logger(name=__name__, category="core::routing_tables") logger = get_logger(name=__name__, category="core::routing_tables")
class VectorDBsRoutingTable(CommonRoutingTableImpl): class VectorStoresRoutingTable(CommonRoutingTableImpl):
"""Internal routing table for vector_db operations. """Internal routing table for vector_store operations.
Does not inherit from VectorDBs to avoid exposing public API endpoints. Does not inherit from VectorStores to avoid exposing public API endpoints.
Only provides internal routing functionality for VectorIORouter. Only provides internal routing functionality for VectorIORouter.
""" """
# Internal methods only - no public API exposure # Internal methods only - no public API exposure
async def register_vector_db( async def register_vector_store(
self, self,
vector_db_id: str, vector_store_id: str,
embedding_model: str, embedding_model: str,
embedding_dimension: int | None = 384, embedding_dimension: int | None = 384,
provider_id: str | None = None, provider_id: str | None = None,
provider_vector_db_id: str | None = None, provider_vector_store_id: str | None = None,
vector_db_name: str | None = None, vector_store_name: str | None = None,
) -> Any: ) -> Any:
if provider_id is None: if provider_id is None:
if len(self.impls_by_provider_id) > 0: if len(self.impls_by_provider_id) > 0:
@ -67,52 +64,24 @@ class VectorDBsRoutingTable(CommonRoutingTableImpl):
raise ModelNotFoundError(embedding_model) raise ModelNotFoundError(embedding_model)
if model.model_type != ModelType.embedding: if model.model_type != ModelType.embedding:
raise ModelTypeError(embedding_model, model.model_type, ModelType.embedding) raise ModelTypeError(embedding_model, model.model_type, ModelType.embedding)
if "embedding_dimension" not in model.metadata:
raise ValueError(f"Model {embedding_model} does not have an embedding dimension")
try: vector_store = VectorStoreWithOwner(
provider = self.impls_by_provider_id[provider_id] identifier=vector_store_id,
except KeyError: type=ResourceType.vector_store.value,
available_providers = list(self.impls_by_provider_id.keys())
raise ValueError(
f"Provider '{provider_id}' not found in routing table. Available providers: {available_providers}"
) from None
logger.warning(
"VectorDB is being deprecated in future releases in favor of VectorStore. Please migrate your usage accordingly."
)
request = OpenAICreateVectorStoreRequestWithExtraBody(
name=vector_db_name or vector_db_id,
embedding_model=embedding_model,
embedding_dimension=model.metadata["embedding_dimension"],
provider_id=provider_id, provider_id=provider_id,
provider_vector_db_id=provider_vector_db_id, provider_resource_id=provider_vector_store_id,
embedding_model=embedding_model,
embedding_dimension=embedding_dimension,
vector_store_name=vector_store_name,
) )
vector_store = await provider.openai_create_vector_store(request) await self.register_object(vector_store)
return vector_store
vector_store_id = vector_store.id
actual_provider_vector_db_id = provider_vector_db_id or vector_store_id
logger.warning(
f"Ignoring vector_db_id {vector_db_id} and using vector_store_id {vector_store_id} instead. Setting VectorDB {vector_db_id} to VectorDB.vector_db_name"
)
vector_db_data = {
"identifier": vector_store_id,
"type": ResourceType.vector_db.value,
"provider_id": provider_id,
"provider_resource_id": actual_provider_vector_db_id,
"embedding_model": embedding_model,
"embedding_dimension": model.metadata["embedding_dimension"],
"vector_db_name": vector_store.name,
}
vector_db = TypeAdapter(VectorDBWithOwner).validate_python(vector_db_data)
await self.register_object(vector_db)
return vector_db
async def openai_retrieve_vector_store( async def openai_retrieve_vector_store(
self, self,
vector_store_id: str, vector_store_id: str,
) -> VectorStoreObject: ) -> VectorStoreObject:
await self.assert_action_allowed("read", "vector_db", vector_store_id) await self.assert_action_allowed("read", "vector_store", vector_store_id)
provider = await self.get_provider_impl(vector_store_id) provider = await self.get_provider_impl(vector_store_id)
return await provider.openai_retrieve_vector_store(vector_store_id) return await provider.openai_retrieve_vector_store(vector_store_id)
@ -123,7 +92,7 @@ class VectorDBsRoutingTable(CommonRoutingTableImpl):
expires_after: dict[str, Any] | None = None, expires_after: dict[str, Any] | None = None,
metadata: dict[str, Any] | None = None, metadata: dict[str, Any] | None = None,
) -> VectorStoreObject: ) -> VectorStoreObject:
await self.assert_action_allowed("update", "vector_db", vector_store_id) await self.assert_action_allowed("update", "vector_store", vector_store_id)
provider = await self.get_provider_impl(vector_store_id) provider = await self.get_provider_impl(vector_store_id)
return await provider.openai_update_vector_store( return await provider.openai_update_vector_store(
vector_store_id=vector_store_id, vector_store_id=vector_store_id,
@ -136,18 +105,18 @@ class VectorDBsRoutingTable(CommonRoutingTableImpl):
self, self,
vector_store_id: str, vector_store_id: str,
) -> VectorStoreDeleteResponse: ) -> VectorStoreDeleteResponse:
await self.assert_action_allowed("delete", "vector_db", vector_store_id) await self.assert_action_allowed("delete", "vector_store", vector_store_id)
provider = await self.get_provider_impl(vector_store_id) provider = await self.get_provider_impl(vector_store_id)
result = await provider.openai_delete_vector_store(vector_store_id) result = await provider.openai_delete_vector_store(vector_store_id)
await self.unregister_vector_db(vector_store_id) await self.unregister_vector_store(vector_store_id)
return result return result
async def unregister_vector_db(self, vector_store_id: str) -> None: async def unregister_vector_store(self, vector_store_id: str) -> None:
"""Remove the vector store from the routing table registry.""" """Remove the vector store from the routing table registry."""
try: try:
vector_db_obj = await self.get_object_by_identifier("vector_db", vector_store_id) vector_store_obj = await self.get_object_by_identifier("vector_store", vector_store_id)
if vector_db_obj: if vector_store_obj:
await self.unregister_object(vector_db_obj) await self.unregister_object(vector_store_obj)
except Exception as e: except Exception as e:
# Log the error but don't fail the operation # Log the error but don't fail the operation
logger.warning(f"Failed to unregister vector store {vector_store_id} from routing table: {e}") logger.warning(f"Failed to unregister vector store {vector_store_id} from routing table: {e}")
@ -162,7 +131,7 @@ class VectorDBsRoutingTable(CommonRoutingTableImpl):
rewrite_query: bool | None = False, rewrite_query: bool | None = False,
search_mode: str | None = "vector", search_mode: str | None = "vector",
) -> VectorStoreSearchResponsePage: ) -> VectorStoreSearchResponsePage:
await self.assert_action_allowed("read", "vector_db", vector_store_id) await self.assert_action_allowed("read", "vector_store", vector_store_id)
provider = await self.get_provider_impl(vector_store_id) provider = await self.get_provider_impl(vector_store_id)
return await provider.openai_search_vector_store( return await provider.openai_search_vector_store(
vector_store_id=vector_store_id, vector_store_id=vector_store_id,
@ -181,7 +150,7 @@ class VectorDBsRoutingTable(CommonRoutingTableImpl):
attributes: dict[str, Any] | None = None, attributes: dict[str, Any] | None = None,
chunking_strategy: VectorStoreChunkingStrategy | None = None, chunking_strategy: VectorStoreChunkingStrategy | None = None,
) -> VectorStoreFileObject: ) -> VectorStoreFileObject:
await self.assert_action_allowed("update", "vector_db", vector_store_id) await self.assert_action_allowed("update", "vector_store", vector_store_id)
provider = await self.get_provider_impl(vector_store_id) provider = await self.get_provider_impl(vector_store_id)
return await provider.openai_attach_file_to_vector_store( return await provider.openai_attach_file_to_vector_store(
vector_store_id=vector_store_id, vector_store_id=vector_store_id,
@ -199,7 +168,7 @@ class VectorDBsRoutingTable(CommonRoutingTableImpl):
before: str | None = None, before: str | None = None,
filter: VectorStoreFileStatus | None = None, filter: VectorStoreFileStatus | None = None,
) -> list[VectorStoreFileObject]: ) -> list[VectorStoreFileObject]:
await self.assert_action_allowed("read", "vector_db", vector_store_id) await self.assert_action_allowed("read", "vector_store", vector_store_id)
provider = await self.get_provider_impl(vector_store_id) provider = await self.get_provider_impl(vector_store_id)
return await provider.openai_list_files_in_vector_store( return await provider.openai_list_files_in_vector_store(
vector_store_id=vector_store_id, vector_store_id=vector_store_id,
@ -215,7 +184,7 @@ class VectorDBsRoutingTable(CommonRoutingTableImpl):
vector_store_id: str, vector_store_id: str,
file_id: str, file_id: str,
) -> VectorStoreFileObject: ) -> VectorStoreFileObject:
await self.assert_action_allowed("read", "vector_db", vector_store_id) await self.assert_action_allowed("read", "vector_store", vector_store_id)
provider = await self.get_provider_impl(vector_store_id) provider = await self.get_provider_impl(vector_store_id)
return await provider.openai_retrieve_vector_store_file( return await provider.openai_retrieve_vector_store_file(
vector_store_id=vector_store_id, vector_store_id=vector_store_id,
@ -227,7 +196,7 @@ class VectorDBsRoutingTable(CommonRoutingTableImpl):
vector_store_id: str, vector_store_id: str,
file_id: str, file_id: str,
) -> VectorStoreFileContentsResponse: ) -> VectorStoreFileContentsResponse:
await self.assert_action_allowed("read", "vector_db", vector_store_id) await self.assert_action_allowed("read", "vector_store", vector_store_id)
provider = await self.get_provider_impl(vector_store_id) provider = await self.get_provider_impl(vector_store_id)
return await provider.openai_retrieve_vector_store_file_contents( return await provider.openai_retrieve_vector_store_file_contents(
vector_store_id=vector_store_id, vector_store_id=vector_store_id,
@ -240,7 +209,7 @@ class VectorDBsRoutingTable(CommonRoutingTableImpl):
file_id: str, file_id: str,
attributes: dict[str, Any], attributes: dict[str, Any],
) -> VectorStoreFileObject: ) -> VectorStoreFileObject:
await self.assert_action_allowed("update", "vector_db", vector_store_id) await self.assert_action_allowed("update", "vector_store", vector_store_id)
provider = await self.get_provider_impl(vector_store_id) provider = await self.get_provider_impl(vector_store_id)
return await provider.openai_update_vector_store_file( return await provider.openai_update_vector_store_file(
vector_store_id=vector_store_id, vector_store_id=vector_store_id,
@ -253,7 +222,7 @@ class VectorDBsRoutingTable(CommonRoutingTableImpl):
vector_store_id: str, vector_store_id: str,
file_id: str, file_id: str,
) -> VectorStoreFileDeleteResponse: ) -> VectorStoreFileDeleteResponse:
await self.assert_action_allowed("delete", "vector_db", vector_store_id) await self.assert_action_allowed("delete", "vector_store", vector_store_id)
provider = await self.get_provider_impl(vector_store_id) provider = await self.get_provider_impl(vector_store_id)
return await provider.openai_delete_vector_store_file( return await provider.openai_delete_vector_store_file(
vector_store_id=vector_store_id, vector_store_id=vector_store_id,
@ -267,7 +236,7 @@ class VectorDBsRoutingTable(CommonRoutingTableImpl):
attributes: dict[str, Any] | None = None, attributes: dict[str, Any] | None = None,
chunking_strategy: Any | None = None, chunking_strategy: Any | None = None,
): ):
await self.assert_action_allowed("update", "vector_db", vector_store_id) await self.assert_action_allowed("update", "vector_store", vector_store_id)
provider = await self.get_provider_impl(vector_store_id) provider = await self.get_provider_impl(vector_store_id)
return await provider.openai_create_vector_store_file_batch( return await provider.openai_create_vector_store_file_batch(
vector_store_id=vector_store_id, vector_store_id=vector_store_id,
@ -281,7 +250,7 @@ class VectorDBsRoutingTable(CommonRoutingTableImpl):
batch_id: str, batch_id: str,
vector_store_id: str, vector_store_id: str,
): ):
await self.assert_action_allowed("read", "vector_db", vector_store_id) await self.assert_action_allowed("read", "vector_store", vector_store_id)
provider = await self.get_provider_impl(vector_store_id) provider = await self.get_provider_impl(vector_store_id)
return await provider.openai_retrieve_vector_store_file_batch( return await provider.openai_retrieve_vector_store_file_batch(
batch_id=batch_id, batch_id=batch_id,
@ -298,7 +267,7 @@ class VectorDBsRoutingTable(CommonRoutingTableImpl):
limit: int | None = 20, limit: int | None = 20,
order: str | None = "desc", order: str | None = "desc",
): ):
await self.assert_action_allowed("read", "vector_db", vector_store_id) await self.assert_action_allowed("read", "vector_store", vector_store_id)
provider = await self.get_provider_impl(vector_store_id) provider = await self.get_provider_impl(vector_store_id)
return await provider.openai_list_files_in_vector_store_file_batch( return await provider.openai_list_files_in_vector_store_file_batch(
batch_id=batch_id, batch_id=batch_id,
@ -315,7 +284,7 @@ class VectorDBsRoutingTable(CommonRoutingTableImpl):
batch_id: str, batch_id: str,
vector_store_id: str, vector_store_id: str,
): ):
await self.assert_action_allowed("update", "vector_db", vector_store_id) await self.assert_action_allowed("update", "vector_store", vector_store_id)
provider = await self.get_provider_impl(vector_store_id) provider = await self.get_provider_impl(vector_store_id)
return await provider.openai_cancel_vector_store_file_batch( return await provider.openai_cancel_vector_store_file_batch(
batch_id=batch_id, batch_id=batch_id,

View file

@ -13,7 +13,6 @@ from aiohttp import hdrs
from starlette.routing import Route from starlette.routing import Route
from llama_stack.apis.datatypes import Api, ExternalApiSpec from llama_stack.apis.datatypes import Api, ExternalApiSpec
from llama_stack.apis.tools import RAGToolRuntime, SpecialToolGroup
from llama_stack.core.resolver import api_protocol_map from llama_stack.core.resolver import api_protocol_map
from llama_stack.schema_utils import WebMethod from llama_stack.schema_utils import WebMethod
@ -25,33 +24,16 @@ RouteImpls = dict[str, PathImpl]
RouteMatch = tuple[EndpointFunc, PathParams, str, WebMethod] RouteMatch = tuple[EndpointFunc, PathParams, str, WebMethod]
def toolgroup_protocol_map():
return {
SpecialToolGroup.rag_tool: RAGToolRuntime,
}
def get_all_api_routes( def get_all_api_routes(
external_apis: dict[Api, ExternalApiSpec] | None = None, external_apis: dict[Api, ExternalApiSpec] | None = None,
) -> dict[Api, list[tuple[Route, WebMethod]]]: ) -> dict[Api, list[tuple[Route, WebMethod]]]:
apis = {} apis = {}
protocols = api_protocol_map(external_apis) protocols = api_protocol_map(external_apis)
toolgroup_protocols = toolgroup_protocol_map()
for api, protocol in protocols.items(): for api, protocol in protocols.items():
routes = [] routes = []
protocol_methods = inspect.getmembers(protocol, predicate=inspect.isfunction) protocol_methods = inspect.getmembers(protocol, predicate=inspect.isfunction)
# HACK ALERT
if api == Api.tool_runtime:
for tool_group in SpecialToolGroup:
sub_protocol = toolgroup_protocols[tool_group]
sub_protocol_methods = inspect.getmembers(sub_protocol, predicate=inspect.isfunction)
for name, method in sub_protocol_methods:
if not hasattr(method, "__webmethod__"):
continue
protocol_methods.append((f"{tool_group.value}.{name}", method))
for name, method in protocol_methods: for name, method in protocol_methods:
# Get all webmethods for this method (supports multiple decorators) # Get all webmethods for this method (supports multiple decorators)
webmethods = getattr(method, "__webmethods__", []) webmethods = getattr(method, "__webmethods__", [])

View file

@ -32,7 +32,7 @@ from llama_stack.apis.scoring_functions import ScoringFunctions
from llama_stack.apis.shields import Shields from llama_stack.apis.shields import Shields
from llama_stack.apis.synthetic_data_generation import SyntheticDataGeneration from llama_stack.apis.synthetic_data_generation import SyntheticDataGeneration
from llama_stack.apis.telemetry import Telemetry from llama_stack.apis.telemetry import Telemetry
from llama_stack.apis.tools import RAGToolRuntime, ToolGroups, ToolRuntime from llama_stack.apis.tools import ToolGroups, ToolRuntime
from llama_stack.apis.vector_io import VectorIO from llama_stack.apis.vector_io import VectorIO
from llama_stack.core.conversations.conversations import ConversationServiceConfig, ConversationServiceImpl from llama_stack.core.conversations.conversations import ConversationServiceConfig, ConversationServiceImpl
from llama_stack.core.datatypes import Provider, StackRunConfig, VectorStoresConfig from llama_stack.core.datatypes import Provider, StackRunConfig, VectorStoresConfig
@ -80,7 +80,6 @@ class LlamaStack(
Inspect, Inspect,
ToolGroups, ToolGroups,
ToolRuntime, ToolRuntime,
RAGToolRuntime,
Files, Files,
Prompts, Prompts,
Conversations, Conversations,

View file

@ -32,7 +32,7 @@ def tool_chat_page():
tool_groups_list = [tool_group.identifier for tool_group in tool_groups] tool_groups_list = [tool_group.identifier for tool_group in tool_groups]
mcp_tools_list = [tool for tool in tool_groups_list if tool.startswith("mcp::")] mcp_tools_list = [tool for tool in tool_groups_list if tool.startswith("mcp::")]
builtin_tools_list = [tool for tool in tool_groups_list if not tool.startswith("mcp::")] builtin_tools_list = [tool for tool in tool_groups_list if not tool.startswith("mcp::")]
selected_vector_dbs = [] selected_vector_stores = []
def reset_agent(): def reset_agent():
st.session_state.clear() st.session_state.clear()
@ -55,13 +55,13 @@ def tool_chat_page():
) )
if "builtin::rag" in toolgroup_selection: if "builtin::rag" in toolgroup_selection:
vector_dbs = llama_stack_api.client.vector_dbs.list() or [] vector_stores = llama_stack_api.client.vector_stores.list() or []
if not vector_dbs: if not vector_stores:
st.info("No vector databases available for selection.") st.info("No vector databases available for selection.")
vector_dbs = [vector_db.identifier for vector_db in vector_dbs] vector_stores = [vector_store.identifier for vector_store in vector_stores]
selected_vector_dbs = st.multiselect( selected_vector_stores = st.multiselect(
label="Select Document Collections to use in RAG queries", label="Select Document Collections to use in RAG queries",
options=vector_dbs, options=vector_stores,
on_change=reset_agent, on_change=reset_agent,
) )
@ -119,7 +119,7 @@ def tool_chat_page():
tool_dict = dict( tool_dict = dict(
name="builtin::rag", name="builtin::rag",
args={ args={
"vector_db_ids": list(selected_vector_dbs), "vector_store_ids": list(selected_vector_stores),
}, },
) )
toolgroup_selection[i] = tool_dict toolgroup_selection[i] = tool_dict

View file

@ -48,7 +48,6 @@ distribution_spec:
tool_runtime: tool_runtime:
- provider_type: remote::brave-search - provider_type: remote::brave-search
- provider_type: remote::tavily-search - provider_type: remote::tavily-search
- provider_type: inline::rag-runtime
- provider_type: remote::model-context-protocol - provider_type: remote::model-context-protocol
batches: batches:
- provider_type: inline::reference - provider_type: inline::reference

View file

@ -216,8 +216,6 @@ providers:
config: config:
api_key: ${env.TAVILY_SEARCH_API_KEY:=} api_key: ${env.TAVILY_SEARCH_API_KEY:=}
max_results: 3 max_results: 3
- provider_id: rag-runtime
provider_type: inline::rag-runtime
- provider_id: model-context-protocol - provider_id: model-context-protocol
provider_type: remote::model-context-protocol provider_type: remote::model-context-protocol
batches: batches:
@ -263,8 +261,6 @@ registered_resources:
tool_groups: tool_groups:
- toolgroup_id: builtin::websearch - toolgroup_id: builtin::websearch
provider_id: tavily-search provider_id: tavily-search
- toolgroup_id: builtin::rag
provider_id: rag-runtime
server: server:
port: 8321 port: 8321
telemetry: telemetry:

View file

@ -26,7 +26,6 @@ distribution_spec:
tool_runtime: tool_runtime:
- provider_type: remote::brave-search - provider_type: remote::brave-search
- provider_type: remote::tavily-search - provider_type: remote::tavily-search
- provider_type: inline::rag-runtime
image_type: venv image_type: venv
additional_pip_packages: additional_pip_packages:
- aiosqlite - aiosqlite

View file

@ -45,7 +45,6 @@ def get_distribution_template() -> DistributionTemplate:
"tool_runtime": [ "tool_runtime": [
BuildProvider(provider_type="remote::brave-search"), BuildProvider(provider_type="remote::brave-search"),
BuildProvider(provider_type="remote::tavily-search"), BuildProvider(provider_type="remote::tavily-search"),
BuildProvider(provider_type="inline::rag-runtime"),
], ],
} }
name = "dell" name = "dell"
@ -98,10 +97,6 @@ def get_distribution_template() -> DistributionTemplate:
toolgroup_id="builtin::websearch", toolgroup_id="builtin::websearch",
provider_id="brave-search", provider_id="brave-search",
), ),
ToolGroupInput(
toolgroup_id="builtin::rag",
provider_id="rag-runtime",
),
] ]
return DistributionTemplate( return DistributionTemplate(

View file

@ -87,8 +87,6 @@ providers:
config: config:
api_key: ${env.TAVILY_SEARCH_API_KEY:=} api_key: ${env.TAVILY_SEARCH_API_KEY:=}
max_results: 3 max_results: 3
- provider_id: rag-runtime
provider_type: inline::rag-runtime
storage: storage:
backends: backends:
kv_default: kv_default:
@ -133,8 +131,6 @@ registered_resources:
tool_groups: tool_groups:
- toolgroup_id: builtin::websearch - toolgroup_id: builtin::websearch
provider_id: brave-search provider_id: brave-search
- toolgroup_id: builtin::rag
provider_id: rag-runtime
server: server:
port: 8321 port: 8321
telemetry: telemetry:

View file

@ -83,8 +83,6 @@ providers:
config: config:
api_key: ${env.TAVILY_SEARCH_API_KEY:=} api_key: ${env.TAVILY_SEARCH_API_KEY:=}
max_results: 3 max_results: 3
- provider_id: rag-runtime
provider_type: inline::rag-runtime
storage: storage:
backends: backends:
kv_default: kv_default:
@ -124,8 +122,6 @@ registered_resources:
tool_groups: tool_groups:
- toolgroup_id: builtin::websearch - toolgroup_id: builtin::websearch
provider_id: brave-search provider_id: brave-search
- toolgroup_id: builtin::rag
provider_id: rag-runtime
server: server:
port: 8321 port: 8321
telemetry: telemetry:

View file

@ -24,7 +24,6 @@ distribution_spec:
tool_runtime: tool_runtime:
- provider_type: remote::brave-search - provider_type: remote::brave-search
- provider_type: remote::tavily-search - provider_type: remote::tavily-search
- provider_type: inline::rag-runtime
- provider_type: remote::model-context-protocol - provider_type: remote::model-context-protocol
image_type: venv image_type: venv
additional_pip_packages: additional_pip_packages:

View file

@ -47,7 +47,6 @@ def get_distribution_template() -> DistributionTemplate:
"tool_runtime": [ "tool_runtime": [
BuildProvider(provider_type="remote::brave-search"), BuildProvider(provider_type="remote::brave-search"),
BuildProvider(provider_type="remote::tavily-search"), BuildProvider(provider_type="remote::tavily-search"),
BuildProvider(provider_type="inline::rag-runtime"),
BuildProvider(provider_type="remote::model-context-protocol"), BuildProvider(provider_type="remote::model-context-protocol"),
], ],
} }
@ -92,10 +91,6 @@ def get_distribution_template() -> DistributionTemplate:
toolgroup_id="builtin::websearch", toolgroup_id="builtin::websearch",
provider_id="tavily-search", provider_id="tavily-search",
), ),
ToolGroupInput(
toolgroup_id="builtin::rag",
provider_id="rag-runtime",
),
] ]
return DistributionTemplate( return DistributionTemplate(

View file

@ -98,8 +98,6 @@ providers:
config: config:
api_key: ${env.TAVILY_SEARCH_API_KEY:=} api_key: ${env.TAVILY_SEARCH_API_KEY:=}
max_results: 3 max_results: 3
- provider_id: rag-runtime
provider_type: inline::rag-runtime
- provider_id: model-context-protocol - provider_id: model-context-protocol
provider_type: remote::model-context-protocol provider_type: remote::model-context-protocol
storage: storage:
@ -146,8 +144,6 @@ registered_resources:
tool_groups: tool_groups:
- toolgroup_id: builtin::websearch - toolgroup_id: builtin::websearch
provider_id: tavily-search provider_id: tavily-search
- toolgroup_id: builtin::rag
provider_id: rag-runtime
server: server:
port: 8321 port: 8321
telemetry: telemetry:

View file

@ -88,8 +88,6 @@ providers:
config: config:
api_key: ${env.TAVILY_SEARCH_API_KEY:=} api_key: ${env.TAVILY_SEARCH_API_KEY:=}
max_results: 3 max_results: 3
- provider_id: rag-runtime
provider_type: inline::rag-runtime
- provider_id: model-context-protocol - provider_id: model-context-protocol
provider_type: remote::model-context-protocol provider_type: remote::model-context-protocol
storage: storage:
@ -131,8 +129,6 @@ registered_resources:
tool_groups: tool_groups:
- toolgroup_id: builtin::websearch - toolgroup_id: builtin::websearch
provider_id: tavily-search provider_id: tavily-search
- toolgroup_id: builtin::rag
provider_id: rag-runtime
server: server:
port: 8321 port: 8321
telemetry: telemetry:

View file

@ -19,8 +19,7 @@ distribution_spec:
- provider_type: remote::nvidia - provider_type: remote::nvidia
scoring: scoring:
- provider_type: inline::basic - provider_type: inline::basic
tool_runtime: tool_runtime: []
- provider_type: inline::rag-runtime
files: files:
- provider_type: inline::localfs - provider_type: inline::localfs
image_type: venv image_type: venv

View file

@ -28,7 +28,7 @@ def get_distribution_template(name: str = "nvidia") -> DistributionTemplate:
BuildProvider(provider_type="remote::nvidia"), BuildProvider(provider_type="remote::nvidia"),
], ],
"scoring": [BuildProvider(provider_type="inline::basic")], "scoring": [BuildProvider(provider_type="inline::basic")],
"tool_runtime": [BuildProvider(provider_type="inline::rag-runtime")], "tool_runtime": [],
"files": [BuildProvider(provider_type="inline::localfs")], "files": [BuildProvider(provider_type="inline::localfs")],
} }
@ -66,12 +66,7 @@ def get_distribution_template(name: str = "nvidia") -> DistributionTemplate:
provider_id="nvidia", provider_id="nvidia",
) )
default_tool_groups = [ default_tool_groups: list[ToolGroupInput] = []
ToolGroupInput(
toolgroup_id="builtin::rag",
provider_id="rag-runtime",
),
]
return DistributionTemplate( return DistributionTemplate(
name=name, name=name,

View file

@ -80,9 +80,7 @@ providers:
scoring: scoring:
- provider_id: basic - provider_id: basic
provider_type: inline::basic provider_type: inline::basic
tool_runtime: tool_runtime: []
- provider_id: rag-runtime
provider_type: inline::rag-runtime
files: files:
- provider_id: meta-reference-files - provider_id: meta-reference-files
provider_type: inline::localfs provider_type: inline::localfs
@ -128,9 +126,7 @@ registered_resources:
datasets: [] datasets: []
scoring_fns: [] scoring_fns: []
benchmarks: [] benchmarks: []
tool_groups: tool_groups: []
- toolgroup_id: builtin::rag
provider_id: rag-runtime
server: server:
port: 8321 port: 8321
telemetry: telemetry:

View file

@ -69,9 +69,7 @@ providers:
scoring: scoring:
- provider_id: basic - provider_id: basic
provider_type: inline::basic provider_type: inline::basic
tool_runtime: tool_runtime: []
- provider_id: rag-runtime
provider_type: inline::rag-runtime
files: files:
- provider_id: meta-reference-files - provider_id: meta-reference-files
provider_type: inline::localfs provider_type: inline::localfs
@ -107,9 +105,7 @@ registered_resources:
datasets: [] datasets: []
scoring_fns: [] scoring_fns: []
benchmarks: [] benchmarks: []
tool_groups: tool_groups: []
- toolgroup_id: builtin::rag
provider_id: rag-runtime
server: server:
port: 8321 port: 8321
telemetry: telemetry:

View file

@ -28,7 +28,6 @@ distribution_spec:
tool_runtime: tool_runtime:
- provider_type: remote::brave-search - provider_type: remote::brave-search
- provider_type: remote::tavily-search - provider_type: remote::tavily-search
- provider_type: inline::rag-runtime
- provider_type: remote::model-context-protocol - provider_type: remote::model-context-protocol
image_type: venv image_type: venv
additional_pip_packages: additional_pip_packages:

View file

@ -118,7 +118,6 @@ def get_distribution_template() -> DistributionTemplate:
"tool_runtime": [ "tool_runtime": [
BuildProvider(provider_type="remote::brave-search"), BuildProvider(provider_type="remote::brave-search"),
BuildProvider(provider_type="remote::tavily-search"), BuildProvider(provider_type="remote::tavily-search"),
BuildProvider(provider_type="inline::rag-runtime"),
BuildProvider(provider_type="remote::model-context-protocol"), BuildProvider(provider_type="remote::model-context-protocol"),
], ],
} }
@ -154,10 +153,6 @@ def get_distribution_template() -> DistributionTemplate:
toolgroup_id="builtin::websearch", toolgroup_id="builtin::websearch",
provider_id="tavily-search", provider_id="tavily-search",
), ),
ToolGroupInput(
toolgroup_id="builtin::rag",
provider_id="rag-runtime",
),
] ]
models, _ = get_model_registry(available_models) models, _ = get_model_registry(available_models)

View file

@ -118,8 +118,6 @@ providers:
config: config:
api_key: ${env.TAVILY_SEARCH_API_KEY:=} api_key: ${env.TAVILY_SEARCH_API_KEY:=}
max_results: 3 max_results: 3
- provider_id: rag-runtime
provider_type: inline::rag-runtime
- provider_id: model-context-protocol - provider_id: model-context-protocol
provider_type: remote::model-context-protocol provider_type: remote::model-context-protocol
storage: storage:
@ -244,8 +242,6 @@ registered_resources:
tool_groups: tool_groups:
- toolgroup_id: builtin::websearch - toolgroup_id: builtin::websearch
provider_id: tavily-search provider_id: tavily-search
- toolgroup_id: builtin::rag
provider_id: rag-runtime
server: server:
port: 8321 port: 8321
telemetry: telemetry:

View file

@ -14,7 +14,6 @@ distribution_spec:
tool_runtime: tool_runtime:
- provider_type: remote::brave-search - provider_type: remote::brave-search
- provider_type: remote::tavily-search - provider_type: remote::tavily-search
- provider_type: inline::rag-runtime
- provider_type: remote::model-context-protocol - provider_type: remote::model-context-protocol
image_type: venv image_type: venv
additional_pip_packages: additional_pip_packages:

View file

@ -45,7 +45,6 @@ def get_distribution_template() -> DistributionTemplate:
"tool_runtime": [ "tool_runtime": [
BuildProvider(provider_type="remote::brave-search"), BuildProvider(provider_type="remote::brave-search"),
BuildProvider(provider_type="remote::tavily-search"), BuildProvider(provider_type="remote::tavily-search"),
BuildProvider(provider_type="inline::rag-runtime"),
BuildProvider(provider_type="remote::model-context-protocol"), BuildProvider(provider_type="remote::model-context-protocol"),
], ],
} }
@ -66,10 +65,6 @@ def get_distribution_template() -> DistributionTemplate:
toolgroup_id="builtin::websearch", toolgroup_id="builtin::websearch",
provider_id="tavily-search", provider_id="tavily-search",
), ),
ToolGroupInput(
toolgroup_id="builtin::rag",
provider_id="rag-runtime",
),
] ]
default_models = [ default_models = [

View file

@ -54,8 +54,6 @@ providers:
config: config:
api_key: ${env.TAVILY_SEARCH_API_KEY:=} api_key: ${env.TAVILY_SEARCH_API_KEY:=}
max_results: 3 max_results: 3
- provider_id: rag-runtime
provider_type: inline::rag-runtime
- provider_id: model-context-protocol - provider_id: model-context-protocol
provider_type: remote::model-context-protocol provider_type: remote::model-context-protocol
storage: storage:
@ -107,8 +105,6 @@ registered_resources:
tool_groups: tool_groups:
- toolgroup_id: builtin::websearch - toolgroup_id: builtin::websearch
provider_id: tavily-search provider_id: tavily-search
- toolgroup_id: builtin::rag
provider_id: rag-runtime
server: server:
port: 8321 port: 8321
telemetry: telemetry:

View file

@ -49,7 +49,6 @@ distribution_spec:
tool_runtime: tool_runtime:
- provider_type: remote::brave-search - provider_type: remote::brave-search
- provider_type: remote::tavily-search - provider_type: remote::tavily-search
- provider_type: inline::rag-runtime
- provider_type: remote::model-context-protocol - provider_type: remote::model-context-protocol
batches: batches:
- provider_type: inline::reference - provider_type: inline::reference

View file

@ -219,8 +219,6 @@ providers:
config: config:
api_key: ${env.TAVILY_SEARCH_API_KEY:=} api_key: ${env.TAVILY_SEARCH_API_KEY:=}
max_results: 3 max_results: 3
- provider_id: rag-runtime
provider_type: inline::rag-runtime
- provider_id: model-context-protocol - provider_id: model-context-protocol
provider_type: remote::model-context-protocol provider_type: remote::model-context-protocol
batches: batches:
@ -266,8 +264,6 @@ registered_resources:
tool_groups: tool_groups:
- toolgroup_id: builtin::websearch - toolgroup_id: builtin::websearch
provider_id: tavily-search provider_id: tavily-search
- toolgroup_id: builtin::rag
provider_id: rag-runtime
server: server:
port: 8321 port: 8321
telemetry: telemetry:

View file

@ -49,7 +49,6 @@ distribution_spec:
tool_runtime: tool_runtime:
- provider_type: remote::brave-search - provider_type: remote::brave-search
- provider_type: remote::tavily-search - provider_type: remote::tavily-search
- provider_type: inline::rag-runtime
- provider_type: remote::model-context-protocol - provider_type: remote::model-context-protocol
batches: batches:
- provider_type: inline::reference - provider_type: inline::reference

View file

@ -216,8 +216,6 @@ providers:
config: config:
api_key: ${env.TAVILY_SEARCH_API_KEY:=} api_key: ${env.TAVILY_SEARCH_API_KEY:=}
max_results: 3 max_results: 3
- provider_id: rag-runtime
provider_type: inline::rag-runtime
- provider_id: model-context-protocol - provider_id: model-context-protocol
provider_type: remote::model-context-protocol provider_type: remote::model-context-protocol
batches: batches:
@ -263,8 +261,6 @@ registered_resources:
tool_groups: tool_groups:
- toolgroup_id: builtin::websearch - toolgroup_id: builtin::websearch
provider_id: tavily-search provider_id: tavily-search
- toolgroup_id: builtin::rag
provider_id: rag-runtime
server: server:
port: 8321 port: 8321
telemetry: telemetry:

View file

@ -140,7 +140,6 @@ def get_distribution_template(name: str = "starter") -> DistributionTemplate:
"tool_runtime": [ "tool_runtime": [
BuildProvider(provider_type="remote::brave-search"), BuildProvider(provider_type="remote::brave-search"),
BuildProvider(provider_type="remote::tavily-search"), BuildProvider(provider_type="remote::tavily-search"),
BuildProvider(provider_type="inline::rag-runtime"),
BuildProvider(provider_type="remote::model-context-protocol"), BuildProvider(provider_type="remote::model-context-protocol"),
], ],
"batches": [ "batches": [
@ -162,10 +161,6 @@ def get_distribution_template(name: str = "starter") -> DistributionTemplate:
toolgroup_id="builtin::websearch", toolgroup_id="builtin::websearch",
provider_id="tavily-search", provider_id="tavily-search",
), ),
ToolGroupInput(
toolgroup_id="builtin::rag",
provider_id="rag-runtime",
),
] ]
default_shields = [ default_shields = [
# if the # if the

View file

@ -23,7 +23,6 @@ distribution_spec:
tool_runtime: tool_runtime:
- provider_type: remote::brave-search - provider_type: remote::brave-search
- provider_type: remote::tavily-search - provider_type: remote::tavily-search
- provider_type: inline::rag-runtime
- provider_type: remote::model-context-protocol - provider_type: remote::model-context-protocol
files: files:
- provider_type: inline::localfs - provider_type: inline::localfs

View file

@ -83,8 +83,6 @@ providers:
config: config:
api_key: ${env.TAVILY_SEARCH_API_KEY:=} api_key: ${env.TAVILY_SEARCH_API_KEY:=}
max_results: 3 max_results: 3
- provider_id: rag-runtime
provider_type: inline::rag-runtime
- provider_id: model-context-protocol - provider_id: model-context-protocol
provider_type: remote::model-context-protocol provider_type: remote::model-context-protocol
files: files:
@ -125,8 +123,6 @@ registered_resources:
tool_groups: tool_groups:
- toolgroup_id: builtin::websearch - toolgroup_id: builtin::websearch
provider_id: tavily-search provider_id: tavily-search
- toolgroup_id: builtin::rag
provider_id: rag-runtime
server: server:
port: 8321 port: 8321
telemetry: telemetry:

View file

@ -33,7 +33,6 @@ def get_distribution_template(name: str = "watsonx") -> DistributionTemplate:
"tool_runtime": [ "tool_runtime": [
BuildProvider(provider_type="remote::brave-search"), BuildProvider(provider_type="remote::brave-search"),
BuildProvider(provider_type="remote::tavily-search"), BuildProvider(provider_type="remote::tavily-search"),
BuildProvider(provider_type="inline::rag-runtime"),
BuildProvider(provider_type="remote::model-context-protocol"), BuildProvider(provider_type="remote::model-context-protocol"),
], ],
"files": [BuildProvider(provider_type="inline::localfs")], "files": [BuildProvider(provider_type="inline::localfs")],
@ -50,10 +49,6 @@ def get_distribution_template(name: str = "watsonx") -> DistributionTemplate:
toolgroup_id="builtin::websearch", toolgroup_id="builtin::websearch",
provider_id="tavily-search", provider_id="tavily-search",
), ),
ToolGroupInput(
toolgroup_id="builtin::rag",
provider_id="rag-runtime",
),
] ]
files_provider = Provider( files_provider = Provider(

View file

@ -17,7 +17,7 @@ from llama_stack.apis.models import Model
from llama_stack.apis.scoring_functions import ScoringFn from llama_stack.apis.scoring_functions import ScoringFn
from llama_stack.apis.shields import Shield from llama_stack.apis.shields import Shield
from llama_stack.apis.tools import ToolGroup from llama_stack.apis.tools import ToolGroup
from llama_stack.apis.vector_dbs import VectorDB from llama_stack.apis.vector_stores import VectorStore
from llama_stack.schema_utils import json_schema_type from llama_stack.schema_utils import json_schema_type
@ -68,10 +68,10 @@ class ShieldsProtocolPrivate(Protocol):
async def unregister_shield(self, identifier: str) -> None: ... async def unregister_shield(self, identifier: str) -> None: ...
class VectorDBsProtocolPrivate(Protocol): class VectorStoresProtocolPrivate(Protocol):
async def register_vector_db(self, vector_db: VectorDB) -> None: ... async def register_vector_store(self, vector_store: VectorStore) -> None: ...
async def unregister_vector_db(self, vector_db_id: str) -> None: ... async def unregister_vector_store(self, vector_store_id: str) -> None: ...
class DatasetsProtocolPrivate(Protocol): class DatasetsProtocolPrivate(Protocol):

View file

@ -1,5 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

View file

@ -1,19 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any
from llama_stack.providers.datatypes import Api
from .config import RagToolRuntimeConfig
async def get_provider_impl(config: RagToolRuntimeConfig, deps: dict[Api, Any]):
from .memory import MemoryToolRuntimeImpl
impl = MemoryToolRuntimeImpl(config, deps[Api.vector_io], deps[Api.inference], deps[Api.files])
await impl.initialize()
return impl

View file

@ -1,15 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any
from pydantic import BaseModel
class RagToolRuntimeConfig(BaseModel):
@classmethod
def sample_run_config(cls, __distro_dir__: str, **kwargs: Any) -> dict[str, Any]:
return {}

View file

@ -1,77 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from jinja2 import Template
from llama_stack.apis.common.content_types import InterleavedContent
from llama_stack.apis.inference import OpenAIChatCompletionRequestWithExtraBody, OpenAIUserMessageParam
from llama_stack.apis.tools.rag_tool import (
DefaultRAGQueryGeneratorConfig,
LLMRAGQueryGeneratorConfig,
RAGQueryGenerator,
RAGQueryGeneratorConfig,
)
from llama_stack.providers.utils.inference.prompt_adapter import (
interleaved_content_as_str,
)
async def generate_rag_query(
config: RAGQueryGeneratorConfig,
content: InterleavedContent,
**kwargs,
):
"""
Generates a query that will be used for
retrieving relevant information from the memory bank.
"""
if config.type == RAGQueryGenerator.default.value:
query = await default_rag_query_generator(config, content, **kwargs)
elif config.type == RAGQueryGenerator.llm.value:
query = await llm_rag_query_generator(config, content, **kwargs)
else:
raise NotImplementedError(f"Unsupported memory query generator {config.type}")
return query
async def default_rag_query_generator(
config: DefaultRAGQueryGeneratorConfig,
content: InterleavedContent,
**kwargs,
):
return interleaved_content_as_str(content, sep=config.separator)
async def llm_rag_query_generator(
config: LLMRAGQueryGeneratorConfig,
content: InterleavedContent,
**kwargs,
):
assert "inference_api" in kwargs, "LLMRAGQueryGenerator needs inference_api"
inference_api = kwargs["inference_api"]
messages = []
if isinstance(content, list):
messages = [interleaved_content_as_str(m) for m in content]
else:
messages = [interleaved_content_as_str(content)]
template = Template(config.template)
rendered_content: str = template.render({"messages": messages})
model = config.model
message = OpenAIUserMessageParam(content=rendered_content)
params = OpenAIChatCompletionRequestWithExtraBody(
model=model,
messages=[message],
stream=False,
)
response = await inference_api.openai_chat_completion(params)
query = response.choices[0].message.content
return query

View file

@ -1,332 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import asyncio
import base64
import io
import mimetypes
from typing import Any
import httpx
from fastapi import UploadFile
from pydantic import TypeAdapter
from llama_stack.apis.common.content_types import (
URL,
InterleavedContent,
InterleavedContentItem,
TextContentItem,
)
from llama_stack.apis.files import Files, OpenAIFilePurpose
from llama_stack.apis.inference import Inference
from llama_stack.apis.tools import (
ListToolDefsResponse,
RAGDocument,
RAGQueryConfig,
RAGQueryResult,
RAGToolRuntime,
ToolDef,
ToolGroup,
ToolInvocationResult,
ToolRuntime,
)
from llama_stack.apis.vector_io import (
QueryChunksResponse,
VectorIO,
VectorStoreChunkingStrategyStatic,
VectorStoreChunkingStrategyStaticConfig,
)
from llama_stack.log import get_logger
from llama_stack.providers.datatypes import ToolGroupsProtocolPrivate
from llama_stack.providers.utils.inference.prompt_adapter import interleaved_content_as_str
from llama_stack.providers.utils.memory.vector_store import parse_data_url
from .config import RagToolRuntimeConfig
from .context_retriever import generate_rag_query
log = get_logger(name=__name__, category="tool_runtime")
async def raw_data_from_doc(doc: RAGDocument) -> tuple[bytes, str]:
"""Get raw binary data and mime type from a RAGDocument for file upload."""
if isinstance(doc.content, URL):
if doc.content.uri.startswith("data:"):
parts = parse_data_url(doc.content.uri)
mime_type = parts["mimetype"]
data = parts["data"]
if parts["is_base64"]:
file_data = base64.b64decode(data)
else:
file_data = data.encode("utf-8")
return file_data, mime_type
else:
async with httpx.AsyncClient() as client:
r = await client.get(doc.content.uri)
r.raise_for_status()
mime_type = r.headers.get("content-type", "application/octet-stream")
return r.content, mime_type
else:
if isinstance(doc.content, str):
content_str = doc.content
else:
content_str = interleaved_content_as_str(doc.content)
if content_str.startswith("data:"):
parts = parse_data_url(content_str)
mime_type = parts["mimetype"]
data = parts["data"]
if parts["is_base64"]:
file_data = base64.b64decode(data)
else:
file_data = data.encode("utf-8")
return file_data, mime_type
else:
return content_str.encode("utf-8"), "text/plain"
class MemoryToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime, RAGToolRuntime):
def __init__(
self,
config: RagToolRuntimeConfig,
vector_io_api: VectorIO,
inference_api: Inference,
files_api: Files,
):
self.config = config
self.vector_io_api = vector_io_api
self.inference_api = inference_api
self.files_api = files_api
async def initialize(self):
pass
async def shutdown(self):
pass
async def register_toolgroup(self, toolgroup: ToolGroup) -> None:
pass
async def unregister_toolgroup(self, toolgroup_id: str) -> None:
return
async def insert(
self,
documents: list[RAGDocument],
vector_db_id: str,
chunk_size_in_tokens: int = 512,
) -> None:
if not documents:
return
for doc in documents:
try:
try:
file_data, mime_type = await raw_data_from_doc(doc)
except Exception as e:
log.error(f"Failed to extract content from document {doc.document_id}: {e}")
continue
file_extension = mimetypes.guess_extension(mime_type) or ".txt"
filename = doc.metadata.get("filename", f"{doc.document_id}{file_extension}")
file_obj = io.BytesIO(file_data)
file_obj.name = filename
upload_file = UploadFile(file=file_obj, filename=filename)
try:
created_file = await self.files_api.openai_upload_file(
file=upload_file, purpose=OpenAIFilePurpose.ASSISTANTS
)
except Exception as e:
log.error(f"Failed to upload file for document {doc.document_id}: {e}")
continue
chunking_strategy = VectorStoreChunkingStrategyStatic(
static=VectorStoreChunkingStrategyStaticConfig(
max_chunk_size_tokens=chunk_size_in_tokens,
chunk_overlap_tokens=chunk_size_in_tokens // 4,
)
)
try:
await self.vector_io_api.openai_attach_file_to_vector_store(
vector_store_id=vector_db_id,
file_id=created_file.id,
attributes=doc.metadata,
chunking_strategy=chunking_strategy,
)
except Exception as e:
log.error(
f"Failed to attach file {created_file.id} to vector store {vector_db_id} for document {doc.document_id}: {e}"
)
continue
except Exception as e:
log.error(f"Unexpected error processing document {doc.document_id}: {e}")
continue
async def query(
self,
content: InterleavedContent,
vector_db_ids: list[str],
query_config: RAGQueryConfig | None = None,
) -> RAGQueryResult:
if not vector_db_ids:
raise ValueError(
"No vector DBs were provided to the knowledge search tool. Please provide at least one vector DB ID."
)
query_config = query_config or RAGQueryConfig()
query = await generate_rag_query(
query_config.query_generator_config,
content,
inference_api=self.inference_api,
)
tasks = [
self.vector_io_api.query_chunks(
vector_db_id=vector_db_id,
query=query,
params={
"mode": query_config.mode,
"max_chunks": query_config.max_chunks,
"score_threshold": 0.0,
"ranker": query_config.ranker,
},
)
for vector_db_id in vector_db_ids
]
results: list[QueryChunksResponse] = await asyncio.gather(*tasks)
chunks = []
scores = []
for vector_db_id, result in zip(vector_db_ids, results, strict=False):
for chunk, score in zip(result.chunks, result.scores, strict=False):
if not hasattr(chunk, "metadata") or chunk.metadata is None:
chunk.metadata = {}
chunk.metadata["vector_db_id"] = vector_db_id
chunks.append(chunk)
scores.append(score)
if not chunks:
return RAGQueryResult(content=None)
# sort by score
chunks, scores = zip(*sorted(zip(chunks, scores, strict=False), key=lambda x: x[1], reverse=True), strict=False) # type: ignore
chunks = chunks[: query_config.max_chunks]
tokens = 0
picked: list[InterleavedContentItem] = [
TextContentItem(
text=f"knowledge_search tool found {len(chunks)} chunks:\nBEGIN of knowledge_search tool results.\n"
)
]
for i, chunk in enumerate(chunks):
metadata = chunk.metadata
tokens += metadata.get("token_count", 0)
tokens += metadata.get("metadata_token_count", 0)
if tokens > query_config.max_tokens_in_context:
log.error(
f"Using {len(picked)} chunks; reached max tokens in context: {tokens}",
)
break
# Add useful keys from chunk_metadata to metadata and remove some from metadata
chunk_metadata_keys_to_include_from_context = [
"chunk_id",
"document_id",
"source",
]
metadata_keys_to_exclude_from_context = [
"token_count",
"metadata_token_count",
"vector_db_id",
]
metadata_for_context = {}
for k in chunk_metadata_keys_to_include_from_context:
metadata_for_context[k] = getattr(chunk.chunk_metadata, k)
for k in metadata:
if k not in metadata_keys_to_exclude_from_context:
metadata_for_context[k] = metadata[k]
text_content = query_config.chunk_template.format(index=i + 1, chunk=chunk, metadata=metadata_for_context)
picked.append(TextContentItem(text=text_content))
picked.append(TextContentItem(text="END of knowledge_search tool results.\n"))
picked.append(
TextContentItem(
text=f'The above results were retrieved to help answer the user\'s query: "{interleaved_content_as_str(content)}". Use them as supporting information only in answering this query.\n',
)
)
return RAGQueryResult(
content=picked,
metadata={
"document_ids": [c.document_id for c in chunks[: len(picked)]],
"chunks": [c.content for c in chunks[: len(picked)]],
"scores": scores[: len(picked)],
"vector_db_ids": [c.metadata["vector_db_id"] for c in chunks[: len(picked)]],
},
)
async def list_runtime_tools(
self, tool_group_id: str | None = None, mcp_endpoint: URL | None = None
) -> ListToolDefsResponse:
# Parameters are not listed since these methods are not yet invoked automatically
# by the LLM. The method is only implemented so things like /tools can list without
# encountering fatals.
return ListToolDefsResponse(
data=[
ToolDef(
name="insert_into_memory",
description="Insert documents into memory",
),
ToolDef(
name="knowledge_search",
description="Search for information in a database.",
input_schema={
"type": "object",
"properties": {
"query": {
"type": "string",
"description": "The query to search for. Can be a natural language sentence or keywords.",
}
},
"required": ["query"],
},
),
]
)
async def invoke_tool(self, tool_name: str, kwargs: dict[str, Any]) -> ToolInvocationResult:
vector_db_ids = kwargs.get("vector_db_ids", [])
query_config = kwargs.get("query_config")
if query_config:
query_config = TypeAdapter(RAGQueryConfig).validate_python(query_config)
else:
query_config = RAGQueryConfig()
query = kwargs["query"]
result = await self.query(
content=query,
vector_db_ids=vector_db_ids,
query_config=query_config,
)
return ToolInvocationResult(
content=result.content or [],
metadata={
**(result.metadata or {}),
"citation_files": getattr(result, "citation_files", None),
},
)

View file

@ -17,21 +17,21 @@ from numpy.typing import NDArray
from llama_stack.apis.common.errors import VectorStoreNotFoundError from llama_stack.apis.common.errors import VectorStoreNotFoundError
from llama_stack.apis.files import Files from llama_stack.apis.files import Files
from llama_stack.apis.inference import Inference, InterleavedContent from llama_stack.apis.inference import Inference, InterleavedContent
from llama_stack.apis.vector_dbs import VectorDB
from llama_stack.apis.vector_io import Chunk, QueryChunksResponse, VectorIO from llama_stack.apis.vector_io import Chunk, QueryChunksResponse, VectorIO
from llama_stack.apis.vector_stores import VectorStore
from llama_stack.log import get_logger from llama_stack.log import get_logger
from llama_stack.providers.datatypes import HealthResponse, HealthStatus, VectorDBsProtocolPrivate from llama_stack.providers.datatypes import HealthResponse, HealthStatus, VectorStoresProtocolPrivate
from llama_stack.providers.utils.kvstore import kvstore_impl from llama_stack.providers.utils.kvstore import kvstore_impl
from llama_stack.providers.utils.kvstore.api import KVStore from llama_stack.providers.utils.kvstore.api import KVStore
from llama_stack.providers.utils.memory.openai_vector_store_mixin import OpenAIVectorStoreMixin from llama_stack.providers.utils.memory.openai_vector_store_mixin import OpenAIVectorStoreMixin
from llama_stack.providers.utils.memory.vector_store import ChunkForDeletion, EmbeddingIndex, VectorDBWithIndex from llama_stack.providers.utils.memory.vector_store import ChunkForDeletion, EmbeddingIndex, VectorStoreWithIndex
from .config import FaissVectorIOConfig from .config import FaissVectorIOConfig
logger = get_logger(name=__name__, category="vector_io") logger = get_logger(name=__name__, category="vector_io")
VERSION = "v3" VERSION = "v3"
VECTOR_DBS_PREFIX = f"vector_dbs:{VERSION}::" VECTOR_DBS_PREFIX = f"vector_stores:{VERSION}::"
FAISS_INDEX_PREFIX = f"faiss_index:{VERSION}::" FAISS_INDEX_PREFIX = f"faiss_index:{VERSION}::"
OPENAI_VECTOR_STORES_PREFIX = f"openai_vector_stores:{VERSION}::" OPENAI_VECTOR_STORES_PREFIX = f"openai_vector_stores:{VERSION}::"
OPENAI_VECTOR_STORES_FILES_PREFIX = f"openai_vector_stores_files:{VERSION}::" OPENAI_VECTOR_STORES_FILES_PREFIX = f"openai_vector_stores_files:{VERSION}::"
@ -176,28 +176,28 @@ class FaissIndex(EmbeddingIndex):
) )
class FaissVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolPrivate): class FaissVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorStoresProtocolPrivate):
def __init__(self, config: FaissVectorIOConfig, inference_api: Inference, files_api: Files | None) -> None: def __init__(self, config: FaissVectorIOConfig, inference_api: Inference, files_api: Files | None) -> None:
super().__init__(files_api=files_api, kvstore=None) super().__init__(files_api=files_api, kvstore=None)
self.config = config self.config = config
self.inference_api = inference_api self.inference_api = inference_api
self.cache: dict[str, VectorDBWithIndex] = {} self.cache: dict[str, VectorStoreWithIndex] = {}
async def initialize(self) -> None: async def initialize(self) -> None:
self.kvstore = await kvstore_impl(self.config.persistence) self.kvstore = await kvstore_impl(self.config.persistence)
# Load existing banks from kvstore # Load existing banks from kvstore
start_key = VECTOR_DBS_PREFIX start_key = VECTOR_DBS_PREFIX
end_key = f"{VECTOR_DBS_PREFIX}\xff" end_key = f"{VECTOR_DBS_PREFIX}\xff"
stored_vector_dbs = await self.kvstore.values_in_range(start_key, end_key) stored_vector_stores = await self.kvstore.values_in_range(start_key, end_key)
for vector_db_data in stored_vector_dbs: for vector_store_data in stored_vector_stores:
vector_db = VectorDB.model_validate_json(vector_db_data) vector_store = VectorStore.model_validate_json(vector_store_data)
index = VectorDBWithIndex( index = VectorStoreWithIndex(
vector_db, vector_store,
await FaissIndex.create(vector_db.embedding_dimension, self.kvstore, vector_db.identifier), await FaissIndex.create(vector_store.embedding_dimension, self.kvstore, vector_store.identifier),
self.inference_api, self.inference_api,
) )
self.cache[vector_db.identifier] = index self.cache[vector_store.identifier] = index
# Load existing OpenAI vector stores into the in-memory cache # Load existing OpenAI vector stores into the in-memory cache
await self.initialize_openai_vector_stores() await self.initialize_openai_vector_stores()
@ -222,32 +222,31 @@ class FaissVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolPr
except Exception as e: except Exception as e:
return HealthResponse(status=HealthStatus.ERROR, message=f"Health check failed: {str(e)}") return HealthResponse(status=HealthStatus.ERROR, message=f"Health check failed: {str(e)}")
async def register_vector_db(self, vector_db: VectorDB) -> None: async def register_vector_store(self, vector_store: VectorStore) -> None:
assert self.kvstore is not None assert self.kvstore is not None
key = f"{VECTOR_DBS_PREFIX}{vector_db.identifier}" key = f"{VECTOR_DBS_PREFIX}{vector_store.identifier}"
await self.kvstore.set(key=key, value=vector_db.model_dump_json()) await self.kvstore.set(key=key, value=vector_store.model_dump_json())
# Store in cache # Store in cache
self.cache[vector_db.identifier] = VectorDBWithIndex( self.cache[vector_store.identifier] = VectorStoreWithIndex(
vector_db=vector_db, vector_store=vector_store,
index=await FaissIndex.create(vector_db.embedding_dimension, self.kvstore, vector_db.identifier), index=await FaissIndex.create(vector_store.embedding_dimension, self.kvstore, vector_store.identifier),
inference_api=self.inference_api, inference_api=self.inference_api,
) )
async def list_vector_dbs(self) -> list[VectorDB]: async def list_vector_stores(self) -> list[VectorStore]:
return [i.vector_db for i in self.cache.values()] return [i.vector_store for i in self.cache.values()]
async def unregister_vector_db(self, vector_db_id: str) -> None: async def unregister_vector_store(self, vector_store_id: str) -> None:
assert self.kvstore is not None assert self.kvstore is not None
if vector_db_id not in self.cache: if vector_store_id not in self.cache:
logger.warning(f"Vector DB {vector_db_id} not found")
return return
await self.cache[vector_db_id].index.delete() await self.cache[vector_store_id].index.delete()
del self.cache[vector_db_id] del self.cache[vector_store_id]
await self.kvstore.delete(f"{VECTOR_DBS_PREFIX}{vector_db_id}") await self.kvstore.delete(f"{VECTOR_DBS_PREFIX}{vector_store_id}")
async def insert_chunks(self, vector_db_id: str, chunks: list[Chunk], ttl_seconds: int | None = None) -> None: async def insert_chunks(self, vector_db_id: str, chunks: list[Chunk], ttl_seconds: int | None = None) -> None:
index = self.cache.get(vector_db_id) index = self.cache.get(vector_db_id)

View file

@ -17,10 +17,10 @@ from numpy.typing import NDArray
from llama_stack.apis.common.errors import VectorStoreNotFoundError from llama_stack.apis.common.errors import VectorStoreNotFoundError
from llama_stack.apis.files import Files from llama_stack.apis.files import Files
from llama_stack.apis.inference import Inference from llama_stack.apis.inference import Inference
from llama_stack.apis.vector_dbs import VectorDB
from llama_stack.apis.vector_io import Chunk, QueryChunksResponse, VectorIO from llama_stack.apis.vector_io import Chunk, QueryChunksResponse, VectorIO
from llama_stack.apis.vector_stores import VectorStore
from llama_stack.log import get_logger from llama_stack.log import get_logger
from llama_stack.providers.datatypes import VectorDBsProtocolPrivate from llama_stack.providers.datatypes import VectorStoresProtocolPrivate
from llama_stack.providers.utils.kvstore import kvstore_impl from llama_stack.providers.utils.kvstore import kvstore_impl
from llama_stack.providers.utils.kvstore.api import KVStore from llama_stack.providers.utils.kvstore.api import KVStore
from llama_stack.providers.utils.memory.openai_vector_store_mixin import OpenAIVectorStoreMixin from llama_stack.providers.utils.memory.openai_vector_store_mixin import OpenAIVectorStoreMixin
@ -28,7 +28,7 @@ from llama_stack.providers.utils.memory.vector_store import (
RERANKER_TYPE_RRF, RERANKER_TYPE_RRF,
ChunkForDeletion, ChunkForDeletion,
EmbeddingIndex, EmbeddingIndex,
VectorDBWithIndex, VectorStoreWithIndex,
) )
from llama_stack.providers.utils.vector_io.vector_utils import WeightedInMemoryAggregator from llama_stack.providers.utils.vector_io.vector_utils import WeightedInMemoryAggregator
@ -41,7 +41,7 @@ HYBRID_SEARCH = "hybrid"
SEARCH_MODES = {VECTOR_SEARCH, KEYWORD_SEARCH, HYBRID_SEARCH} SEARCH_MODES = {VECTOR_SEARCH, KEYWORD_SEARCH, HYBRID_SEARCH}
VERSION = "v3" VERSION = "v3"
VECTOR_DBS_PREFIX = f"vector_dbs:sqlite_vec:{VERSION}::" VECTOR_DBS_PREFIX = f"vector_stores:sqlite_vec:{VERSION}::"
VECTOR_INDEX_PREFIX = f"vector_index:sqlite_vec:{VERSION}::" VECTOR_INDEX_PREFIX = f"vector_index:sqlite_vec:{VERSION}::"
OPENAI_VECTOR_STORES_PREFIX = f"openai_vector_stores:sqlite_vec:{VERSION}::" OPENAI_VECTOR_STORES_PREFIX = f"openai_vector_stores:sqlite_vec:{VERSION}::"
OPENAI_VECTOR_STORES_FILES_PREFIX = f"openai_vector_stores_files:sqlite_vec:{VERSION}::" OPENAI_VECTOR_STORES_FILES_PREFIX = f"openai_vector_stores_files:sqlite_vec:{VERSION}::"
@ -374,32 +374,32 @@ class SQLiteVecIndex(EmbeddingIndex):
await asyncio.to_thread(_delete_chunks) await asyncio.to_thread(_delete_chunks)
class SQLiteVecVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolPrivate): class SQLiteVecVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorStoresProtocolPrivate):
""" """
A VectorIO implementation using SQLite + sqlite_vec. A VectorIO implementation using SQLite + sqlite_vec.
This class handles vector database registration (with metadata stored in a table named `vector_dbs`) This class handles vector database registration (with metadata stored in a table named `vector_stores`)
and creates a cache of VectorDBWithIndex instances (each wrapping a SQLiteVecIndex). and creates a cache of VectorStoreWithIndex instances (each wrapping a SQLiteVecIndex).
""" """
def __init__(self, config, inference_api: Inference, files_api: Files | None) -> None: def __init__(self, config, inference_api: Inference, files_api: Files | None) -> None:
super().__init__(files_api=files_api, kvstore=None) super().__init__(files_api=files_api, kvstore=None)
self.config = config self.config = config
self.inference_api = inference_api self.inference_api = inference_api
self.cache: dict[str, VectorDBWithIndex] = {} self.cache: dict[str, VectorStoreWithIndex] = {}
self.vector_db_store = None self.vector_store_table = None
async def initialize(self) -> None: async def initialize(self) -> None:
self.kvstore = await kvstore_impl(self.config.persistence) self.kvstore = await kvstore_impl(self.config.persistence)
start_key = VECTOR_DBS_PREFIX start_key = VECTOR_DBS_PREFIX
end_key = f"{VECTOR_DBS_PREFIX}\xff" end_key = f"{VECTOR_DBS_PREFIX}\xff"
stored_vector_dbs = await self.kvstore.values_in_range(start_key, end_key) stored_vector_stores = await self.kvstore.values_in_range(start_key, end_key)
for db_json in stored_vector_dbs: for db_json in stored_vector_stores:
vector_db = VectorDB.model_validate_json(db_json) vector_store = VectorStore.model_validate_json(db_json)
index = await SQLiteVecIndex.create( index = await SQLiteVecIndex.create(
vector_db.embedding_dimension, self.config.db_path, vector_db.identifier vector_store.embedding_dimension, self.config.db_path, vector_store.identifier
) )
self.cache[vector_db.identifier] = VectorDBWithIndex(vector_db, index, self.inference_api) self.cache[vector_store.identifier] = VectorStoreWithIndex(vector_store, index, self.inference_api)
# Load existing OpenAI vector stores into the in-memory cache # Load existing OpenAI vector stores into the in-memory cache
await self.initialize_openai_vector_stores() await self.initialize_openai_vector_stores()
@ -408,63 +408,64 @@ class SQLiteVecVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtoc
# Clean up mixin resources (file batch tasks) # Clean up mixin resources (file batch tasks)
await super().shutdown() await super().shutdown()
async def list_vector_dbs(self) -> list[VectorDB]: async def list_vector_stores(self) -> list[VectorStore]:
return [v.vector_db for v in self.cache.values()] return [v.vector_store for v in self.cache.values()]
async def register_vector_db(self, vector_db: VectorDB) -> None: async def register_vector_store(self, vector_store: VectorStore) -> None:
index = await SQLiteVecIndex.create(vector_db.embedding_dimension, self.config.db_path, vector_db.identifier) index = await SQLiteVecIndex.create(
self.cache[vector_db.identifier] = VectorDBWithIndex(vector_db, index, self.inference_api) vector_store.embedding_dimension, self.config.db_path, vector_store.identifier
)
self.cache[vector_store.identifier] = VectorStoreWithIndex(vector_store, index, self.inference_api)
async def _get_and_cache_vector_db_index(self, vector_db_id: str) -> VectorDBWithIndex | None: async def _get_and_cache_vector_store_index(self, vector_store_id: str) -> VectorStoreWithIndex | None:
if vector_db_id in self.cache: if vector_store_id in self.cache:
return self.cache[vector_db_id] return self.cache[vector_store_id]
if self.vector_db_store is None: if self.vector_store_table is None:
raise VectorStoreNotFoundError(vector_db_id) raise VectorStoreNotFoundError(vector_store_id)
vector_db = self.vector_db_store.get_vector_db(vector_db_id) vector_store = self.vector_store_table.get_vector_store(vector_store_id)
if not vector_db: if not vector_store:
raise VectorStoreNotFoundError(vector_db_id) raise VectorStoreNotFoundError(vector_store_id)
index = VectorDBWithIndex( index = VectorStoreWithIndex(
vector_db=vector_db, vector_store=vector_store,
index=SQLiteVecIndex( index=SQLiteVecIndex(
dimension=vector_db.embedding_dimension, dimension=vector_store.embedding_dimension,
db_path=self.config.db_path, db_path=self.config.db_path,
bank_id=vector_db.identifier, bank_id=vector_store.identifier,
kvstore=self.kvstore, kvstore=self.kvstore,
), ),
inference_api=self.inference_api, inference_api=self.inference_api,
) )
self.cache[vector_db_id] = index self.cache[vector_store_id] = index
return index return index
async def unregister_vector_db(self, vector_db_id: str) -> None: async def unregister_vector_store(self, vector_store_id: str) -> None:
if vector_db_id not in self.cache: if vector_store_id not in self.cache:
logger.warning(f"Vector DB {vector_db_id} not found")
return return
await self.cache[vector_db_id].index.delete() await self.cache[vector_store_id].index.delete()
del self.cache[vector_db_id] del self.cache[vector_store_id]
async def insert_chunks(self, vector_db_id: str, chunks: list[Chunk], ttl_seconds: int | None = None) -> None: async def insert_chunks(self, vector_db_id: str, chunks: list[Chunk], ttl_seconds: int | None = None) -> None:
index = await self._get_and_cache_vector_db_index(vector_db_id) index = await self._get_and_cache_vector_store_index(vector_db_id)
if not index: if not index:
raise VectorStoreNotFoundError(vector_db_id) raise VectorStoreNotFoundError(vector_db_id)
# The VectorDBWithIndex helper is expected to compute embeddings via the inference_api # The VectorStoreWithIndex helper is expected to compute embeddings via the inference_api
# and then call our index's add_chunks. # and then call our index's add_chunks.
await index.insert_chunks(chunks) await index.insert_chunks(chunks)
async def query_chunks( async def query_chunks(
self, vector_db_id: str, query: Any, params: dict[str, Any] | None = None self, vector_db_id: str, query: Any, params: dict[str, Any] | None = None
) -> QueryChunksResponse: ) -> QueryChunksResponse:
index = await self._get_and_cache_vector_db_index(vector_db_id) index = await self._get_and_cache_vector_store_index(vector_db_id)
if not index: if not index:
raise VectorStoreNotFoundError(vector_db_id) raise VectorStoreNotFoundError(vector_db_id)
return await index.query_chunks(query, params) return await index.query_chunks(query, params)
async def delete_chunks(self, store_id: str, chunks_for_deletion: list[ChunkForDeletion]) -> None: async def delete_chunks(self, store_id: str, chunks_for_deletion: list[ChunkForDeletion]) -> None:
"""Delete chunks from a sqlite_vec index.""" """Delete chunks from a sqlite_vec index."""
index = await self._get_and_cache_vector_db_index(store_id) index = await self._get_and_cache_vector_store_index(store_id)
if not index: if not index:
raise VectorStoreNotFoundError(store_id) raise VectorStoreNotFoundError(store_id)

View file

@ -42,6 +42,7 @@ def available_providers() -> list[ProviderSpec]:
# CrossEncoder depends on torchao.quantization # CrossEncoder depends on torchao.quantization
pip_packages=[ pip_packages=[
"torch torchvision torchao>=0.12.0 --extra-index-url https://download.pytorch.org/whl/cpu", "torch torchvision torchao>=0.12.0 --extra-index-url https://download.pytorch.org/whl/cpu",
"numpy tqdm transformers",
"sentence-transformers --no-deps", "sentence-transformers --no-deps",
# required by some SentenceTransformers architectures for tensor rearrange/merge ops # required by some SentenceTransformers architectures for tensor rearrange/merge ops
"einops", "einops",

View file

@ -7,33 +7,13 @@
from llama_stack.providers.datatypes import ( from llama_stack.providers.datatypes import (
Api, Api,
InlineProviderSpec,
ProviderSpec, ProviderSpec,
RemoteProviderSpec, RemoteProviderSpec,
) )
from llama_stack.providers.registry.vector_io import DEFAULT_VECTOR_IO_DEPS
def available_providers() -> list[ProviderSpec]: def available_providers() -> list[ProviderSpec]:
return [ return [
InlineProviderSpec(
api=Api.tool_runtime,
provider_type="inline::rag-runtime",
pip_packages=DEFAULT_VECTOR_IO_DEPS
+ [
"tqdm",
"numpy",
"scikit-learn",
"scipy",
"nltk",
"sentencepiece",
"transformers",
],
module="llama_stack.providers.inline.tool_runtime.rag",
config_class="llama_stack.providers.inline.tool_runtime.rag.config.RagToolRuntimeConfig",
api_dependencies=[Api.vector_io, Api.inference, Api.files],
description="RAG (Retrieval-Augmented Generation) tool runtime for document ingestion, chunking, and semantic search.",
),
RemoteProviderSpec( RemoteProviderSpec(
api=Api.tool_runtime, api=Api.tool_runtime,
adapter_type="brave-search", adapter_type="brave-search",

View file

@ -119,7 +119,7 @@ Datasets that can fit in memory, frequent reads | Faiss | Optimized for speed, i
#### Empirical Example #### Empirical Example
Consider the histogram below in which 10,000 randomly generated strings were inserted Consider the histogram below in which 10,000 randomly generated strings were inserted
in batches of 100 into both Faiss and sqlite-vec using `client.tool_runtime.rag_tool.insert()`. in batches of 100 into both Faiss and sqlite-vec.
```{image} ../../../../_static/providers/vector_io/write_time_comparison_sqlite-vec-faiss.png ```{image} ../../../../_static/providers/vector_io/write_time_comparison_sqlite-vec-faiss.png
:alt: Comparison of SQLite-Vec and Faiss write times :alt: Comparison of SQLite-Vec and Faiss write times

View file

@ -13,15 +13,15 @@ from numpy.typing import NDArray
from llama_stack.apis.files import Files from llama_stack.apis.files import Files
from llama_stack.apis.inference import Inference, InterleavedContent from llama_stack.apis.inference import Inference, InterleavedContent
from llama_stack.apis.vector_dbs import VectorDB
from llama_stack.apis.vector_io import Chunk, QueryChunksResponse, VectorIO from llama_stack.apis.vector_io import Chunk, QueryChunksResponse, VectorIO
from llama_stack.apis.vector_stores import VectorStore
from llama_stack.log import get_logger from llama_stack.log import get_logger
from llama_stack.providers.datatypes import VectorDBsProtocolPrivate from llama_stack.providers.datatypes import VectorStoresProtocolPrivate
from llama_stack.providers.inline.vector_io.chroma import ChromaVectorIOConfig as InlineChromaVectorIOConfig from llama_stack.providers.inline.vector_io.chroma import ChromaVectorIOConfig as InlineChromaVectorIOConfig
from llama_stack.providers.utils.kvstore import kvstore_impl from llama_stack.providers.utils.kvstore import kvstore_impl
from llama_stack.providers.utils.kvstore.api import KVStore from llama_stack.providers.utils.kvstore.api import KVStore
from llama_stack.providers.utils.memory.openai_vector_store_mixin import OpenAIVectorStoreMixin from llama_stack.providers.utils.memory.openai_vector_store_mixin import OpenAIVectorStoreMixin
from llama_stack.providers.utils.memory.vector_store import ChunkForDeletion, EmbeddingIndex, VectorDBWithIndex from llama_stack.providers.utils.memory.vector_store import ChunkForDeletion, EmbeddingIndex, VectorStoreWithIndex
from .config import ChromaVectorIOConfig as RemoteChromaVectorIOConfig from .config import ChromaVectorIOConfig as RemoteChromaVectorIOConfig
@ -30,7 +30,7 @@ log = get_logger(name=__name__, category="vector_io::chroma")
ChromaClientType = chromadb.api.AsyncClientAPI | chromadb.api.ClientAPI ChromaClientType = chromadb.api.AsyncClientAPI | chromadb.api.ClientAPI
VERSION = "v3" VERSION = "v3"
VECTOR_DBS_PREFIX = f"vector_dbs:chroma:{VERSION}::" VECTOR_DBS_PREFIX = f"vector_stores:chroma:{VERSION}::"
VECTOR_INDEX_PREFIX = f"vector_index:chroma:{VERSION}::" VECTOR_INDEX_PREFIX = f"vector_index:chroma:{VERSION}::"
OPENAI_VECTOR_STORES_PREFIX = f"openai_vector_stores:chroma:{VERSION}::" OPENAI_VECTOR_STORES_PREFIX = f"openai_vector_stores:chroma:{VERSION}::"
OPENAI_VECTOR_STORES_FILES_PREFIX = f"openai_vector_stores_files:chroma:{VERSION}::" OPENAI_VECTOR_STORES_FILES_PREFIX = f"openai_vector_stores_files:chroma:{VERSION}::"
@ -114,7 +114,7 @@ class ChromaIndex(EmbeddingIndex):
raise NotImplementedError("Hybrid search is not supported in Chroma") raise NotImplementedError("Hybrid search is not supported in Chroma")
class ChromaVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolPrivate): class ChromaVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorStoresProtocolPrivate):
def __init__( def __init__(
self, self,
config: RemoteChromaVectorIOConfig | InlineChromaVectorIOConfig, config: RemoteChromaVectorIOConfig | InlineChromaVectorIOConfig,
@ -127,11 +127,11 @@ class ChromaVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP
self.inference_api = inference_api self.inference_api = inference_api
self.client = None self.client = None
self.cache = {} self.cache = {}
self.vector_db_store = None self.vector_store_table = None
async def initialize(self) -> None: async def initialize(self) -> None:
self.kvstore = await kvstore_impl(self.config.persistence) self.kvstore = await kvstore_impl(self.config.persistence)
self.vector_db_store = self.kvstore self.vector_store_table = self.kvstore
if isinstance(self.config, RemoteChromaVectorIOConfig): if isinstance(self.config, RemoteChromaVectorIOConfig):
log.info(f"Connecting to Chroma server at: {self.config.url}") log.info(f"Connecting to Chroma server at: {self.config.url}")
@ -151,26 +151,26 @@ class ChromaVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP
# Clean up mixin resources (file batch tasks) # Clean up mixin resources (file batch tasks)
await super().shutdown() await super().shutdown()
async def register_vector_db(self, vector_db: VectorDB) -> None: async def register_vector_store(self, vector_store: VectorStore) -> None:
collection = await maybe_await( collection = await maybe_await(
self.client.get_or_create_collection( self.client.get_or_create_collection(
name=vector_db.identifier, metadata={"vector_db": vector_db.model_dump_json()} name=vector_store.identifier, metadata={"vector_store": vector_store.model_dump_json()}
) )
) )
self.cache[vector_db.identifier] = VectorDBWithIndex( self.cache[vector_store.identifier] = VectorStoreWithIndex(
vector_db, ChromaIndex(self.client, collection), self.inference_api vector_store, ChromaIndex(self.client, collection), self.inference_api
) )
async def unregister_vector_db(self, vector_db_id: str) -> None: async def unregister_vector_store(self, vector_store_id: str) -> None:
if vector_db_id not in self.cache: if vector_store_id not in self.cache:
log.warning(f"Vector DB {vector_db_id} not found") log.warning(f"Vector DB {vector_store_id} not found")
return return
await self.cache[vector_db_id].index.delete() await self.cache[vector_store_id].index.delete()
del self.cache[vector_db_id] del self.cache[vector_store_id]
async def insert_chunks(self, vector_db_id: str, chunks: list[Chunk], ttl_seconds: int | None = None) -> None: async def insert_chunks(self, vector_db_id: str, chunks: list[Chunk], ttl_seconds: int | None = None) -> None:
index = await self._get_and_cache_vector_db_index(vector_db_id) index = await self._get_and_cache_vector_store_index(vector_db_id)
if index is None: if index is None:
raise ValueError(f"Vector DB {vector_db_id} not found in Chroma") raise ValueError(f"Vector DB {vector_db_id} not found in Chroma")
@ -179,30 +179,30 @@ class ChromaVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP
async def query_chunks( async def query_chunks(
self, vector_db_id: str, query: InterleavedContent, params: dict[str, Any] | None = None self, vector_db_id: str, query: InterleavedContent, params: dict[str, Any] | None = None
) -> QueryChunksResponse: ) -> QueryChunksResponse:
index = await self._get_and_cache_vector_db_index(vector_db_id) index = await self._get_and_cache_vector_store_index(vector_db_id)
if index is None: if index is None:
raise ValueError(f"Vector DB {vector_db_id} not found in Chroma") raise ValueError(f"Vector DB {vector_db_id} not found in Chroma")
return await index.query_chunks(query, params) return await index.query_chunks(query, params)
async def _get_and_cache_vector_db_index(self, vector_db_id: str) -> VectorDBWithIndex: async def _get_and_cache_vector_store_index(self, vector_store_id: str) -> VectorStoreWithIndex:
if vector_db_id in self.cache: if vector_store_id in self.cache:
return self.cache[vector_db_id] return self.cache[vector_store_id]
vector_db = await self.vector_db_store.get_vector_db(vector_db_id) vector_store = await self.vector_store_table.get_vector_store(vector_store_id)
if not vector_db: if not vector_store:
raise ValueError(f"Vector DB {vector_db_id} not found in Llama Stack") raise ValueError(f"Vector DB {vector_store_id} not found in Llama Stack")
collection = await maybe_await(self.client.get_collection(vector_db_id)) collection = await maybe_await(self.client.get_collection(vector_store_id))
if not collection: if not collection:
raise ValueError(f"Vector DB {vector_db_id} not found in Chroma") raise ValueError(f"Vector DB {vector_store_id} not found in Chroma")
index = VectorDBWithIndex(vector_db, ChromaIndex(self.client, collection), self.inference_api) index = VectorStoreWithIndex(vector_store, ChromaIndex(self.client, collection), self.inference_api)
self.cache[vector_db_id] = index self.cache[vector_store_id] = index
return index return index
async def delete_chunks(self, store_id: str, chunks_for_deletion: list[ChunkForDeletion]) -> None: async def delete_chunks(self, store_id: str, chunks_for_deletion: list[ChunkForDeletion]) -> None:
"""Delete chunks from a Chroma vector store.""" """Delete chunks from a Chroma vector store."""
index = await self._get_and_cache_vector_db_index(store_id) index = await self._get_and_cache_vector_store_index(store_id)
if not index: if not index:
raise ValueError(f"Vector DB {store_id} not found") raise ValueError(f"Vector DB {store_id} not found")

View file

@ -14,10 +14,10 @@ from pymilvus import AnnSearchRequest, DataType, Function, FunctionType, MilvusC
from llama_stack.apis.common.errors import VectorStoreNotFoundError from llama_stack.apis.common.errors import VectorStoreNotFoundError
from llama_stack.apis.files import Files from llama_stack.apis.files import Files
from llama_stack.apis.inference import Inference, InterleavedContent from llama_stack.apis.inference import Inference, InterleavedContent
from llama_stack.apis.vector_dbs import VectorDB
from llama_stack.apis.vector_io import Chunk, QueryChunksResponse, VectorIO from llama_stack.apis.vector_io import Chunk, QueryChunksResponse, VectorIO
from llama_stack.apis.vector_stores import VectorStore
from llama_stack.log import get_logger from llama_stack.log import get_logger
from llama_stack.providers.datatypes import VectorDBsProtocolPrivate from llama_stack.providers.datatypes import VectorStoresProtocolPrivate
from llama_stack.providers.inline.vector_io.milvus import MilvusVectorIOConfig as InlineMilvusVectorIOConfig from llama_stack.providers.inline.vector_io.milvus import MilvusVectorIOConfig as InlineMilvusVectorIOConfig
from llama_stack.providers.utils.kvstore import kvstore_impl from llama_stack.providers.utils.kvstore import kvstore_impl
from llama_stack.providers.utils.kvstore.api import KVStore from llama_stack.providers.utils.kvstore.api import KVStore
@ -26,7 +26,7 @@ from llama_stack.providers.utils.memory.vector_store import (
RERANKER_TYPE_WEIGHTED, RERANKER_TYPE_WEIGHTED,
ChunkForDeletion, ChunkForDeletion,
EmbeddingIndex, EmbeddingIndex,
VectorDBWithIndex, VectorStoreWithIndex,
) )
from llama_stack.providers.utils.vector_io.vector_utils import sanitize_collection_name from llama_stack.providers.utils.vector_io.vector_utils import sanitize_collection_name
@ -35,7 +35,7 @@ from .config import MilvusVectorIOConfig as RemoteMilvusVectorIOConfig
logger = get_logger(name=__name__, category="vector_io::milvus") logger = get_logger(name=__name__, category="vector_io::milvus")
VERSION = "v3" VERSION = "v3"
VECTOR_DBS_PREFIX = f"vector_dbs:milvus:{VERSION}::" VECTOR_DBS_PREFIX = f"vector_stores:milvus:{VERSION}::"
VECTOR_INDEX_PREFIX = f"vector_index:milvus:{VERSION}::" VECTOR_INDEX_PREFIX = f"vector_index:milvus:{VERSION}::"
OPENAI_VECTOR_STORES_PREFIX = f"openai_vector_stores:milvus:{VERSION}::" OPENAI_VECTOR_STORES_PREFIX = f"openai_vector_stores:milvus:{VERSION}::"
OPENAI_VECTOR_STORES_FILES_PREFIX = f"openai_vector_stores_files:milvus:{VERSION}::" OPENAI_VECTOR_STORES_FILES_PREFIX = f"openai_vector_stores_files:milvus:{VERSION}::"
@ -261,7 +261,7 @@ class MilvusIndex(EmbeddingIndex):
raise raise
class MilvusVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolPrivate): class MilvusVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorStoresProtocolPrivate):
def __init__( def __init__(
self, self,
config: RemoteMilvusVectorIOConfig | InlineMilvusVectorIOConfig, config: RemoteMilvusVectorIOConfig | InlineMilvusVectorIOConfig,
@ -273,28 +273,28 @@ class MilvusVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP
self.cache = {} self.cache = {}
self.client = None self.client = None
self.inference_api = inference_api self.inference_api = inference_api
self.vector_db_store = None self.vector_store_table = None
self.metadata_collection_name = "openai_vector_stores_metadata" self.metadata_collection_name = "openai_vector_stores_metadata"
async def initialize(self) -> None: async def initialize(self) -> None:
self.kvstore = await kvstore_impl(self.config.persistence) self.kvstore = await kvstore_impl(self.config.persistence)
start_key = VECTOR_DBS_PREFIX start_key = VECTOR_DBS_PREFIX
end_key = f"{VECTOR_DBS_PREFIX}\xff" end_key = f"{VECTOR_DBS_PREFIX}\xff"
stored_vector_dbs = await self.kvstore.values_in_range(start_key, end_key) stored_vector_stores = await self.kvstore.values_in_range(start_key, end_key)
for vector_db_data in stored_vector_dbs: for vector_store_data in stored_vector_stores:
vector_db = VectorDB.model_validate_json(vector_db_data) vector_store = VectorStore.model_validate_json(vector_store_data)
index = VectorDBWithIndex( index = VectorStoreWithIndex(
vector_db, vector_store,
index=MilvusIndex( index=MilvusIndex(
client=self.client, client=self.client,
collection_name=vector_db.identifier, collection_name=vector_store.identifier,
consistency_level=self.config.consistency_level, consistency_level=self.config.consistency_level,
kvstore=self.kvstore, kvstore=self.kvstore,
), ),
inference_api=self.inference_api, inference_api=self.inference_api,
) )
self.cache[vector_db.identifier] = index self.cache[vector_store.identifier] = index
if isinstance(self.config, RemoteMilvusVectorIOConfig): if isinstance(self.config, RemoteMilvusVectorIOConfig):
logger.info(f"Connecting to Milvus server at {self.config.uri}") logger.info(f"Connecting to Milvus server at {self.config.uri}")
self.client = MilvusClient(**self.config.model_dump(exclude_none=True)) self.client = MilvusClient(**self.config.model_dump(exclude_none=True))
@ -311,45 +311,45 @@ class MilvusVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP
# Clean up mixin resources (file batch tasks) # Clean up mixin resources (file batch tasks)
await super().shutdown() await super().shutdown()
async def register_vector_db(self, vector_db: VectorDB) -> None: async def register_vector_store(self, vector_store: VectorStore) -> None:
if isinstance(self.config, RemoteMilvusVectorIOConfig): if isinstance(self.config, RemoteMilvusVectorIOConfig):
consistency_level = self.config.consistency_level consistency_level = self.config.consistency_level
else: else:
consistency_level = "Strong" consistency_level = "Strong"
index = VectorDBWithIndex( index = VectorStoreWithIndex(
vector_db=vector_db, vector_store=vector_store,
index=MilvusIndex(self.client, vector_db.identifier, consistency_level=consistency_level), index=MilvusIndex(self.client, vector_store.identifier, consistency_level=consistency_level),
inference_api=self.inference_api, inference_api=self.inference_api,
) )
self.cache[vector_db.identifier] = index self.cache[vector_store.identifier] = index
async def _get_and_cache_vector_db_index(self, vector_db_id: str) -> VectorDBWithIndex | None: async def _get_and_cache_vector_store_index(self, vector_store_id: str) -> VectorStoreWithIndex | None:
if vector_db_id in self.cache: if vector_store_id in self.cache:
return self.cache[vector_db_id] return self.cache[vector_store_id]
if self.vector_db_store is None: if self.vector_store_table is None:
raise VectorStoreNotFoundError(vector_db_id) raise VectorStoreNotFoundError(vector_store_id)
vector_db = await self.vector_db_store.get_vector_db(vector_db_id) vector_store = await self.vector_store_table.get_vector_store(vector_store_id)
if not vector_db: if not vector_store:
raise VectorStoreNotFoundError(vector_db_id) raise VectorStoreNotFoundError(vector_store_id)
index = VectorDBWithIndex( index = VectorStoreWithIndex(
vector_db=vector_db, vector_store=vector_store,
index=MilvusIndex(client=self.client, collection_name=vector_db.identifier, kvstore=self.kvstore), index=MilvusIndex(client=self.client, collection_name=vector_store.identifier, kvstore=self.kvstore),
inference_api=self.inference_api, inference_api=self.inference_api,
) )
self.cache[vector_db_id] = index self.cache[vector_store_id] = index
return index return index
async def unregister_vector_db(self, vector_db_id: str) -> None: async def unregister_vector_store(self, vector_store_id: str) -> None:
if vector_db_id in self.cache: if vector_store_id in self.cache:
await self.cache[vector_db_id].index.delete() await self.cache[vector_store_id].index.delete()
del self.cache[vector_db_id] del self.cache[vector_store_id]
async def insert_chunks(self, vector_db_id: str, chunks: list[Chunk], ttl_seconds: int | None = None) -> None: async def insert_chunks(self, vector_db_id: str, chunks: list[Chunk], ttl_seconds: int | None = None) -> None:
index = await self._get_and_cache_vector_db_index(vector_db_id) index = await self._get_and_cache_vector_store_index(vector_db_id)
if not index: if not index:
raise VectorStoreNotFoundError(vector_db_id) raise VectorStoreNotFoundError(vector_db_id)
@ -358,14 +358,14 @@ class MilvusVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP
async def query_chunks( async def query_chunks(
self, vector_db_id: str, query: InterleavedContent, params: dict[str, Any] | None = None self, vector_db_id: str, query: InterleavedContent, params: dict[str, Any] | None = None
) -> QueryChunksResponse: ) -> QueryChunksResponse:
index = await self._get_and_cache_vector_db_index(vector_db_id) index = await self._get_and_cache_vector_store_index(vector_db_id)
if not index: if not index:
raise VectorStoreNotFoundError(vector_db_id) raise VectorStoreNotFoundError(vector_db_id)
return await index.query_chunks(query, params) return await index.query_chunks(query, params)
async def delete_chunks(self, store_id: str, chunks_for_deletion: list[ChunkForDeletion]) -> None: async def delete_chunks(self, store_id: str, chunks_for_deletion: list[ChunkForDeletion]) -> None:
"""Delete a chunk from a milvus vector store.""" """Delete a chunk from a milvus vector store."""
index = await self._get_and_cache_vector_db_index(store_id) index = await self._get_and_cache_vector_store_index(store_id)
if not index: if not index:
raise VectorStoreNotFoundError(store_id) raise VectorStoreNotFoundError(store_id)

View file

@ -16,15 +16,15 @@ from pydantic import BaseModel, TypeAdapter
from llama_stack.apis.common.errors import VectorStoreNotFoundError from llama_stack.apis.common.errors import VectorStoreNotFoundError
from llama_stack.apis.files import Files from llama_stack.apis.files import Files
from llama_stack.apis.inference import Inference, InterleavedContent from llama_stack.apis.inference import Inference, InterleavedContent
from llama_stack.apis.vector_dbs import VectorDB
from llama_stack.apis.vector_io import Chunk, QueryChunksResponse, VectorIO from llama_stack.apis.vector_io import Chunk, QueryChunksResponse, VectorIO
from llama_stack.apis.vector_stores import VectorStore
from llama_stack.log import get_logger from llama_stack.log import get_logger
from llama_stack.providers.datatypes import VectorDBsProtocolPrivate from llama_stack.providers.datatypes import VectorStoresProtocolPrivate
from llama_stack.providers.utils.inference.prompt_adapter import interleaved_content_as_str from llama_stack.providers.utils.inference.prompt_adapter import interleaved_content_as_str
from llama_stack.providers.utils.kvstore import kvstore_impl from llama_stack.providers.utils.kvstore import kvstore_impl
from llama_stack.providers.utils.kvstore.api import KVStore from llama_stack.providers.utils.kvstore.api import KVStore
from llama_stack.providers.utils.memory.openai_vector_store_mixin import OpenAIVectorStoreMixin from llama_stack.providers.utils.memory.openai_vector_store_mixin import OpenAIVectorStoreMixin
from llama_stack.providers.utils.memory.vector_store import ChunkForDeletion, EmbeddingIndex, VectorDBWithIndex from llama_stack.providers.utils.memory.vector_store import ChunkForDeletion, EmbeddingIndex, VectorStoreWithIndex
from llama_stack.providers.utils.vector_io.vector_utils import WeightedInMemoryAggregator, sanitize_collection_name from llama_stack.providers.utils.vector_io.vector_utils import WeightedInMemoryAggregator, sanitize_collection_name
from .config import PGVectorVectorIOConfig from .config import PGVectorVectorIOConfig
@ -32,7 +32,7 @@ from .config import PGVectorVectorIOConfig
log = get_logger(name=__name__, category="vector_io::pgvector") log = get_logger(name=__name__, category="vector_io::pgvector")
VERSION = "v3" VERSION = "v3"
VECTOR_DBS_PREFIX = f"vector_dbs:pgvector:{VERSION}::" VECTOR_DBS_PREFIX = f"vector_stores:pgvector:{VERSION}::"
VECTOR_INDEX_PREFIX = f"vector_index:pgvector:{VERSION}::" VECTOR_INDEX_PREFIX = f"vector_index:pgvector:{VERSION}::"
OPENAI_VECTOR_STORES_PREFIX = f"openai_vector_stores:pgvector:{VERSION}::" OPENAI_VECTOR_STORES_PREFIX = f"openai_vector_stores:pgvector:{VERSION}::"
OPENAI_VECTOR_STORES_FILES_PREFIX = f"openai_vector_stores_files:pgvector:{VERSION}::" OPENAI_VECTOR_STORES_FILES_PREFIX = f"openai_vector_stores_files:pgvector:{VERSION}::"
@ -79,13 +79,13 @@ class PGVectorIndex(EmbeddingIndex):
def __init__( def __init__(
self, self,
vector_db: VectorDB, vector_store: VectorStore,
dimension: int, dimension: int,
conn: psycopg2.extensions.connection, conn: psycopg2.extensions.connection,
kvstore: KVStore | None = None, kvstore: KVStore | None = None,
distance_metric: str = "COSINE", distance_metric: str = "COSINE",
): ):
self.vector_db = vector_db self.vector_store = vector_store
self.dimension = dimension self.dimension = dimension
self.conn = conn self.conn = conn
self.kvstore = kvstore self.kvstore = kvstore
@ -97,9 +97,9 @@ class PGVectorIndex(EmbeddingIndex):
try: try:
with self.conn.cursor(cursor_factory=psycopg2.extras.DictCursor) as cur: with self.conn.cursor(cursor_factory=psycopg2.extras.DictCursor) as cur:
# Sanitize the table name by replacing hyphens with underscores # Sanitize the table name by replacing hyphens with underscores
# SQL doesn't allow hyphens in table names, and vector_db.identifier may contain hyphens # SQL doesn't allow hyphens in table names, and vector_store.identifier may contain hyphens
# when created with patterns like "test-vector-db-{uuid4()}" # when created with patterns like "test-vector-db-{uuid4()}"
sanitized_identifier = sanitize_collection_name(self.vector_db.identifier) sanitized_identifier = sanitize_collection_name(self.vector_store.identifier)
self.table_name = f"vs_{sanitized_identifier}" self.table_name = f"vs_{sanitized_identifier}"
cur.execute( cur.execute(
@ -122,8 +122,8 @@ class PGVectorIndex(EmbeddingIndex):
""" """
) )
except Exception as e: except Exception as e:
log.exception(f"Error creating PGVectorIndex for vector_db: {self.vector_db.identifier}") log.exception(f"Error creating PGVectorIndex for vector_store: {self.vector_store.identifier}")
raise RuntimeError(f"Error creating PGVectorIndex for vector_db: {self.vector_db.identifier}") from e raise RuntimeError(f"Error creating PGVectorIndex for vector_store: {self.vector_store.identifier}") from e
async def add_chunks(self, chunks: list[Chunk], embeddings: NDArray): async def add_chunks(self, chunks: list[Chunk], embeddings: NDArray):
assert len(chunks) == len(embeddings), ( assert len(chunks) == len(embeddings), (
@ -323,7 +323,7 @@ class PGVectorIndex(EmbeddingIndex):
) )
class PGVectorVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolPrivate): class PGVectorVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorStoresProtocolPrivate):
def __init__( def __init__(
self, config: PGVectorVectorIOConfig, inference_api: Inference, files_api: Files | None = None self, config: PGVectorVectorIOConfig, inference_api: Inference, files_api: Files | None = None
) -> None: ) -> None:
@ -332,7 +332,7 @@ class PGVectorVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtoco
self.inference_api = inference_api self.inference_api = inference_api
self.conn = None self.conn = None
self.cache = {} self.cache = {}
self.vector_db_store = None self.vector_store_table = None
self.metadata_collection_name = "openai_vector_stores_metadata" self.metadata_collection_name = "openai_vector_stores_metadata"
async def initialize(self) -> None: async def initialize(self) -> None:
@ -375,59 +375,59 @@ class PGVectorVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtoco
# Clean up mixin resources (file batch tasks) # Clean up mixin resources (file batch tasks)
await super().shutdown() await super().shutdown()
async def register_vector_db(self, vector_db: VectorDB) -> None: async def register_vector_store(self, vector_store: VectorStore) -> None:
# Persist vector DB metadata in the KV store # Persist vector DB metadata in the KV store
assert self.kvstore is not None assert self.kvstore is not None
# Upsert model metadata in Postgres # Upsert model metadata in Postgres
upsert_models(self.conn, [(vector_db.identifier, vector_db)]) upsert_models(self.conn, [(vector_store.identifier, vector_store)])
# Create and cache the PGVector index table for the vector DB # Create and cache the PGVector index table for the vector DB
pgvector_index = PGVectorIndex( pgvector_index = PGVectorIndex(
vector_db=vector_db, dimension=vector_db.embedding_dimension, conn=self.conn, kvstore=self.kvstore vector_store=vector_store, dimension=vector_store.embedding_dimension, conn=self.conn, kvstore=self.kvstore
) )
await pgvector_index.initialize() await pgvector_index.initialize()
index = VectorDBWithIndex(vector_db, index=pgvector_index, inference_api=self.inference_api) index = VectorStoreWithIndex(vector_store, index=pgvector_index, inference_api=self.inference_api)
self.cache[vector_db.identifier] = index self.cache[vector_store.identifier] = index
async def unregister_vector_db(self, vector_db_id: str) -> None: async def unregister_vector_store(self, vector_store_id: str) -> None:
# Remove provider index and cache # Remove provider index and cache
if vector_db_id in self.cache: if vector_store_id in self.cache:
await self.cache[vector_db_id].index.delete() await self.cache[vector_store_id].index.delete()
del self.cache[vector_db_id] del self.cache[vector_store_id]
# Delete vector DB metadata from KV store # Delete vector DB metadata from KV store
assert self.kvstore is not None assert self.kvstore is not None
await self.kvstore.delete(key=f"{VECTOR_DBS_PREFIX}{vector_db_id}") await self.kvstore.delete(key=f"{VECTOR_DBS_PREFIX}{vector_store_id}")
async def insert_chunks(self, vector_db_id: str, chunks: list[Chunk], ttl_seconds: int | None = None) -> None: async def insert_chunks(self, vector_db_id: str, chunks: list[Chunk], ttl_seconds: int | None = None) -> None:
index = await self._get_and_cache_vector_db_index(vector_db_id) index = await self._get_and_cache_vector_store_index(vector_db_id)
await index.insert_chunks(chunks) await index.insert_chunks(chunks)
async def query_chunks( async def query_chunks(
self, vector_db_id: str, query: InterleavedContent, params: dict[str, Any] | None = None self, vector_db_id: str, query: InterleavedContent, params: dict[str, Any] | None = None
) -> QueryChunksResponse: ) -> QueryChunksResponse:
index = await self._get_and_cache_vector_db_index(vector_db_id) index = await self._get_and_cache_vector_store_index(vector_db_id)
return await index.query_chunks(query, params) return await index.query_chunks(query, params)
async def _get_and_cache_vector_db_index(self, vector_db_id: str) -> VectorDBWithIndex: async def _get_and_cache_vector_store_index(self, vector_store_id: str) -> VectorStoreWithIndex:
if vector_db_id in self.cache: if vector_store_id in self.cache:
return self.cache[vector_db_id] return self.cache[vector_store_id]
if self.vector_db_store is None: if self.vector_store_table is None:
raise VectorStoreNotFoundError(vector_db_id) raise VectorStoreNotFoundError(vector_store_id)
vector_db = await self.vector_db_store.get_vector_db(vector_db_id) vector_store = await self.vector_store_table.get_vector_store(vector_store_id)
if not vector_db: if not vector_store:
raise VectorStoreNotFoundError(vector_db_id) raise VectorStoreNotFoundError(vector_store_id)
index = PGVectorIndex(vector_db, vector_db.embedding_dimension, self.conn) index = PGVectorIndex(vector_store, vector_store.embedding_dimension, self.conn)
await index.initialize() await index.initialize()
self.cache[vector_db_id] = VectorDBWithIndex(vector_db, index, self.inference_api) self.cache[vector_store_id] = VectorStoreWithIndex(vector_store, index, self.inference_api)
return self.cache[vector_db_id] return self.cache[vector_store_id]
async def delete_chunks(self, store_id: str, chunks_for_deletion: list[ChunkForDeletion]) -> None: async def delete_chunks(self, store_id: str, chunks_for_deletion: list[ChunkForDeletion]) -> None:
"""Delete a chunk from a PostgreSQL vector store.""" """Delete a chunk from a PostgreSQL vector store."""
index = await self._get_and_cache_vector_db_index(store_id) index = await self._get_and_cache_vector_store_index(store_id)
if not index: if not index:
raise VectorStoreNotFoundError(store_id) raise VectorStoreNotFoundError(store_id)

View file

@ -16,7 +16,6 @@ from qdrant_client.models import PointStruct
from llama_stack.apis.common.errors import VectorStoreNotFoundError from llama_stack.apis.common.errors import VectorStoreNotFoundError
from llama_stack.apis.files import Files from llama_stack.apis.files import Files
from llama_stack.apis.inference import Inference, InterleavedContent from llama_stack.apis.inference import Inference, InterleavedContent
from llama_stack.apis.vector_dbs import VectorDB
from llama_stack.apis.vector_io import ( from llama_stack.apis.vector_io import (
Chunk, Chunk,
QueryChunksResponse, QueryChunksResponse,
@ -24,12 +23,13 @@ from llama_stack.apis.vector_io import (
VectorStoreChunkingStrategy, VectorStoreChunkingStrategy,
VectorStoreFileObject, VectorStoreFileObject,
) )
from llama_stack.apis.vector_stores import VectorStore
from llama_stack.log import get_logger from llama_stack.log import get_logger
from llama_stack.providers.datatypes import VectorDBsProtocolPrivate from llama_stack.providers.datatypes import VectorStoresProtocolPrivate
from llama_stack.providers.inline.vector_io.qdrant import QdrantVectorIOConfig as InlineQdrantVectorIOConfig from llama_stack.providers.inline.vector_io.qdrant import QdrantVectorIOConfig as InlineQdrantVectorIOConfig
from llama_stack.providers.utils.kvstore import kvstore_impl from llama_stack.providers.utils.kvstore import kvstore_impl
from llama_stack.providers.utils.memory.openai_vector_store_mixin import OpenAIVectorStoreMixin from llama_stack.providers.utils.memory.openai_vector_store_mixin import OpenAIVectorStoreMixin
from llama_stack.providers.utils.memory.vector_store import ChunkForDeletion, EmbeddingIndex, VectorDBWithIndex from llama_stack.providers.utils.memory.vector_store import ChunkForDeletion, EmbeddingIndex, VectorStoreWithIndex
from .config import QdrantVectorIOConfig as RemoteQdrantVectorIOConfig from .config import QdrantVectorIOConfig as RemoteQdrantVectorIOConfig
@ -38,7 +38,7 @@ CHUNK_ID_KEY = "_chunk_id"
# KV store prefixes for vector databases # KV store prefixes for vector databases
VERSION = "v3" VERSION = "v3"
VECTOR_DBS_PREFIX = f"vector_dbs:qdrant:{VERSION}::" VECTOR_DBS_PREFIX = f"vector_stores:qdrant:{VERSION}::"
def convert_id(_id: str) -> str: def convert_id(_id: str) -> str:
@ -145,7 +145,7 @@ class QdrantIndex(EmbeddingIndex):
await self.client.delete_collection(collection_name=self.collection_name) await self.client.delete_collection(collection_name=self.collection_name)
class QdrantVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolPrivate): class QdrantVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorStoresProtocolPrivate):
def __init__( def __init__(
self, self,
config: RemoteQdrantVectorIOConfig | InlineQdrantVectorIOConfig, config: RemoteQdrantVectorIOConfig | InlineQdrantVectorIOConfig,
@ -157,7 +157,7 @@ class QdrantVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP
self.client: AsyncQdrantClient = None self.client: AsyncQdrantClient = None
self.cache = {} self.cache = {}
self.inference_api = inference_api self.inference_api = inference_api
self.vector_db_store = None self.vector_store_table = None
self._qdrant_lock = asyncio.Lock() self._qdrant_lock = asyncio.Lock()
async def initialize(self) -> None: async def initialize(self) -> None:
@ -167,12 +167,14 @@ class QdrantVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP
start_key = VECTOR_DBS_PREFIX start_key = VECTOR_DBS_PREFIX
end_key = f"{VECTOR_DBS_PREFIX}\xff" end_key = f"{VECTOR_DBS_PREFIX}\xff"
stored_vector_dbs = await self.kvstore.values_in_range(start_key, end_key) stored_vector_stores = await self.kvstore.values_in_range(start_key, end_key)
for vector_db_data in stored_vector_dbs: for vector_store_data in stored_vector_stores:
vector_db = VectorDB.model_validate_json(vector_db_data) vector_store = VectorStore.model_validate_json(vector_store_data)
index = VectorDBWithIndex(vector_db, QdrantIndex(self.client, vector_db.identifier), self.inference_api) index = VectorStoreWithIndex(
self.cache[vector_db.identifier] = index vector_store, QdrantIndex(self.client, vector_store.identifier), self.inference_api
)
self.cache[vector_store.identifier] = index
self.openai_vector_stores = await self._load_openai_vector_stores() self.openai_vector_stores = await self._load_openai_vector_stores()
async def shutdown(self) -> None: async def shutdown(self) -> None:
@ -180,46 +182,48 @@ class QdrantVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP
# Clean up mixin resources (file batch tasks) # Clean up mixin resources (file batch tasks)
await super().shutdown() await super().shutdown()
async def register_vector_db(self, vector_db: VectorDB) -> None: async def register_vector_store(self, vector_store: VectorStore) -> None:
assert self.kvstore is not None assert self.kvstore is not None
key = f"{VECTOR_DBS_PREFIX}{vector_db.identifier}" key = f"{VECTOR_DBS_PREFIX}{vector_store.identifier}"
await self.kvstore.set(key=key, value=vector_db.model_dump_json()) await self.kvstore.set(key=key, value=vector_store.model_dump_json())
index = VectorDBWithIndex( index = VectorStoreWithIndex(
vector_db=vector_db, index=QdrantIndex(self.client, vector_db.identifier), inference_api=self.inference_api vector_store=vector_store,
) index=QdrantIndex(self.client, vector_store.identifier),
self.cache[vector_db.identifier] = index
async def unregister_vector_db(self, vector_db_id: str) -> None:
if vector_db_id in self.cache:
await self.cache[vector_db_id].index.delete()
del self.cache[vector_db_id]
assert self.kvstore is not None
await self.kvstore.delete(f"{VECTOR_DBS_PREFIX}{vector_db_id}")
async def _get_and_cache_vector_db_index(self, vector_db_id: str) -> VectorDBWithIndex | None:
if vector_db_id in self.cache:
return self.cache[vector_db_id]
if self.vector_db_store is None:
raise ValueError(f"Vector DB not found {vector_db_id}")
vector_db = await self.vector_db_store.get_vector_db(vector_db_id)
if not vector_db:
raise VectorStoreNotFoundError(vector_db_id)
index = VectorDBWithIndex(
vector_db=vector_db,
index=QdrantIndex(client=self.client, collection_name=vector_db.identifier),
inference_api=self.inference_api, inference_api=self.inference_api,
) )
self.cache[vector_db_id] = index
self.cache[vector_store.identifier] = index
async def unregister_vector_store(self, vector_store_id: str) -> None:
if vector_store_id in self.cache:
await self.cache[vector_store_id].index.delete()
del self.cache[vector_store_id]
assert self.kvstore is not None
await self.kvstore.delete(f"{VECTOR_DBS_PREFIX}{vector_store_id}")
async def _get_and_cache_vector_store_index(self, vector_store_id: str) -> VectorStoreWithIndex | None:
if vector_store_id in self.cache:
return self.cache[vector_store_id]
if self.vector_store_table is None:
raise ValueError(f"Vector DB not found {vector_store_id}")
vector_store = await self.vector_store_table.get_vector_store(vector_store_id)
if not vector_store:
raise VectorStoreNotFoundError(vector_store_id)
index = VectorStoreWithIndex(
vector_store=vector_store,
index=QdrantIndex(client=self.client, collection_name=vector_store.identifier),
inference_api=self.inference_api,
)
self.cache[vector_store_id] = index
return index return index
async def insert_chunks(self, vector_db_id: str, chunks: list[Chunk], ttl_seconds: int | None = None) -> None: async def insert_chunks(self, vector_db_id: str, chunks: list[Chunk], ttl_seconds: int | None = None) -> None:
index = await self._get_and_cache_vector_db_index(vector_db_id) index = await self._get_and_cache_vector_store_index(vector_db_id)
if not index: if not index:
raise VectorStoreNotFoundError(vector_db_id) raise VectorStoreNotFoundError(vector_db_id)
@ -228,7 +232,7 @@ class QdrantVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP
async def query_chunks( async def query_chunks(
self, vector_db_id: str, query: InterleavedContent, params: dict[str, Any] | None = None self, vector_db_id: str, query: InterleavedContent, params: dict[str, Any] | None = None
) -> QueryChunksResponse: ) -> QueryChunksResponse:
index = await self._get_and_cache_vector_db_index(vector_db_id) index = await self._get_and_cache_vector_store_index(vector_db_id)
if not index: if not index:
raise VectorStoreNotFoundError(vector_db_id) raise VectorStoreNotFoundError(vector_db_id)
@ -249,7 +253,7 @@ class QdrantVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP
async def delete_chunks(self, store_id: str, chunks_for_deletion: list[ChunkForDeletion]) -> None: async def delete_chunks(self, store_id: str, chunks_for_deletion: list[ChunkForDeletion]) -> None:
"""Delete chunks from a Qdrant vector store.""" """Delete chunks from a Qdrant vector store."""
index = await self._get_and_cache_vector_db_index(store_id) index = await self._get_and_cache_vector_store_index(store_id)
if not index: if not index:
raise ValueError(f"Vector DB {store_id} not found") raise ValueError(f"Vector DB {store_id} not found")

View file

@ -16,11 +16,11 @@ from llama_stack.apis.common.content_types import InterleavedContent
from llama_stack.apis.common.errors import VectorStoreNotFoundError from llama_stack.apis.common.errors import VectorStoreNotFoundError
from llama_stack.apis.files import Files from llama_stack.apis.files import Files
from llama_stack.apis.inference import Inference from llama_stack.apis.inference import Inference
from llama_stack.apis.vector_dbs import VectorDB
from llama_stack.apis.vector_io import Chunk, QueryChunksResponse, VectorIO from llama_stack.apis.vector_io import Chunk, QueryChunksResponse, VectorIO
from llama_stack.apis.vector_stores import VectorStore
from llama_stack.core.request_headers import NeedsRequestProviderData from llama_stack.core.request_headers import NeedsRequestProviderData
from llama_stack.log import get_logger from llama_stack.log import get_logger
from llama_stack.providers.datatypes import VectorDBsProtocolPrivate from llama_stack.providers.datatypes import VectorStoresProtocolPrivate
from llama_stack.providers.utils.kvstore import kvstore_impl from llama_stack.providers.utils.kvstore import kvstore_impl
from llama_stack.providers.utils.kvstore.api import KVStore from llama_stack.providers.utils.kvstore.api import KVStore
from llama_stack.providers.utils.memory.openai_vector_store_mixin import OpenAIVectorStoreMixin from llama_stack.providers.utils.memory.openai_vector_store_mixin import OpenAIVectorStoreMixin
@ -28,7 +28,7 @@ from llama_stack.providers.utils.memory.vector_store import (
RERANKER_TYPE_RRF, RERANKER_TYPE_RRF,
ChunkForDeletion, ChunkForDeletion,
EmbeddingIndex, EmbeddingIndex,
VectorDBWithIndex, VectorStoreWithIndex,
) )
from llama_stack.providers.utils.vector_io.vector_utils import sanitize_collection_name from llama_stack.providers.utils.vector_io.vector_utils import sanitize_collection_name
@ -37,7 +37,7 @@ from .config import WeaviateVectorIOConfig
log = get_logger(name=__name__, category="vector_io::weaviate") log = get_logger(name=__name__, category="vector_io::weaviate")
VERSION = "v3" VERSION = "v3"
VECTOR_DBS_PREFIX = f"vector_dbs:weaviate:{VERSION}::" VECTOR_DBS_PREFIX = f"vector_stores:weaviate:{VERSION}::"
VECTOR_INDEX_PREFIX = f"vector_index:weaviate:{VERSION}::" VECTOR_INDEX_PREFIX = f"vector_index:weaviate:{VERSION}::"
OPENAI_VECTOR_STORES_PREFIX = f"openai_vector_stores:weaviate:{VERSION}::" OPENAI_VECTOR_STORES_PREFIX = f"openai_vector_stores:weaviate:{VERSION}::"
OPENAI_VECTOR_STORES_FILES_PREFIX = f"openai_vector_stores_files:weaviate:{VERSION}::" OPENAI_VECTOR_STORES_FILES_PREFIX = f"openai_vector_stores_files:weaviate:{VERSION}::"
@ -257,14 +257,14 @@ class WeaviateIndex(EmbeddingIndex):
return QueryChunksResponse(chunks=chunks, scores=scores) return QueryChunksResponse(chunks=chunks, scores=scores)
class WeaviateVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, NeedsRequestProviderData, VectorDBsProtocolPrivate): class WeaviateVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, NeedsRequestProviderData, VectorStoresProtocolPrivate):
def __init__(self, config: WeaviateVectorIOConfig, inference_api: Inference, files_api: Files | None) -> None: def __init__(self, config: WeaviateVectorIOConfig, inference_api: Inference, files_api: Files | None) -> None:
super().__init__(files_api=files_api, kvstore=None) super().__init__(files_api=files_api, kvstore=None)
self.config = config self.config = config
self.inference_api = inference_api self.inference_api = inference_api
self.client_cache = {} self.client_cache = {}
self.cache = {} self.cache = {}
self.vector_db_store = None self.vector_store_table = None
self.metadata_collection_name = "openai_vector_stores_metadata" self.metadata_collection_name = "openai_vector_stores_metadata"
def _get_client(self) -> weaviate.WeaviateClient: def _get_client(self) -> weaviate.WeaviateClient:
@ -300,11 +300,11 @@ class WeaviateVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, NeedsRequestProv
end_key = f"{VECTOR_DBS_PREFIX}\xff" end_key = f"{VECTOR_DBS_PREFIX}\xff"
stored = await self.kvstore.values_in_range(start_key, end_key) stored = await self.kvstore.values_in_range(start_key, end_key)
for raw in stored: for raw in stored:
vector_db = VectorDB.model_validate_json(raw) vector_store = VectorStore.model_validate_json(raw)
client = self._get_client() client = self._get_client()
idx = WeaviateIndex(client=client, collection_name=vector_db.identifier, kvstore=self.kvstore) idx = WeaviateIndex(client=client, collection_name=vector_store.identifier, kvstore=self.kvstore)
self.cache[vector_db.identifier] = VectorDBWithIndex( self.cache[vector_store.identifier] = VectorStoreWithIndex(
vector_db=vector_db, index=idx, inference_api=self.inference_api vector_store=vector_store, index=idx, inference_api=self.inference_api
) )
# Load OpenAI vector stores metadata into cache # Load OpenAI vector stores metadata into cache
@ -316,9 +316,9 @@ class WeaviateVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, NeedsRequestProv
# Clean up mixin resources (file batch tasks) # Clean up mixin resources (file batch tasks)
await super().shutdown() await super().shutdown()
async def register_vector_db(self, vector_db: VectorDB) -> None: async def register_vector_store(self, vector_store: VectorStore) -> None:
client = self._get_client() client = self._get_client()
sanitized_collection_name = sanitize_collection_name(vector_db.identifier, weaviate_format=True) sanitized_collection_name = sanitize_collection_name(vector_store.identifier, weaviate_format=True)
# Create collection if it doesn't exist # Create collection if it doesn't exist
if not client.collections.exists(sanitized_collection_name): if not client.collections.exists(sanitized_collection_name):
client.collections.create( client.collections.create(
@ -329,45 +329,45 @@ class WeaviateVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, NeedsRequestProv
], ],
) )
self.cache[vector_db.identifier] = VectorDBWithIndex( self.cache[vector_store.identifier] = VectorStoreWithIndex(
vector_db, WeaviateIndex(client=client, collection_name=sanitized_collection_name), self.inference_api vector_store, WeaviateIndex(client=client, collection_name=sanitized_collection_name), self.inference_api
) )
async def unregister_vector_db(self, vector_db_id: str) -> None: async def unregister_vector_store(self, vector_store_id: str) -> None:
client = self._get_client() client = self._get_client()
sanitized_collection_name = sanitize_collection_name(vector_db_id, weaviate_format=True) sanitized_collection_name = sanitize_collection_name(vector_store_id, weaviate_format=True)
if vector_db_id not in self.cache or client.collections.exists(sanitized_collection_name) is False: if vector_store_id not in self.cache or client.collections.exists(sanitized_collection_name) is False:
return return
client.collections.delete(sanitized_collection_name) client.collections.delete(sanitized_collection_name)
await self.cache[vector_db_id].index.delete() await self.cache[vector_store_id].index.delete()
del self.cache[vector_db_id] del self.cache[vector_store_id]
async def _get_and_cache_vector_db_index(self, vector_db_id: str) -> VectorDBWithIndex | None: async def _get_and_cache_vector_store_index(self, vector_store_id: str) -> VectorStoreWithIndex | None:
if vector_db_id in self.cache: if vector_store_id in self.cache:
return self.cache[vector_db_id] return self.cache[vector_store_id]
if self.vector_db_store is None: if self.vector_store_table is None:
raise VectorStoreNotFoundError(vector_db_id) raise VectorStoreNotFoundError(vector_store_id)
vector_db = await self.vector_db_store.get_vector_db(vector_db_id) vector_store = await self.vector_store_table.get_vector_store(vector_store_id)
if not vector_db: if not vector_store:
raise VectorStoreNotFoundError(vector_db_id) raise VectorStoreNotFoundError(vector_store_id)
client = self._get_client() client = self._get_client()
sanitized_collection_name = sanitize_collection_name(vector_db.identifier, weaviate_format=True) sanitized_collection_name = sanitize_collection_name(vector_store.identifier, weaviate_format=True)
if not client.collections.exists(sanitized_collection_name): if not client.collections.exists(sanitized_collection_name):
raise ValueError(f"Collection with name `{sanitized_collection_name}` not found") raise ValueError(f"Collection with name `{sanitized_collection_name}` not found")
index = VectorDBWithIndex( index = VectorStoreWithIndex(
vector_db=vector_db, vector_store=vector_store,
index=WeaviateIndex(client=client, collection_name=vector_db.identifier), index=WeaviateIndex(client=client, collection_name=vector_store.identifier),
inference_api=self.inference_api, inference_api=self.inference_api,
) )
self.cache[vector_db_id] = index self.cache[vector_store_id] = index
return index return index
async def insert_chunks(self, vector_db_id: str, chunks: list[Chunk], ttl_seconds: int | None = None) -> None: async def insert_chunks(self, vector_db_id: str, chunks: list[Chunk], ttl_seconds: int | None = None) -> None:
index = await self._get_and_cache_vector_db_index(vector_db_id) index = await self._get_and_cache_vector_store_index(vector_db_id)
if not index: if not index:
raise VectorStoreNotFoundError(vector_db_id) raise VectorStoreNotFoundError(vector_db_id)
@ -376,14 +376,14 @@ class WeaviateVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, NeedsRequestProv
async def query_chunks( async def query_chunks(
self, vector_db_id: str, query: InterleavedContent, params: dict[str, Any] | None = None self, vector_db_id: str, query: InterleavedContent, params: dict[str, Any] | None = None
) -> QueryChunksResponse: ) -> QueryChunksResponse:
index = await self._get_and_cache_vector_db_index(vector_db_id) index = await self._get_and_cache_vector_store_index(vector_db_id)
if not index: if not index:
raise VectorStoreNotFoundError(vector_db_id) raise VectorStoreNotFoundError(vector_db_id)
return await index.query_chunks(query, params) return await index.query_chunks(query, params)
async def delete_chunks(self, store_id: str, chunks_for_deletion: list[ChunkForDeletion]) -> None: async def delete_chunks(self, store_id: str, chunks_for_deletion: list[ChunkForDeletion]) -> None:
index = await self._get_and_cache_vector_db_index(store_id) index = await self._get_and_cache_vector_store_index(store_id)
if not index: if not index:
raise ValueError(f"Vector DB {store_id} not found") raise ValueError(f"Vector DB {store_id} not found")

View file

@ -17,7 +17,6 @@ from pydantic import TypeAdapter
from llama_stack.apis.common.errors import VectorStoreNotFoundError from llama_stack.apis.common.errors import VectorStoreNotFoundError
from llama_stack.apis.files import Files, OpenAIFileObject from llama_stack.apis.files import Files, OpenAIFileObject
from llama_stack.apis.vector_dbs import VectorDB
from llama_stack.apis.vector_io import ( from llama_stack.apis.vector_io import (
Chunk, Chunk,
OpenAICreateVectorStoreFileBatchRequestWithExtraBody, OpenAICreateVectorStoreFileBatchRequestWithExtraBody,
@ -43,6 +42,7 @@ from llama_stack.apis.vector_io import (
VectorStoreSearchResponse, VectorStoreSearchResponse,
VectorStoreSearchResponsePage, VectorStoreSearchResponsePage,
) )
from llama_stack.apis.vector_stores import VectorStore
from llama_stack.core.id_generation import generate_object_id from llama_stack.core.id_generation import generate_object_id
from llama_stack.log import get_logger from llama_stack.log import get_logger
from llama_stack.providers.utils.kvstore.api import KVStore from llama_stack.providers.utils.kvstore.api import KVStore
@ -63,7 +63,7 @@ MAX_CONCURRENT_FILES_PER_BATCH = 3 # Maximum concurrent file processing within
FILE_BATCH_CHUNK_SIZE = 10 # Process files in chunks of this size FILE_BATCH_CHUNK_SIZE = 10 # Process files in chunks of this size
VERSION = "v3" VERSION = "v3"
VECTOR_DBS_PREFIX = f"vector_dbs:{VERSION}::" VECTOR_DBS_PREFIX = f"vector_stores:{VERSION}::"
OPENAI_VECTOR_STORES_PREFIX = f"openai_vector_stores:{VERSION}::" OPENAI_VECTOR_STORES_PREFIX = f"openai_vector_stores:{VERSION}::"
OPENAI_VECTOR_STORES_FILES_PREFIX = f"openai_vector_stores_files:{VERSION}::" OPENAI_VECTOR_STORES_FILES_PREFIX = f"openai_vector_stores_files:{VERSION}::"
OPENAI_VECTOR_STORES_FILES_CONTENTS_PREFIX = f"openai_vector_stores_files_contents:{VERSION}::" OPENAI_VECTOR_STORES_FILES_CONTENTS_PREFIX = f"openai_vector_stores_files_contents:{VERSION}::"
@ -321,12 +321,12 @@ class OpenAIVectorStoreMixin(ABC):
pass pass
@abstractmethod @abstractmethod
async def register_vector_db(self, vector_db: VectorDB) -> None: async def register_vector_store(self, vector_store: VectorStore) -> None:
"""Register a vector database (provider-specific implementation).""" """Register a vector database (provider-specific implementation)."""
pass pass
@abstractmethod @abstractmethod
async def unregister_vector_db(self, vector_db_id: str) -> None: async def unregister_vector_store(self, vector_store_id: str) -> None:
"""Unregister a vector database (provider-specific implementation).""" """Unregister a vector database (provider-specific implementation)."""
pass pass
@ -358,7 +358,7 @@ class OpenAIVectorStoreMixin(ABC):
extra_body = params.model_extra or {} extra_body = params.model_extra or {}
metadata = params.metadata or {} metadata = params.metadata or {}
provider_vector_db_id = extra_body.get("provider_vector_db_id") provider_vector_store_id = extra_body.get("provider_vector_store_id")
# Use embedding info from metadata if available, otherwise from extra_body # Use embedding info from metadata if available, otherwise from extra_body
if metadata.get("embedding_model"): if metadata.get("embedding_model"):
@ -389,8 +389,8 @@ class OpenAIVectorStoreMixin(ABC):
# use provider_id set by router; fallback to provider's own ID when used directly via --stack-config # use provider_id set by router; fallback to provider's own ID when used directly via --stack-config
provider_id = extra_body.get("provider_id") or getattr(self, "__provider_id__", None) provider_id = extra_body.get("provider_id") or getattr(self, "__provider_id__", None)
# Derive the canonical vector_db_id (allow override, else generate) # Derive the canonical vector_store_id (allow override, else generate)
vector_db_id = provider_vector_db_id or generate_object_id("vector_store", lambda: f"vs_{uuid.uuid4()}") vector_store_id = provider_vector_store_id or generate_object_id("vector_store", lambda: f"vs_{uuid.uuid4()}")
if embedding_model is None: if embedding_model is None:
raise ValueError("embedding_model is required") raise ValueError("embedding_model is required")
@ -398,19 +398,20 @@ class OpenAIVectorStoreMixin(ABC):
if embedding_dimension is None: if embedding_dimension is None:
raise ValueError("Embedding dimension is required") raise ValueError("Embedding dimension is required")
# Register the VectorDB backing this vector store # Register the VectorStore backing this vector store
if provider_id is None: if provider_id is None:
raise ValueError("Provider ID is required but was not provided") raise ValueError("Provider ID is required but was not provided")
vector_db = VectorDB( # call to the provider to create any index, etc.
identifier=vector_db_id, vector_store = VectorStore(
identifier=vector_store_id,
embedding_dimension=embedding_dimension, embedding_dimension=embedding_dimension,
embedding_model=embedding_model, embedding_model=embedding_model,
provider_id=provider_id, provider_id=provider_id,
provider_resource_id=vector_db_id, provider_resource_id=vector_store_id,
vector_db_name=params.name, vector_store_name=params.name,
) )
await self.register_vector_db(vector_db) await self.register_vector_store(vector_store)
# Create OpenAI vector store metadata # Create OpenAI vector store metadata
status = "completed" status = "completed"
@ -424,7 +425,7 @@ class OpenAIVectorStoreMixin(ABC):
total=0, total=0,
) )
store_info: dict[str, Any] = { store_info: dict[str, Any] = {
"id": vector_db_id, "id": vector_store_id,
"object": "vector_store", "object": "vector_store",
"created_at": created_at, "created_at": created_at,
"name": params.name, "name": params.name,
@ -441,23 +442,23 @@ class OpenAIVectorStoreMixin(ABC):
# Add provider information to metadata if provided # Add provider information to metadata if provided
if provider_id: if provider_id:
metadata["provider_id"] = provider_id metadata["provider_id"] = provider_id
if provider_vector_db_id: if provider_vector_store_id:
metadata["provider_vector_db_id"] = provider_vector_db_id metadata["provider_vector_store_id"] = provider_vector_store_id
store_info["metadata"] = metadata store_info["metadata"] = metadata
# Save to persistent storage (provider-specific) # Save to persistent storage (provider-specific)
await self._save_openai_vector_store(vector_db_id, store_info) await self._save_openai_vector_store(vector_store_id, store_info)
# Store in memory cache # Store in memory cache
self.openai_vector_stores[vector_db_id] = store_info self.openai_vector_stores[vector_store_id] = store_info
# Now that our vector store is created, attach any files that were provided # Now that our vector store is created, attach any files that were provided
file_ids = params.file_ids or [] file_ids = params.file_ids or []
tasks = [self.openai_attach_file_to_vector_store(vector_db_id, file_id) for file_id in file_ids] tasks = [self.openai_attach_file_to_vector_store(vector_store_id, file_id) for file_id in file_ids]
await asyncio.gather(*tasks) await asyncio.gather(*tasks)
# Get the updated store info and return it # Get the updated store info and return it
store_info = self.openai_vector_stores[vector_db_id] store_info = self.openai_vector_stores[vector_store_id]
return VectorStoreObject.model_validate(store_info) return VectorStoreObject.model_validate(store_info)
async def openai_list_vector_stores( async def openai_list_vector_stores(
@ -567,7 +568,7 @@ class OpenAIVectorStoreMixin(ABC):
# Also delete the underlying vector DB # Also delete the underlying vector DB
try: try:
await self.unregister_vector_db(vector_store_id) await self.unregister_vector_store(vector_store_id)
except Exception as e: except Exception as e:
logger.warning(f"Failed to delete underlying vector DB {vector_store_id}: {e}") logger.warning(f"Failed to delete underlying vector DB {vector_store_id}: {e}")

View file

@ -12,19 +12,16 @@ from dataclasses import dataclass
from typing import Any from typing import Any
from urllib.parse import unquote from urllib.parse import unquote
import httpx
import numpy as np import numpy as np
from numpy.typing import NDArray from numpy.typing import NDArray
from pydantic import BaseModel from pydantic import BaseModel
from llama_stack.apis.common.content_types import ( from llama_stack.apis.common.content_types import (
URL,
InterleavedContent, InterleavedContent,
) )
from llama_stack.apis.inference import OpenAIEmbeddingsRequestWithExtraBody from llama_stack.apis.inference import OpenAIEmbeddingsRequestWithExtraBody
from llama_stack.apis.tools import RAGDocument
from llama_stack.apis.vector_dbs import VectorDB
from llama_stack.apis.vector_io import Chunk, ChunkMetadata, QueryChunksResponse from llama_stack.apis.vector_io import Chunk, ChunkMetadata, QueryChunksResponse
from llama_stack.apis.vector_stores import VectorStore
from llama_stack.log import get_logger from llama_stack.log import get_logger
from llama_stack.models.llama.llama3.tokenizer import Tokenizer from llama_stack.models.llama.llama3.tokenizer import Tokenizer
from llama_stack.providers.datatypes import Api from llama_stack.providers.datatypes import Api
@ -129,31 +126,6 @@ def content_from_data_and_mime_type(data: bytes | str, mime_type: str | None, en
return "" return ""
async def content_from_doc(doc: RAGDocument) -> str:
if isinstance(doc.content, URL):
if doc.content.uri.startswith("data:"):
return content_from_data(doc.content.uri)
async with httpx.AsyncClient() as client:
r = await client.get(doc.content.uri)
if doc.mime_type == "application/pdf":
return parse_pdf(r.content)
return r.text
elif isinstance(doc.content, str):
pattern = re.compile("^(https?://|file://|data:)")
if pattern.match(doc.content):
if doc.content.startswith("data:"):
return content_from_data(doc.content)
async with httpx.AsyncClient() as client:
r = await client.get(doc.content)
if doc.mime_type == "application/pdf":
return parse_pdf(r.content)
return r.text
return doc.content
else:
# will raise ValueError if the content is not List[InterleavedContent] or InterleavedContent
return interleaved_content_as_str(doc.content)
def make_overlapped_chunks( def make_overlapped_chunks(
document_id: str, text: str, window_len: int, overlap_len: int, metadata: dict[str, Any] document_id: str, text: str, window_len: int, overlap_len: int, metadata: dict[str, Any]
) -> list[Chunk]: ) -> list[Chunk]:
@ -187,7 +159,7 @@ def make_overlapped_chunks(
updated_timestamp=int(time.time()), updated_timestamp=int(time.time()),
chunk_window=chunk_window, chunk_window=chunk_window,
chunk_tokenizer=default_tokenizer, chunk_tokenizer=default_tokenizer,
chunk_embedding_model=None, # This will be set in `VectorDBWithIndex.insert_chunks` chunk_embedding_model=None, # This will be set in `VectorStoreWithIndex.insert_chunks`
content_token_count=len(toks), content_token_count=len(toks),
metadata_token_count=len(metadata_tokens), metadata_token_count=len(metadata_tokens),
) )
@ -255,8 +227,8 @@ class EmbeddingIndex(ABC):
@dataclass @dataclass
class VectorDBWithIndex: class VectorStoreWithIndex:
vector_db: VectorDB vector_store: VectorStore
index: EmbeddingIndex index: EmbeddingIndex
inference_api: Api.inference inference_api: Api.inference
@ -269,14 +241,14 @@ class VectorDBWithIndex:
if c.embedding is None: if c.embedding is None:
chunks_to_embed.append(c) chunks_to_embed.append(c)
if c.chunk_metadata: if c.chunk_metadata:
c.chunk_metadata.chunk_embedding_model = self.vector_db.embedding_model c.chunk_metadata.chunk_embedding_model = self.vector_store.embedding_model
c.chunk_metadata.chunk_embedding_dimension = self.vector_db.embedding_dimension c.chunk_metadata.chunk_embedding_dimension = self.vector_store.embedding_dimension
else: else:
_validate_embedding(c.embedding, i, self.vector_db.embedding_dimension) _validate_embedding(c.embedding, i, self.vector_store.embedding_dimension)
if chunks_to_embed: if chunks_to_embed:
params = OpenAIEmbeddingsRequestWithExtraBody( params = OpenAIEmbeddingsRequestWithExtraBody(
model=self.vector_db.embedding_model, model=self.vector_store.embedding_model,
input=[c.content for c in chunks_to_embed], input=[c.content for c in chunks_to_embed],
) )
resp = await self.inference_api.openai_embeddings(params) resp = await self.inference_api.openai_embeddings(params)
@ -319,7 +291,7 @@ class VectorDBWithIndex:
return await self.index.query_keyword(query_string, k, score_threshold) return await self.index.query_keyword(query_string, k, score_threshold)
params = OpenAIEmbeddingsRequestWithExtraBody( params = OpenAIEmbeddingsRequestWithExtraBody(
model=self.vector_db.embedding_model, model=self.vector_store.embedding_model,
input=[query_string], input=[query_string],
) )
embeddings_response = await self.inference_api.openai_embeddings(params) embeddings_response = await self.inference_api.openai_embeddings(params)

View file

@ -238,6 +238,8 @@ if [[ "$STACK_CONFIG" == *"docker:"* && "$COLLECT_ONLY" == false ]]; then
echo "Stopping Docker container..." echo "Stopping Docker container..."
container_name="llama-stack-test-$DISTRO" container_name="llama-stack-test-$DISTRO"
if docker ps -a --format '{{.Names}}' | grep -q "^${container_name}$"; then if docker ps -a --format '{{.Names}}' | grep -q "^${container_name}$"; then
echo "Dumping container logs before stopping..."
docker logs "$container_name" > "docker-${DISTRO}-${INFERENCE_MODE}.log" 2>&1 || true
echo "Stopping and removing container: $container_name" echo "Stopping and removing container: $container_name"
docker stop "$container_name" 2>/dev/null || true docker stop "$container_name" 2>/dev/null || true
docker rm "$container_name" 2>/dev/null || true docker rm "$container_name" 2>/dev/null || true
@ -408,6 +410,21 @@ elif [ $exit_code -eq 5 ]; then
echo "⚠️ No tests collected (pattern matched no tests)" echo "⚠️ No tests collected (pattern matched no tests)"
else else
echo "❌ Tests failed" echo "❌ Tests failed"
echo ""
echo "=== Dumping last 100 lines of logs for debugging ==="
# Output server or container logs based on stack config
if [[ "$STACK_CONFIG" == *"server:"* && -f "server.log" ]]; then
echo "--- Last 100 lines of server.log ---"
tail -100 server.log
elif [[ "$STACK_CONFIG" == *"docker:"* ]]; then
docker_log_file="docker-${DISTRO}-${INFERENCE_MODE}.log"
if [[ -f "$docker_log_file" ]]; then
echo "--- Last 100 lines of $docker_log_file ---"
tail -100 "$docker_log_file"
fi
fi
exit 1 exit 1
fi fi

View file

@ -37,6 +37,9 @@ def pytest_sessionstart(session):
if "LLAMA_STACK_TEST_INFERENCE_MODE" not in os.environ: if "LLAMA_STACK_TEST_INFERENCE_MODE" not in os.environ:
os.environ["LLAMA_STACK_TEST_INFERENCE_MODE"] = "replay" os.environ["LLAMA_STACK_TEST_INFERENCE_MODE"] = "replay"
if "LLAMA_STACK_LOGGING" not in os.environ:
os.environ["LLAMA_STACK_LOGGING"] = "all=warning"
if "SQLITE_STORE_DIR" not in os.environ: if "SQLITE_STORE_DIR" not in os.environ:
os.environ["SQLITE_STORE_DIR"] = tempfile.mkdtemp() os.environ["SQLITE_STORE_DIR"] = tempfile.mkdtemp()

View file

@ -49,46 +49,50 @@ def client_with_empty_registry(client_with_models):
@vector_provider_wrapper @vector_provider_wrapper
def test_vector_db_retrieve(client_with_empty_registry, embedding_model_id, embedding_dimension, vector_io_provider_id): def test_vector_store_retrieve(
vector_db_name = "test_vector_db" client_with_empty_registry, embedding_model_id, embedding_dimension, vector_io_provider_id
):
vector_store_name = "test_vector_store"
create_response = client_with_empty_registry.vector_stores.create( create_response = client_with_empty_registry.vector_stores.create(
name=vector_db_name, name=vector_store_name,
extra_body={ extra_body={
"provider_id": vector_io_provider_id, "provider_id": vector_io_provider_id,
}, },
) )
actual_vector_db_id = create_response.id actual_vector_store_id = create_response.id
# Retrieve the vector store and validate its properties # Retrieve the vector store and validate its properties
response = client_with_empty_registry.vector_stores.retrieve(vector_store_id=actual_vector_db_id) response = client_with_empty_registry.vector_stores.retrieve(vector_store_id=actual_vector_store_id)
assert response is not None assert response is not None
assert response.id == actual_vector_db_id assert response.id == actual_vector_store_id
assert response.name == vector_db_name assert response.name == vector_store_name
assert response.id.startswith("vs_") assert response.id.startswith("vs_")
@vector_provider_wrapper @vector_provider_wrapper
def test_vector_db_register(client_with_empty_registry, embedding_model_id, embedding_dimension, vector_io_provider_id): def test_vector_store_register(
vector_db_name = "test_vector_db" client_with_empty_registry, embedding_model_id, embedding_dimension, vector_io_provider_id
):
vector_store_name = "test_vector_store"
response = client_with_empty_registry.vector_stores.create( response = client_with_empty_registry.vector_stores.create(
name=vector_db_name, name=vector_store_name,
extra_body={ extra_body={
"provider_id": vector_io_provider_id, "provider_id": vector_io_provider_id,
}, },
) )
actual_vector_db_id = response.id actual_vector_store_id = response.id
assert actual_vector_db_id.startswith("vs_") assert actual_vector_store_id.startswith("vs_")
assert actual_vector_db_id != vector_db_name assert actual_vector_store_id != vector_store_name
vector_stores = client_with_empty_registry.vector_stores.list() vector_stores = client_with_empty_registry.vector_stores.list()
assert len(vector_stores.data) == 1 assert len(vector_stores.data) == 1
vector_store = vector_stores.data[0] vector_store = vector_stores.data[0]
assert vector_store.id == actual_vector_db_id assert vector_store.id == actual_vector_store_id
assert vector_store.name == vector_db_name assert vector_store.name == vector_store_name
client_with_empty_registry.vector_stores.delete(vector_store_id=actual_vector_db_id) client_with_empty_registry.vector_stores.delete(vector_store_id=actual_vector_store_id)
vector_stores = client_with_empty_registry.vector_stores.list() vector_stores = client_with_empty_registry.vector_stores.list()
assert len(vector_stores.data) == 0 assert len(vector_stores.data) == 0
@ -108,23 +112,23 @@ def test_vector_db_register(client_with_empty_registry, embedding_model_id, embe
def test_insert_chunks( def test_insert_chunks(
client_with_empty_registry, embedding_model_id, embedding_dimension, sample_chunks, test_case, vector_io_provider_id client_with_empty_registry, embedding_model_id, embedding_dimension, sample_chunks, test_case, vector_io_provider_id
): ):
vector_db_name = "test_vector_db" vector_store_name = "test_vector_store"
create_response = client_with_empty_registry.vector_stores.create( create_response = client_with_empty_registry.vector_stores.create(
name=vector_db_name, name=vector_store_name,
extra_body={ extra_body={
"provider_id": vector_io_provider_id, "provider_id": vector_io_provider_id,
}, },
) )
actual_vector_db_id = create_response.id actual_vector_store_id = create_response.id
client_with_empty_registry.vector_io.insert( client_with_empty_registry.vector_io.insert(
vector_db_id=actual_vector_db_id, vector_db_id=actual_vector_store_id,
chunks=sample_chunks, chunks=sample_chunks,
) )
response = client_with_empty_registry.vector_io.query( response = client_with_empty_registry.vector_io.query(
vector_db_id=actual_vector_db_id, vector_db_id=actual_vector_store_id,
query="What is the capital of France?", query="What is the capital of France?",
) )
assert response is not None assert response is not None
@ -133,7 +137,7 @@ def test_insert_chunks(
query, expected_doc_id = test_case query, expected_doc_id = test_case
response = client_with_empty_registry.vector_io.query( response = client_with_empty_registry.vector_io.query(
vector_db_id=actual_vector_db_id, vector_db_id=actual_vector_store_id,
query=query, query=query,
) )
assert response is not None assert response is not None
@ -151,15 +155,15 @@ def test_insert_chunks_with_precomputed_embeddings(
"inline::qdrant": {"score_threshold": -1.0}, "inline::qdrant": {"score_threshold": -1.0},
"remote::qdrant": {"score_threshold": -1.0}, "remote::qdrant": {"score_threshold": -1.0},
} }
vector_db_name = "test_precomputed_embeddings_db" vector_store_name = "test_precomputed_embeddings_db"
register_response = client_with_empty_registry.vector_stores.create( register_response = client_with_empty_registry.vector_stores.create(
name=vector_db_name, name=vector_store_name,
extra_body={ extra_body={
"provider_id": vector_io_provider_id, "provider_id": vector_io_provider_id,
}, },
) )
actual_vector_db_id = register_response.id actual_vector_store_id = register_response.id
chunks_with_embeddings = [ chunks_with_embeddings = [
Chunk( Chunk(
@ -170,13 +174,13 @@ def test_insert_chunks_with_precomputed_embeddings(
] ]
client_with_empty_registry.vector_io.insert( client_with_empty_registry.vector_io.insert(
vector_db_id=actual_vector_db_id, vector_db_id=actual_vector_store_id,
chunks=chunks_with_embeddings, chunks=chunks_with_embeddings,
) )
provider = [p.provider_id for p in client_with_empty_registry.providers.list() if p.api == "vector_io"][0] provider = [p.provider_id for p in client_with_empty_registry.providers.list() if p.api == "vector_io"][0]
response = client_with_empty_registry.vector_io.query( response = client_with_empty_registry.vector_io.query(
vector_db_id=actual_vector_db_id, vector_db_id=actual_vector_store_id,
query="precomputed embedding test", query="precomputed embedding test",
params=vector_io_provider_params_dict.get(provider, None), params=vector_io_provider_params_dict.get(provider, None),
) )
@ -200,16 +204,16 @@ def test_query_returns_valid_object_when_identical_to_embedding_in_vdb(
"remote::qdrant": {"score_threshold": 0.0}, "remote::qdrant": {"score_threshold": 0.0},
"inline::qdrant": {"score_threshold": 0.0}, "inline::qdrant": {"score_threshold": 0.0},
} }
vector_db_name = "test_precomputed_embeddings_db" vector_store_name = "test_precomputed_embeddings_db"
register_response = client_with_empty_registry.vector_stores.create( register_response = client_with_empty_registry.vector_stores.create(
name=vector_db_name, name=vector_store_name,
extra_body={ extra_body={
"embedding_model": embedding_model_id, "embedding_model": embedding_model_id,
"provider_id": vector_io_provider_id, "provider_id": vector_io_provider_id,
}, },
) )
actual_vector_db_id = register_response.id actual_vector_store_id = register_response.id
chunks_with_embeddings = [ chunks_with_embeddings = [
Chunk( Chunk(
@ -220,13 +224,13 @@ def test_query_returns_valid_object_when_identical_to_embedding_in_vdb(
] ]
client_with_empty_registry.vector_io.insert( client_with_empty_registry.vector_io.insert(
vector_db_id=actual_vector_db_id, vector_db_id=actual_vector_store_id,
chunks=chunks_with_embeddings, chunks=chunks_with_embeddings,
) )
provider = [p.provider_id for p in client_with_empty_registry.providers.list() if p.api == "vector_io"][0] provider = [p.provider_id for p in client_with_empty_registry.providers.list() if p.api == "vector_io"][0]
response = client_with_empty_registry.vector_io.query( response = client_with_empty_registry.vector_io.query(
vector_db_id=actual_vector_db_id, vector_db_id=actual_vector_store_id,
query="duplicate", query="duplicate",
params=vector_io_provider_params_dict.get(provider, None), params=vector_io_provider_params_dict.get(provider, None),
) )

View file

@ -21,7 +21,7 @@ async def test_single_provider_auto_selection():
Mock(identifier="all-MiniLM-L6-v2", model_type="embedding", metadata={"embedding_dimension": 384}) Mock(identifier="all-MiniLM-L6-v2", model_type="embedding", metadata={"embedding_dimension": 384})
] ]
) )
mock_routing_table.register_vector_db = AsyncMock( mock_routing_table.register_vector_store = AsyncMock(
return_value=Mock(identifier="vs_123", provider_id="inline::faiss", provider_resource_id="vs_123") return_value=Mock(identifier="vs_123", provider_id="inline::faiss", provider_resource_id="vs_123")
) )
mock_routing_table.get_provider_impl = AsyncMock( mock_routing_table.get_provider_impl = AsyncMock(

View file

@ -4,138 +4,11 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
from unittest.mock import AsyncMock, MagicMock, patch from unittest.mock import patch
import pytest import pytest
from llama_stack.apis.common.content_types import URL, TextContentItem from llama_stack.providers.utils.memory.vector_store import content_from_data_and_mime_type
from llama_stack.apis.tools import RAGDocument
from llama_stack.providers.utils.memory.vector_store import content_from_data_and_mime_type, content_from_doc
async def test_content_from_doc_with_url():
"""Test extracting content from RAGDocument with URL content."""
mock_url = URL(uri="https://example.com")
mock_doc = RAGDocument(document_id="foo", content=mock_url)
mock_response = MagicMock()
mock_response.text = "Sample content from URL"
with patch("httpx.AsyncClient") as mock_client:
mock_instance = AsyncMock()
mock_instance.get.return_value = mock_response
mock_client.return_value.__aenter__.return_value = mock_instance
result = await content_from_doc(mock_doc)
assert result == "Sample content from URL"
mock_instance.get.assert_called_once_with(mock_url.uri)
async def test_content_from_doc_with_pdf_url():
"""Test extracting content from RAGDocument with URL pointing to a PDF."""
mock_url = URL(uri="https://example.com/document.pdf")
mock_doc = RAGDocument(document_id="foo", content=mock_url, mime_type="application/pdf")
mock_response = MagicMock()
mock_response.content = b"PDF binary data"
with (
patch("httpx.AsyncClient") as mock_client,
patch("llama_stack.providers.utils.memory.vector_store.parse_pdf") as mock_parse_pdf,
):
mock_instance = AsyncMock()
mock_instance.get.return_value = mock_response
mock_client.return_value.__aenter__.return_value = mock_instance
mock_parse_pdf.return_value = "Extracted PDF content"
result = await content_from_doc(mock_doc)
assert result == "Extracted PDF content"
mock_instance.get.assert_called_once_with(mock_url.uri)
mock_parse_pdf.assert_called_once_with(b"PDF binary data")
async def test_content_from_doc_with_data_url():
"""Test extracting content from RAGDocument with data URL content."""
data_url = "data:text/plain;base64,SGVsbG8gV29ybGQ=" # "Hello World" base64 encoded
mock_url = URL(uri=data_url)
mock_doc = RAGDocument(document_id="foo", content=mock_url)
with patch("llama_stack.providers.utils.memory.vector_store.content_from_data") as mock_content_from_data:
mock_content_from_data.return_value = "Hello World"
result = await content_from_doc(mock_doc)
assert result == "Hello World"
mock_content_from_data.assert_called_once_with(data_url)
async def test_content_from_doc_with_string():
"""Test extracting content from RAGDocument with string content."""
content_string = "This is plain text content"
mock_doc = RAGDocument(document_id="foo", content=content_string)
result = await content_from_doc(mock_doc)
assert result == content_string
async def test_content_from_doc_with_string_url():
"""Test extracting content from RAGDocument with string URL content."""
url_string = "https://example.com"
mock_doc = RAGDocument(document_id="foo", content=url_string)
mock_response = MagicMock()
mock_response.text = "Sample content from URL string"
with patch("httpx.AsyncClient") as mock_client:
mock_instance = AsyncMock()
mock_instance.get.return_value = mock_response
mock_client.return_value.__aenter__.return_value = mock_instance
result = await content_from_doc(mock_doc)
assert result == "Sample content from URL string"
mock_instance.get.assert_called_once_with(url_string)
async def test_content_from_doc_with_string_pdf_url():
"""Test extracting content from RAGDocument with string URL pointing to a PDF."""
url_string = "https://example.com/document.pdf"
mock_doc = RAGDocument(document_id="foo", content=url_string, mime_type="application/pdf")
mock_response = MagicMock()
mock_response.content = b"PDF binary data"
with (
patch("httpx.AsyncClient") as mock_client,
patch("llama_stack.providers.utils.memory.vector_store.parse_pdf") as mock_parse_pdf,
):
mock_instance = AsyncMock()
mock_instance.get.return_value = mock_response
mock_client.return_value.__aenter__.return_value = mock_instance
mock_parse_pdf.return_value = "Extracted PDF content from string URL"
result = await content_from_doc(mock_doc)
assert result == "Extracted PDF content from string URL"
mock_instance.get.assert_called_once_with(url_string)
mock_parse_pdf.assert_called_once_with(b"PDF binary data")
async def test_content_from_doc_with_interleaved_content():
"""Test extracting content from RAGDocument with InterleavedContent (the new case added in the commit)."""
interleaved_content = [TextContentItem(text="First item"), TextContentItem(text="Second item")]
mock_doc = RAGDocument(document_id="foo", content=interleaved_content)
with patch("llama_stack.providers.utils.memory.vector_store.interleaved_content_as_str") as mock_interleaved:
mock_interleaved.return_value = "First item\nSecond item"
result = await content_from_doc(mock_doc)
assert result == "First item\nSecond item"
mock_interleaved.assert_called_once_with(interleaved_content)
def test_content_from_data_and_mime_type_success_utf8(): def test_content_from_data_and_mime_type_success_utf8():
@ -178,41 +51,3 @@ def test_content_from_data_and_mime_type_both_encodings_fail():
# Should raise an exception instead of returning empty string # Should raise an exception instead of returning empty string
with pytest.raises(UnicodeDecodeError): with pytest.raises(UnicodeDecodeError):
content_from_data_and_mime_type(data, mime_type) content_from_data_and_mime_type(data, mime_type)
async def test_memory_tool_error_handling():
"""Test that memory tool handles various failures gracefully without crashing."""
from llama_stack.providers.inline.tool_runtime.rag.config import RagToolRuntimeConfig
from llama_stack.providers.inline.tool_runtime.rag.memory import MemoryToolRuntimeImpl
config = RagToolRuntimeConfig()
memory_tool = MemoryToolRuntimeImpl(
config=config,
vector_io_api=AsyncMock(),
inference_api=AsyncMock(),
files_api=AsyncMock(),
)
docs = [
RAGDocument(document_id="good_doc", content="Good content", metadata={}),
RAGDocument(document_id="bad_url_doc", content=URL(uri="https://bad.url"), metadata={}),
RAGDocument(document_id="another_good_doc", content="Another good content", metadata={}),
]
mock_file1 = MagicMock()
mock_file1.id = "file_good1"
mock_file2 = MagicMock()
mock_file2.id = "file_good2"
memory_tool.files_api.openai_upload_file.side_effect = [mock_file1, mock_file2]
with patch("httpx.AsyncClient") as mock_client:
mock_instance = AsyncMock()
mock_instance.get.side_effect = Exception("Bad URL")
mock_client.return_value.__aenter__.return_value = mock_instance
# won't raise exception despite one document failing
await memory_tool.insert(docs, "vector_store_123")
# processed 2 documents successfully, skipped 1
assert memory_tool.files_api.openai_upload_file.call_count == 2
assert memory_tool.vector_io_api.openai_attach_file_to_vector_store.call_count == 2

View file

@ -10,8 +10,8 @@ from unittest.mock import AsyncMock, MagicMock, patch
import numpy as np import numpy as np
import pytest import pytest
from llama_stack.apis.vector_dbs import VectorDB
from llama_stack.apis.vector_io import Chunk, ChunkMetadata, QueryChunksResponse from llama_stack.apis.vector_io import Chunk, ChunkMetadata, QueryChunksResponse
from llama_stack.apis.vector_stores import VectorStore
from llama_stack.core.storage.datatypes import KVStoreReference, SqliteKVStoreConfig from llama_stack.core.storage.datatypes import KVStoreReference, SqliteKVStoreConfig
from llama_stack.providers.inline.vector_io.faiss.config import FaissVectorIOConfig from llama_stack.providers.inline.vector_io.faiss.config import FaissVectorIOConfig
from llama_stack.providers.inline.vector_io.faiss.faiss import FaissIndex, FaissVectorIOAdapter from llama_stack.providers.inline.vector_io.faiss.faiss import FaissIndex, FaissVectorIOAdapter
@ -31,7 +31,7 @@ def vector_provider(request):
@pytest.fixture @pytest.fixture
def vector_db_id() -> str: def vector_store_id() -> str:
return f"test-vector-db-{random.randint(1, 100)}" return f"test-vector-db-{random.randint(1, 100)}"
@ -149,8 +149,8 @@ async def sqlite_vec_adapter(sqlite_vec_db_path, unique_kvstore_config, mock_inf
) )
collection_id = f"sqlite_test_collection_{np.random.randint(1e6)}" collection_id = f"sqlite_test_collection_{np.random.randint(1e6)}"
await adapter.initialize() await adapter.initialize()
await adapter.register_vector_db( await adapter.register_vector_store(
VectorDB( VectorStore(
identifier=collection_id, identifier=collection_id,
provider_id="test_provider", provider_id="test_provider",
embedding_model="test_model", embedding_model="test_model",
@ -186,8 +186,8 @@ async def faiss_vec_adapter(unique_kvstore_config, mock_inference_api, embedding
files_api=None, files_api=None,
) )
await adapter.initialize() await adapter.initialize()
await adapter.register_vector_db( await adapter.register_vector_store(
VectorDB( VectorStore(
identifier=f"faiss_test_collection_{np.random.randint(1e6)}", identifier=f"faiss_test_collection_{np.random.randint(1e6)}",
provider_id="test_provider", provider_id="test_provider",
embedding_model="test_model", embedding_model="test_model",
@ -215,7 +215,7 @@ def mock_psycopg2_connection():
async def pgvector_vec_index(embedding_dimension, mock_psycopg2_connection): async def pgvector_vec_index(embedding_dimension, mock_psycopg2_connection):
connection, cursor = mock_psycopg2_connection connection, cursor = mock_psycopg2_connection
vector_db = VectorDB( vector_store = VectorStore(
identifier="test-vector-db", identifier="test-vector-db",
embedding_model="test-model", embedding_model="test-model",
embedding_dimension=embedding_dimension, embedding_dimension=embedding_dimension,
@ -225,7 +225,7 @@ async def pgvector_vec_index(embedding_dimension, mock_psycopg2_connection):
with patch("llama_stack.providers.remote.vector_io.pgvector.pgvector.psycopg2"): with patch("llama_stack.providers.remote.vector_io.pgvector.pgvector.psycopg2"):
with patch("llama_stack.providers.remote.vector_io.pgvector.pgvector.execute_values"): with patch("llama_stack.providers.remote.vector_io.pgvector.pgvector.execute_values"):
index = PGVectorIndex(vector_db, embedding_dimension, connection, distance_metric="COSINE") index = PGVectorIndex(vector_store, embedding_dimension, connection, distance_metric="COSINE")
index._test_chunks = [] index._test_chunks = []
original_add_chunks = index.add_chunks original_add_chunks = index.add_chunks
@ -281,30 +281,30 @@ async def pgvector_vec_adapter(unique_kvstore_config, mock_inference_api, embedd
await adapter.initialize() await adapter.initialize()
adapter.conn = mock_conn adapter.conn = mock_conn
async def mock_insert_chunks(vector_db_id, chunks, ttl_seconds=None): async def mock_insert_chunks(vector_store_id, chunks, ttl_seconds=None):
index = await adapter._get_and_cache_vector_db_index(vector_db_id) index = await adapter._get_and_cache_vector_store_index(vector_store_id)
if not index: if not index:
raise ValueError(f"Vector DB {vector_db_id} not found") raise ValueError(f"Vector DB {vector_store_id} not found")
await index.insert_chunks(chunks) await index.insert_chunks(chunks)
adapter.insert_chunks = mock_insert_chunks adapter.insert_chunks = mock_insert_chunks
async def mock_query_chunks(vector_db_id, query, params=None): async def mock_query_chunks(vector_store_id, query, params=None):
index = await adapter._get_and_cache_vector_db_index(vector_db_id) index = await adapter._get_and_cache_vector_store_index(vector_store_id)
if not index: if not index:
raise ValueError(f"Vector DB {vector_db_id} not found") raise ValueError(f"Vector DB {vector_store_id} not found")
return await index.query_chunks(query, params) return await index.query_chunks(query, params)
adapter.query_chunks = mock_query_chunks adapter.query_chunks = mock_query_chunks
test_vector_db = VectorDB( test_vector_store = VectorStore(
identifier=f"pgvector_test_collection_{random.randint(1, 1_000_000)}", identifier=f"pgvector_test_collection_{random.randint(1, 1_000_000)}",
provider_id="test_provider", provider_id="test_provider",
embedding_model="test_model", embedding_model="test_model",
embedding_dimension=embedding_dimension, embedding_dimension=embedding_dimension,
) )
await adapter.register_vector_db(test_vector_db) await adapter.register_vector_store(test_vector_store)
adapter.test_collection_id = test_vector_db.identifier adapter.test_collection_id = test_vector_store.identifier
yield adapter yield adapter
await adapter.shutdown() await adapter.shutdown()

View file

@ -11,8 +11,8 @@ import numpy as np
import pytest import pytest
from llama_stack.apis.files import Files from llama_stack.apis.files import Files
from llama_stack.apis.vector_dbs import VectorDB
from llama_stack.apis.vector_io import Chunk, QueryChunksResponse from llama_stack.apis.vector_io import Chunk, QueryChunksResponse
from llama_stack.apis.vector_stores import VectorStore
from llama_stack.providers.datatypes import HealthStatus from llama_stack.providers.datatypes import HealthStatus
from llama_stack.providers.inline.vector_io.faiss.config import FaissVectorIOConfig from llama_stack.providers.inline.vector_io.faiss.config import FaissVectorIOConfig
from llama_stack.providers.inline.vector_io.faiss.faiss import ( from llama_stack.providers.inline.vector_io.faiss.faiss import (
@ -43,8 +43,8 @@ def embedding_dimension():
@pytest.fixture @pytest.fixture
def vector_db_id(): def vector_store_id():
return "test_vector_db" return "test_vector_store"
@pytest.fixture @pytest.fixture
@ -61,12 +61,12 @@ def sample_embeddings(embedding_dimension):
@pytest.fixture @pytest.fixture
def mock_vector_db(vector_db_id, embedding_dimension) -> MagicMock: def mock_vector_store(vector_store_id, embedding_dimension) -> MagicMock:
mock_vector_db = MagicMock(spec=VectorDB) mock_vector_store = MagicMock(spec=VectorStore)
mock_vector_db.embedding_model = "mock_embedding_model" mock_vector_store.embedding_model = "mock_embedding_model"
mock_vector_db.identifier = vector_db_id mock_vector_store.identifier = vector_store_id
mock_vector_db.embedding_dimension = embedding_dimension mock_vector_store.embedding_dimension = embedding_dimension
return mock_vector_db return mock_vector_store
@pytest.fixture @pytest.fixture

View file

@ -12,7 +12,6 @@ import numpy as np
import pytest import pytest
from llama_stack.apis.common.errors import VectorStoreNotFoundError from llama_stack.apis.common.errors import VectorStoreNotFoundError
from llama_stack.apis.vector_dbs import VectorDB
from llama_stack.apis.vector_io import ( from llama_stack.apis.vector_io import (
Chunk, Chunk,
OpenAICreateVectorStoreFileBatchRequestWithExtraBody, OpenAICreateVectorStoreFileBatchRequestWithExtraBody,
@ -21,6 +20,7 @@ from llama_stack.apis.vector_io import (
VectorStoreChunkingStrategyAuto, VectorStoreChunkingStrategyAuto,
VectorStoreFileObject, VectorStoreFileObject,
) )
from llama_stack.apis.vector_stores import VectorStore
from llama_stack.providers.inline.vector_io.sqlite_vec.sqlite_vec import VECTOR_DBS_PREFIX from llama_stack.providers.inline.vector_io.sqlite_vec.sqlite_vec import VECTOR_DBS_PREFIX
# This test is a unit test for the inline VectorIO providers. This should only contain # This test is a unit test for the inline VectorIO providers. This should only contain
@ -71,7 +71,7 @@ async def test_chunk_id_conflict(vector_index, sample_chunks, embedding_dimensio
async def test_initialize_adapter_with_existing_kvstore(vector_io_adapter): async def test_initialize_adapter_with_existing_kvstore(vector_io_adapter):
key = f"{VECTOR_DBS_PREFIX}db1" key = f"{VECTOR_DBS_PREFIX}db1"
dummy = VectorDB( dummy = VectorStore(
identifier="foo_db", provider_id="test_provider", embedding_model="test_model", embedding_dimension=128 identifier="foo_db", provider_id="test_provider", embedding_model="test_model", embedding_dimension=128
) )
await vector_io_adapter.kvstore.set(key=key, value=json.dumps(dummy.model_dump())) await vector_io_adapter.kvstore.set(key=key, value=json.dumps(dummy.model_dump()))
@ -81,10 +81,10 @@ async def test_initialize_adapter_with_existing_kvstore(vector_io_adapter):
async def test_persistence_across_adapter_restarts(vector_io_adapter): async def test_persistence_across_adapter_restarts(vector_io_adapter):
await vector_io_adapter.initialize() await vector_io_adapter.initialize()
dummy = VectorDB( dummy = VectorStore(
identifier="foo_db", provider_id="test_provider", embedding_model="test_model", embedding_dimension=128 identifier="foo_db", provider_id="test_provider", embedding_model="test_model", embedding_dimension=128
) )
await vector_io_adapter.register_vector_db(dummy) await vector_io_adapter.register_vector_store(dummy)
await vector_io_adapter.shutdown() await vector_io_adapter.shutdown()
await vector_io_adapter.initialize() await vector_io_adapter.initialize()
@ -92,15 +92,15 @@ async def test_persistence_across_adapter_restarts(vector_io_adapter):
await vector_io_adapter.shutdown() await vector_io_adapter.shutdown()
async def test_register_and_unregister_vector_db(vector_io_adapter): async def test_register_and_unregister_vector_store(vector_io_adapter):
unique_id = f"foo_db_{np.random.randint(1e6)}" unique_id = f"foo_db_{np.random.randint(1e6)}"
dummy = VectorDB( dummy = VectorStore(
identifier=unique_id, provider_id="test_provider", embedding_model="test_model", embedding_dimension=128 identifier=unique_id, provider_id="test_provider", embedding_model="test_model", embedding_dimension=128
) )
await vector_io_adapter.register_vector_db(dummy) await vector_io_adapter.register_vector_store(dummy)
assert dummy.identifier in vector_io_adapter.cache assert dummy.identifier in vector_io_adapter.cache
await vector_io_adapter.unregister_vector_db(dummy.identifier) await vector_io_adapter.unregister_vector_store(dummy.identifier)
assert dummy.identifier not in vector_io_adapter.cache assert dummy.identifier not in vector_io_adapter.cache
@ -121,7 +121,7 @@ async def test_insert_chunks_calls_underlying_index(vector_io_adapter):
async def test_insert_chunks_missing_db_raises(vector_io_adapter): async def test_insert_chunks_missing_db_raises(vector_io_adapter):
vector_io_adapter._get_and_cache_vector_db_index = AsyncMock(return_value=None) vector_io_adapter._get_and_cache_vector_store_index = AsyncMock(return_value=None)
with pytest.raises(ValueError): with pytest.raises(ValueError):
await vector_io_adapter.insert_chunks("db_not_exist", []) await vector_io_adapter.insert_chunks("db_not_exist", [])
@ -170,7 +170,7 @@ async def test_query_chunks_calls_underlying_index_and_returns(vector_io_adapter
async def test_query_chunks_missing_db_raises(vector_io_adapter): async def test_query_chunks_missing_db_raises(vector_io_adapter):
vector_io_adapter._get_and_cache_vector_db_index = AsyncMock(return_value=None) vector_io_adapter._get_and_cache_vector_store_index = AsyncMock(return_value=None)
with pytest.raises(ValueError): with pytest.raises(ValueError):
await vector_io_adapter.query_chunks("db_missing", "q", None) await vector_io_adapter.query_chunks("db_missing", "q", None)
@ -182,7 +182,7 @@ async def test_save_openai_vector_store(vector_io_adapter):
"id": store_id, "id": store_id,
"name": "Test Store", "name": "Test Store",
"description": "A test OpenAI vector store", "description": "A test OpenAI vector store",
"vector_db_id": "test_db", "vector_store_id": "test_db",
"embedding_model": "test_model", "embedding_model": "test_model",
} }
@ -198,7 +198,7 @@ async def test_update_openai_vector_store(vector_io_adapter):
"id": store_id, "id": store_id,
"name": "Test Store", "name": "Test Store",
"description": "A test OpenAI vector store", "description": "A test OpenAI vector store",
"vector_db_id": "test_db", "vector_store_id": "test_db",
"embedding_model": "test_model", "embedding_model": "test_model",
} }
@ -214,7 +214,7 @@ async def test_delete_openai_vector_store(vector_io_adapter):
"id": store_id, "id": store_id,
"name": "Test Store", "name": "Test Store",
"description": "A test OpenAI vector store", "description": "A test OpenAI vector store",
"vector_db_id": "test_db", "vector_store_id": "test_db",
"embedding_model": "test_model", "embedding_model": "test_model",
} }
@ -229,7 +229,7 @@ async def test_load_openai_vector_stores(vector_io_adapter):
"id": store_id, "id": store_id,
"name": "Test Store", "name": "Test Store",
"description": "A test OpenAI vector store", "description": "A test OpenAI vector store",
"vector_db_id": "test_db", "vector_store_id": "test_db",
"embedding_model": "test_model", "embedding_model": "test_model",
} }
@ -998,8 +998,8 @@ async def test_max_concurrent_files_per_batch(vector_io_adapter):
async def test_embedding_config_from_metadata(vector_io_adapter): async def test_embedding_config_from_metadata(vector_io_adapter):
"""Test that embedding configuration is correctly extracted from metadata.""" """Test that embedding configuration is correctly extracted from metadata."""
# Mock register_vector_db to avoid actual registration # Mock register_vector_store to avoid actual registration
vector_io_adapter.register_vector_db = AsyncMock() vector_io_adapter.register_vector_store = AsyncMock()
# Set provider_id attribute for the adapter # Set provider_id attribute for the adapter
vector_io_adapter.__provider_id__ = "test_provider" vector_io_adapter.__provider_id__ = "test_provider"
@ -1015,9 +1015,9 @@ async def test_embedding_config_from_metadata(vector_io_adapter):
await vector_io_adapter.openai_create_vector_store(params) await vector_io_adapter.openai_create_vector_store(params)
# Verify VectorDB was registered with correct embedding config from metadata # Verify VectorStore was registered with correct embedding config from metadata
vector_io_adapter.register_vector_db.assert_called_once() vector_io_adapter.register_vector_store.assert_called_once()
call_args = vector_io_adapter.register_vector_db.call_args[0][0] call_args = vector_io_adapter.register_vector_store.call_args[0][0]
assert call_args.embedding_model == "test-embedding-model" assert call_args.embedding_model == "test-embedding-model"
assert call_args.embedding_dimension == 512 assert call_args.embedding_dimension == 512
@ -1025,8 +1025,8 @@ async def test_embedding_config_from_metadata(vector_io_adapter):
async def test_embedding_config_from_extra_body(vector_io_adapter): async def test_embedding_config_from_extra_body(vector_io_adapter):
"""Test that embedding configuration is correctly extracted from extra_body when metadata is empty.""" """Test that embedding configuration is correctly extracted from extra_body when metadata is empty."""
# Mock register_vector_db to avoid actual registration # Mock register_vector_store to avoid actual registration
vector_io_adapter.register_vector_db = AsyncMock() vector_io_adapter.register_vector_store = AsyncMock()
# Set provider_id attribute for the adapter # Set provider_id attribute for the adapter
vector_io_adapter.__provider_id__ = "test_provider" vector_io_adapter.__provider_id__ = "test_provider"
@ -1042,9 +1042,9 @@ async def test_embedding_config_from_extra_body(vector_io_adapter):
await vector_io_adapter.openai_create_vector_store(params) await vector_io_adapter.openai_create_vector_store(params)
# Verify VectorDB was registered with correct embedding config from extra_body # Verify VectorStore was registered with correct embedding config from extra_body
vector_io_adapter.register_vector_db.assert_called_once() vector_io_adapter.register_vector_store.assert_called_once()
call_args = vector_io_adapter.register_vector_db.call_args[0][0] call_args = vector_io_adapter.register_vector_store.call_args[0][0]
assert call_args.embedding_model == "extra-body-model" assert call_args.embedding_model == "extra-body-model"
assert call_args.embedding_dimension == 1024 assert call_args.embedding_dimension == 1024
@ -1052,8 +1052,8 @@ async def test_embedding_config_from_extra_body(vector_io_adapter):
async def test_embedding_config_consistency_check_passes(vector_io_adapter): async def test_embedding_config_consistency_check_passes(vector_io_adapter):
"""Test that consistent embedding config in both metadata and extra_body passes validation.""" """Test that consistent embedding config in both metadata and extra_body passes validation."""
# Mock register_vector_db to avoid actual registration # Mock register_vector_store to avoid actual registration
vector_io_adapter.register_vector_db = AsyncMock() vector_io_adapter.register_vector_store = AsyncMock()
# Set provider_id attribute for the adapter # Set provider_id attribute for the adapter
vector_io_adapter.__provider_id__ = "test_provider" vector_io_adapter.__provider_id__ = "test_provider"
@ -1073,8 +1073,8 @@ async def test_embedding_config_consistency_check_passes(vector_io_adapter):
await vector_io_adapter.openai_create_vector_store(params) await vector_io_adapter.openai_create_vector_store(params)
# Should not raise any error and use metadata config # Should not raise any error and use metadata config
vector_io_adapter.register_vector_db.assert_called_once() vector_io_adapter.register_vector_store.assert_called_once()
call_args = vector_io_adapter.register_vector_db.call_args[0][0] call_args = vector_io_adapter.register_vector_store.call_args[0][0]
assert call_args.embedding_model == "consistent-model" assert call_args.embedding_model == "consistent-model"
assert call_args.embedding_dimension == 768 assert call_args.embedding_dimension == 768
@ -1082,8 +1082,8 @@ async def test_embedding_config_consistency_check_passes(vector_io_adapter):
async def test_embedding_config_inconsistency_errors(vector_io_adapter): async def test_embedding_config_inconsistency_errors(vector_io_adapter):
"""Test that inconsistent embedding config between metadata and extra_body raises errors.""" """Test that inconsistent embedding config between metadata and extra_body raises errors."""
# Mock register_vector_db to avoid actual registration # Mock register_vector_store to avoid actual registration
vector_io_adapter.register_vector_db = AsyncMock() vector_io_adapter.register_vector_store = AsyncMock()
# Set provider_id attribute for the adapter # Set provider_id attribute for the adapter
vector_io_adapter.__provider_id__ = "test_provider" vector_io_adapter.__provider_id__ = "test_provider"
@ -1104,7 +1104,7 @@ async def test_embedding_config_inconsistency_errors(vector_io_adapter):
await vector_io_adapter.openai_create_vector_store(params) await vector_io_adapter.openai_create_vector_store(params)
# Reset mock for second test # Reset mock for second test
vector_io_adapter.register_vector_db.reset_mock() vector_io_adapter.register_vector_store.reset_mock()
# Test with inconsistent embedding dimension # Test with inconsistent embedding dimension
params = OpenAICreateVectorStoreRequestWithExtraBody( params = OpenAICreateVectorStoreRequestWithExtraBody(
@ -1126,8 +1126,8 @@ async def test_embedding_config_inconsistency_errors(vector_io_adapter):
async def test_embedding_config_defaults_when_missing(vector_io_adapter): async def test_embedding_config_defaults_when_missing(vector_io_adapter):
"""Test that embedding dimension defaults to 768 when not provided.""" """Test that embedding dimension defaults to 768 when not provided."""
# Mock register_vector_db to avoid actual registration # Mock register_vector_store to avoid actual registration
vector_io_adapter.register_vector_db = AsyncMock() vector_io_adapter.register_vector_store = AsyncMock()
# Set provider_id attribute for the adapter # Set provider_id attribute for the adapter
vector_io_adapter.__provider_id__ = "test_provider" vector_io_adapter.__provider_id__ = "test_provider"
@ -1143,8 +1143,8 @@ async def test_embedding_config_defaults_when_missing(vector_io_adapter):
await vector_io_adapter.openai_create_vector_store(params) await vector_io_adapter.openai_create_vector_store(params)
# Should default to 768 dimensions # Should default to 768 dimensions
vector_io_adapter.register_vector_db.assert_called_once() vector_io_adapter.register_vector_store.assert_called_once()
call_args = vector_io_adapter.register_vector_db.call_args[0][0] call_args = vector_io_adapter.register_vector_store.call_args[0][0]
assert call_args.embedding_model == "model-without-dimension" assert call_args.embedding_model == "model-without-dimension"
assert call_args.embedding_dimension == 768 assert call_args.embedding_dimension == 768
@ -1152,8 +1152,8 @@ async def test_embedding_config_defaults_when_missing(vector_io_adapter):
async def test_embedding_config_required_model_missing(vector_io_adapter): async def test_embedding_config_required_model_missing(vector_io_adapter):
"""Test that missing embedding model raises error.""" """Test that missing embedding model raises error."""
# Mock register_vector_db to avoid actual registration # Mock register_vector_store to avoid actual registration
vector_io_adapter.register_vector_db = AsyncMock() vector_io_adapter.register_vector_store = AsyncMock()
# Set provider_id attribute for the adapter # Set provider_id attribute for the adapter
vector_io_adapter.__provider_id__ = "test_provider" vector_io_adapter.__provider_id__ = "test_provider"
# Mock the default model lookup to return None (no default model available) # Mock the default model lookup to return None (no default model available)

View file

@ -1,138 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from unittest.mock import AsyncMock, MagicMock
import pytest
from llama_stack.apis.tools.rag_tool import RAGQueryConfig
from llama_stack.apis.vector_io import (
Chunk,
ChunkMetadata,
QueryChunksResponse,
)
from llama_stack.providers.inline.tool_runtime.rag.memory import MemoryToolRuntimeImpl
class TestRagQuery:
async def test_query_raises_on_empty_vector_db_ids(self):
rag_tool = MemoryToolRuntimeImpl(
config=MagicMock(), vector_io_api=MagicMock(), inference_api=MagicMock(), files_api=MagicMock()
)
with pytest.raises(ValueError):
await rag_tool.query(content=MagicMock(), vector_db_ids=[])
async def test_query_chunk_metadata_handling(self):
rag_tool = MemoryToolRuntimeImpl(
config=MagicMock(), vector_io_api=MagicMock(), inference_api=MagicMock(), files_api=MagicMock()
)
content = "test query content"
vector_db_ids = ["db1"]
chunk_metadata = ChunkMetadata(
document_id="doc1",
chunk_id="chunk1",
source="test_source",
metadata_token_count=5,
)
interleaved_content = MagicMock()
chunk = Chunk(
content=interleaved_content,
metadata={
"key1": "value1",
"token_count": 10,
"metadata_token_count": 5,
# Note this is inserted into `metadata` during MemoryToolRuntimeImpl().insert()
"document_id": "doc1",
},
stored_chunk_id="chunk1",
chunk_metadata=chunk_metadata,
)
query_response = QueryChunksResponse(chunks=[chunk], scores=[1.0])
rag_tool.vector_io_api.query_chunks = AsyncMock(return_value=query_response)
result = await rag_tool.query(content=content, vector_db_ids=vector_db_ids)
assert result is not None
expected_metadata_string = (
"Metadata: {'chunk_id': 'chunk1', 'document_id': 'doc1', 'source': 'test_source', 'key1': 'value1'}"
)
assert expected_metadata_string in result.content[1].text
assert result.content is not None
async def test_query_raises_incorrect_mode(self):
with pytest.raises(ValueError):
RAGQueryConfig(mode="invalid_mode")
async def test_query_accepts_valid_modes(self):
default_config = RAGQueryConfig() # Test default (vector)
assert default_config.mode == "vector"
vector_config = RAGQueryConfig(mode="vector") # Test vector
assert vector_config.mode == "vector"
keyword_config = RAGQueryConfig(mode="keyword") # Test keyword
assert keyword_config.mode == "keyword"
hybrid_config = RAGQueryConfig(mode="hybrid") # Test hybrid
assert hybrid_config.mode == "hybrid"
# Test that invalid mode raises an error
with pytest.raises(ValueError):
RAGQueryConfig(mode="wrong_mode")
async def test_query_adds_vector_db_id_to_chunk_metadata(self):
rag_tool = MemoryToolRuntimeImpl(
config=MagicMock(),
vector_io_api=MagicMock(),
inference_api=MagicMock(),
files_api=MagicMock(),
)
vector_db_ids = ["db1", "db2"]
# Fake chunks from each DB
chunk_metadata1 = ChunkMetadata(
document_id="doc1",
chunk_id="chunk1",
source="test_source1",
metadata_token_count=5,
)
chunk1 = Chunk(
content="chunk from db1",
metadata={"vector_db_id": "db1", "document_id": "doc1"},
stored_chunk_id="c1",
chunk_metadata=chunk_metadata1,
)
chunk_metadata2 = ChunkMetadata(
document_id="doc2",
chunk_id="chunk2",
source="test_source2",
metadata_token_count=5,
)
chunk2 = Chunk(
content="chunk from db2",
metadata={"vector_db_id": "db2", "document_id": "doc2"},
stored_chunk_id="c2",
chunk_metadata=chunk_metadata2,
)
rag_tool.vector_io_api.query_chunks = AsyncMock(
side_effect=[
QueryChunksResponse(chunks=[chunk1], scores=[0.9]),
QueryChunksResponse(chunks=[chunk2], scores=[0.8]),
]
)
result = await rag_tool.query(content="test", vector_db_ids=vector_db_ids)
returned_chunks = result.metadata["chunks"]
returned_scores = result.metadata["scores"]
returned_doc_ids = result.metadata["document_ids"]
returned_vector_db_ids = result.metadata["vector_db_ids"]
assert returned_chunks == ["chunk from db1", "chunk from db2"]
assert returned_scores == (0.9, 0.8)
assert returned_doc_ids == ["doc1", "doc2"]
assert returned_vector_db_ids == ["db1", "db2"]

View file

@ -4,10 +4,6 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
import base64
import mimetypes
import os
from pathlib import Path
from unittest.mock import AsyncMock, MagicMock from unittest.mock import AsyncMock, MagicMock
import numpy as np import numpy as np
@ -17,37 +13,13 @@ from llama_stack.apis.inference.inference import (
OpenAIEmbeddingData, OpenAIEmbeddingData,
OpenAIEmbeddingsRequestWithExtraBody, OpenAIEmbeddingsRequestWithExtraBody,
) )
from llama_stack.apis.tools import RAGDocument
from llama_stack.apis.vector_io import Chunk from llama_stack.apis.vector_io import Chunk
from llama_stack.providers.utils.memory.vector_store import ( from llama_stack.providers.utils.memory.vector_store import (
URL, VectorStoreWithIndex,
VectorDBWithIndex,
_validate_embedding, _validate_embedding,
content_from_doc,
make_overlapped_chunks, make_overlapped_chunks,
) )
DUMMY_PDF_PATH = Path(os.path.abspath(__file__)).parent / "fixtures" / "dummy.pdf"
# Depending on the machine, this can get parsed a couple of ways
DUMMY_PDF_TEXT_CHOICES = ["Dummy PDF file", "Dumm y PDF file"]
def read_file(file_path: str) -> bytes:
with open(file_path, "rb") as file:
return file.read()
def data_url_from_file(file_path: str) -> str:
with open(file_path, "rb") as file:
file_content = file.read()
base64_content = base64.b64encode(file_content).decode("utf-8")
mime_type, _ = mimetypes.guess_type(file_path)
data_url = f"data:{mime_type};base64,{base64_content}"
return data_url
class TestChunk: class TestChunk:
def test_chunk(self): def test_chunk(self):
@ -116,45 +88,6 @@ class TestValidateEmbedding:
class TestVectorStore: class TestVectorStore:
async def test_returns_content_from_pdf_data_uri(self):
data_uri = data_url_from_file(DUMMY_PDF_PATH)
doc = RAGDocument(
document_id="dummy",
content=data_uri,
mime_type="application/pdf",
metadata={},
)
content = await content_from_doc(doc)
assert content in DUMMY_PDF_TEXT_CHOICES
@pytest.mark.allow_network
async def test_downloads_pdf_and_returns_content(self):
# Using GitHub to host the PDF file
url = "https://raw.githubusercontent.com/meta-llama/llama-stack/da035d69cfca915318eaf485770a467ca3c2a238/llama_stack/providers/tests/memory/fixtures/dummy.pdf"
doc = RAGDocument(
document_id="dummy",
content=url,
mime_type="application/pdf",
metadata={},
)
content = await content_from_doc(doc)
assert content in DUMMY_PDF_TEXT_CHOICES
@pytest.mark.allow_network
async def test_downloads_pdf_and_returns_content_with_url_object(self):
# Using GitHub to host the PDF file
url = "https://raw.githubusercontent.com/meta-llama/llama-stack/da035d69cfca915318eaf485770a467ca3c2a238/llama_stack/providers/tests/memory/fixtures/dummy.pdf"
doc = RAGDocument(
document_id="dummy",
content=URL(
uri=url,
),
mime_type="application/pdf",
metadata={},
)
content = await content_from_doc(doc)
assert content in DUMMY_PDF_TEXT_CHOICES
@pytest.mark.parametrize( @pytest.mark.parametrize(
"window_len, overlap_len, expected_chunks", "window_len, overlap_len, expected_chunks",
[ [
@ -206,15 +139,15 @@ class TestVectorStore:
assert str(excinfo.value.__cause__) == "Cannot convert to string" assert str(excinfo.value.__cause__) == "Cannot convert to string"
class TestVectorDBWithIndex: class TestVectorStoreWithIndex:
async def test_insert_chunks_without_embeddings(self): async def test_insert_chunks_without_embeddings(self):
mock_vector_db = MagicMock() mock_vector_store = MagicMock()
mock_vector_db.embedding_model = "test-model without embeddings" mock_vector_store.embedding_model = "test-model without embeddings"
mock_index = AsyncMock() mock_index = AsyncMock()
mock_inference_api = AsyncMock() mock_inference_api = AsyncMock()
vector_db_with_index = VectorDBWithIndex( vector_store_with_index = VectorStoreWithIndex(
vector_db=mock_vector_db, index=mock_index, inference_api=mock_inference_api vector_store=mock_vector_store, index=mock_index, inference_api=mock_inference_api
) )
chunks = [ chunks = [
@ -227,7 +160,7 @@ class TestVectorDBWithIndex:
OpenAIEmbeddingData(embedding=[0.4, 0.5, 0.6], index=1), OpenAIEmbeddingData(embedding=[0.4, 0.5, 0.6], index=1),
] ]
await vector_db_with_index.insert_chunks(chunks) await vector_store_with_index.insert_chunks(chunks)
# Verify openai_embeddings was called with correct params # Verify openai_embeddings was called with correct params
mock_inference_api.openai_embeddings.assert_called_once() mock_inference_api.openai_embeddings.assert_called_once()
@ -243,14 +176,14 @@ class TestVectorDBWithIndex:
assert np.array_equal(args[1], np.array([[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]], dtype=np.float32)) assert np.array_equal(args[1], np.array([[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]], dtype=np.float32))
async def test_insert_chunks_with_valid_embeddings(self): async def test_insert_chunks_with_valid_embeddings(self):
mock_vector_db = MagicMock() mock_vector_store = MagicMock()
mock_vector_db.embedding_model = "test-model with embeddings" mock_vector_store.embedding_model = "test-model with embeddings"
mock_vector_db.embedding_dimension = 3 mock_vector_store.embedding_dimension = 3
mock_index = AsyncMock() mock_index = AsyncMock()
mock_inference_api = AsyncMock() mock_inference_api = AsyncMock()
vector_db_with_index = VectorDBWithIndex( vector_store_with_index = VectorStoreWithIndex(
vector_db=mock_vector_db, index=mock_index, inference_api=mock_inference_api vector_store=mock_vector_store, index=mock_index, inference_api=mock_inference_api
) )
chunks = [ chunks = [
@ -258,7 +191,7 @@ class TestVectorDBWithIndex:
Chunk(content="Test 2", embedding=[0.4, 0.5, 0.6], metadata={}), Chunk(content="Test 2", embedding=[0.4, 0.5, 0.6], metadata={}),
] ]
await vector_db_with_index.insert_chunks(chunks) await vector_store_with_index.insert_chunks(chunks)
mock_inference_api.openai_embeddings.assert_not_called() mock_inference_api.openai_embeddings.assert_not_called()
mock_index.add_chunks.assert_called_once() mock_index.add_chunks.assert_called_once()
@ -267,14 +200,14 @@ class TestVectorDBWithIndex:
assert np.array_equal(args[1], np.array([[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]], dtype=np.float32)) assert np.array_equal(args[1], np.array([[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]], dtype=np.float32))
async def test_insert_chunks_with_invalid_embeddings(self): async def test_insert_chunks_with_invalid_embeddings(self):
mock_vector_db = MagicMock() mock_vector_store = MagicMock()
mock_vector_db.embedding_dimension = 3 mock_vector_store.embedding_dimension = 3
mock_vector_db.embedding_model = "test-model with invalid embeddings" mock_vector_store.embedding_model = "test-model with invalid embeddings"
mock_index = AsyncMock() mock_index = AsyncMock()
mock_inference_api = AsyncMock() mock_inference_api = AsyncMock()
vector_db_with_index = VectorDBWithIndex( vector_store_with_index = VectorStoreWithIndex(
vector_db=mock_vector_db, index=mock_index, inference_api=mock_inference_api vector_store=mock_vector_store, index=mock_index, inference_api=mock_inference_api
) )
# Verify Chunk raises ValueError for invalid embedding type # Verify Chunk raises ValueError for invalid embedding type
@ -283,7 +216,7 @@ class TestVectorDBWithIndex:
# Verify Chunk raises ValueError for invalid embedding type in insert_chunks (i.e., Chunk errors before insert_chunks is called) # Verify Chunk raises ValueError for invalid embedding type in insert_chunks (i.e., Chunk errors before insert_chunks is called)
with pytest.raises(ValueError, match="Input should be a valid list"): with pytest.raises(ValueError, match="Input should be a valid list"):
await vector_db_with_index.insert_chunks( await vector_store_with_index.insert_chunks(
[ [
Chunk(content="Test 1", embedding=None, metadata={}), Chunk(content="Test 1", embedding=None, metadata={}),
Chunk(content="Test 2", embedding="invalid_type", metadata={}), Chunk(content="Test 2", embedding="invalid_type", metadata={}),
@ -292,7 +225,7 @@ class TestVectorDBWithIndex:
# Verify Chunk raises ValueError for invalid embedding element type in insert_chunks (i.e., Chunk errors before insert_chunks is called) # Verify Chunk raises ValueError for invalid embedding element type in insert_chunks (i.e., Chunk errors before insert_chunks is called)
with pytest.raises(ValueError, match=" Input should be a valid number, unable to parse string as a number "): with pytest.raises(ValueError, match=" Input should be a valid number, unable to parse string as a number "):
await vector_db_with_index.insert_chunks( await vector_store_with_index.insert_chunks(
Chunk(content="Test 1", embedding=[0.1, "string", 0.3], metadata={}) Chunk(content="Test 1", embedding=[0.1, "string", 0.3], metadata={})
) )
@ -300,20 +233,20 @@ class TestVectorDBWithIndex:
Chunk(content="Test 1", embedding=[0.1, 0.2, 0.3, 0.4], metadata={}), Chunk(content="Test 1", embedding=[0.1, 0.2, 0.3, 0.4], metadata={}),
] ]
with pytest.raises(ValueError, match="has dimension 4, expected 3"): with pytest.raises(ValueError, match="has dimension 4, expected 3"):
await vector_db_with_index.insert_chunks(chunks_wrong_dim) await vector_store_with_index.insert_chunks(chunks_wrong_dim)
mock_inference_api.openai_embeddings.assert_not_called() mock_inference_api.openai_embeddings.assert_not_called()
mock_index.add_chunks.assert_not_called() mock_index.add_chunks.assert_not_called()
async def test_insert_chunks_with_partially_precomputed_embeddings(self): async def test_insert_chunks_with_partially_precomputed_embeddings(self):
mock_vector_db = MagicMock() mock_vector_store = MagicMock()
mock_vector_db.embedding_model = "test-model with partial embeddings" mock_vector_store.embedding_model = "test-model with partial embeddings"
mock_vector_db.embedding_dimension = 3 mock_vector_store.embedding_dimension = 3
mock_index = AsyncMock() mock_index = AsyncMock()
mock_inference_api = AsyncMock() mock_inference_api = AsyncMock()
vector_db_with_index = VectorDBWithIndex( vector_store_with_index = VectorStoreWithIndex(
vector_db=mock_vector_db, index=mock_index, inference_api=mock_inference_api vector_store=mock_vector_store, index=mock_index, inference_api=mock_inference_api
) )
chunks = [ chunks = [
@ -327,7 +260,7 @@ class TestVectorDBWithIndex:
OpenAIEmbeddingData(embedding=[0.3, 0.3, 0.3], index=1), OpenAIEmbeddingData(embedding=[0.3, 0.3, 0.3], index=1),
] ]
await vector_db_with_index.insert_chunks(chunks) await vector_store_with_index.insert_chunks(chunks)
# Verify openai_embeddings was called with correct params # Verify openai_embeddings was called with correct params
mock_inference_api.openai_embeddings.assert_called_once() mock_inference_api.openai_embeddings.assert_called_once()

View file

@ -8,8 +8,8 @@
import pytest import pytest
from llama_stack.apis.inference import Model from llama_stack.apis.inference import Model
from llama_stack.apis.vector_dbs import VectorDB from llama_stack.apis.vector_stores import VectorStore
from llama_stack.core.datatypes import VectorDBWithOwner from llama_stack.core.datatypes import VectorStoreWithOwner
from llama_stack.core.storage.datatypes import KVStoreReference, SqliteKVStoreConfig from llama_stack.core.storage.datatypes import KVStoreReference, SqliteKVStoreConfig
from llama_stack.core.store.registry import ( from llama_stack.core.store.registry import (
KEY_FORMAT, KEY_FORMAT,
@ -20,12 +20,12 @@ from llama_stack.providers.utils.kvstore import kvstore_impl, register_kvstore_b
@pytest.fixture @pytest.fixture
def sample_vector_db(): def sample_vector_store():
return VectorDB( return VectorStore(
identifier="test_vector_db", identifier="test_vector_store",
embedding_model="nomic-embed-text-v1.5", embedding_model="nomic-embed-text-v1.5",
embedding_dimension=768, embedding_dimension=768,
provider_resource_id="test_vector_db", provider_resource_id="test_vector_store",
provider_id="test-provider", provider_id="test-provider",
) )
@ -45,17 +45,17 @@ async def test_registry_initialization(disk_dist_registry):
assert result is None assert result is None
async def test_basic_registration(disk_dist_registry, sample_vector_db, sample_model): async def test_basic_registration(disk_dist_registry, sample_vector_store, sample_model):
print(f"Registering {sample_vector_db}") print(f"Registering {sample_vector_store}")
await disk_dist_registry.register(sample_vector_db) await disk_dist_registry.register(sample_vector_store)
print(f"Registering {sample_model}") print(f"Registering {sample_model}")
await disk_dist_registry.register(sample_model) await disk_dist_registry.register(sample_model)
print("Getting vector_db") print("Getting vector_store")
result_vector_db = await disk_dist_registry.get("vector_db", "test_vector_db") result_vector_store = await disk_dist_registry.get("vector_store", "test_vector_store")
assert result_vector_db is not None assert result_vector_store is not None
assert result_vector_db.identifier == sample_vector_db.identifier assert result_vector_store.identifier == sample_vector_store.identifier
assert result_vector_db.embedding_model == sample_vector_db.embedding_model assert result_vector_store.embedding_model == sample_vector_store.embedding_model
assert result_vector_db.provider_id == sample_vector_db.provider_id assert result_vector_store.provider_id == sample_vector_store.provider_id
result_model = await disk_dist_registry.get("model", "test_model") result_model = await disk_dist_registry.get("model", "test_model")
assert result_model is not None assert result_model is not None
@ -63,11 +63,11 @@ async def test_basic_registration(disk_dist_registry, sample_vector_db, sample_m
assert result_model.provider_id == sample_model.provider_id assert result_model.provider_id == sample_model.provider_id
async def test_cached_registry_initialization(sqlite_kvstore, sample_vector_db, sample_model): async def test_cached_registry_initialization(sqlite_kvstore, sample_vector_store, sample_model):
# First populate the disk registry # First populate the disk registry
disk_registry = DiskDistributionRegistry(sqlite_kvstore) disk_registry = DiskDistributionRegistry(sqlite_kvstore)
await disk_registry.initialize() await disk_registry.initialize()
await disk_registry.register(sample_vector_db) await disk_registry.register(sample_vector_store)
await disk_registry.register(sample_model) await disk_registry.register(sample_model)
# Test cached version loads from disk # Test cached version loads from disk
@ -79,29 +79,29 @@ async def test_cached_registry_initialization(sqlite_kvstore, sample_vector_db,
) )
await cached_registry.initialize() await cached_registry.initialize()
result_vector_db = await cached_registry.get("vector_db", "test_vector_db") result_vector_store = await cached_registry.get("vector_store", "test_vector_store")
assert result_vector_db is not None assert result_vector_store is not None
assert result_vector_db.identifier == sample_vector_db.identifier assert result_vector_store.identifier == sample_vector_store.identifier
assert result_vector_db.embedding_model == sample_vector_db.embedding_model assert result_vector_store.embedding_model == sample_vector_store.embedding_model
assert result_vector_db.embedding_dimension == sample_vector_db.embedding_dimension assert result_vector_store.embedding_dimension == sample_vector_store.embedding_dimension
assert result_vector_db.provider_id == sample_vector_db.provider_id assert result_vector_store.provider_id == sample_vector_store.provider_id
async def test_cached_registry_updates(cached_disk_dist_registry): async def test_cached_registry_updates(cached_disk_dist_registry):
new_vector_db = VectorDB( new_vector_store = VectorStore(
identifier="test_vector_db_2", identifier="test_vector_store_2",
embedding_model="nomic-embed-text-v1.5", embedding_model="nomic-embed-text-v1.5",
embedding_dimension=768, embedding_dimension=768,
provider_resource_id="test_vector_db_2", provider_resource_id="test_vector_store_2",
provider_id="baz", provider_id="baz",
) )
await cached_disk_dist_registry.register(new_vector_db) await cached_disk_dist_registry.register(new_vector_store)
# Verify in cache # Verify in cache
result_vector_db = await cached_disk_dist_registry.get("vector_db", "test_vector_db_2") result_vector_store = await cached_disk_dist_registry.get("vector_store", "test_vector_store_2")
assert result_vector_db is not None assert result_vector_store is not None
assert result_vector_db.identifier == new_vector_db.identifier assert result_vector_store.identifier == new_vector_store.identifier
assert result_vector_db.provider_id == new_vector_db.provider_id assert result_vector_store.provider_id == new_vector_store.provider_id
# Verify persisted to disk # Verify persisted to disk
db_path = cached_disk_dist_registry.kvstore.db_path db_path = cached_disk_dist_registry.kvstore.db_path
@ -111,87 +111,89 @@ async def test_cached_registry_updates(cached_disk_dist_registry):
await kvstore_impl(KVStoreReference(backend=backend_name, namespace="registry")) await kvstore_impl(KVStoreReference(backend=backend_name, namespace="registry"))
) )
await new_registry.initialize() await new_registry.initialize()
result_vector_db = await new_registry.get("vector_db", "test_vector_db_2") result_vector_store = await new_registry.get("vector_store", "test_vector_store_2")
assert result_vector_db is not None assert result_vector_store is not None
assert result_vector_db.identifier == new_vector_db.identifier assert result_vector_store.identifier == new_vector_store.identifier
assert result_vector_db.provider_id == new_vector_db.provider_id assert result_vector_store.provider_id == new_vector_store.provider_id
async def test_duplicate_provider_registration(cached_disk_dist_registry): async def test_duplicate_provider_registration(cached_disk_dist_registry):
original_vector_db = VectorDB( original_vector_store = VectorStore(
identifier="test_vector_db_2", identifier="test_vector_store_2",
embedding_model="nomic-embed-text-v1.5", embedding_model="nomic-embed-text-v1.5",
embedding_dimension=768, embedding_dimension=768,
provider_resource_id="test_vector_db_2", provider_resource_id="test_vector_store_2",
provider_id="baz", provider_id="baz",
) )
assert await cached_disk_dist_registry.register(original_vector_db) assert await cached_disk_dist_registry.register(original_vector_store)
duplicate_vector_db = VectorDB( duplicate_vector_store = VectorStore(
identifier="test_vector_db_2", identifier="test_vector_store_2",
embedding_model="different-model", embedding_model="different-model",
embedding_dimension=768, embedding_dimension=768,
provider_resource_id="test_vector_db_2", provider_resource_id="test_vector_store_2",
provider_id="baz", # Same provider_id provider_id="baz", # Same provider_id
) )
with pytest.raises(ValueError, match="Object of type 'vector_db' and identifier 'test_vector_db_2' already exists"): with pytest.raises(
await cached_disk_dist_registry.register(duplicate_vector_db) ValueError, match="Object of type 'vector_store' and identifier 'test_vector_store_2' already exists"
):
await cached_disk_dist_registry.register(duplicate_vector_store)
result = await cached_disk_dist_registry.get("vector_db", "test_vector_db_2") result = await cached_disk_dist_registry.get("vector_store", "test_vector_store_2")
assert result is not None assert result is not None
assert result.embedding_model == original_vector_db.embedding_model # Original values preserved assert result.embedding_model == original_vector_store.embedding_model # Original values preserved
async def test_get_all_objects(cached_disk_dist_registry): async def test_get_all_objects(cached_disk_dist_registry):
# Create multiple test banks # Create multiple test banks
# Create multiple test banks # Create multiple test banks
test_vector_dbs = [ test_vector_stores = [
VectorDB( VectorStore(
identifier=f"test_vector_db_{i}", identifier=f"test_vector_store_{i}",
embedding_model="nomic-embed-text-v1.5", embedding_model="nomic-embed-text-v1.5",
embedding_dimension=768, embedding_dimension=768,
provider_resource_id=f"test_vector_db_{i}", provider_resource_id=f"test_vector_store_{i}",
provider_id=f"provider_{i}", provider_id=f"provider_{i}",
) )
for i in range(3) for i in range(3)
] ]
# Register all vector_dbs # Register all vector_stores
for vector_db in test_vector_dbs: for vector_store in test_vector_stores:
await cached_disk_dist_registry.register(vector_db) await cached_disk_dist_registry.register(vector_store)
# Test get_all retrieval # Test get_all retrieval
all_results = await cached_disk_dist_registry.get_all() all_results = await cached_disk_dist_registry.get_all()
assert len(all_results) == 3 assert len(all_results) == 3
# Verify each vector_db was stored correctly # Verify each vector_store was stored correctly
for original_vector_db in test_vector_dbs: for original_vector_store in test_vector_stores:
matching_vector_dbs = [v for v in all_results if v.identifier == original_vector_db.identifier] matching_vector_stores = [v for v in all_results if v.identifier == original_vector_store.identifier]
assert len(matching_vector_dbs) == 1 assert len(matching_vector_stores) == 1
stored_vector_db = matching_vector_dbs[0] stored_vector_store = matching_vector_stores[0]
assert stored_vector_db.embedding_model == original_vector_db.embedding_model assert stored_vector_store.embedding_model == original_vector_store.embedding_model
assert stored_vector_db.provider_id == original_vector_db.provider_id assert stored_vector_store.provider_id == original_vector_store.provider_id
assert stored_vector_db.embedding_dimension == original_vector_db.embedding_dimension assert stored_vector_store.embedding_dimension == original_vector_store.embedding_dimension
async def test_parse_registry_values_error_handling(sqlite_kvstore): async def test_parse_registry_values_error_handling(sqlite_kvstore):
valid_db = VectorDB( valid_db = VectorStore(
identifier="valid_vector_db", identifier="valid_vector_store",
embedding_model="nomic-embed-text-v1.5", embedding_model="nomic-embed-text-v1.5",
embedding_dimension=768, embedding_dimension=768,
provider_resource_id="valid_vector_db", provider_resource_id="valid_vector_store",
provider_id="test-provider", provider_id="test-provider",
) )
await sqlite_kvstore.set( await sqlite_kvstore.set(
KEY_FORMAT.format(type="vector_db", identifier="valid_vector_db"), valid_db.model_dump_json() KEY_FORMAT.format(type="vector_store", identifier="valid_vector_store"), valid_db.model_dump_json()
) )
await sqlite_kvstore.set(KEY_FORMAT.format(type="vector_db", identifier="corrupted_json"), "{not valid json") await sqlite_kvstore.set(KEY_FORMAT.format(type="vector_store", identifier="corrupted_json"), "{not valid json")
await sqlite_kvstore.set( await sqlite_kvstore.set(
KEY_FORMAT.format(type="vector_db", identifier="missing_fields"), KEY_FORMAT.format(type="vector_store", identifier="missing_fields"),
'{"type": "vector_db", "identifier": "missing_fields"}', '{"type": "vector_store", "identifier": "missing_fields"}',
) )
test_registry = DiskDistributionRegistry(sqlite_kvstore) test_registry = DiskDistributionRegistry(sqlite_kvstore)
@ -202,18 +204,18 @@ async def test_parse_registry_values_error_handling(sqlite_kvstore):
# Should have filtered out the invalid entries # Should have filtered out the invalid entries
assert len(all_objects) == 1 assert len(all_objects) == 1
assert all_objects[0].identifier == "valid_vector_db" assert all_objects[0].identifier == "valid_vector_store"
# Check that the get method also handles errors correctly # Check that the get method also handles errors correctly
invalid_obj = await test_registry.get("vector_db", "corrupted_json") invalid_obj = await test_registry.get("vector_store", "corrupted_json")
assert invalid_obj is None assert invalid_obj is None
invalid_obj = await test_registry.get("vector_db", "missing_fields") invalid_obj = await test_registry.get("vector_store", "missing_fields")
assert invalid_obj is None assert invalid_obj is None
async def test_cached_registry_error_handling(sqlite_kvstore): async def test_cached_registry_error_handling(sqlite_kvstore):
valid_db = VectorDB( valid_db = VectorStore(
identifier="valid_cached_db", identifier="valid_cached_db",
embedding_model="nomic-embed-text-v1.5", embedding_model="nomic-embed-text-v1.5",
embedding_dimension=768, embedding_dimension=768,
@ -222,12 +224,12 @@ async def test_cached_registry_error_handling(sqlite_kvstore):
) )
await sqlite_kvstore.set( await sqlite_kvstore.set(
KEY_FORMAT.format(type="vector_db", identifier="valid_cached_db"), valid_db.model_dump_json() KEY_FORMAT.format(type="vector_store", identifier="valid_cached_db"), valid_db.model_dump_json()
) )
await sqlite_kvstore.set( await sqlite_kvstore.set(
KEY_FORMAT.format(type="vector_db", identifier="invalid_cached_db"), KEY_FORMAT.format(type="vector_store", identifier="invalid_cached_db"),
'{"type": "vector_db", "identifier": "invalid_cached_db", "embedding_model": 12345}', # Should be string '{"type": "vector_store", "identifier": "invalid_cached_db", "embedding_model": 12345}', # Should be string
) )
cached_registry = CachedDiskDistributionRegistry(sqlite_kvstore) cached_registry = CachedDiskDistributionRegistry(sqlite_kvstore)
@ -237,63 +239,65 @@ async def test_cached_registry_error_handling(sqlite_kvstore):
assert len(all_objects) == 1 assert len(all_objects) == 1
assert all_objects[0].identifier == "valid_cached_db" assert all_objects[0].identifier == "valid_cached_db"
invalid_obj = await cached_registry.get("vector_db", "invalid_cached_db") invalid_obj = await cached_registry.get("vector_store", "invalid_cached_db")
assert invalid_obj is None assert invalid_obj is None
async def test_double_registration_identical_objects(disk_dist_registry): async def test_double_registration_identical_objects(disk_dist_registry):
"""Test that registering identical objects succeeds (idempotent).""" """Test that registering identical objects succeeds (idempotent)."""
vector_db = VectorDBWithOwner( vector_store = VectorStoreWithOwner(
identifier="test_vector_db", identifier="test_vector_store",
embedding_model="all-MiniLM-L6-v2", embedding_model="all-MiniLM-L6-v2",
embedding_dimension=384, embedding_dimension=384,
provider_resource_id="test_vector_db", provider_resource_id="test_vector_store",
provider_id="test-provider", provider_id="test-provider",
) )
# First registration should succeed # First registration should succeed
result1 = await disk_dist_registry.register(vector_db) result1 = await disk_dist_registry.register(vector_store)
assert result1 is True assert result1 is True
# Second registration of identical object should also succeed (idempotent) # Second registration of identical object should also succeed (idempotent)
result2 = await disk_dist_registry.register(vector_db) result2 = await disk_dist_registry.register(vector_store)
assert result2 is True assert result2 is True
# Verify object exists and is unchanged # Verify object exists and is unchanged
retrieved = await disk_dist_registry.get("vector_db", "test_vector_db") retrieved = await disk_dist_registry.get("vector_store", "test_vector_store")
assert retrieved is not None assert retrieved is not None
assert retrieved.identifier == vector_db.identifier assert retrieved.identifier == vector_store.identifier
assert retrieved.embedding_model == vector_db.embedding_model assert retrieved.embedding_model == vector_store.embedding_model
async def test_double_registration_different_objects(disk_dist_registry): async def test_double_registration_different_objects(disk_dist_registry):
"""Test that registering different objects with same identifier fails.""" """Test that registering different objects with same identifier fails."""
vector_db1 = VectorDBWithOwner( vector_store1 = VectorStoreWithOwner(
identifier="test_vector_db", identifier="test_vector_store",
embedding_model="all-MiniLM-L6-v2", embedding_model="all-MiniLM-L6-v2",
embedding_dimension=384, embedding_dimension=384,
provider_resource_id="test_vector_db", provider_resource_id="test_vector_store",
provider_id="test-provider", provider_id="test-provider",
) )
vector_db2 = VectorDBWithOwner( vector_store2 = VectorStoreWithOwner(
identifier="test_vector_db", # Same identifier identifier="test_vector_store", # Same identifier
embedding_model="different-model", # Different embedding model embedding_model="different-model", # Different embedding model
embedding_dimension=384, embedding_dimension=384,
provider_resource_id="test_vector_db", provider_resource_id="test_vector_store",
provider_id="test-provider", provider_id="test-provider",
) )
# First registration should succeed # First registration should succeed
result1 = await disk_dist_registry.register(vector_db1) result1 = await disk_dist_registry.register(vector_store1)
assert result1 is True assert result1 is True
# Second registration with different data should fail # Second registration with different data should fail
with pytest.raises(ValueError, match="Object of type 'vector_db' and identifier 'test_vector_db' already exists"): with pytest.raises(
await disk_dist_registry.register(vector_db2) ValueError, match="Object of type 'vector_store' and identifier 'test_vector_store' already exists"
):
await disk_dist_registry.register(vector_store2)
# Verify original object is unchanged # Verify original object is unchanged
retrieved = await disk_dist_registry.get("vector_db", "test_vector_db") retrieved = await disk_dist_registry.get("vector_store", "test_vector_store")
assert retrieved is not None assert retrieved is not None
assert retrieved.embedding_model == "all-MiniLM-L6-v2" # Original value assert retrieved.embedding_model == "all-MiniLM-L6-v2" # Original value

View file

@ -41,7 +41,7 @@ class TestTranslateException:
self.identifier = identifier self.identifier = identifier
self.owner = owner self.owner = owner
resource = MockResource("vector_db", "test-db") resource = MockResource("vector_store", "test-db")
exc = AccessDeniedError("create", resource, user) exc = AccessDeniedError("create", resource, user)
result = translate_exception(exc) result = translate_exception(exc)
@ -49,7 +49,7 @@ class TestTranslateException:
assert isinstance(result, HTTPException) assert isinstance(result, HTTPException)
assert result.status_code == 403 assert result.status_code == 403
assert "test-user" in result.detail assert "test-user" in result.detail
assert "vector_db::test-db" in result.detail assert "vector_store::test-db" in result.detail
assert "create" in result.detail assert "create" in result.detail
assert "roles=['user']" in result.detail assert "roles=['user']" in result.detail
assert "teams=['dev']" in result.detail assert "teams=['dev']" in result.detail