Merge 960f5e4cd4 into sapling-pr-archive-ehhuang

This commit is contained in:
ehhuang 2025-10-27 13:23:09 -07:00 committed by GitHub
commit bf8f6b6914
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
41 changed files with 346 additions and 183 deletions

View file

@ -9862,7 +9862,7 @@ components:
$ref: '#/components/schemas/RAGDocument' $ref: '#/components/schemas/RAGDocument'
description: >- description: >-
List of documents to index in the RAG system List of documents to index in the RAG system
vector_db_id: vector_store_id:
type: string type: string
description: >- description: >-
ID of the vector database to store the document embeddings ID of the vector database to store the document embeddings
@ -9873,7 +9873,7 @@ components:
additionalProperties: false additionalProperties: false
required: required:
- documents - documents
- vector_db_id - vector_store_id
- chunk_size_in_tokens - chunk_size_in_tokens
title: InsertRequest title: InsertRequest
DefaultRAGQueryGeneratorConfig: DefaultRAGQueryGeneratorConfig:
@ -10044,7 +10044,7 @@ components:
$ref: '#/components/schemas/InterleavedContent' $ref: '#/components/schemas/InterleavedContent'
description: >- description: >-
The query content to search for in the indexed documents The query content to search for in the indexed documents
vector_db_ids: vector_store_ids:
type: array type: array
items: items:
type: string type: string
@ -10057,7 +10057,7 @@ components:
additionalProperties: false additionalProperties: false
required: required:
- content - content
- vector_db_ids - vector_store_ids
title: QueryRequest title: QueryRequest
RAGQueryResult: RAGQueryResult:
type: object type: object
@ -10281,7 +10281,7 @@ components:
InsertChunksRequest: InsertChunksRequest:
type: object type: object
properties: properties:
vector_db_id: vector_store_id:
type: string type: string
description: >- description: >-
The identifier of the vector database to insert the chunks into. The identifier of the vector database to insert the chunks into.
@ -10300,13 +10300,13 @@ components:
description: The time to live of the chunks. description: The time to live of the chunks.
additionalProperties: false additionalProperties: false
required: required:
- vector_db_id - vector_store_id
- chunks - chunks
title: InsertChunksRequest title: InsertChunksRequest
QueryChunksRequest: QueryChunksRequest:
type: object type: object
properties: properties:
vector_db_id: vector_store_id:
type: string type: string
description: >- description: >-
The identifier of the vector database to query. The identifier of the vector database to query.
@ -10326,7 +10326,7 @@ components:
description: The parameters of the query. description: The parameters of the query.
additionalProperties: false additionalProperties: false
required: required:
- vector_db_id - vector_store_id
- query - query
title: QueryChunksRequest title: QueryChunksRequest
QueryChunksResponse: QueryChunksResponse:
@ -11844,7 +11844,7 @@ components:
description: Type of the step in an agent turn. description: Type of the step in an agent turn.
const: memory_retrieval const: memory_retrieval
default: memory_retrieval default: memory_retrieval
vector_db_ids: vector_store_ids:
type: string type: string
description: >- description: >-
The IDs of the vector databases to retrieve context from. The IDs of the vector databases to retrieve context from.
@ -11857,7 +11857,7 @@ components:
- turn_id - turn_id
- step_id - step_id
- step_type - step_type
- vector_db_ids - vector_store_ids
- inserted_context - inserted_context
title: MemoryRetrievalStep title: MemoryRetrievalStep
description: >- description: >-

View file

@ -72,14 +72,14 @@ description: |
Example with hybrid search: Example with hybrid search:
```python ```python
response = await vector_io.query_chunks( response = await vector_io.query_chunks(
vector_db_id="my_db", vector_store_id="my_db",
query="your query here", query="your query here",
params={"mode": "hybrid", "max_chunks": 3, "score_threshold": 0.7}, params={"mode": "hybrid", "max_chunks": 3, "score_threshold": 0.7},
) )
# Using RRF ranker # Using RRF ranker
response = await vector_io.query_chunks( response = await vector_io.query_chunks(
vector_db_id="my_db", vector_store_id="my_db",
query="your query here", query="your query here",
params={ params={
"mode": "hybrid", "mode": "hybrid",
@ -91,7 +91,7 @@ description: |
# Using weighted ranker # Using weighted ranker
response = await vector_io.query_chunks( response = await vector_io.query_chunks(
vector_db_id="my_db", vector_store_id="my_db",
query="your query here", query="your query here",
params={ params={
"mode": "hybrid", "mode": "hybrid",
@ -105,7 +105,7 @@ description: |
Example with explicit vector search: Example with explicit vector search:
```python ```python
response = await vector_io.query_chunks( response = await vector_io.query_chunks(
vector_db_id="my_db", vector_store_id="my_db",
query="your query here", query="your query here",
params={"mode": "vector", "max_chunks": 3, "score_threshold": 0.7}, params={"mode": "vector", "max_chunks": 3, "score_threshold": 0.7},
) )
@ -114,7 +114,7 @@ description: |
Example with keyword search: Example with keyword search:
```python ```python
response = await vector_io.query_chunks( response = await vector_io.query_chunks(
vector_db_id="my_db", vector_store_id="my_db",
query="your query here", query="your query here",
params={"mode": "keyword", "max_chunks": 3, "score_threshold": 0.7}, params={"mode": "keyword", "max_chunks": 3, "score_threshold": 0.7},
) )
@ -277,14 +277,14 @@ The SQLite-vec provider supports three search modes:
Example with hybrid search: Example with hybrid search:
```python ```python
response = await vector_io.query_chunks( response = await vector_io.query_chunks(
vector_db_id="my_db", vector_store_id="my_db",
query="your query here", query="your query here",
params={"mode": "hybrid", "max_chunks": 3, "score_threshold": 0.7}, params={"mode": "hybrid", "max_chunks": 3, "score_threshold": 0.7},
) )
# Using RRF ranker # Using RRF ranker
response = await vector_io.query_chunks( response = await vector_io.query_chunks(
vector_db_id="my_db", vector_store_id="my_db",
query="your query here", query="your query here",
params={ params={
"mode": "hybrid", "mode": "hybrid",
@ -296,7 +296,7 @@ response = await vector_io.query_chunks(
# Using weighted ranker # Using weighted ranker
response = await vector_io.query_chunks( response = await vector_io.query_chunks(
vector_db_id="my_db", vector_store_id="my_db",
query="your query here", query="your query here",
params={ params={
"mode": "hybrid", "mode": "hybrid",
@ -310,7 +310,7 @@ response = await vector_io.query_chunks(
Example with explicit vector search: Example with explicit vector search:
```python ```python
response = await vector_io.query_chunks( response = await vector_io.query_chunks(
vector_db_id="my_db", vector_store_id="my_db",
query="your query here", query="your query here",
params={"mode": "vector", "max_chunks": 3, "score_threshold": 0.7}, params={"mode": "vector", "max_chunks": 3, "score_threshold": 0.7},
) )
@ -319,7 +319,7 @@ response = await vector_io.query_chunks(
Example with keyword search: Example with keyword search:
```python ```python
response = await vector_io.query_chunks( response = await vector_io.query_chunks(
vector_db_id="my_db", vector_store_id="my_db",
query="your query here", query="your query here",
params={"mode": "keyword", "max_chunks": 3, "score_threshold": 0.7}, params={"mode": "keyword", "max_chunks": 3, "score_threshold": 0.7},
) )

View file

@ -4390,7 +4390,7 @@
"const": "memory_retrieval", "const": "memory_retrieval",
"default": "memory_retrieval" "default": "memory_retrieval"
}, },
"vector_db_ids": { "vector_store_ids": {
"type": "string", "type": "string",
"description": "The IDs of the vector databases to retrieve context from." "description": "The IDs of the vector databases to retrieve context from."
}, },
@ -4404,7 +4404,7 @@
"turn_id", "turn_id",
"step_id", "step_id",
"step_type", "step_type",
"vector_db_ids", "vector_store_ids",
"inserted_context" "inserted_context"
], ],
"title": "MemoryRetrievalStep", "title": "MemoryRetrievalStep",

View file

@ -3252,7 +3252,7 @@ components:
description: Type of the step in an agent turn. description: Type of the step in an agent turn.
const: memory_retrieval const: memory_retrieval
default: memory_retrieval default: memory_retrieval
vector_db_ids: vector_store_ids:
type: string type: string
description: >- description: >-
The IDs of the vector databases to retrieve context from. The IDs of the vector databases to retrieve context from.
@ -3265,7 +3265,7 @@ components:
- turn_id - turn_id
- step_id - step_id
- step_type - step_type
- vector_db_ids - vector_store_ids
- inserted_context - inserted_context
title: MemoryRetrievalStep title: MemoryRetrievalStep
description: >- description: >-

View file

@ -2865,7 +2865,7 @@
"const": "memory_retrieval", "const": "memory_retrieval",
"default": "memory_retrieval" "default": "memory_retrieval"
}, },
"vector_db_ids": { "vector_store_ids": {
"type": "string", "type": "string",
"description": "The IDs of the vector databases to retrieve context from." "description": "The IDs of the vector databases to retrieve context from."
}, },
@ -2879,7 +2879,7 @@
"turn_id", "turn_id",
"step_id", "step_id",
"step_type", "step_type",
"vector_db_ids", "vector_store_ids",
"inserted_context" "inserted_context"
], ],
"title": "MemoryRetrievalStep", "title": "MemoryRetrievalStep",

View file

@ -2085,7 +2085,7 @@ components:
description: Type of the step in an agent turn. description: Type of the step in an agent turn.
const: memory_retrieval const: memory_retrieval
default: memory_retrieval default: memory_retrieval
vector_db_ids: vector_store_ids:
type: string type: string
description: >- description: >-
The IDs of the vector databases to retrieve context from. The IDs of the vector databases to retrieve context from.
@ -2098,7 +2098,7 @@ components:
- turn_id - turn_id
- step_id - step_id
- step_type - step_type
- vector_db_ids - vector_store_ids
- inserted_context - inserted_context
title: MemoryRetrievalStep title: MemoryRetrievalStep
description: >- description: >-

View file

@ -11412,7 +11412,7 @@
}, },
"description": "List of documents to index in the RAG system" "description": "List of documents to index in the RAG system"
}, },
"vector_db_id": { "vector_store_id": {
"type": "string", "type": "string",
"description": "ID of the vector database to store the document embeddings" "description": "ID of the vector database to store the document embeddings"
}, },
@ -11424,7 +11424,7 @@
"additionalProperties": false, "additionalProperties": false,
"required": [ "required": [
"documents", "documents",
"vector_db_id", "vector_store_id",
"chunk_size_in_tokens" "chunk_size_in_tokens"
], ],
"title": "InsertRequest" "title": "InsertRequest"
@ -11615,7 +11615,7 @@
"$ref": "#/components/schemas/InterleavedContent", "$ref": "#/components/schemas/InterleavedContent",
"description": "The query content to search for in the indexed documents" "description": "The query content to search for in the indexed documents"
}, },
"vector_db_ids": { "vector_store_ids": {
"type": "array", "type": "array",
"items": { "items": {
"type": "string" "type": "string"
@ -11630,7 +11630,7 @@
"additionalProperties": false, "additionalProperties": false,
"required": [ "required": [
"content", "content",
"vector_db_ids" "vector_store_ids"
], ],
"title": "QueryRequest" "title": "QueryRequest"
}, },
@ -11923,7 +11923,7 @@
"InsertChunksRequest": { "InsertChunksRequest": {
"type": "object", "type": "object",
"properties": { "properties": {
"vector_db_id": { "vector_store_id": {
"type": "string", "type": "string",
"description": "The identifier of the vector database to insert the chunks into." "description": "The identifier of the vector database to insert the chunks into."
}, },
@ -11941,7 +11941,7 @@
}, },
"additionalProperties": false, "additionalProperties": false,
"required": [ "required": [
"vector_db_id", "vector_store_id",
"chunks" "chunks"
], ],
"title": "InsertChunksRequest" "title": "InsertChunksRequest"
@ -11949,7 +11949,7 @@
"QueryChunksRequest": { "QueryChunksRequest": {
"type": "object", "type": "object",
"properties": { "properties": {
"vector_db_id": { "vector_store_id": {
"type": "string", "type": "string",
"description": "The identifier of the vector database to query." "description": "The identifier of the vector database to query."
}, },
@ -11986,7 +11986,7 @@
}, },
"additionalProperties": false, "additionalProperties": false,
"required": [ "required": [
"vector_db_id", "vector_store_id",
"query" "query"
], ],
"title": "QueryChunksRequest" "title": "QueryChunksRequest"

View file

@ -8649,7 +8649,7 @@ components:
$ref: '#/components/schemas/RAGDocument' $ref: '#/components/schemas/RAGDocument'
description: >- description: >-
List of documents to index in the RAG system List of documents to index in the RAG system
vector_db_id: vector_store_id:
type: string type: string
description: >- description: >-
ID of the vector database to store the document embeddings ID of the vector database to store the document embeddings
@ -8660,7 +8660,7 @@ components:
additionalProperties: false additionalProperties: false
required: required:
- documents - documents
- vector_db_id - vector_store_id
- chunk_size_in_tokens - chunk_size_in_tokens
title: InsertRequest title: InsertRequest
DefaultRAGQueryGeneratorConfig: DefaultRAGQueryGeneratorConfig:
@ -8831,7 +8831,7 @@ components:
$ref: '#/components/schemas/InterleavedContent' $ref: '#/components/schemas/InterleavedContent'
description: >- description: >-
The query content to search for in the indexed documents The query content to search for in the indexed documents
vector_db_ids: vector_store_ids:
type: array type: array
items: items:
type: string type: string
@ -8844,7 +8844,7 @@ components:
additionalProperties: false additionalProperties: false
required: required:
- content - content
- vector_db_ids - vector_store_ids
title: QueryRequest title: QueryRequest
RAGQueryResult: RAGQueryResult:
type: object type: object
@ -9068,7 +9068,7 @@ components:
InsertChunksRequest: InsertChunksRequest:
type: object type: object
properties: properties:
vector_db_id: vector_store_id:
type: string type: string
description: >- description: >-
The identifier of the vector database to insert the chunks into. The identifier of the vector database to insert the chunks into.
@ -9087,13 +9087,13 @@ components:
description: The time to live of the chunks. description: The time to live of the chunks.
additionalProperties: false additionalProperties: false
required: required:
- vector_db_id - vector_store_id
- chunks - chunks
title: InsertChunksRequest title: InsertChunksRequest
QueryChunksRequest: QueryChunksRequest:
type: object type: object
properties: properties:
vector_db_id: vector_store_id:
type: string type: string
description: >- description: >-
The identifier of the vector database to query. The identifier of the vector database to query.
@ -9113,7 +9113,7 @@ components:
description: The parameters of the query. description: The parameters of the query.
additionalProperties: false additionalProperties: false
required: required:
- vector_db_id - vector_store_id
- query - query
title: QueryChunksRequest title: QueryChunksRequest
QueryChunksResponse: QueryChunksResponse:

View file

@ -13084,7 +13084,7 @@
}, },
"description": "List of documents to index in the RAG system" "description": "List of documents to index in the RAG system"
}, },
"vector_db_id": { "vector_store_id": {
"type": "string", "type": "string",
"description": "ID of the vector database to store the document embeddings" "description": "ID of the vector database to store the document embeddings"
}, },
@ -13096,7 +13096,7 @@
"additionalProperties": false, "additionalProperties": false,
"required": [ "required": [
"documents", "documents",
"vector_db_id", "vector_store_id",
"chunk_size_in_tokens" "chunk_size_in_tokens"
], ],
"title": "InsertRequest" "title": "InsertRequest"
@ -13287,7 +13287,7 @@
"$ref": "#/components/schemas/InterleavedContent", "$ref": "#/components/schemas/InterleavedContent",
"description": "The query content to search for in the indexed documents" "description": "The query content to search for in the indexed documents"
}, },
"vector_db_ids": { "vector_store_ids": {
"type": "array", "type": "array",
"items": { "items": {
"type": "string" "type": "string"
@ -13302,7 +13302,7 @@
"additionalProperties": false, "additionalProperties": false,
"required": [ "required": [
"content", "content",
"vector_db_ids" "vector_store_ids"
], ],
"title": "QueryRequest" "title": "QueryRequest"
}, },
@ -13595,7 +13595,7 @@
"InsertChunksRequest": { "InsertChunksRequest": {
"type": "object", "type": "object",
"properties": { "properties": {
"vector_db_id": { "vector_store_id": {
"type": "string", "type": "string",
"description": "The identifier of the vector database to insert the chunks into." "description": "The identifier of the vector database to insert the chunks into."
}, },
@ -13613,7 +13613,7 @@
}, },
"additionalProperties": false, "additionalProperties": false,
"required": [ "required": [
"vector_db_id", "vector_store_id",
"chunks" "chunks"
], ],
"title": "InsertChunksRequest" "title": "InsertChunksRequest"
@ -13621,7 +13621,7 @@
"QueryChunksRequest": { "QueryChunksRequest": {
"type": "object", "type": "object",
"properties": { "properties": {
"vector_db_id": { "vector_store_id": {
"type": "string", "type": "string",
"description": "The identifier of the vector database to query." "description": "The identifier of the vector database to query."
}, },
@ -13658,7 +13658,7 @@
}, },
"additionalProperties": false, "additionalProperties": false,
"required": [ "required": [
"vector_db_id", "vector_store_id",
"query" "query"
], ],
"title": "QueryChunksRequest" "title": "QueryChunksRequest"
@ -15719,7 +15719,7 @@
"const": "memory_retrieval", "const": "memory_retrieval",
"default": "memory_retrieval" "default": "memory_retrieval"
}, },
"vector_db_ids": { "vector_store_ids": {
"type": "string", "type": "string",
"description": "The IDs of the vector databases to retrieve context from." "description": "The IDs of the vector databases to retrieve context from."
}, },
@ -15733,7 +15733,7 @@
"turn_id", "turn_id",
"step_id", "step_id",
"step_type", "step_type",
"vector_db_ids", "vector_store_ids",
"inserted_context" "inserted_context"
], ],
"title": "MemoryRetrievalStep", "title": "MemoryRetrievalStep",

View file

@ -9862,7 +9862,7 @@ components:
$ref: '#/components/schemas/RAGDocument' $ref: '#/components/schemas/RAGDocument'
description: >- description: >-
List of documents to index in the RAG system List of documents to index in the RAG system
vector_db_id: vector_store_id:
type: string type: string
description: >- description: >-
ID of the vector database to store the document embeddings ID of the vector database to store the document embeddings
@ -9873,7 +9873,7 @@ components:
additionalProperties: false additionalProperties: false
required: required:
- documents - documents
- vector_db_id - vector_store_id
- chunk_size_in_tokens - chunk_size_in_tokens
title: InsertRequest title: InsertRequest
DefaultRAGQueryGeneratorConfig: DefaultRAGQueryGeneratorConfig:
@ -10044,7 +10044,7 @@ components:
$ref: '#/components/schemas/InterleavedContent' $ref: '#/components/schemas/InterleavedContent'
description: >- description: >-
The query content to search for in the indexed documents The query content to search for in the indexed documents
vector_db_ids: vector_store_ids:
type: array type: array
items: items:
type: string type: string
@ -10057,7 +10057,7 @@ components:
additionalProperties: false additionalProperties: false
required: required:
- content - content
- vector_db_ids - vector_store_ids
title: QueryRequest title: QueryRequest
RAGQueryResult: RAGQueryResult:
type: object type: object
@ -10281,7 +10281,7 @@ components:
InsertChunksRequest: InsertChunksRequest:
type: object type: object
properties: properties:
vector_db_id: vector_store_id:
type: string type: string
description: >- description: >-
The identifier of the vector database to insert the chunks into. The identifier of the vector database to insert the chunks into.
@ -10300,13 +10300,13 @@ components:
description: The time to live of the chunks. description: The time to live of the chunks.
additionalProperties: false additionalProperties: false
required: required:
- vector_db_id - vector_store_id
- chunks - chunks
title: InsertChunksRequest title: InsertChunksRequest
QueryChunksRequest: QueryChunksRequest:
type: object type: object
properties: properties:
vector_db_id: vector_store_id:
type: string type: string
description: >- description: >-
The identifier of the vector database to query. The identifier of the vector database to query.
@ -10326,7 +10326,7 @@ components:
description: The parameters of the query. description: The parameters of the query.
additionalProperties: false additionalProperties: false
required: required:
- vector_db_id - vector_store_id
- query - query
title: QueryChunksRequest title: QueryChunksRequest
QueryChunksResponse: QueryChunksResponse:
@ -11844,7 +11844,7 @@ components:
description: Type of the step in an agent turn. description: Type of the step in an agent turn.
const: memory_retrieval const: memory_retrieval
default: memory_retrieval default: memory_retrieval
vector_db_ids: vector_store_ids:
type: string type: string
description: >- description: >-
The IDs of the vector databases to retrieve context from. The IDs of the vector databases to retrieve context from.
@ -11857,7 +11857,7 @@ components:
- turn_id - turn_id
- step_id - step_id
- step_type - step_type
- vector_db_ids - vector_store_ids
- inserted_context - inserted_context
title: MemoryRetrievalStep title: MemoryRetrievalStep
description: >- description: >-

View file

@ -78,6 +78,8 @@ dev = [
] ]
# These are the dependencies required for running unit tests. # These are the dependencies required for running unit tests.
unit = [ unit = [
"anthropic",
"databricks-sdk",
"sqlite-vec", "sqlite-vec",
"ollama", "ollama",
"aiosqlite", "aiosqlite",

View file

@ -149,13 +149,13 @@ class ShieldCallStep(StepCommon):
class MemoryRetrievalStep(StepCommon): class MemoryRetrievalStep(StepCommon):
"""A memory retrieval step in an agent turn. """A memory retrieval step in an agent turn.
:param vector_db_ids: The IDs of the vector databases to retrieve context from. :param vector_store_ids: The IDs of the vector databases to retrieve context from.
:param inserted_context: The context retrieved from the vector databases. :param inserted_context: The context retrieved from the vector databases.
""" """
step_type: Literal[StepType.memory_retrieval] = StepType.memory_retrieval step_type: Literal[StepType.memory_retrieval] = StepType.memory_retrieval
# TODO: should this be List[str]? # TODO: should this be List[str]?
vector_db_ids: str vector_store_ids: str
inserted_context: InterleavedContent inserted_context: InterleavedContent

View file

@ -190,13 +190,13 @@ class RAGToolRuntime(Protocol):
async def insert( async def insert(
self, self,
documents: list[RAGDocument], documents: list[RAGDocument],
vector_db_id: str, vector_store_id: str,
chunk_size_in_tokens: int = 512, chunk_size_in_tokens: int = 512,
) -> None: ) -> None:
"""Index documents so they can be used by the RAG system. """Index documents so they can be used by the RAG system.
:param documents: List of documents to index in the RAG system :param documents: List of documents to index in the RAG system
:param vector_db_id: ID of the vector database to store the document embeddings :param vector_store_id: ID of the vector database to store the document embeddings
:param chunk_size_in_tokens: (Optional) Size in tokens for document chunking during indexing :param chunk_size_in_tokens: (Optional) Size in tokens for document chunking during indexing
""" """
... ...
@ -205,13 +205,13 @@ class RAGToolRuntime(Protocol):
async def query( async def query(
self, self,
content: InterleavedContent, content: InterleavedContent,
vector_db_ids: list[str], vector_store_ids: list[str],
query_config: RAGQueryConfig | None = None, query_config: RAGQueryConfig | None = None,
) -> RAGQueryResult: ) -> RAGQueryResult:
"""Query the RAG system for context; typically invoked by the agent. """Query the RAG system for context; typically invoked by the agent.
:param content: The query content to search for in the indexed documents :param content: The query content to search for in the indexed documents
:param vector_db_ids: List of vector database IDs to search within :param vector_store_ids: List of vector database IDs to search within
:param query_config: (Optional) Configuration parameters for the query operation :param query_config: (Optional) Configuration parameters for the query operation
:returns: RAGQueryResult containing the retrieved content and metadata :returns: RAGQueryResult containing the retrieved content and metadata
""" """

View file

@ -529,17 +529,17 @@ class VectorIO(Protocol):
# this will just block now until chunks are inserted, but it should # this will just block now until chunks are inserted, but it should
# probably return a Job instance which can be polled for completion # probably return a Job instance which can be polled for completion
# TODO: rename vector_db_id to vector_store_id once Stainless is working # TODO: rename vector_store_id to vector_store_id once Stainless is working
@webmethod(route="/vector-io/insert", method="POST", level=LLAMA_STACK_API_V1) @webmethod(route="/vector-io/insert", method="POST", level=LLAMA_STACK_API_V1)
async def insert_chunks( async def insert_chunks(
self, self,
vector_db_id: str, vector_store_id: str,
chunks: list[Chunk], chunks: list[Chunk],
ttl_seconds: int | None = None, ttl_seconds: int | None = None,
) -> None: ) -> None:
"""Insert chunks into a vector database. """Insert chunks into a vector database.
:param vector_db_id: The identifier of the vector database to insert the chunks into. :param vector_store_id: The identifier of the vector database to insert the chunks into.
:param chunks: The chunks to insert. Each `Chunk` should contain content which can be interleaved text, images, or other types. :param chunks: The chunks to insert. Each `Chunk` should contain content which can be interleaved text, images, or other types.
`metadata`: `dict[str, Any]` and `embedding`: `List[float]` are optional. `metadata`: `dict[str, Any]` and `embedding`: `List[float]` are optional.
If `metadata` is provided, you configure how Llama Stack formats the chunk during generation. If `metadata` is provided, you configure how Llama Stack formats the chunk during generation.
@ -548,17 +548,17 @@ class VectorIO(Protocol):
""" """
... ...
# TODO: rename vector_db_id to vector_store_id once Stainless is working # TODO: rename vector_store_id to vector_store_id once Stainless is working
@webmethod(route="/vector-io/query", method="POST", level=LLAMA_STACK_API_V1) @webmethod(route="/vector-io/query", method="POST", level=LLAMA_STACK_API_V1)
async def query_chunks( async def query_chunks(
self, self,
vector_db_id: str, vector_store_id: str,
query: InterleavedContent, query: InterleavedContent,
params: dict[str, Any] | None = None, params: dict[str, Any] | None = None,
) -> QueryChunksResponse: ) -> QueryChunksResponse:
"""Query chunks from a vector database. """Query chunks from a vector database.
:param vector_db_id: The identifier of the vector database to query. :param vector_store_id: The identifier of the vector database to query.
:param query: The query to search for. :param query: The query to search for.
:param params: The parameters of the query. :param params: The parameters of the query.
:returns: A QueryChunksResponse. :returns: A QueryChunksResponse.

View file

@ -73,27 +73,27 @@ class VectorIORouter(VectorIO):
async def insert_chunks( async def insert_chunks(
self, self,
vector_db_id: str, vector_store_id: str,
chunks: list[Chunk], chunks: list[Chunk],
ttl_seconds: int | None = None, ttl_seconds: int | None = None,
) -> None: ) -> None:
doc_ids = [chunk.document_id for chunk in chunks[:3]] doc_ids = [chunk.document_id for chunk in chunks[:3]]
logger.debug( logger.debug(
f"VectorIORouter.insert_chunks: {vector_db_id}, {len(chunks)} chunks, " f"VectorIORouter.insert_chunks: {vector_store_id}, {len(chunks)} chunks, "
f"ttl_seconds={ttl_seconds}, chunk_ids={doc_ids}{' and more...' if len(chunks) > 3 else ''}" f"ttl_seconds={ttl_seconds}, chunk_ids={doc_ids}{' and more...' if len(chunks) > 3 else ''}"
) )
provider = await self.routing_table.get_provider_impl(vector_db_id) provider = await self.routing_table.get_provider_impl(vector_store_id)
return await provider.insert_chunks(vector_db_id, chunks, ttl_seconds) return await provider.insert_chunks(vector_store_id, chunks, ttl_seconds)
async def query_chunks( async def query_chunks(
self, self,
vector_db_id: str, vector_store_id: str,
query: InterleavedContent, query: InterleavedContent,
params: dict[str, Any] | None = None, params: dict[str, Any] | None = None,
) -> QueryChunksResponse: ) -> QueryChunksResponse:
logger.debug(f"VectorIORouter.query_chunks: {vector_db_id}") logger.debug(f"VectorIORouter.query_chunks: {vector_store_id}")
provider = await self.routing_table.get_provider_impl(vector_db_id) provider = await self.routing_table.get_provider_impl(vector_store_id)
return await provider.query_chunks(vector_db_id, query, params) return await provider.query_chunks(vector_store_id, query, params)
# OpenAI Vector Stores API endpoints # OpenAI Vector Stores API endpoints
async def openai_create_vector_store( async def openai_create_vector_store(

View file

@ -488,13 +488,13 @@ class ChatAgent(ShieldRunnerMixin):
session_info = await self.storage.get_session_info(session_id) session_info = await self.storage.get_session_info(session_id)
# if the session has a memory bank id, let the memory tool use it # if the session has a memory bank id, let the memory tool use it
if session_info and session_info.vector_db_id: if session_info and session_info.vector_store_id:
for tool_name in self.tool_name_to_args.keys(): for tool_name in self.tool_name_to_args.keys():
if tool_name == MEMORY_QUERY_TOOL: if tool_name == MEMORY_QUERY_TOOL:
if "vector_db_ids" not in self.tool_name_to_args[tool_name]: if "vector_store_ids" not in self.tool_name_to_args[tool_name]:
self.tool_name_to_args[tool_name]["vector_db_ids"] = [session_info.vector_db_id] self.tool_name_to_args[tool_name]["vector_store_ids"] = [session_info.vector_store_id]
else: else:
self.tool_name_to_args[tool_name]["vector_db_ids"].append(session_info.vector_db_id) self.tool_name_to_args[tool_name]["vector_store_ids"].append(session_info.vector_store_id)
output_attachments = [] output_attachments = []

View file

@ -22,7 +22,7 @@ log = get_logger(name=__name__, category="agents::meta_reference")
class AgentSessionInfo(Session): class AgentSessionInfo(Session):
# TODO: is this used anywhere? # TODO: is this used anywhere?
vector_db_id: str | None = None vector_store_id: str | None = None
started_at: datetime started_at: datetime
owner: User | None = None owner: User | None = None
identifier: str | None = None identifier: str | None = None
@ -93,12 +93,12 @@ class AgentPersistence:
return session_info return session_info
async def add_vector_db_to_session(self, session_id: str, vector_db_id: str): async def add_vector_db_to_session(self, session_id: str, vector_store_id: str):
session_info = await self.get_session_if_accessible(session_id) session_info = await self.get_session_if_accessible(session_id)
if session_info is None: if session_info is None:
raise SessionNotFoundError(session_id) raise SessionNotFoundError(session_id)
session_info.vector_db_id = vector_db_id session_info.vector_store_id = vector_store_id
await self.kvstore.set( await self.kvstore.set(
key=f"session:{self.agent_id}:{session_id}", key=f"session:{self.agent_id}:{session_id}",
value=session_info.model_dump_json(), value=session_info.model_dump_json(),

View file

@ -119,7 +119,7 @@ class MemoryToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime, RAGToolRunti
async def insert( async def insert(
self, self,
documents: list[RAGDocument], documents: list[RAGDocument],
vector_db_id: str, vector_store_id: str,
chunk_size_in_tokens: int = 512, chunk_size_in_tokens: int = 512,
) -> None: ) -> None:
if not documents: if not documents:
@ -158,14 +158,14 @@ class MemoryToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime, RAGToolRunti
try: try:
await self.vector_io_api.openai_attach_file_to_vector_store( await self.vector_io_api.openai_attach_file_to_vector_store(
vector_store_id=vector_db_id, vector_store_id=vector_store_id,
file_id=created_file.id, file_id=created_file.id,
attributes=doc.metadata, attributes=doc.metadata,
chunking_strategy=chunking_strategy, chunking_strategy=chunking_strategy,
) )
except Exception as e: except Exception as e:
log.error( log.error(
f"Failed to attach file {created_file.id} to vector store {vector_db_id} for document {doc.document_id}: {e}" f"Failed to attach file {created_file.id} to vector store {vector_store_id} for document {doc.document_id}: {e}"
) )
continue continue
@ -176,10 +176,10 @@ class MemoryToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime, RAGToolRunti
async def query( async def query(
self, self,
content: InterleavedContent, content: InterleavedContent,
vector_db_ids: list[str], vector_store_ids: list[str],
query_config: RAGQueryConfig | None = None, query_config: RAGQueryConfig | None = None,
) -> RAGQueryResult: ) -> RAGQueryResult:
if not vector_db_ids: if not vector_store_ids:
raise ValueError( raise ValueError(
"No vector DBs were provided to the knowledge search tool. Please provide at least one vector DB ID." "No vector DBs were provided to the knowledge search tool. Please provide at least one vector DB ID."
) )
@ -192,7 +192,7 @@ class MemoryToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime, RAGToolRunti
) )
tasks = [ tasks = [
self.vector_io_api.query_chunks( self.vector_io_api.query_chunks(
vector_db_id=vector_db_id, vector_store_id=vector_store_id,
query=query, query=query,
params={ params={
"mode": query_config.mode, "mode": query_config.mode,
@ -201,18 +201,18 @@ class MemoryToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime, RAGToolRunti
"ranker": query_config.ranker, "ranker": query_config.ranker,
}, },
) )
for vector_db_id in vector_db_ids for vector_store_id in vector_store_ids
] ]
results: list[QueryChunksResponse] = await asyncio.gather(*tasks) results: list[QueryChunksResponse] = await asyncio.gather(*tasks)
chunks = [] chunks = []
scores = [] scores = []
for vector_db_id, result in zip(vector_db_ids, results, strict=False): for vector_store_id, result in zip(vector_store_ids, results, strict=False):
for chunk, score in zip(result.chunks, result.scores, strict=False): for chunk, score in zip(result.chunks, result.scores, strict=False):
if not hasattr(chunk, "metadata") or chunk.metadata is None: if not hasattr(chunk, "metadata") or chunk.metadata is None:
chunk.metadata = {} chunk.metadata = {}
chunk.metadata["vector_db_id"] = vector_db_id chunk.metadata["vector_store_id"] = vector_store_id
chunks.append(chunk) chunks.append(chunk)
scores.append(score) scores.append(score)
@ -250,7 +250,7 @@ class MemoryToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime, RAGToolRunti
metadata_keys_to_exclude_from_context = [ metadata_keys_to_exclude_from_context = [
"token_count", "token_count",
"metadata_token_count", "metadata_token_count",
"vector_db_id", "vector_store_id",
] ]
metadata_for_context = {} metadata_for_context = {}
for k in chunk_metadata_keys_to_include_from_context: for k in chunk_metadata_keys_to_include_from_context:
@ -275,7 +275,7 @@ class MemoryToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime, RAGToolRunti
"document_ids": [c.document_id for c in chunks[: len(picked)]], "document_ids": [c.document_id for c in chunks[: len(picked)]],
"chunks": [c.content for c in chunks[: len(picked)]], "chunks": [c.content for c in chunks[: len(picked)]],
"scores": scores[: len(picked)], "scores": scores[: len(picked)],
"vector_db_ids": [c.metadata["vector_db_id"] for c in chunks[: len(picked)]], "vector_store_ids": [c.metadata["vector_store_id"] for c in chunks[: len(picked)]],
}, },
) )
@ -309,7 +309,7 @@ class MemoryToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime, RAGToolRunti
) )
async def invoke_tool(self, tool_name: str, kwargs: dict[str, Any]) -> ToolInvocationResult: async def invoke_tool(self, tool_name: str, kwargs: dict[str, Any]) -> ToolInvocationResult:
vector_db_ids = kwargs.get("vector_db_ids", []) vector_store_ids = kwargs.get("vector_store_ids", [])
query_config = kwargs.get("query_config") query_config = kwargs.get("query_config")
if query_config: if query_config:
query_config = TypeAdapter(RAGQueryConfig).validate_python(query_config) query_config = TypeAdapter(RAGQueryConfig).validate_python(query_config)
@ -319,7 +319,7 @@ class MemoryToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime, RAGToolRunti
query = kwargs["query"] query = kwargs["query"]
result = await self.query( result = await self.query(
content=query, content=query,
vector_db_ids=vector_db_ids, vector_store_ids=vector_store_ids,
query_config=query_config, query_config=query_config,
) )

View file

@ -248,19 +248,19 @@ class FaissVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorStoresProtoco
del self.cache[vector_store_id] del self.cache[vector_store_id]
await self.kvstore.delete(f"{VECTOR_DBS_PREFIX}{vector_store_id}") await self.kvstore.delete(f"{VECTOR_DBS_PREFIX}{vector_store_id}")
async def insert_chunks(self, vector_db_id: str, chunks: list[Chunk], ttl_seconds: int | None = None) -> None: async def insert_chunks(self, vector_store_id: str, chunks: list[Chunk], ttl_seconds: int | None = None) -> None:
index = self.cache.get(vector_db_id) index = self.cache.get(vector_store_id)
if index is None: if index is None:
raise ValueError(f"Vector DB {vector_db_id} not found. found: {self.cache.keys()}") raise ValueError(f"Vector DB {vector_store_id} not found. found: {self.cache.keys()}")
await index.insert_chunks(chunks) await index.insert_chunks(chunks)
async def query_chunks( async def query_chunks(
self, vector_db_id: str, query: InterleavedContent, params: dict[str, Any] | None = None self, vector_store_id: str, query: InterleavedContent, params: dict[str, Any] | None = None
) -> QueryChunksResponse: ) -> QueryChunksResponse:
index = self.cache.get(vector_db_id) index = self.cache.get(vector_store_id)
if index is None: if index is None:
raise VectorStoreNotFoundError(vector_db_id) raise VectorStoreNotFoundError(vector_store_id)
return await index.query_chunks(query, params) return await index.query_chunks(query, params)

View file

@ -447,20 +447,20 @@ class SQLiteVecVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorStoresPro
await self.cache[vector_store_id].index.delete() await self.cache[vector_store_id].index.delete()
del self.cache[vector_store_id] del self.cache[vector_store_id]
async def insert_chunks(self, vector_db_id: str, chunks: list[Chunk], ttl_seconds: int | None = None) -> None: async def insert_chunks(self, vector_store_id: str, chunks: list[Chunk], ttl_seconds: int | None = None) -> None:
index = await self._get_and_cache_vector_store_index(vector_db_id) index = await self._get_and_cache_vector_store_index(vector_store_id)
if not index: if not index:
raise VectorStoreNotFoundError(vector_db_id) raise VectorStoreNotFoundError(vector_store_id)
# The VectorStoreWithIndex helper is expected to compute embeddings via the inference_api # The VectorStoreWithIndex helper is expected to compute embeddings via the inference_api
# and then call our index's add_chunks. # and then call our index's add_chunks.
await index.insert_chunks(chunks) await index.insert_chunks(chunks)
async def query_chunks( async def query_chunks(
self, vector_db_id: str, query: Any, params: dict[str, Any] | None = None self, vector_store_id: str, query: Any, params: dict[str, Any] | None = None
) -> QueryChunksResponse: ) -> QueryChunksResponse:
index = await self._get_and_cache_vector_store_index(vector_db_id) index = await self._get_and_cache_vector_store_index(vector_store_id)
if not index: if not index:
raise VectorStoreNotFoundError(vector_db_id) raise VectorStoreNotFoundError(vector_store_id)
return await index.query_chunks(query, params) return await index.query_chunks(query, params)
async def delete_chunks(self, store_id: str, chunks_for_deletion: list[ChunkForDeletion]) -> None: async def delete_chunks(self, store_id: str, chunks_for_deletion: list[ChunkForDeletion]) -> None:

View file

@ -61,6 +61,7 @@ def available_providers() -> list[ProviderSpec]:
pip_packages=[], pip_packages=[],
module="llama_stack.providers.remote.inference.cerebras", module="llama_stack.providers.remote.inference.cerebras",
config_class="llama_stack.providers.remote.inference.cerebras.CerebrasImplConfig", config_class="llama_stack.providers.remote.inference.cerebras.CerebrasImplConfig",
provider_data_validator="llama_stack.providers.remote.inference.cerebras.config.CerebrasProviderDataValidator",
description="Cerebras inference provider for running models on Cerebras Cloud platform.", description="Cerebras inference provider for running models on Cerebras Cloud platform.",
), ),
RemoteProviderSpec( RemoteProviderSpec(
@ -149,6 +150,7 @@ def available_providers() -> list[ProviderSpec]:
pip_packages=["databricks-sdk"], pip_packages=["databricks-sdk"],
module="llama_stack.providers.remote.inference.databricks", module="llama_stack.providers.remote.inference.databricks",
config_class="llama_stack.providers.remote.inference.databricks.DatabricksImplConfig", config_class="llama_stack.providers.remote.inference.databricks.DatabricksImplConfig",
provider_data_validator="llama_stack.providers.remote.inference.databricks.config.DatabricksProviderDataValidator",
description="Databricks inference provider for running models on Databricks' unified analytics platform.", description="Databricks inference provider for running models on Databricks' unified analytics platform.",
), ),
RemoteProviderSpec( RemoteProviderSpec(
@ -158,6 +160,7 @@ def available_providers() -> list[ProviderSpec]:
pip_packages=[], pip_packages=[],
module="llama_stack.providers.remote.inference.nvidia", module="llama_stack.providers.remote.inference.nvidia",
config_class="llama_stack.providers.remote.inference.nvidia.NVIDIAConfig", config_class="llama_stack.providers.remote.inference.nvidia.NVIDIAConfig",
provider_data_validator="llama_stack.providers.remote.inference.nvidia.config.NVIDIAProviderDataValidator",
description="NVIDIA inference provider for accessing NVIDIA NIM models and AI services.", description="NVIDIA inference provider for accessing NVIDIA NIM models and AI services.",
), ),
RemoteProviderSpec( RemoteProviderSpec(
@ -167,6 +170,7 @@ def available_providers() -> list[ProviderSpec]:
pip_packages=[], pip_packages=[],
module="llama_stack.providers.remote.inference.runpod", module="llama_stack.providers.remote.inference.runpod",
config_class="llama_stack.providers.remote.inference.runpod.RunpodImplConfig", config_class="llama_stack.providers.remote.inference.runpod.RunpodImplConfig",
provider_data_validator="llama_stack.providers.remote.inference.runpod.config.RunpodProviderDataValidator",
description="RunPod inference provider for running models on RunPod's cloud GPU platform.", description="RunPod inference provider for running models on RunPod's cloud GPU platform.",
), ),
RemoteProviderSpec( RemoteProviderSpec(

View file

@ -163,14 +163,14 @@ The SQLite-vec provider supports three search modes:
Example with hybrid search: Example with hybrid search:
```python ```python
response = await vector_io.query_chunks( response = await vector_io.query_chunks(
vector_db_id="my_db", vector_store_id="my_db",
query="your query here", query="your query here",
params={"mode": "hybrid", "max_chunks": 3, "score_threshold": 0.7}, params={"mode": "hybrid", "max_chunks": 3, "score_threshold": 0.7},
) )
# Using RRF ranker # Using RRF ranker
response = await vector_io.query_chunks( response = await vector_io.query_chunks(
vector_db_id="my_db", vector_store_id="my_db",
query="your query here", query="your query here",
params={ params={
"mode": "hybrid", "mode": "hybrid",
@ -182,7 +182,7 @@ response = await vector_io.query_chunks(
# Using weighted ranker # Using weighted ranker
response = await vector_io.query_chunks( response = await vector_io.query_chunks(
vector_db_id="my_db", vector_store_id="my_db",
query="your query here", query="your query here",
params={ params={
"mode": "hybrid", "mode": "hybrid",
@ -196,7 +196,7 @@ response = await vector_io.query_chunks(
Example with explicit vector search: Example with explicit vector search:
```python ```python
response = await vector_io.query_chunks( response = await vector_io.query_chunks(
vector_db_id="my_db", vector_store_id="my_db",
query="your query here", query="your query here",
params={"mode": "vector", "max_chunks": 3, "score_threshold": 0.7}, params={"mode": "vector", "max_chunks": 3, "score_threshold": 0.7},
) )
@ -205,7 +205,7 @@ response = await vector_io.query_chunks(
Example with keyword search: Example with keyword search:
```python ```python
response = await vector_io.query_chunks( response = await vector_io.query_chunks(
vector_db_id="my_db", vector_store_id="my_db",
query="your query here", query="your query here",
params={"mode": "keyword", "max_chunks": 3, "score_threshold": 0.7}, params={"mode": "keyword", "max_chunks": 3, "score_threshold": 0.7},
) )

View file

@ -18,6 +18,8 @@ from .config import CerebrasImplConfig
class CerebrasInferenceAdapter(OpenAIMixin): class CerebrasInferenceAdapter(OpenAIMixin):
config: CerebrasImplConfig config: CerebrasImplConfig
provider_data_api_key_field: str = "cerebras_api_key"
def get_base_url(self) -> str: def get_base_url(self) -> str:
return urljoin(self.config.base_url, "v1") return urljoin(self.config.base_url, "v1")

View file

@ -7,7 +7,7 @@
import os import os
from typing import Any from typing import Any
from pydantic import Field from pydantic import BaseModel, Field
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
from llama_stack.schema_utils import json_schema_type from llama_stack.schema_utils import json_schema_type
@ -15,6 +15,13 @@ from llama_stack.schema_utils import json_schema_type
DEFAULT_BASE_URL = "https://api.cerebras.ai" DEFAULT_BASE_URL = "https://api.cerebras.ai"
class CerebrasProviderDataValidator(BaseModel):
cerebras_api_key: str | None = Field(
default=None,
description="API key for Cerebras models",
)
@json_schema_type @json_schema_type
class CerebrasImplConfig(RemoteInferenceProviderConfig): class CerebrasImplConfig(RemoteInferenceProviderConfig):
base_url: str = Field( base_url: str = Field(

View file

@ -6,12 +6,19 @@
from typing import Any from typing import Any
from pydantic import Field, SecretStr from pydantic import BaseModel, Field, SecretStr
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
from llama_stack.schema_utils import json_schema_type from llama_stack.schema_utils import json_schema_type
class DatabricksProviderDataValidator(BaseModel):
databricks_api_token: str | None = Field(
default=None,
description="API token for Databricks models",
)
@json_schema_type @json_schema_type
class DatabricksImplConfig(RemoteInferenceProviderConfig): class DatabricksImplConfig(RemoteInferenceProviderConfig):
url: str | None = Field( url: str | None = Field(

View file

@ -20,6 +20,8 @@ logger = get_logger(name=__name__, category="inference::databricks")
class DatabricksInferenceAdapter(OpenAIMixin): class DatabricksInferenceAdapter(OpenAIMixin):
config: DatabricksImplConfig config: DatabricksImplConfig
provider_data_api_key_field: str = "databricks_api_token"
# source: https://docs.databricks.com/aws/en/machine-learning/foundation-model-apis/supported-models # source: https://docs.databricks.com/aws/en/machine-learning/foundation-model-apis/supported-models
embedding_model_metadata: dict[str, dict[str, int]] = { embedding_model_metadata: dict[str, dict[str, int]] = {
"databricks-gte-large-en": {"embedding_dimension": 1024, "context_length": 8192}, "databricks-gte-large-en": {"embedding_dimension": 1024, "context_length": 8192},

View file

@ -7,12 +7,19 @@
import os import os
from typing import Any from typing import Any
from pydantic import Field from pydantic import BaseModel, Field
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
from llama_stack.schema_utils import json_schema_type from llama_stack.schema_utils import json_schema_type
class NVIDIAProviderDataValidator(BaseModel):
nvidia_api_key: str | None = Field(
default=None,
description="API key for NVIDIA NIM models",
)
@json_schema_type @json_schema_type
class NVIDIAConfig(RemoteInferenceProviderConfig): class NVIDIAConfig(RemoteInferenceProviderConfig):
""" """

View file

@ -17,6 +17,8 @@ logger = get_logger(name=__name__, category="inference::nvidia")
class NVIDIAInferenceAdapter(OpenAIMixin): class NVIDIAInferenceAdapter(OpenAIMixin):
config: NVIDIAConfig config: NVIDIAConfig
provider_data_api_key_field: str = "nvidia_api_key"
""" """
NVIDIA Inference Adapter for Llama Stack. NVIDIA Inference Adapter for Llama Stack.
""" """

View file

@ -6,12 +6,19 @@
from typing import Any from typing import Any
from pydantic import Field, SecretStr from pydantic import BaseModel, Field, SecretStr
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
from llama_stack.schema_utils import json_schema_type from llama_stack.schema_utils import json_schema_type
class RunpodProviderDataValidator(BaseModel):
runpod_api_token: str | None = Field(
default=None,
description="API token for RunPod models",
)
@json_schema_type @json_schema_type
class RunpodImplConfig(RemoteInferenceProviderConfig): class RunpodImplConfig(RemoteInferenceProviderConfig):
url: str | None = Field( url: str | None = Field(

View file

@ -24,6 +24,7 @@ class RunpodInferenceAdapter(OpenAIMixin):
""" """
config: RunpodImplConfig config: RunpodImplConfig
provider_data_api_key_field: str = "runpod_api_token"
def get_base_url(self) -> str: def get_base_url(self) -> str:
"""Get base URL for OpenAI client.""" """Get base URL for OpenAI client."""

View file

@ -169,20 +169,20 @@ class ChromaVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorStoresProtoc
await self.cache[vector_store_id].index.delete() await self.cache[vector_store_id].index.delete()
del self.cache[vector_store_id] del self.cache[vector_store_id]
async def insert_chunks(self, vector_db_id: str, chunks: list[Chunk], ttl_seconds: int | None = None) -> None: async def insert_chunks(self, vector_store_id: str, chunks: list[Chunk], ttl_seconds: int | None = None) -> None:
index = await self._get_and_cache_vector_store_index(vector_db_id) index = await self._get_and_cache_vector_store_index(vector_store_id)
if index is None: if index is None:
raise ValueError(f"Vector DB {vector_db_id} not found in Chroma") raise ValueError(f"Vector DB {vector_store_id} not found in Chroma")
await index.insert_chunks(chunks) await index.insert_chunks(chunks)
async def query_chunks( async def query_chunks(
self, vector_db_id: str, query: InterleavedContent, params: dict[str, Any] | None = None self, vector_store_id: str, query: InterleavedContent, params: dict[str, Any] | None = None
) -> QueryChunksResponse: ) -> QueryChunksResponse:
index = await self._get_and_cache_vector_store_index(vector_db_id) index = await self._get_and_cache_vector_store_index(vector_store_id)
if index is None: if index is None:
raise ValueError(f"Vector DB {vector_db_id} not found in Chroma") raise ValueError(f"Vector DB {vector_store_id} not found in Chroma")
return await index.query_chunks(query, params) return await index.query_chunks(query, params)

View file

@ -348,19 +348,19 @@ class MilvusVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorStoresProtoc
await self.cache[vector_store_id].index.delete() await self.cache[vector_store_id].index.delete()
del self.cache[vector_store_id] del self.cache[vector_store_id]
async def insert_chunks(self, vector_db_id: str, chunks: list[Chunk], ttl_seconds: int | None = None) -> None: async def insert_chunks(self, vector_store_id: str, chunks: list[Chunk], ttl_seconds: int | None = None) -> None:
index = await self._get_and_cache_vector_store_index(vector_db_id) index = await self._get_and_cache_vector_store_index(vector_store_id)
if not index: if not index:
raise VectorStoreNotFoundError(vector_db_id) raise VectorStoreNotFoundError(vector_store_id)
await index.insert_chunks(chunks) await index.insert_chunks(chunks)
async def query_chunks( async def query_chunks(
self, vector_db_id: str, query: InterleavedContent, params: dict[str, Any] | None = None self, vector_store_id: str, query: InterleavedContent, params: dict[str, Any] | None = None
) -> QueryChunksResponse: ) -> QueryChunksResponse:
index = await self._get_and_cache_vector_store_index(vector_db_id) index = await self._get_and_cache_vector_store_index(vector_store_id)
if not index: if not index:
raise VectorStoreNotFoundError(vector_db_id) raise VectorStoreNotFoundError(vector_store_id)
return await index.query_chunks(query, params) return await index.query_chunks(query, params)
async def delete_chunks(self, store_id: str, chunks_for_deletion: list[ChunkForDeletion]) -> None: async def delete_chunks(self, store_id: str, chunks_for_deletion: list[ChunkForDeletion]) -> None:

View file

@ -399,14 +399,14 @@ class PGVectorVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorStoresProt
assert self.kvstore is not None assert self.kvstore is not None
await self.kvstore.delete(key=f"{VECTOR_DBS_PREFIX}{vector_store_id}") await self.kvstore.delete(key=f"{VECTOR_DBS_PREFIX}{vector_store_id}")
async def insert_chunks(self, vector_db_id: str, chunks: list[Chunk], ttl_seconds: int | None = None) -> None: async def insert_chunks(self, vector_store_id: str, chunks: list[Chunk], ttl_seconds: int | None = None) -> None:
index = await self._get_and_cache_vector_store_index(vector_db_id) index = await self._get_and_cache_vector_store_index(vector_store_id)
await index.insert_chunks(chunks) await index.insert_chunks(chunks)
async def query_chunks( async def query_chunks(
self, vector_db_id: str, query: InterleavedContent, params: dict[str, Any] | None = None self, vector_store_id: str, query: InterleavedContent, params: dict[str, Any] | None = None
) -> QueryChunksResponse: ) -> QueryChunksResponse:
index = await self._get_and_cache_vector_store_index(vector_db_id) index = await self._get_and_cache_vector_store_index(vector_store_id)
return await index.query_chunks(query, params) return await index.query_chunks(query, params)
async def _get_and_cache_vector_store_index(self, vector_store_id: str) -> VectorStoreWithIndex: async def _get_and_cache_vector_store_index(self, vector_store_id: str) -> VectorStoreWithIndex:

View file

@ -222,19 +222,19 @@ class QdrantVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorStoresProtoc
self.cache[vector_store_id] = index self.cache[vector_store_id] = index
return index return index
async def insert_chunks(self, vector_db_id: str, chunks: list[Chunk], ttl_seconds: int | None = None) -> None: async def insert_chunks(self, vector_store_id: str, chunks: list[Chunk], ttl_seconds: int | None = None) -> None:
index = await self._get_and_cache_vector_store_index(vector_db_id) index = await self._get_and_cache_vector_store_index(vector_store_id)
if not index: if not index:
raise VectorStoreNotFoundError(vector_db_id) raise VectorStoreNotFoundError(vector_store_id)
await index.insert_chunks(chunks) await index.insert_chunks(chunks)
async def query_chunks( async def query_chunks(
self, vector_db_id: str, query: InterleavedContent, params: dict[str, Any] | None = None self, vector_store_id: str, query: InterleavedContent, params: dict[str, Any] | None = None
) -> QueryChunksResponse: ) -> QueryChunksResponse:
index = await self._get_and_cache_vector_store_index(vector_db_id) index = await self._get_and_cache_vector_store_index(vector_store_id)
if not index: if not index:
raise VectorStoreNotFoundError(vector_db_id) raise VectorStoreNotFoundError(vector_store_id)
return await index.query_chunks(query, params) return await index.query_chunks(query, params)

View file

@ -366,19 +366,19 @@ class WeaviateVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, NeedsRequestProv
self.cache[vector_store_id] = index self.cache[vector_store_id] = index
return index return index
async def insert_chunks(self, vector_db_id: str, chunks: list[Chunk], ttl_seconds: int | None = None) -> None: async def insert_chunks(self, vector_store_id: str, chunks: list[Chunk], ttl_seconds: int | None = None) -> None:
index = await self._get_and_cache_vector_store_index(vector_db_id) index = await self._get_and_cache_vector_store_index(vector_store_id)
if not index: if not index:
raise VectorStoreNotFoundError(vector_db_id) raise VectorStoreNotFoundError(vector_store_id)
await index.insert_chunks(chunks) await index.insert_chunks(chunks)
async def query_chunks( async def query_chunks(
self, vector_db_id: str, query: InterleavedContent, params: dict[str, Any] | None = None self, vector_store_id: str, query: InterleavedContent, params: dict[str, Any] | None = None
) -> QueryChunksResponse: ) -> QueryChunksResponse:
index = await self._get_and_cache_vector_store_index(vector_db_id) index = await self._get_and_cache_vector_store_index(vector_store_id)
if not index: if not index:
raise VectorStoreNotFoundError(vector_db_id) raise VectorStoreNotFoundError(vector_store_id)
return await index.query_chunks(query, params) return await index.query_chunks(query, params)

View file

@ -333,7 +333,7 @@ class OpenAIVectorStoreMixin(ABC):
@abstractmethod @abstractmethod
async def insert_chunks( async def insert_chunks(
self, self,
vector_db_id: str, vector_store_id: str,
chunks: list[Chunk], chunks: list[Chunk],
ttl_seconds: int | None = None, ttl_seconds: int | None = None,
) -> None: ) -> None:
@ -342,7 +342,7 @@ class OpenAIVectorStoreMixin(ABC):
@abstractmethod @abstractmethod
async def query_chunks( async def query_chunks(
self, vector_db_id: str, query: Any, params: dict[str, Any] | None = None self, vector_store_id: str, query: Any, params: dict[str, Any] | None = None
) -> QueryChunksResponse: ) -> QueryChunksResponse:
"""Query chunks from a vector database (provider-specific implementation).""" """Query chunks from a vector database (provider-specific implementation)."""
pass pass
@ -609,7 +609,7 @@ class OpenAIVectorStoreMixin(ABC):
# TODO: Add support for ranking_options.ranker # TODO: Add support for ranking_options.ranker
response = await self.query_chunks( response = await self.query_chunks(
vector_db_id=vector_store_id, vector_store_id=vector_store_id,
query=search_query, query=search_query,
params=params, params=params,
) )
@ -803,7 +803,7 @@ class OpenAIVectorStoreMixin(ABC):
) )
else: else:
await self.insert_chunks( await self.insert_chunks(
vector_db_id=vector_store_id, vector_store_id=vector_store_id,
chunks=chunks, chunks=chunks,
) )
vector_store_file_object.status = "completed" vector_store_file_object.status = "completed"

View file

@ -367,7 +367,7 @@ def test_openai_vector_store_with_chunks(
# Insert chunks using the native LlamaStack API (since OpenAI API doesn't have direct chunk insertion) # Insert chunks using the native LlamaStack API (since OpenAI API doesn't have direct chunk insertion)
llama_client.vector_io.insert( llama_client.vector_io.insert(
vector_db_id=vector_store.id, vector_store_id=vector_store.id,
chunks=sample_chunks, chunks=sample_chunks,
) )
@ -434,7 +434,7 @@ def test_openai_vector_store_search_relevance(
# Insert chunks using native API # Insert chunks using native API
llama_client.vector_io.insert( llama_client.vector_io.insert(
vector_db_id=vector_store.id, vector_store_id=vector_store.id,
chunks=sample_chunks, chunks=sample_chunks,
) )
@ -484,7 +484,7 @@ def test_openai_vector_store_search_with_ranking_options(
# Insert chunks # Insert chunks
llama_client.vector_io.insert( llama_client.vector_io.insert(
vector_db_id=vector_store.id, vector_store_id=vector_store.id,
chunks=sample_chunks, chunks=sample_chunks,
) )
@ -544,7 +544,7 @@ def test_openai_vector_store_search_with_high_score_filter(
# Insert chunks # Insert chunks
llama_client.vector_io.insert( llama_client.vector_io.insert(
vector_db_id=vector_store.id, vector_store_id=vector_store.id,
chunks=sample_chunks, chunks=sample_chunks,
) )
@ -610,7 +610,7 @@ def test_openai_vector_store_search_with_max_num_results(
# Insert chunks # Insert chunks
llama_client.vector_io.insert( llama_client.vector_io.insert(
vector_db_id=vector_store.id, vector_store_id=vector_store.id,
chunks=sample_chunks, chunks=sample_chunks,
) )
@ -1175,7 +1175,7 @@ def test_openai_vector_store_search_modes(
) )
client_with_models.vector_io.insert( client_with_models.vector_io.insert(
vector_db_id=vector_store.id, vector_store_id=vector_store.id,
chunks=sample_chunks, chunks=sample_chunks,
) )
query = "Python programming language" query = "Python programming language"

View file

@ -123,12 +123,12 @@ def test_insert_chunks(
actual_vector_store_id = create_response.id actual_vector_store_id = create_response.id
client_with_empty_registry.vector_io.insert( client_with_empty_registry.vector_io.insert(
vector_db_id=actual_vector_store_id, vector_store_id=actual_vector_store_id,
chunks=sample_chunks, chunks=sample_chunks,
) )
response = client_with_empty_registry.vector_io.query( response = client_with_empty_registry.vector_io.query(
vector_db_id=actual_vector_store_id, vector_store_id=actual_vector_store_id,
query="What is the capital of France?", query="What is the capital of France?",
) )
assert response is not None assert response is not None
@ -137,7 +137,7 @@ def test_insert_chunks(
query, expected_doc_id = test_case query, expected_doc_id = test_case
response = client_with_empty_registry.vector_io.query( response = client_with_empty_registry.vector_io.query(
vector_db_id=actual_vector_store_id, vector_store_id=actual_vector_store_id,
query=query, query=query,
) )
assert response is not None assert response is not None
@ -174,13 +174,13 @@ def test_insert_chunks_with_precomputed_embeddings(
] ]
client_with_empty_registry.vector_io.insert( client_with_empty_registry.vector_io.insert(
vector_db_id=actual_vector_store_id, vector_store_id=actual_vector_store_id,
chunks=chunks_with_embeddings, chunks=chunks_with_embeddings,
) )
provider = [p.provider_id for p in client_with_empty_registry.providers.list() if p.api == "vector_io"][0] provider = [p.provider_id for p in client_with_empty_registry.providers.list() if p.api == "vector_io"][0]
response = client_with_empty_registry.vector_io.query( response = client_with_empty_registry.vector_io.query(
vector_db_id=actual_vector_store_id, vector_store_id=actual_vector_store_id,
query="precomputed embedding test", query="precomputed embedding test",
params=vector_io_provider_params_dict.get(provider, None), params=vector_io_provider_params_dict.get(provider, None),
) )
@ -224,13 +224,13 @@ def test_query_returns_valid_object_when_identical_to_embedding_in_vdb(
] ]
client_with_empty_registry.vector_io.insert( client_with_empty_registry.vector_io.insert(
vector_db_id=actual_vector_store_id, vector_store_id=actual_vector_store_id,
chunks=chunks_with_embeddings, chunks=chunks_with_embeddings,
) )
provider = [p.provider_id for p in client_with_empty_registry.providers.list() if p.api == "vector_io"][0] provider = [p.provider_id for p in client_with_empty_registry.providers.list() if p.api == "vector_io"][0]
response = client_with_empty_registry.vector_io.query( response = client_with_empty_registry.vector_io.query(
vector_db_id=actual_vector_store_id, vector_store_id=actual_vector_store_id,
query="duplicate", query="duplicate",
params=vector_io_provider_params_dict.get(provider, None), params=vector_io_provider_params_dict.get(provider, None),
) )

View file

@ -10,47 +10,124 @@ from unittest.mock import MagicMock
import pytest import pytest
from llama_stack.core.request_headers import request_provider_data_context from llama_stack.core.request_headers import request_provider_data_context
from llama_stack.providers.remote.inference.anthropic.anthropic import AnthropicInferenceAdapter
from llama_stack.providers.remote.inference.anthropic.config import AnthropicConfig
from llama_stack.providers.remote.inference.cerebras.cerebras import CerebrasInferenceAdapter
from llama_stack.providers.remote.inference.cerebras.config import CerebrasImplConfig
from llama_stack.providers.remote.inference.databricks.config import DatabricksImplConfig
from llama_stack.providers.remote.inference.databricks.databricks import DatabricksInferenceAdapter
from llama_stack.providers.remote.inference.fireworks.config import FireworksImplConfig
from llama_stack.providers.remote.inference.fireworks.fireworks import FireworksInferenceAdapter
from llama_stack.providers.remote.inference.gemini.config import GeminiConfig
from llama_stack.providers.remote.inference.gemini.gemini import GeminiInferenceAdapter
from llama_stack.providers.remote.inference.groq.config import GroqConfig from llama_stack.providers.remote.inference.groq.config import GroqConfig
from llama_stack.providers.remote.inference.groq.groq import GroqInferenceAdapter from llama_stack.providers.remote.inference.groq.groq import GroqInferenceAdapter
from llama_stack.providers.remote.inference.llama_openai_compat.config import LlamaCompatConfig from llama_stack.providers.remote.inference.llama_openai_compat.config import LlamaCompatConfig
from llama_stack.providers.remote.inference.llama_openai_compat.llama import LlamaCompatInferenceAdapter from llama_stack.providers.remote.inference.llama_openai_compat.llama import LlamaCompatInferenceAdapter
from llama_stack.providers.remote.inference.nvidia.config import NVIDIAConfig
from llama_stack.providers.remote.inference.nvidia.nvidia import NVIDIAInferenceAdapter
from llama_stack.providers.remote.inference.openai.config import OpenAIConfig from llama_stack.providers.remote.inference.openai.config import OpenAIConfig
from llama_stack.providers.remote.inference.openai.openai import OpenAIInferenceAdapter from llama_stack.providers.remote.inference.openai.openai import OpenAIInferenceAdapter
from llama_stack.providers.remote.inference.runpod.config import RunpodImplConfig
from llama_stack.providers.remote.inference.runpod.runpod import RunpodInferenceAdapter
from llama_stack.providers.remote.inference.sambanova.config import SambaNovaImplConfig
from llama_stack.providers.remote.inference.sambanova.sambanova import SambaNovaInferenceAdapter
from llama_stack.providers.remote.inference.together.config import TogetherImplConfig from llama_stack.providers.remote.inference.together.config import TogetherImplConfig
from llama_stack.providers.remote.inference.together.together import TogetherInferenceAdapter from llama_stack.providers.remote.inference.together.together import TogetherInferenceAdapter
from llama_stack.providers.remote.inference.vllm.config import VLLMInferenceAdapterConfig
from llama_stack.providers.remote.inference.vllm.vllm import VLLMInferenceAdapter
from llama_stack.providers.remote.inference.watsonx.config import WatsonXConfig from llama_stack.providers.remote.inference.watsonx.config import WatsonXConfig
from llama_stack.providers.remote.inference.watsonx.watsonx import WatsonXInferenceAdapter from llama_stack.providers.remote.inference.watsonx.watsonx import WatsonXInferenceAdapter
@pytest.mark.parametrize( @pytest.mark.parametrize(
"config_cls,adapter_cls,provider_data_validator", "config_cls,adapter_cls,provider_data_validator,config_params",
[ [
( (
GroqConfig, GroqConfig,
GroqInferenceAdapter, GroqInferenceAdapter,
"llama_stack.providers.remote.inference.groq.config.GroqProviderDataValidator", "llama_stack.providers.remote.inference.groq.config.GroqProviderDataValidator",
{},
), ),
( (
OpenAIConfig, OpenAIConfig,
OpenAIInferenceAdapter, OpenAIInferenceAdapter,
"llama_stack.providers.remote.inference.openai.config.OpenAIProviderDataValidator", "llama_stack.providers.remote.inference.openai.config.OpenAIProviderDataValidator",
{},
), ),
( (
TogetherImplConfig, TogetherImplConfig,
TogetherInferenceAdapter, TogetherInferenceAdapter,
"llama_stack.providers.remote.inference.together.TogetherProviderDataValidator", "llama_stack.providers.remote.inference.together.TogetherProviderDataValidator",
{},
), ),
( (
LlamaCompatConfig, LlamaCompatConfig,
LlamaCompatInferenceAdapter, LlamaCompatInferenceAdapter,
"llama_stack.providers.remote.inference.llama_openai_compat.config.LlamaProviderDataValidator", "llama_stack.providers.remote.inference.llama_openai_compat.config.LlamaProviderDataValidator",
{},
),
(
CerebrasImplConfig,
CerebrasInferenceAdapter,
"llama_stack.providers.remote.inference.cerebras.config.CerebrasProviderDataValidator",
{},
),
(
DatabricksImplConfig,
DatabricksInferenceAdapter,
"llama_stack.providers.remote.inference.databricks.config.DatabricksProviderDataValidator",
{},
),
(
NVIDIAConfig,
NVIDIAInferenceAdapter,
"llama_stack.providers.remote.inference.nvidia.config.NVIDIAProviderDataValidator",
{},
),
(
RunpodImplConfig,
RunpodInferenceAdapter,
"llama_stack.providers.remote.inference.runpod.config.RunpodProviderDataValidator",
{},
),
(
FireworksImplConfig,
FireworksInferenceAdapter,
"llama_stack.providers.remote.inference.fireworks.FireworksProviderDataValidator",
{},
),
(
AnthropicConfig,
AnthropicInferenceAdapter,
"llama_stack.providers.remote.inference.anthropic.config.AnthropicProviderDataValidator",
{},
),
(
GeminiConfig,
GeminiInferenceAdapter,
"llama_stack.providers.remote.inference.gemini.config.GeminiProviderDataValidator",
{},
),
(
SambaNovaImplConfig,
SambaNovaInferenceAdapter,
"llama_stack.providers.remote.inference.sambanova.config.SambaNovaProviderDataValidator",
{},
),
(
VLLMInferenceAdapterConfig,
VLLMInferenceAdapter,
"llama_stack.providers.remote.inference.vllm.VLLMProviderDataValidator",
{
"url": "http://fake",
},
), ),
], ],
) )
def test_openai_provider_data_used(config_cls, adapter_cls, provider_data_validator: str): def test_openai_provider_data_used(config_cls, adapter_cls, provider_data_validator: str, config_params: dict):
"""Ensure the OpenAI provider does not cache api keys across client requests""" """Ensure the OpenAI provider does not cache api keys across client requests"""
inference_adapter = adapter_cls(config=config_cls(**config_params))
inference_adapter = adapter_cls(config=config_cls())
inference_adapter.__provider_spec__ = MagicMock() inference_adapter.__provider_spec__ = MagicMock()
inference_adapter.__provider_spec__.provider_data_validator = provider_data_validator inference_adapter.__provider_spec__.provider_data_validator = provider_data_validator

View file

@ -23,14 +23,14 @@ class TestRagQuery:
config=MagicMock(), vector_io_api=MagicMock(), inference_api=MagicMock(), files_api=MagicMock() config=MagicMock(), vector_io_api=MagicMock(), inference_api=MagicMock(), files_api=MagicMock()
) )
with pytest.raises(ValueError): with pytest.raises(ValueError):
await rag_tool.query(content=MagicMock(), vector_db_ids=[]) await rag_tool.query(content=MagicMock(), vector_store_ids=[])
async def test_query_chunk_metadata_handling(self): async def test_query_chunk_metadata_handling(self):
rag_tool = MemoryToolRuntimeImpl( rag_tool = MemoryToolRuntimeImpl(
config=MagicMock(), vector_io_api=MagicMock(), inference_api=MagicMock(), files_api=MagicMock() config=MagicMock(), vector_io_api=MagicMock(), inference_api=MagicMock(), files_api=MagicMock()
) )
content = "test query content" content = "test query content"
vector_db_ids = ["db1"] vector_store_ids = ["db1"]
chunk_metadata = ChunkMetadata( chunk_metadata = ChunkMetadata(
document_id="doc1", document_id="doc1",
@ -55,7 +55,7 @@ class TestRagQuery:
query_response = QueryChunksResponse(chunks=[chunk], scores=[1.0]) query_response = QueryChunksResponse(chunks=[chunk], scores=[1.0])
rag_tool.vector_io_api.query_chunks = AsyncMock(return_value=query_response) rag_tool.vector_io_api.query_chunks = AsyncMock(return_value=query_response)
result = await rag_tool.query(content=content, vector_db_ids=vector_db_ids) result = await rag_tool.query(content=content, vector_store_ids=vector_store_ids)
assert result is not None assert result is not None
expected_metadata_string = ( expected_metadata_string = (
@ -90,7 +90,7 @@ class TestRagQuery:
files_api=MagicMock(), files_api=MagicMock(),
) )
vector_db_ids = ["db1", "db2"] vector_store_ids = ["db1", "db2"]
# Fake chunks from each DB # Fake chunks from each DB
chunk_metadata1 = ChunkMetadata( chunk_metadata1 = ChunkMetadata(
@ -101,7 +101,7 @@ class TestRagQuery:
) )
chunk1 = Chunk( chunk1 = Chunk(
content="chunk from db1", content="chunk from db1",
metadata={"vector_db_id": "db1", "document_id": "doc1"}, metadata={"vector_store_id": "db1", "document_id": "doc1"},
stored_chunk_id="c1", stored_chunk_id="c1",
chunk_metadata=chunk_metadata1, chunk_metadata=chunk_metadata1,
) )
@ -114,7 +114,7 @@ class TestRagQuery:
) )
chunk2 = Chunk( chunk2 = Chunk(
content="chunk from db2", content="chunk from db2",
metadata={"vector_db_id": "db2", "document_id": "doc2"}, metadata={"vector_store_id": "db2", "document_id": "doc2"},
stored_chunk_id="c2", stored_chunk_id="c2",
chunk_metadata=chunk_metadata2, chunk_metadata=chunk_metadata2,
) )
@ -126,13 +126,13 @@ class TestRagQuery:
] ]
) )
result = await rag_tool.query(content="test", vector_db_ids=vector_db_ids) result = await rag_tool.query(content="test", vector_store_ids=vector_store_ids)
returned_chunks = result.metadata["chunks"] returned_chunks = result.metadata["chunks"]
returned_scores = result.metadata["scores"] returned_scores = result.metadata["scores"]
returned_doc_ids = result.metadata["document_ids"] returned_doc_ids = result.metadata["document_ids"]
returned_vector_db_ids = result.metadata["vector_db_ids"] returned_vector_store_ids = result.metadata["vector_store_ids"]
assert returned_chunks == ["chunk from db1", "chunk from db2"] assert returned_chunks == ["chunk from db1", "chunk from db2"]
assert returned_scores == (0.9, 0.8) assert returned_scores == (0.9, 0.8)
assert returned_doc_ids == ["doc1", "doc2"] assert returned_doc_ids == ["doc1", "doc2"]
assert returned_vector_db_ids == ["db1", "db2"] assert returned_vector_store_ids == ["db1", "db2"]

45
uv.lock generated
View file

@ -129,6 +129,25 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/78/b6/6307fbef88d9b5ee7421e68d78a9f162e0da4900bc5f5793f6d3d0e34fb8/annotated_types-0.7.0-py3-none-any.whl", hash = "sha256:1f02e8b43a8fbbc3f3e0d4f0f4bfc8131bcb4eebe8849b8e5c773f3a1c582a53", size = 13643, upload-time = "2024-05-20T21:33:24.1Z" }, { url = "https://files.pythonhosted.org/packages/78/b6/6307fbef88d9b5ee7421e68d78a9f162e0da4900bc5f5793f6d3d0e34fb8/annotated_types-0.7.0-py3-none-any.whl", hash = "sha256:1f02e8b43a8fbbc3f3e0d4f0f4bfc8131bcb4eebe8849b8e5c773f3a1c582a53", size = 13643, upload-time = "2024-05-20T21:33:24.1Z" },
] ]
[[package]]
name = "anthropic"
version = "0.69.0"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "anyio" },
{ name = "distro" },
{ name = "docstring-parser" },
{ name = "httpx" },
{ name = "jiter" },
{ name = "pydantic" },
{ name = "sniffio" },
{ name = "typing-extensions" },
]
sdist = { url = "https://files.pythonhosted.org/packages/c8/9d/9ad1778b95f15c5b04e7d328c1b5f558f1e893857b7c33cd288c19c0057a/anthropic-0.69.0.tar.gz", hash = "sha256:c604d287f4d73640f40bd2c0f3265a2eb6ce034217ead0608f6b07a8bc5ae5f2", size = 480622, upload-time = "2025-09-29T16:53:45.282Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/9b/38/75129688de5637eb5b383e5f2b1570a5cc3aecafa4de422da8eea4b90a6c/anthropic-0.69.0-py3-none-any.whl", hash = "sha256:1f73193040f33f11e27c2cd6ec25f24fe7c3f193dc1c5cde6b7a08b18a16bcc5", size = 337265, upload-time = "2025-09-29T16:53:43.686Z" },
]
[[package]] [[package]]
name = "anyio" name = "anyio"
version = "4.9.0" version = "4.9.0"
@ -758,6 +777,19 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/79/b3/28ac139109d9005ad3f6b6f8976ffede6706a6478e21c889ce36c840918e/cryptography-45.0.5-cp37-abi3-win_amd64.whl", hash = "sha256:90cb0a7bb35959f37e23303b7eed0a32280510030daba3f7fdfbb65defde6a97", size = 3390016, upload-time = "2025-07-02T13:05:50.811Z" }, { url = "https://files.pythonhosted.org/packages/79/b3/28ac139109d9005ad3f6b6f8976ffede6706a6478e21c889ce36c840918e/cryptography-45.0.5-cp37-abi3-win_amd64.whl", hash = "sha256:90cb0a7bb35959f37e23303b7eed0a32280510030daba3f7fdfbb65defde6a97", size = 3390016, upload-time = "2025-07-02T13:05:50.811Z" },
] ]
[[package]]
name = "databricks-sdk"
version = "0.67.0"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "google-auth" },
{ name = "requests" },
]
sdist = { url = "https://files.pythonhosted.org/packages/b3/5b/df3e5424d833e4f3f9b42c409ef8b513e468c9cdf06c2a9935c6cbc4d128/databricks_sdk-0.67.0.tar.gz", hash = "sha256:f923227babcaad428b0c2eede2755ebe9deb996e2c8654f179eb37f486b37a36", size = 761000, upload-time = "2025-09-25T13:32:10.858Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/a0/ca/2aff3817041483fb8e4f75a74a36ff4ca3a826e276becd1179a591b6348f/databricks_sdk-0.67.0-py3-none-any.whl", hash = "sha256:ef49e49db45ed12c015a32a6f9d4ba395850f25bb3dcffdcaf31a5167fe03ee2", size = 718422, upload-time = "2025-09-25T13:32:09.011Z" },
]
[[package]] [[package]]
name = "datasets" name = "datasets"
version = "4.0.0" version = "4.0.0"
@ -856,6 +888,15 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/12/b3/231ffd4ab1fc9d679809f356cebee130ac7daa00d6d6f3206dd4fd137e9e/distro-1.9.0-py3-none-any.whl", hash = "sha256:7bffd925d65168f85027d8da9af6bddab658135b840670a223589bc0c8ef02b2", size = 20277, upload-time = "2023-12-24T09:54:30.421Z" }, { url = "https://files.pythonhosted.org/packages/12/b3/231ffd4ab1fc9d679809f356cebee130ac7daa00d6d6f3206dd4fd137e9e/distro-1.9.0-py3-none-any.whl", hash = "sha256:7bffd925d65168f85027d8da9af6bddab658135b840670a223589bc0c8ef02b2", size = 20277, upload-time = "2023-12-24T09:54:30.421Z" },
] ]
[[package]]
name = "docstring-parser"
version = "0.17.0"
source = { registry = "https://pypi.org/simple" }
sdist = { url = "https://files.pythonhosted.org/packages/b2/9d/c3b43da9515bd270df0f80548d9944e389870713cc1fe2b8fb35fe2bcefd/docstring_parser-0.17.0.tar.gz", hash = "sha256:583de4a309722b3315439bb31d64ba3eebada841f2e2cee23b99df001434c912", size = 27442, upload-time = "2025-07-21T07:35:01.868Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/55/e2/2537ebcff11c1ee1ff17d8d0b6f4db75873e3b0fb32c2d4a2ee31ecb310a/docstring_parser-0.17.0-py3-none-any.whl", hash = "sha256:cf2569abd23dce8099b300f9b4fa8191e9582dda731fd533daf54c4551658708", size = 36896, upload-time = "2025-07-21T07:35:00.684Z" },
]
[[package]] [[package]]
name = "docutils" name = "docutils"
version = "0.21.2" version = "0.21.2"
@ -1863,9 +1904,11 @@ test = [
unit = [ unit = [
{ name = "aiohttp" }, { name = "aiohttp" },
{ name = "aiosqlite" }, { name = "aiosqlite" },
{ name = "anthropic" },
{ name = "blobfile" }, { name = "blobfile" },
{ name = "chardet" }, { name = "chardet" },
{ name = "coverage" }, { name = "coverage" },
{ name = "databricks-sdk" },
{ name = "faiss-cpu" }, { name = "faiss-cpu" },
{ name = "litellm" }, { name = "litellm" },
{ name = "mcp" }, { name = "mcp" },
@ -1978,9 +2021,11 @@ test = [
unit = [ unit = [
{ name = "aiohttp" }, { name = "aiohttp" },
{ name = "aiosqlite" }, { name = "aiosqlite" },
{ name = "anthropic" },
{ name = "blobfile" }, { name = "blobfile" },
{ name = "chardet" }, { name = "chardet" },
{ name = "coverage" }, { name = "coverage" },
{ name = "databricks-sdk" },
{ name = "faiss-cpu" }, { name = "faiss-cpu" },
{ name = "litellm" }, { name = "litellm" },
{ name = "mcp" }, { name = "mcp" },