mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-04 10:10:36 +00:00
Merge bf442eb3f3 into sapling-pr-archive-ehhuang
This commit is contained in:
commit
a7c0ec991b
97 changed files with 1986 additions and 5308 deletions
|
|
@ -86,10 +86,9 @@ runs:
|
|||
if: ${{ always() }}
|
||||
shell: bash
|
||||
run: |
|
||||
sudo docker logs ollama > ollama-${{ inputs.inference-mode }}.log || true
|
||||
distro_name=$(echo "${{ inputs.stack-config }}" | sed 's/^docker://' | sed 's/^server://')
|
||||
stack_container_name="llama-stack-test-$distro_name"
|
||||
sudo docker logs $stack_container_name > docker-${distro_name}-${{ inputs.inference-mode }}.log || true
|
||||
# Ollama logs (if ollama container exists)
|
||||
sudo docker logs ollama > ollama-${{ inputs.inference-mode }}.log 2>&1 || true
|
||||
# Note: distro container logs are now dumped in integration-tests.sh before container is removed
|
||||
|
||||
- name: Upload logs
|
||||
if: ${{ always() }}
|
||||
|
|
|
|||
2
.github/workflows/precommit-trigger.yml
vendored
2
.github/workflows/precommit-trigger.yml
vendored
|
|
@ -99,7 +99,7 @@ jobs:
|
|||
owner: context.repo.owner,
|
||||
repo: context.repo.repo,
|
||||
issue_number: ${{ steps.check_author.outputs.pr_number }},
|
||||
body: `⏳ Running pre-commit hooks on PR #${{ steps.check_author.outputs.pr_number }}...`
|
||||
body: `⏳ Running [pre-commit hooks](https://github.com/${context.repo.owner}/${context.repo.repo}/actions/runs/${context.runId}) on PR #${{ steps.check_author.outputs.pr_number }}...`
|
||||
});
|
||||
|
||||
- name: Checkout PR branch (same-repo)
|
||||
|
|
|
|||
|
|
@ -208,19 +208,6 @@ resources:
|
|||
type: http
|
||||
endpoint: post /v1/conversations/{conversation_id}/items
|
||||
|
||||
datasets:
|
||||
models:
|
||||
list_datasets_response: ListDatasetsResponse
|
||||
methods:
|
||||
register: post /v1beta/datasets
|
||||
retrieve: get /v1beta/datasets/{dataset_id}
|
||||
list:
|
||||
endpoint: get /v1beta/datasets
|
||||
paginated: false
|
||||
unregister: delete /v1beta/datasets/{dataset_id}
|
||||
iterrows: get /v1beta/datasetio/iterrows/{dataset_id}
|
||||
appendrows: post /v1beta/datasetio/append-rows/{dataset_id}
|
||||
|
||||
inspect:
|
||||
models:
|
||||
healthInfo: HealthInfo
|
||||
|
|
@ -521,6 +508,21 @@ resources:
|
|||
stream_event_model: alpha.agents.turn.agent_turn_response_stream_chunk
|
||||
param_discriminator: stream
|
||||
|
||||
beta:
|
||||
subresources:
|
||||
datasets:
|
||||
models:
|
||||
list_datasets_response: ListDatasetsResponse
|
||||
methods:
|
||||
register: post /v1beta/datasets
|
||||
retrieve: get /v1beta/datasets/{dataset_id}
|
||||
list:
|
||||
endpoint: get /v1beta/datasets
|
||||
paginated: false
|
||||
unregister: delete /v1beta/datasets/{dataset_id}
|
||||
iterrows: get /v1beta/datasetio/iterrows/{dataset_id}
|
||||
appendrows: post /v1beta/datasetio/append-rows/{dataset_id}
|
||||
|
||||
|
||||
settings:
|
||||
license: MIT
|
||||
|
|
|
|||
|
|
@ -2039,69 +2039,6 @@ paths:
|
|||
schema:
|
||||
$ref: '#/components/schemas/URL'
|
||||
deprecated: false
|
||||
/v1/tool-runtime/rag-tool/insert:
|
||||
post:
|
||||
responses:
|
||||
'200':
|
||||
description: OK
|
||||
'400':
|
||||
$ref: '#/components/responses/BadRequest400'
|
||||
'429':
|
||||
$ref: >-
|
||||
#/components/responses/TooManyRequests429
|
||||
'500':
|
||||
$ref: >-
|
||||
#/components/responses/InternalServerError500
|
||||
default:
|
||||
$ref: '#/components/responses/DefaultError'
|
||||
tags:
|
||||
- ToolRuntime
|
||||
summary: >-
|
||||
Index documents so they can be used by the RAG system.
|
||||
description: >-
|
||||
Index documents so they can be used by the RAG system.
|
||||
parameters: []
|
||||
requestBody:
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: '#/components/schemas/InsertRequest'
|
||||
required: true
|
||||
deprecated: false
|
||||
/v1/tool-runtime/rag-tool/query:
|
||||
post:
|
||||
responses:
|
||||
'200':
|
||||
description: >-
|
||||
RAGQueryResult containing the retrieved content and metadata
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: '#/components/schemas/RAGQueryResult'
|
||||
'400':
|
||||
$ref: '#/components/responses/BadRequest400'
|
||||
'429':
|
||||
$ref: >-
|
||||
#/components/responses/TooManyRequests429
|
||||
'500':
|
||||
$ref: >-
|
||||
#/components/responses/InternalServerError500
|
||||
default:
|
||||
$ref: '#/components/responses/DefaultError'
|
||||
tags:
|
||||
- ToolRuntime
|
||||
summary: >-
|
||||
Query the RAG system for context; typically invoked by the agent.
|
||||
description: >-
|
||||
Query the RAG system for context; typically invoked by the agent.
|
||||
parameters: []
|
||||
requestBody:
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: '#/components/schemas/QueryRequest'
|
||||
required: true
|
||||
deprecated: false
|
||||
/v1/toolgroups:
|
||||
get:
|
||||
responses:
|
||||
|
|
@ -6440,7 +6377,7 @@ components:
|
|||
enum:
|
||||
- model
|
||||
- shield
|
||||
- vector_db
|
||||
- vector_store
|
||||
- dataset
|
||||
- scoring_function
|
||||
- benchmark
|
||||
|
|
@ -9132,7 +9069,7 @@ components:
|
|||
enum:
|
||||
- model
|
||||
- shield
|
||||
- vector_db
|
||||
- vector_store
|
||||
- dataset
|
||||
- scoring_function
|
||||
- benchmark
|
||||
|
|
@ -9440,7 +9377,7 @@ components:
|
|||
enum:
|
||||
- model
|
||||
- shield
|
||||
- vector_db
|
||||
- vector_store
|
||||
- dataset
|
||||
- scoring_function
|
||||
- benchmark
|
||||
|
|
@ -9921,274 +9858,6 @@ components:
|
|||
title: ListToolDefsResponse
|
||||
description: >-
|
||||
Response containing a list of tool definitions.
|
||||
RAGDocument:
|
||||
type: object
|
||||
properties:
|
||||
document_id:
|
||||
type: string
|
||||
description: The unique identifier for the document.
|
||||
content:
|
||||
oneOf:
|
||||
- type: string
|
||||
- $ref: '#/components/schemas/InterleavedContentItem'
|
||||
- type: array
|
||||
items:
|
||||
$ref: '#/components/schemas/InterleavedContentItem'
|
||||
- $ref: '#/components/schemas/URL'
|
||||
description: The content of the document.
|
||||
mime_type:
|
||||
type: string
|
||||
description: The MIME type of the document.
|
||||
metadata:
|
||||
type: object
|
||||
additionalProperties:
|
||||
oneOf:
|
||||
- type: 'null'
|
||||
- type: boolean
|
||||
- type: number
|
||||
- type: string
|
||||
- type: array
|
||||
- type: object
|
||||
description: Additional metadata for the document.
|
||||
additionalProperties: false
|
||||
required:
|
||||
- document_id
|
||||
- content
|
||||
- metadata
|
||||
title: RAGDocument
|
||||
description: >-
|
||||
A document to be used for document ingestion in the RAG Tool.
|
||||
InsertRequest:
|
||||
type: object
|
||||
properties:
|
||||
documents:
|
||||
type: array
|
||||
items:
|
||||
$ref: '#/components/schemas/RAGDocument'
|
||||
description: >-
|
||||
List of documents to index in the RAG system
|
||||
vector_db_id:
|
||||
type: string
|
||||
description: >-
|
||||
ID of the vector database to store the document embeddings
|
||||
chunk_size_in_tokens:
|
||||
type: integer
|
||||
description: >-
|
||||
(Optional) Size in tokens for document chunking during indexing
|
||||
additionalProperties: false
|
||||
required:
|
||||
- documents
|
||||
- vector_db_id
|
||||
- chunk_size_in_tokens
|
||||
title: InsertRequest
|
||||
DefaultRAGQueryGeneratorConfig:
|
||||
type: object
|
||||
properties:
|
||||
type:
|
||||
type: string
|
||||
const: default
|
||||
default: default
|
||||
description: >-
|
||||
Type of query generator, always 'default'
|
||||
separator:
|
||||
type: string
|
||||
default: ' '
|
||||
description: >-
|
||||
String separator used to join query terms
|
||||
additionalProperties: false
|
||||
required:
|
||||
- type
|
||||
- separator
|
||||
title: DefaultRAGQueryGeneratorConfig
|
||||
description: >-
|
||||
Configuration for the default RAG query generator.
|
||||
LLMRAGQueryGeneratorConfig:
|
||||
type: object
|
||||
properties:
|
||||
type:
|
||||
type: string
|
||||
const: llm
|
||||
default: llm
|
||||
description: Type of query generator, always 'llm'
|
||||
model:
|
||||
type: string
|
||||
description: >-
|
||||
Name of the language model to use for query generation
|
||||
template:
|
||||
type: string
|
||||
description: >-
|
||||
Template string for formatting the query generation prompt
|
||||
additionalProperties: false
|
||||
required:
|
||||
- type
|
||||
- model
|
||||
- template
|
||||
title: LLMRAGQueryGeneratorConfig
|
||||
description: >-
|
||||
Configuration for the LLM-based RAG query generator.
|
||||
RAGQueryConfig:
|
||||
type: object
|
||||
properties:
|
||||
query_generator_config:
|
||||
oneOf:
|
||||
- $ref: '#/components/schemas/DefaultRAGQueryGeneratorConfig'
|
||||
- $ref: '#/components/schemas/LLMRAGQueryGeneratorConfig'
|
||||
discriminator:
|
||||
propertyName: type
|
||||
mapping:
|
||||
default: '#/components/schemas/DefaultRAGQueryGeneratorConfig'
|
||||
llm: '#/components/schemas/LLMRAGQueryGeneratorConfig'
|
||||
description: Configuration for the query generator.
|
||||
max_tokens_in_context:
|
||||
type: integer
|
||||
default: 4096
|
||||
description: Maximum number of tokens in the context.
|
||||
max_chunks:
|
||||
type: integer
|
||||
default: 5
|
||||
description: Maximum number of chunks to retrieve.
|
||||
chunk_template:
|
||||
type: string
|
||||
default: >
|
||||
Result {index}
|
||||
|
||||
Content: {chunk.content}
|
||||
|
||||
Metadata: {metadata}
|
||||
description: >-
|
||||
Template for formatting each retrieved chunk in the context. Available
|
||||
placeholders: {index} (1-based chunk ordinal), {chunk.content} (chunk
|
||||
content string), {metadata} (chunk metadata dict). Default: "Result {index}\nContent:
|
||||
{chunk.content}\nMetadata: {metadata}\n"
|
||||
mode:
|
||||
$ref: '#/components/schemas/RAGSearchMode'
|
||||
default: vector
|
||||
description: >-
|
||||
Search mode for retrieval—either "vector", "keyword", or "hybrid". Default
|
||||
"vector".
|
||||
ranker:
|
||||
$ref: '#/components/schemas/Ranker'
|
||||
description: >-
|
||||
Configuration for the ranker to use in hybrid search. Defaults to RRF
|
||||
ranker.
|
||||
additionalProperties: false
|
||||
required:
|
||||
- query_generator_config
|
||||
- max_tokens_in_context
|
||||
- max_chunks
|
||||
- chunk_template
|
||||
title: RAGQueryConfig
|
||||
description: >-
|
||||
Configuration for the RAG query generation.
|
||||
RAGSearchMode:
|
||||
type: string
|
||||
enum:
|
||||
- vector
|
||||
- keyword
|
||||
- hybrid
|
||||
title: RAGSearchMode
|
||||
description: >-
|
||||
Search modes for RAG query retrieval: - VECTOR: Uses vector similarity search
|
||||
for semantic matching - KEYWORD: Uses keyword-based search for exact matching
|
||||
- HYBRID: Combines both vector and keyword search for better results
|
||||
RRFRanker:
|
||||
type: object
|
||||
properties:
|
||||
type:
|
||||
type: string
|
||||
const: rrf
|
||||
default: rrf
|
||||
description: The type of ranker, always "rrf"
|
||||
impact_factor:
|
||||
type: number
|
||||
default: 60.0
|
||||
description: >-
|
||||
The impact factor for RRF scoring. Higher values give more weight to higher-ranked
|
||||
results. Must be greater than 0
|
||||
additionalProperties: false
|
||||
required:
|
||||
- type
|
||||
- impact_factor
|
||||
title: RRFRanker
|
||||
description: >-
|
||||
Reciprocal Rank Fusion (RRF) ranker configuration.
|
||||
Ranker:
|
||||
oneOf:
|
||||
- $ref: '#/components/schemas/RRFRanker'
|
||||
- $ref: '#/components/schemas/WeightedRanker'
|
||||
discriminator:
|
||||
propertyName: type
|
||||
mapping:
|
||||
rrf: '#/components/schemas/RRFRanker'
|
||||
weighted: '#/components/schemas/WeightedRanker'
|
||||
WeightedRanker:
|
||||
type: object
|
||||
properties:
|
||||
type:
|
||||
type: string
|
||||
const: weighted
|
||||
default: weighted
|
||||
description: The type of ranker, always "weighted"
|
||||
alpha:
|
||||
type: number
|
||||
default: 0.5
|
||||
description: >-
|
||||
Weight factor between 0 and 1. 0 means only use keyword scores, 1 means
|
||||
only use vector scores, values in between blend both scores.
|
||||
additionalProperties: false
|
||||
required:
|
||||
- type
|
||||
- alpha
|
||||
title: WeightedRanker
|
||||
description: >-
|
||||
Weighted ranker configuration that combines vector and keyword scores.
|
||||
QueryRequest:
|
||||
type: object
|
||||
properties:
|
||||
content:
|
||||
$ref: '#/components/schemas/InterleavedContent'
|
||||
description: >-
|
||||
The query content to search for in the indexed documents
|
||||
vector_db_ids:
|
||||
type: array
|
||||
items:
|
||||
type: string
|
||||
description: >-
|
||||
List of vector database IDs to search within
|
||||
query_config:
|
||||
$ref: '#/components/schemas/RAGQueryConfig'
|
||||
description: >-
|
||||
(Optional) Configuration parameters for the query operation
|
||||
additionalProperties: false
|
||||
required:
|
||||
- content
|
||||
- vector_db_ids
|
||||
title: QueryRequest
|
||||
RAGQueryResult:
|
||||
type: object
|
||||
properties:
|
||||
content:
|
||||
$ref: '#/components/schemas/InterleavedContent'
|
||||
description: >-
|
||||
(Optional) The retrieved content from the query
|
||||
metadata:
|
||||
type: object
|
||||
additionalProperties:
|
||||
oneOf:
|
||||
- type: 'null'
|
||||
- type: boolean
|
||||
- type: number
|
||||
- type: string
|
||||
- type: array
|
||||
- type: object
|
||||
description: >-
|
||||
Additional metadata about the query result
|
||||
additionalProperties: false
|
||||
required:
|
||||
- metadata
|
||||
title: RAGQueryResult
|
||||
description: >-
|
||||
Result of a RAG query containing retrieved content and metadata.
|
||||
ToolGroup:
|
||||
type: object
|
||||
properties:
|
||||
|
|
@ -10203,7 +9872,7 @@ components:
|
|||
enum:
|
||||
- model
|
||||
- shield
|
||||
- vector_db
|
||||
- vector_store
|
||||
- dataset
|
||||
- scoring_function
|
||||
- benchmark
|
||||
|
|
@ -11325,7 +10994,7 @@ components:
|
|||
enum:
|
||||
- model
|
||||
- shield
|
||||
- vector_db
|
||||
- vector_store
|
||||
- dataset
|
||||
- scoring_function
|
||||
- benchmark
|
||||
|
|
@ -12652,7 +12321,7 @@ components:
|
|||
enum:
|
||||
- model
|
||||
- shield
|
||||
- vector_db
|
||||
- vector_store
|
||||
- dataset
|
||||
- scoring_function
|
||||
- benchmark
|
||||
|
|
|
|||
|
|
@ -21,7 +21,7 @@ The `llamastack/distribution-meta-reference-gpu` distribution consists of the fo
|
|||
| inference | `inline::meta-reference` |
|
||||
| safety | `inline::llama-guard` |
|
||||
| scoring | `inline::basic`, `inline::llm-as-judge`, `inline::braintrust` |
|
||||
| tool_runtime | `remote::brave-search`, `remote::tavily-search`, `inline::rag-runtime`, `remote::model-context-protocol` |
|
||||
| tool_runtime | `remote::brave-search`, `remote::tavily-search`, `remote::model-context-protocol` |
|
||||
| vector_io | `inline::faiss`, `remote::chromadb`, `remote::pgvector` |
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -16,7 +16,7 @@ The `llamastack/distribution-nvidia` distribution consists of the following prov
|
|||
| post_training | `remote::nvidia` |
|
||||
| safety | `remote::nvidia` |
|
||||
| scoring | `inline::basic` |
|
||||
| tool_runtime | `inline::rag-runtime` |
|
||||
| tool_runtime | |
|
||||
| vector_io | `inline::faiss` |
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -28,7 +28,7 @@ description: |
|
|||
#### Empirical Example
|
||||
|
||||
Consider the histogram below in which 10,000 randomly generated strings were inserted
|
||||
in batches of 100 into both Faiss and sqlite-vec using `client.tool_runtime.rag_tool.insert()`.
|
||||
in batches of 100 into both Faiss and sqlite-vec.
|
||||
|
||||
```{image} ../../../../_static/providers/vector_io/write_time_comparison_sqlite-vec-faiss.png
|
||||
:alt: Comparison of SQLite-Vec and Faiss write times
|
||||
|
|
@ -233,7 +233,7 @@ Datasets that can fit in memory, frequent reads | Faiss | Optimized for speed, i
|
|||
#### Empirical Example
|
||||
|
||||
Consider the histogram below in which 10,000 randomly generated strings were inserted
|
||||
in batches of 100 into both Faiss and sqlite-vec using `client.tool_runtime.rag_tool.insert()`.
|
||||
in batches of 100 into both Faiss and sqlite-vec.
|
||||
|
||||
```{image} ../../../../_static/providers/vector_io/write_time_comparison_sqlite-vec-faiss.png
|
||||
:alt: Comparison of SQLite-Vec and Faiss write times
|
||||
|
|
|
|||
|
|
@ -32,7 +32,6 @@ Commands:
|
|||
scoring_functions Manage scoring functions.
|
||||
shields Manage safety shield services.
|
||||
toolgroups Manage available tool groups.
|
||||
vector_dbs Manage vector databases.
|
||||
```
|
||||
|
||||
### `llama-stack-client configure`
|
||||
|
|
@ -211,53 +210,6 @@ Unregister a model from distribution endpoint
|
|||
llama-stack-client models unregister <model_id>
|
||||
```
|
||||
|
||||
## Vector DB Management
|
||||
Manage vector databases.
|
||||
|
||||
|
||||
### `llama-stack-client vector_dbs list`
|
||||
Show available vector dbs on distribution endpoint
|
||||
```bash
|
||||
llama-stack-client vector_dbs list
|
||||
```
|
||||
```
|
||||
┏━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓
|
||||
┃ identifier ┃ provider_id ┃ provider_resource_id ┃ vector_db_type ┃ params ┃
|
||||
┡━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┩
|
||||
│ my_demo_vector_db │ faiss │ my_demo_vector_db │ │ embedding_dimension: 768 │
|
||||
│ │ │ │ │ embedding_model: nomic-embed-text-v1.5 │
|
||||
│ │ │ │ │ type: vector_db │
|
||||
│ │ │ │ │ │
|
||||
└──────────────────────────┴─────────────┴──────────────────────────┴────────────────┴───────────────────────────────────┘
|
||||
```
|
||||
|
||||
### `llama-stack-client vector_dbs register`
|
||||
Create a new vector db
|
||||
```bash
|
||||
llama-stack-client vector_dbs register <vector-db-id> [--provider-id <provider-id>] [--provider-vector-db-id <provider-vector-db-id>] [--embedding-model <embedding-model>] [--embedding-dimension <embedding-dimension>]
|
||||
```
|
||||
|
||||
|
||||
Required arguments:
|
||||
- `VECTOR_DB_ID`: Vector DB ID
|
||||
|
||||
Optional arguments:
|
||||
- `--provider-id`: Provider ID for the vector db
|
||||
- `--provider-vector-db-id`: Provider's vector db ID
|
||||
- `--embedding-model`: Embedding model to use. Default: `nomic-embed-text-v1.5`
|
||||
- `--embedding-dimension`: Dimension of embeddings. Default: 768
|
||||
|
||||
### `llama-stack-client vector_dbs unregister`
|
||||
Delete a vector db
|
||||
```bash
|
||||
llama-stack-client vector_dbs unregister <vector-db-id>
|
||||
```
|
||||
|
||||
|
||||
Required arguments:
|
||||
- `VECTOR_DB_ID`: Vector DB ID
|
||||
|
||||
|
||||
## Shield Management
|
||||
Manage safety shield services.
|
||||
### `llama-stack-client shields list`
|
||||
|
|
|
|||
File diff suppressed because one or more lines are too long
|
|
@ -196,16 +196,10 @@ def _get_endpoint_functions(
|
|||
def _get_defining_class(member_fn: str, derived_cls: type) -> type:
|
||||
"Find the class in which a member function is first defined in a class inheritance hierarchy."
|
||||
|
||||
# This import must be dynamic here
|
||||
from llama_stack.apis.tools import RAGToolRuntime, ToolRuntime
|
||||
|
||||
# iterate in reverse member resolution order to find most specific class first
|
||||
for cls in reversed(inspect.getmro(derived_cls)):
|
||||
for name, _ in inspect.getmembers(cls, inspect.isfunction):
|
||||
if name == member_fn:
|
||||
# HACK ALERT
|
||||
if cls == RAGToolRuntime:
|
||||
return ToolRuntime
|
||||
return cls
|
||||
|
||||
raise ValidationError(
|
||||
|
|
|
|||
4
docs/static/deprecated-llama-stack-spec.html
vendored
4
docs/static/deprecated-llama-stack-spec.html
vendored
|
|
@ -5547,7 +5547,7 @@
|
|||
"enum": [
|
||||
"model",
|
||||
"shield",
|
||||
"vector_db",
|
||||
"vector_store",
|
||||
"dataset",
|
||||
"scoring_function",
|
||||
"benchmark",
|
||||
|
|
@ -5798,7 +5798,7 @@
|
|||
"enum": [
|
||||
"model",
|
||||
"shield",
|
||||
"vector_db",
|
||||
"vector_store",
|
||||
"dataset",
|
||||
"scoring_function",
|
||||
"benchmark",
|
||||
|
|
|
|||
4
docs/static/deprecated-llama-stack-spec.yaml
vendored
4
docs/static/deprecated-llama-stack-spec.yaml
vendored
|
|
@ -4114,7 +4114,7 @@ components:
|
|||
enum:
|
||||
- model
|
||||
- shield
|
||||
- vector_db
|
||||
- vector_store
|
||||
- dataset
|
||||
- scoring_function
|
||||
- benchmark
|
||||
|
|
@ -4303,7 +4303,7 @@ components:
|
|||
enum:
|
||||
- model
|
||||
- shield
|
||||
- vector_db
|
||||
- vector_store
|
||||
- dataset
|
||||
- scoring_function
|
||||
- benchmark
|
||||
|
|
|
|||
|
|
@ -1850,7 +1850,7 @@
|
|||
"enum": [
|
||||
"model",
|
||||
"shield",
|
||||
"vector_db",
|
||||
"vector_store",
|
||||
"dataset",
|
||||
"scoring_function",
|
||||
"benchmark",
|
||||
|
|
@ -3983,7 +3983,7 @@
|
|||
"enum": [
|
||||
"model",
|
||||
"shield",
|
||||
"vector_db",
|
||||
"vector_store",
|
||||
"dataset",
|
||||
"scoring_function",
|
||||
"benchmark",
|
||||
|
|
|
|||
|
|
@ -1320,7 +1320,7 @@ components:
|
|||
enum:
|
||||
- model
|
||||
- shield
|
||||
- vector_db
|
||||
- vector_store
|
||||
- dataset
|
||||
- scoring_function
|
||||
- benchmark
|
||||
|
|
@ -2927,7 +2927,7 @@ components:
|
|||
enum:
|
||||
- model
|
||||
- shield
|
||||
- vector_db
|
||||
- vector_store
|
||||
- dataset
|
||||
- scoring_function
|
||||
- benchmark
|
||||
|
|
|
|||
431
docs/static/llama-stack-spec.html
vendored
431
docs/static/llama-stack-spec.html
vendored
|
|
@ -2624,89 +2624,6 @@
|
|||
"deprecated": false
|
||||
}
|
||||
},
|
||||
"/v1/tool-runtime/rag-tool/insert": {
|
||||
"post": {
|
||||
"responses": {
|
||||
"200": {
|
||||
"description": "OK"
|
||||
},
|
||||
"400": {
|
||||
"$ref": "#/components/responses/BadRequest400"
|
||||
},
|
||||
"429": {
|
||||
"$ref": "#/components/responses/TooManyRequests429"
|
||||
},
|
||||
"500": {
|
||||
"$ref": "#/components/responses/InternalServerError500"
|
||||
},
|
||||
"default": {
|
||||
"$ref": "#/components/responses/DefaultError"
|
||||
}
|
||||
},
|
||||
"tags": [
|
||||
"ToolRuntime"
|
||||
],
|
||||
"summary": "Index documents so they can be used by the RAG system.",
|
||||
"description": "Index documents so they can be used by the RAG system.",
|
||||
"parameters": [],
|
||||
"requestBody": {
|
||||
"content": {
|
||||
"application/json": {
|
||||
"schema": {
|
||||
"$ref": "#/components/schemas/InsertRequest"
|
||||
}
|
||||
}
|
||||
},
|
||||
"required": true
|
||||
},
|
||||
"deprecated": false
|
||||
}
|
||||
},
|
||||
"/v1/tool-runtime/rag-tool/query": {
|
||||
"post": {
|
||||
"responses": {
|
||||
"200": {
|
||||
"description": "RAGQueryResult containing the retrieved content and metadata",
|
||||
"content": {
|
||||
"application/json": {
|
||||
"schema": {
|
||||
"$ref": "#/components/schemas/RAGQueryResult"
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"400": {
|
||||
"$ref": "#/components/responses/BadRequest400"
|
||||
},
|
||||
"429": {
|
||||
"$ref": "#/components/responses/TooManyRequests429"
|
||||
},
|
||||
"500": {
|
||||
"$ref": "#/components/responses/InternalServerError500"
|
||||
},
|
||||
"default": {
|
||||
"$ref": "#/components/responses/DefaultError"
|
||||
}
|
||||
},
|
||||
"tags": [
|
||||
"ToolRuntime"
|
||||
],
|
||||
"summary": "Query the RAG system for context; typically invoked by the agent.",
|
||||
"description": "Query the RAG system for context; typically invoked by the agent.",
|
||||
"parameters": [],
|
||||
"requestBody": {
|
||||
"content": {
|
||||
"application/json": {
|
||||
"schema": {
|
||||
"$ref": "#/components/schemas/QueryRequest"
|
||||
}
|
||||
}
|
||||
},
|
||||
"required": true
|
||||
},
|
||||
"deprecated": false
|
||||
}
|
||||
},
|
||||
"/v1/toolgroups": {
|
||||
"get": {
|
||||
"responses": {
|
||||
|
|
@ -6800,7 +6717,7 @@
|
|||
"enum": [
|
||||
"model",
|
||||
"shield",
|
||||
"vector_db",
|
||||
"vector_store",
|
||||
"dataset",
|
||||
"scoring_function",
|
||||
"benchmark",
|
||||
|
|
@ -10205,7 +10122,7 @@
|
|||
"enum": [
|
||||
"model",
|
||||
"shield",
|
||||
"vector_db",
|
||||
"vector_store",
|
||||
"dataset",
|
||||
"scoring_function",
|
||||
"benchmark",
|
||||
|
|
@ -10687,7 +10604,7 @@
|
|||
"enum": [
|
||||
"model",
|
||||
"shield",
|
||||
"vector_db",
|
||||
"vector_store",
|
||||
"dataset",
|
||||
"scoring_function",
|
||||
"benchmark",
|
||||
|
|
@ -11383,346 +11300,6 @@
|
|||
"title": "ListToolDefsResponse",
|
||||
"description": "Response containing a list of tool definitions."
|
||||
},
|
||||
"RAGDocument": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"document_id": {
|
||||
"type": "string",
|
||||
"description": "The unique identifier for the document."
|
||||
},
|
||||
"content": {
|
||||
"oneOf": [
|
||||
{
|
||||
"type": "string"
|
||||
},
|
||||
{
|
||||
"$ref": "#/components/schemas/InterleavedContentItem"
|
||||
},
|
||||
{
|
||||
"type": "array",
|
||||
"items": {
|
||||
"$ref": "#/components/schemas/InterleavedContentItem"
|
||||
}
|
||||
},
|
||||
{
|
||||
"$ref": "#/components/schemas/URL"
|
||||
}
|
||||
],
|
||||
"description": "The content of the document."
|
||||
},
|
||||
"mime_type": {
|
||||
"type": "string",
|
||||
"description": "The MIME type of the document."
|
||||
},
|
||||
"metadata": {
|
||||
"type": "object",
|
||||
"additionalProperties": {
|
||||
"oneOf": [
|
||||
{
|
||||
"type": "null"
|
||||
},
|
||||
{
|
||||
"type": "boolean"
|
||||
},
|
||||
{
|
||||
"type": "number"
|
||||
},
|
||||
{
|
||||
"type": "string"
|
||||
},
|
||||
{
|
||||
"type": "array"
|
||||
},
|
||||
{
|
||||
"type": "object"
|
||||
}
|
||||
]
|
||||
},
|
||||
"description": "Additional metadata for the document."
|
||||
}
|
||||
},
|
||||
"additionalProperties": false,
|
||||
"required": [
|
||||
"document_id",
|
||||
"content",
|
||||
"metadata"
|
||||
],
|
||||
"title": "RAGDocument",
|
||||
"description": "A document to be used for document ingestion in the RAG Tool."
|
||||
},
|
||||
"InsertRequest": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"documents": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"$ref": "#/components/schemas/RAGDocument"
|
||||
},
|
||||
"description": "List of documents to index in the RAG system"
|
||||
},
|
||||
"vector_db_id": {
|
||||
"type": "string",
|
||||
"description": "ID of the vector database to store the document embeddings"
|
||||
},
|
||||
"chunk_size_in_tokens": {
|
||||
"type": "integer",
|
||||
"description": "(Optional) Size in tokens for document chunking during indexing"
|
||||
}
|
||||
},
|
||||
"additionalProperties": false,
|
||||
"required": [
|
||||
"documents",
|
||||
"vector_db_id",
|
||||
"chunk_size_in_tokens"
|
||||
],
|
||||
"title": "InsertRequest"
|
||||
},
|
||||
"DefaultRAGQueryGeneratorConfig": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"type": {
|
||||
"type": "string",
|
||||
"const": "default",
|
||||
"default": "default",
|
||||
"description": "Type of query generator, always 'default'"
|
||||
},
|
||||
"separator": {
|
||||
"type": "string",
|
||||
"default": " ",
|
||||
"description": "String separator used to join query terms"
|
||||
}
|
||||
},
|
||||
"additionalProperties": false,
|
||||
"required": [
|
||||
"type",
|
||||
"separator"
|
||||
],
|
||||
"title": "DefaultRAGQueryGeneratorConfig",
|
||||
"description": "Configuration for the default RAG query generator."
|
||||
},
|
||||
"LLMRAGQueryGeneratorConfig": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"type": {
|
||||
"type": "string",
|
||||
"const": "llm",
|
||||
"default": "llm",
|
||||
"description": "Type of query generator, always 'llm'"
|
||||
},
|
||||
"model": {
|
||||
"type": "string",
|
||||
"description": "Name of the language model to use for query generation"
|
||||
},
|
||||
"template": {
|
||||
"type": "string",
|
||||
"description": "Template string for formatting the query generation prompt"
|
||||
}
|
||||
},
|
||||
"additionalProperties": false,
|
||||
"required": [
|
||||
"type",
|
||||
"model",
|
||||
"template"
|
||||
],
|
||||
"title": "LLMRAGQueryGeneratorConfig",
|
||||
"description": "Configuration for the LLM-based RAG query generator."
|
||||
},
|
||||
"RAGQueryConfig": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"query_generator_config": {
|
||||
"oneOf": [
|
||||
{
|
||||
"$ref": "#/components/schemas/DefaultRAGQueryGeneratorConfig"
|
||||
},
|
||||
{
|
||||
"$ref": "#/components/schemas/LLMRAGQueryGeneratorConfig"
|
||||
}
|
||||
],
|
||||
"discriminator": {
|
||||
"propertyName": "type",
|
||||
"mapping": {
|
||||
"default": "#/components/schemas/DefaultRAGQueryGeneratorConfig",
|
||||
"llm": "#/components/schemas/LLMRAGQueryGeneratorConfig"
|
||||
}
|
||||
},
|
||||
"description": "Configuration for the query generator."
|
||||
},
|
||||
"max_tokens_in_context": {
|
||||
"type": "integer",
|
||||
"default": 4096,
|
||||
"description": "Maximum number of tokens in the context."
|
||||
},
|
||||
"max_chunks": {
|
||||
"type": "integer",
|
||||
"default": 5,
|
||||
"description": "Maximum number of chunks to retrieve."
|
||||
},
|
||||
"chunk_template": {
|
||||
"type": "string",
|
||||
"default": "Result {index}\nContent: {chunk.content}\nMetadata: {metadata}\n",
|
||||
"description": "Template for formatting each retrieved chunk in the context. Available placeholders: {index} (1-based chunk ordinal), {chunk.content} (chunk content string), {metadata} (chunk metadata dict). Default: \"Result {index}\\nContent: {chunk.content}\\nMetadata: {metadata}\\n\""
|
||||
},
|
||||
"mode": {
|
||||
"$ref": "#/components/schemas/RAGSearchMode",
|
||||
"default": "vector",
|
||||
"description": "Search mode for retrieval—either \"vector\", \"keyword\", or \"hybrid\". Default \"vector\"."
|
||||
},
|
||||
"ranker": {
|
||||
"$ref": "#/components/schemas/Ranker",
|
||||
"description": "Configuration for the ranker to use in hybrid search. Defaults to RRF ranker."
|
||||
}
|
||||
},
|
||||
"additionalProperties": false,
|
||||
"required": [
|
||||
"query_generator_config",
|
||||
"max_tokens_in_context",
|
||||
"max_chunks",
|
||||
"chunk_template"
|
||||
],
|
||||
"title": "RAGQueryConfig",
|
||||
"description": "Configuration for the RAG query generation."
|
||||
},
|
||||
"RAGSearchMode": {
|
||||
"type": "string",
|
||||
"enum": [
|
||||
"vector",
|
||||
"keyword",
|
||||
"hybrid"
|
||||
],
|
||||
"title": "RAGSearchMode",
|
||||
"description": "Search modes for RAG query retrieval: - VECTOR: Uses vector similarity search for semantic matching - KEYWORD: Uses keyword-based search for exact matching - HYBRID: Combines both vector and keyword search for better results"
|
||||
},
|
||||
"RRFRanker": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"type": {
|
||||
"type": "string",
|
||||
"const": "rrf",
|
||||
"default": "rrf",
|
||||
"description": "The type of ranker, always \"rrf\""
|
||||
},
|
||||
"impact_factor": {
|
||||
"type": "number",
|
||||
"default": 60.0,
|
||||
"description": "The impact factor for RRF scoring. Higher values give more weight to higher-ranked results. Must be greater than 0"
|
||||
}
|
||||
},
|
||||
"additionalProperties": false,
|
||||
"required": [
|
||||
"type",
|
||||
"impact_factor"
|
||||
],
|
||||
"title": "RRFRanker",
|
||||
"description": "Reciprocal Rank Fusion (RRF) ranker configuration."
|
||||
},
|
||||
"Ranker": {
|
||||
"oneOf": [
|
||||
{
|
||||
"$ref": "#/components/schemas/RRFRanker"
|
||||
},
|
||||
{
|
||||
"$ref": "#/components/schemas/WeightedRanker"
|
||||
}
|
||||
],
|
||||
"discriminator": {
|
||||
"propertyName": "type",
|
||||
"mapping": {
|
||||
"rrf": "#/components/schemas/RRFRanker",
|
||||
"weighted": "#/components/schemas/WeightedRanker"
|
||||
}
|
||||
}
|
||||
},
|
||||
"WeightedRanker": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"type": {
|
||||
"type": "string",
|
||||
"const": "weighted",
|
||||
"default": "weighted",
|
||||
"description": "The type of ranker, always \"weighted\""
|
||||
},
|
||||
"alpha": {
|
||||
"type": "number",
|
||||
"default": 0.5,
|
||||
"description": "Weight factor between 0 and 1. 0 means only use keyword scores, 1 means only use vector scores, values in between blend both scores."
|
||||
}
|
||||
},
|
||||
"additionalProperties": false,
|
||||
"required": [
|
||||
"type",
|
||||
"alpha"
|
||||
],
|
||||
"title": "WeightedRanker",
|
||||
"description": "Weighted ranker configuration that combines vector and keyword scores."
|
||||
},
|
||||
"QueryRequest": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"content": {
|
||||
"$ref": "#/components/schemas/InterleavedContent",
|
||||
"description": "The query content to search for in the indexed documents"
|
||||
},
|
||||
"vector_db_ids": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "string"
|
||||
},
|
||||
"description": "List of vector database IDs to search within"
|
||||
},
|
||||
"query_config": {
|
||||
"$ref": "#/components/schemas/RAGQueryConfig",
|
||||
"description": "(Optional) Configuration parameters for the query operation"
|
||||
}
|
||||
},
|
||||
"additionalProperties": false,
|
||||
"required": [
|
||||
"content",
|
||||
"vector_db_ids"
|
||||
],
|
||||
"title": "QueryRequest"
|
||||
},
|
||||
"RAGQueryResult": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"content": {
|
||||
"$ref": "#/components/schemas/InterleavedContent",
|
||||
"description": "(Optional) The retrieved content from the query"
|
||||
},
|
||||
"metadata": {
|
||||
"type": "object",
|
||||
"additionalProperties": {
|
||||
"oneOf": [
|
||||
{
|
||||
"type": "null"
|
||||
},
|
||||
{
|
||||
"type": "boolean"
|
||||
},
|
||||
{
|
||||
"type": "number"
|
||||
},
|
||||
{
|
||||
"type": "string"
|
||||
},
|
||||
{
|
||||
"type": "array"
|
||||
},
|
||||
{
|
||||
"type": "object"
|
||||
}
|
||||
]
|
||||
},
|
||||
"description": "Additional metadata about the query result"
|
||||
}
|
||||
},
|
||||
"additionalProperties": false,
|
||||
"required": [
|
||||
"metadata"
|
||||
],
|
||||
"title": "RAGQueryResult",
|
||||
"description": "Result of a RAG query containing retrieved content and metadata."
|
||||
},
|
||||
"ToolGroup": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
|
|
@ -11740,7 +11317,7 @@
|
|||
"enum": [
|
||||
"model",
|
||||
"shield",
|
||||
"vector_db",
|
||||
"vector_store",
|
||||
"dataset",
|
||||
"scoring_function",
|
||||
"benchmark",
|
||||
|
|
|
|||
339
docs/static/llama-stack-spec.yaml
vendored
339
docs/static/llama-stack-spec.yaml
vendored
|
|
@ -2036,69 +2036,6 @@ paths:
|
|||
schema:
|
||||
$ref: '#/components/schemas/URL'
|
||||
deprecated: false
|
||||
/v1/tool-runtime/rag-tool/insert:
|
||||
post:
|
||||
responses:
|
||||
'200':
|
||||
description: OK
|
||||
'400':
|
||||
$ref: '#/components/responses/BadRequest400'
|
||||
'429':
|
||||
$ref: >-
|
||||
#/components/responses/TooManyRequests429
|
||||
'500':
|
||||
$ref: >-
|
||||
#/components/responses/InternalServerError500
|
||||
default:
|
||||
$ref: '#/components/responses/DefaultError'
|
||||
tags:
|
||||
- ToolRuntime
|
||||
summary: >-
|
||||
Index documents so they can be used by the RAG system.
|
||||
description: >-
|
||||
Index documents so they can be used by the RAG system.
|
||||
parameters: []
|
||||
requestBody:
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: '#/components/schemas/InsertRequest'
|
||||
required: true
|
||||
deprecated: false
|
||||
/v1/tool-runtime/rag-tool/query:
|
||||
post:
|
||||
responses:
|
||||
'200':
|
||||
description: >-
|
||||
RAGQueryResult containing the retrieved content and metadata
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: '#/components/schemas/RAGQueryResult'
|
||||
'400':
|
||||
$ref: '#/components/responses/BadRequest400'
|
||||
'429':
|
||||
$ref: >-
|
||||
#/components/responses/TooManyRequests429
|
||||
'500':
|
||||
$ref: >-
|
||||
#/components/responses/InternalServerError500
|
||||
default:
|
||||
$ref: '#/components/responses/DefaultError'
|
||||
tags:
|
||||
- ToolRuntime
|
||||
summary: >-
|
||||
Query the RAG system for context; typically invoked by the agent.
|
||||
description: >-
|
||||
Query the RAG system for context; typically invoked by the agent.
|
||||
parameters: []
|
||||
requestBody:
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: '#/components/schemas/QueryRequest'
|
||||
required: true
|
||||
deprecated: false
|
||||
/v1/toolgroups:
|
||||
get:
|
||||
responses:
|
||||
|
|
@ -5227,7 +5164,7 @@ components:
|
|||
enum:
|
||||
- model
|
||||
- shield
|
||||
- vector_db
|
||||
- vector_store
|
||||
- dataset
|
||||
- scoring_function
|
||||
- benchmark
|
||||
|
|
@ -7919,7 +7856,7 @@ components:
|
|||
enum:
|
||||
- model
|
||||
- shield
|
||||
- vector_db
|
||||
- vector_store
|
||||
- dataset
|
||||
- scoring_function
|
||||
- benchmark
|
||||
|
|
@ -8227,7 +8164,7 @@ components:
|
|||
enum:
|
||||
- model
|
||||
- shield
|
||||
- vector_db
|
||||
- vector_store
|
||||
- dataset
|
||||
- scoring_function
|
||||
- benchmark
|
||||
|
|
@ -8708,274 +8645,6 @@ components:
|
|||
title: ListToolDefsResponse
|
||||
description: >-
|
||||
Response containing a list of tool definitions.
|
||||
RAGDocument:
|
||||
type: object
|
||||
properties:
|
||||
document_id:
|
||||
type: string
|
||||
description: The unique identifier for the document.
|
||||
content:
|
||||
oneOf:
|
||||
- type: string
|
||||
- $ref: '#/components/schemas/InterleavedContentItem'
|
||||
- type: array
|
||||
items:
|
||||
$ref: '#/components/schemas/InterleavedContentItem'
|
||||
- $ref: '#/components/schemas/URL'
|
||||
description: The content of the document.
|
||||
mime_type:
|
||||
type: string
|
||||
description: The MIME type of the document.
|
||||
metadata:
|
||||
type: object
|
||||
additionalProperties:
|
||||
oneOf:
|
||||
- type: 'null'
|
||||
- type: boolean
|
||||
- type: number
|
||||
- type: string
|
||||
- type: array
|
||||
- type: object
|
||||
description: Additional metadata for the document.
|
||||
additionalProperties: false
|
||||
required:
|
||||
- document_id
|
||||
- content
|
||||
- metadata
|
||||
title: RAGDocument
|
||||
description: >-
|
||||
A document to be used for document ingestion in the RAG Tool.
|
||||
InsertRequest:
|
||||
type: object
|
||||
properties:
|
||||
documents:
|
||||
type: array
|
||||
items:
|
||||
$ref: '#/components/schemas/RAGDocument'
|
||||
description: >-
|
||||
List of documents to index in the RAG system
|
||||
vector_db_id:
|
||||
type: string
|
||||
description: >-
|
||||
ID of the vector database to store the document embeddings
|
||||
chunk_size_in_tokens:
|
||||
type: integer
|
||||
description: >-
|
||||
(Optional) Size in tokens for document chunking during indexing
|
||||
additionalProperties: false
|
||||
required:
|
||||
- documents
|
||||
- vector_db_id
|
||||
- chunk_size_in_tokens
|
||||
title: InsertRequest
|
||||
DefaultRAGQueryGeneratorConfig:
|
||||
type: object
|
||||
properties:
|
||||
type:
|
||||
type: string
|
||||
const: default
|
||||
default: default
|
||||
description: >-
|
||||
Type of query generator, always 'default'
|
||||
separator:
|
||||
type: string
|
||||
default: ' '
|
||||
description: >-
|
||||
String separator used to join query terms
|
||||
additionalProperties: false
|
||||
required:
|
||||
- type
|
||||
- separator
|
||||
title: DefaultRAGQueryGeneratorConfig
|
||||
description: >-
|
||||
Configuration for the default RAG query generator.
|
||||
LLMRAGQueryGeneratorConfig:
|
||||
type: object
|
||||
properties:
|
||||
type:
|
||||
type: string
|
||||
const: llm
|
||||
default: llm
|
||||
description: Type of query generator, always 'llm'
|
||||
model:
|
||||
type: string
|
||||
description: >-
|
||||
Name of the language model to use for query generation
|
||||
template:
|
||||
type: string
|
||||
description: >-
|
||||
Template string for formatting the query generation prompt
|
||||
additionalProperties: false
|
||||
required:
|
||||
- type
|
||||
- model
|
||||
- template
|
||||
title: LLMRAGQueryGeneratorConfig
|
||||
description: >-
|
||||
Configuration for the LLM-based RAG query generator.
|
||||
RAGQueryConfig:
|
||||
type: object
|
||||
properties:
|
||||
query_generator_config:
|
||||
oneOf:
|
||||
- $ref: '#/components/schemas/DefaultRAGQueryGeneratorConfig'
|
||||
- $ref: '#/components/schemas/LLMRAGQueryGeneratorConfig'
|
||||
discriminator:
|
||||
propertyName: type
|
||||
mapping:
|
||||
default: '#/components/schemas/DefaultRAGQueryGeneratorConfig'
|
||||
llm: '#/components/schemas/LLMRAGQueryGeneratorConfig'
|
||||
description: Configuration for the query generator.
|
||||
max_tokens_in_context:
|
||||
type: integer
|
||||
default: 4096
|
||||
description: Maximum number of tokens in the context.
|
||||
max_chunks:
|
||||
type: integer
|
||||
default: 5
|
||||
description: Maximum number of chunks to retrieve.
|
||||
chunk_template:
|
||||
type: string
|
||||
default: >
|
||||
Result {index}
|
||||
|
||||
Content: {chunk.content}
|
||||
|
||||
Metadata: {metadata}
|
||||
description: >-
|
||||
Template for formatting each retrieved chunk in the context. Available
|
||||
placeholders: {index} (1-based chunk ordinal), {chunk.content} (chunk
|
||||
content string), {metadata} (chunk metadata dict). Default: "Result {index}\nContent:
|
||||
{chunk.content}\nMetadata: {metadata}\n"
|
||||
mode:
|
||||
$ref: '#/components/schemas/RAGSearchMode'
|
||||
default: vector
|
||||
description: >-
|
||||
Search mode for retrieval—either "vector", "keyword", or "hybrid". Default
|
||||
"vector".
|
||||
ranker:
|
||||
$ref: '#/components/schemas/Ranker'
|
||||
description: >-
|
||||
Configuration for the ranker to use in hybrid search. Defaults to RRF
|
||||
ranker.
|
||||
additionalProperties: false
|
||||
required:
|
||||
- query_generator_config
|
||||
- max_tokens_in_context
|
||||
- max_chunks
|
||||
- chunk_template
|
||||
title: RAGQueryConfig
|
||||
description: >-
|
||||
Configuration for the RAG query generation.
|
||||
RAGSearchMode:
|
||||
type: string
|
||||
enum:
|
||||
- vector
|
||||
- keyword
|
||||
- hybrid
|
||||
title: RAGSearchMode
|
||||
description: >-
|
||||
Search modes for RAG query retrieval: - VECTOR: Uses vector similarity search
|
||||
for semantic matching - KEYWORD: Uses keyword-based search for exact matching
|
||||
- HYBRID: Combines both vector and keyword search for better results
|
||||
RRFRanker:
|
||||
type: object
|
||||
properties:
|
||||
type:
|
||||
type: string
|
||||
const: rrf
|
||||
default: rrf
|
||||
description: The type of ranker, always "rrf"
|
||||
impact_factor:
|
||||
type: number
|
||||
default: 60.0
|
||||
description: >-
|
||||
The impact factor for RRF scoring. Higher values give more weight to higher-ranked
|
||||
results. Must be greater than 0
|
||||
additionalProperties: false
|
||||
required:
|
||||
- type
|
||||
- impact_factor
|
||||
title: RRFRanker
|
||||
description: >-
|
||||
Reciprocal Rank Fusion (RRF) ranker configuration.
|
||||
Ranker:
|
||||
oneOf:
|
||||
- $ref: '#/components/schemas/RRFRanker'
|
||||
- $ref: '#/components/schemas/WeightedRanker'
|
||||
discriminator:
|
||||
propertyName: type
|
||||
mapping:
|
||||
rrf: '#/components/schemas/RRFRanker'
|
||||
weighted: '#/components/schemas/WeightedRanker'
|
||||
WeightedRanker:
|
||||
type: object
|
||||
properties:
|
||||
type:
|
||||
type: string
|
||||
const: weighted
|
||||
default: weighted
|
||||
description: The type of ranker, always "weighted"
|
||||
alpha:
|
||||
type: number
|
||||
default: 0.5
|
||||
description: >-
|
||||
Weight factor between 0 and 1. 0 means only use keyword scores, 1 means
|
||||
only use vector scores, values in between blend both scores.
|
||||
additionalProperties: false
|
||||
required:
|
||||
- type
|
||||
- alpha
|
||||
title: WeightedRanker
|
||||
description: >-
|
||||
Weighted ranker configuration that combines vector and keyword scores.
|
||||
QueryRequest:
|
||||
type: object
|
||||
properties:
|
||||
content:
|
||||
$ref: '#/components/schemas/InterleavedContent'
|
||||
description: >-
|
||||
The query content to search for in the indexed documents
|
||||
vector_db_ids:
|
||||
type: array
|
||||
items:
|
||||
type: string
|
||||
description: >-
|
||||
List of vector database IDs to search within
|
||||
query_config:
|
||||
$ref: '#/components/schemas/RAGQueryConfig'
|
||||
description: >-
|
||||
(Optional) Configuration parameters for the query operation
|
||||
additionalProperties: false
|
||||
required:
|
||||
- content
|
||||
- vector_db_ids
|
||||
title: QueryRequest
|
||||
RAGQueryResult:
|
||||
type: object
|
||||
properties:
|
||||
content:
|
||||
$ref: '#/components/schemas/InterleavedContent'
|
||||
description: >-
|
||||
(Optional) The retrieved content from the query
|
||||
metadata:
|
||||
type: object
|
||||
additionalProperties:
|
||||
oneOf:
|
||||
- type: 'null'
|
||||
- type: boolean
|
||||
- type: number
|
||||
- type: string
|
||||
- type: array
|
||||
- type: object
|
||||
description: >-
|
||||
Additional metadata about the query result
|
||||
additionalProperties: false
|
||||
required:
|
||||
- metadata
|
||||
title: RAGQueryResult
|
||||
description: >-
|
||||
Result of a RAG query containing retrieved content and metadata.
|
||||
ToolGroup:
|
||||
type: object
|
||||
properties:
|
||||
|
|
@ -8990,7 +8659,7 @@ components:
|
|||
enum:
|
||||
- model
|
||||
- shield
|
||||
- vector_db
|
||||
- vector_store
|
||||
- dataset
|
||||
- scoring_function
|
||||
- benchmark
|
||||
|
|
|
|||
435
docs/static/stainless-llama-stack-spec.html
vendored
435
docs/static/stainless-llama-stack-spec.html
vendored
|
|
@ -2624,89 +2624,6 @@
|
|||
"deprecated": false
|
||||
}
|
||||
},
|
||||
"/v1/tool-runtime/rag-tool/insert": {
|
||||
"post": {
|
||||
"responses": {
|
||||
"200": {
|
||||
"description": "OK"
|
||||
},
|
||||
"400": {
|
||||
"$ref": "#/components/responses/BadRequest400"
|
||||
},
|
||||
"429": {
|
||||
"$ref": "#/components/responses/TooManyRequests429"
|
||||
},
|
||||
"500": {
|
||||
"$ref": "#/components/responses/InternalServerError500"
|
||||
},
|
||||
"default": {
|
||||
"$ref": "#/components/responses/DefaultError"
|
||||
}
|
||||
},
|
||||
"tags": [
|
||||
"ToolRuntime"
|
||||
],
|
||||
"summary": "Index documents so they can be used by the RAG system.",
|
||||
"description": "Index documents so they can be used by the RAG system.",
|
||||
"parameters": [],
|
||||
"requestBody": {
|
||||
"content": {
|
||||
"application/json": {
|
||||
"schema": {
|
||||
"$ref": "#/components/schemas/InsertRequest"
|
||||
}
|
||||
}
|
||||
},
|
||||
"required": true
|
||||
},
|
||||
"deprecated": false
|
||||
}
|
||||
},
|
||||
"/v1/tool-runtime/rag-tool/query": {
|
||||
"post": {
|
||||
"responses": {
|
||||
"200": {
|
||||
"description": "RAGQueryResult containing the retrieved content and metadata",
|
||||
"content": {
|
||||
"application/json": {
|
||||
"schema": {
|
||||
"$ref": "#/components/schemas/RAGQueryResult"
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"400": {
|
||||
"$ref": "#/components/responses/BadRequest400"
|
||||
},
|
||||
"429": {
|
||||
"$ref": "#/components/responses/TooManyRequests429"
|
||||
},
|
||||
"500": {
|
||||
"$ref": "#/components/responses/InternalServerError500"
|
||||
},
|
||||
"default": {
|
||||
"$ref": "#/components/responses/DefaultError"
|
||||
}
|
||||
},
|
||||
"tags": [
|
||||
"ToolRuntime"
|
||||
],
|
||||
"summary": "Query the RAG system for context; typically invoked by the agent.",
|
||||
"description": "Query the RAG system for context; typically invoked by the agent.",
|
||||
"parameters": [],
|
||||
"requestBody": {
|
||||
"content": {
|
||||
"application/json": {
|
||||
"schema": {
|
||||
"$ref": "#/components/schemas/QueryRequest"
|
||||
}
|
||||
}
|
||||
},
|
||||
"required": true
|
||||
},
|
||||
"deprecated": false
|
||||
}
|
||||
},
|
||||
"/v1/toolgroups": {
|
||||
"get": {
|
||||
"responses": {
|
||||
|
|
@ -8472,7 +8389,7 @@
|
|||
"enum": [
|
||||
"model",
|
||||
"shield",
|
||||
"vector_db",
|
||||
"vector_store",
|
||||
"dataset",
|
||||
"scoring_function",
|
||||
"benchmark",
|
||||
|
|
@ -11877,7 +11794,7 @@
|
|||
"enum": [
|
||||
"model",
|
||||
"shield",
|
||||
"vector_db",
|
||||
"vector_store",
|
||||
"dataset",
|
||||
"scoring_function",
|
||||
"benchmark",
|
||||
|
|
@ -12359,7 +12276,7 @@
|
|||
"enum": [
|
||||
"model",
|
||||
"shield",
|
||||
"vector_db",
|
||||
"vector_store",
|
||||
"dataset",
|
||||
"scoring_function",
|
||||
"benchmark",
|
||||
|
|
@ -13055,346 +12972,6 @@
|
|||
"title": "ListToolDefsResponse",
|
||||
"description": "Response containing a list of tool definitions."
|
||||
},
|
||||
"RAGDocument": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"document_id": {
|
||||
"type": "string",
|
||||
"description": "The unique identifier for the document."
|
||||
},
|
||||
"content": {
|
||||
"oneOf": [
|
||||
{
|
||||
"type": "string"
|
||||
},
|
||||
{
|
||||
"$ref": "#/components/schemas/InterleavedContentItem"
|
||||
},
|
||||
{
|
||||
"type": "array",
|
||||
"items": {
|
||||
"$ref": "#/components/schemas/InterleavedContentItem"
|
||||
}
|
||||
},
|
||||
{
|
||||
"$ref": "#/components/schemas/URL"
|
||||
}
|
||||
],
|
||||
"description": "The content of the document."
|
||||
},
|
||||
"mime_type": {
|
||||
"type": "string",
|
||||
"description": "The MIME type of the document."
|
||||
},
|
||||
"metadata": {
|
||||
"type": "object",
|
||||
"additionalProperties": {
|
||||
"oneOf": [
|
||||
{
|
||||
"type": "null"
|
||||
},
|
||||
{
|
||||
"type": "boolean"
|
||||
},
|
||||
{
|
||||
"type": "number"
|
||||
},
|
||||
{
|
||||
"type": "string"
|
||||
},
|
||||
{
|
||||
"type": "array"
|
||||
},
|
||||
{
|
||||
"type": "object"
|
||||
}
|
||||
]
|
||||
},
|
||||
"description": "Additional metadata for the document."
|
||||
}
|
||||
},
|
||||
"additionalProperties": false,
|
||||
"required": [
|
||||
"document_id",
|
||||
"content",
|
||||
"metadata"
|
||||
],
|
||||
"title": "RAGDocument",
|
||||
"description": "A document to be used for document ingestion in the RAG Tool."
|
||||
},
|
||||
"InsertRequest": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"documents": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"$ref": "#/components/schemas/RAGDocument"
|
||||
},
|
||||
"description": "List of documents to index in the RAG system"
|
||||
},
|
||||
"vector_db_id": {
|
||||
"type": "string",
|
||||
"description": "ID of the vector database to store the document embeddings"
|
||||
},
|
||||
"chunk_size_in_tokens": {
|
||||
"type": "integer",
|
||||
"description": "(Optional) Size in tokens for document chunking during indexing"
|
||||
}
|
||||
},
|
||||
"additionalProperties": false,
|
||||
"required": [
|
||||
"documents",
|
||||
"vector_db_id",
|
||||
"chunk_size_in_tokens"
|
||||
],
|
||||
"title": "InsertRequest"
|
||||
},
|
||||
"DefaultRAGQueryGeneratorConfig": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"type": {
|
||||
"type": "string",
|
||||
"const": "default",
|
||||
"default": "default",
|
||||
"description": "Type of query generator, always 'default'"
|
||||
},
|
||||
"separator": {
|
||||
"type": "string",
|
||||
"default": " ",
|
||||
"description": "String separator used to join query terms"
|
||||
}
|
||||
},
|
||||
"additionalProperties": false,
|
||||
"required": [
|
||||
"type",
|
||||
"separator"
|
||||
],
|
||||
"title": "DefaultRAGQueryGeneratorConfig",
|
||||
"description": "Configuration for the default RAG query generator."
|
||||
},
|
||||
"LLMRAGQueryGeneratorConfig": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"type": {
|
||||
"type": "string",
|
||||
"const": "llm",
|
||||
"default": "llm",
|
||||
"description": "Type of query generator, always 'llm'"
|
||||
},
|
||||
"model": {
|
||||
"type": "string",
|
||||
"description": "Name of the language model to use for query generation"
|
||||
},
|
||||
"template": {
|
||||
"type": "string",
|
||||
"description": "Template string for formatting the query generation prompt"
|
||||
}
|
||||
},
|
||||
"additionalProperties": false,
|
||||
"required": [
|
||||
"type",
|
||||
"model",
|
||||
"template"
|
||||
],
|
||||
"title": "LLMRAGQueryGeneratorConfig",
|
||||
"description": "Configuration for the LLM-based RAG query generator."
|
||||
},
|
||||
"RAGQueryConfig": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"query_generator_config": {
|
||||
"oneOf": [
|
||||
{
|
||||
"$ref": "#/components/schemas/DefaultRAGQueryGeneratorConfig"
|
||||
},
|
||||
{
|
||||
"$ref": "#/components/schemas/LLMRAGQueryGeneratorConfig"
|
||||
}
|
||||
],
|
||||
"discriminator": {
|
||||
"propertyName": "type",
|
||||
"mapping": {
|
||||
"default": "#/components/schemas/DefaultRAGQueryGeneratorConfig",
|
||||
"llm": "#/components/schemas/LLMRAGQueryGeneratorConfig"
|
||||
}
|
||||
},
|
||||
"description": "Configuration for the query generator."
|
||||
},
|
||||
"max_tokens_in_context": {
|
||||
"type": "integer",
|
||||
"default": 4096,
|
||||
"description": "Maximum number of tokens in the context."
|
||||
},
|
||||
"max_chunks": {
|
||||
"type": "integer",
|
||||
"default": 5,
|
||||
"description": "Maximum number of chunks to retrieve."
|
||||
},
|
||||
"chunk_template": {
|
||||
"type": "string",
|
||||
"default": "Result {index}\nContent: {chunk.content}\nMetadata: {metadata}\n",
|
||||
"description": "Template for formatting each retrieved chunk in the context. Available placeholders: {index} (1-based chunk ordinal), {chunk.content} (chunk content string), {metadata} (chunk metadata dict). Default: \"Result {index}\\nContent: {chunk.content}\\nMetadata: {metadata}\\n\""
|
||||
},
|
||||
"mode": {
|
||||
"$ref": "#/components/schemas/RAGSearchMode",
|
||||
"default": "vector",
|
||||
"description": "Search mode for retrieval—either \"vector\", \"keyword\", or \"hybrid\". Default \"vector\"."
|
||||
},
|
||||
"ranker": {
|
||||
"$ref": "#/components/schemas/Ranker",
|
||||
"description": "Configuration for the ranker to use in hybrid search. Defaults to RRF ranker."
|
||||
}
|
||||
},
|
||||
"additionalProperties": false,
|
||||
"required": [
|
||||
"query_generator_config",
|
||||
"max_tokens_in_context",
|
||||
"max_chunks",
|
||||
"chunk_template"
|
||||
],
|
||||
"title": "RAGQueryConfig",
|
||||
"description": "Configuration for the RAG query generation."
|
||||
},
|
||||
"RAGSearchMode": {
|
||||
"type": "string",
|
||||
"enum": [
|
||||
"vector",
|
||||
"keyword",
|
||||
"hybrid"
|
||||
],
|
||||
"title": "RAGSearchMode",
|
||||
"description": "Search modes for RAG query retrieval: - VECTOR: Uses vector similarity search for semantic matching - KEYWORD: Uses keyword-based search for exact matching - HYBRID: Combines both vector and keyword search for better results"
|
||||
},
|
||||
"RRFRanker": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"type": {
|
||||
"type": "string",
|
||||
"const": "rrf",
|
||||
"default": "rrf",
|
||||
"description": "The type of ranker, always \"rrf\""
|
||||
},
|
||||
"impact_factor": {
|
||||
"type": "number",
|
||||
"default": 60.0,
|
||||
"description": "The impact factor for RRF scoring. Higher values give more weight to higher-ranked results. Must be greater than 0"
|
||||
}
|
||||
},
|
||||
"additionalProperties": false,
|
||||
"required": [
|
||||
"type",
|
||||
"impact_factor"
|
||||
],
|
||||
"title": "RRFRanker",
|
||||
"description": "Reciprocal Rank Fusion (RRF) ranker configuration."
|
||||
},
|
||||
"Ranker": {
|
||||
"oneOf": [
|
||||
{
|
||||
"$ref": "#/components/schemas/RRFRanker"
|
||||
},
|
||||
{
|
||||
"$ref": "#/components/schemas/WeightedRanker"
|
||||
}
|
||||
],
|
||||
"discriminator": {
|
||||
"propertyName": "type",
|
||||
"mapping": {
|
||||
"rrf": "#/components/schemas/RRFRanker",
|
||||
"weighted": "#/components/schemas/WeightedRanker"
|
||||
}
|
||||
}
|
||||
},
|
||||
"WeightedRanker": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"type": {
|
||||
"type": "string",
|
||||
"const": "weighted",
|
||||
"default": "weighted",
|
||||
"description": "The type of ranker, always \"weighted\""
|
||||
},
|
||||
"alpha": {
|
||||
"type": "number",
|
||||
"default": 0.5,
|
||||
"description": "Weight factor between 0 and 1. 0 means only use keyword scores, 1 means only use vector scores, values in between blend both scores."
|
||||
}
|
||||
},
|
||||
"additionalProperties": false,
|
||||
"required": [
|
||||
"type",
|
||||
"alpha"
|
||||
],
|
||||
"title": "WeightedRanker",
|
||||
"description": "Weighted ranker configuration that combines vector and keyword scores."
|
||||
},
|
||||
"QueryRequest": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"content": {
|
||||
"$ref": "#/components/schemas/InterleavedContent",
|
||||
"description": "The query content to search for in the indexed documents"
|
||||
},
|
||||
"vector_db_ids": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "string"
|
||||
},
|
||||
"description": "List of vector database IDs to search within"
|
||||
},
|
||||
"query_config": {
|
||||
"$ref": "#/components/schemas/RAGQueryConfig",
|
||||
"description": "(Optional) Configuration parameters for the query operation"
|
||||
}
|
||||
},
|
||||
"additionalProperties": false,
|
||||
"required": [
|
||||
"content",
|
||||
"vector_db_ids"
|
||||
],
|
||||
"title": "QueryRequest"
|
||||
},
|
||||
"RAGQueryResult": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"content": {
|
||||
"$ref": "#/components/schemas/InterleavedContent",
|
||||
"description": "(Optional) The retrieved content from the query"
|
||||
},
|
||||
"metadata": {
|
||||
"type": "object",
|
||||
"additionalProperties": {
|
||||
"oneOf": [
|
||||
{
|
||||
"type": "null"
|
||||
},
|
||||
{
|
||||
"type": "boolean"
|
||||
},
|
||||
{
|
||||
"type": "number"
|
||||
},
|
||||
{
|
||||
"type": "string"
|
||||
},
|
||||
{
|
||||
"type": "array"
|
||||
},
|
||||
{
|
||||
"type": "object"
|
||||
}
|
||||
]
|
||||
},
|
||||
"description": "Additional metadata about the query result"
|
||||
}
|
||||
},
|
||||
"additionalProperties": false,
|
||||
"required": [
|
||||
"metadata"
|
||||
],
|
||||
"title": "RAGQueryResult",
|
||||
"description": "Result of a RAG query containing retrieved content and metadata."
|
||||
},
|
||||
"ToolGroup": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
|
|
@ -13412,7 +12989,7 @@
|
|||
"enum": [
|
||||
"model",
|
||||
"shield",
|
||||
"vector_db",
|
||||
"vector_store",
|
||||
"dataset",
|
||||
"scoring_function",
|
||||
"benchmark",
|
||||
|
|
@ -14959,7 +14536,7 @@
|
|||
"enum": [
|
||||
"model",
|
||||
"shield",
|
||||
"vector_db",
|
||||
"vector_store",
|
||||
"dataset",
|
||||
"scoring_function",
|
||||
"benchmark",
|
||||
|
|
@ -16704,7 +16281,7 @@
|
|||
"enum": [
|
||||
"model",
|
||||
"shield",
|
||||
"vector_db",
|
||||
"vector_store",
|
||||
"dataset",
|
||||
"scoring_function",
|
||||
"benchmark",
|
||||
|
|
|
|||
343
docs/static/stainless-llama-stack-spec.yaml
vendored
343
docs/static/stainless-llama-stack-spec.yaml
vendored
|
|
@ -2039,69 +2039,6 @@ paths:
|
|||
schema:
|
||||
$ref: '#/components/schemas/URL'
|
||||
deprecated: false
|
||||
/v1/tool-runtime/rag-tool/insert:
|
||||
post:
|
||||
responses:
|
||||
'200':
|
||||
description: OK
|
||||
'400':
|
||||
$ref: '#/components/responses/BadRequest400'
|
||||
'429':
|
||||
$ref: >-
|
||||
#/components/responses/TooManyRequests429
|
||||
'500':
|
||||
$ref: >-
|
||||
#/components/responses/InternalServerError500
|
||||
default:
|
||||
$ref: '#/components/responses/DefaultError'
|
||||
tags:
|
||||
- ToolRuntime
|
||||
summary: >-
|
||||
Index documents so they can be used by the RAG system.
|
||||
description: >-
|
||||
Index documents so they can be used by the RAG system.
|
||||
parameters: []
|
||||
requestBody:
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: '#/components/schemas/InsertRequest'
|
||||
required: true
|
||||
deprecated: false
|
||||
/v1/tool-runtime/rag-tool/query:
|
||||
post:
|
||||
responses:
|
||||
'200':
|
||||
description: >-
|
||||
RAGQueryResult containing the retrieved content and metadata
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: '#/components/schemas/RAGQueryResult'
|
||||
'400':
|
||||
$ref: '#/components/responses/BadRequest400'
|
||||
'429':
|
||||
$ref: >-
|
||||
#/components/responses/TooManyRequests429
|
||||
'500':
|
||||
$ref: >-
|
||||
#/components/responses/InternalServerError500
|
||||
default:
|
||||
$ref: '#/components/responses/DefaultError'
|
||||
tags:
|
||||
- ToolRuntime
|
||||
summary: >-
|
||||
Query the RAG system for context; typically invoked by the agent.
|
||||
description: >-
|
||||
Query the RAG system for context; typically invoked by the agent.
|
||||
parameters: []
|
||||
requestBody:
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: '#/components/schemas/QueryRequest'
|
||||
required: true
|
||||
deprecated: false
|
||||
/v1/toolgroups:
|
||||
get:
|
||||
responses:
|
||||
|
|
@ -6440,7 +6377,7 @@ components:
|
|||
enum:
|
||||
- model
|
||||
- shield
|
||||
- vector_db
|
||||
- vector_store
|
||||
- dataset
|
||||
- scoring_function
|
||||
- benchmark
|
||||
|
|
@ -9132,7 +9069,7 @@ components:
|
|||
enum:
|
||||
- model
|
||||
- shield
|
||||
- vector_db
|
||||
- vector_store
|
||||
- dataset
|
||||
- scoring_function
|
||||
- benchmark
|
||||
|
|
@ -9440,7 +9377,7 @@ components:
|
|||
enum:
|
||||
- model
|
||||
- shield
|
||||
- vector_db
|
||||
- vector_store
|
||||
- dataset
|
||||
- scoring_function
|
||||
- benchmark
|
||||
|
|
@ -9921,274 +9858,6 @@ components:
|
|||
title: ListToolDefsResponse
|
||||
description: >-
|
||||
Response containing a list of tool definitions.
|
||||
RAGDocument:
|
||||
type: object
|
||||
properties:
|
||||
document_id:
|
||||
type: string
|
||||
description: The unique identifier for the document.
|
||||
content:
|
||||
oneOf:
|
||||
- type: string
|
||||
- $ref: '#/components/schemas/InterleavedContentItem'
|
||||
- type: array
|
||||
items:
|
||||
$ref: '#/components/schemas/InterleavedContentItem'
|
||||
- $ref: '#/components/schemas/URL'
|
||||
description: The content of the document.
|
||||
mime_type:
|
||||
type: string
|
||||
description: The MIME type of the document.
|
||||
metadata:
|
||||
type: object
|
||||
additionalProperties:
|
||||
oneOf:
|
||||
- type: 'null'
|
||||
- type: boolean
|
||||
- type: number
|
||||
- type: string
|
||||
- type: array
|
||||
- type: object
|
||||
description: Additional metadata for the document.
|
||||
additionalProperties: false
|
||||
required:
|
||||
- document_id
|
||||
- content
|
||||
- metadata
|
||||
title: RAGDocument
|
||||
description: >-
|
||||
A document to be used for document ingestion in the RAG Tool.
|
||||
InsertRequest:
|
||||
type: object
|
||||
properties:
|
||||
documents:
|
||||
type: array
|
||||
items:
|
||||
$ref: '#/components/schemas/RAGDocument'
|
||||
description: >-
|
||||
List of documents to index in the RAG system
|
||||
vector_db_id:
|
||||
type: string
|
||||
description: >-
|
||||
ID of the vector database to store the document embeddings
|
||||
chunk_size_in_tokens:
|
||||
type: integer
|
||||
description: >-
|
||||
(Optional) Size in tokens for document chunking during indexing
|
||||
additionalProperties: false
|
||||
required:
|
||||
- documents
|
||||
- vector_db_id
|
||||
- chunk_size_in_tokens
|
||||
title: InsertRequest
|
||||
DefaultRAGQueryGeneratorConfig:
|
||||
type: object
|
||||
properties:
|
||||
type:
|
||||
type: string
|
||||
const: default
|
||||
default: default
|
||||
description: >-
|
||||
Type of query generator, always 'default'
|
||||
separator:
|
||||
type: string
|
||||
default: ' '
|
||||
description: >-
|
||||
String separator used to join query terms
|
||||
additionalProperties: false
|
||||
required:
|
||||
- type
|
||||
- separator
|
||||
title: DefaultRAGQueryGeneratorConfig
|
||||
description: >-
|
||||
Configuration for the default RAG query generator.
|
||||
LLMRAGQueryGeneratorConfig:
|
||||
type: object
|
||||
properties:
|
||||
type:
|
||||
type: string
|
||||
const: llm
|
||||
default: llm
|
||||
description: Type of query generator, always 'llm'
|
||||
model:
|
||||
type: string
|
||||
description: >-
|
||||
Name of the language model to use for query generation
|
||||
template:
|
||||
type: string
|
||||
description: >-
|
||||
Template string for formatting the query generation prompt
|
||||
additionalProperties: false
|
||||
required:
|
||||
- type
|
||||
- model
|
||||
- template
|
||||
title: LLMRAGQueryGeneratorConfig
|
||||
description: >-
|
||||
Configuration for the LLM-based RAG query generator.
|
||||
RAGQueryConfig:
|
||||
type: object
|
||||
properties:
|
||||
query_generator_config:
|
||||
oneOf:
|
||||
- $ref: '#/components/schemas/DefaultRAGQueryGeneratorConfig'
|
||||
- $ref: '#/components/schemas/LLMRAGQueryGeneratorConfig'
|
||||
discriminator:
|
||||
propertyName: type
|
||||
mapping:
|
||||
default: '#/components/schemas/DefaultRAGQueryGeneratorConfig'
|
||||
llm: '#/components/schemas/LLMRAGQueryGeneratorConfig'
|
||||
description: Configuration for the query generator.
|
||||
max_tokens_in_context:
|
||||
type: integer
|
||||
default: 4096
|
||||
description: Maximum number of tokens in the context.
|
||||
max_chunks:
|
||||
type: integer
|
||||
default: 5
|
||||
description: Maximum number of chunks to retrieve.
|
||||
chunk_template:
|
||||
type: string
|
||||
default: >
|
||||
Result {index}
|
||||
|
||||
Content: {chunk.content}
|
||||
|
||||
Metadata: {metadata}
|
||||
description: >-
|
||||
Template for formatting each retrieved chunk in the context. Available
|
||||
placeholders: {index} (1-based chunk ordinal), {chunk.content} (chunk
|
||||
content string), {metadata} (chunk metadata dict). Default: "Result {index}\nContent:
|
||||
{chunk.content}\nMetadata: {metadata}\n"
|
||||
mode:
|
||||
$ref: '#/components/schemas/RAGSearchMode'
|
||||
default: vector
|
||||
description: >-
|
||||
Search mode for retrieval—either "vector", "keyword", or "hybrid". Default
|
||||
"vector".
|
||||
ranker:
|
||||
$ref: '#/components/schemas/Ranker'
|
||||
description: >-
|
||||
Configuration for the ranker to use in hybrid search. Defaults to RRF
|
||||
ranker.
|
||||
additionalProperties: false
|
||||
required:
|
||||
- query_generator_config
|
||||
- max_tokens_in_context
|
||||
- max_chunks
|
||||
- chunk_template
|
||||
title: RAGQueryConfig
|
||||
description: >-
|
||||
Configuration for the RAG query generation.
|
||||
RAGSearchMode:
|
||||
type: string
|
||||
enum:
|
||||
- vector
|
||||
- keyword
|
||||
- hybrid
|
||||
title: RAGSearchMode
|
||||
description: >-
|
||||
Search modes for RAG query retrieval: - VECTOR: Uses vector similarity search
|
||||
for semantic matching - KEYWORD: Uses keyword-based search for exact matching
|
||||
- HYBRID: Combines both vector and keyword search for better results
|
||||
RRFRanker:
|
||||
type: object
|
||||
properties:
|
||||
type:
|
||||
type: string
|
||||
const: rrf
|
||||
default: rrf
|
||||
description: The type of ranker, always "rrf"
|
||||
impact_factor:
|
||||
type: number
|
||||
default: 60.0
|
||||
description: >-
|
||||
The impact factor for RRF scoring. Higher values give more weight to higher-ranked
|
||||
results. Must be greater than 0
|
||||
additionalProperties: false
|
||||
required:
|
||||
- type
|
||||
- impact_factor
|
||||
title: RRFRanker
|
||||
description: >-
|
||||
Reciprocal Rank Fusion (RRF) ranker configuration.
|
||||
Ranker:
|
||||
oneOf:
|
||||
- $ref: '#/components/schemas/RRFRanker'
|
||||
- $ref: '#/components/schemas/WeightedRanker'
|
||||
discriminator:
|
||||
propertyName: type
|
||||
mapping:
|
||||
rrf: '#/components/schemas/RRFRanker'
|
||||
weighted: '#/components/schemas/WeightedRanker'
|
||||
WeightedRanker:
|
||||
type: object
|
||||
properties:
|
||||
type:
|
||||
type: string
|
||||
const: weighted
|
||||
default: weighted
|
||||
description: The type of ranker, always "weighted"
|
||||
alpha:
|
||||
type: number
|
||||
default: 0.5
|
||||
description: >-
|
||||
Weight factor between 0 and 1. 0 means only use keyword scores, 1 means
|
||||
only use vector scores, values in between blend both scores.
|
||||
additionalProperties: false
|
||||
required:
|
||||
- type
|
||||
- alpha
|
||||
title: WeightedRanker
|
||||
description: >-
|
||||
Weighted ranker configuration that combines vector and keyword scores.
|
||||
QueryRequest:
|
||||
type: object
|
||||
properties:
|
||||
content:
|
||||
$ref: '#/components/schemas/InterleavedContent'
|
||||
description: >-
|
||||
The query content to search for in the indexed documents
|
||||
vector_db_ids:
|
||||
type: array
|
||||
items:
|
||||
type: string
|
||||
description: >-
|
||||
List of vector database IDs to search within
|
||||
query_config:
|
||||
$ref: '#/components/schemas/RAGQueryConfig'
|
||||
description: >-
|
||||
(Optional) Configuration parameters for the query operation
|
||||
additionalProperties: false
|
||||
required:
|
||||
- content
|
||||
- vector_db_ids
|
||||
title: QueryRequest
|
||||
RAGQueryResult:
|
||||
type: object
|
||||
properties:
|
||||
content:
|
||||
$ref: '#/components/schemas/InterleavedContent'
|
||||
description: >-
|
||||
(Optional) The retrieved content from the query
|
||||
metadata:
|
||||
type: object
|
||||
additionalProperties:
|
||||
oneOf:
|
||||
- type: 'null'
|
||||
- type: boolean
|
||||
- type: number
|
||||
- type: string
|
||||
- type: array
|
||||
- type: object
|
||||
description: >-
|
||||
Additional metadata about the query result
|
||||
additionalProperties: false
|
||||
required:
|
||||
- metadata
|
||||
title: RAGQueryResult
|
||||
description: >-
|
||||
Result of a RAG query containing retrieved content and metadata.
|
||||
ToolGroup:
|
||||
type: object
|
||||
properties:
|
||||
|
|
@ -10203,7 +9872,7 @@ components:
|
|||
enum:
|
||||
- model
|
||||
- shield
|
||||
- vector_db
|
||||
- vector_store
|
||||
- dataset
|
||||
- scoring_function
|
||||
- benchmark
|
||||
|
|
@ -11325,7 +10994,7 @@ components:
|
|||
enum:
|
||||
- model
|
||||
- shield
|
||||
- vector_db
|
||||
- vector_store
|
||||
- dataset
|
||||
- scoring_function
|
||||
- benchmark
|
||||
|
|
@ -12652,7 +12321,7 @@ components:
|
|||
enum:
|
||||
- model
|
||||
- shield
|
||||
- vector_db
|
||||
- vector_store
|
||||
- dataset
|
||||
- scoring_function
|
||||
- benchmark
|
||||
|
|
|
|||
|
|
@ -121,7 +121,7 @@ class Api(Enum, metaclass=DynamicApiMeta):
|
|||
|
||||
models = "models"
|
||||
shields = "shields"
|
||||
vector_dbs = "vector_dbs" # only used for routing
|
||||
vector_stores = "vector_stores" # only used for routing table
|
||||
datasets = "datasets"
|
||||
scoring_functions = "scoring_functions"
|
||||
benchmarks = "benchmarks"
|
||||
|
|
|
|||
|
|
@ -13,7 +13,7 @@ from pydantic import BaseModel, Field
|
|||
class ResourceType(StrEnum):
|
||||
model = "model"
|
||||
shield = "shield"
|
||||
vector_db = "vector_db"
|
||||
vector_store = "vector_store"
|
||||
dataset = "dataset"
|
||||
scoring_function = "scoring_function"
|
||||
benchmark = "benchmark"
|
||||
|
|
@ -34,4 +34,4 @@ class Resource(BaseModel):
|
|||
|
||||
provider_id: str = Field(description="ID of the provider that owns this resource")
|
||||
|
||||
type: ResourceType = Field(description="Type of resource (e.g. 'model', 'shield', 'vector_db', etc.)")
|
||||
type: ResourceType = Field(description="Type of resource (e.g. 'model', 'shield', 'vector_store', etc.)")
|
||||
|
|
|
|||
|
|
@ -4,5 +4,4 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from .rag_tool import *
|
||||
from .tools import *
|
||||
|
|
|
|||
|
|
@ -1,218 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from enum import Enum, StrEnum
|
||||
from typing import Annotated, Any, Literal, Protocol
|
||||
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
from typing_extensions import runtime_checkable
|
||||
|
||||
from llama_stack.apis.common.content_types import URL, InterleavedContent
|
||||
from llama_stack.apis.version import LLAMA_STACK_API_V1
|
||||
from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol
|
||||
from llama_stack.schema_utils import json_schema_type, register_schema, webmethod
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class RRFRanker(BaseModel):
|
||||
"""
|
||||
Reciprocal Rank Fusion (RRF) ranker configuration.
|
||||
|
||||
:param type: The type of ranker, always "rrf"
|
||||
:param impact_factor: The impact factor for RRF scoring. Higher values give more weight to higher-ranked results.
|
||||
Must be greater than 0
|
||||
"""
|
||||
|
||||
type: Literal["rrf"] = "rrf"
|
||||
impact_factor: float = Field(default=60.0, gt=0.0) # default of 60 for optimal performance
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class WeightedRanker(BaseModel):
|
||||
"""
|
||||
Weighted ranker configuration that combines vector and keyword scores.
|
||||
|
||||
:param type: The type of ranker, always "weighted"
|
||||
:param alpha: Weight factor between 0 and 1.
|
||||
0 means only use keyword scores,
|
||||
1 means only use vector scores,
|
||||
values in between blend both scores.
|
||||
"""
|
||||
|
||||
type: Literal["weighted"] = "weighted"
|
||||
alpha: float = Field(
|
||||
default=0.5,
|
||||
ge=0.0,
|
||||
le=1.0,
|
||||
description="Weight factor between 0 and 1. 0 means only keyword scores, 1 means only vector scores.",
|
||||
)
|
||||
|
||||
|
||||
Ranker = Annotated[
|
||||
RRFRanker | WeightedRanker,
|
||||
Field(discriminator="type"),
|
||||
]
|
||||
register_schema(Ranker, name="Ranker")
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class RAGDocument(BaseModel):
|
||||
"""
|
||||
A document to be used for document ingestion in the RAG Tool.
|
||||
|
||||
:param document_id: The unique identifier for the document.
|
||||
:param content: The content of the document.
|
||||
:param mime_type: The MIME type of the document.
|
||||
:param metadata: Additional metadata for the document.
|
||||
"""
|
||||
|
||||
document_id: str
|
||||
content: InterleavedContent | URL
|
||||
mime_type: str | None = None
|
||||
metadata: dict[str, Any] = Field(default_factory=dict)
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class RAGQueryResult(BaseModel):
|
||||
"""Result of a RAG query containing retrieved content and metadata.
|
||||
|
||||
:param content: (Optional) The retrieved content from the query
|
||||
:param metadata: Additional metadata about the query result
|
||||
"""
|
||||
|
||||
content: InterleavedContent | None = None
|
||||
metadata: dict[str, Any] = Field(default_factory=dict)
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class RAGQueryGenerator(Enum):
|
||||
"""Types of query generators for RAG systems.
|
||||
|
||||
:cvar default: Default query generator using simple text processing
|
||||
:cvar llm: LLM-based query generator for enhanced query understanding
|
||||
:cvar custom: Custom query generator implementation
|
||||
"""
|
||||
|
||||
default = "default"
|
||||
llm = "llm"
|
||||
custom = "custom"
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class RAGSearchMode(StrEnum):
|
||||
"""
|
||||
Search modes for RAG query retrieval:
|
||||
- VECTOR: Uses vector similarity search for semantic matching
|
||||
- KEYWORD: Uses keyword-based search for exact matching
|
||||
- HYBRID: Combines both vector and keyword search for better results
|
||||
"""
|
||||
|
||||
VECTOR = "vector"
|
||||
KEYWORD = "keyword"
|
||||
HYBRID = "hybrid"
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class DefaultRAGQueryGeneratorConfig(BaseModel):
|
||||
"""Configuration for the default RAG query generator.
|
||||
|
||||
:param type: Type of query generator, always 'default'
|
||||
:param separator: String separator used to join query terms
|
||||
"""
|
||||
|
||||
type: Literal["default"] = "default"
|
||||
separator: str = " "
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class LLMRAGQueryGeneratorConfig(BaseModel):
|
||||
"""Configuration for the LLM-based RAG query generator.
|
||||
|
||||
:param type: Type of query generator, always 'llm'
|
||||
:param model: Name of the language model to use for query generation
|
||||
:param template: Template string for formatting the query generation prompt
|
||||
"""
|
||||
|
||||
type: Literal["llm"] = "llm"
|
||||
model: str
|
||||
template: str
|
||||
|
||||
|
||||
RAGQueryGeneratorConfig = Annotated[
|
||||
DefaultRAGQueryGeneratorConfig | LLMRAGQueryGeneratorConfig,
|
||||
Field(discriminator="type"),
|
||||
]
|
||||
register_schema(RAGQueryGeneratorConfig, name="RAGQueryGeneratorConfig")
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class RAGQueryConfig(BaseModel):
|
||||
"""
|
||||
Configuration for the RAG query generation.
|
||||
|
||||
:param query_generator_config: Configuration for the query generator.
|
||||
:param max_tokens_in_context: Maximum number of tokens in the context.
|
||||
:param max_chunks: Maximum number of chunks to retrieve.
|
||||
:param chunk_template: Template for formatting each retrieved chunk in the context.
|
||||
Available placeholders: {index} (1-based chunk ordinal), {chunk.content} (chunk content string), {metadata} (chunk metadata dict).
|
||||
Default: "Result {index}\\nContent: {chunk.content}\\nMetadata: {metadata}\\n"
|
||||
:param mode: Search mode for retrieval—either "vector", "keyword", or "hybrid". Default "vector".
|
||||
:param ranker: Configuration for the ranker to use in hybrid search. Defaults to RRF ranker.
|
||||
"""
|
||||
|
||||
# This config defines how a query is generated using the messages
|
||||
# for memory bank retrieval.
|
||||
query_generator_config: RAGQueryGeneratorConfig = Field(default=DefaultRAGQueryGeneratorConfig())
|
||||
max_tokens_in_context: int = 4096
|
||||
max_chunks: int = 5
|
||||
chunk_template: str = "Result {index}\nContent: {chunk.content}\nMetadata: {metadata}\n"
|
||||
mode: RAGSearchMode | None = RAGSearchMode.VECTOR
|
||||
ranker: Ranker | None = Field(default=None) # Only used for hybrid mode
|
||||
|
||||
@field_validator("chunk_template")
|
||||
def validate_chunk_template(cls, v: str) -> str:
|
||||
if "{chunk.content}" not in v:
|
||||
raise ValueError("chunk_template must contain {chunk.content}")
|
||||
if "{index}" not in v:
|
||||
raise ValueError("chunk_template must contain {index}")
|
||||
if len(v) == 0:
|
||||
raise ValueError("chunk_template must not be empty")
|
||||
return v
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
@trace_protocol
|
||||
class RAGToolRuntime(Protocol):
|
||||
@webmethod(route="/tool-runtime/rag-tool/insert", method="POST", level=LLAMA_STACK_API_V1)
|
||||
async def insert(
|
||||
self,
|
||||
documents: list[RAGDocument],
|
||||
vector_db_id: str,
|
||||
chunk_size_in_tokens: int = 512,
|
||||
) -> None:
|
||||
"""Index documents so they can be used by the RAG system.
|
||||
|
||||
:param documents: List of documents to index in the RAG system
|
||||
:param vector_db_id: ID of the vector database to store the document embeddings
|
||||
:param chunk_size_in_tokens: (Optional) Size in tokens for document chunking during indexing
|
||||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/tool-runtime/rag-tool/query", method="POST", level=LLAMA_STACK_API_V1)
|
||||
async def query(
|
||||
self,
|
||||
content: InterleavedContent,
|
||||
vector_db_ids: list[str],
|
||||
query_config: RAGQueryConfig | None = None,
|
||||
) -> RAGQueryResult:
|
||||
"""Query the RAG system for context; typically invoked by the agent.
|
||||
|
||||
:param content: The query content to search for in the indexed documents
|
||||
:param vector_db_ids: List of vector database IDs to search within
|
||||
:param query_config: (Optional) Configuration parameters for the query operation
|
||||
:returns: RAGQueryResult containing the retrieved content and metadata
|
||||
"""
|
||||
...
|
||||
|
|
@ -4,7 +4,6 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from enum import Enum
|
||||
from typing import Any, Literal, Protocol
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
|
@ -16,8 +15,6 @@ from llama_stack.apis.version import LLAMA_STACK_API_V1
|
|||
from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol
|
||||
from llama_stack.schema_utils import json_schema_type, webmethod
|
||||
|
||||
from .rag_tool import RAGToolRuntime
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class ToolDef(BaseModel):
|
||||
|
|
@ -181,22 +178,11 @@ class ToolGroups(Protocol):
|
|||
...
|
||||
|
||||
|
||||
class SpecialToolGroup(Enum):
|
||||
"""Special tool groups with predefined functionality.
|
||||
|
||||
:cvar rag_tool: Retrieval-Augmented Generation tool group for document search and retrieval
|
||||
"""
|
||||
|
||||
rag_tool = "rag_tool"
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
@trace_protocol
|
||||
class ToolRuntime(Protocol):
|
||||
tool_store: ToolStore | None = None
|
||||
|
||||
rag_tool: RAGToolRuntime | None = None
|
||||
|
||||
# TODO: This needs to be renamed once OPEN API generator name conflict issue is fixed.
|
||||
@webmethod(route="/tool-runtime/list-tools", method="GET", level=LLAMA_STACK_API_V1)
|
||||
async def list_runtime_tools(
|
||||
|
|
|
|||
|
|
@ -1,93 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Literal, Protocol, runtime_checkable
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from llama_stack.apis.resource import Resource, ResourceType
|
||||
from llama_stack.schema_utils import json_schema_type
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class VectorDB(Resource):
|
||||
"""Vector database resource for storing and querying vector embeddings.
|
||||
|
||||
:param type: Type of resource, always 'vector_db' for vector databases
|
||||
:param embedding_model: Name of the embedding model to use for vector generation
|
||||
:param embedding_dimension: Dimension of the embedding vectors
|
||||
"""
|
||||
|
||||
type: Literal[ResourceType.vector_db] = ResourceType.vector_db
|
||||
|
||||
embedding_model: str
|
||||
embedding_dimension: int
|
||||
vector_db_name: str | None = None
|
||||
|
||||
@property
|
||||
def vector_db_id(self) -> str:
|
||||
return self.identifier
|
||||
|
||||
@property
|
||||
def provider_vector_db_id(self) -> str | None:
|
||||
return self.provider_resource_id
|
||||
|
||||
|
||||
class VectorDBInput(BaseModel):
|
||||
"""Input parameters for creating or configuring a vector database.
|
||||
|
||||
:param vector_db_id: Unique identifier for the vector database
|
||||
:param embedding_model: Name of the embedding model to use for vector generation
|
||||
:param embedding_dimension: Dimension of the embedding vectors
|
||||
:param provider_vector_db_id: (Optional) Provider-specific identifier for the vector database
|
||||
"""
|
||||
|
||||
vector_db_id: str
|
||||
embedding_model: str
|
||||
embedding_dimension: int
|
||||
provider_id: str | None = None
|
||||
provider_vector_db_id: str | None = None
|
||||
|
||||
|
||||
class ListVectorDBsResponse(BaseModel):
|
||||
"""Response from listing vector databases.
|
||||
|
||||
:param data: List of vector databases
|
||||
"""
|
||||
|
||||
data: list[VectorDB]
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class VectorDBs(Protocol):
|
||||
"""Internal protocol for vector_dbs routing - no public API endpoints."""
|
||||
|
||||
async def list_vector_dbs(self) -> ListVectorDBsResponse:
|
||||
"""Internal method to list vector databases."""
|
||||
...
|
||||
|
||||
async def get_vector_db(
|
||||
self,
|
||||
vector_db_id: str,
|
||||
) -> VectorDB:
|
||||
"""Internal method to get a vector database by ID."""
|
||||
...
|
||||
|
||||
async def register_vector_db(
|
||||
self,
|
||||
vector_db_id: str,
|
||||
embedding_model: str,
|
||||
embedding_dimension: int | None = 384,
|
||||
provider_id: str | None = None,
|
||||
vector_db_name: str | None = None,
|
||||
provider_vector_db_id: str | None = None,
|
||||
) -> VectorDB:
|
||||
"""Internal method to register a vector database."""
|
||||
...
|
||||
|
||||
async def unregister_vector_db(self, vector_db_id: str) -> None:
|
||||
"""Internal method to unregister a vector database."""
|
||||
...
|
||||
|
|
@ -15,7 +15,7 @@ from fastapi import Body
|
|||
from pydantic import BaseModel, Field
|
||||
|
||||
from llama_stack.apis.inference import InterleavedContent
|
||||
from llama_stack.apis.vector_dbs import VectorDB
|
||||
from llama_stack.apis.vector_stores import VectorStore
|
||||
from llama_stack.apis.version import LLAMA_STACK_API_V1
|
||||
from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol
|
||||
from llama_stack.providers.utils.vector_io.vector_utils import generate_chunk_id
|
||||
|
|
@ -140,6 +140,7 @@ class VectorStoreFileCounts(BaseModel):
|
|||
total: int
|
||||
|
||||
|
||||
# TODO: rename this as OpenAIVectorStore
|
||||
@json_schema_type
|
||||
class VectorStoreObject(BaseModel):
|
||||
"""OpenAI Vector Store object.
|
||||
|
|
@ -517,17 +518,18 @@ class OpenAICreateVectorStoreFileBatchRequestWithExtraBody(BaseModel, extra="all
|
|||
chunking_strategy: VectorStoreChunkingStrategy | None = None
|
||||
|
||||
|
||||
class VectorDBStore(Protocol):
|
||||
def get_vector_db(self, vector_db_id: str) -> VectorDB | None: ...
|
||||
class VectorStoreTable(Protocol):
|
||||
def get_vector_store(self, vector_store_id: str) -> VectorStore | None: ...
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
@trace_protocol
|
||||
class VectorIO(Protocol):
|
||||
vector_db_store: VectorDBStore | None = None
|
||||
vector_store_table: VectorStoreTable | None = None
|
||||
|
||||
# this will just block now until chunks are inserted, but it should
|
||||
# probably return a Job instance which can be polled for completion
|
||||
# TODO: rename vector_db_id to vector_store_id once Stainless is working
|
||||
@webmethod(route="/vector-io/insert", method="POST", level=LLAMA_STACK_API_V1)
|
||||
async def insert_chunks(
|
||||
self,
|
||||
|
|
@ -546,6 +548,7 @@ class VectorIO(Protocol):
|
|||
"""
|
||||
...
|
||||
|
||||
# TODO: rename vector_db_id to vector_store_id once Stainless is working
|
||||
@webmethod(route="/vector-io/query", method="POST", level=LLAMA_STACK_API_V1)
|
||||
async def query_chunks(
|
||||
self,
|
||||
|
|
|
|||
|
|
@ -4,4 +4,4 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from .vector_dbs import *
|
||||
from .vector_stores import *
|
||||
51
llama_stack/apis/vector_stores/vector_stores.py
Normal file
51
llama_stack/apis/vector_stores/vector_stores.py
Normal file
|
|
@ -0,0 +1,51 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Literal
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from llama_stack.apis.resource import Resource, ResourceType
|
||||
|
||||
|
||||
# Internal resource type for storing the vector store routing and other information
|
||||
class VectorStore(Resource):
|
||||
"""Vector database resource for storing and querying vector embeddings.
|
||||
|
||||
:param type: Type of resource, always 'vector_store' for vector stores
|
||||
:param embedding_model: Name of the embedding model to use for vector generation
|
||||
:param embedding_dimension: Dimension of the embedding vectors
|
||||
"""
|
||||
|
||||
type: Literal[ResourceType.vector_store] = ResourceType.vector_store
|
||||
|
||||
embedding_model: str
|
||||
embedding_dimension: int
|
||||
vector_store_name: str | None = None
|
||||
|
||||
@property
|
||||
def vector_store_id(self) -> str:
|
||||
return self.identifier
|
||||
|
||||
@property
|
||||
def provider_vector_store_id(self) -> str | None:
|
||||
return self.provider_resource_id
|
||||
|
||||
|
||||
class VectorStoreInput(BaseModel):
|
||||
"""Input parameters for creating or configuring a vector database.
|
||||
|
||||
:param vector_store_id: Unique identifier for the vector store
|
||||
:param embedding_model: Name of the embedding model to use for vector generation
|
||||
:param embedding_dimension: Dimension of the embedding vectors
|
||||
:param provider_vector_store_id: (Optional) Provider-specific identifier for the vector store
|
||||
"""
|
||||
|
||||
vector_store_id: str
|
||||
embedding_model: str
|
||||
embedding_dimension: int
|
||||
provider_id: str | None = None
|
||||
provider_vector_store_id: str | None = None
|
||||
|
|
@ -41,7 +41,7 @@ class AccessRule(BaseModel):
|
|||
A rule defines a list of action either to permit or to forbid. It may specify a
|
||||
principal or a resource that must match for the rule to take effect. The resource
|
||||
to match should be specified in the form of a type qualified identifier, e.g.
|
||||
model::my-model or vector_db::some-db, or a wildcard for all resources of a type,
|
||||
model::my-model or vector_store::some-db, or a wildcard for all resources of a type,
|
||||
e.g. model::*. If the principal or resource are not specified, they will match all
|
||||
requests.
|
||||
|
||||
|
|
@ -79,9 +79,9 @@ class AccessRule(BaseModel):
|
|||
description: any user has read access to any resource created by a member of their team
|
||||
- forbid:
|
||||
actions: [create, read, delete]
|
||||
resource: vector_db::*
|
||||
resource: vector_store::*
|
||||
unless: user with admin in roles
|
||||
description: only user with admin role can use vector_db resources
|
||||
description: only user with admin role can use vector_store resources
|
||||
|
||||
"""
|
||||
|
||||
|
|
|
|||
|
|
@ -23,8 +23,8 @@ from llama_stack.apis.scoring import Scoring
|
|||
from llama_stack.apis.scoring_functions import ScoringFn, ScoringFnInput
|
||||
from llama_stack.apis.shields import Shield, ShieldInput
|
||||
from llama_stack.apis.tools import ToolGroup, ToolGroupInput, ToolRuntime
|
||||
from llama_stack.apis.vector_dbs import VectorDB, VectorDBInput
|
||||
from llama_stack.apis.vector_io import VectorIO
|
||||
from llama_stack.apis.vector_stores import VectorStore, VectorStoreInput
|
||||
from llama_stack.core.access_control.datatypes import AccessRule
|
||||
from llama_stack.core.storage.datatypes import (
|
||||
KVStoreReference,
|
||||
|
|
@ -71,7 +71,7 @@ class ShieldWithOwner(Shield, ResourceWithOwner):
|
|||
pass
|
||||
|
||||
|
||||
class VectorDBWithOwner(VectorDB, ResourceWithOwner):
|
||||
class VectorStoreWithOwner(VectorStore, ResourceWithOwner):
|
||||
pass
|
||||
|
||||
|
||||
|
|
@ -91,12 +91,12 @@ class ToolGroupWithOwner(ToolGroup, ResourceWithOwner):
|
|||
pass
|
||||
|
||||
|
||||
RoutableObject = Model | Shield | VectorDB | Dataset | ScoringFn | Benchmark | ToolGroup
|
||||
RoutableObject = Model | Shield | VectorStore | Dataset | ScoringFn | Benchmark | ToolGroup
|
||||
|
||||
RoutableObjectWithProvider = Annotated[
|
||||
ModelWithOwner
|
||||
| ShieldWithOwner
|
||||
| VectorDBWithOwner
|
||||
| VectorStoreWithOwner
|
||||
| DatasetWithOwner
|
||||
| ScoringFnWithOwner
|
||||
| BenchmarkWithOwner
|
||||
|
|
@ -427,7 +427,7 @@ class RegisteredResources(BaseModel):
|
|||
|
||||
models: list[ModelInput] = Field(default_factory=list)
|
||||
shields: list[ShieldInput] = Field(default_factory=list)
|
||||
vector_dbs: list[VectorDBInput] = Field(default_factory=list)
|
||||
vector_stores: list[VectorStoreInput] = Field(default_factory=list)
|
||||
datasets: list[DatasetInput] = Field(default_factory=list)
|
||||
scoring_fns: list[ScoringFnInput] = Field(default_factory=list)
|
||||
benchmarks: list[BenchmarkInput] = Field(default_factory=list)
|
||||
|
|
|
|||
|
|
@ -64,7 +64,7 @@ def builtin_automatically_routed_apis() -> list[AutoRoutedApiInfo]:
|
|||
router_api=Api.tool_runtime,
|
||||
),
|
||||
AutoRoutedApiInfo(
|
||||
routing_table_api=Api.vector_dbs,
|
||||
routing_table_api=Api.vector_stores,
|
||||
router_api=Api.vector_io,
|
||||
),
|
||||
]
|
||||
|
|
|
|||
|
|
@ -29,8 +29,8 @@ from llama_stack.apis.scoring_functions import ScoringFunctions
|
|||
from llama_stack.apis.shields import Shields
|
||||
from llama_stack.apis.telemetry import Telemetry
|
||||
from llama_stack.apis.tools import ToolGroups, ToolRuntime
|
||||
from llama_stack.apis.vector_dbs import VectorDBs
|
||||
from llama_stack.apis.vector_io import VectorIO
|
||||
from llama_stack.apis.vector_stores import VectorStore
|
||||
from llama_stack.apis.version import LLAMA_STACK_API_V1ALPHA
|
||||
from llama_stack.core.client import get_client_impl
|
||||
from llama_stack.core.datatypes import (
|
||||
|
|
@ -82,7 +82,7 @@ def api_protocol_map(external_apis: dict[Api, ExternalApiSpec] | None = None) ->
|
|||
Api.inspect: Inspect,
|
||||
Api.batches: Batches,
|
||||
Api.vector_io: VectorIO,
|
||||
Api.vector_dbs: VectorDBs,
|
||||
Api.vector_stores: VectorStore,
|
||||
Api.models: Models,
|
||||
Api.safety: Safety,
|
||||
Api.shields: Shields,
|
||||
|
|
|
|||
|
|
@ -29,7 +29,7 @@ async def get_routing_table_impl(
|
|||
from ..routing_tables.scoring_functions import ScoringFunctionsRoutingTable
|
||||
from ..routing_tables.shields import ShieldsRoutingTable
|
||||
from ..routing_tables.toolgroups import ToolGroupsRoutingTable
|
||||
from ..routing_tables.vector_dbs import VectorDBsRoutingTable
|
||||
from ..routing_tables.vector_stores import VectorStoresRoutingTable
|
||||
|
||||
api_to_tables = {
|
||||
"models": ModelsRoutingTable,
|
||||
|
|
@ -38,7 +38,7 @@ async def get_routing_table_impl(
|
|||
"scoring_functions": ScoringFunctionsRoutingTable,
|
||||
"benchmarks": BenchmarksRoutingTable,
|
||||
"tool_groups": ToolGroupsRoutingTable,
|
||||
"vector_dbs": VectorDBsRoutingTable,
|
||||
"vector_stores": VectorStoresRoutingTable,
|
||||
}
|
||||
|
||||
if api.value not in api_to_tables:
|
||||
|
|
|
|||
|
|
@ -8,16 +8,8 @@ from typing import Any
|
|||
|
||||
from llama_stack.apis.common.content_types import (
|
||||
URL,
|
||||
InterleavedContent,
|
||||
)
|
||||
from llama_stack.apis.tools import (
|
||||
ListToolDefsResponse,
|
||||
RAGDocument,
|
||||
RAGQueryConfig,
|
||||
RAGQueryResult,
|
||||
RAGToolRuntime,
|
||||
ToolRuntime,
|
||||
)
|
||||
from llama_stack.apis.tools import ListToolDefsResponse, ToolRuntime
|
||||
from llama_stack.log import get_logger
|
||||
|
||||
from ..routing_tables.toolgroups import ToolGroupsRoutingTable
|
||||
|
|
@ -26,36 +18,6 @@ logger = get_logger(name=__name__, category="core::routers")
|
|||
|
||||
|
||||
class ToolRuntimeRouter(ToolRuntime):
|
||||
class RagToolImpl(RAGToolRuntime):
|
||||
def __init__(
|
||||
self,
|
||||
routing_table: ToolGroupsRoutingTable,
|
||||
) -> None:
|
||||
logger.debug("Initializing ToolRuntimeRouter.RagToolImpl")
|
||||
self.routing_table = routing_table
|
||||
|
||||
async def query(
|
||||
self,
|
||||
content: InterleavedContent,
|
||||
vector_db_ids: list[str],
|
||||
query_config: RAGQueryConfig | None = None,
|
||||
) -> RAGQueryResult:
|
||||
logger.debug(f"ToolRuntimeRouter.RagToolImpl.query: {vector_db_ids}")
|
||||
provider = await self.routing_table.get_provider_impl("knowledge_search")
|
||||
return await provider.query(content, vector_db_ids, query_config)
|
||||
|
||||
async def insert(
|
||||
self,
|
||||
documents: list[RAGDocument],
|
||||
vector_db_id: str,
|
||||
chunk_size_in_tokens: int = 512,
|
||||
) -> None:
|
||||
logger.debug(
|
||||
f"ToolRuntimeRouter.RagToolImpl.insert: {vector_db_id}, {len(documents)} documents, chunk_size={chunk_size_in_tokens}"
|
||||
)
|
||||
provider = await self.routing_table.get_provider_impl("insert_into_memory")
|
||||
return await provider.insert(documents, vector_db_id, chunk_size_in_tokens)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
routing_table: ToolGroupsRoutingTable,
|
||||
|
|
@ -63,11 +25,6 @@ class ToolRuntimeRouter(ToolRuntime):
|
|||
logger.debug("Initializing ToolRuntimeRouter")
|
||||
self.routing_table = routing_table
|
||||
|
||||
# HACK ALERT this should be in sync with "get_all_api_endpoints()"
|
||||
self.rag_tool = self.RagToolImpl(routing_table)
|
||||
for method in ("query", "insert"):
|
||||
setattr(self, f"rag_tool.{method}", getattr(self.rag_tool, method))
|
||||
|
||||
async def initialize(self) -> None:
|
||||
logger.debug("ToolRuntimeRouter.initialize")
|
||||
pass
|
||||
|
|
|
|||
|
|
@ -71,25 +71,6 @@ class VectorIORouter(VectorIO):
|
|||
|
||||
raise ValueError(f"Embedding model '{embedding_model_id}' not found or not an embedding model")
|
||||
|
||||
async def register_vector_db(
|
||||
self,
|
||||
vector_db_id: str,
|
||||
embedding_model: str,
|
||||
embedding_dimension: int | None = 384,
|
||||
provider_id: str | None = None,
|
||||
vector_db_name: str | None = None,
|
||||
provider_vector_db_id: str | None = None,
|
||||
) -> None:
|
||||
logger.debug(f"VectorIORouter.register_vector_db: {vector_db_id}, {embedding_model}")
|
||||
await self.routing_table.register_vector_db(
|
||||
vector_db_id,
|
||||
embedding_model,
|
||||
embedding_dimension,
|
||||
provider_id,
|
||||
vector_db_name,
|
||||
provider_vector_db_id,
|
||||
)
|
||||
|
||||
async def insert_chunks(
|
||||
self,
|
||||
vector_db_id: str,
|
||||
|
|
@ -165,22 +146,22 @@ class VectorIORouter(VectorIO):
|
|||
else:
|
||||
provider_id = list(self.routing_table.impls_by_provider_id.keys())[0]
|
||||
|
||||
vector_db_id = f"vs_{uuid.uuid4()}"
|
||||
registered_vector_db = await self.routing_table.register_vector_db(
|
||||
vector_db_id=vector_db_id,
|
||||
vector_store_id = f"vs_{uuid.uuid4()}"
|
||||
registered_vector_store = await self.routing_table.register_vector_store(
|
||||
vector_store_id=vector_store_id,
|
||||
embedding_model=embedding_model,
|
||||
embedding_dimension=embedding_dimension,
|
||||
provider_id=provider_id,
|
||||
provider_vector_db_id=vector_db_id,
|
||||
vector_db_name=params.name,
|
||||
provider_vector_store_id=vector_store_id,
|
||||
vector_store_name=params.name,
|
||||
)
|
||||
provider = await self.routing_table.get_provider_impl(registered_vector_db.identifier)
|
||||
provider = await self.routing_table.get_provider_impl(registered_vector_store.identifier)
|
||||
|
||||
# Update model_extra with registered values so provider uses the already-registered vector_db
|
||||
# Update model_extra with registered values so provider uses the already-registered vector_store
|
||||
if params.model_extra is None:
|
||||
params.model_extra = {}
|
||||
params.model_extra["provider_vector_db_id"] = registered_vector_db.provider_resource_id
|
||||
params.model_extra["provider_id"] = registered_vector_db.provider_id
|
||||
params.model_extra["provider_vector_store_id"] = registered_vector_store.provider_resource_id
|
||||
params.model_extra["provider_id"] = registered_vector_store.provider_id
|
||||
if embedding_model is not None:
|
||||
params.model_extra["embedding_model"] = embedding_model
|
||||
if embedding_dimension is not None:
|
||||
|
|
@ -198,15 +179,15 @@ class VectorIORouter(VectorIO):
|
|||
logger.debug(f"VectorIORouter.openai_list_vector_stores: limit={limit}")
|
||||
# Route to default provider for now - could aggregate from all providers in the future
|
||||
# call retrieve on each vector dbs to get list of vector stores
|
||||
vector_dbs = await self.routing_table.get_all_with_type("vector_db")
|
||||
vector_stores = await self.routing_table.get_all_with_type("vector_store")
|
||||
all_stores = []
|
||||
for vector_db in vector_dbs:
|
||||
for vector_store in vector_stores:
|
||||
try:
|
||||
provider = await self.routing_table.get_provider_impl(vector_db.identifier)
|
||||
vector_store = await provider.openai_retrieve_vector_store(vector_db.identifier)
|
||||
provider = await self.routing_table.get_provider_impl(vector_store.identifier)
|
||||
vector_store = await provider.openai_retrieve_vector_store(vector_store.identifier)
|
||||
all_stores.append(vector_store)
|
||||
except Exception as e:
|
||||
logger.error(f"Error retrieving vector store {vector_db.identifier}: {e}")
|
||||
logger.error(f"Error retrieving vector store {vector_store.identifier}: {e}")
|
||||
continue
|
||||
|
||||
# Sort by created_at
|
||||
|
|
|
|||
|
|
@ -41,7 +41,7 @@ async def register_object_with_provider(obj: RoutableObject, p: Any) -> Routable
|
|||
elif api == Api.safety:
|
||||
return await p.register_shield(obj)
|
||||
elif api == Api.vector_io:
|
||||
return await p.register_vector_db(obj)
|
||||
return await p.register_vector_store(obj)
|
||||
elif api == Api.datasetio:
|
||||
return await p.register_dataset(obj)
|
||||
elif api == Api.scoring:
|
||||
|
|
@ -57,7 +57,7 @@ async def register_object_with_provider(obj: RoutableObject, p: Any) -> Routable
|
|||
async def unregister_object_from_provider(obj: RoutableObject, p: Any) -> None:
|
||||
api = get_impl_api(p)
|
||||
if api == Api.vector_io:
|
||||
return await p.unregister_vector_db(obj.identifier)
|
||||
return await p.unregister_vector_store(obj.identifier)
|
||||
elif api == Api.inference:
|
||||
return await p.unregister_model(obj.identifier)
|
||||
elif api == Api.safety:
|
||||
|
|
@ -108,7 +108,7 @@ class CommonRoutingTableImpl(RoutingTable):
|
|||
elif api == Api.safety:
|
||||
p.shield_store = self
|
||||
elif api == Api.vector_io:
|
||||
p.vector_db_store = self
|
||||
p.vector_store_store = self
|
||||
elif api == Api.datasetio:
|
||||
p.dataset_store = self
|
||||
elif api == Api.scoring:
|
||||
|
|
@ -134,15 +134,15 @@ class CommonRoutingTableImpl(RoutingTable):
|
|||
from .scoring_functions import ScoringFunctionsRoutingTable
|
||||
from .shields import ShieldsRoutingTable
|
||||
from .toolgroups import ToolGroupsRoutingTable
|
||||
from .vector_dbs import VectorDBsRoutingTable
|
||||
from .vector_stores import VectorStoresRoutingTable
|
||||
|
||||
def apiname_object():
|
||||
if isinstance(self, ModelsRoutingTable):
|
||||
return ("Inference", "model")
|
||||
elif isinstance(self, ShieldsRoutingTable):
|
||||
return ("Safety", "shield")
|
||||
elif isinstance(self, VectorDBsRoutingTable):
|
||||
return ("VectorIO", "vector_db")
|
||||
elif isinstance(self, VectorStoresRoutingTable):
|
||||
return ("VectorIO", "vector_store")
|
||||
elif isinstance(self, DatasetsRoutingTable):
|
||||
return ("DatasetIO", "dataset")
|
||||
elif isinstance(self, ScoringFunctionsRoutingTable):
|
||||
|
|
|
|||
|
|
@ -6,15 +6,12 @@
|
|||
|
||||
from typing import Any
|
||||
|
||||
from pydantic import TypeAdapter
|
||||
|
||||
from llama_stack.apis.common.errors import ModelNotFoundError, ModelTypeError
|
||||
from llama_stack.apis.models import ModelType
|
||||
from llama_stack.apis.resource import ResourceType
|
||||
|
||||
# Removed VectorDBs import to avoid exposing public API
|
||||
# Removed VectorStores import to avoid exposing public API
|
||||
from llama_stack.apis.vector_io.vector_io import (
|
||||
OpenAICreateVectorStoreRequestWithExtraBody,
|
||||
SearchRankingOptions,
|
||||
VectorStoreChunkingStrategy,
|
||||
VectorStoreDeleteResponse,
|
||||
|
|
@ -26,7 +23,7 @@ from llama_stack.apis.vector_io.vector_io import (
|
|||
VectorStoreSearchResponsePage,
|
||||
)
|
||||
from llama_stack.core.datatypes import (
|
||||
VectorDBWithOwner,
|
||||
VectorStoreWithOwner,
|
||||
)
|
||||
from llama_stack.log import get_logger
|
||||
|
||||
|
|
@ -35,23 +32,23 @@ from .common import CommonRoutingTableImpl, lookup_model
|
|||
logger = get_logger(name=__name__, category="core::routing_tables")
|
||||
|
||||
|
||||
class VectorDBsRoutingTable(CommonRoutingTableImpl):
|
||||
"""Internal routing table for vector_db operations.
|
||||
class VectorStoresRoutingTable(CommonRoutingTableImpl):
|
||||
"""Internal routing table for vector_store operations.
|
||||
|
||||
Does not inherit from VectorDBs to avoid exposing public API endpoints.
|
||||
Does not inherit from VectorStores to avoid exposing public API endpoints.
|
||||
Only provides internal routing functionality for VectorIORouter.
|
||||
"""
|
||||
|
||||
# Internal methods only - no public API exposure
|
||||
|
||||
async def register_vector_db(
|
||||
async def register_vector_store(
|
||||
self,
|
||||
vector_db_id: str,
|
||||
vector_store_id: str,
|
||||
embedding_model: str,
|
||||
embedding_dimension: int | None = 384,
|
||||
provider_id: str | None = None,
|
||||
provider_vector_db_id: str | None = None,
|
||||
vector_db_name: str | None = None,
|
||||
provider_vector_store_id: str | None = None,
|
||||
vector_store_name: str | None = None,
|
||||
) -> Any:
|
||||
if provider_id is None:
|
||||
if len(self.impls_by_provider_id) > 0:
|
||||
|
|
@ -67,52 +64,24 @@ class VectorDBsRoutingTable(CommonRoutingTableImpl):
|
|||
raise ModelNotFoundError(embedding_model)
|
||||
if model.model_type != ModelType.embedding:
|
||||
raise ModelTypeError(embedding_model, model.model_type, ModelType.embedding)
|
||||
if "embedding_dimension" not in model.metadata:
|
||||
raise ValueError(f"Model {embedding_model} does not have an embedding dimension")
|
||||
|
||||
try:
|
||||
provider = self.impls_by_provider_id[provider_id]
|
||||
except KeyError:
|
||||
available_providers = list(self.impls_by_provider_id.keys())
|
||||
raise ValueError(
|
||||
f"Provider '{provider_id}' not found in routing table. Available providers: {available_providers}"
|
||||
) from None
|
||||
logger.warning(
|
||||
"VectorDB is being deprecated in future releases in favor of VectorStore. Please migrate your usage accordingly."
|
||||
)
|
||||
request = OpenAICreateVectorStoreRequestWithExtraBody(
|
||||
name=vector_db_name or vector_db_id,
|
||||
embedding_model=embedding_model,
|
||||
embedding_dimension=model.metadata["embedding_dimension"],
|
||||
vector_store = VectorStoreWithOwner(
|
||||
identifier=vector_store_id,
|
||||
type=ResourceType.vector_store.value,
|
||||
provider_id=provider_id,
|
||||
provider_vector_db_id=provider_vector_db_id,
|
||||
provider_resource_id=provider_vector_store_id,
|
||||
embedding_model=embedding_model,
|
||||
embedding_dimension=embedding_dimension,
|
||||
vector_store_name=vector_store_name,
|
||||
)
|
||||
vector_store = await provider.openai_create_vector_store(request)
|
||||
|
||||
vector_store_id = vector_store.id
|
||||
actual_provider_vector_db_id = provider_vector_db_id or vector_store_id
|
||||
logger.warning(
|
||||
f"Ignoring vector_db_id {vector_db_id} and using vector_store_id {vector_store_id} instead. Setting VectorDB {vector_db_id} to VectorDB.vector_db_name"
|
||||
)
|
||||
|
||||
vector_db_data = {
|
||||
"identifier": vector_store_id,
|
||||
"type": ResourceType.vector_db.value,
|
||||
"provider_id": provider_id,
|
||||
"provider_resource_id": actual_provider_vector_db_id,
|
||||
"embedding_model": embedding_model,
|
||||
"embedding_dimension": model.metadata["embedding_dimension"],
|
||||
"vector_db_name": vector_store.name,
|
||||
}
|
||||
vector_db = TypeAdapter(VectorDBWithOwner).validate_python(vector_db_data)
|
||||
await self.register_object(vector_db)
|
||||
return vector_db
|
||||
await self.register_object(vector_store)
|
||||
return vector_store
|
||||
|
||||
async def openai_retrieve_vector_store(
|
||||
self,
|
||||
vector_store_id: str,
|
||||
) -> VectorStoreObject:
|
||||
await self.assert_action_allowed("read", "vector_db", vector_store_id)
|
||||
await self.assert_action_allowed("read", "vector_store", vector_store_id)
|
||||
provider = await self.get_provider_impl(vector_store_id)
|
||||
return await provider.openai_retrieve_vector_store(vector_store_id)
|
||||
|
||||
|
|
@ -123,7 +92,7 @@ class VectorDBsRoutingTable(CommonRoutingTableImpl):
|
|||
expires_after: dict[str, Any] | None = None,
|
||||
metadata: dict[str, Any] | None = None,
|
||||
) -> VectorStoreObject:
|
||||
await self.assert_action_allowed("update", "vector_db", vector_store_id)
|
||||
await self.assert_action_allowed("update", "vector_store", vector_store_id)
|
||||
provider = await self.get_provider_impl(vector_store_id)
|
||||
return await provider.openai_update_vector_store(
|
||||
vector_store_id=vector_store_id,
|
||||
|
|
@ -136,18 +105,18 @@ class VectorDBsRoutingTable(CommonRoutingTableImpl):
|
|||
self,
|
||||
vector_store_id: str,
|
||||
) -> VectorStoreDeleteResponse:
|
||||
await self.assert_action_allowed("delete", "vector_db", vector_store_id)
|
||||
await self.assert_action_allowed("delete", "vector_store", vector_store_id)
|
||||
provider = await self.get_provider_impl(vector_store_id)
|
||||
result = await provider.openai_delete_vector_store(vector_store_id)
|
||||
await self.unregister_vector_db(vector_store_id)
|
||||
await self.unregister_vector_store(vector_store_id)
|
||||
return result
|
||||
|
||||
async def unregister_vector_db(self, vector_store_id: str) -> None:
|
||||
async def unregister_vector_store(self, vector_store_id: str) -> None:
|
||||
"""Remove the vector store from the routing table registry."""
|
||||
try:
|
||||
vector_db_obj = await self.get_object_by_identifier("vector_db", vector_store_id)
|
||||
if vector_db_obj:
|
||||
await self.unregister_object(vector_db_obj)
|
||||
vector_store_obj = await self.get_object_by_identifier("vector_store", vector_store_id)
|
||||
if vector_store_obj:
|
||||
await self.unregister_object(vector_store_obj)
|
||||
except Exception as e:
|
||||
# Log the error but don't fail the operation
|
||||
logger.warning(f"Failed to unregister vector store {vector_store_id} from routing table: {e}")
|
||||
|
|
@ -162,7 +131,7 @@ class VectorDBsRoutingTable(CommonRoutingTableImpl):
|
|||
rewrite_query: bool | None = False,
|
||||
search_mode: str | None = "vector",
|
||||
) -> VectorStoreSearchResponsePage:
|
||||
await self.assert_action_allowed("read", "vector_db", vector_store_id)
|
||||
await self.assert_action_allowed("read", "vector_store", vector_store_id)
|
||||
provider = await self.get_provider_impl(vector_store_id)
|
||||
return await provider.openai_search_vector_store(
|
||||
vector_store_id=vector_store_id,
|
||||
|
|
@ -181,7 +150,7 @@ class VectorDBsRoutingTable(CommonRoutingTableImpl):
|
|||
attributes: dict[str, Any] | None = None,
|
||||
chunking_strategy: VectorStoreChunkingStrategy | None = None,
|
||||
) -> VectorStoreFileObject:
|
||||
await self.assert_action_allowed("update", "vector_db", vector_store_id)
|
||||
await self.assert_action_allowed("update", "vector_store", vector_store_id)
|
||||
provider = await self.get_provider_impl(vector_store_id)
|
||||
return await provider.openai_attach_file_to_vector_store(
|
||||
vector_store_id=vector_store_id,
|
||||
|
|
@ -199,7 +168,7 @@ class VectorDBsRoutingTable(CommonRoutingTableImpl):
|
|||
before: str | None = None,
|
||||
filter: VectorStoreFileStatus | None = None,
|
||||
) -> list[VectorStoreFileObject]:
|
||||
await self.assert_action_allowed("read", "vector_db", vector_store_id)
|
||||
await self.assert_action_allowed("read", "vector_store", vector_store_id)
|
||||
provider = await self.get_provider_impl(vector_store_id)
|
||||
return await provider.openai_list_files_in_vector_store(
|
||||
vector_store_id=vector_store_id,
|
||||
|
|
@ -215,7 +184,7 @@ class VectorDBsRoutingTable(CommonRoutingTableImpl):
|
|||
vector_store_id: str,
|
||||
file_id: str,
|
||||
) -> VectorStoreFileObject:
|
||||
await self.assert_action_allowed("read", "vector_db", vector_store_id)
|
||||
await self.assert_action_allowed("read", "vector_store", vector_store_id)
|
||||
provider = await self.get_provider_impl(vector_store_id)
|
||||
return await provider.openai_retrieve_vector_store_file(
|
||||
vector_store_id=vector_store_id,
|
||||
|
|
@ -227,7 +196,7 @@ class VectorDBsRoutingTable(CommonRoutingTableImpl):
|
|||
vector_store_id: str,
|
||||
file_id: str,
|
||||
) -> VectorStoreFileContentsResponse:
|
||||
await self.assert_action_allowed("read", "vector_db", vector_store_id)
|
||||
await self.assert_action_allowed("read", "vector_store", vector_store_id)
|
||||
provider = await self.get_provider_impl(vector_store_id)
|
||||
return await provider.openai_retrieve_vector_store_file_contents(
|
||||
vector_store_id=vector_store_id,
|
||||
|
|
@ -240,7 +209,7 @@ class VectorDBsRoutingTable(CommonRoutingTableImpl):
|
|||
file_id: str,
|
||||
attributes: dict[str, Any],
|
||||
) -> VectorStoreFileObject:
|
||||
await self.assert_action_allowed("update", "vector_db", vector_store_id)
|
||||
await self.assert_action_allowed("update", "vector_store", vector_store_id)
|
||||
provider = await self.get_provider_impl(vector_store_id)
|
||||
return await provider.openai_update_vector_store_file(
|
||||
vector_store_id=vector_store_id,
|
||||
|
|
@ -253,7 +222,7 @@ class VectorDBsRoutingTable(CommonRoutingTableImpl):
|
|||
vector_store_id: str,
|
||||
file_id: str,
|
||||
) -> VectorStoreFileDeleteResponse:
|
||||
await self.assert_action_allowed("delete", "vector_db", vector_store_id)
|
||||
await self.assert_action_allowed("delete", "vector_store", vector_store_id)
|
||||
provider = await self.get_provider_impl(vector_store_id)
|
||||
return await provider.openai_delete_vector_store_file(
|
||||
vector_store_id=vector_store_id,
|
||||
|
|
@ -267,7 +236,7 @@ class VectorDBsRoutingTable(CommonRoutingTableImpl):
|
|||
attributes: dict[str, Any] | None = None,
|
||||
chunking_strategy: Any | None = None,
|
||||
):
|
||||
await self.assert_action_allowed("update", "vector_db", vector_store_id)
|
||||
await self.assert_action_allowed("update", "vector_store", vector_store_id)
|
||||
provider = await self.get_provider_impl(vector_store_id)
|
||||
return await provider.openai_create_vector_store_file_batch(
|
||||
vector_store_id=vector_store_id,
|
||||
|
|
@ -281,7 +250,7 @@ class VectorDBsRoutingTable(CommonRoutingTableImpl):
|
|||
batch_id: str,
|
||||
vector_store_id: str,
|
||||
):
|
||||
await self.assert_action_allowed("read", "vector_db", vector_store_id)
|
||||
await self.assert_action_allowed("read", "vector_store", vector_store_id)
|
||||
provider = await self.get_provider_impl(vector_store_id)
|
||||
return await provider.openai_retrieve_vector_store_file_batch(
|
||||
batch_id=batch_id,
|
||||
|
|
@ -298,7 +267,7 @@ class VectorDBsRoutingTable(CommonRoutingTableImpl):
|
|||
limit: int | None = 20,
|
||||
order: str | None = "desc",
|
||||
):
|
||||
await self.assert_action_allowed("read", "vector_db", vector_store_id)
|
||||
await self.assert_action_allowed("read", "vector_store", vector_store_id)
|
||||
provider = await self.get_provider_impl(vector_store_id)
|
||||
return await provider.openai_list_files_in_vector_store_file_batch(
|
||||
batch_id=batch_id,
|
||||
|
|
@ -315,7 +284,7 @@ class VectorDBsRoutingTable(CommonRoutingTableImpl):
|
|||
batch_id: str,
|
||||
vector_store_id: str,
|
||||
):
|
||||
await self.assert_action_allowed("update", "vector_db", vector_store_id)
|
||||
await self.assert_action_allowed("update", "vector_store", vector_store_id)
|
||||
provider = await self.get_provider_impl(vector_store_id)
|
||||
return await provider.openai_cancel_vector_store_file_batch(
|
||||
batch_id=batch_id,
|
||||
|
|
@ -13,7 +13,6 @@ from aiohttp import hdrs
|
|||
from starlette.routing import Route
|
||||
|
||||
from llama_stack.apis.datatypes import Api, ExternalApiSpec
|
||||
from llama_stack.apis.tools import RAGToolRuntime, SpecialToolGroup
|
||||
from llama_stack.core.resolver import api_protocol_map
|
||||
from llama_stack.schema_utils import WebMethod
|
||||
|
||||
|
|
@ -25,33 +24,16 @@ RouteImpls = dict[str, PathImpl]
|
|||
RouteMatch = tuple[EndpointFunc, PathParams, str, WebMethod]
|
||||
|
||||
|
||||
def toolgroup_protocol_map():
|
||||
return {
|
||||
SpecialToolGroup.rag_tool: RAGToolRuntime,
|
||||
}
|
||||
|
||||
|
||||
def get_all_api_routes(
|
||||
external_apis: dict[Api, ExternalApiSpec] | None = None,
|
||||
) -> dict[Api, list[tuple[Route, WebMethod]]]:
|
||||
apis = {}
|
||||
|
||||
protocols = api_protocol_map(external_apis)
|
||||
toolgroup_protocols = toolgroup_protocol_map()
|
||||
for api, protocol in protocols.items():
|
||||
routes = []
|
||||
protocol_methods = inspect.getmembers(protocol, predicate=inspect.isfunction)
|
||||
|
||||
# HACK ALERT
|
||||
if api == Api.tool_runtime:
|
||||
for tool_group in SpecialToolGroup:
|
||||
sub_protocol = toolgroup_protocols[tool_group]
|
||||
sub_protocol_methods = inspect.getmembers(sub_protocol, predicate=inspect.isfunction)
|
||||
for name, method in sub_protocol_methods:
|
||||
if not hasattr(method, "__webmethod__"):
|
||||
continue
|
||||
protocol_methods.append((f"{tool_group.value}.{name}", method))
|
||||
|
||||
for name, method in protocol_methods:
|
||||
# Get all webmethods for this method (supports multiple decorators)
|
||||
webmethods = getattr(method, "__webmethods__", [])
|
||||
|
|
|
|||
|
|
@ -32,7 +32,7 @@ from llama_stack.apis.scoring_functions import ScoringFunctions
|
|||
from llama_stack.apis.shields import Shields
|
||||
from llama_stack.apis.synthetic_data_generation import SyntheticDataGeneration
|
||||
from llama_stack.apis.telemetry import Telemetry
|
||||
from llama_stack.apis.tools import RAGToolRuntime, ToolGroups, ToolRuntime
|
||||
from llama_stack.apis.tools import ToolGroups, ToolRuntime
|
||||
from llama_stack.apis.vector_io import VectorIO
|
||||
from llama_stack.core.conversations.conversations import ConversationServiceConfig, ConversationServiceImpl
|
||||
from llama_stack.core.datatypes import Provider, StackRunConfig, VectorStoresConfig
|
||||
|
|
@ -80,7 +80,6 @@ class LlamaStack(
|
|||
Inspect,
|
||||
ToolGroups,
|
||||
ToolRuntime,
|
||||
RAGToolRuntime,
|
||||
Files,
|
||||
Prompts,
|
||||
Conversations,
|
||||
|
|
|
|||
|
|
@ -32,7 +32,7 @@ def tool_chat_page():
|
|||
tool_groups_list = [tool_group.identifier for tool_group in tool_groups]
|
||||
mcp_tools_list = [tool for tool in tool_groups_list if tool.startswith("mcp::")]
|
||||
builtin_tools_list = [tool for tool in tool_groups_list if not tool.startswith("mcp::")]
|
||||
selected_vector_dbs = []
|
||||
selected_vector_stores = []
|
||||
|
||||
def reset_agent():
|
||||
st.session_state.clear()
|
||||
|
|
@ -55,13 +55,13 @@ def tool_chat_page():
|
|||
)
|
||||
|
||||
if "builtin::rag" in toolgroup_selection:
|
||||
vector_dbs = llama_stack_api.client.vector_dbs.list() or []
|
||||
if not vector_dbs:
|
||||
vector_stores = llama_stack_api.client.vector_stores.list() or []
|
||||
if not vector_stores:
|
||||
st.info("No vector databases available for selection.")
|
||||
vector_dbs = [vector_db.identifier for vector_db in vector_dbs]
|
||||
selected_vector_dbs = st.multiselect(
|
||||
vector_stores = [vector_store.identifier for vector_store in vector_stores]
|
||||
selected_vector_stores = st.multiselect(
|
||||
label="Select Document Collections to use in RAG queries",
|
||||
options=vector_dbs,
|
||||
options=vector_stores,
|
||||
on_change=reset_agent,
|
||||
)
|
||||
|
||||
|
|
@ -119,7 +119,7 @@ def tool_chat_page():
|
|||
tool_dict = dict(
|
||||
name="builtin::rag",
|
||||
args={
|
||||
"vector_db_ids": list(selected_vector_dbs),
|
||||
"vector_store_ids": list(selected_vector_stores),
|
||||
},
|
||||
)
|
||||
toolgroup_selection[i] = tool_dict
|
||||
|
|
|
|||
|
|
@ -48,7 +48,6 @@ distribution_spec:
|
|||
tool_runtime:
|
||||
- provider_type: remote::brave-search
|
||||
- provider_type: remote::tavily-search
|
||||
- provider_type: inline::rag-runtime
|
||||
- provider_type: remote::model-context-protocol
|
||||
batches:
|
||||
- provider_type: inline::reference
|
||||
|
|
|
|||
|
|
@ -216,8 +216,6 @@ providers:
|
|||
config:
|
||||
api_key: ${env.TAVILY_SEARCH_API_KEY:=}
|
||||
max_results: 3
|
||||
- provider_id: rag-runtime
|
||||
provider_type: inline::rag-runtime
|
||||
- provider_id: model-context-protocol
|
||||
provider_type: remote::model-context-protocol
|
||||
batches:
|
||||
|
|
@ -263,8 +261,6 @@ registered_resources:
|
|||
tool_groups:
|
||||
- toolgroup_id: builtin::websearch
|
||||
provider_id: tavily-search
|
||||
- toolgroup_id: builtin::rag
|
||||
provider_id: rag-runtime
|
||||
server:
|
||||
port: 8321
|
||||
telemetry:
|
||||
|
|
|
|||
|
|
@ -26,7 +26,6 @@ distribution_spec:
|
|||
tool_runtime:
|
||||
- provider_type: remote::brave-search
|
||||
- provider_type: remote::tavily-search
|
||||
- provider_type: inline::rag-runtime
|
||||
image_type: venv
|
||||
additional_pip_packages:
|
||||
- aiosqlite
|
||||
|
|
|
|||
|
|
@ -45,7 +45,6 @@ def get_distribution_template() -> DistributionTemplate:
|
|||
"tool_runtime": [
|
||||
BuildProvider(provider_type="remote::brave-search"),
|
||||
BuildProvider(provider_type="remote::tavily-search"),
|
||||
BuildProvider(provider_type="inline::rag-runtime"),
|
||||
],
|
||||
}
|
||||
name = "dell"
|
||||
|
|
@ -98,10 +97,6 @@ def get_distribution_template() -> DistributionTemplate:
|
|||
toolgroup_id="builtin::websearch",
|
||||
provider_id="brave-search",
|
||||
),
|
||||
ToolGroupInput(
|
||||
toolgroup_id="builtin::rag",
|
||||
provider_id="rag-runtime",
|
||||
),
|
||||
]
|
||||
|
||||
return DistributionTemplate(
|
||||
|
|
|
|||
|
|
@ -87,8 +87,6 @@ providers:
|
|||
config:
|
||||
api_key: ${env.TAVILY_SEARCH_API_KEY:=}
|
||||
max_results: 3
|
||||
- provider_id: rag-runtime
|
||||
provider_type: inline::rag-runtime
|
||||
storage:
|
||||
backends:
|
||||
kv_default:
|
||||
|
|
@ -133,8 +131,6 @@ registered_resources:
|
|||
tool_groups:
|
||||
- toolgroup_id: builtin::websearch
|
||||
provider_id: brave-search
|
||||
- toolgroup_id: builtin::rag
|
||||
provider_id: rag-runtime
|
||||
server:
|
||||
port: 8321
|
||||
telemetry:
|
||||
|
|
|
|||
|
|
@ -83,8 +83,6 @@ providers:
|
|||
config:
|
||||
api_key: ${env.TAVILY_SEARCH_API_KEY:=}
|
||||
max_results: 3
|
||||
- provider_id: rag-runtime
|
||||
provider_type: inline::rag-runtime
|
||||
storage:
|
||||
backends:
|
||||
kv_default:
|
||||
|
|
@ -124,8 +122,6 @@ registered_resources:
|
|||
tool_groups:
|
||||
- toolgroup_id: builtin::websearch
|
||||
provider_id: brave-search
|
||||
- toolgroup_id: builtin::rag
|
||||
provider_id: rag-runtime
|
||||
server:
|
||||
port: 8321
|
||||
telemetry:
|
||||
|
|
|
|||
|
|
@ -24,7 +24,6 @@ distribution_spec:
|
|||
tool_runtime:
|
||||
- provider_type: remote::brave-search
|
||||
- provider_type: remote::tavily-search
|
||||
- provider_type: inline::rag-runtime
|
||||
- provider_type: remote::model-context-protocol
|
||||
image_type: venv
|
||||
additional_pip_packages:
|
||||
|
|
|
|||
|
|
@ -47,7 +47,6 @@ def get_distribution_template() -> DistributionTemplate:
|
|||
"tool_runtime": [
|
||||
BuildProvider(provider_type="remote::brave-search"),
|
||||
BuildProvider(provider_type="remote::tavily-search"),
|
||||
BuildProvider(provider_type="inline::rag-runtime"),
|
||||
BuildProvider(provider_type="remote::model-context-protocol"),
|
||||
],
|
||||
}
|
||||
|
|
@ -92,10 +91,6 @@ def get_distribution_template() -> DistributionTemplate:
|
|||
toolgroup_id="builtin::websearch",
|
||||
provider_id="tavily-search",
|
||||
),
|
||||
ToolGroupInput(
|
||||
toolgroup_id="builtin::rag",
|
||||
provider_id="rag-runtime",
|
||||
),
|
||||
]
|
||||
|
||||
return DistributionTemplate(
|
||||
|
|
|
|||
|
|
@ -98,8 +98,6 @@ providers:
|
|||
config:
|
||||
api_key: ${env.TAVILY_SEARCH_API_KEY:=}
|
||||
max_results: 3
|
||||
- provider_id: rag-runtime
|
||||
provider_type: inline::rag-runtime
|
||||
- provider_id: model-context-protocol
|
||||
provider_type: remote::model-context-protocol
|
||||
storage:
|
||||
|
|
@ -146,8 +144,6 @@ registered_resources:
|
|||
tool_groups:
|
||||
- toolgroup_id: builtin::websearch
|
||||
provider_id: tavily-search
|
||||
- toolgroup_id: builtin::rag
|
||||
provider_id: rag-runtime
|
||||
server:
|
||||
port: 8321
|
||||
telemetry:
|
||||
|
|
|
|||
|
|
@ -88,8 +88,6 @@ providers:
|
|||
config:
|
||||
api_key: ${env.TAVILY_SEARCH_API_KEY:=}
|
||||
max_results: 3
|
||||
- provider_id: rag-runtime
|
||||
provider_type: inline::rag-runtime
|
||||
- provider_id: model-context-protocol
|
||||
provider_type: remote::model-context-protocol
|
||||
storage:
|
||||
|
|
@ -131,8 +129,6 @@ registered_resources:
|
|||
tool_groups:
|
||||
- toolgroup_id: builtin::websearch
|
||||
provider_id: tavily-search
|
||||
- toolgroup_id: builtin::rag
|
||||
provider_id: rag-runtime
|
||||
server:
|
||||
port: 8321
|
||||
telemetry:
|
||||
|
|
|
|||
|
|
@ -19,8 +19,7 @@ distribution_spec:
|
|||
- provider_type: remote::nvidia
|
||||
scoring:
|
||||
- provider_type: inline::basic
|
||||
tool_runtime:
|
||||
- provider_type: inline::rag-runtime
|
||||
tool_runtime: []
|
||||
files:
|
||||
- provider_type: inline::localfs
|
||||
image_type: venv
|
||||
|
|
|
|||
|
|
@ -28,7 +28,7 @@ def get_distribution_template(name: str = "nvidia") -> DistributionTemplate:
|
|||
BuildProvider(provider_type="remote::nvidia"),
|
||||
],
|
||||
"scoring": [BuildProvider(provider_type="inline::basic")],
|
||||
"tool_runtime": [BuildProvider(provider_type="inline::rag-runtime")],
|
||||
"tool_runtime": [],
|
||||
"files": [BuildProvider(provider_type="inline::localfs")],
|
||||
}
|
||||
|
||||
|
|
@ -66,12 +66,7 @@ def get_distribution_template(name: str = "nvidia") -> DistributionTemplate:
|
|||
provider_id="nvidia",
|
||||
)
|
||||
|
||||
default_tool_groups = [
|
||||
ToolGroupInput(
|
||||
toolgroup_id="builtin::rag",
|
||||
provider_id="rag-runtime",
|
||||
),
|
||||
]
|
||||
default_tool_groups: list[ToolGroupInput] = []
|
||||
|
||||
return DistributionTemplate(
|
||||
name=name,
|
||||
|
|
|
|||
|
|
@ -80,9 +80,7 @@ providers:
|
|||
scoring:
|
||||
- provider_id: basic
|
||||
provider_type: inline::basic
|
||||
tool_runtime:
|
||||
- provider_id: rag-runtime
|
||||
provider_type: inline::rag-runtime
|
||||
tool_runtime: []
|
||||
files:
|
||||
- provider_id: meta-reference-files
|
||||
provider_type: inline::localfs
|
||||
|
|
@ -128,9 +126,7 @@ registered_resources:
|
|||
datasets: []
|
||||
scoring_fns: []
|
||||
benchmarks: []
|
||||
tool_groups:
|
||||
- toolgroup_id: builtin::rag
|
||||
provider_id: rag-runtime
|
||||
tool_groups: []
|
||||
server:
|
||||
port: 8321
|
||||
telemetry:
|
||||
|
|
|
|||
|
|
@ -69,9 +69,7 @@ providers:
|
|||
scoring:
|
||||
- provider_id: basic
|
||||
provider_type: inline::basic
|
||||
tool_runtime:
|
||||
- provider_id: rag-runtime
|
||||
provider_type: inline::rag-runtime
|
||||
tool_runtime: []
|
||||
files:
|
||||
- provider_id: meta-reference-files
|
||||
provider_type: inline::localfs
|
||||
|
|
@ -107,9 +105,7 @@ registered_resources:
|
|||
datasets: []
|
||||
scoring_fns: []
|
||||
benchmarks: []
|
||||
tool_groups:
|
||||
- toolgroup_id: builtin::rag
|
||||
provider_id: rag-runtime
|
||||
tool_groups: []
|
||||
server:
|
||||
port: 8321
|
||||
telemetry:
|
||||
|
|
|
|||
|
|
@ -28,7 +28,6 @@ distribution_spec:
|
|||
tool_runtime:
|
||||
- provider_type: remote::brave-search
|
||||
- provider_type: remote::tavily-search
|
||||
- provider_type: inline::rag-runtime
|
||||
- provider_type: remote::model-context-protocol
|
||||
image_type: venv
|
||||
additional_pip_packages:
|
||||
|
|
|
|||
|
|
@ -118,7 +118,6 @@ def get_distribution_template() -> DistributionTemplate:
|
|||
"tool_runtime": [
|
||||
BuildProvider(provider_type="remote::brave-search"),
|
||||
BuildProvider(provider_type="remote::tavily-search"),
|
||||
BuildProvider(provider_type="inline::rag-runtime"),
|
||||
BuildProvider(provider_type="remote::model-context-protocol"),
|
||||
],
|
||||
}
|
||||
|
|
@ -154,10 +153,6 @@ def get_distribution_template() -> DistributionTemplate:
|
|||
toolgroup_id="builtin::websearch",
|
||||
provider_id="tavily-search",
|
||||
),
|
||||
ToolGroupInput(
|
||||
toolgroup_id="builtin::rag",
|
||||
provider_id="rag-runtime",
|
||||
),
|
||||
]
|
||||
|
||||
models, _ = get_model_registry(available_models)
|
||||
|
|
|
|||
|
|
@ -118,8 +118,6 @@ providers:
|
|||
config:
|
||||
api_key: ${env.TAVILY_SEARCH_API_KEY:=}
|
||||
max_results: 3
|
||||
- provider_id: rag-runtime
|
||||
provider_type: inline::rag-runtime
|
||||
- provider_id: model-context-protocol
|
||||
provider_type: remote::model-context-protocol
|
||||
storage:
|
||||
|
|
@ -244,8 +242,6 @@ registered_resources:
|
|||
tool_groups:
|
||||
- toolgroup_id: builtin::websearch
|
||||
provider_id: tavily-search
|
||||
- toolgroup_id: builtin::rag
|
||||
provider_id: rag-runtime
|
||||
server:
|
||||
port: 8321
|
||||
telemetry:
|
||||
|
|
|
|||
|
|
@ -14,7 +14,6 @@ distribution_spec:
|
|||
tool_runtime:
|
||||
- provider_type: remote::brave-search
|
||||
- provider_type: remote::tavily-search
|
||||
- provider_type: inline::rag-runtime
|
||||
- provider_type: remote::model-context-protocol
|
||||
image_type: venv
|
||||
additional_pip_packages:
|
||||
|
|
|
|||
|
|
@ -45,7 +45,6 @@ def get_distribution_template() -> DistributionTemplate:
|
|||
"tool_runtime": [
|
||||
BuildProvider(provider_type="remote::brave-search"),
|
||||
BuildProvider(provider_type="remote::tavily-search"),
|
||||
BuildProvider(provider_type="inline::rag-runtime"),
|
||||
BuildProvider(provider_type="remote::model-context-protocol"),
|
||||
],
|
||||
}
|
||||
|
|
@ -66,10 +65,6 @@ def get_distribution_template() -> DistributionTemplate:
|
|||
toolgroup_id="builtin::websearch",
|
||||
provider_id="tavily-search",
|
||||
),
|
||||
ToolGroupInput(
|
||||
toolgroup_id="builtin::rag",
|
||||
provider_id="rag-runtime",
|
||||
),
|
||||
]
|
||||
|
||||
default_models = [
|
||||
|
|
|
|||
|
|
@ -54,8 +54,6 @@ providers:
|
|||
config:
|
||||
api_key: ${env.TAVILY_SEARCH_API_KEY:=}
|
||||
max_results: 3
|
||||
- provider_id: rag-runtime
|
||||
provider_type: inline::rag-runtime
|
||||
- provider_id: model-context-protocol
|
||||
provider_type: remote::model-context-protocol
|
||||
storage:
|
||||
|
|
@ -107,8 +105,6 @@ registered_resources:
|
|||
tool_groups:
|
||||
- toolgroup_id: builtin::websearch
|
||||
provider_id: tavily-search
|
||||
- toolgroup_id: builtin::rag
|
||||
provider_id: rag-runtime
|
||||
server:
|
||||
port: 8321
|
||||
telemetry:
|
||||
|
|
|
|||
|
|
@ -49,7 +49,6 @@ distribution_spec:
|
|||
tool_runtime:
|
||||
- provider_type: remote::brave-search
|
||||
- provider_type: remote::tavily-search
|
||||
- provider_type: inline::rag-runtime
|
||||
- provider_type: remote::model-context-protocol
|
||||
batches:
|
||||
- provider_type: inline::reference
|
||||
|
|
|
|||
|
|
@ -219,8 +219,6 @@ providers:
|
|||
config:
|
||||
api_key: ${env.TAVILY_SEARCH_API_KEY:=}
|
||||
max_results: 3
|
||||
- provider_id: rag-runtime
|
||||
provider_type: inline::rag-runtime
|
||||
- provider_id: model-context-protocol
|
||||
provider_type: remote::model-context-protocol
|
||||
batches:
|
||||
|
|
@ -266,8 +264,6 @@ registered_resources:
|
|||
tool_groups:
|
||||
- toolgroup_id: builtin::websearch
|
||||
provider_id: tavily-search
|
||||
- toolgroup_id: builtin::rag
|
||||
provider_id: rag-runtime
|
||||
server:
|
||||
port: 8321
|
||||
telemetry:
|
||||
|
|
|
|||
|
|
@ -49,7 +49,6 @@ distribution_spec:
|
|||
tool_runtime:
|
||||
- provider_type: remote::brave-search
|
||||
- provider_type: remote::tavily-search
|
||||
- provider_type: inline::rag-runtime
|
||||
- provider_type: remote::model-context-protocol
|
||||
batches:
|
||||
- provider_type: inline::reference
|
||||
|
|
|
|||
|
|
@ -216,8 +216,6 @@ providers:
|
|||
config:
|
||||
api_key: ${env.TAVILY_SEARCH_API_KEY:=}
|
||||
max_results: 3
|
||||
- provider_id: rag-runtime
|
||||
provider_type: inline::rag-runtime
|
||||
- provider_id: model-context-protocol
|
||||
provider_type: remote::model-context-protocol
|
||||
batches:
|
||||
|
|
@ -263,8 +261,6 @@ registered_resources:
|
|||
tool_groups:
|
||||
- toolgroup_id: builtin::websearch
|
||||
provider_id: tavily-search
|
||||
- toolgroup_id: builtin::rag
|
||||
provider_id: rag-runtime
|
||||
server:
|
||||
port: 8321
|
||||
telemetry:
|
||||
|
|
|
|||
|
|
@ -140,7 +140,6 @@ def get_distribution_template(name: str = "starter") -> DistributionTemplate:
|
|||
"tool_runtime": [
|
||||
BuildProvider(provider_type="remote::brave-search"),
|
||||
BuildProvider(provider_type="remote::tavily-search"),
|
||||
BuildProvider(provider_type="inline::rag-runtime"),
|
||||
BuildProvider(provider_type="remote::model-context-protocol"),
|
||||
],
|
||||
"batches": [
|
||||
|
|
@ -162,10 +161,6 @@ def get_distribution_template(name: str = "starter") -> DistributionTemplate:
|
|||
toolgroup_id="builtin::websearch",
|
||||
provider_id="tavily-search",
|
||||
),
|
||||
ToolGroupInput(
|
||||
toolgroup_id="builtin::rag",
|
||||
provider_id="rag-runtime",
|
||||
),
|
||||
]
|
||||
default_shields = [
|
||||
# if the
|
||||
|
|
|
|||
|
|
@ -23,7 +23,6 @@ distribution_spec:
|
|||
tool_runtime:
|
||||
- provider_type: remote::brave-search
|
||||
- provider_type: remote::tavily-search
|
||||
- provider_type: inline::rag-runtime
|
||||
- provider_type: remote::model-context-protocol
|
||||
files:
|
||||
- provider_type: inline::localfs
|
||||
|
|
|
|||
|
|
@ -83,8 +83,6 @@ providers:
|
|||
config:
|
||||
api_key: ${env.TAVILY_SEARCH_API_KEY:=}
|
||||
max_results: 3
|
||||
- provider_id: rag-runtime
|
||||
provider_type: inline::rag-runtime
|
||||
- provider_id: model-context-protocol
|
||||
provider_type: remote::model-context-protocol
|
||||
files:
|
||||
|
|
@ -125,8 +123,6 @@ registered_resources:
|
|||
tool_groups:
|
||||
- toolgroup_id: builtin::websearch
|
||||
provider_id: tavily-search
|
||||
- toolgroup_id: builtin::rag
|
||||
provider_id: rag-runtime
|
||||
server:
|
||||
port: 8321
|
||||
telemetry:
|
||||
|
|
|
|||
|
|
@ -33,7 +33,6 @@ def get_distribution_template(name: str = "watsonx") -> DistributionTemplate:
|
|||
"tool_runtime": [
|
||||
BuildProvider(provider_type="remote::brave-search"),
|
||||
BuildProvider(provider_type="remote::tavily-search"),
|
||||
BuildProvider(provider_type="inline::rag-runtime"),
|
||||
BuildProvider(provider_type="remote::model-context-protocol"),
|
||||
],
|
||||
"files": [BuildProvider(provider_type="inline::localfs")],
|
||||
|
|
@ -50,10 +49,6 @@ def get_distribution_template(name: str = "watsonx") -> DistributionTemplate:
|
|||
toolgroup_id="builtin::websearch",
|
||||
provider_id="tavily-search",
|
||||
),
|
||||
ToolGroupInput(
|
||||
toolgroup_id="builtin::rag",
|
||||
provider_id="rag-runtime",
|
||||
),
|
||||
]
|
||||
|
||||
files_provider = Provider(
|
||||
|
|
|
|||
|
|
@ -17,7 +17,7 @@ from llama_stack.apis.models import Model
|
|||
from llama_stack.apis.scoring_functions import ScoringFn
|
||||
from llama_stack.apis.shields import Shield
|
||||
from llama_stack.apis.tools import ToolGroup
|
||||
from llama_stack.apis.vector_dbs import VectorDB
|
||||
from llama_stack.apis.vector_stores import VectorStore
|
||||
from llama_stack.schema_utils import json_schema_type
|
||||
|
||||
|
||||
|
|
@ -68,10 +68,10 @@ class ShieldsProtocolPrivate(Protocol):
|
|||
async def unregister_shield(self, identifier: str) -> None: ...
|
||||
|
||||
|
||||
class VectorDBsProtocolPrivate(Protocol):
|
||||
async def register_vector_db(self, vector_db: VectorDB) -> None: ...
|
||||
class VectorStoresProtocolPrivate(Protocol):
|
||||
async def register_vector_store(self, vector_store: VectorStore) -> None: ...
|
||||
|
||||
async def unregister_vector_db(self, vector_db_id: str) -> None: ...
|
||||
async def unregister_vector_store(self, vector_store_id: str) -> None: ...
|
||||
|
||||
|
||||
class DatasetsProtocolPrivate(Protocol):
|
||||
|
|
|
|||
|
|
@ -1,5 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
|
@ -1,19 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any
|
||||
|
||||
from llama_stack.providers.datatypes import Api
|
||||
|
||||
from .config import RagToolRuntimeConfig
|
||||
|
||||
|
||||
async def get_provider_impl(config: RagToolRuntimeConfig, deps: dict[Api, Any]):
|
||||
from .memory import MemoryToolRuntimeImpl
|
||||
|
||||
impl = MemoryToolRuntimeImpl(config, deps[Api.vector_io], deps[Api.inference], deps[Api.files])
|
||||
await impl.initialize()
|
||||
return impl
|
||||
|
|
@ -1,15 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class RagToolRuntimeConfig(BaseModel):
|
||||
@classmethod
|
||||
def sample_run_config(cls, __distro_dir__: str, **kwargs: Any) -> dict[str, Any]:
|
||||
return {}
|
||||
|
|
@ -1,77 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
|
||||
from jinja2 import Template
|
||||
|
||||
from llama_stack.apis.common.content_types import InterleavedContent
|
||||
from llama_stack.apis.inference import OpenAIChatCompletionRequestWithExtraBody, OpenAIUserMessageParam
|
||||
from llama_stack.apis.tools.rag_tool import (
|
||||
DefaultRAGQueryGeneratorConfig,
|
||||
LLMRAGQueryGeneratorConfig,
|
||||
RAGQueryGenerator,
|
||||
RAGQueryGeneratorConfig,
|
||||
)
|
||||
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||
interleaved_content_as_str,
|
||||
)
|
||||
|
||||
|
||||
async def generate_rag_query(
|
||||
config: RAGQueryGeneratorConfig,
|
||||
content: InterleavedContent,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Generates a query that will be used for
|
||||
retrieving relevant information from the memory bank.
|
||||
"""
|
||||
if config.type == RAGQueryGenerator.default.value:
|
||||
query = await default_rag_query_generator(config, content, **kwargs)
|
||||
elif config.type == RAGQueryGenerator.llm.value:
|
||||
query = await llm_rag_query_generator(config, content, **kwargs)
|
||||
else:
|
||||
raise NotImplementedError(f"Unsupported memory query generator {config.type}")
|
||||
return query
|
||||
|
||||
|
||||
async def default_rag_query_generator(
|
||||
config: DefaultRAGQueryGeneratorConfig,
|
||||
content: InterleavedContent,
|
||||
**kwargs,
|
||||
):
|
||||
return interleaved_content_as_str(content, sep=config.separator)
|
||||
|
||||
|
||||
async def llm_rag_query_generator(
|
||||
config: LLMRAGQueryGeneratorConfig,
|
||||
content: InterleavedContent,
|
||||
**kwargs,
|
||||
):
|
||||
assert "inference_api" in kwargs, "LLMRAGQueryGenerator needs inference_api"
|
||||
inference_api = kwargs["inference_api"]
|
||||
|
||||
messages = []
|
||||
if isinstance(content, list):
|
||||
messages = [interleaved_content_as_str(m) for m in content]
|
||||
else:
|
||||
messages = [interleaved_content_as_str(content)]
|
||||
|
||||
template = Template(config.template)
|
||||
rendered_content: str = template.render({"messages": messages})
|
||||
|
||||
model = config.model
|
||||
message = OpenAIUserMessageParam(content=rendered_content)
|
||||
params = OpenAIChatCompletionRequestWithExtraBody(
|
||||
model=model,
|
||||
messages=[message],
|
||||
stream=False,
|
||||
)
|
||||
response = await inference_api.openai_chat_completion(params)
|
||||
|
||||
query = response.choices[0].message.content
|
||||
|
||||
return query
|
||||
|
|
@ -1,332 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import asyncio
|
||||
import base64
|
||||
import io
|
||||
import mimetypes
|
||||
from typing import Any
|
||||
|
||||
import httpx
|
||||
from fastapi import UploadFile
|
||||
from pydantic import TypeAdapter
|
||||
|
||||
from llama_stack.apis.common.content_types import (
|
||||
URL,
|
||||
InterleavedContent,
|
||||
InterleavedContentItem,
|
||||
TextContentItem,
|
||||
)
|
||||
from llama_stack.apis.files import Files, OpenAIFilePurpose
|
||||
from llama_stack.apis.inference import Inference
|
||||
from llama_stack.apis.tools import (
|
||||
ListToolDefsResponse,
|
||||
RAGDocument,
|
||||
RAGQueryConfig,
|
||||
RAGQueryResult,
|
||||
RAGToolRuntime,
|
||||
ToolDef,
|
||||
ToolGroup,
|
||||
ToolInvocationResult,
|
||||
ToolRuntime,
|
||||
)
|
||||
from llama_stack.apis.vector_io import (
|
||||
QueryChunksResponse,
|
||||
VectorIO,
|
||||
VectorStoreChunkingStrategyStatic,
|
||||
VectorStoreChunkingStrategyStaticConfig,
|
||||
)
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.datatypes import ToolGroupsProtocolPrivate
|
||||
from llama_stack.providers.utils.inference.prompt_adapter import interleaved_content_as_str
|
||||
from llama_stack.providers.utils.memory.vector_store import parse_data_url
|
||||
|
||||
from .config import RagToolRuntimeConfig
|
||||
from .context_retriever import generate_rag_query
|
||||
|
||||
log = get_logger(name=__name__, category="tool_runtime")
|
||||
|
||||
|
||||
async def raw_data_from_doc(doc: RAGDocument) -> tuple[bytes, str]:
|
||||
"""Get raw binary data and mime type from a RAGDocument for file upload."""
|
||||
if isinstance(doc.content, URL):
|
||||
if doc.content.uri.startswith("data:"):
|
||||
parts = parse_data_url(doc.content.uri)
|
||||
mime_type = parts["mimetype"]
|
||||
data = parts["data"]
|
||||
|
||||
if parts["is_base64"]:
|
||||
file_data = base64.b64decode(data)
|
||||
else:
|
||||
file_data = data.encode("utf-8")
|
||||
|
||||
return file_data, mime_type
|
||||
else:
|
||||
async with httpx.AsyncClient() as client:
|
||||
r = await client.get(doc.content.uri)
|
||||
r.raise_for_status()
|
||||
mime_type = r.headers.get("content-type", "application/octet-stream")
|
||||
return r.content, mime_type
|
||||
else:
|
||||
if isinstance(doc.content, str):
|
||||
content_str = doc.content
|
||||
else:
|
||||
content_str = interleaved_content_as_str(doc.content)
|
||||
|
||||
if content_str.startswith("data:"):
|
||||
parts = parse_data_url(content_str)
|
||||
mime_type = parts["mimetype"]
|
||||
data = parts["data"]
|
||||
|
||||
if parts["is_base64"]:
|
||||
file_data = base64.b64decode(data)
|
||||
else:
|
||||
file_data = data.encode("utf-8")
|
||||
|
||||
return file_data, mime_type
|
||||
else:
|
||||
return content_str.encode("utf-8"), "text/plain"
|
||||
|
||||
|
||||
class MemoryToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime, RAGToolRuntime):
|
||||
def __init__(
|
||||
self,
|
||||
config: RagToolRuntimeConfig,
|
||||
vector_io_api: VectorIO,
|
||||
inference_api: Inference,
|
||||
files_api: Files,
|
||||
):
|
||||
self.config = config
|
||||
self.vector_io_api = vector_io_api
|
||||
self.inference_api = inference_api
|
||||
self.files_api = files_api
|
||||
|
||||
async def initialize(self):
|
||||
pass
|
||||
|
||||
async def shutdown(self):
|
||||
pass
|
||||
|
||||
async def register_toolgroup(self, toolgroup: ToolGroup) -> None:
|
||||
pass
|
||||
|
||||
async def unregister_toolgroup(self, toolgroup_id: str) -> None:
|
||||
return
|
||||
|
||||
async def insert(
|
||||
self,
|
||||
documents: list[RAGDocument],
|
||||
vector_db_id: str,
|
||||
chunk_size_in_tokens: int = 512,
|
||||
) -> None:
|
||||
if not documents:
|
||||
return
|
||||
|
||||
for doc in documents:
|
||||
try:
|
||||
try:
|
||||
file_data, mime_type = await raw_data_from_doc(doc)
|
||||
except Exception as e:
|
||||
log.error(f"Failed to extract content from document {doc.document_id}: {e}")
|
||||
continue
|
||||
|
||||
file_extension = mimetypes.guess_extension(mime_type) or ".txt"
|
||||
filename = doc.metadata.get("filename", f"{doc.document_id}{file_extension}")
|
||||
|
||||
file_obj = io.BytesIO(file_data)
|
||||
file_obj.name = filename
|
||||
|
||||
upload_file = UploadFile(file=file_obj, filename=filename)
|
||||
|
||||
try:
|
||||
created_file = await self.files_api.openai_upload_file(
|
||||
file=upload_file, purpose=OpenAIFilePurpose.ASSISTANTS
|
||||
)
|
||||
except Exception as e:
|
||||
log.error(f"Failed to upload file for document {doc.document_id}: {e}")
|
||||
continue
|
||||
|
||||
chunking_strategy = VectorStoreChunkingStrategyStatic(
|
||||
static=VectorStoreChunkingStrategyStaticConfig(
|
||||
max_chunk_size_tokens=chunk_size_in_tokens,
|
||||
chunk_overlap_tokens=chunk_size_in_tokens // 4,
|
||||
)
|
||||
)
|
||||
|
||||
try:
|
||||
await self.vector_io_api.openai_attach_file_to_vector_store(
|
||||
vector_store_id=vector_db_id,
|
||||
file_id=created_file.id,
|
||||
attributes=doc.metadata,
|
||||
chunking_strategy=chunking_strategy,
|
||||
)
|
||||
except Exception as e:
|
||||
log.error(
|
||||
f"Failed to attach file {created_file.id} to vector store {vector_db_id} for document {doc.document_id}: {e}"
|
||||
)
|
||||
continue
|
||||
|
||||
except Exception as e:
|
||||
log.error(f"Unexpected error processing document {doc.document_id}: {e}")
|
||||
continue
|
||||
|
||||
async def query(
|
||||
self,
|
||||
content: InterleavedContent,
|
||||
vector_db_ids: list[str],
|
||||
query_config: RAGQueryConfig | None = None,
|
||||
) -> RAGQueryResult:
|
||||
if not vector_db_ids:
|
||||
raise ValueError(
|
||||
"No vector DBs were provided to the knowledge search tool. Please provide at least one vector DB ID."
|
||||
)
|
||||
|
||||
query_config = query_config or RAGQueryConfig()
|
||||
query = await generate_rag_query(
|
||||
query_config.query_generator_config,
|
||||
content,
|
||||
inference_api=self.inference_api,
|
||||
)
|
||||
tasks = [
|
||||
self.vector_io_api.query_chunks(
|
||||
vector_db_id=vector_db_id,
|
||||
query=query,
|
||||
params={
|
||||
"mode": query_config.mode,
|
||||
"max_chunks": query_config.max_chunks,
|
||||
"score_threshold": 0.0,
|
||||
"ranker": query_config.ranker,
|
||||
},
|
||||
)
|
||||
for vector_db_id in vector_db_ids
|
||||
]
|
||||
results: list[QueryChunksResponse] = await asyncio.gather(*tasks)
|
||||
|
||||
chunks = []
|
||||
scores = []
|
||||
|
||||
for vector_db_id, result in zip(vector_db_ids, results, strict=False):
|
||||
for chunk, score in zip(result.chunks, result.scores, strict=False):
|
||||
if not hasattr(chunk, "metadata") or chunk.metadata is None:
|
||||
chunk.metadata = {}
|
||||
chunk.metadata["vector_db_id"] = vector_db_id
|
||||
|
||||
chunks.append(chunk)
|
||||
scores.append(score)
|
||||
|
||||
if not chunks:
|
||||
return RAGQueryResult(content=None)
|
||||
|
||||
# sort by score
|
||||
chunks, scores = zip(*sorted(zip(chunks, scores, strict=False), key=lambda x: x[1], reverse=True), strict=False) # type: ignore
|
||||
chunks = chunks[: query_config.max_chunks]
|
||||
|
||||
tokens = 0
|
||||
picked: list[InterleavedContentItem] = [
|
||||
TextContentItem(
|
||||
text=f"knowledge_search tool found {len(chunks)} chunks:\nBEGIN of knowledge_search tool results.\n"
|
||||
)
|
||||
]
|
||||
for i, chunk in enumerate(chunks):
|
||||
metadata = chunk.metadata
|
||||
tokens += metadata.get("token_count", 0)
|
||||
tokens += metadata.get("metadata_token_count", 0)
|
||||
|
||||
if tokens > query_config.max_tokens_in_context:
|
||||
log.error(
|
||||
f"Using {len(picked)} chunks; reached max tokens in context: {tokens}",
|
||||
)
|
||||
break
|
||||
|
||||
# Add useful keys from chunk_metadata to metadata and remove some from metadata
|
||||
chunk_metadata_keys_to_include_from_context = [
|
||||
"chunk_id",
|
||||
"document_id",
|
||||
"source",
|
||||
]
|
||||
metadata_keys_to_exclude_from_context = [
|
||||
"token_count",
|
||||
"metadata_token_count",
|
||||
"vector_db_id",
|
||||
]
|
||||
metadata_for_context = {}
|
||||
for k in chunk_metadata_keys_to_include_from_context:
|
||||
metadata_for_context[k] = getattr(chunk.chunk_metadata, k)
|
||||
for k in metadata:
|
||||
if k not in metadata_keys_to_exclude_from_context:
|
||||
metadata_for_context[k] = metadata[k]
|
||||
|
||||
text_content = query_config.chunk_template.format(index=i + 1, chunk=chunk, metadata=metadata_for_context)
|
||||
picked.append(TextContentItem(text=text_content))
|
||||
|
||||
picked.append(TextContentItem(text="END of knowledge_search tool results.\n"))
|
||||
picked.append(
|
||||
TextContentItem(
|
||||
text=f'The above results were retrieved to help answer the user\'s query: "{interleaved_content_as_str(content)}". Use them as supporting information only in answering this query.\n',
|
||||
)
|
||||
)
|
||||
|
||||
return RAGQueryResult(
|
||||
content=picked,
|
||||
metadata={
|
||||
"document_ids": [c.document_id for c in chunks[: len(picked)]],
|
||||
"chunks": [c.content for c in chunks[: len(picked)]],
|
||||
"scores": scores[: len(picked)],
|
||||
"vector_db_ids": [c.metadata["vector_db_id"] for c in chunks[: len(picked)]],
|
||||
},
|
||||
)
|
||||
|
||||
async def list_runtime_tools(
|
||||
self, tool_group_id: str | None = None, mcp_endpoint: URL | None = None
|
||||
) -> ListToolDefsResponse:
|
||||
# Parameters are not listed since these methods are not yet invoked automatically
|
||||
# by the LLM. The method is only implemented so things like /tools can list without
|
||||
# encountering fatals.
|
||||
return ListToolDefsResponse(
|
||||
data=[
|
||||
ToolDef(
|
||||
name="insert_into_memory",
|
||||
description="Insert documents into memory",
|
||||
),
|
||||
ToolDef(
|
||||
name="knowledge_search",
|
||||
description="Search for information in a database.",
|
||||
input_schema={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"query": {
|
||||
"type": "string",
|
||||
"description": "The query to search for. Can be a natural language sentence or keywords.",
|
||||
}
|
||||
},
|
||||
"required": ["query"],
|
||||
},
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
async def invoke_tool(self, tool_name: str, kwargs: dict[str, Any]) -> ToolInvocationResult:
|
||||
vector_db_ids = kwargs.get("vector_db_ids", [])
|
||||
query_config = kwargs.get("query_config")
|
||||
if query_config:
|
||||
query_config = TypeAdapter(RAGQueryConfig).validate_python(query_config)
|
||||
else:
|
||||
query_config = RAGQueryConfig()
|
||||
|
||||
query = kwargs["query"]
|
||||
result = await self.query(
|
||||
content=query,
|
||||
vector_db_ids=vector_db_ids,
|
||||
query_config=query_config,
|
||||
)
|
||||
|
||||
return ToolInvocationResult(
|
||||
content=result.content or [],
|
||||
metadata={
|
||||
**(result.metadata or {}),
|
||||
"citation_files": getattr(result, "citation_files", None),
|
||||
},
|
||||
)
|
||||
|
|
@ -17,21 +17,21 @@ from numpy.typing import NDArray
|
|||
from llama_stack.apis.common.errors import VectorStoreNotFoundError
|
||||
from llama_stack.apis.files import Files
|
||||
from llama_stack.apis.inference import Inference, InterleavedContent
|
||||
from llama_stack.apis.vector_dbs import VectorDB
|
||||
from llama_stack.apis.vector_io import Chunk, QueryChunksResponse, VectorIO
|
||||
from llama_stack.apis.vector_stores import VectorStore
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.datatypes import HealthResponse, HealthStatus, VectorDBsProtocolPrivate
|
||||
from llama_stack.providers.datatypes import HealthResponse, HealthStatus, VectorStoresProtocolPrivate
|
||||
from llama_stack.providers.utils.kvstore import kvstore_impl
|
||||
from llama_stack.providers.utils.kvstore.api import KVStore
|
||||
from llama_stack.providers.utils.memory.openai_vector_store_mixin import OpenAIVectorStoreMixin
|
||||
from llama_stack.providers.utils.memory.vector_store import ChunkForDeletion, EmbeddingIndex, VectorDBWithIndex
|
||||
from llama_stack.providers.utils.memory.vector_store import ChunkForDeletion, EmbeddingIndex, VectorStoreWithIndex
|
||||
|
||||
from .config import FaissVectorIOConfig
|
||||
|
||||
logger = get_logger(name=__name__, category="vector_io")
|
||||
|
||||
VERSION = "v3"
|
||||
VECTOR_DBS_PREFIX = f"vector_dbs:{VERSION}::"
|
||||
VECTOR_DBS_PREFIX = f"vector_stores:{VERSION}::"
|
||||
FAISS_INDEX_PREFIX = f"faiss_index:{VERSION}::"
|
||||
OPENAI_VECTOR_STORES_PREFIX = f"openai_vector_stores:{VERSION}::"
|
||||
OPENAI_VECTOR_STORES_FILES_PREFIX = f"openai_vector_stores_files:{VERSION}::"
|
||||
|
|
@ -176,28 +176,28 @@ class FaissIndex(EmbeddingIndex):
|
|||
)
|
||||
|
||||
|
||||
class FaissVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolPrivate):
|
||||
class FaissVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorStoresProtocolPrivate):
|
||||
def __init__(self, config: FaissVectorIOConfig, inference_api: Inference, files_api: Files | None) -> None:
|
||||
super().__init__(files_api=files_api, kvstore=None)
|
||||
self.config = config
|
||||
self.inference_api = inference_api
|
||||
self.cache: dict[str, VectorDBWithIndex] = {}
|
||||
self.cache: dict[str, VectorStoreWithIndex] = {}
|
||||
|
||||
async def initialize(self) -> None:
|
||||
self.kvstore = await kvstore_impl(self.config.persistence)
|
||||
# Load existing banks from kvstore
|
||||
start_key = VECTOR_DBS_PREFIX
|
||||
end_key = f"{VECTOR_DBS_PREFIX}\xff"
|
||||
stored_vector_dbs = await self.kvstore.values_in_range(start_key, end_key)
|
||||
stored_vector_stores = await self.kvstore.values_in_range(start_key, end_key)
|
||||
|
||||
for vector_db_data in stored_vector_dbs:
|
||||
vector_db = VectorDB.model_validate_json(vector_db_data)
|
||||
index = VectorDBWithIndex(
|
||||
vector_db,
|
||||
await FaissIndex.create(vector_db.embedding_dimension, self.kvstore, vector_db.identifier),
|
||||
for vector_store_data in stored_vector_stores:
|
||||
vector_store = VectorStore.model_validate_json(vector_store_data)
|
||||
index = VectorStoreWithIndex(
|
||||
vector_store,
|
||||
await FaissIndex.create(vector_store.embedding_dimension, self.kvstore, vector_store.identifier),
|
||||
self.inference_api,
|
||||
)
|
||||
self.cache[vector_db.identifier] = index
|
||||
self.cache[vector_store.identifier] = index
|
||||
|
||||
# Load existing OpenAI vector stores into the in-memory cache
|
||||
await self.initialize_openai_vector_stores()
|
||||
|
|
@ -222,32 +222,31 @@ class FaissVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolPr
|
|||
except Exception as e:
|
||||
return HealthResponse(status=HealthStatus.ERROR, message=f"Health check failed: {str(e)}")
|
||||
|
||||
async def register_vector_db(self, vector_db: VectorDB) -> None:
|
||||
async def register_vector_store(self, vector_store: VectorStore) -> None:
|
||||
assert self.kvstore is not None
|
||||
|
||||
key = f"{VECTOR_DBS_PREFIX}{vector_db.identifier}"
|
||||
await self.kvstore.set(key=key, value=vector_db.model_dump_json())
|
||||
key = f"{VECTOR_DBS_PREFIX}{vector_store.identifier}"
|
||||
await self.kvstore.set(key=key, value=vector_store.model_dump_json())
|
||||
|
||||
# Store in cache
|
||||
self.cache[vector_db.identifier] = VectorDBWithIndex(
|
||||
vector_db=vector_db,
|
||||
index=await FaissIndex.create(vector_db.embedding_dimension, self.kvstore, vector_db.identifier),
|
||||
self.cache[vector_store.identifier] = VectorStoreWithIndex(
|
||||
vector_store=vector_store,
|
||||
index=await FaissIndex.create(vector_store.embedding_dimension, self.kvstore, vector_store.identifier),
|
||||
inference_api=self.inference_api,
|
||||
)
|
||||
|
||||
async def list_vector_dbs(self) -> list[VectorDB]:
|
||||
return [i.vector_db for i in self.cache.values()]
|
||||
async def list_vector_stores(self) -> list[VectorStore]:
|
||||
return [i.vector_store for i in self.cache.values()]
|
||||
|
||||
async def unregister_vector_db(self, vector_db_id: str) -> None:
|
||||
async def unregister_vector_store(self, vector_store_id: str) -> None:
|
||||
assert self.kvstore is not None
|
||||
|
||||
if vector_db_id not in self.cache:
|
||||
logger.warning(f"Vector DB {vector_db_id} not found")
|
||||
if vector_store_id not in self.cache:
|
||||
return
|
||||
|
||||
await self.cache[vector_db_id].index.delete()
|
||||
del self.cache[vector_db_id]
|
||||
await self.kvstore.delete(f"{VECTOR_DBS_PREFIX}{vector_db_id}")
|
||||
await self.cache[vector_store_id].index.delete()
|
||||
del self.cache[vector_store_id]
|
||||
await self.kvstore.delete(f"{VECTOR_DBS_PREFIX}{vector_store_id}")
|
||||
|
||||
async def insert_chunks(self, vector_db_id: str, chunks: list[Chunk], ttl_seconds: int | None = None) -> None:
|
||||
index = self.cache.get(vector_db_id)
|
||||
|
|
|
|||
|
|
@ -17,10 +17,10 @@ from numpy.typing import NDArray
|
|||
from llama_stack.apis.common.errors import VectorStoreNotFoundError
|
||||
from llama_stack.apis.files import Files
|
||||
from llama_stack.apis.inference import Inference
|
||||
from llama_stack.apis.vector_dbs import VectorDB
|
||||
from llama_stack.apis.vector_io import Chunk, QueryChunksResponse, VectorIO
|
||||
from llama_stack.apis.vector_stores import VectorStore
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.datatypes import VectorDBsProtocolPrivate
|
||||
from llama_stack.providers.datatypes import VectorStoresProtocolPrivate
|
||||
from llama_stack.providers.utils.kvstore import kvstore_impl
|
||||
from llama_stack.providers.utils.kvstore.api import KVStore
|
||||
from llama_stack.providers.utils.memory.openai_vector_store_mixin import OpenAIVectorStoreMixin
|
||||
|
|
@ -28,7 +28,7 @@ from llama_stack.providers.utils.memory.vector_store import (
|
|||
RERANKER_TYPE_RRF,
|
||||
ChunkForDeletion,
|
||||
EmbeddingIndex,
|
||||
VectorDBWithIndex,
|
||||
VectorStoreWithIndex,
|
||||
)
|
||||
from llama_stack.providers.utils.vector_io.vector_utils import WeightedInMemoryAggregator
|
||||
|
||||
|
|
@ -41,7 +41,7 @@ HYBRID_SEARCH = "hybrid"
|
|||
SEARCH_MODES = {VECTOR_SEARCH, KEYWORD_SEARCH, HYBRID_SEARCH}
|
||||
|
||||
VERSION = "v3"
|
||||
VECTOR_DBS_PREFIX = f"vector_dbs:sqlite_vec:{VERSION}::"
|
||||
VECTOR_DBS_PREFIX = f"vector_stores:sqlite_vec:{VERSION}::"
|
||||
VECTOR_INDEX_PREFIX = f"vector_index:sqlite_vec:{VERSION}::"
|
||||
OPENAI_VECTOR_STORES_PREFIX = f"openai_vector_stores:sqlite_vec:{VERSION}::"
|
||||
OPENAI_VECTOR_STORES_FILES_PREFIX = f"openai_vector_stores_files:sqlite_vec:{VERSION}::"
|
||||
|
|
@ -374,32 +374,32 @@ class SQLiteVecIndex(EmbeddingIndex):
|
|||
await asyncio.to_thread(_delete_chunks)
|
||||
|
||||
|
||||
class SQLiteVecVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolPrivate):
|
||||
class SQLiteVecVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorStoresProtocolPrivate):
|
||||
"""
|
||||
A VectorIO implementation using SQLite + sqlite_vec.
|
||||
This class handles vector database registration (with metadata stored in a table named `vector_dbs`)
|
||||
and creates a cache of VectorDBWithIndex instances (each wrapping a SQLiteVecIndex).
|
||||
This class handles vector database registration (with metadata stored in a table named `vector_stores`)
|
||||
and creates a cache of VectorStoreWithIndex instances (each wrapping a SQLiteVecIndex).
|
||||
"""
|
||||
|
||||
def __init__(self, config, inference_api: Inference, files_api: Files | None) -> None:
|
||||
super().__init__(files_api=files_api, kvstore=None)
|
||||
self.config = config
|
||||
self.inference_api = inference_api
|
||||
self.cache: dict[str, VectorDBWithIndex] = {}
|
||||
self.vector_db_store = None
|
||||
self.cache: dict[str, VectorStoreWithIndex] = {}
|
||||
self.vector_store_table = None
|
||||
|
||||
async def initialize(self) -> None:
|
||||
self.kvstore = await kvstore_impl(self.config.persistence)
|
||||
|
||||
start_key = VECTOR_DBS_PREFIX
|
||||
end_key = f"{VECTOR_DBS_PREFIX}\xff"
|
||||
stored_vector_dbs = await self.kvstore.values_in_range(start_key, end_key)
|
||||
for db_json in stored_vector_dbs:
|
||||
vector_db = VectorDB.model_validate_json(db_json)
|
||||
stored_vector_stores = await self.kvstore.values_in_range(start_key, end_key)
|
||||
for db_json in stored_vector_stores:
|
||||
vector_store = VectorStore.model_validate_json(db_json)
|
||||
index = await SQLiteVecIndex.create(
|
||||
vector_db.embedding_dimension, self.config.db_path, vector_db.identifier
|
||||
vector_store.embedding_dimension, self.config.db_path, vector_store.identifier
|
||||
)
|
||||
self.cache[vector_db.identifier] = VectorDBWithIndex(vector_db, index, self.inference_api)
|
||||
self.cache[vector_store.identifier] = VectorStoreWithIndex(vector_store, index, self.inference_api)
|
||||
|
||||
# Load existing OpenAI vector stores into the in-memory cache
|
||||
await self.initialize_openai_vector_stores()
|
||||
|
|
@ -408,63 +408,64 @@ class SQLiteVecVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtoc
|
|||
# Clean up mixin resources (file batch tasks)
|
||||
await super().shutdown()
|
||||
|
||||
async def list_vector_dbs(self) -> list[VectorDB]:
|
||||
return [v.vector_db for v in self.cache.values()]
|
||||
async def list_vector_stores(self) -> list[VectorStore]:
|
||||
return [v.vector_store for v in self.cache.values()]
|
||||
|
||||
async def register_vector_db(self, vector_db: VectorDB) -> None:
|
||||
index = await SQLiteVecIndex.create(vector_db.embedding_dimension, self.config.db_path, vector_db.identifier)
|
||||
self.cache[vector_db.identifier] = VectorDBWithIndex(vector_db, index, self.inference_api)
|
||||
async def register_vector_store(self, vector_store: VectorStore) -> None:
|
||||
index = await SQLiteVecIndex.create(
|
||||
vector_store.embedding_dimension, self.config.db_path, vector_store.identifier
|
||||
)
|
||||
self.cache[vector_store.identifier] = VectorStoreWithIndex(vector_store, index, self.inference_api)
|
||||
|
||||
async def _get_and_cache_vector_db_index(self, vector_db_id: str) -> VectorDBWithIndex | None:
|
||||
if vector_db_id in self.cache:
|
||||
return self.cache[vector_db_id]
|
||||
async def _get_and_cache_vector_store_index(self, vector_store_id: str) -> VectorStoreWithIndex | None:
|
||||
if vector_store_id in self.cache:
|
||||
return self.cache[vector_store_id]
|
||||
|
||||
if self.vector_db_store is None:
|
||||
raise VectorStoreNotFoundError(vector_db_id)
|
||||
if self.vector_store_table is None:
|
||||
raise VectorStoreNotFoundError(vector_store_id)
|
||||
|
||||
vector_db = self.vector_db_store.get_vector_db(vector_db_id)
|
||||
if not vector_db:
|
||||
raise VectorStoreNotFoundError(vector_db_id)
|
||||
vector_store = self.vector_store_table.get_vector_store(vector_store_id)
|
||||
if not vector_store:
|
||||
raise VectorStoreNotFoundError(vector_store_id)
|
||||
|
||||
index = VectorDBWithIndex(
|
||||
vector_db=vector_db,
|
||||
index = VectorStoreWithIndex(
|
||||
vector_store=vector_store,
|
||||
index=SQLiteVecIndex(
|
||||
dimension=vector_db.embedding_dimension,
|
||||
dimension=vector_store.embedding_dimension,
|
||||
db_path=self.config.db_path,
|
||||
bank_id=vector_db.identifier,
|
||||
bank_id=vector_store.identifier,
|
||||
kvstore=self.kvstore,
|
||||
),
|
||||
inference_api=self.inference_api,
|
||||
)
|
||||
self.cache[vector_db_id] = index
|
||||
self.cache[vector_store_id] = index
|
||||
return index
|
||||
|
||||
async def unregister_vector_db(self, vector_db_id: str) -> None:
|
||||
if vector_db_id not in self.cache:
|
||||
logger.warning(f"Vector DB {vector_db_id} not found")
|
||||
async def unregister_vector_store(self, vector_store_id: str) -> None:
|
||||
if vector_store_id not in self.cache:
|
||||
return
|
||||
await self.cache[vector_db_id].index.delete()
|
||||
del self.cache[vector_db_id]
|
||||
await self.cache[vector_store_id].index.delete()
|
||||
del self.cache[vector_store_id]
|
||||
|
||||
async def insert_chunks(self, vector_db_id: str, chunks: list[Chunk], ttl_seconds: int | None = None) -> None:
|
||||
index = await self._get_and_cache_vector_db_index(vector_db_id)
|
||||
index = await self._get_and_cache_vector_store_index(vector_db_id)
|
||||
if not index:
|
||||
raise VectorStoreNotFoundError(vector_db_id)
|
||||
# The VectorDBWithIndex helper is expected to compute embeddings via the inference_api
|
||||
# The VectorStoreWithIndex helper is expected to compute embeddings via the inference_api
|
||||
# and then call our index's add_chunks.
|
||||
await index.insert_chunks(chunks)
|
||||
|
||||
async def query_chunks(
|
||||
self, vector_db_id: str, query: Any, params: dict[str, Any] | None = None
|
||||
) -> QueryChunksResponse:
|
||||
index = await self._get_and_cache_vector_db_index(vector_db_id)
|
||||
index = await self._get_and_cache_vector_store_index(vector_db_id)
|
||||
if not index:
|
||||
raise VectorStoreNotFoundError(vector_db_id)
|
||||
return await index.query_chunks(query, params)
|
||||
|
||||
async def delete_chunks(self, store_id: str, chunks_for_deletion: list[ChunkForDeletion]) -> None:
|
||||
"""Delete chunks from a sqlite_vec index."""
|
||||
index = await self._get_and_cache_vector_db_index(store_id)
|
||||
index = await self._get_and_cache_vector_store_index(store_id)
|
||||
if not index:
|
||||
raise VectorStoreNotFoundError(store_id)
|
||||
|
||||
|
|
|
|||
|
|
@ -42,6 +42,7 @@ def available_providers() -> list[ProviderSpec]:
|
|||
# CrossEncoder depends on torchao.quantization
|
||||
pip_packages=[
|
||||
"torch torchvision torchao>=0.12.0 --extra-index-url https://download.pytorch.org/whl/cpu",
|
||||
"numpy tqdm transformers",
|
||||
"sentence-transformers --no-deps",
|
||||
# required by some SentenceTransformers architectures for tensor rearrange/merge ops
|
||||
"einops",
|
||||
|
|
|
|||
|
|
@ -7,33 +7,13 @@
|
|||
|
||||
from llama_stack.providers.datatypes import (
|
||||
Api,
|
||||
InlineProviderSpec,
|
||||
ProviderSpec,
|
||||
RemoteProviderSpec,
|
||||
)
|
||||
from llama_stack.providers.registry.vector_io import DEFAULT_VECTOR_IO_DEPS
|
||||
|
||||
|
||||
def available_providers() -> list[ProviderSpec]:
|
||||
return [
|
||||
InlineProviderSpec(
|
||||
api=Api.tool_runtime,
|
||||
provider_type="inline::rag-runtime",
|
||||
pip_packages=DEFAULT_VECTOR_IO_DEPS
|
||||
+ [
|
||||
"tqdm",
|
||||
"numpy",
|
||||
"scikit-learn",
|
||||
"scipy",
|
||||
"nltk",
|
||||
"sentencepiece",
|
||||
"transformers",
|
||||
],
|
||||
module="llama_stack.providers.inline.tool_runtime.rag",
|
||||
config_class="llama_stack.providers.inline.tool_runtime.rag.config.RagToolRuntimeConfig",
|
||||
api_dependencies=[Api.vector_io, Api.inference, Api.files],
|
||||
description="RAG (Retrieval-Augmented Generation) tool runtime for document ingestion, chunking, and semantic search.",
|
||||
),
|
||||
RemoteProviderSpec(
|
||||
api=Api.tool_runtime,
|
||||
adapter_type="brave-search",
|
||||
|
|
|
|||
|
|
@ -119,7 +119,7 @@ Datasets that can fit in memory, frequent reads | Faiss | Optimized for speed, i
|
|||
#### Empirical Example
|
||||
|
||||
Consider the histogram below in which 10,000 randomly generated strings were inserted
|
||||
in batches of 100 into both Faiss and sqlite-vec using `client.tool_runtime.rag_tool.insert()`.
|
||||
in batches of 100 into both Faiss and sqlite-vec.
|
||||
|
||||
```{image} ../../../../_static/providers/vector_io/write_time_comparison_sqlite-vec-faiss.png
|
||||
:alt: Comparison of SQLite-Vec and Faiss write times
|
||||
|
|
|
|||
|
|
@ -13,15 +13,15 @@ from numpy.typing import NDArray
|
|||
|
||||
from llama_stack.apis.files import Files
|
||||
from llama_stack.apis.inference import Inference, InterleavedContent
|
||||
from llama_stack.apis.vector_dbs import VectorDB
|
||||
from llama_stack.apis.vector_io import Chunk, QueryChunksResponse, VectorIO
|
||||
from llama_stack.apis.vector_stores import VectorStore
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.datatypes import VectorDBsProtocolPrivate
|
||||
from llama_stack.providers.datatypes import VectorStoresProtocolPrivate
|
||||
from llama_stack.providers.inline.vector_io.chroma import ChromaVectorIOConfig as InlineChromaVectorIOConfig
|
||||
from llama_stack.providers.utils.kvstore import kvstore_impl
|
||||
from llama_stack.providers.utils.kvstore.api import KVStore
|
||||
from llama_stack.providers.utils.memory.openai_vector_store_mixin import OpenAIVectorStoreMixin
|
||||
from llama_stack.providers.utils.memory.vector_store import ChunkForDeletion, EmbeddingIndex, VectorDBWithIndex
|
||||
from llama_stack.providers.utils.memory.vector_store import ChunkForDeletion, EmbeddingIndex, VectorStoreWithIndex
|
||||
|
||||
from .config import ChromaVectorIOConfig as RemoteChromaVectorIOConfig
|
||||
|
||||
|
|
@ -30,7 +30,7 @@ log = get_logger(name=__name__, category="vector_io::chroma")
|
|||
ChromaClientType = chromadb.api.AsyncClientAPI | chromadb.api.ClientAPI
|
||||
|
||||
VERSION = "v3"
|
||||
VECTOR_DBS_PREFIX = f"vector_dbs:chroma:{VERSION}::"
|
||||
VECTOR_DBS_PREFIX = f"vector_stores:chroma:{VERSION}::"
|
||||
VECTOR_INDEX_PREFIX = f"vector_index:chroma:{VERSION}::"
|
||||
OPENAI_VECTOR_STORES_PREFIX = f"openai_vector_stores:chroma:{VERSION}::"
|
||||
OPENAI_VECTOR_STORES_FILES_PREFIX = f"openai_vector_stores_files:chroma:{VERSION}::"
|
||||
|
|
@ -114,7 +114,7 @@ class ChromaIndex(EmbeddingIndex):
|
|||
raise NotImplementedError("Hybrid search is not supported in Chroma")
|
||||
|
||||
|
||||
class ChromaVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolPrivate):
|
||||
class ChromaVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorStoresProtocolPrivate):
|
||||
def __init__(
|
||||
self,
|
||||
config: RemoteChromaVectorIOConfig | InlineChromaVectorIOConfig,
|
||||
|
|
@ -127,11 +127,11 @@ class ChromaVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP
|
|||
self.inference_api = inference_api
|
||||
self.client = None
|
||||
self.cache = {}
|
||||
self.vector_db_store = None
|
||||
self.vector_store_table = None
|
||||
|
||||
async def initialize(self) -> None:
|
||||
self.kvstore = await kvstore_impl(self.config.persistence)
|
||||
self.vector_db_store = self.kvstore
|
||||
self.vector_store_table = self.kvstore
|
||||
|
||||
if isinstance(self.config, RemoteChromaVectorIOConfig):
|
||||
log.info(f"Connecting to Chroma server at: {self.config.url}")
|
||||
|
|
@ -151,26 +151,26 @@ class ChromaVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP
|
|||
# Clean up mixin resources (file batch tasks)
|
||||
await super().shutdown()
|
||||
|
||||
async def register_vector_db(self, vector_db: VectorDB) -> None:
|
||||
async def register_vector_store(self, vector_store: VectorStore) -> None:
|
||||
collection = await maybe_await(
|
||||
self.client.get_or_create_collection(
|
||||
name=vector_db.identifier, metadata={"vector_db": vector_db.model_dump_json()}
|
||||
name=vector_store.identifier, metadata={"vector_store": vector_store.model_dump_json()}
|
||||
)
|
||||
)
|
||||
self.cache[vector_db.identifier] = VectorDBWithIndex(
|
||||
vector_db, ChromaIndex(self.client, collection), self.inference_api
|
||||
self.cache[vector_store.identifier] = VectorStoreWithIndex(
|
||||
vector_store, ChromaIndex(self.client, collection), self.inference_api
|
||||
)
|
||||
|
||||
async def unregister_vector_db(self, vector_db_id: str) -> None:
|
||||
if vector_db_id not in self.cache:
|
||||
log.warning(f"Vector DB {vector_db_id} not found")
|
||||
async def unregister_vector_store(self, vector_store_id: str) -> None:
|
||||
if vector_store_id not in self.cache:
|
||||
log.warning(f"Vector DB {vector_store_id} not found")
|
||||
return
|
||||
|
||||
await self.cache[vector_db_id].index.delete()
|
||||
del self.cache[vector_db_id]
|
||||
await self.cache[vector_store_id].index.delete()
|
||||
del self.cache[vector_store_id]
|
||||
|
||||
async def insert_chunks(self, vector_db_id: str, chunks: list[Chunk], ttl_seconds: int | None = None) -> None:
|
||||
index = await self._get_and_cache_vector_db_index(vector_db_id)
|
||||
index = await self._get_and_cache_vector_store_index(vector_db_id)
|
||||
if index is None:
|
||||
raise ValueError(f"Vector DB {vector_db_id} not found in Chroma")
|
||||
|
||||
|
|
@ -179,30 +179,30 @@ class ChromaVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP
|
|||
async def query_chunks(
|
||||
self, vector_db_id: str, query: InterleavedContent, params: dict[str, Any] | None = None
|
||||
) -> QueryChunksResponse:
|
||||
index = await self._get_and_cache_vector_db_index(vector_db_id)
|
||||
index = await self._get_and_cache_vector_store_index(vector_db_id)
|
||||
|
||||
if index is None:
|
||||
raise ValueError(f"Vector DB {vector_db_id} not found in Chroma")
|
||||
|
||||
return await index.query_chunks(query, params)
|
||||
|
||||
async def _get_and_cache_vector_db_index(self, vector_db_id: str) -> VectorDBWithIndex:
|
||||
if vector_db_id in self.cache:
|
||||
return self.cache[vector_db_id]
|
||||
async def _get_and_cache_vector_store_index(self, vector_store_id: str) -> VectorStoreWithIndex:
|
||||
if vector_store_id in self.cache:
|
||||
return self.cache[vector_store_id]
|
||||
|
||||
vector_db = await self.vector_db_store.get_vector_db(vector_db_id)
|
||||
if not vector_db:
|
||||
raise ValueError(f"Vector DB {vector_db_id} not found in Llama Stack")
|
||||
collection = await maybe_await(self.client.get_collection(vector_db_id))
|
||||
vector_store = await self.vector_store_table.get_vector_store(vector_store_id)
|
||||
if not vector_store:
|
||||
raise ValueError(f"Vector DB {vector_store_id} not found in Llama Stack")
|
||||
collection = await maybe_await(self.client.get_collection(vector_store_id))
|
||||
if not collection:
|
||||
raise ValueError(f"Vector DB {vector_db_id} not found in Chroma")
|
||||
index = VectorDBWithIndex(vector_db, ChromaIndex(self.client, collection), self.inference_api)
|
||||
self.cache[vector_db_id] = index
|
||||
raise ValueError(f"Vector DB {vector_store_id} not found in Chroma")
|
||||
index = VectorStoreWithIndex(vector_store, ChromaIndex(self.client, collection), self.inference_api)
|
||||
self.cache[vector_store_id] = index
|
||||
return index
|
||||
|
||||
async def delete_chunks(self, store_id: str, chunks_for_deletion: list[ChunkForDeletion]) -> None:
|
||||
"""Delete chunks from a Chroma vector store."""
|
||||
index = await self._get_and_cache_vector_db_index(store_id)
|
||||
index = await self._get_and_cache_vector_store_index(store_id)
|
||||
if not index:
|
||||
raise ValueError(f"Vector DB {store_id} not found")
|
||||
|
||||
|
|
|
|||
|
|
@ -14,10 +14,10 @@ from pymilvus import AnnSearchRequest, DataType, Function, FunctionType, MilvusC
|
|||
from llama_stack.apis.common.errors import VectorStoreNotFoundError
|
||||
from llama_stack.apis.files import Files
|
||||
from llama_stack.apis.inference import Inference, InterleavedContent
|
||||
from llama_stack.apis.vector_dbs import VectorDB
|
||||
from llama_stack.apis.vector_io import Chunk, QueryChunksResponse, VectorIO
|
||||
from llama_stack.apis.vector_stores import VectorStore
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.datatypes import VectorDBsProtocolPrivate
|
||||
from llama_stack.providers.datatypes import VectorStoresProtocolPrivate
|
||||
from llama_stack.providers.inline.vector_io.milvus import MilvusVectorIOConfig as InlineMilvusVectorIOConfig
|
||||
from llama_stack.providers.utils.kvstore import kvstore_impl
|
||||
from llama_stack.providers.utils.kvstore.api import KVStore
|
||||
|
|
@ -26,7 +26,7 @@ from llama_stack.providers.utils.memory.vector_store import (
|
|||
RERANKER_TYPE_WEIGHTED,
|
||||
ChunkForDeletion,
|
||||
EmbeddingIndex,
|
||||
VectorDBWithIndex,
|
||||
VectorStoreWithIndex,
|
||||
)
|
||||
from llama_stack.providers.utils.vector_io.vector_utils import sanitize_collection_name
|
||||
|
||||
|
|
@ -35,7 +35,7 @@ from .config import MilvusVectorIOConfig as RemoteMilvusVectorIOConfig
|
|||
logger = get_logger(name=__name__, category="vector_io::milvus")
|
||||
|
||||
VERSION = "v3"
|
||||
VECTOR_DBS_PREFIX = f"vector_dbs:milvus:{VERSION}::"
|
||||
VECTOR_DBS_PREFIX = f"vector_stores:milvus:{VERSION}::"
|
||||
VECTOR_INDEX_PREFIX = f"vector_index:milvus:{VERSION}::"
|
||||
OPENAI_VECTOR_STORES_PREFIX = f"openai_vector_stores:milvus:{VERSION}::"
|
||||
OPENAI_VECTOR_STORES_FILES_PREFIX = f"openai_vector_stores_files:milvus:{VERSION}::"
|
||||
|
|
@ -261,7 +261,7 @@ class MilvusIndex(EmbeddingIndex):
|
|||
raise
|
||||
|
||||
|
||||
class MilvusVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolPrivate):
|
||||
class MilvusVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorStoresProtocolPrivate):
|
||||
def __init__(
|
||||
self,
|
||||
config: RemoteMilvusVectorIOConfig | InlineMilvusVectorIOConfig,
|
||||
|
|
@ -273,28 +273,28 @@ class MilvusVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP
|
|||
self.cache = {}
|
||||
self.client = None
|
||||
self.inference_api = inference_api
|
||||
self.vector_db_store = None
|
||||
self.vector_store_table = None
|
||||
self.metadata_collection_name = "openai_vector_stores_metadata"
|
||||
|
||||
async def initialize(self) -> None:
|
||||
self.kvstore = await kvstore_impl(self.config.persistence)
|
||||
start_key = VECTOR_DBS_PREFIX
|
||||
end_key = f"{VECTOR_DBS_PREFIX}\xff"
|
||||
stored_vector_dbs = await self.kvstore.values_in_range(start_key, end_key)
|
||||
stored_vector_stores = await self.kvstore.values_in_range(start_key, end_key)
|
||||
|
||||
for vector_db_data in stored_vector_dbs:
|
||||
vector_db = VectorDB.model_validate_json(vector_db_data)
|
||||
index = VectorDBWithIndex(
|
||||
vector_db,
|
||||
for vector_store_data in stored_vector_stores:
|
||||
vector_store = VectorStore.model_validate_json(vector_store_data)
|
||||
index = VectorStoreWithIndex(
|
||||
vector_store,
|
||||
index=MilvusIndex(
|
||||
client=self.client,
|
||||
collection_name=vector_db.identifier,
|
||||
collection_name=vector_store.identifier,
|
||||
consistency_level=self.config.consistency_level,
|
||||
kvstore=self.kvstore,
|
||||
),
|
||||
inference_api=self.inference_api,
|
||||
)
|
||||
self.cache[vector_db.identifier] = index
|
||||
self.cache[vector_store.identifier] = index
|
||||
if isinstance(self.config, RemoteMilvusVectorIOConfig):
|
||||
logger.info(f"Connecting to Milvus server at {self.config.uri}")
|
||||
self.client = MilvusClient(**self.config.model_dump(exclude_none=True))
|
||||
|
|
@ -311,45 +311,45 @@ class MilvusVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP
|
|||
# Clean up mixin resources (file batch tasks)
|
||||
await super().shutdown()
|
||||
|
||||
async def register_vector_db(self, vector_db: VectorDB) -> None:
|
||||
async def register_vector_store(self, vector_store: VectorStore) -> None:
|
||||
if isinstance(self.config, RemoteMilvusVectorIOConfig):
|
||||
consistency_level = self.config.consistency_level
|
||||
else:
|
||||
consistency_level = "Strong"
|
||||
index = VectorDBWithIndex(
|
||||
vector_db=vector_db,
|
||||
index=MilvusIndex(self.client, vector_db.identifier, consistency_level=consistency_level),
|
||||
index = VectorStoreWithIndex(
|
||||
vector_store=vector_store,
|
||||
index=MilvusIndex(self.client, vector_store.identifier, consistency_level=consistency_level),
|
||||
inference_api=self.inference_api,
|
||||
)
|
||||
|
||||
self.cache[vector_db.identifier] = index
|
||||
self.cache[vector_store.identifier] = index
|
||||
|
||||
async def _get_and_cache_vector_db_index(self, vector_db_id: str) -> VectorDBWithIndex | None:
|
||||
if vector_db_id in self.cache:
|
||||
return self.cache[vector_db_id]
|
||||
async def _get_and_cache_vector_store_index(self, vector_store_id: str) -> VectorStoreWithIndex | None:
|
||||
if vector_store_id in self.cache:
|
||||
return self.cache[vector_store_id]
|
||||
|
||||
if self.vector_db_store is None:
|
||||
raise VectorStoreNotFoundError(vector_db_id)
|
||||
if self.vector_store_table is None:
|
||||
raise VectorStoreNotFoundError(vector_store_id)
|
||||
|
||||
vector_db = await self.vector_db_store.get_vector_db(vector_db_id)
|
||||
if not vector_db:
|
||||
raise VectorStoreNotFoundError(vector_db_id)
|
||||
vector_store = await self.vector_store_table.get_vector_store(vector_store_id)
|
||||
if not vector_store:
|
||||
raise VectorStoreNotFoundError(vector_store_id)
|
||||
|
||||
index = VectorDBWithIndex(
|
||||
vector_db=vector_db,
|
||||
index=MilvusIndex(client=self.client, collection_name=vector_db.identifier, kvstore=self.kvstore),
|
||||
index = VectorStoreWithIndex(
|
||||
vector_store=vector_store,
|
||||
index=MilvusIndex(client=self.client, collection_name=vector_store.identifier, kvstore=self.kvstore),
|
||||
inference_api=self.inference_api,
|
||||
)
|
||||
self.cache[vector_db_id] = index
|
||||
self.cache[vector_store_id] = index
|
||||
return index
|
||||
|
||||
async def unregister_vector_db(self, vector_db_id: str) -> None:
|
||||
if vector_db_id in self.cache:
|
||||
await self.cache[vector_db_id].index.delete()
|
||||
del self.cache[vector_db_id]
|
||||
async def unregister_vector_store(self, vector_store_id: str) -> None:
|
||||
if vector_store_id in self.cache:
|
||||
await self.cache[vector_store_id].index.delete()
|
||||
del self.cache[vector_store_id]
|
||||
|
||||
async def insert_chunks(self, vector_db_id: str, chunks: list[Chunk], ttl_seconds: int | None = None) -> None:
|
||||
index = await self._get_and_cache_vector_db_index(vector_db_id)
|
||||
index = await self._get_and_cache_vector_store_index(vector_db_id)
|
||||
if not index:
|
||||
raise VectorStoreNotFoundError(vector_db_id)
|
||||
|
||||
|
|
@ -358,14 +358,14 @@ class MilvusVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP
|
|||
async def query_chunks(
|
||||
self, vector_db_id: str, query: InterleavedContent, params: dict[str, Any] | None = None
|
||||
) -> QueryChunksResponse:
|
||||
index = await self._get_and_cache_vector_db_index(vector_db_id)
|
||||
index = await self._get_and_cache_vector_store_index(vector_db_id)
|
||||
if not index:
|
||||
raise VectorStoreNotFoundError(vector_db_id)
|
||||
return await index.query_chunks(query, params)
|
||||
|
||||
async def delete_chunks(self, store_id: str, chunks_for_deletion: list[ChunkForDeletion]) -> None:
|
||||
"""Delete a chunk from a milvus vector store."""
|
||||
index = await self._get_and_cache_vector_db_index(store_id)
|
||||
index = await self._get_and_cache_vector_store_index(store_id)
|
||||
if not index:
|
||||
raise VectorStoreNotFoundError(store_id)
|
||||
|
||||
|
|
|
|||
|
|
@ -16,15 +16,15 @@ from pydantic import BaseModel, TypeAdapter
|
|||
from llama_stack.apis.common.errors import VectorStoreNotFoundError
|
||||
from llama_stack.apis.files import Files
|
||||
from llama_stack.apis.inference import Inference, InterleavedContent
|
||||
from llama_stack.apis.vector_dbs import VectorDB
|
||||
from llama_stack.apis.vector_io import Chunk, QueryChunksResponse, VectorIO
|
||||
from llama_stack.apis.vector_stores import VectorStore
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.datatypes import VectorDBsProtocolPrivate
|
||||
from llama_stack.providers.datatypes import VectorStoresProtocolPrivate
|
||||
from llama_stack.providers.utils.inference.prompt_adapter import interleaved_content_as_str
|
||||
from llama_stack.providers.utils.kvstore import kvstore_impl
|
||||
from llama_stack.providers.utils.kvstore.api import KVStore
|
||||
from llama_stack.providers.utils.memory.openai_vector_store_mixin import OpenAIVectorStoreMixin
|
||||
from llama_stack.providers.utils.memory.vector_store import ChunkForDeletion, EmbeddingIndex, VectorDBWithIndex
|
||||
from llama_stack.providers.utils.memory.vector_store import ChunkForDeletion, EmbeddingIndex, VectorStoreWithIndex
|
||||
from llama_stack.providers.utils.vector_io.vector_utils import WeightedInMemoryAggregator, sanitize_collection_name
|
||||
|
||||
from .config import PGVectorVectorIOConfig
|
||||
|
|
@ -32,7 +32,7 @@ from .config import PGVectorVectorIOConfig
|
|||
log = get_logger(name=__name__, category="vector_io::pgvector")
|
||||
|
||||
VERSION = "v3"
|
||||
VECTOR_DBS_PREFIX = f"vector_dbs:pgvector:{VERSION}::"
|
||||
VECTOR_DBS_PREFIX = f"vector_stores:pgvector:{VERSION}::"
|
||||
VECTOR_INDEX_PREFIX = f"vector_index:pgvector:{VERSION}::"
|
||||
OPENAI_VECTOR_STORES_PREFIX = f"openai_vector_stores:pgvector:{VERSION}::"
|
||||
OPENAI_VECTOR_STORES_FILES_PREFIX = f"openai_vector_stores_files:pgvector:{VERSION}::"
|
||||
|
|
@ -79,13 +79,13 @@ class PGVectorIndex(EmbeddingIndex):
|
|||
|
||||
def __init__(
|
||||
self,
|
||||
vector_db: VectorDB,
|
||||
vector_store: VectorStore,
|
||||
dimension: int,
|
||||
conn: psycopg2.extensions.connection,
|
||||
kvstore: KVStore | None = None,
|
||||
distance_metric: str = "COSINE",
|
||||
):
|
||||
self.vector_db = vector_db
|
||||
self.vector_store = vector_store
|
||||
self.dimension = dimension
|
||||
self.conn = conn
|
||||
self.kvstore = kvstore
|
||||
|
|
@ -97,9 +97,9 @@ class PGVectorIndex(EmbeddingIndex):
|
|||
try:
|
||||
with self.conn.cursor(cursor_factory=psycopg2.extras.DictCursor) as cur:
|
||||
# Sanitize the table name by replacing hyphens with underscores
|
||||
# SQL doesn't allow hyphens in table names, and vector_db.identifier may contain hyphens
|
||||
# SQL doesn't allow hyphens in table names, and vector_store.identifier may contain hyphens
|
||||
# when created with patterns like "test-vector-db-{uuid4()}"
|
||||
sanitized_identifier = sanitize_collection_name(self.vector_db.identifier)
|
||||
sanitized_identifier = sanitize_collection_name(self.vector_store.identifier)
|
||||
self.table_name = f"vs_{sanitized_identifier}"
|
||||
|
||||
cur.execute(
|
||||
|
|
@ -122,8 +122,8 @@ class PGVectorIndex(EmbeddingIndex):
|
|||
"""
|
||||
)
|
||||
except Exception as e:
|
||||
log.exception(f"Error creating PGVectorIndex for vector_db: {self.vector_db.identifier}")
|
||||
raise RuntimeError(f"Error creating PGVectorIndex for vector_db: {self.vector_db.identifier}") from e
|
||||
log.exception(f"Error creating PGVectorIndex for vector_store: {self.vector_store.identifier}")
|
||||
raise RuntimeError(f"Error creating PGVectorIndex for vector_store: {self.vector_store.identifier}") from e
|
||||
|
||||
async def add_chunks(self, chunks: list[Chunk], embeddings: NDArray):
|
||||
assert len(chunks) == len(embeddings), (
|
||||
|
|
@ -323,7 +323,7 @@ class PGVectorIndex(EmbeddingIndex):
|
|||
)
|
||||
|
||||
|
||||
class PGVectorVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolPrivate):
|
||||
class PGVectorVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorStoresProtocolPrivate):
|
||||
def __init__(
|
||||
self, config: PGVectorVectorIOConfig, inference_api: Inference, files_api: Files | None = None
|
||||
) -> None:
|
||||
|
|
@ -332,7 +332,7 @@ class PGVectorVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtoco
|
|||
self.inference_api = inference_api
|
||||
self.conn = None
|
||||
self.cache = {}
|
||||
self.vector_db_store = None
|
||||
self.vector_store_table = None
|
||||
self.metadata_collection_name = "openai_vector_stores_metadata"
|
||||
|
||||
async def initialize(self) -> None:
|
||||
|
|
@ -375,59 +375,59 @@ class PGVectorVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtoco
|
|||
# Clean up mixin resources (file batch tasks)
|
||||
await super().shutdown()
|
||||
|
||||
async def register_vector_db(self, vector_db: VectorDB) -> None:
|
||||
async def register_vector_store(self, vector_store: VectorStore) -> None:
|
||||
# Persist vector DB metadata in the KV store
|
||||
assert self.kvstore is not None
|
||||
# Upsert model metadata in Postgres
|
||||
upsert_models(self.conn, [(vector_db.identifier, vector_db)])
|
||||
upsert_models(self.conn, [(vector_store.identifier, vector_store)])
|
||||
|
||||
# Create and cache the PGVector index table for the vector DB
|
||||
pgvector_index = PGVectorIndex(
|
||||
vector_db=vector_db, dimension=vector_db.embedding_dimension, conn=self.conn, kvstore=self.kvstore
|
||||
vector_store=vector_store, dimension=vector_store.embedding_dimension, conn=self.conn, kvstore=self.kvstore
|
||||
)
|
||||
await pgvector_index.initialize()
|
||||
index = VectorDBWithIndex(vector_db, index=pgvector_index, inference_api=self.inference_api)
|
||||
self.cache[vector_db.identifier] = index
|
||||
index = VectorStoreWithIndex(vector_store, index=pgvector_index, inference_api=self.inference_api)
|
||||
self.cache[vector_store.identifier] = index
|
||||
|
||||
async def unregister_vector_db(self, vector_db_id: str) -> None:
|
||||
async def unregister_vector_store(self, vector_store_id: str) -> None:
|
||||
# Remove provider index and cache
|
||||
if vector_db_id in self.cache:
|
||||
await self.cache[vector_db_id].index.delete()
|
||||
del self.cache[vector_db_id]
|
||||
if vector_store_id in self.cache:
|
||||
await self.cache[vector_store_id].index.delete()
|
||||
del self.cache[vector_store_id]
|
||||
|
||||
# Delete vector DB metadata from KV store
|
||||
assert self.kvstore is not None
|
||||
await self.kvstore.delete(key=f"{VECTOR_DBS_PREFIX}{vector_db_id}")
|
||||
await self.kvstore.delete(key=f"{VECTOR_DBS_PREFIX}{vector_store_id}")
|
||||
|
||||
async def insert_chunks(self, vector_db_id: str, chunks: list[Chunk], ttl_seconds: int | None = None) -> None:
|
||||
index = await self._get_and_cache_vector_db_index(vector_db_id)
|
||||
index = await self._get_and_cache_vector_store_index(vector_db_id)
|
||||
await index.insert_chunks(chunks)
|
||||
|
||||
async def query_chunks(
|
||||
self, vector_db_id: str, query: InterleavedContent, params: dict[str, Any] | None = None
|
||||
) -> QueryChunksResponse:
|
||||
index = await self._get_and_cache_vector_db_index(vector_db_id)
|
||||
index = await self._get_and_cache_vector_store_index(vector_db_id)
|
||||
return await index.query_chunks(query, params)
|
||||
|
||||
async def _get_and_cache_vector_db_index(self, vector_db_id: str) -> VectorDBWithIndex:
|
||||
if vector_db_id in self.cache:
|
||||
return self.cache[vector_db_id]
|
||||
async def _get_and_cache_vector_store_index(self, vector_store_id: str) -> VectorStoreWithIndex:
|
||||
if vector_store_id in self.cache:
|
||||
return self.cache[vector_store_id]
|
||||
|
||||
if self.vector_db_store is None:
|
||||
raise VectorStoreNotFoundError(vector_db_id)
|
||||
if self.vector_store_table is None:
|
||||
raise VectorStoreNotFoundError(vector_store_id)
|
||||
|
||||
vector_db = await self.vector_db_store.get_vector_db(vector_db_id)
|
||||
if not vector_db:
|
||||
raise VectorStoreNotFoundError(vector_db_id)
|
||||
vector_store = await self.vector_store_table.get_vector_store(vector_store_id)
|
||||
if not vector_store:
|
||||
raise VectorStoreNotFoundError(vector_store_id)
|
||||
|
||||
index = PGVectorIndex(vector_db, vector_db.embedding_dimension, self.conn)
|
||||
index = PGVectorIndex(vector_store, vector_store.embedding_dimension, self.conn)
|
||||
await index.initialize()
|
||||
self.cache[vector_db_id] = VectorDBWithIndex(vector_db, index, self.inference_api)
|
||||
return self.cache[vector_db_id]
|
||||
self.cache[vector_store_id] = VectorStoreWithIndex(vector_store, index, self.inference_api)
|
||||
return self.cache[vector_store_id]
|
||||
|
||||
async def delete_chunks(self, store_id: str, chunks_for_deletion: list[ChunkForDeletion]) -> None:
|
||||
"""Delete a chunk from a PostgreSQL vector store."""
|
||||
index = await self._get_and_cache_vector_db_index(store_id)
|
||||
index = await self._get_and_cache_vector_store_index(store_id)
|
||||
if not index:
|
||||
raise VectorStoreNotFoundError(store_id)
|
||||
|
||||
|
|
|
|||
|
|
@ -16,7 +16,6 @@ from qdrant_client.models import PointStruct
|
|||
from llama_stack.apis.common.errors import VectorStoreNotFoundError
|
||||
from llama_stack.apis.files import Files
|
||||
from llama_stack.apis.inference import Inference, InterleavedContent
|
||||
from llama_stack.apis.vector_dbs import VectorDB
|
||||
from llama_stack.apis.vector_io import (
|
||||
Chunk,
|
||||
QueryChunksResponse,
|
||||
|
|
@ -24,12 +23,13 @@ from llama_stack.apis.vector_io import (
|
|||
VectorStoreChunkingStrategy,
|
||||
VectorStoreFileObject,
|
||||
)
|
||||
from llama_stack.apis.vector_stores import VectorStore
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.datatypes import VectorDBsProtocolPrivate
|
||||
from llama_stack.providers.datatypes import VectorStoresProtocolPrivate
|
||||
from llama_stack.providers.inline.vector_io.qdrant import QdrantVectorIOConfig as InlineQdrantVectorIOConfig
|
||||
from llama_stack.providers.utils.kvstore import kvstore_impl
|
||||
from llama_stack.providers.utils.memory.openai_vector_store_mixin import OpenAIVectorStoreMixin
|
||||
from llama_stack.providers.utils.memory.vector_store import ChunkForDeletion, EmbeddingIndex, VectorDBWithIndex
|
||||
from llama_stack.providers.utils.memory.vector_store import ChunkForDeletion, EmbeddingIndex, VectorStoreWithIndex
|
||||
|
||||
from .config import QdrantVectorIOConfig as RemoteQdrantVectorIOConfig
|
||||
|
||||
|
|
@ -38,7 +38,7 @@ CHUNK_ID_KEY = "_chunk_id"
|
|||
|
||||
# KV store prefixes for vector databases
|
||||
VERSION = "v3"
|
||||
VECTOR_DBS_PREFIX = f"vector_dbs:qdrant:{VERSION}::"
|
||||
VECTOR_DBS_PREFIX = f"vector_stores:qdrant:{VERSION}::"
|
||||
|
||||
|
||||
def convert_id(_id: str) -> str:
|
||||
|
|
@ -145,7 +145,7 @@ class QdrantIndex(EmbeddingIndex):
|
|||
await self.client.delete_collection(collection_name=self.collection_name)
|
||||
|
||||
|
||||
class QdrantVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolPrivate):
|
||||
class QdrantVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorStoresProtocolPrivate):
|
||||
def __init__(
|
||||
self,
|
||||
config: RemoteQdrantVectorIOConfig | InlineQdrantVectorIOConfig,
|
||||
|
|
@ -157,7 +157,7 @@ class QdrantVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP
|
|||
self.client: AsyncQdrantClient = None
|
||||
self.cache = {}
|
||||
self.inference_api = inference_api
|
||||
self.vector_db_store = None
|
||||
self.vector_store_table = None
|
||||
self._qdrant_lock = asyncio.Lock()
|
||||
|
||||
async def initialize(self) -> None:
|
||||
|
|
@ -167,12 +167,14 @@ class QdrantVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP
|
|||
|
||||
start_key = VECTOR_DBS_PREFIX
|
||||
end_key = f"{VECTOR_DBS_PREFIX}\xff"
|
||||
stored_vector_dbs = await self.kvstore.values_in_range(start_key, end_key)
|
||||
stored_vector_stores = await self.kvstore.values_in_range(start_key, end_key)
|
||||
|
||||
for vector_db_data in stored_vector_dbs:
|
||||
vector_db = VectorDB.model_validate_json(vector_db_data)
|
||||
index = VectorDBWithIndex(vector_db, QdrantIndex(self.client, vector_db.identifier), self.inference_api)
|
||||
self.cache[vector_db.identifier] = index
|
||||
for vector_store_data in stored_vector_stores:
|
||||
vector_store = VectorStore.model_validate_json(vector_store_data)
|
||||
index = VectorStoreWithIndex(
|
||||
vector_store, QdrantIndex(self.client, vector_store.identifier), self.inference_api
|
||||
)
|
||||
self.cache[vector_store.identifier] = index
|
||||
self.openai_vector_stores = await self._load_openai_vector_stores()
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
|
|
@ -180,46 +182,48 @@ class QdrantVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP
|
|||
# Clean up mixin resources (file batch tasks)
|
||||
await super().shutdown()
|
||||
|
||||
async def register_vector_db(self, vector_db: VectorDB) -> None:
|
||||
async def register_vector_store(self, vector_store: VectorStore) -> None:
|
||||
assert self.kvstore is not None
|
||||
key = f"{VECTOR_DBS_PREFIX}{vector_db.identifier}"
|
||||
await self.kvstore.set(key=key, value=vector_db.model_dump_json())
|
||||
key = f"{VECTOR_DBS_PREFIX}{vector_store.identifier}"
|
||||
await self.kvstore.set(key=key, value=vector_store.model_dump_json())
|
||||
|
||||
index = VectorDBWithIndex(
|
||||
vector_db=vector_db, index=QdrantIndex(self.client, vector_db.identifier), inference_api=self.inference_api
|
||||
)
|
||||
|
||||
self.cache[vector_db.identifier] = index
|
||||
|
||||
async def unregister_vector_db(self, vector_db_id: str) -> None:
|
||||
if vector_db_id in self.cache:
|
||||
await self.cache[vector_db_id].index.delete()
|
||||
del self.cache[vector_db_id]
|
||||
|
||||
assert self.kvstore is not None
|
||||
await self.kvstore.delete(f"{VECTOR_DBS_PREFIX}{vector_db_id}")
|
||||
|
||||
async def _get_and_cache_vector_db_index(self, vector_db_id: str) -> VectorDBWithIndex | None:
|
||||
if vector_db_id in self.cache:
|
||||
return self.cache[vector_db_id]
|
||||
|
||||
if self.vector_db_store is None:
|
||||
raise ValueError(f"Vector DB not found {vector_db_id}")
|
||||
|
||||
vector_db = await self.vector_db_store.get_vector_db(vector_db_id)
|
||||
if not vector_db:
|
||||
raise VectorStoreNotFoundError(vector_db_id)
|
||||
|
||||
index = VectorDBWithIndex(
|
||||
vector_db=vector_db,
|
||||
index=QdrantIndex(client=self.client, collection_name=vector_db.identifier),
|
||||
index = VectorStoreWithIndex(
|
||||
vector_store=vector_store,
|
||||
index=QdrantIndex(self.client, vector_store.identifier),
|
||||
inference_api=self.inference_api,
|
||||
)
|
||||
self.cache[vector_db_id] = index
|
||||
|
||||
self.cache[vector_store.identifier] = index
|
||||
|
||||
async def unregister_vector_store(self, vector_store_id: str) -> None:
|
||||
if vector_store_id in self.cache:
|
||||
await self.cache[vector_store_id].index.delete()
|
||||
del self.cache[vector_store_id]
|
||||
|
||||
assert self.kvstore is not None
|
||||
await self.kvstore.delete(f"{VECTOR_DBS_PREFIX}{vector_store_id}")
|
||||
|
||||
async def _get_and_cache_vector_store_index(self, vector_store_id: str) -> VectorStoreWithIndex | None:
|
||||
if vector_store_id in self.cache:
|
||||
return self.cache[vector_store_id]
|
||||
|
||||
if self.vector_store_table is None:
|
||||
raise ValueError(f"Vector DB not found {vector_store_id}")
|
||||
|
||||
vector_store = await self.vector_store_table.get_vector_store(vector_store_id)
|
||||
if not vector_store:
|
||||
raise VectorStoreNotFoundError(vector_store_id)
|
||||
|
||||
index = VectorStoreWithIndex(
|
||||
vector_store=vector_store,
|
||||
index=QdrantIndex(client=self.client, collection_name=vector_store.identifier),
|
||||
inference_api=self.inference_api,
|
||||
)
|
||||
self.cache[vector_store_id] = index
|
||||
return index
|
||||
|
||||
async def insert_chunks(self, vector_db_id: str, chunks: list[Chunk], ttl_seconds: int | None = None) -> None:
|
||||
index = await self._get_and_cache_vector_db_index(vector_db_id)
|
||||
index = await self._get_and_cache_vector_store_index(vector_db_id)
|
||||
if not index:
|
||||
raise VectorStoreNotFoundError(vector_db_id)
|
||||
|
||||
|
|
@ -228,7 +232,7 @@ class QdrantVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP
|
|||
async def query_chunks(
|
||||
self, vector_db_id: str, query: InterleavedContent, params: dict[str, Any] | None = None
|
||||
) -> QueryChunksResponse:
|
||||
index = await self._get_and_cache_vector_db_index(vector_db_id)
|
||||
index = await self._get_and_cache_vector_store_index(vector_db_id)
|
||||
if not index:
|
||||
raise VectorStoreNotFoundError(vector_db_id)
|
||||
|
||||
|
|
@ -249,7 +253,7 @@ class QdrantVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP
|
|||
|
||||
async def delete_chunks(self, store_id: str, chunks_for_deletion: list[ChunkForDeletion]) -> None:
|
||||
"""Delete chunks from a Qdrant vector store."""
|
||||
index = await self._get_and_cache_vector_db_index(store_id)
|
||||
index = await self._get_and_cache_vector_store_index(store_id)
|
||||
if not index:
|
||||
raise ValueError(f"Vector DB {store_id} not found")
|
||||
|
||||
|
|
|
|||
|
|
@ -16,11 +16,11 @@ from llama_stack.apis.common.content_types import InterleavedContent
|
|||
from llama_stack.apis.common.errors import VectorStoreNotFoundError
|
||||
from llama_stack.apis.files import Files
|
||||
from llama_stack.apis.inference import Inference
|
||||
from llama_stack.apis.vector_dbs import VectorDB
|
||||
from llama_stack.apis.vector_io import Chunk, QueryChunksResponse, VectorIO
|
||||
from llama_stack.apis.vector_stores import VectorStore
|
||||
from llama_stack.core.request_headers import NeedsRequestProviderData
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.datatypes import VectorDBsProtocolPrivate
|
||||
from llama_stack.providers.datatypes import VectorStoresProtocolPrivate
|
||||
from llama_stack.providers.utils.kvstore import kvstore_impl
|
||||
from llama_stack.providers.utils.kvstore.api import KVStore
|
||||
from llama_stack.providers.utils.memory.openai_vector_store_mixin import OpenAIVectorStoreMixin
|
||||
|
|
@ -28,7 +28,7 @@ from llama_stack.providers.utils.memory.vector_store import (
|
|||
RERANKER_TYPE_RRF,
|
||||
ChunkForDeletion,
|
||||
EmbeddingIndex,
|
||||
VectorDBWithIndex,
|
||||
VectorStoreWithIndex,
|
||||
)
|
||||
from llama_stack.providers.utils.vector_io.vector_utils import sanitize_collection_name
|
||||
|
||||
|
|
@ -37,7 +37,7 @@ from .config import WeaviateVectorIOConfig
|
|||
log = get_logger(name=__name__, category="vector_io::weaviate")
|
||||
|
||||
VERSION = "v3"
|
||||
VECTOR_DBS_PREFIX = f"vector_dbs:weaviate:{VERSION}::"
|
||||
VECTOR_DBS_PREFIX = f"vector_stores:weaviate:{VERSION}::"
|
||||
VECTOR_INDEX_PREFIX = f"vector_index:weaviate:{VERSION}::"
|
||||
OPENAI_VECTOR_STORES_PREFIX = f"openai_vector_stores:weaviate:{VERSION}::"
|
||||
OPENAI_VECTOR_STORES_FILES_PREFIX = f"openai_vector_stores_files:weaviate:{VERSION}::"
|
||||
|
|
@ -257,14 +257,14 @@ class WeaviateIndex(EmbeddingIndex):
|
|||
return QueryChunksResponse(chunks=chunks, scores=scores)
|
||||
|
||||
|
||||
class WeaviateVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, NeedsRequestProviderData, VectorDBsProtocolPrivate):
|
||||
class WeaviateVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, NeedsRequestProviderData, VectorStoresProtocolPrivate):
|
||||
def __init__(self, config: WeaviateVectorIOConfig, inference_api: Inference, files_api: Files | None) -> None:
|
||||
super().__init__(files_api=files_api, kvstore=None)
|
||||
self.config = config
|
||||
self.inference_api = inference_api
|
||||
self.client_cache = {}
|
||||
self.cache = {}
|
||||
self.vector_db_store = None
|
||||
self.vector_store_table = None
|
||||
self.metadata_collection_name = "openai_vector_stores_metadata"
|
||||
|
||||
def _get_client(self) -> weaviate.WeaviateClient:
|
||||
|
|
@ -300,11 +300,11 @@ class WeaviateVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, NeedsRequestProv
|
|||
end_key = f"{VECTOR_DBS_PREFIX}\xff"
|
||||
stored = await self.kvstore.values_in_range(start_key, end_key)
|
||||
for raw in stored:
|
||||
vector_db = VectorDB.model_validate_json(raw)
|
||||
vector_store = VectorStore.model_validate_json(raw)
|
||||
client = self._get_client()
|
||||
idx = WeaviateIndex(client=client, collection_name=vector_db.identifier, kvstore=self.kvstore)
|
||||
self.cache[vector_db.identifier] = VectorDBWithIndex(
|
||||
vector_db=vector_db, index=idx, inference_api=self.inference_api
|
||||
idx = WeaviateIndex(client=client, collection_name=vector_store.identifier, kvstore=self.kvstore)
|
||||
self.cache[vector_store.identifier] = VectorStoreWithIndex(
|
||||
vector_store=vector_store, index=idx, inference_api=self.inference_api
|
||||
)
|
||||
|
||||
# Load OpenAI vector stores metadata into cache
|
||||
|
|
@ -316,9 +316,9 @@ class WeaviateVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, NeedsRequestProv
|
|||
# Clean up mixin resources (file batch tasks)
|
||||
await super().shutdown()
|
||||
|
||||
async def register_vector_db(self, vector_db: VectorDB) -> None:
|
||||
async def register_vector_store(self, vector_store: VectorStore) -> None:
|
||||
client = self._get_client()
|
||||
sanitized_collection_name = sanitize_collection_name(vector_db.identifier, weaviate_format=True)
|
||||
sanitized_collection_name = sanitize_collection_name(vector_store.identifier, weaviate_format=True)
|
||||
# Create collection if it doesn't exist
|
||||
if not client.collections.exists(sanitized_collection_name):
|
||||
client.collections.create(
|
||||
|
|
@ -329,45 +329,45 @@ class WeaviateVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, NeedsRequestProv
|
|||
],
|
||||
)
|
||||
|
||||
self.cache[vector_db.identifier] = VectorDBWithIndex(
|
||||
vector_db, WeaviateIndex(client=client, collection_name=sanitized_collection_name), self.inference_api
|
||||
self.cache[vector_store.identifier] = VectorStoreWithIndex(
|
||||
vector_store, WeaviateIndex(client=client, collection_name=sanitized_collection_name), self.inference_api
|
||||
)
|
||||
|
||||
async def unregister_vector_db(self, vector_db_id: str) -> None:
|
||||
async def unregister_vector_store(self, vector_store_id: str) -> None:
|
||||
client = self._get_client()
|
||||
sanitized_collection_name = sanitize_collection_name(vector_db_id, weaviate_format=True)
|
||||
if vector_db_id not in self.cache or client.collections.exists(sanitized_collection_name) is False:
|
||||
sanitized_collection_name = sanitize_collection_name(vector_store_id, weaviate_format=True)
|
||||
if vector_store_id not in self.cache or client.collections.exists(sanitized_collection_name) is False:
|
||||
return
|
||||
client.collections.delete(sanitized_collection_name)
|
||||
await self.cache[vector_db_id].index.delete()
|
||||
del self.cache[vector_db_id]
|
||||
await self.cache[vector_store_id].index.delete()
|
||||
del self.cache[vector_store_id]
|
||||
|
||||
async def _get_and_cache_vector_db_index(self, vector_db_id: str) -> VectorDBWithIndex | None:
|
||||
if vector_db_id in self.cache:
|
||||
return self.cache[vector_db_id]
|
||||
async def _get_and_cache_vector_store_index(self, vector_store_id: str) -> VectorStoreWithIndex | None:
|
||||
if vector_store_id in self.cache:
|
||||
return self.cache[vector_store_id]
|
||||
|
||||
if self.vector_db_store is None:
|
||||
raise VectorStoreNotFoundError(vector_db_id)
|
||||
if self.vector_store_table is None:
|
||||
raise VectorStoreNotFoundError(vector_store_id)
|
||||
|
||||
vector_db = await self.vector_db_store.get_vector_db(vector_db_id)
|
||||
if not vector_db:
|
||||
raise VectorStoreNotFoundError(vector_db_id)
|
||||
vector_store = await self.vector_store_table.get_vector_store(vector_store_id)
|
||||
if not vector_store:
|
||||
raise VectorStoreNotFoundError(vector_store_id)
|
||||
|
||||
client = self._get_client()
|
||||
sanitized_collection_name = sanitize_collection_name(vector_db.identifier, weaviate_format=True)
|
||||
sanitized_collection_name = sanitize_collection_name(vector_store.identifier, weaviate_format=True)
|
||||
if not client.collections.exists(sanitized_collection_name):
|
||||
raise ValueError(f"Collection with name `{sanitized_collection_name}` not found")
|
||||
|
||||
index = VectorDBWithIndex(
|
||||
vector_db=vector_db,
|
||||
index=WeaviateIndex(client=client, collection_name=vector_db.identifier),
|
||||
index = VectorStoreWithIndex(
|
||||
vector_store=vector_store,
|
||||
index=WeaviateIndex(client=client, collection_name=vector_store.identifier),
|
||||
inference_api=self.inference_api,
|
||||
)
|
||||
self.cache[vector_db_id] = index
|
||||
self.cache[vector_store_id] = index
|
||||
return index
|
||||
|
||||
async def insert_chunks(self, vector_db_id: str, chunks: list[Chunk], ttl_seconds: int | None = None) -> None:
|
||||
index = await self._get_and_cache_vector_db_index(vector_db_id)
|
||||
index = await self._get_and_cache_vector_store_index(vector_db_id)
|
||||
if not index:
|
||||
raise VectorStoreNotFoundError(vector_db_id)
|
||||
|
||||
|
|
@ -376,14 +376,14 @@ class WeaviateVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, NeedsRequestProv
|
|||
async def query_chunks(
|
||||
self, vector_db_id: str, query: InterleavedContent, params: dict[str, Any] | None = None
|
||||
) -> QueryChunksResponse:
|
||||
index = await self._get_and_cache_vector_db_index(vector_db_id)
|
||||
index = await self._get_and_cache_vector_store_index(vector_db_id)
|
||||
if not index:
|
||||
raise VectorStoreNotFoundError(vector_db_id)
|
||||
|
||||
return await index.query_chunks(query, params)
|
||||
|
||||
async def delete_chunks(self, store_id: str, chunks_for_deletion: list[ChunkForDeletion]) -> None:
|
||||
index = await self._get_and_cache_vector_db_index(store_id)
|
||||
index = await self._get_and_cache_vector_store_index(store_id)
|
||||
if not index:
|
||||
raise ValueError(f"Vector DB {store_id} not found")
|
||||
|
||||
|
|
|
|||
|
|
@ -17,7 +17,6 @@ from pydantic import TypeAdapter
|
|||
|
||||
from llama_stack.apis.common.errors import VectorStoreNotFoundError
|
||||
from llama_stack.apis.files import Files, OpenAIFileObject
|
||||
from llama_stack.apis.vector_dbs import VectorDB
|
||||
from llama_stack.apis.vector_io import (
|
||||
Chunk,
|
||||
OpenAICreateVectorStoreFileBatchRequestWithExtraBody,
|
||||
|
|
@ -43,6 +42,7 @@ from llama_stack.apis.vector_io import (
|
|||
VectorStoreSearchResponse,
|
||||
VectorStoreSearchResponsePage,
|
||||
)
|
||||
from llama_stack.apis.vector_stores import VectorStore
|
||||
from llama_stack.core.id_generation import generate_object_id
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.utils.kvstore.api import KVStore
|
||||
|
|
@ -63,7 +63,7 @@ MAX_CONCURRENT_FILES_PER_BATCH = 3 # Maximum concurrent file processing within
|
|||
FILE_BATCH_CHUNK_SIZE = 10 # Process files in chunks of this size
|
||||
|
||||
VERSION = "v3"
|
||||
VECTOR_DBS_PREFIX = f"vector_dbs:{VERSION}::"
|
||||
VECTOR_DBS_PREFIX = f"vector_stores:{VERSION}::"
|
||||
OPENAI_VECTOR_STORES_PREFIX = f"openai_vector_stores:{VERSION}::"
|
||||
OPENAI_VECTOR_STORES_FILES_PREFIX = f"openai_vector_stores_files:{VERSION}::"
|
||||
OPENAI_VECTOR_STORES_FILES_CONTENTS_PREFIX = f"openai_vector_stores_files_contents:{VERSION}::"
|
||||
|
|
@ -321,12 +321,12 @@ class OpenAIVectorStoreMixin(ABC):
|
|||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def register_vector_db(self, vector_db: VectorDB) -> None:
|
||||
async def register_vector_store(self, vector_store: VectorStore) -> None:
|
||||
"""Register a vector database (provider-specific implementation)."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def unregister_vector_db(self, vector_db_id: str) -> None:
|
||||
async def unregister_vector_store(self, vector_store_id: str) -> None:
|
||||
"""Unregister a vector database (provider-specific implementation)."""
|
||||
pass
|
||||
|
||||
|
|
@ -358,7 +358,7 @@ class OpenAIVectorStoreMixin(ABC):
|
|||
extra_body = params.model_extra or {}
|
||||
metadata = params.metadata or {}
|
||||
|
||||
provider_vector_db_id = extra_body.get("provider_vector_db_id")
|
||||
provider_vector_store_id = extra_body.get("provider_vector_store_id")
|
||||
|
||||
# Use embedding info from metadata if available, otherwise from extra_body
|
||||
if metadata.get("embedding_model"):
|
||||
|
|
@ -389,8 +389,8 @@ class OpenAIVectorStoreMixin(ABC):
|
|||
|
||||
# use provider_id set by router; fallback to provider's own ID when used directly via --stack-config
|
||||
provider_id = extra_body.get("provider_id") or getattr(self, "__provider_id__", None)
|
||||
# Derive the canonical vector_db_id (allow override, else generate)
|
||||
vector_db_id = provider_vector_db_id or generate_object_id("vector_store", lambda: f"vs_{uuid.uuid4()}")
|
||||
# Derive the canonical vector_store_id (allow override, else generate)
|
||||
vector_store_id = provider_vector_store_id or generate_object_id("vector_store", lambda: f"vs_{uuid.uuid4()}")
|
||||
|
||||
if embedding_model is None:
|
||||
raise ValueError("embedding_model is required")
|
||||
|
|
@ -398,19 +398,20 @@ class OpenAIVectorStoreMixin(ABC):
|
|||
if embedding_dimension is None:
|
||||
raise ValueError("Embedding dimension is required")
|
||||
|
||||
# Register the VectorDB backing this vector store
|
||||
# Register the VectorStore backing this vector store
|
||||
if provider_id is None:
|
||||
raise ValueError("Provider ID is required but was not provided")
|
||||
|
||||
vector_db = VectorDB(
|
||||
identifier=vector_db_id,
|
||||
# call to the provider to create any index, etc.
|
||||
vector_store = VectorStore(
|
||||
identifier=vector_store_id,
|
||||
embedding_dimension=embedding_dimension,
|
||||
embedding_model=embedding_model,
|
||||
provider_id=provider_id,
|
||||
provider_resource_id=vector_db_id,
|
||||
vector_db_name=params.name,
|
||||
provider_resource_id=vector_store_id,
|
||||
vector_store_name=params.name,
|
||||
)
|
||||
await self.register_vector_db(vector_db)
|
||||
await self.register_vector_store(vector_store)
|
||||
|
||||
# Create OpenAI vector store metadata
|
||||
status = "completed"
|
||||
|
|
@ -424,7 +425,7 @@ class OpenAIVectorStoreMixin(ABC):
|
|||
total=0,
|
||||
)
|
||||
store_info: dict[str, Any] = {
|
||||
"id": vector_db_id,
|
||||
"id": vector_store_id,
|
||||
"object": "vector_store",
|
||||
"created_at": created_at,
|
||||
"name": params.name,
|
||||
|
|
@ -441,23 +442,23 @@ class OpenAIVectorStoreMixin(ABC):
|
|||
# Add provider information to metadata if provided
|
||||
if provider_id:
|
||||
metadata["provider_id"] = provider_id
|
||||
if provider_vector_db_id:
|
||||
metadata["provider_vector_db_id"] = provider_vector_db_id
|
||||
if provider_vector_store_id:
|
||||
metadata["provider_vector_store_id"] = provider_vector_store_id
|
||||
store_info["metadata"] = metadata
|
||||
|
||||
# Save to persistent storage (provider-specific)
|
||||
await self._save_openai_vector_store(vector_db_id, store_info)
|
||||
await self._save_openai_vector_store(vector_store_id, store_info)
|
||||
|
||||
# Store in memory cache
|
||||
self.openai_vector_stores[vector_db_id] = store_info
|
||||
self.openai_vector_stores[vector_store_id] = store_info
|
||||
|
||||
# Now that our vector store is created, attach any files that were provided
|
||||
file_ids = params.file_ids or []
|
||||
tasks = [self.openai_attach_file_to_vector_store(vector_db_id, file_id) for file_id in file_ids]
|
||||
tasks = [self.openai_attach_file_to_vector_store(vector_store_id, file_id) for file_id in file_ids]
|
||||
await asyncio.gather(*tasks)
|
||||
|
||||
# Get the updated store info and return it
|
||||
store_info = self.openai_vector_stores[vector_db_id]
|
||||
store_info = self.openai_vector_stores[vector_store_id]
|
||||
return VectorStoreObject.model_validate(store_info)
|
||||
|
||||
async def openai_list_vector_stores(
|
||||
|
|
@ -567,7 +568,7 @@ class OpenAIVectorStoreMixin(ABC):
|
|||
|
||||
# Also delete the underlying vector DB
|
||||
try:
|
||||
await self.unregister_vector_db(vector_store_id)
|
||||
await self.unregister_vector_store(vector_store_id)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to delete underlying vector DB {vector_store_id}: {e}")
|
||||
|
||||
|
|
|
|||
|
|
@ -12,19 +12,16 @@ from dataclasses import dataclass
|
|||
from typing import Any
|
||||
from urllib.parse import unquote
|
||||
|
||||
import httpx
|
||||
import numpy as np
|
||||
from numpy.typing import NDArray
|
||||
from pydantic import BaseModel
|
||||
|
||||
from llama_stack.apis.common.content_types import (
|
||||
URL,
|
||||
InterleavedContent,
|
||||
)
|
||||
from llama_stack.apis.inference import OpenAIEmbeddingsRequestWithExtraBody
|
||||
from llama_stack.apis.tools import RAGDocument
|
||||
from llama_stack.apis.vector_dbs import VectorDB
|
||||
from llama_stack.apis.vector_io import Chunk, ChunkMetadata, QueryChunksResponse
|
||||
from llama_stack.apis.vector_stores import VectorStore
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.models.llama.llama3.tokenizer import Tokenizer
|
||||
from llama_stack.providers.datatypes import Api
|
||||
|
|
@ -129,31 +126,6 @@ def content_from_data_and_mime_type(data: bytes | str, mime_type: str | None, en
|
|||
return ""
|
||||
|
||||
|
||||
async def content_from_doc(doc: RAGDocument) -> str:
|
||||
if isinstance(doc.content, URL):
|
||||
if doc.content.uri.startswith("data:"):
|
||||
return content_from_data(doc.content.uri)
|
||||
async with httpx.AsyncClient() as client:
|
||||
r = await client.get(doc.content.uri)
|
||||
if doc.mime_type == "application/pdf":
|
||||
return parse_pdf(r.content)
|
||||
return r.text
|
||||
elif isinstance(doc.content, str):
|
||||
pattern = re.compile("^(https?://|file://|data:)")
|
||||
if pattern.match(doc.content):
|
||||
if doc.content.startswith("data:"):
|
||||
return content_from_data(doc.content)
|
||||
async with httpx.AsyncClient() as client:
|
||||
r = await client.get(doc.content)
|
||||
if doc.mime_type == "application/pdf":
|
||||
return parse_pdf(r.content)
|
||||
return r.text
|
||||
return doc.content
|
||||
else:
|
||||
# will raise ValueError if the content is not List[InterleavedContent] or InterleavedContent
|
||||
return interleaved_content_as_str(doc.content)
|
||||
|
||||
|
||||
def make_overlapped_chunks(
|
||||
document_id: str, text: str, window_len: int, overlap_len: int, metadata: dict[str, Any]
|
||||
) -> list[Chunk]:
|
||||
|
|
@ -187,7 +159,7 @@ def make_overlapped_chunks(
|
|||
updated_timestamp=int(time.time()),
|
||||
chunk_window=chunk_window,
|
||||
chunk_tokenizer=default_tokenizer,
|
||||
chunk_embedding_model=None, # This will be set in `VectorDBWithIndex.insert_chunks`
|
||||
chunk_embedding_model=None, # This will be set in `VectorStoreWithIndex.insert_chunks`
|
||||
content_token_count=len(toks),
|
||||
metadata_token_count=len(metadata_tokens),
|
||||
)
|
||||
|
|
@ -255,8 +227,8 @@ class EmbeddingIndex(ABC):
|
|||
|
||||
|
||||
@dataclass
|
||||
class VectorDBWithIndex:
|
||||
vector_db: VectorDB
|
||||
class VectorStoreWithIndex:
|
||||
vector_store: VectorStore
|
||||
index: EmbeddingIndex
|
||||
inference_api: Api.inference
|
||||
|
||||
|
|
@ -269,14 +241,14 @@ class VectorDBWithIndex:
|
|||
if c.embedding is None:
|
||||
chunks_to_embed.append(c)
|
||||
if c.chunk_metadata:
|
||||
c.chunk_metadata.chunk_embedding_model = self.vector_db.embedding_model
|
||||
c.chunk_metadata.chunk_embedding_dimension = self.vector_db.embedding_dimension
|
||||
c.chunk_metadata.chunk_embedding_model = self.vector_store.embedding_model
|
||||
c.chunk_metadata.chunk_embedding_dimension = self.vector_store.embedding_dimension
|
||||
else:
|
||||
_validate_embedding(c.embedding, i, self.vector_db.embedding_dimension)
|
||||
_validate_embedding(c.embedding, i, self.vector_store.embedding_dimension)
|
||||
|
||||
if chunks_to_embed:
|
||||
params = OpenAIEmbeddingsRequestWithExtraBody(
|
||||
model=self.vector_db.embedding_model,
|
||||
model=self.vector_store.embedding_model,
|
||||
input=[c.content for c in chunks_to_embed],
|
||||
)
|
||||
resp = await self.inference_api.openai_embeddings(params)
|
||||
|
|
@ -319,7 +291,7 @@ class VectorDBWithIndex:
|
|||
return await self.index.query_keyword(query_string, k, score_threshold)
|
||||
|
||||
params = OpenAIEmbeddingsRequestWithExtraBody(
|
||||
model=self.vector_db.embedding_model,
|
||||
model=self.vector_store.embedding_model,
|
||||
input=[query_string],
|
||||
)
|
||||
embeddings_response = await self.inference_api.openai_embeddings(params)
|
||||
|
|
|
|||
|
|
@ -238,6 +238,8 @@ if [[ "$STACK_CONFIG" == *"docker:"* && "$COLLECT_ONLY" == false ]]; then
|
|||
echo "Stopping Docker container..."
|
||||
container_name="llama-stack-test-$DISTRO"
|
||||
if docker ps -a --format '{{.Names}}' | grep -q "^${container_name}$"; then
|
||||
echo "Dumping container logs before stopping..."
|
||||
docker logs "$container_name" > "docker-${DISTRO}-${INFERENCE_MODE}.log" 2>&1 || true
|
||||
echo "Stopping and removing container: $container_name"
|
||||
docker stop "$container_name" 2>/dev/null || true
|
||||
docker rm "$container_name" 2>/dev/null || true
|
||||
|
|
@ -408,6 +410,21 @@ elif [ $exit_code -eq 5 ]; then
|
|||
echo "⚠️ No tests collected (pattern matched no tests)"
|
||||
else
|
||||
echo "❌ Tests failed"
|
||||
echo ""
|
||||
echo "=== Dumping last 100 lines of logs for debugging ==="
|
||||
|
||||
# Output server or container logs based on stack config
|
||||
if [[ "$STACK_CONFIG" == *"server:"* && -f "server.log" ]]; then
|
||||
echo "--- Last 100 lines of server.log ---"
|
||||
tail -100 server.log
|
||||
elif [[ "$STACK_CONFIG" == *"docker:"* ]]; then
|
||||
docker_log_file="docker-${DISTRO}-${INFERENCE_MODE}.log"
|
||||
if [[ -f "$docker_log_file" ]]; then
|
||||
echo "--- Last 100 lines of $docker_log_file ---"
|
||||
tail -100 "$docker_log_file"
|
||||
fi
|
||||
fi
|
||||
|
||||
exit 1
|
||||
fi
|
||||
|
||||
|
|
|
|||
|
|
@ -37,6 +37,9 @@ def pytest_sessionstart(session):
|
|||
if "LLAMA_STACK_TEST_INFERENCE_MODE" not in os.environ:
|
||||
os.environ["LLAMA_STACK_TEST_INFERENCE_MODE"] = "replay"
|
||||
|
||||
if "LLAMA_STACK_LOGGING" not in os.environ:
|
||||
os.environ["LLAMA_STACK_LOGGING"] = "all=warning"
|
||||
|
||||
if "SQLITE_STORE_DIR" not in os.environ:
|
||||
os.environ["SQLITE_STORE_DIR"] = tempfile.mkdtemp()
|
||||
|
||||
|
|
|
|||
|
|
@ -49,46 +49,50 @@ def client_with_empty_registry(client_with_models):
|
|||
|
||||
|
||||
@vector_provider_wrapper
|
||||
def test_vector_db_retrieve(client_with_empty_registry, embedding_model_id, embedding_dimension, vector_io_provider_id):
|
||||
vector_db_name = "test_vector_db"
|
||||
def test_vector_store_retrieve(
|
||||
client_with_empty_registry, embedding_model_id, embedding_dimension, vector_io_provider_id
|
||||
):
|
||||
vector_store_name = "test_vector_store"
|
||||
create_response = client_with_empty_registry.vector_stores.create(
|
||||
name=vector_db_name,
|
||||
name=vector_store_name,
|
||||
extra_body={
|
||||
"provider_id": vector_io_provider_id,
|
||||
},
|
||||
)
|
||||
|
||||
actual_vector_db_id = create_response.id
|
||||
actual_vector_store_id = create_response.id
|
||||
|
||||
# Retrieve the vector store and validate its properties
|
||||
response = client_with_empty_registry.vector_stores.retrieve(vector_store_id=actual_vector_db_id)
|
||||
response = client_with_empty_registry.vector_stores.retrieve(vector_store_id=actual_vector_store_id)
|
||||
assert response is not None
|
||||
assert response.id == actual_vector_db_id
|
||||
assert response.name == vector_db_name
|
||||
assert response.id == actual_vector_store_id
|
||||
assert response.name == vector_store_name
|
||||
assert response.id.startswith("vs_")
|
||||
|
||||
|
||||
@vector_provider_wrapper
|
||||
def test_vector_db_register(client_with_empty_registry, embedding_model_id, embedding_dimension, vector_io_provider_id):
|
||||
vector_db_name = "test_vector_db"
|
||||
def test_vector_store_register(
|
||||
client_with_empty_registry, embedding_model_id, embedding_dimension, vector_io_provider_id
|
||||
):
|
||||
vector_store_name = "test_vector_store"
|
||||
response = client_with_empty_registry.vector_stores.create(
|
||||
name=vector_db_name,
|
||||
name=vector_store_name,
|
||||
extra_body={
|
||||
"provider_id": vector_io_provider_id,
|
||||
},
|
||||
)
|
||||
|
||||
actual_vector_db_id = response.id
|
||||
assert actual_vector_db_id.startswith("vs_")
|
||||
assert actual_vector_db_id != vector_db_name
|
||||
actual_vector_store_id = response.id
|
||||
assert actual_vector_store_id.startswith("vs_")
|
||||
assert actual_vector_store_id != vector_store_name
|
||||
|
||||
vector_stores = client_with_empty_registry.vector_stores.list()
|
||||
assert len(vector_stores.data) == 1
|
||||
vector_store = vector_stores.data[0]
|
||||
assert vector_store.id == actual_vector_db_id
|
||||
assert vector_store.name == vector_db_name
|
||||
assert vector_store.id == actual_vector_store_id
|
||||
assert vector_store.name == vector_store_name
|
||||
|
||||
client_with_empty_registry.vector_stores.delete(vector_store_id=actual_vector_db_id)
|
||||
client_with_empty_registry.vector_stores.delete(vector_store_id=actual_vector_store_id)
|
||||
|
||||
vector_stores = client_with_empty_registry.vector_stores.list()
|
||||
assert len(vector_stores.data) == 0
|
||||
|
|
@ -108,23 +112,23 @@ def test_vector_db_register(client_with_empty_registry, embedding_model_id, embe
|
|||
def test_insert_chunks(
|
||||
client_with_empty_registry, embedding_model_id, embedding_dimension, sample_chunks, test_case, vector_io_provider_id
|
||||
):
|
||||
vector_db_name = "test_vector_db"
|
||||
vector_store_name = "test_vector_store"
|
||||
create_response = client_with_empty_registry.vector_stores.create(
|
||||
name=vector_db_name,
|
||||
name=vector_store_name,
|
||||
extra_body={
|
||||
"provider_id": vector_io_provider_id,
|
||||
},
|
||||
)
|
||||
|
||||
actual_vector_db_id = create_response.id
|
||||
actual_vector_store_id = create_response.id
|
||||
|
||||
client_with_empty_registry.vector_io.insert(
|
||||
vector_db_id=actual_vector_db_id,
|
||||
vector_db_id=actual_vector_store_id,
|
||||
chunks=sample_chunks,
|
||||
)
|
||||
|
||||
response = client_with_empty_registry.vector_io.query(
|
||||
vector_db_id=actual_vector_db_id,
|
||||
vector_db_id=actual_vector_store_id,
|
||||
query="What is the capital of France?",
|
||||
)
|
||||
assert response is not None
|
||||
|
|
@ -133,7 +137,7 @@ def test_insert_chunks(
|
|||
|
||||
query, expected_doc_id = test_case
|
||||
response = client_with_empty_registry.vector_io.query(
|
||||
vector_db_id=actual_vector_db_id,
|
||||
vector_db_id=actual_vector_store_id,
|
||||
query=query,
|
||||
)
|
||||
assert response is not None
|
||||
|
|
@ -151,15 +155,15 @@ def test_insert_chunks_with_precomputed_embeddings(
|
|||
"inline::qdrant": {"score_threshold": -1.0},
|
||||
"remote::qdrant": {"score_threshold": -1.0},
|
||||
}
|
||||
vector_db_name = "test_precomputed_embeddings_db"
|
||||
vector_store_name = "test_precomputed_embeddings_db"
|
||||
register_response = client_with_empty_registry.vector_stores.create(
|
||||
name=vector_db_name,
|
||||
name=vector_store_name,
|
||||
extra_body={
|
||||
"provider_id": vector_io_provider_id,
|
||||
},
|
||||
)
|
||||
|
||||
actual_vector_db_id = register_response.id
|
||||
actual_vector_store_id = register_response.id
|
||||
|
||||
chunks_with_embeddings = [
|
||||
Chunk(
|
||||
|
|
@ -170,13 +174,13 @@ def test_insert_chunks_with_precomputed_embeddings(
|
|||
]
|
||||
|
||||
client_with_empty_registry.vector_io.insert(
|
||||
vector_db_id=actual_vector_db_id,
|
||||
vector_db_id=actual_vector_store_id,
|
||||
chunks=chunks_with_embeddings,
|
||||
)
|
||||
|
||||
provider = [p.provider_id for p in client_with_empty_registry.providers.list() if p.api == "vector_io"][0]
|
||||
response = client_with_empty_registry.vector_io.query(
|
||||
vector_db_id=actual_vector_db_id,
|
||||
vector_db_id=actual_vector_store_id,
|
||||
query="precomputed embedding test",
|
||||
params=vector_io_provider_params_dict.get(provider, None),
|
||||
)
|
||||
|
|
@ -200,16 +204,16 @@ def test_query_returns_valid_object_when_identical_to_embedding_in_vdb(
|
|||
"remote::qdrant": {"score_threshold": 0.0},
|
||||
"inline::qdrant": {"score_threshold": 0.0},
|
||||
}
|
||||
vector_db_name = "test_precomputed_embeddings_db"
|
||||
vector_store_name = "test_precomputed_embeddings_db"
|
||||
register_response = client_with_empty_registry.vector_stores.create(
|
||||
name=vector_db_name,
|
||||
name=vector_store_name,
|
||||
extra_body={
|
||||
"embedding_model": embedding_model_id,
|
||||
"provider_id": vector_io_provider_id,
|
||||
},
|
||||
)
|
||||
|
||||
actual_vector_db_id = register_response.id
|
||||
actual_vector_store_id = register_response.id
|
||||
|
||||
chunks_with_embeddings = [
|
||||
Chunk(
|
||||
|
|
@ -220,13 +224,13 @@ def test_query_returns_valid_object_when_identical_to_embedding_in_vdb(
|
|||
]
|
||||
|
||||
client_with_empty_registry.vector_io.insert(
|
||||
vector_db_id=actual_vector_db_id,
|
||||
vector_db_id=actual_vector_store_id,
|
||||
chunks=chunks_with_embeddings,
|
||||
)
|
||||
|
||||
provider = [p.provider_id for p in client_with_empty_registry.providers.list() if p.api == "vector_io"][0]
|
||||
response = client_with_empty_registry.vector_io.query(
|
||||
vector_db_id=actual_vector_db_id,
|
||||
vector_db_id=actual_vector_store_id,
|
||||
query="duplicate",
|
||||
params=vector_io_provider_params_dict.get(provider, None),
|
||||
)
|
||||
|
|
|
|||
|
|
@ -21,7 +21,7 @@ async def test_single_provider_auto_selection():
|
|||
Mock(identifier="all-MiniLM-L6-v2", model_type="embedding", metadata={"embedding_dimension": 384})
|
||||
]
|
||||
)
|
||||
mock_routing_table.register_vector_db = AsyncMock(
|
||||
mock_routing_table.register_vector_store = AsyncMock(
|
||||
return_value=Mock(identifier="vs_123", provider_id="inline::faiss", provider_resource_id="vs_123")
|
||||
)
|
||||
mock_routing_table.get_provider_impl = AsyncMock(
|
||||
|
|
|
|||
|
|
@ -4,138 +4,11 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from llama_stack.apis.common.content_types import URL, TextContentItem
|
||||
from llama_stack.apis.tools import RAGDocument
|
||||
from llama_stack.providers.utils.memory.vector_store import content_from_data_and_mime_type, content_from_doc
|
||||
|
||||
|
||||
async def test_content_from_doc_with_url():
|
||||
"""Test extracting content from RAGDocument with URL content."""
|
||||
mock_url = URL(uri="https://example.com")
|
||||
mock_doc = RAGDocument(document_id="foo", content=mock_url)
|
||||
|
||||
mock_response = MagicMock()
|
||||
mock_response.text = "Sample content from URL"
|
||||
|
||||
with patch("httpx.AsyncClient") as mock_client:
|
||||
mock_instance = AsyncMock()
|
||||
mock_instance.get.return_value = mock_response
|
||||
mock_client.return_value.__aenter__.return_value = mock_instance
|
||||
|
||||
result = await content_from_doc(mock_doc)
|
||||
|
||||
assert result == "Sample content from URL"
|
||||
mock_instance.get.assert_called_once_with(mock_url.uri)
|
||||
|
||||
|
||||
async def test_content_from_doc_with_pdf_url():
|
||||
"""Test extracting content from RAGDocument with URL pointing to a PDF."""
|
||||
mock_url = URL(uri="https://example.com/document.pdf")
|
||||
mock_doc = RAGDocument(document_id="foo", content=mock_url, mime_type="application/pdf")
|
||||
|
||||
mock_response = MagicMock()
|
||||
mock_response.content = b"PDF binary data"
|
||||
|
||||
with (
|
||||
patch("httpx.AsyncClient") as mock_client,
|
||||
patch("llama_stack.providers.utils.memory.vector_store.parse_pdf") as mock_parse_pdf,
|
||||
):
|
||||
mock_instance = AsyncMock()
|
||||
mock_instance.get.return_value = mock_response
|
||||
mock_client.return_value.__aenter__.return_value = mock_instance
|
||||
mock_parse_pdf.return_value = "Extracted PDF content"
|
||||
|
||||
result = await content_from_doc(mock_doc)
|
||||
|
||||
assert result == "Extracted PDF content"
|
||||
mock_instance.get.assert_called_once_with(mock_url.uri)
|
||||
mock_parse_pdf.assert_called_once_with(b"PDF binary data")
|
||||
|
||||
|
||||
async def test_content_from_doc_with_data_url():
|
||||
"""Test extracting content from RAGDocument with data URL content."""
|
||||
data_url = "data:text/plain;base64,SGVsbG8gV29ybGQ=" # "Hello World" base64 encoded
|
||||
mock_url = URL(uri=data_url)
|
||||
mock_doc = RAGDocument(document_id="foo", content=mock_url)
|
||||
|
||||
with patch("llama_stack.providers.utils.memory.vector_store.content_from_data") as mock_content_from_data:
|
||||
mock_content_from_data.return_value = "Hello World"
|
||||
|
||||
result = await content_from_doc(mock_doc)
|
||||
|
||||
assert result == "Hello World"
|
||||
mock_content_from_data.assert_called_once_with(data_url)
|
||||
|
||||
|
||||
async def test_content_from_doc_with_string():
|
||||
"""Test extracting content from RAGDocument with string content."""
|
||||
content_string = "This is plain text content"
|
||||
mock_doc = RAGDocument(document_id="foo", content=content_string)
|
||||
|
||||
result = await content_from_doc(mock_doc)
|
||||
|
||||
assert result == content_string
|
||||
|
||||
|
||||
async def test_content_from_doc_with_string_url():
|
||||
"""Test extracting content from RAGDocument with string URL content."""
|
||||
url_string = "https://example.com"
|
||||
mock_doc = RAGDocument(document_id="foo", content=url_string)
|
||||
|
||||
mock_response = MagicMock()
|
||||
mock_response.text = "Sample content from URL string"
|
||||
|
||||
with patch("httpx.AsyncClient") as mock_client:
|
||||
mock_instance = AsyncMock()
|
||||
mock_instance.get.return_value = mock_response
|
||||
mock_client.return_value.__aenter__.return_value = mock_instance
|
||||
|
||||
result = await content_from_doc(mock_doc)
|
||||
|
||||
assert result == "Sample content from URL string"
|
||||
mock_instance.get.assert_called_once_with(url_string)
|
||||
|
||||
|
||||
async def test_content_from_doc_with_string_pdf_url():
|
||||
"""Test extracting content from RAGDocument with string URL pointing to a PDF."""
|
||||
url_string = "https://example.com/document.pdf"
|
||||
mock_doc = RAGDocument(document_id="foo", content=url_string, mime_type="application/pdf")
|
||||
|
||||
mock_response = MagicMock()
|
||||
mock_response.content = b"PDF binary data"
|
||||
|
||||
with (
|
||||
patch("httpx.AsyncClient") as mock_client,
|
||||
patch("llama_stack.providers.utils.memory.vector_store.parse_pdf") as mock_parse_pdf,
|
||||
):
|
||||
mock_instance = AsyncMock()
|
||||
mock_instance.get.return_value = mock_response
|
||||
mock_client.return_value.__aenter__.return_value = mock_instance
|
||||
mock_parse_pdf.return_value = "Extracted PDF content from string URL"
|
||||
|
||||
result = await content_from_doc(mock_doc)
|
||||
|
||||
assert result == "Extracted PDF content from string URL"
|
||||
mock_instance.get.assert_called_once_with(url_string)
|
||||
mock_parse_pdf.assert_called_once_with(b"PDF binary data")
|
||||
|
||||
|
||||
async def test_content_from_doc_with_interleaved_content():
|
||||
"""Test extracting content from RAGDocument with InterleavedContent (the new case added in the commit)."""
|
||||
interleaved_content = [TextContentItem(text="First item"), TextContentItem(text="Second item")]
|
||||
mock_doc = RAGDocument(document_id="foo", content=interleaved_content)
|
||||
|
||||
with patch("llama_stack.providers.utils.memory.vector_store.interleaved_content_as_str") as mock_interleaved:
|
||||
mock_interleaved.return_value = "First item\nSecond item"
|
||||
|
||||
result = await content_from_doc(mock_doc)
|
||||
|
||||
assert result == "First item\nSecond item"
|
||||
mock_interleaved.assert_called_once_with(interleaved_content)
|
||||
from llama_stack.providers.utils.memory.vector_store import content_from_data_and_mime_type
|
||||
|
||||
|
||||
def test_content_from_data_and_mime_type_success_utf8():
|
||||
|
|
@ -178,41 +51,3 @@ def test_content_from_data_and_mime_type_both_encodings_fail():
|
|||
# Should raise an exception instead of returning empty string
|
||||
with pytest.raises(UnicodeDecodeError):
|
||||
content_from_data_and_mime_type(data, mime_type)
|
||||
|
||||
|
||||
async def test_memory_tool_error_handling():
|
||||
"""Test that memory tool handles various failures gracefully without crashing."""
|
||||
from llama_stack.providers.inline.tool_runtime.rag.config import RagToolRuntimeConfig
|
||||
from llama_stack.providers.inline.tool_runtime.rag.memory import MemoryToolRuntimeImpl
|
||||
|
||||
config = RagToolRuntimeConfig()
|
||||
memory_tool = MemoryToolRuntimeImpl(
|
||||
config=config,
|
||||
vector_io_api=AsyncMock(),
|
||||
inference_api=AsyncMock(),
|
||||
files_api=AsyncMock(),
|
||||
)
|
||||
|
||||
docs = [
|
||||
RAGDocument(document_id="good_doc", content="Good content", metadata={}),
|
||||
RAGDocument(document_id="bad_url_doc", content=URL(uri="https://bad.url"), metadata={}),
|
||||
RAGDocument(document_id="another_good_doc", content="Another good content", metadata={}),
|
||||
]
|
||||
|
||||
mock_file1 = MagicMock()
|
||||
mock_file1.id = "file_good1"
|
||||
mock_file2 = MagicMock()
|
||||
mock_file2.id = "file_good2"
|
||||
memory_tool.files_api.openai_upload_file.side_effect = [mock_file1, mock_file2]
|
||||
|
||||
with patch("httpx.AsyncClient") as mock_client:
|
||||
mock_instance = AsyncMock()
|
||||
mock_instance.get.side_effect = Exception("Bad URL")
|
||||
mock_client.return_value.__aenter__.return_value = mock_instance
|
||||
|
||||
# won't raise exception despite one document failing
|
||||
await memory_tool.insert(docs, "vector_store_123")
|
||||
|
||||
# processed 2 documents successfully, skipped 1
|
||||
assert memory_tool.files_api.openai_upload_file.call_count == 2
|
||||
assert memory_tool.vector_io_api.openai_attach_file_to_vector_store.call_count == 2
|
||||
|
|
|
|||
|
|
@ -10,8 +10,8 @@ from unittest.mock import AsyncMock, MagicMock, patch
|
|||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from llama_stack.apis.vector_dbs import VectorDB
|
||||
from llama_stack.apis.vector_io import Chunk, ChunkMetadata, QueryChunksResponse
|
||||
from llama_stack.apis.vector_stores import VectorStore
|
||||
from llama_stack.core.storage.datatypes import KVStoreReference, SqliteKVStoreConfig
|
||||
from llama_stack.providers.inline.vector_io.faiss.config import FaissVectorIOConfig
|
||||
from llama_stack.providers.inline.vector_io.faiss.faiss import FaissIndex, FaissVectorIOAdapter
|
||||
|
|
@ -31,7 +31,7 @@ def vector_provider(request):
|
|||
|
||||
|
||||
@pytest.fixture
|
||||
def vector_db_id() -> str:
|
||||
def vector_store_id() -> str:
|
||||
return f"test-vector-db-{random.randint(1, 100)}"
|
||||
|
||||
|
||||
|
|
@ -149,8 +149,8 @@ async def sqlite_vec_adapter(sqlite_vec_db_path, unique_kvstore_config, mock_inf
|
|||
)
|
||||
collection_id = f"sqlite_test_collection_{np.random.randint(1e6)}"
|
||||
await adapter.initialize()
|
||||
await adapter.register_vector_db(
|
||||
VectorDB(
|
||||
await adapter.register_vector_store(
|
||||
VectorStore(
|
||||
identifier=collection_id,
|
||||
provider_id="test_provider",
|
||||
embedding_model="test_model",
|
||||
|
|
@ -186,8 +186,8 @@ async def faiss_vec_adapter(unique_kvstore_config, mock_inference_api, embedding
|
|||
files_api=None,
|
||||
)
|
||||
await adapter.initialize()
|
||||
await adapter.register_vector_db(
|
||||
VectorDB(
|
||||
await adapter.register_vector_store(
|
||||
VectorStore(
|
||||
identifier=f"faiss_test_collection_{np.random.randint(1e6)}",
|
||||
provider_id="test_provider",
|
||||
embedding_model="test_model",
|
||||
|
|
@ -215,7 +215,7 @@ def mock_psycopg2_connection():
|
|||
async def pgvector_vec_index(embedding_dimension, mock_psycopg2_connection):
|
||||
connection, cursor = mock_psycopg2_connection
|
||||
|
||||
vector_db = VectorDB(
|
||||
vector_store = VectorStore(
|
||||
identifier="test-vector-db",
|
||||
embedding_model="test-model",
|
||||
embedding_dimension=embedding_dimension,
|
||||
|
|
@ -225,7 +225,7 @@ async def pgvector_vec_index(embedding_dimension, mock_psycopg2_connection):
|
|||
|
||||
with patch("llama_stack.providers.remote.vector_io.pgvector.pgvector.psycopg2"):
|
||||
with patch("llama_stack.providers.remote.vector_io.pgvector.pgvector.execute_values"):
|
||||
index = PGVectorIndex(vector_db, embedding_dimension, connection, distance_metric="COSINE")
|
||||
index = PGVectorIndex(vector_store, embedding_dimension, connection, distance_metric="COSINE")
|
||||
index._test_chunks = []
|
||||
original_add_chunks = index.add_chunks
|
||||
|
||||
|
|
@ -281,30 +281,30 @@ async def pgvector_vec_adapter(unique_kvstore_config, mock_inference_api, embedd
|
|||
await adapter.initialize()
|
||||
adapter.conn = mock_conn
|
||||
|
||||
async def mock_insert_chunks(vector_db_id, chunks, ttl_seconds=None):
|
||||
index = await adapter._get_and_cache_vector_db_index(vector_db_id)
|
||||
async def mock_insert_chunks(vector_store_id, chunks, ttl_seconds=None):
|
||||
index = await adapter._get_and_cache_vector_store_index(vector_store_id)
|
||||
if not index:
|
||||
raise ValueError(f"Vector DB {vector_db_id} not found")
|
||||
raise ValueError(f"Vector DB {vector_store_id} not found")
|
||||
await index.insert_chunks(chunks)
|
||||
|
||||
adapter.insert_chunks = mock_insert_chunks
|
||||
|
||||
async def mock_query_chunks(vector_db_id, query, params=None):
|
||||
index = await adapter._get_and_cache_vector_db_index(vector_db_id)
|
||||
async def mock_query_chunks(vector_store_id, query, params=None):
|
||||
index = await adapter._get_and_cache_vector_store_index(vector_store_id)
|
||||
if not index:
|
||||
raise ValueError(f"Vector DB {vector_db_id} not found")
|
||||
raise ValueError(f"Vector DB {vector_store_id} not found")
|
||||
return await index.query_chunks(query, params)
|
||||
|
||||
adapter.query_chunks = mock_query_chunks
|
||||
|
||||
test_vector_db = VectorDB(
|
||||
test_vector_store = VectorStore(
|
||||
identifier=f"pgvector_test_collection_{random.randint(1, 1_000_000)}",
|
||||
provider_id="test_provider",
|
||||
embedding_model="test_model",
|
||||
embedding_dimension=embedding_dimension,
|
||||
)
|
||||
await adapter.register_vector_db(test_vector_db)
|
||||
adapter.test_collection_id = test_vector_db.identifier
|
||||
await adapter.register_vector_store(test_vector_store)
|
||||
adapter.test_collection_id = test_vector_store.identifier
|
||||
|
||||
yield adapter
|
||||
await adapter.shutdown()
|
||||
|
|
|
|||
|
|
@ -11,8 +11,8 @@ import numpy as np
|
|||
import pytest
|
||||
|
||||
from llama_stack.apis.files import Files
|
||||
from llama_stack.apis.vector_dbs import VectorDB
|
||||
from llama_stack.apis.vector_io import Chunk, QueryChunksResponse
|
||||
from llama_stack.apis.vector_stores import VectorStore
|
||||
from llama_stack.providers.datatypes import HealthStatus
|
||||
from llama_stack.providers.inline.vector_io.faiss.config import FaissVectorIOConfig
|
||||
from llama_stack.providers.inline.vector_io.faiss.faiss import (
|
||||
|
|
@ -43,8 +43,8 @@ def embedding_dimension():
|
|||
|
||||
|
||||
@pytest.fixture
|
||||
def vector_db_id():
|
||||
return "test_vector_db"
|
||||
def vector_store_id():
|
||||
return "test_vector_store"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
|
|
@ -61,12 +61,12 @@ def sample_embeddings(embedding_dimension):
|
|||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_vector_db(vector_db_id, embedding_dimension) -> MagicMock:
|
||||
mock_vector_db = MagicMock(spec=VectorDB)
|
||||
mock_vector_db.embedding_model = "mock_embedding_model"
|
||||
mock_vector_db.identifier = vector_db_id
|
||||
mock_vector_db.embedding_dimension = embedding_dimension
|
||||
return mock_vector_db
|
||||
def mock_vector_store(vector_store_id, embedding_dimension) -> MagicMock:
|
||||
mock_vector_store = MagicMock(spec=VectorStore)
|
||||
mock_vector_store.embedding_model = "mock_embedding_model"
|
||||
mock_vector_store.identifier = vector_store_id
|
||||
mock_vector_store.embedding_dimension = embedding_dimension
|
||||
return mock_vector_store
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
|
|
|
|||
|
|
@ -12,7 +12,6 @@ import numpy as np
|
|||
import pytest
|
||||
|
||||
from llama_stack.apis.common.errors import VectorStoreNotFoundError
|
||||
from llama_stack.apis.vector_dbs import VectorDB
|
||||
from llama_stack.apis.vector_io import (
|
||||
Chunk,
|
||||
OpenAICreateVectorStoreFileBatchRequestWithExtraBody,
|
||||
|
|
@ -21,6 +20,7 @@ from llama_stack.apis.vector_io import (
|
|||
VectorStoreChunkingStrategyAuto,
|
||||
VectorStoreFileObject,
|
||||
)
|
||||
from llama_stack.apis.vector_stores import VectorStore
|
||||
from llama_stack.providers.inline.vector_io.sqlite_vec.sqlite_vec import VECTOR_DBS_PREFIX
|
||||
|
||||
# This test is a unit test for the inline VectorIO providers. This should only contain
|
||||
|
|
@ -71,7 +71,7 @@ async def test_chunk_id_conflict(vector_index, sample_chunks, embedding_dimensio
|
|||
|
||||
async def test_initialize_adapter_with_existing_kvstore(vector_io_adapter):
|
||||
key = f"{VECTOR_DBS_PREFIX}db1"
|
||||
dummy = VectorDB(
|
||||
dummy = VectorStore(
|
||||
identifier="foo_db", provider_id="test_provider", embedding_model="test_model", embedding_dimension=128
|
||||
)
|
||||
await vector_io_adapter.kvstore.set(key=key, value=json.dumps(dummy.model_dump()))
|
||||
|
|
@ -81,10 +81,10 @@ async def test_initialize_adapter_with_existing_kvstore(vector_io_adapter):
|
|||
|
||||
async def test_persistence_across_adapter_restarts(vector_io_adapter):
|
||||
await vector_io_adapter.initialize()
|
||||
dummy = VectorDB(
|
||||
dummy = VectorStore(
|
||||
identifier="foo_db", provider_id="test_provider", embedding_model="test_model", embedding_dimension=128
|
||||
)
|
||||
await vector_io_adapter.register_vector_db(dummy)
|
||||
await vector_io_adapter.register_vector_store(dummy)
|
||||
await vector_io_adapter.shutdown()
|
||||
|
||||
await vector_io_adapter.initialize()
|
||||
|
|
@ -92,15 +92,15 @@ async def test_persistence_across_adapter_restarts(vector_io_adapter):
|
|||
await vector_io_adapter.shutdown()
|
||||
|
||||
|
||||
async def test_register_and_unregister_vector_db(vector_io_adapter):
|
||||
async def test_register_and_unregister_vector_store(vector_io_adapter):
|
||||
unique_id = f"foo_db_{np.random.randint(1e6)}"
|
||||
dummy = VectorDB(
|
||||
dummy = VectorStore(
|
||||
identifier=unique_id, provider_id="test_provider", embedding_model="test_model", embedding_dimension=128
|
||||
)
|
||||
|
||||
await vector_io_adapter.register_vector_db(dummy)
|
||||
await vector_io_adapter.register_vector_store(dummy)
|
||||
assert dummy.identifier in vector_io_adapter.cache
|
||||
await vector_io_adapter.unregister_vector_db(dummy.identifier)
|
||||
await vector_io_adapter.unregister_vector_store(dummy.identifier)
|
||||
assert dummy.identifier not in vector_io_adapter.cache
|
||||
|
||||
|
||||
|
|
@ -121,7 +121,7 @@ async def test_insert_chunks_calls_underlying_index(vector_io_adapter):
|
|||
|
||||
|
||||
async def test_insert_chunks_missing_db_raises(vector_io_adapter):
|
||||
vector_io_adapter._get_and_cache_vector_db_index = AsyncMock(return_value=None)
|
||||
vector_io_adapter._get_and_cache_vector_store_index = AsyncMock(return_value=None)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
await vector_io_adapter.insert_chunks("db_not_exist", [])
|
||||
|
|
@ -170,7 +170,7 @@ async def test_query_chunks_calls_underlying_index_and_returns(vector_io_adapter
|
|||
|
||||
|
||||
async def test_query_chunks_missing_db_raises(vector_io_adapter):
|
||||
vector_io_adapter._get_and_cache_vector_db_index = AsyncMock(return_value=None)
|
||||
vector_io_adapter._get_and_cache_vector_store_index = AsyncMock(return_value=None)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
await vector_io_adapter.query_chunks("db_missing", "q", None)
|
||||
|
|
@ -182,7 +182,7 @@ async def test_save_openai_vector_store(vector_io_adapter):
|
|||
"id": store_id,
|
||||
"name": "Test Store",
|
||||
"description": "A test OpenAI vector store",
|
||||
"vector_db_id": "test_db",
|
||||
"vector_store_id": "test_db",
|
||||
"embedding_model": "test_model",
|
||||
}
|
||||
|
||||
|
|
@ -198,7 +198,7 @@ async def test_update_openai_vector_store(vector_io_adapter):
|
|||
"id": store_id,
|
||||
"name": "Test Store",
|
||||
"description": "A test OpenAI vector store",
|
||||
"vector_db_id": "test_db",
|
||||
"vector_store_id": "test_db",
|
||||
"embedding_model": "test_model",
|
||||
}
|
||||
|
||||
|
|
@ -214,7 +214,7 @@ async def test_delete_openai_vector_store(vector_io_adapter):
|
|||
"id": store_id,
|
||||
"name": "Test Store",
|
||||
"description": "A test OpenAI vector store",
|
||||
"vector_db_id": "test_db",
|
||||
"vector_store_id": "test_db",
|
||||
"embedding_model": "test_model",
|
||||
}
|
||||
|
||||
|
|
@ -229,7 +229,7 @@ async def test_load_openai_vector_stores(vector_io_adapter):
|
|||
"id": store_id,
|
||||
"name": "Test Store",
|
||||
"description": "A test OpenAI vector store",
|
||||
"vector_db_id": "test_db",
|
||||
"vector_store_id": "test_db",
|
||||
"embedding_model": "test_model",
|
||||
}
|
||||
|
||||
|
|
@ -998,8 +998,8 @@ async def test_max_concurrent_files_per_batch(vector_io_adapter):
|
|||
async def test_embedding_config_from_metadata(vector_io_adapter):
|
||||
"""Test that embedding configuration is correctly extracted from metadata."""
|
||||
|
||||
# Mock register_vector_db to avoid actual registration
|
||||
vector_io_adapter.register_vector_db = AsyncMock()
|
||||
# Mock register_vector_store to avoid actual registration
|
||||
vector_io_adapter.register_vector_store = AsyncMock()
|
||||
# Set provider_id attribute for the adapter
|
||||
vector_io_adapter.__provider_id__ = "test_provider"
|
||||
|
||||
|
|
@ -1015,9 +1015,9 @@ async def test_embedding_config_from_metadata(vector_io_adapter):
|
|||
|
||||
await vector_io_adapter.openai_create_vector_store(params)
|
||||
|
||||
# Verify VectorDB was registered with correct embedding config from metadata
|
||||
vector_io_adapter.register_vector_db.assert_called_once()
|
||||
call_args = vector_io_adapter.register_vector_db.call_args[0][0]
|
||||
# Verify VectorStore was registered with correct embedding config from metadata
|
||||
vector_io_adapter.register_vector_store.assert_called_once()
|
||||
call_args = vector_io_adapter.register_vector_store.call_args[0][0]
|
||||
assert call_args.embedding_model == "test-embedding-model"
|
||||
assert call_args.embedding_dimension == 512
|
||||
|
||||
|
|
@ -1025,8 +1025,8 @@ async def test_embedding_config_from_metadata(vector_io_adapter):
|
|||
async def test_embedding_config_from_extra_body(vector_io_adapter):
|
||||
"""Test that embedding configuration is correctly extracted from extra_body when metadata is empty."""
|
||||
|
||||
# Mock register_vector_db to avoid actual registration
|
||||
vector_io_adapter.register_vector_db = AsyncMock()
|
||||
# Mock register_vector_store to avoid actual registration
|
||||
vector_io_adapter.register_vector_store = AsyncMock()
|
||||
# Set provider_id attribute for the adapter
|
||||
vector_io_adapter.__provider_id__ = "test_provider"
|
||||
|
||||
|
|
@ -1042,9 +1042,9 @@ async def test_embedding_config_from_extra_body(vector_io_adapter):
|
|||
|
||||
await vector_io_adapter.openai_create_vector_store(params)
|
||||
|
||||
# Verify VectorDB was registered with correct embedding config from extra_body
|
||||
vector_io_adapter.register_vector_db.assert_called_once()
|
||||
call_args = vector_io_adapter.register_vector_db.call_args[0][0]
|
||||
# Verify VectorStore was registered with correct embedding config from extra_body
|
||||
vector_io_adapter.register_vector_store.assert_called_once()
|
||||
call_args = vector_io_adapter.register_vector_store.call_args[0][0]
|
||||
assert call_args.embedding_model == "extra-body-model"
|
||||
assert call_args.embedding_dimension == 1024
|
||||
|
||||
|
|
@ -1052,8 +1052,8 @@ async def test_embedding_config_from_extra_body(vector_io_adapter):
|
|||
async def test_embedding_config_consistency_check_passes(vector_io_adapter):
|
||||
"""Test that consistent embedding config in both metadata and extra_body passes validation."""
|
||||
|
||||
# Mock register_vector_db to avoid actual registration
|
||||
vector_io_adapter.register_vector_db = AsyncMock()
|
||||
# Mock register_vector_store to avoid actual registration
|
||||
vector_io_adapter.register_vector_store = AsyncMock()
|
||||
# Set provider_id attribute for the adapter
|
||||
vector_io_adapter.__provider_id__ = "test_provider"
|
||||
|
||||
|
|
@ -1073,8 +1073,8 @@ async def test_embedding_config_consistency_check_passes(vector_io_adapter):
|
|||
await vector_io_adapter.openai_create_vector_store(params)
|
||||
|
||||
# Should not raise any error and use metadata config
|
||||
vector_io_adapter.register_vector_db.assert_called_once()
|
||||
call_args = vector_io_adapter.register_vector_db.call_args[0][0]
|
||||
vector_io_adapter.register_vector_store.assert_called_once()
|
||||
call_args = vector_io_adapter.register_vector_store.call_args[0][0]
|
||||
assert call_args.embedding_model == "consistent-model"
|
||||
assert call_args.embedding_dimension == 768
|
||||
|
||||
|
|
@ -1082,8 +1082,8 @@ async def test_embedding_config_consistency_check_passes(vector_io_adapter):
|
|||
async def test_embedding_config_inconsistency_errors(vector_io_adapter):
|
||||
"""Test that inconsistent embedding config between metadata and extra_body raises errors."""
|
||||
|
||||
# Mock register_vector_db to avoid actual registration
|
||||
vector_io_adapter.register_vector_db = AsyncMock()
|
||||
# Mock register_vector_store to avoid actual registration
|
||||
vector_io_adapter.register_vector_store = AsyncMock()
|
||||
# Set provider_id attribute for the adapter
|
||||
vector_io_adapter.__provider_id__ = "test_provider"
|
||||
|
||||
|
|
@ -1104,7 +1104,7 @@ async def test_embedding_config_inconsistency_errors(vector_io_adapter):
|
|||
await vector_io_adapter.openai_create_vector_store(params)
|
||||
|
||||
# Reset mock for second test
|
||||
vector_io_adapter.register_vector_db.reset_mock()
|
||||
vector_io_adapter.register_vector_store.reset_mock()
|
||||
|
||||
# Test with inconsistent embedding dimension
|
||||
params = OpenAICreateVectorStoreRequestWithExtraBody(
|
||||
|
|
@ -1126,8 +1126,8 @@ async def test_embedding_config_inconsistency_errors(vector_io_adapter):
|
|||
async def test_embedding_config_defaults_when_missing(vector_io_adapter):
|
||||
"""Test that embedding dimension defaults to 768 when not provided."""
|
||||
|
||||
# Mock register_vector_db to avoid actual registration
|
||||
vector_io_adapter.register_vector_db = AsyncMock()
|
||||
# Mock register_vector_store to avoid actual registration
|
||||
vector_io_adapter.register_vector_store = AsyncMock()
|
||||
# Set provider_id attribute for the adapter
|
||||
vector_io_adapter.__provider_id__ = "test_provider"
|
||||
|
||||
|
|
@ -1143,8 +1143,8 @@ async def test_embedding_config_defaults_when_missing(vector_io_adapter):
|
|||
await vector_io_adapter.openai_create_vector_store(params)
|
||||
|
||||
# Should default to 768 dimensions
|
||||
vector_io_adapter.register_vector_db.assert_called_once()
|
||||
call_args = vector_io_adapter.register_vector_db.call_args[0][0]
|
||||
vector_io_adapter.register_vector_store.assert_called_once()
|
||||
call_args = vector_io_adapter.register_vector_store.call_args[0][0]
|
||||
assert call_args.embedding_model == "model-without-dimension"
|
||||
assert call_args.embedding_dimension == 768
|
||||
|
||||
|
|
@ -1152,8 +1152,8 @@ async def test_embedding_config_defaults_when_missing(vector_io_adapter):
|
|||
async def test_embedding_config_required_model_missing(vector_io_adapter):
|
||||
"""Test that missing embedding model raises error."""
|
||||
|
||||
# Mock register_vector_db to avoid actual registration
|
||||
vector_io_adapter.register_vector_db = AsyncMock()
|
||||
# Mock register_vector_store to avoid actual registration
|
||||
vector_io_adapter.register_vector_store = AsyncMock()
|
||||
# Set provider_id attribute for the adapter
|
||||
vector_io_adapter.__provider_id__ = "test_provider"
|
||||
# Mock the default model lookup to return None (no default model available)
|
||||
|
|
|
|||
|
|
@ -1,138 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from llama_stack.apis.tools.rag_tool import RAGQueryConfig
|
||||
from llama_stack.apis.vector_io import (
|
||||
Chunk,
|
||||
ChunkMetadata,
|
||||
QueryChunksResponse,
|
||||
)
|
||||
from llama_stack.providers.inline.tool_runtime.rag.memory import MemoryToolRuntimeImpl
|
||||
|
||||
|
||||
class TestRagQuery:
|
||||
async def test_query_raises_on_empty_vector_db_ids(self):
|
||||
rag_tool = MemoryToolRuntimeImpl(
|
||||
config=MagicMock(), vector_io_api=MagicMock(), inference_api=MagicMock(), files_api=MagicMock()
|
||||
)
|
||||
with pytest.raises(ValueError):
|
||||
await rag_tool.query(content=MagicMock(), vector_db_ids=[])
|
||||
|
||||
async def test_query_chunk_metadata_handling(self):
|
||||
rag_tool = MemoryToolRuntimeImpl(
|
||||
config=MagicMock(), vector_io_api=MagicMock(), inference_api=MagicMock(), files_api=MagicMock()
|
||||
)
|
||||
content = "test query content"
|
||||
vector_db_ids = ["db1"]
|
||||
|
||||
chunk_metadata = ChunkMetadata(
|
||||
document_id="doc1",
|
||||
chunk_id="chunk1",
|
||||
source="test_source",
|
||||
metadata_token_count=5,
|
||||
)
|
||||
interleaved_content = MagicMock()
|
||||
chunk = Chunk(
|
||||
content=interleaved_content,
|
||||
metadata={
|
||||
"key1": "value1",
|
||||
"token_count": 10,
|
||||
"metadata_token_count": 5,
|
||||
# Note this is inserted into `metadata` during MemoryToolRuntimeImpl().insert()
|
||||
"document_id": "doc1",
|
||||
},
|
||||
stored_chunk_id="chunk1",
|
||||
chunk_metadata=chunk_metadata,
|
||||
)
|
||||
|
||||
query_response = QueryChunksResponse(chunks=[chunk], scores=[1.0])
|
||||
|
||||
rag_tool.vector_io_api.query_chunks = AsyncMock(return_value=query_response)
|
||||
result = await rag_tool.query(content=content, vector_db_ids=vector_db_ids)
|
||||
|
||||
assert result is not None
|
||||
expected_metadata_string = (
|
||||
"Metadata: {'chunk_id': 'chunk1', 'document_id': 'doc1', 'source': 'test_source', 'key1': 'value1'}"
|
||||
)
|
||||
assert expected_metadata_string in result.content[1].text
|
||||
assert result.content is not None
|
||||
|
||||
async def test_query_raises_incorrect_mode(self):
|
||||
with pytest.raises(ValueError):
|
||||
RAGQueryConfig(mode="invalid_mode")
|
||||
|
||||
async def test_query_accepts_valid_modes(self):
|
||||
default_config = RAGQueryConfig() # Test default (vector)
|
||||
assert default_config.mode == "vector"
|
||||
vector_config = RAGQueryConfig(mode="vector") # Test vector
|
||||
assert vector_config.mode == "vector"
|
||||
keyword_config = RAGQueryConfig(mode="keyword") # Test keyword
|
||||
assert keyword_config.mode == "keyword"
|
||||
hybrid_config = RAGQueryConfig(mode="hybrid") # Test hybrid
|
||||
assert hybrid_config.mode == "hybrid"
|
||||
|
||||
# Test that invalid mode raises an error
|
||||
with pytest.raises(ValueError):
|
||||
RAGQueryConfig(mode="wrong_mode")
|
||||
|
||||
async def test_query_adds_vector_db_id_to_chunk_metadata(self):
|
||||
rag_tool = MemoryToolRuntimeImpl(
|
||||
config=MagicMock(),
|
||||
vector_io_api=MagicMock(),
|
||||
inference_api=MagicMock(),
|
||||
files_api=MagicMock(),
|
||||
)
|
||||
|
||||
vector_db_ids = ["db1", "db2"]
|
||||
|
||||
# Fake chunks from each DB
|
||||
chunk_metadata1 = ChunkMetadata(
|
||||
document_id="doc1",
|
||||
chunk_id="chunk1",
|
||||
source="test_source1",
|
||||
metadata_token_count=5,
|
||||
)
|
||||
chunk1 = Chunk(
|
||||
content="chunk from db1",
|
||||
metadata={"vector_db_id": "db1", "document_id": "doc1"},
|
||||
stored_chunk_id="c1",
|
||||
chunk_metadata=chunk_metadata1,
|
||||
)
|
||||
|
||||
chunk_metadata2 = ChunkMetadata(
|
||||
document_id="doc2",
|
||||
chunk_id="chunk2",
|
||||
source="test_source2",
|
||||
metadata_token_count=5,
|
||||
)
|
||||
chunk2 = Chunk(
|
||||
content="chunk from db2",
|
||||
metadata={"vector_db_id": "db2", "document_id": "doc2"},
|
||||
stored_chunk_id="c2",
|
||||
chunk_metadata=chunk_metadata2,
|
||||
)
|
||||
|
||||
rag_tool.vector_io_api.query_chunks = AsyncMock(
|
||||
side_effect=[
|
||||
QueryChunksResponse(chunks=[chunk1], scores=[0.9]),
|
||||
QueryChunksResponse(chunks=[chunk2], scores=[0.8]),
|
||||
]
|
||||
)
|
||||
|
||||
result = await rag_tool.query(content="test", vector_db_ids=vector_db_ids)
|
||||
returned_chunks = result.metadata["chunks"]
|
||||
returned_scores = result.metadata["scores"]
|
||||
returned_doc_ids = result.metadata["document_ids"]
|
||||
returned_vector_db_ids = result.metadata["vector_db_ids"]
|
||||
|
||||
assert returned_chunks == ["chunk from db1", "chunk from db2"]
|
||||
assert returned_scores == (0.9, 0.8)
|
||||
assert returned_doc_ids == ["doc1", "doc2"]
|
||||
assert returned_vector_db_ids == ["db1", "db2"]
|
||||
|
|
@ -4,10 +4,6 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import base64
|
||||
import mimetypes
|
||||
import os
|
||||
from pathlib import Path
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import numpy as np
|
||||
|
|
@ -17,37 +13,13 @@ from llama_stack.apis.inference.inference import (
|
|||
OpenAIEmbeddingData,
|
||||
OpenAIEmbeddingsRequestWithExtraBody,
|
||||
)
|
||||
from llama_stack.apis.tools import RAGDocument
|
||||
from llama_stack.apis.vector_io import Chunk
|
||||
from llama_stack.providers.utils.memory.vector_store import (
|
||||
URL,
|
||||
VectorDBWithIndex,
|
||||
VectorStoreWithIndex,
|
||||
_validate_embedding,
|
||||
content_from_doc,
|
||||
make_overlapped_chunks,
|
||||
)
|
||||
|
||||
DUMMY_PDF_PATH = Path(os.path.abspath(__file__)).parent / "fixtures" / "dummy.pdf"
|
||||
# Depending on the machine, this can get parsed a couple of ways
|
||||
DUMMY_PDF_TEXT_CHOICES = ["Dummy PDF file", "Dumm y PDF file"]
|
||||
|
||||
|
||||
def read_file(file_path: str) -> bytes:
|
||||
with open(file_path, "rb") as file:
|
||||
return file.read()
|
||||
|
||||
|
||||
def data_url_from_file(file_path: str) -> str:
|
||||
with open(file_path, "rb") as file:
|
||||
file_content = file.read()
|
||||
|
||||
base64_content = base64.b64encode(file_content).decode("utf-8")
|
||||
mime_type, _ = mimetypes.guess_type(file_path)
|
||||
|
||||
data_url = f"data:{mime_type};base64,{base64_content}"
|
||||
|
||||
return data_url
|
||||
|
||||
|
||||
class TestChunk:
|
||||
def test_chunk(self):
|
||||
|
|
@ -116,45 +88,6 @@ class TestValidateEmbedding:
|
|||
|
||||
|
||||
class TestVectorStore:
|
||||
async def test_returns_content_from_pdf_data_uri(self):
|
||||
data_uri = data_url_from_file(DUMMY_PDF_PATH)
|
||||
doc = RAGDocument(
|
||||
document_id="dummy",
|
||||
content=data_uri,
|
||||
mime_type="application/pdf",
|
||||
metadata={},
|
||||
)
|
||||
content = await content_from_doc(doc)
|
||||
assert content in DUMMY_PDF_TEXT_CHOICES
|
||||
|
||||
@pytest.mark.allow_network
|
||||
async def test_downloads_pdf_and_returns_content(self):
|
||||
# Using GitHub to host the PDF file
|
||||
url = "https://raw.githubusercontent.com/meta-llama/llama-stack/da035d69cfca915318eaf485770a467ca3c2a238/llama_stack/providers/tests/memory/fixtures/dummy.pdf"
|
||||
doc = RAGDocument(
|
||||
document_id="dummy",
|
||||
content=url,
|
||||
mime_type="application/pdf",
|
||||
metadata={},
|
||||
)
|
||||
content = await content_from_doc(doc)
|
||||
assert content in DUMMY_PDF_TEXT_CHOICES
|
||||
|
||||
@pytest.mark.allow_network
|
||||
async def test_downloads_pdf_and_returns_content_with_url_object(self):
|
||||
# Using GitHub to host the PDF file
|
||||
url = "https://raw.githubusercontent.com/meta-llama/llama-stack/da035d69cfca915318eaf485770a467ca3c2a238/llama_stack/providers/tests/memory/fixtures/dummy.pdf"
|
||||
doc = RAGDocument(
|
||||
document_id="dummy",
|
||||
content=URL(
|
||||
uri=url,
|
||||
),
|
||||
mime_type="application/pdf",
|
||||
metadata={},
|
||||
)
|
||||
content = await content_from_doc(doc)
|
||||
assert content in DUMMY_PDF_TEXT_CHOICES
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"window_len, overlap_len, expected_chunks",
|
||||
[
|
||||
|
|
@ -206,15 +139,15 @@ class TestVectorStore:
|
|||
assert str(excinfo.value.__cause__) == "Cannot convert to string"
|
||||
|
||||
|
||||
class TestVectorDBWithIndex:
|
||||
class TestVectorStoreWithIndex:
|
||||
async def test_insert_chunks_without_embeddings(self):
|
||||
mock_vector_db = MagicMock()
|
||||
mock_vector_db.embedding_model = "test-model without embeddings"
|
||||
mock_vector_store = MagicMock()
|
||||
mock_vector_store.embedding_model = "test-model without embeddings"
|
||||
mock_index = AsyncMock()
|
||||
mock_inference_api = AsyncMock()
|
||||
|
||||
vector_db_with_index = VectorDBWithIndex(
|
||||
vector_db=mock_vector_db, index=mock_index, inference_api=mock_inference_api
|
||||
vector_store_with_index = VectorStoreWithIndex(
|
||||
vector_store=mock_vector_store, index=mock_index, inference_api=mock_inference_api
|
||||
)
|
||||
|
||||
chunks = [
|
||||
|
|
@ -227,7 +160,7 @@ class TestVectorDBWithIndex:
|
|||
OpenAIEmbeddingData(embedding=[0.4, 0.5, 0.6], index=1),
|
||||
]
|
||||
|
||||
await vector_db_with_index.insert_chunks(chunks)
|
||||
await vector_store_with_index.insert_chunks(chunks)
|
||||
|
||||
# Verify openai_embeddings was called with correct params
|
||||
mock_inference_api.openai_embeddings.assert_called_once()
|
||||
|
|
@ -243,14 +176,14 @@ class TestVectorDBWithIndex:
|
|||
assert np.array_equal(args[1], np.array([[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]], dtype=np.float32))
|
||||
|
||||
async def test_insert_chunks_with_valid_embeddings(self):
|
||||
mock_vector_db = MagicMock()
|
||||
mock_vector_db.embedding_model = "test-model with embeddings"
|
||||
mock_vector_db.embedding_dimension = 3
|
||||
mock_vector_store = MagicMock()
|
||||
mock_vector_store.embedding_model = "test-model with embeddings"
|
||||
mock_vector_store.embedding_dimension = 3
|
||||
mock_index = AsyncMock()
|
||||
mock_inference_api = AsyncMock()
|
||||
|
||||
vector_db_with_index = VectorDBWithIndex(
|
||||
vector_db=mock_vector_db, index=mock_index, inference_api=mock_inference_api
|
||||
vector_store_with_index = VectorStoreWithIndex(
|
||||
vector_store=mock_vector_store, index=mock_index, inference_api=mock_inference_api
|
||||
)
|
||||
|
||||
chunks = [
|
||||
|
|
@ -258,7 +191,7 @@ class TestVectorDBWithIndex:
|
|||
Chunk(content="Test 2", embedding=[0.4, 0.5, 0.6], metadata={}),
|
||||
]
|
||||
|
||||
await vector_db_with_index.insert_chunks(chunks)
|
||||
await vector_store_with_index.insert_chunks(chunks)
|
||||
|
||||
mock_inference_api.openai_embeddings.assert_not_called()
|
||||
mock_index.add_chunks.assert_called_once()
|
||||
|
|
@ -267,14 +200,14 @@ class TestVectorDBWithIndex:
|
|||
assert np.array_equal(args[1], np.array([[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]], dtype=np.float32))
|
||||
|
||||
async def test_insert_chunks_with_invalid_embeddings(self):
|
||||
mock_vector_db = MagicMock()
|
||||
mock_vector_db.embedding_dimension = 3
|
||||
mock_vector_db.embedding_model = "test-model with invalid embeddings"
|
||||
mock_vector_store = MagicMock()
|
||||
mock_vector_store.embedding_dimension = 3
|
||||
mock_vector_store.embedding_model = "test-model with invalid embeddings"
|
||||
mock_index = AsyncMock()
|
||||
mock_inference_api = AsyncMock()
|
||||
|
||||
vector_db_with_index = VectorDBWithIndex(
|
||||
vector_db=mock_vector_db, index=mock_index, inference_api=mock_inference_api
|
||||
vector_store_with_index = VectorStoreWithIndex(
|
||||
vector_store=mock_vector_store, index=mock_index, inference_api=mock_inference_api
|
||||
)
|
||||
|
||||
# Verify Chunk raises ValueError for invalid embedding type
|
||||
|
|
@ -283,7 +216,7 @@ class TestVectorDBWithIndex:
|
|||
|
||||
# Verify Chunk raises ValueError for invalid embedding type in insert_chunks (i.e., Chunk errors before insert_chunks is called)
|
||||
with pytest.raises(ValueError, match="Input should be a valid list"):
|
||||
await vector_db_with_index.insert_chunks(
|
||||
await vector_store_with_index.insert_chunks(
|
||||
[
|
||||
Chunk(content="Test 1", embedding=None, metadata={}),
|
||||
Chunk(content="Test 2", embedding="invalid_type", metadata={}),
|
||||
|
|
@ -292,7 +225,7 @@ class TestVectorDBWithIndex:
|
|||
|
||||
# Verify Chunk raises ValueError for invalid embedding element type in insert_chunks (i.e., Chunk errors before insert_chunks is called)
|
||||
with pytest.raises(ValueError, match=" Input should be a valid number, unable to parse string as a number "):
|
||||
await vector_db_with_index.insert_chunks(
|
||||
await vector_store_with_index.insert_chunks(
|
||||
Chunk(content="Test 1", embedding=[0.1, "string", 0.3], metadata={})
|
||||
)
|
||||
|
||||
|
|
@ -300,20 +233,20 @@ class TestVectorDBWithIndex:
|
|||
Chunk(content="Test 1", embedding=[0.1, 0.2, 0.3, 0.4], metadata={}),
|
||||
]
|
||||
with pytest.raises(ValueError, match="has dimension 4, expected 3"):
|
||||
await vector_db_with_index.insert_chunks(chunks_wrong_dim)
|
||||
await vector_store_with_index.insert_chunks(chunks_wrong_dim)
|
||||
|
||||
mock_inference_api.openai_embeddings.assert_not_called()
|
||||
mock_index.add_chunks.assert_not_called()
|
||||
|
||||
async def test_insert_chunks_with_partially_precomputed_embeddings(self):
|
||||
mock_vector_db = MagicMock()
|
||||
mock_vector_db.embedding_model = "test-model with partial embeddings"
|
||||
mock_vector_db.embedding_dimension = 3
|
||||
mock_vector_store = MagicMock()
|
||||
mock_vector_store.embedding_model = "test-model with partial embeddings"
|
||||
mock_vector_store.embedding_dimension = 3
|
||||
mock_index = AsyncMock()
|
||||
mock_inference_api = AsyncMock()
|
||||
|
||||
vector_db_with_index = VectorDBWithIndex(
|
||||
vector_db=mock_vector_db, index=mock_index, inference_api=mock_inference_api
|
||||
vector_store_with_index = VectorStoreWithIndex(
|
||||
vector_store=mock_vector_store, index=mock_index, inference_api=mock_inference_api
|
||||
)
|
||||
|
||||
chunks = [
|
||||
|
|
@ -327,7 +260,7 @@ class TestVectorDBWithIndex:
|
|||
OpenAIEmbeddingData(embedding=[0.3, 0.3, 0.3], index=1),
|
||||
]
|
||||
|
||||
await vector_db_with_index.insert_chunks(chunks)
|
||||
await vector_store_with_index.insert_chunks(chunks)
|
||||
|
||||
# Verify openai_embeddings was called with correct params
|
||||
mock_inference_api.openai_embeddings.assert_called_once()
|
||||
|
|
|
|||
|
|
@ -8,8 +8,8 @@
|
|||
import pytest
|
||||
|
||||
from llama_stack.apis.inference import Model
|
||||
from llama_stack.apis.vector_dbs import VectorDB
|
||||
from llama_stack.core.datatypes import VectorDBWithOwner
|
||||
from llama_stack.apis.vector_stores import VectorStore
|
||||
from llama_stack.core.datatypes import VectorStoreWithOwner
|
||||
from llama_stack.core.storage.datatypes import KVStoreReference, SqliteKVStoreConfig
|
||||
from llama_stack.core.store.registry import (
|
||||
KEY_FORMAT,
|
||||
|
|
@ -20,12 +20,12 @@ from llama_stack.providers.utils.kvstore import kvstore_impl, register_kvstore_b
|
|||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_vector_db():
|
||||
return VectorDB(
|
||||
identifier="test_vector_db",
|
||||
def sample_vector_store():
|
||||
return VectorStore(
|
||||
identifier="test_vector_store",
|
||||
embedding_model="nomic-embed-text-v1.5",
|
||||
embedding_dimension=768,
|
||||
provider_resource_id="test_vector_db",
|
||||
provider_resource_id="test_vector_store",
|
||||
provider_id="test-provider",
|
||||
)
|
||||
|
||||
|
|
@ -45,17 +45,17 @@ async def test_registry_initialization(disk_dist_registry):
|
|||
assert result is None
|
||||
|
||||
|
||||
async def test_basic_registration(disk_dist_registry, sample_vector_db, sample_model):
|
||||
print(f"Registering {sample_vector_db}")
|
||||
await disk_dist_registry.register(sample_vector_db)
|
||||
async def test_basic_registration(disk_dist_registry, sample_vector_store, sample_model):
|
||||
print(f"Registering {sample_vector_store}")
|
||||
await disk_dist_registry.register(sample_vector_store)
|
||||
print(f"Registering {sample_model}")
|
||||
await disk_dist_registry.register(sample_model)
|
||||
print("Getting vector_db")
|
||||
result_vector_db = await disk_dist_registry.get("vector_db", "test_vector_db")
|
||||
assert result_vector_db is not None
|
||||
assert result_vector_db.identifier == sample_vector_db.identifier
|
||||
assert result_vector_db.embedding_model == sample_vector_db.embedding_model
|
||||
assert result_vector_db.provider_id == sample_vector_db.provider_id
|
||||
print("Getting vector_store")
|
||||
result_vector_store = await disk_dist_registry.get("vector_store", "test_vector_store")
|
||||
assert result_vector_store is not None
|
||||
assert result_vector_store.identifier == sample_vector_store.identifier
|
||||
assert result_vector_store.embedding_model == sample_vector_store.embedding_model
|
||||
assert result_vector_store.provider_id == sample_vector_store.provider_id
|
||||
|
||||
result_model = await disk_dist_registry.get("model", "test_model")
|
||||
assert result_model is not None
|
||||
|
|
@ -63,11 +63,11 @@ async def test_basic_registration(disk_dist_registry, sample_vector_db, sample_m
|
|||
assert result_model.provider_id == sample_model.provider_id
|
||||
|
||||
|
||||
async def test_cached_registry_initialization(sqlite_kvstore, sample_vector_db, sample_model):
|
||||
async def test_cached_registry_initialization(sqlite_kvstore, sample_vector_store, sample_model):
|
||||
# First populate the disk registry
|
||||
disk_registry = DiskDistributionRegistry(sqlite_kvstore)
|
||||
await disk_registry.initialize()
|
||||
await disk_registry.register(sample_vector_db)
|
||||
await disk_registry.register(sample_vector_store)
|
||||
await disk_registry.register(sample_model)
|
||||
|
||||
# Test cached version loads from disk
|
||||
|
|
@ -79,29 +79,29 @@ async def test_cached_registry_initialization(sqlite_kvstore, sample_vector_db,
|
|||
)
|
||||
await cached_registry.initialize()
|
||||
|
||||
result_vector_db = await cached_registry.get("vector_db", "test_vector_db")
|
||||
assert result_vector_db is not None
|
||||
assert result_vector_db.identifier == sample_vector_db.identifier
|
||||
assert result_vector_db.embedding_model == sample_vector_db.embedding_model
|
||||
assert result_vector_db.embedding_dimension == sample_vector_db.embedding_dimension
|
||||
assert result_vector_db.provider_id == sample_vector_db.provider_id
|
||||
result_vector_store = await cached_registry.get("vector_store", "test_vector_store")
|
||||
assert result_vector_store is not None
|
||||
assert result_vector_store.identifier == sample_vector_store.identifier
|
||||
assert result_vector_store.embedding_model == sample_vector_store.embedding_model
|
||||
assert result_vector_store.embedding_dimension == sample_vector_store.embedding_dimension
|
||||
assert result_vector_store.provider_id == sample_vector_store.provider_id
|
||||
|
||||
|
||||
async def test_cached_registry_updates(cached_disk_dist_registry):
|
||||
new_vector_db = VectorDB(
|
||||
identifier="test_vector_db_2",
|
||||
new_vector_store = VectorStore(
|
||||
identifier="test_vector_store_2",
|
||||
embedding_model="nomic-embed-text-v1.5",
|
||||
embedding_dimension=768,
|
||||
provider_resource_id="test_vector_db_2",
|
||||
provider_resource_id="test_vector_store_2",
|
||||
provider_id="baz",
|
||||
)
|
||||
await cached_disk_dist_registry.register(new_vector_db)
|
||||
await cached_disk_dist_registry.register(new_vector_store)
|
||||
|
||||
# Verify in cache
|
||||
result_vector_db = await cached_disk_dist_registry.get("vector_db", "test_vector_db_2")
|
||||
assert result_vector_db is not None
|
||||
assert result_vector_db.identifier == new_vector_db.identifier
|
||||
assert result_vector_db.provider_id == new_vector_db.provider_id
|
||||
result_vector_store = await cached_disk_dist_registry.get("vector_store", "test_vector_store_2")
|
||||
assert result_vector_store is not None
|
||||
assert result_vector_store.identifier == new_vector_store.identifier
|
||||
assert result_vector_store.provider_id == new_vector_store.provider_id
|
||||
|
||||
# Verify persisted to disk
|
||||
db_path = cached_disk_dist_registry.kvstore.db_path
|
||||
|
|
@ -111,87 +111,89 @@ async def test_cached_registry_updates(cached_disk_dist_registry):
|
|||
await kvstore_impl(KVStoreReference(backend=backend_name, namespace="registry"))
|
||||
)
|
||||
await new_registry.initialize()
|
||||
result_vector_db = await new_registry.get("vector_db", "test_vector_db_2")
|
||||
assert result_vector_db is not None
|
||||
assert result_vector_db.identifier == new_vector_db.identifier
|
||||
assert result_vector_db.provider_id == new_vector_db.provider_id
|
||||
result_vector_store = await new_registry.get("vector_store", "test_vector_store_2")
|
||||
assert result_vector_store is not None
|
||||
assert result_vector_store.identifier == new_vector_store.identifier
|
||||
assert result_vector_store.provider_id == new_vector_store.provider_id
|
||||
|
||||
|
||||
async def test_duplicate_provider_registration(cached_disk_dist_registry):
|
||||
original_vector_db = VectorDB(
|
||||
identifier="test_vector_db_2",
|
||||
original_vector_store = VectorStore(
|
||||
identifier="test_vector_store_2",
|
||||
embedding_model="nomic-embed-text-v1.5",
|
||||
embedding_dimension=768,
|
||||
provider_resource_id="test_vector_db_2",
|
||||
provider_resource_id="test_vector_store_2",
|
||||
provider_id="baz",
|
||||
)
|
||||
assert await cached_disk_dist_registry.register(original_vector_db)
|
||||
assert await cached_disk_dist_registry.register(original_vector_store)
|
||||
|
||||
duplicate_vector_db = VectorDB(
|
||||
identifier="test_vector_db_2",
|
||||
duplicate_vector_store = VectorStore(
|
||||
identifier="test_vector_store_2",
|
||||
embedding_model="different-model",
|
||||
embedding_dimension=768,
|
||||
provider_resource_id="test_vector_db_2",
|
||||
provider_resource_id="test_vector_store_2",
|
||||
provider_id="baz", # Same provider_id
|
||||
)
|
||||
with pytest.raises(ValueError, match="Object of type 'vector_db' and identifier 'test_vector_db_2' already exists"):
|
||||
await cached_disk_dist_registry.register(duplicate_vector_db)
|
||||
with pytest.raises(
|
||||
ValueError, match="Object of type 'vector_store' and identifier 'test_vector_store_2' already exists"
|
||||
):
|
||||
await cached_disk_dist_registry.register(duplicate_vector_store)
|
||||
|
||||
result = await cached_disk_dist_registry.get("vector_db", "test_vector_db_2")
|
||||
result = await cached_disk_dist_registry.get("vector_store", "test_vector_store_2")
|
||||
assert result is not None
|
||||
assert result.embedding_model == original_vector_db.embedding_model # Original values preserved
|
||||
assert result.embedding_model == original_vector_store.embedding_model # Original values preserved
|
||||
|
||||
|
||||
async def test_get_all_objects(cached_disk_dist_registry):
|
||||
# Create multiple test banks
|
||||
# Create multiple test banks
|
||||
test_vector_dbs = [
|
||||
VectorDB(
|
||||
identifier=f"test_vector_db_{i}",
|
||||
test_vector_stores = [
|
||||
VectorStore(
|
||||
identifier=f"test_vector_store_{i}",
|
||||
embedding_model="nomic-embed-text-v1.5",
|
||||
embedding_dimension=768,
|
||||
provider_resource_id=f"test_vector_db_{i}",
|
||||
provider_resource_id=f"test_vector_store_{i}",
|
||||
provider_id=f"provider_{i}",
|
||||
)
|
||||
for i in range(3)
|
||||
]
|
||||
|
||||
# Register all vector_dbs
|
||||
for vector_db in test_vector_dbs:
|
||||
await cached_disk_dist_registry.register(vector_db)
|
||||
# Register all vector_stores
|
||||
for vector_store in test_vector_stores:
|
||||
await cached_disk_dist_registry.register(vector_store)
|
||||
|
||||
# Test get_all retrieval
|
||||
all_results = await cached_disk_dist_registry.get_all()
|
||||
assert len(all_results) == 3
|
||||
|
||||
# Verify each vector_db was stored correctly
|
||||
for original_vector_db in test_vector_dbs:
|
||||
matching_vector_dbs = [v for v in all_results if v.identifier == original_vector_db.identifier]
|
||||
assert len(matching_vector_dbs) == 1
|
||||
stored_vector_db = matching_vector_dbs[0]
|
||||
assert stored_vector_db.embedding_model == original_vector_db.embedding_model
|
||||
assert stored_vector_db.provider_id == original_vector_db.provider_id
|
||||
assert stored_vector_db.embedding_dimension == original_vector_db.embedding_dimension
|
||||
# Verify each vector_store was stored correctly
|
||||
for original_vector_store in test_vector_stores:
|
||||
matching_vector_stores = [v for v in all_results if v.identifier == original_vector_store.identifier]
|
||||
assert len(matching_vector_stores) == 1
|
||||
stored_vector_store = matching_vector_stores[0]
|
||||
assert stored_vector_store.embedding_model == original_vector_store.embedding_model
|
||||
assert stored_vector_store.provider_id == original_vector_store.provider_id
|
||||
assert stored_vector_store.embedding_dimension == original_vector_store.embedding_dimension
|
||||
|
||||
|
||||
async def test_parse_registry_values_error_handling(sqlite_kvstore):
|
||||
valid_db = VectorDB(
|
||||
identifier="valid_vector_db",
|
||||
valid_db = VectorStore(
|
||||
identifier="valid_vector_store",
|
||||
embedding_model="nomic-embed-text-v1.5",
|
||||
embedding_dimension=768,
|
||||
provider_resource_id="valid_vector_db",
|
||||
provider_resource_id="valid_vector_store",
|
||||
provider_id="test-provider",
|
||||
)
|
||||
|
||||
await sqlite_kvstore.set(
|
||||
KEY_FORMAT.format(type="vector_db", identifier="valid_vector_db"), valid_db.model_dump_json()
|
||||
KEY_FORMAT.format(type="vector_store", identifier="valid_vector_store"), valid_db.model_dump_json()
|
||||
)
|
||||
|
||||
await sqlite_kvstore.set(KEY_FORMAT.format(type="vector_db", identifier="corrupted_json"), "{not valid json")
|
||||
await sqlite_kvstore.set(KEY_FORMAT.format(type="vector_store", identifier="corrupted_json"), "{not valid json")
|
||||
|
||||
await sqlite_kvstore.set(
|
||||
KEY_FORMAT.format(type="vector_db", identifier="missing_fields"),
|
||||
'{"type": "vector_db", "identifier": "missing_fields"}',
|
||||
KEY_FORMAT.format(type="vector_store", identifier="missing_fields"),
|
||||
'{"type": "vector_store", "identifier": "missing_fields"}',
|
||||
)
|
||||
|
||||
test_registry = DiskDistributionRegistry(sqlite_kvstore)
|
||||
|
|
@ -202,18 +204,18 @@ async def test_parse_registry_values_error_handling(sqlite_kvstore):
|
|||
|
||||
# Should have filtered out the invalid entries
|
||||
assert len(all_objects) == 1
|
||||
assert all_objects[0].identifier == "valid_vector_db"
|
||||
assert all_objects[0].identifier == "valid_vector_store"
|
||||
|
||||
# Check that the get method also handles errors correctly
|
||||
invalid_obj = await test_registry.get("vector_db", "corrupted_json")
|
||||
invalid_obj = await test_registry.get("vector_store", "corrupted_json")
|
||||
assert invalid_obj is None
|
||||
|
||||
invalid_obj = await test_registry.get("vector_db", "missing_fields")
|
||||
invalid_obj = await test_registry.get("vector_store", "missing_fields")
|
||||
assert invalid_obj is None
|
||||
|
||||
|
||||
async def test_cached_registry_error_handling(sqlite_kvstore):
|
||||
valid_db = VectorDB(
|
||||
valid_db = VectorStore(
|
||||
identifier="valid_cached_db",
|
||||
embedding_model="nomic-embed-text-v1.5",
|
||||
embedding_dimension=768,
|
||||
|
|
@ -222,12 +224,12 @@ async def test_cached_registry_error_handling(sqlite_kvstore):
|
|||
)
|
||||
|
||||
await sqlite_kvstore.set(
|
||||
KEY_FORMAT.format(type="vector_db", identifier="valid_cached_db"), valid_db.model_dump_json()
|
||||
KEY_FORMAT.format(type="vector_store", identifier="valid_cached_db"), valid_db.model_dump_json()
|
||||
)
|
||||
|
||||
await sqlite_kvstore.set(
|
||||
KEY_FORMAT.format(type="vector_db", identifier="invalid_cached_db"),
|
||||
'{"type": "vector_db", "identifier": "invalid_cached_db", "embedding_model": 12345}', # Should be string
|
||||
KEY_FORMAT.format(type="vector_store", identifier="invalid_cached_db"),
|
||||
'{"type": "vector_store", "identifier": "invalid_cached_db", "embedding_model": 12345}', # Should be string
|
||||
)
|
||||
|
||||
cached_registry = CachedDiskDistributionRegistry(sqlite_kvstore)
|
||||
|
|
@ -237,63 +239,65 @@ async def test_cached_registry_error_handling(sqlite_kvstore):
|
|||
assert len(all_objects) == 1
|
||||
assert all_objects[0].identifier == "valid_cached_db"
|
||||
|
||||
invalid_obj = await cached_registry.get("vector_db", "invalid_cached_db")
|
||||
invalid_obj = await cached_registry.get("vector_store", "invalid_cached_db")
|
||||
assert invalid_obj is None
|
||||
|
||||
|
||||
async def test_double_registration_identical_objects(disk_dist_registry):
|
||||
"""Test that registering identical objects succeeds (idempotent)."""
|
||||
vector_db = VectorDBWithOwner(
|
||||
identifier="test_vector_db",
|
||||
vector_store = VectorStoreWithOwner(
|
||||
identifier="test_vector_store",
|
||||
embedding_model="all-MiniLM-L6-v2",
|
||||
embedding_dimension=384,
|
||||
provider_resource_id="test_vector_db",
|
||||
provider_resource_id="test_vector_store",
|
||||
provider_id="test-provider",
|
||||
)
|
||||
|
||||
# First registration should succeed
|
||||
result1 = await disk_dist_registry.register(vector_db)
|
||||
result1 = await disk_dist_registry.register(vector_store)
|
||||
assert result1 is True
|
||||
|
||||
# Second registration of identical object should also succeed (idempotent)
|
||||
result2 = await disk_dist_registry.register(vector_db)
|
||||
result2 = await disk_dist_registry.register(vector_store)
|
||||
assert result2 is True
|
||||
|
||||
# Verify object exists and is unchanged
|
||||
retrieved = await disk_dist_registry.get("vector_db", "test_vector_db")
|
||||
retrieved = await disk_dist_registry.get("vector_store", "test_vector_store")
|
||||
assert retrieved is not None
|
||||
assert retrieved.identifier == vector_db.identifier
|
||||
assert retrieved.embedding_model == vector_db.embedding_model
|
||||
assert retrieved.identifier == vector_store.identifier
|
||||
assert retrieved.embedding_model == vector_store.embedding_model
|
||||
|
||||
|
||||
async def test_double_registration_different_objects(disk_dist_registry):
|
||||
"""Test that registering different objects with same identifier fails."""
|
||||
vector_db1 = VectorDBWithOwner(
|
||||
identifier="test_vector_db",
|
||||
vector_store1 = VectorStoreWithOwner(
|
||||
identifier="test_vector_store",
|
||||
embedding_model="all-MiniLM-L6-v2",
|
||||
embedding_dimension=384,
|
||||
provider_resource_id="test_vector_db",
|
||||
provider_resource_id="test_vector_store",
|
||||
provider_id="test-provider",
|
||||
)
|
||||
|
||||
vector_db2 = VectorDBWithOwner(
|
||||
identifier="test_vector_db", # Same identifier
|
||||
vector_store2 = VectorStoreWithOwner(
|
||||
identifier="test_vector_store", # Same identifier
|
||||
embedding_model="different-model", # Different embedding model
|
||||
embedding_dimension=384,
|
||||
provider_resource_id="test_vector_db",
|
||||
provider_resource_id="test_vector_store",
|
||||
provider_id="test-provider",
|
||||
)
|
||||
|
||||
# First registration should succeed
|
||||
result1 = await disk_dist_registry.register(vector_db1)
|
||||
result1 = await disk_dist_registry.register(vector_store1)
|
||||
assert result1 is True
|
||||
|
||||
# Second registration with different data should fail
|
||||
with pytest.raises(ValueError, match="Object of type 'vector_db' and identifier 'test_vector_db' already exists"):
|
||||
await disk_dist_registry.register(vector_db2)
|
||||
with pytest.raises(
|
||||
ValueError, match="Object of type 'vector_store' and identifier 'test_vector_store' already exists"
|
||||
):
|
||||
await disk_dist_registry.register(vector_store2)
|
||||
|
||||
# Verify original object is unchanged
|
||||
retrieved = await disk_dist_registry.get("vector_db", "test_vector_db")
|
||||
retrieved = await disk_dist_registry.get("vector_store", "test_vector_store")
|
||||
assert retrieved is not None
|
||||
assert retrieved.embedding_model == "all-MiniLM-L6-v2" # Original value
|
||||
|
||||
|
|
|
|||
|
|
@ -41,7 +41,7 @@ class TestTranslateException:
|
|||
self.identifier = identifier
|
||||
self.owner = owner
|
||||
|
||||
resource = MockResource("vector_db", "test-db")
|
||||
resource = MockResource("vector_store", "test-db")
|
||||
|
||||
exc = AccessDeniedError("create", resource, user)
|
||||
result = translate_exception(exc)
|
||||
|
|
@ -49,7 +49,7 @@ class TestTranslateException:
|
|||
assert isinstance(result, HTTPException)
|
||||
assert result.status_code == 403
|
||||
assert "test-user" in result.detail
|
||||
assert "vector_db::test-db" in result.detail
|
||||
assert "vector_store::test-db" in result.detail
|
||||
assert "create" in result.detail
|
||||
assert "roles=['user']" in result.detail
|
||||
assert "teams=['dev']" in result.detail
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue