mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-03 18:00:36 +00:00
Merge 960f5e4cd4 into sapling-pr-archive-ehhuang
This commit is contained in:
commit
bf8f6b6914
41 changed files with 346 additions and 183 deletions
|
|
@ -9862,7 +9862,7 @@ components:
|
|||
$ref: '#/components/schemas/RAGDocument'
|
||||
description: >-
|
||||
List of documents to index in the RAG system
|
||||
vector_db_id:
|
||||
vector_store_id:
|
||||
type: string
|
||||
description: >-
|
||||
ID of the vector database to store the document embeddings
|
||||
|
|
@ -9873,7 +9873,7 @@ components:
|
|||
additionalProperties: false
|
||||
required:
|
||||
- documents
|
||||
- vector_db_id
|
||||
- vector_store_id
|
||||
- chunk_size_in_tokens
|
||||
title: InsertRequest
|
||||
DefaultRAGQueryGeneratorConfig:
|
||||
|
|
@ -10044,7 +10044,7 @@ components:
|
|||
$ref: '#/components/schemas/InterleavedContent'
|
||||
description: >-
|
||||
The query content to search for in the indexed documents
|
||||
vector_db_ids:
|
||||
vector_store_ids:
|
||||
type: array
|
||||
items:
|
||||
type: string
|
||||
|
|
@ -10057,7 +10057,7 @@ components:
|
|||
additionalProperties: false
|
||||
required:
|
||||
- content
|
||||
- vector_db_ids
|
||||
- vector_store_ids
|
||||
title: QueryRequest
|
||||
RAGQueryResult:
|
||||
type: object
|
||||
|
|
@ -10281,7 +10281,7 @@ components:
|
|||
InsertChunksRequest:
|
||||
type: object
|
||||
properties:
|
||||
vector_db_id:
|
||||
vector_store_id:
|
||||
type: string
|
||||
description: >-
|
||||
The identifier of the vector database to insert the chunks into.
|
||||
|
|
@ -10300,13 +10300,13 @@ components:
|
|||
description: The time to live of the chunks.
|
||||
additionalProperties: false
|
||||
required:
|
||||
- vector_db_id
|
||||
- vector_store_id
|
||||
- chunks
|
||||
title: InsertChunksRequest
|
||||
QueryChunksRequest:
|
||||
type: object
|
||||
properties:
|
||||
vector_db_id:
|
||||
vector_store_id:
|
||||
type: string
|
||||
description: >-
|
||||
The identifier of the vector database to query.
|
||||
|
|
@ -10326,7 +10326,7 @@ components:
|
|||
description: The parameters of the query.
|
||||
additionalProperties: false
|
||||
required:
|
||||
- vector_db_id
|
||||
- vector_store_id
|
||||
- query
|
||||
title: QueryChunksRequest
|
||||
QueryChunksResponse:
|
||||
|
|
@ -11844,7 +11844,7 @@ components:
|
|||
description: Type of the step in an agent turn.
|
||||
const: memory_retrieval
|
||||
default: memory_retrieval
|
||||
vector_db_ids:
|
||||
vector_store_ids:
|
||||
type: string
|
||||
description: >-
|
||||
The IDs of the vector databases to retrieve context from.
|
||||
|
|
@ -11857,7 +11857,7 @@ components:
|
|||
- turn_id
|
||||
- step_id
|
||||
- step_type
|
||||
- vector_db_ids
|
||||
- vector_store_ids
|
||||
- inserted_context
|
||||
title: MemoryRetrievalStep
|
||||
description: >-
|
||||
|
|
|
|||
|
|
@ -72,14 +72,14 @@ description: |
|
|||
Example with hybrid search:
|
||||
```python
|
||||
response = await vector_io.query_chunks(
|
||||
vector_db_id="my_db",
|
||||
vector_store_id="my_db",
|
||||
query="your query here",
|
||||
params={"mode": "hybrid", "max_chunks": 3, "score_threshold": 0.7},
|
||||
)
|
||||
|
||||
# Using RRF ranker
|
||||
response = await vector_io.query_chunks(
|
||||
vector_db_id="my_db",
|
||||
vector_store_id="my_db",
|
||||
query="your query here",
|
||||
params={
|
||||
"mode": "hybrid",
|
||||
|
|
@ -91,7 +91,7 @@ description: |
|
|||
|
||||
# Using weighted ranker
|
||||
response = await vector_io.query_chunks(
|
||||
vector_db_id="my_db",
|
||||
vector_store_id="my_db",
|
||||
query="your query here",
|
||||
params={
|
||||
"mode": "hybrid",
|
||||
|
|
@ -105,7 +105,7 @@ description: |
|
|||
Example with explicit vector search:
|
||||
```python
|
||||
response = await vector_io.query_chunks(
|
||||
vector_db_id="my_db",
|
||||
vector_store_id="my_db",
|
||||
query="your query here",
|
||||
params={"mode": "vector", "max_chunks": 3, "score_threshold": 0.7},
|
||||
)
|
||||
|
|
@ -114,7 +114,7 @@ description: |
|
|||
Example with keyword search:
|
||||
```python
|
||||
response = await vector_io.query_chunks(
|
||||
vector_db_id="my_db",
|
||||
vector_store_id="my_db",
|
||||
query="your query here",
|
||||
params={"mode": "keyword", "max_chunks": 3, "score_threshold": 0.7},
|
||||
)
|
||||
|
|
@ -277,14 +277,14 @@ The SQLite-vec provider supports three search modes:
|
|||
Example with hybrid search:
|
||||
```python
|
||||
response = await vector_io.query_chunks(
|
||||
vector_db_id="my_db",
|
||||
vector_store_id="my_db",
|
||||
query="your query here",
|
||||
params={"mode": "hybrid", "max_chunks": 3, "score_threshold": 0.7},
|
||||
)
|
||||
|
||||
# Using RRF ranker
|
||||
response = await vector_io.query_chunks(
|
||||
vector_db_id="my_db",
|
||||
vector_store_id="my_db",
|
||||
query="your query here",
|
||||
params={
|
||||
"mode": "hybrid",
|
||||
|
|
@ -296,7 +296,7 @@ response = await vector_io.query_chunks(
|
|||
|
||||
# Using weighted ranker
|
||||
response = await vector_io.query_chunks(
|
||||
vector_db_id="my_db",
|
||||
vector_store_id="my_db",
|
||||
query="your query here",
|
||||
params={
|
||||
"mode": "hybrid",
|
||||
|
|
@ -310,7 +310,7 @@ response = await vector_io.query_chunks(
|
|||
Example with explicit vector search:
|
||||
```python
|
||||
response = await vector_io.query_chunks(
|
||||
vector_db_id="my_db",
|
||||
vector_store_id="my_db",
|
||||
query="your query here",
|
||||
params={"mode": "vector", "max_chunks": 3, "score_threshold": 0.7},
|
||||
)
|
||||
|
|
@ -319,7 +319,7 @@ response = await vector_io.query_chunks(
|
|||
Example with keyword search:
|
||||
```python
|
||||
response = await vector_io.query_chunks(
|
||||
vector_db_id="my_db",
|
||||
vector_store_id="my_db",
|
||||
query="your query here",
|
||||
params={"mode": "keyword", "max_chunks": 3, "score_threshold": 0.7},
|
||||
)
|
||||
|
|
|
|||
4
docs/static/deprecated-llama-stack-spec.html
vendored
4
docs/static/deprecated-llama-stack-spec.html
vendored
|
|
@ -4390,7 +4390,7 @@
|
|||
"const": "memory_retrieval",
|
||||
"default": "memory_retrieval"
|
||||
},
|
||||
"vector_db_ids": {
|
||||
"vector_store_ids": {
|
||||
"type": "string",
|
||||
"description": "The IDs of the vector databases to retrieve context from."
|
||||
},
|
||||
|
|
@ -4404,7 +4404,7 @@
|
|||
"turn_id",
|
||||
"step_id",
|
||||
"step_type",
|
||||
"vector_db_ids",
|
||||
"vector_store_ids",
|
||||
"inserted_context"
|
||||
],
|
||||
"title": "MemoryRetrievalStep",
|
||||
|
|
|
|||
4
docs/static/deprecated-llama-stack-spec.yaml
vendored
4
docs/static/deprecated-llama-stack-spec.yaml
vendored
|
|
@ -3252,7 +3252,7 @@ components:
|
|||
description: Type of the step in an agent turn.
|
||||
const: memory_retrieval
|
||||
default: memory_retrieval
|
||||
vector_db_ids:
|
||||
vector_store_ids:
|
||||
type: string
|
||||
description: >-
|
||||
The IDs of the vector databases to retrieve context from.
|
||||
|
|
@ -3265,7 +3265,7 @@ components:
|
|||
- turn_id
|
||||
- step_id
|
||||
- step_type
|
||||
- vector_db_ids
|
||||
- vector_store_ids
|
||||
- inserted_context
|
||||
title: MemoryRetrievalStep
|
||||
description: >-
|
||||
|
|
|
|||
|
|
@ -2865,7 +2865,7 @@
|
|||
"const": "memory_retrieval",
|
||||
"default": "memory_retrieval"
|
||||
},
|
||||
"vector_db_ids": {
|
||||
"vector_store_ids": {
|
||||
"type": "string",
|
||||
"description": "The IDs of the vector databases to retrieve context from."
|
||||
},
|
||||
|
|
@ -2879,7 +2879,7 @@
|
|||
"turn_id",
|
||||
"step_id",
|
||||
"step_type",
|
||||
"vector_db_ids",
|
||||
"vector_store_ids",
|
||||
"inserted_context"
|
||||
],
|
||||
"title": "MemoryRetrievalStep",
|
||||
|
|
|
|||
|
|
@ -2085,7 +2085,7 @@ components:
|
|||
description: Type of the step in an agent turn.
|
||||
const: memory_retrieval
|
||||
default: memory_retrieval
|
||||
vector_db_ids:
|
||||
vector_store_ids:
|
||||
type: string
|
||||
description: >-
|
||||
The IDs of the vector databases to retrieve context from.
|
||||
|
|
@ -2098,7 +2098,7 @@ components:
|
|||
- turn_id
|
||||
- step_id
|
||||
- step_type
|
||||
- vector_db_ids
|
||||
- vector_store_ids
|
||||
- inserted_context
|
||||
title: MemoryRetrievalStep
|
||||
description: >-
|
||||
|
|
|
|||
16
docs/static/llama-stack-spec.html
vendored
16
docs/static/llama-stack-spec.html
vendored
|
|
@ -11412,7 +11412,7 @@
|
|||
},
|
||||
"description": "List of documents to index in the RAG system"
|
||||
},
|
||||
"vector_db_id": {
|
||||
"vector_store_id": {
|
||||
"type": "string",
|
||||
"description": "ID of the vector database to store the document embeddings"
|
||||
},
|
||||
|
|
@ -11424,7 +11424,7 @@
|
|||
"additionalProperties": false,
|
||||
"required": [
|
||||
"documents",
|
||||
"vector_db_id",
|
||||
"vector_store_id",
|
||||
"chunk_size_in_tokens"
|
||||
],
|
||||
"title": "InsertRequest"
|
||||
|
|
@ -11615,7 +11615,7 @@
|
|||
"$ref": "#/components/schemas/InterleavedContent",
|
||||
"description": "The query content to search for in the indexed documents"
|
||||
},
|
||||
"vector_db_ids": {
|
||||
"vector_store_ids": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "string"
|
||||
|
|
@ -11630,7 +11630,7 @@
|
|||
"additionalProperties": false,
|
||||
"required": [
|
||||
"content",
|
||||
"vector_db_ids"
|
||||
"vector_store_ids"
|
||||
],
|
||||
"title": "QueryRequest"
|
||||
},
|
||||
|
|
@ -11923,7 +11923,7 @@
|
|||
"InsertChunksRequest": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"vector_db_id": {
|
||||
"vector_store_id": {
|
||||
"type": "string",
|
||||
"description": "The identifier of the vector database to insert the chunks into."
|
||||
},
|
||||
|
|
@ -11941,7 +11941,7 @@
|
|||
},
|
||||
"additionalProperties": false,
|
||||
"required": [
|
||||
"vector_db_id",
|
||||
"vector_store_id",
|
||||
"chunks"
|
||||
],
|
||||
"title": "InsertChunksRequest"
|
||||
|
|
@ -11949,7 +11949,7 @@
|
|||
"QueryChunksRequest": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"vector_db_id": {
|
||||
"vector_store_id": {
|
||||
"type": "string",
|
||||
"description": "The identifier of the vector database to query."
|
||||
},
|
||||
|
|
@ -11986,7 +11986,7 @@
|
|||
},
|
||||
"additionalProperties": false,
|
||||
"required": [
|
||||
"vector_db_id",
|
||||
"vector_store_id",
|
||||
"query"
|
||||
],
|
||||
"title": "QueryChunksRequest"
|
||||
|
|
|
|||
16
docs/static/llama-stack-spec.yaml
vendored
16
docs/static/llama-stack-spec.yaml
vendored
|
|
@ -8649,7 +8649,7 @@ components:
|
|||
$ref: '#/components/schemas/RAGDocument'
|
||||
description: >-
|
||||
List of documents to index in the RAG system
|
||||
vector_db_id:
|
||||
vector_store_id:
|
||||
type: string
|
||||
description: >-
|
||||
ID of the vector database to store the document embeddings
|
||||
|
|
@ -8660,7 +8660,7 @@ components:
|
|||
additionalProperties: false
|
||||
required:
|
||||
- documents
|
||||
- vector_db_id
|
||||
- vector_store_id
|
||||
- chunk_size_in_tokens
|
||||
title: InsertRequest
|
||||
DefaultRAGQueryGeneratorConfig:
|
||||
|
|
@ -8831,7 +8831,7 @@ components:
|
|||
$ref: '#/components/schemas/InterleavedContent'
|
||||
description: >-
|
||||
The query content to search for in the indexed documents
|
||||
vector_db_ids:
|
||||
vector_store_ids:
|
||||
type: array
|
||||
items:
|
||||
type: string
|
||||
|
|
@ -8844,7 +8844,7 @@ components:
|
|||
additionalProperties: false
|
||||
required:
|
||||
- content
|
||||
- vector_db_ids
|
||||
- vector_store_ids
|
||||
title: QueryRequest
|
||||
RAGQueryResult:
|
||||
type: object
|
||||
|
|
@ -9068,7 +9068,7 @@ components:
|
|||
InsertChunksRequest:
|
||||
type: object
|
||||
properties:
|
||||
vector_db_id:
|
||||
vector_store_id:
|
||||
type: string
|
||||
description: >-
|
||||
The identifier of the vector database to insert the chunks into.
|
||||
|
|
@ -9087,13 +9087,13 @@ components:
|
|||
description: The time to live of the chunks.
|
||||
additionalProperties: false
|
||||
required:
|
||||
- vector_db_id
|
||||
- vector_store_id
|
||||
- chunks
|
||||
title: InsertChunksRequest
|
||||
QueryChunksRequest:
|
||||
type: object
|
||||
properties:
|
||||
vector_db_id:
|
||||
vector_store_id:
|
||||
type: string
|
||||
description: >-
|
||||
The identifier of the vector database to query.
|
||||
|
|
@ -9113,7 +9113,7 @@ components:
|
|||
description: The parameters of the query.
|
||||
additionalProperties: false
|
||||
required:
|
||||
- vector_db_id
|
||||
- vector_store_id
|
||||
- query
|
||||
title: QueryChunksRequest
|
||||
QueryChunksResponse:
|
||||
|
|
|
|||
20
docs/static/stainless-llama-stack-spec.html
vendored
20
docs/static/stainless-llama-stack-spec.html
vendored
|
|
@ -13084,7 +13084,7 @@
|
|||
},
|
||||
"description": "List of documents to index in the RAG system"
|
||||
},
|
||||
"vector_db_id": {
|
||||
"vector_store_id": {
|
||||
"type": "string",
|
||||
"description": "ID of the vector database to store the document embeddings"
|
||||
},
|
||||
|
|
@ -13096,7 +13096,7 @@
|
|||
"additionalProperties": false,
|
||||
"required": [
|
||||
"documents",
|
||||
"vector_db_id",
|
||||
"vector_store_id",
|
||||
"chunk_size_in_tokens"
|
||||
],
|
||||
"title": "InsertRequest"
|
||||
|
|
@ -13287,7 +13287,7 @@
|
|||
"$ref": "#/components/schemas/InterleavedContent",
|
||||
"description": "The query content to search for in the indexed documents"
|
||||
},
|
||||
"vector_db_ids": {
|
||||
"vector_store_ids": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "string"
|
||||
|
|
@ -13302,7 +13302,7 @@
|
|||
"additionalProperties": false,
|
||||
"required": [
|
||||
"content",
|
||||
"vector_db_ids"
|
||||
"vector_store_ids"
|
||||
],
|
||||
"title": "QueryRequest"
|
||||
},
|
||||
|
|
@ -13595,7 +13595,7 @@
|
|||
"InsertChunksRequest": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"vector_db_id": {
|
||||
"vector_store_id": {
|
||||
"type": "string",
|
||||
"description": "The identifier of the vector database to insert the chunks into."
|
||||
},
|
||||
|
|
@ -13613,7 +13613,7 @@
|
|||
},
|
||||
"additionalProperties": false,
|
||||
"required": [
|
||||
"vector_db_id",
|
||||
"vector_store_id",
|
||||
"chunks"
|
||||
],
|
||||
"title": "InsertChunksRequest"
|
||||
|
|
@ -13621,7 +13621,7 @@
|
|||
"QueryChunksRequest": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"vector_db_id": {
|
||||
"vector_store_id": {
|
||||
"type": "string",
|
||||
"description": "The identifier of the vector database to query."
|
||||
},
|
||||
|
|
@ -13658,7 +13658,7 @@
|
|||
},
|
||||
"additionalProperties": false,
|
||||
"required": [
|
||||
"vector_db_id",
|
||||
"vector_store_id",
|
||||
"query"
|
||||
],
|
||||
"title": "QueryChunksRequest"
|
||||
|
|
@ -15719,7 +15719,7 @@
|
|||
"const": "memory_retrieval",
|
||||
"default": "memory_retrieval"
|
||||
},
|
||||
"vector_db_ids": {
|
||||
"vector_store_ids": {
|
||||
"type": "string",
|
||||
"description": "The IDs of the vector databases to retrieve context from."
|
||||
},
|
||||
|
|
@ -15733,7 +15733,7 @@
|
|||
"turn_id",
|
||||
"step_id",
|
||||
"step_type",
|
||||
"vector_db_ids",
|
||||
"vector_store_ids",
|
||||
"inserted_context"
|
||||
],
|
||||
"title": "MemoryRetrievalStep",
|
||||
|
|
|
|||
20
docs/static/stainless-llama-stack-spec.yaml
vendored
20
docs/static/stainless-llama-stack-spec.yaml
vendored
|
|
@ -9862,7 +9862,7 @@ components:
|
|||
$ref: '#/components/schemas/RAGDocument'
|
||||
description: >-
|
||||
List of documents to index in the RAG system
|
||||
vector_db_id:
|
||||
vector_store_id:
|
||||
type: string
|
||||
description: >-
|
||||
ID of the vector database to store the document embeddings
|
||||
|
|
@ -9873,7 +9873,7 @@ components:
|
|||
additionalProperties: false
|
||||
required:
|
||||
- documents
|
||||
- vector_db_id
|
||||
- vector_store_id
|
||||
- chunk_size_in_tokens
|
||||
title: InsertRequest
|
||||
DefaultRAGQueryGeneratorConfig:
|
||||
|
|
@ -10044,7 +10044,7 @@ components:
|
|||
$ref: '#/components/schemas/InterleavedContent'
|
||||
description: >-
|
||||
The query content to search for in the indexed documents
|
||||
vector_db_ids:
|
||||
vector_store_ids:
|
||||
type: array
|
||||
items:
|
||||
type: string
|
||||
|
|
@ -10057,7 +10057,7 @@ components:
|
|||
additionalProperties: false
|
||||
required:
|
||||
- content
|
||||
- vector_db_ids
|
||||
- vector_store_ids
|
||||
title: QueryRequest
|
||||
RAGQueryResult:
|
||||
type: object
|
||||
|
|
@ -10281,7 +10281,7 @@ components:
|
|||
InsertChunksRequest:
|
||||
type: object
|
||||
properties:
|
||||
vector_db_id:
|
||||
vector_store_id:
|
||||
type: string
|
||||
description: >-
|
||||
The identifier of the vector database to insert the chunks into.
|
||||
|
|
@ -10300,13 +10300,13 @@ components:
|
|||
description: The time to live of the chunks.
|
||||
additionalProperties: false
|
||||
required:
|
||||
- vector_db_id
|
||||
- vector_store_id
|
||||
- chunks
|
||||
title: InsertChunksRequest
|
||||
QueryChunksRequest:
|
||||
type: object
|
||||
properties:
|
||||
vector_db_id:
|
||||
vector_store_id:
|
||||
type: string
|
||||
description: >-
|
||||
The identifier of the vector database to query.
|
||||
|
|
@ -10326,7 +10326,7 @@ components:
|
|||
description: The parameters of the query.
|
||||
additionalProperties: false
|
||||
required:
|
||||
- vector_db_id
|
||||
- vector_store_id
|
||||
- query
|
||||
title: QueryChunksRequest
|
||||
QueryChunksResponse:
|
||||
|
|
@ -11844,7 +11844,7 @@ components:
|
|||
description: Type of the step in an agent turn.
|
||||
const: memory_retrieval
|
||||
default: memory_retrieval
|
||||
vector_db_ids:
|
||||
vector_store_ids:
|
||||
type: string
|
||||
description: >-
|
||||
The IDs of the vector databases to retrieve context from.
|
||||
|
|
@ -11857,7 +11857,7 @@ components:
|
|||
- turn_id
|
||||
- step_id
|
||||
- step_type
|
||||
- vector_db_ids
|
||||
- vector_store_ids
|
||||
- inserted_context
|
||||
title: MemoryRetrievalStep
|
||||
description: >-
|
||||
|
|
|
|||
|
|
@ -78,6 +78,8 @@ dev = [
|
|||
]
|
||||
# These are the dependencies required for running unit tests.
|
||||
unit = [
|
||||
"anthropic",
|
||||
"databricks-sdk",
|
||||
"sqlite-vec",
|
||||
"ollama",
|
||||
"aiosqlite",
|
||||
|
|
|
|||
|
|
@ -149,13 +149,13 @@ class ShieldCallStep(StepCommon):
|
|||
class MemoryRetrievalStep(StepCommon):
|
||||
"""A memory retrieval step in an agent turn.
|
||||
|
||||
:param vector_db_ids: The IDs of the vector databases to retrieve context from.
|
||||
:param vector_store_ids: The IDs of the vector databases to retrieve context from.
|
||||
:param inserted_context: The context retrieved from the vector databases.
|
||||
"""
|
||||
|
||||
step_type: Literal[StepType.memory_retrieval] = StepType.memory_retrieval
|
||||
# TODO: should this be List[str]?
|
||||
vector_db_ids: str
|
||||
vector_store_ids: str
|
||||
inserted_context: InterleavedContent
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -190,13 +190,13 @@ class RAGToolRuntime(Protocol):
|
|||
async def insert(
|
||||
self,
|
||||
documents: list[RAGDocument],
|
||||
vector_db_id: str,
|
||||
vector_store_id: str,
|
||||
chunk_size_in_tokens: int = 512,
|
||||
) -> None:
|
||||
"""Index documents so they can be used by the RAG system.
|
||||
|
||||
:param documents: List of documents to index in the RAG system
|
||||
:param vector_db_id: ID of the vector database to store the document embeddings
|
||||
:param vector_store_id: ID of the vector database to store the document embeddings
|
||||
:param chunk_size_in_tokens: (Optional) Size in tokens for document chunking during indexing
|
||||
"""
|
||||
...
|
||||
|
|
@ -205,13 +205,13 @@ class RAGToolRuntime(Protocol):
|
|||
async def query(
|
||||
self,
|
||||
content: InterleavedContent,
|
||||
vector_db_ids: list[str],
|
||||
vector_store_ids: list[str],
|
||||
query_config: RAGQueryConfig | None = None,
|
||||
) -> RAGQueryResult:
|
||||
"""Query the RAG system for context; typically invoked by the agent.
|
||||
|
||||
:param content: The query content to search for in the indexed documents
|
||||
:param vector_db_ids: List of vector database IDs to search within
|
||||
:param vector_store_ids: List of vector database IDs to search within
|
||||
:param query_config: (Optional) Configuration parameters for the query operation
|
||||
:returns: RAGQueryResult containing the retrieved content and metadata
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -529,17 +529,17 @@ class VectorIO(Protocol):
|
|||
|
||||
# this will just block now until chunks are inserted, but it should
|
||||
# probably return a Job instance which can be polled for completion
|
||||
# TODO: rename vector_db_id to vector_store_id once Stainless is working
|
||||
# TODO: rename vector_store_id to vector_store_id once Stainless is working
|
||||
@webmethod(route="/vector-io/insert", method="POST", level=LLAMA_STACK_API_V1)
|
||||
async def insert_chunks(
|
||||
self,
|
||||
vector_db_id: str,
|
||||
vector_store_id: str,
|
||||
chunks: list[Chunk],
|
||||
ttl_seconds: int | None = None,
|
||||
) -> None:
|
||||
"""Insert chunks into a vector database.
|
||||
|
||||
:param vector_db_id: The identifier of the vector database to insert the chunks into.
|
||||
:param vector_store_id: The identifier of the vector database to insert the chunks into.
|
||||
:param chunks: The chunks to insert. Each `Chunk` should contain content which can be interleaved text, images, or other types.
|
||||
`metadata`: `dict[str, Any]` and `embedding`: `List[float]` are optional.
|
||||
If `metadata` is provided, you configure how Llama Stack formats the chunk during generation.
|
||||
|
|
@ -548,17 +548,17 @@ class VectorIO(Protocol):
|
|||
"""
|
||||
...
|
||||
|
||||
# TODO: rename vector_db_id to vector_store_id once Stainless is working
|
||||
# TODO: rename vector_store_id to vector_store_id once Stainless is working
|
||||
@webmethod(route="/vector-io/query", method="POST", level=LLAMA_STACK_API_V1)
|
||||
async def query_chunks(
|
||||
self,
|
||||
vector_db_id: str,
|
||||
vector_store_id: str,
|
||||
query: InterleavedContent,
|
||||
params: dict[str, Any] | None = None,
|
||||
) -> QueryChunksResponse:
|
||||
"""Query chunks from a vector database.
|
||||
|
||||
:param vector_db_id: The identifier of the vector database to query.
|
||||
:param vector_store_id: The identifier of the vector database to query.
|
||||
:param query: The query to search for.
|
||||
:param params: The parameters of the query.
|
||||
:returns: A QueryChunksResponse.
|
||||
|
|
|
|||
|
|
@ -73,27 +73,27 @@ class VectorIORouter(VectorIO):
|
|||
|
||||
async def insert_chunks(
|
||||
self,
|
||||
vector_db_id: str,
|
||||
vector_store_id: str,
|
||||
chunks: list[Chunk],
|
||||
ttl_seconds: int | None = None,
|
||||
) -> None:
|
||||
doc_ids = [chunk.document_id for chunk in chunks[:3]]
|
||||
logger.debug(
|
||||
f"VectorIORouter.insert_chunks: {vector_db_id}, {len(chunks)} chunks, "
|
||||
f"VectorIORouter.insert_chunks: {vector_store_id}, {len(chunks)} chunks, "
|
||||
f"ttl_seconds={ttl_seconds}, chunk_ids={doc_ids}{' and more...' if len(chunks) > 3 else ''}"
|
||||
)
|
||||
provider = await self.routing_table.get_provider_impl(vector_db_id)
|
||||
return await provider.insert_chunks(vector_db_id, chunks, ttl_seconds)
|
||||
provider = await self.routing_table.get_provider_impl(vector_store_id)
|
||||
return await provider.insert_chunks(vector_store_id, chunks, ttl_seconds)
|
||||
|
||||
async def query_chunks(
|
||||
self,
|
||||
vector_db_id: str,
|
||||
vector_store_id: str,
|
||||
query: InterleavedContent,
|
||||
params: dict[str, Any] | None = None,
|
||||
) -> QueryChunksResponse:
|
||||
logger.debug(f"VectorIORouter.query_chunks: {vector_db_id}")
|
||||
provider = await self.routing_table.get_provider_impl(vector_db_id)
|
||||
return await provider.query_chunks(vector_db_id, query, params)
|
||||
logger.debug(f"VectorIORouter.query_chunks: {vector_store_id}")
|
||||
provider = await self.routing_table.get_provider_impl(vector_store_id)
|
||||
return await provider.query_chunks(vector_store_id, query, params)
|
||||
|
||||
# OpenAI Vector Stores API endpoints
|
||||
async def openai_create_vector_store(
|
||||
|
|
|
|||
|
|
@ -488,13 +488,13 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
|
||||
session_info = await self.storage.get_session_info(session_id)
|
||||
# if the session has a memory bank id, let the memory tool use it
|
||||
if session_info and session_info.vector_db_id:
|
||||
if session_info and session_info.vector_store_id:
|
||||
for tool_name in self.tool_name_to_args.keys():
|
||||
if tool_name == MEMORY_QUERY_TOOL:
|
||||
if "vector_db_ids" not in self.tool_name_to_args[tool_name]:
|
||||
self.tool_name_to_args[tool_name]["vector_db_ids"] = [session_info.vector_db_id]
|
||||
if "vector_store_ids" not in self.tool_name_to_args[tool_name]:
|
||||
self.tool_name_to_args[tool_name]["vector_store_ids"] = [session_info.vector_store_id]
|
||||
else:
|
||||
self.tool_name_to_args[tool_name]["vector_db_ids"].append(session_info.vector_db_id)
|
||||
self.tool_name_to_args[tool_name]["vector_store_ids"].append(session_info.vector_store_id)
|
||||
|
||||
output_attachments = []
|
||||
|
||||
|
|
|
|||
|
|
@ -22,7 +22,7 @@ log = get_logger(name=__name__, category="agents::meta_reference")
|
|||
|
||||
class AgentSessionInfo(Session):
|
||||
# TODO: is this used anywhere?
|
||||
vector_db_id: str | None = None
|
||||
vector_store_id: str | None = None
|
||||
started_at: datetime
|
||||
owner: User | None = None
|
||||
identifier: str | None = None
|
||||
|
|
@ -93,12 +93,12 @@ class AgentPersistence:
|
|||
|
||||
return session_info
|
||||
|
||||
async def add_vector_db_to_session(self, session_id: str, vector_db_id: str):
|
||||
async def add_vector_db_to_session(self, session_id: str, vector_store_id: str):
|
||||
session_info = await self.get_session_if_accessible(session_id)
|
||||
if session_info is None:
|
||||
raise SessionNotFoundError(session_id)
|
||||
|
||||
session_info.vector_db_id = vector_db_id
|
||||
session_info.vector_store_id = vector_store_id
|
||||
await self.kvstore.set(
|
||||
key=f"session:{self.agent_id}:{session_id}",
|
||||
value=session_info.model_dump_json(),
|
||||
|
|
|
|||
|
|
@ -119,7 +119,7 @@ class MemoryToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime, RAGToolRunti
|
|||
async def insert(
|
||||
self,
|
||||
documents: list[RAGDocument],
|
||||
vector_db_id: str,
|
||||
vector_store_id: str,
|
||||
chunk_size_in_tokens: int = 512,
|
||||
) -> None:
|
||||
if not documents:
|
||||
|
|
@ -158,14 +158,14 @@ class MemoryToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime, RAGToolRunti
|
|||
|
||||
try:
|
||||
await self.vector_io_api.openai_attach_file_to_vector_store(
|
||||
vector_store_id=vector_db_id,
|
||||
vector_store_id=vector_store_id,
|
||||
file_id=created_file.id,
|
||||
attributes=doc.metadata,
|
||||
chunking_strategy=chunking_strategy,
|
||||
)
|
||||
except Exception as e:
|
||||
log.error(
|
||||
f"Failed to attach file {created_file.id} to vector store {vector_db_id} for document {doc.document_id}: {e}"
|
||||
f"Failed to attach file {created_file.id} to vector store {vector_store_id} for document {doc.document_id}: {e}"
|
||||
)
|
||||
continue
|
||||
|
||||
|
|
@ -176,10 +176,10 @@ class MemoryToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime, RAGToolRunti
|
|||
async def query(
|
||||
self,
|
||||
content: InterleavedContent,
|
||||
vector_db_ids: list[str],
|
||||
vector_store_ids: list[str],
|
||||
query_config: RAGQueryConfig | None = None,
|
||||
) -> RAGQueryResult:
|
||||
if not vector_db_ids:
|
||||
if not vector_store_ids:
|
||||
raise ValueError(
|
||||
"No vector DBs were provided to the knowledge search tool. Please provide at least one vector DB ID."
|
||||
)
|
||||
|
|
@ -192,7 +192,7 @@ class MemoryToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime, RAGToolRunti
|
|||
)
|
||||
tasks = [
|
||||
self.vector_io_api.query_chunks(
|
||||
vector_db_id=vector_db_id,
|
||||
vector_store_id=vector_store_id,
|
||||
query=query,
|
||||
params={
|
||||
"mode": query_config.mode,
|
||||
|
|
@ -201,18 +201,18 @@ class MemoryToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime, RAGToolRunti
|
|||
"ranker": query_config.ranker,
|
||||
},
|
||||
)
|
||||
for vector_db_id in vector_db_ids
|
||||
for vector_store_id in vector_store_ids
|
||||
]
|
||||
results: list[QueryChunksResponse] = await asyncio.gather(*tasks)
|
||||
|
||||
chunks = []
|
||||
scores = []
|
||||
|
||||
for vector_db_id, result in zip(vector_db_ids, results, strict=False):
|
||||
for vector_store_id, result in zip(vector_store_ids, results, strict=False):
|
||||
for chunk, score in zip(result.chunks, result.scores, strict=False):
|
||||
if not hasattr(chunk, "metadata") or chunk.metadata is None:
|
||||
chunk.metadata = {}
|
||||
chunk.metadata["vector_db_id"] = vector_db_id
|
||||
chunk.metadata["vector_store_id"] = vector_store_id
|
||||
|
||||
chunks.append(chunk)
|
||||
scores.append(score)
|
||||
|
|
@ -250,7 +250,7 @@ class MemoryToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime, RAGToolRunti
|
|||
metadata_keys_to_exclude_from_context = [
|
||||
"token_count",
|
||||
"metadata_token_count",
|
||||
"vector_db_id",
|
||||
"vector_store_id",
|
||||
]
|
||||
metadata_for_context = {}
|
||||
for k in chunk_metadata_keys_to_include_from_context:
|
||||
|
|
@ -275,7 +275,7 @@ class MemoryToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime, RAGToolRunti
|
|||
"document_ids": [c.document_id for c in chunks[: len(picked)]],
|
||||
"chunks": [c.content for c in chunks[: len(picked)]],
|
||||
"scores": scores[: len(picked)],
|
||||
"vector_db_ids": [c.metadata["vector_db_id"] for c in chunks[: len(picked)]],
|
||||
"vector_store_ids": [c.metadata["vector_store_id"] for c in chunks[: len(picked)]],
|
||||
},
|
||||
)
|
||||
|
||||
|
|
@ -309,7 +309,7 @@ class MemoryToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime, RAGToolRunti
|
|||
)
|
||||
|
||||
async def invoke_tool(self, tool_name: str, kwargs: dict[str, Any]) -> ToolInvocationResult:
|
||||
vector_db_ids = kwargs.get("vector_db_ids", [])
|
||||
vector_store_ids = kwargs.get("vector_store_ids", [])
|
||||
query_config = kwargs.get("query_config")
|
||||
if query_config:
|
||||
query_config = TypeAdapter(RAGQueryConfig).validate_python(query_config)
|
||||
|
|
@ -319,7 +319,7 @@ class MemoryToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime, RAGToolRunti
|
|||
query = kwargs["query"]
|
||||
result = await self.query(
|
||||
content=query,
|
||||
vector_db_ids=vector_db_ids,
|
||||
vector_store_ids=vector_store_ids,
|
||||
query_config=query_config,
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -248,19 +248,19 @@ class FaissVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorStoresProtoco
|
|||
del self.cache[vector_store_id]
|
||||
await self.kvstore.delete(f"{VECTOR_DBS_PREFIX}{vector_store_id}")
|
||||
|
||||
async def insert_chunks(self, vector_db_id: str, chunks: list[Chunk], ttl_seconds: int | None = None) -> None:
|
||||
index = self.cache.get(vector_db_id)
|
||||
async def insert_chunks(self, vector_store_id: str, chunks: list[Chunk], ttl_seconds: int | None = None) -> None:
|
||||
index = self.cache.get(vector_store_id)
|
||||
if index is None:
|
||||
raise ValueError(f"Vector DB {vector_db_id} not found. found: {self.cache.keys()}")
|
||||
raise ValueError(f"Vector DB {vector_store_id} not found. found: {self.cache.keys()}")
|
||||
|
||||
await index.insert_chunks(chunks)
|
||||
|
||||
async def query_chunks(
|
||||
self, vector_db_id: str, query: InterleavedContent, params: dict[str, Any] | None = None
|
||||
self, vector_store_id: str, query: InterleavedContent, params: dict[str, Any] | None = None
|
||||
) -> QueryChunksResponse:
|
||||
index = self.cache.get(vector_db_id)
|
||||
index = self.cache.get(vector_store_id)
|
||||
if index is None:
|
||||
raise VectorStoreNotFoundError(vector_db_id)
|
||||
raise VectorStoreNotFoundError(vector_store_id)
|
||||
|
||||
return await index.query_chunks(query, params)
|
||||
|
||||
|
|
|
|||
|
|
@ -447,20 +447,20 @@ class SQLiteVecVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorStoresPro
|
|||
await self.cache[vector_store_id].index.delete()
|
||||
del self.cache[vector_store_id]
|
||||
|
||||
async def insert_chunks(self, vector_db_id: str, chunks: list[Chunk], ttl_seconds: int | None = None) -> None:
|
||||
index = await self._get_and_cache_vector_store_index(vector_db_id)
|
||||
async def insert_chunks(self, vector_store_id: str, chunks: list[Chunk], ttl_seconds: int | None = None) -> None:
|
||||
index = await self._get_and_cache_vector_store_index(vector_store_id)
|
||||
if not index:
|
||||
raise VectorStoreNotFoundError(vector_db_id)
|
||||
raise VectorStoreNotFoundError(vector_store_id)
|
||||
# The VectorStoreWithIndex helper is expected to compute embeddings via the inference_api
|
||||
# and then call our index's add_chunks.
|
||||
await index.insert_chunks(chunks)
|
||||
|
||||
async def query_chunks(
|
||||
self, vector_db_id: str, query: Any, params: dict[str, Any] | None = None
|
||||
self, vector_store_id: str, query: Any, params: dict[str, Any] | None = None
|
||||
) -> QueryChunksResponse:
|
||||
index = await self._get_and_cache_vector_store_index(vector_db_id)
|
||||
index = await self._get_and_cache_vector_store_index(vector_store_id)
|
||||
if not index:
|
||||
raise VectorStoreNotFoundError(vector_db_id)
|
||||
raise VectorStoreNotFoundError(vector_store_id)
|
||||
return await index.query_chunks(query, params)
|
||||
|
||||
async def delete_chunks(self, store_id: str, chunks_for_deletion: list[ChunkForDeletion]) -> None:
|
||||
|
|
|
|||
|
|
@ -61,6 +61,7 @@ def available_providers() -> list[ProviderSpec]:
|
|||
pip_packages=[],
|
||||
module="llama_stack.providers.remote.inference.cerebras",
|
||||
config_class="llama_stack.providers.remote.inference.cerebras.CerebrasImplConfig",
|
||||
provider_data_validator="llama_stack.providers.remote.inference.cerebras.config.CerebrasProviderDataValidator",
|
||||
description="Cerebras inference provider for running models on Cerebras Cloud platform.",
|
||||
),
|
||||
RemoteProviderSpec(
|
||||
|
|
@ -149,6 +150,7 @@ def available_providers() -> list[ProviderSpec]:
|
|||
pip_packages=["databricks-sdk"],
|
||||
module="llama_stack.providers.remote.inference.databricks",
|
||||
config_class="llama_stack.providers.remote.inference.databricks.DatabricksImplConfig",
|
||||
provider_data_validator="llama_stack.providers.remote.inference.databricks.config.DatabricksProviderDataValidator",
|
||||
description="Databricks inference provider for running models on Databricks' unified analytics platform.",
|
||||
),
|
||||
RemoteProviderSpec(
|
||||
|
|
@ -158,6 +160,7 @@ def available_providers() -> list[ProviderSpec]:
|
|||
pip_packages=[],
|
||||
module="llama_stack.providers.remote.inference.nvidia",
|
||||
config_class="llama_stack.providers.remote.inference.nvidia.NVIDIAConfig",
|
||||
provider_data_validator="llama_stack.providers.remote.inference.nvidia.config.NVIDIAProviderDataValidator",
|
||||
description="NVIDIA inference provider for accessing NVIDIA NIM models and AI services.",
|
||||
),
|
||||
RemoteProviderSpec(
|
||||
|
|
@ -167,6 +170,7 @@ def available_providers() -> list[ProviderSpec]:
|
|||
pip_packages=[],
|
||||
module="llama_stack.providers.remote.inference.runpod",
|
||||
config_class="llama_stack.providers.remote.inference.runpod.RunpodImplConfig",
|
||||
provider_data_validator="llama_stack.providers.remote.inference.runpod.config.RunpodProviderDataValidator",
|
||||
description="RunPod inference provider for running models on RunPod's cloud GPU platform.",
|
||||
),
|
||||
RemoteProviderSpec(
|
||||
|
|
|
|||
|
|
@ -163,14 +163,14 @@ The SQLite-vec provider supports three search modes:
|
|||
Example with hybrid search:
|
||||
```python
|
||||
response = await vector_io.query_chunks(
|
||||
vector_db_id="my_db",
|
||||
vector_store_id="my_db",
|
||||
query="your query here",
|
||||
params={"mode": "hybrid", "max_chunks": 3, "score_threshold": 0.7},
|
||||
)
|
||||
|
||||
# Using RRF ranker
|
||||
response = await vector_io.query_chunks(
|
||||
vector_db_id="my_db",
|
||||
vector_store_id="my_db",
|
||||
query="your query here",
|
||||
params={
|
||||
"mode": "hybrid",
|
||||
|
|
@ -182,7 +182,7 @@ response = await vector_io.query_chunks(
|
|||
|
||||
# Using weighted ranker
|
||||
response = await vector_io.query_chunks(
|
||||
vector_db_id="my_db",
|
||||
vector_store_id="my_db",
|
||||
query="your query here",
|
||||
params={
|
||||
"mode": "hybrid",
|
||||
|
|
@ -196,7 +196,7 @@ response = await vector_io.query_chunks(
|
|||
Example with explicit vector search:
|
||||
```python
|
||||
response = await vector_io.query_chunks(
|
||||
vector_db_id="my_db",
|
||||
vector_store_id="my_db",
|
||||
query="your query here",
|
||||
params={"mode": "vector", "max_chunks": 3, "score_threshold": 0.7},
|
||||
)
|
||||
|
|
@ -205,7 +205,7 @@ response = await vector_io.query_chunks(
|
|||
Example with keyword search:
|
||||
```python
|
||||
response = await vector_io.query_chunks(
|
||||
vector_db_id="my_db",
|
||||
vector_store_id="my_db",
|
||||
query="your query here",
|
||||
params={"mode": "keyword", "max_chunks": 3, "score_threshold": 0.7},
|
||||
)
|
||||
|
|
|
|||
|
|
@ -18,6 +18,8 @@ from .config import CerebrasImplConfig
|
|||
class CerebrasInferenceAdapter(OpenAIMixin):
|
||||
config: CerebrasImplConfig
|
||||
|
||||
provider_data_api_key_field: str = "cerebras_api_key"
|
||||
|
||||
def get_base_url(self) -> str:
|
||||
return urljoin(self.config.base_url, "v1")
|
||||
|
||||
|
|
|
|||
|
|
@ -7,7 +7,7 @@
|
|||
import os
|
||||
from typing import Any
|
||||
|
||||
from pydantic import Field
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
|
||||
from llama_stack.schema_utils import json_schema_type
|
||||
|
|
@ -15,6 +15,13 @@ from llama_stack.schema_utils import json_schema_type
|
|||
DEFAULT_BASE_URL = "https://api.cerebras.ai"
|
||||
|
||||
|
||||
class CerebrasProviderDataValidator(BaseModel):
|
||||
cerebras_api_key: str | None = Field(
|
||||
default=None,
|
||||
description="API key for Cerebras models",
|
||||
)
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class CerebrasImplConfig(RemoteInferenceProviderConfig):
|
||||
base_url: str = Field(
|
||||
|
|
|
|||
|
|
@ -6,12 +6,19 @@
|
|||
|
||||
from typing import Any
|
||||
|
||||
from pydantic import Field, SecretStr
|
||||
from pydantic import BaseModel, Field, SecretStr
|
||||
|
||||
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
|
||||
from llama_stack.schema_utils import json_schema_type
|
||||
|
||||
|
||||
class DatabricksProviderDataValidator(BaseModel):
|
||||
databricks_api_token: str | None = Field(
|
||||
default=None,
|
||||
description="API token for Databricks models",
|
||||
)
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class DatabricksImplConfig(RemoteInferenceProviderConfig):
|
||||
url: str | None = Field(
|
||||
|
|
|
|||
|
|
@ -20,6 +20,8 @@ logger = get_logger(name=__name__, category="inference::databricks")
|
|||
class DatabricksInferenceAdapter(OpenAIMixin):
|
||||
config: DatabricksImplConfig
|
||||
|
||||
provider_data_api_key_field: str = "databricks_api_token"
|
||||
|
||||
# source: https://docs.databricks.com/aws/en/machine-learning/foundation-model-apis/supported-models
|
||||
embedding_model_metadata: dict[str, dict[str, int]] = {
|
||||
"databricks-gte-large-en": {"embedding_dimension": 1024, "context_length": 8192},
|
||||
|
|
|
|||
|
|
@ -7,12 +7,19 @@
|
|||
import os
|
||||
from typing import Any
|
||||
|
||||
from pydantic import Field
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
|
||||
from llama_stack.schema_utils import json_schema_type
|
||||
|
||||
|
||||
class NVIDIAProviderDataValidator(BaseModel):
|
||||
nvidia_api_key: str | None = Field(
|
||||
default=None,
|
||||
description="API key for NVIDIA NIM models",
|
||||
)
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class NVIDIAConfig(RemoteInferenceProviderConfig):
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -17,6 +17,8 @@ logger = get_logger(name=__name__, category="inference::nvidia")
|
|||
class NVIDIAInferenceAdapter(OpenAIMixin):
|
||||
config: NVIDIAConfig
|
||||
|
||||
provider_data_api_key_field: str = "nvidia_api_key"
|
||||
|
||||
"""
|
||||
NVIDIA Inference Adapter for Llama Stack.
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -6,12 +6,19 @@
|
|||
|
||||
from typing import Any
|
||||
|
||||
from pydantic import Field, SecretStr
|
||||
from pydantic import BaseModel, Field, SecretStr
|
||||
|
||||
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
|
||||
from llama_stack.schema_utils import json_schema_type
|
||||
|
||||
|
||||
class RunpodProviderDataValidator(BaseModel):
|
||||
runpod_api_token: str | None = Field(
|
||||
default=None,
|
||||
description="API token for RunPod models",
|
||||
)
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class RunpodImplConfig(RemoteInferenceProviderConfig):
|
||||
url: str | None = Field(
|
||||
|
|
|
|||
|
|
@ -24,6 +24,7 @@ class RunpodInferenceAdapter(OpenAIMixin):
|
|||
"""
|
||||
|
||||
config: RunpodImplConfig
|
||||
provider_data_api_key_field: str = "runpod_api_token"
|
||||
|
||||
def get_base_url(self) -> str:
|
||||
"""Get base URL for OpenAI client."""
|
||||
|
|
|
|||
|
|
@ -169,20 +169,20 @@ class ChromaVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorStoresProtoc
|
|||
await self.cache[vector_store_id].index.delete()
|
||||
del self.cache[vector_store_id]
|
||||
|
||||
async def insert_chunks(self, vector_db_id: str, chunks: list[Chunk], ttl_seconds: int | None = None) -> None:
|
||||
index = await self._get_and_cache_vector_store_index(vector_db_id)
|
||||
async def insert_chunks(self, vector_store_id: str, chunks: list[Chunk], ttl_seconds: int | None = None) -> None:
|
||||
index = await self._get_and_cache_vector_store_index(vector_store_id)
|
||||
if index is None:
|
||||
raise ValueError(f"Vector DB {vector_db_id} not found in Chroma")
|
||||
raise ValueError(f"Vector DB {vector_store_id} not found in Chroma")
|
||||
|
||||
await index.insert_chunks(chunks)
|
||||
|
||||
async def query_chunks(
|
||||
self, vector_db_id: str, query: InterleavedContent, params: dict[str, Any] | None = None
|
||||
self, vector_store_id: str, query: InterleavedContent, params: dict[str, Any] | None = None
|
||||
) -> QueryChunksResponse:
|
||||
index = await self._get_and_cache_vector_store_index(vector_db_id)
|
||||
index = await self._get_and_cache_vector_store_index(vector_store_id)
|
||||
|
||||
if index is None:
|
||||
raise ValueError(f"Vector DB {vector_db_id} not found in Chroma")
|
||||
raise ValueError(f"Vector DB {vector_store_id} not found in Chroma")
|
||||
|
||||
return await index.query_chunks(query, params)
|
||||
|
||||
|
|
|
|||
|
|
@ -348,19 +348,19 @@ class MilvusVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorStoresProtoc
|
|||
await self.cache[vector_store_id].index.delete()
|
||||
del self.cache[vector_store_id]
|
||||
|
||||
async def insert_chunks(self, vector_db_id: str, chunks: list[Chunk], ttl_seconds: int | None = None) -> None:
|
||||
index = await self._get_and_cache_vector_store_index(vector_db_id)
|
||||
async def insert_chunks(self, vector_store_id: str, chunks: list[Chunk], ttl_seconds: int | None = None) -> None:
|
||||
index = await self._get_and_cache_vector_store_index(vector_store_id)
|
||||
if not index:
|
||||
raise VectorStoreNotFoundError(vector_db_id)
|
||||
raise VectorStoreNotFoundError(vector_store_id)
|
||||
|
||||
await index.insert_chunks(chunks)
|
||||
|
||||
async def query_chunks(
|
||||
self, vector_db_id: str, query: InterleavedContent, params: dict[str, Any] | None = None
|
||||
self, vector_store_id: str, query: InterleavedContent, params: dict[str, Any] | None = None
|
||||
) -> QueryChunksResponse:
|
||||
index = await self._get_and_cache_vector_store_index(vector_db_id)
|
||||
index = await self._get_and_cache_vector_store_index(vector_store_id)
|
||||
if not index:
|
||||
raise VectorStoreNotFoundError(vector_db_id)
|
||||
raise VectorStoreNotFoundError(vector_store_id)
|
||||
return await index.query_chunks(query, params)
|
||||
|
||||
async def delete_chunks(self, store_id: str, chunks_for_deletion: list[ChunkForDeletion]) -> None:
|
||||
|
|
|
|||
|
|
@ -399,14 +399,14 @@ class PGVectorVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorStoresProt
|
|||
assert self.kvstore is not None
|
||||
await self.kvstore.delete(key=f"{VECTOR_DBS_PREFIX}{vector_store_id}")
|
||||
|
||||
async def insert_chunks(self, vector_db_id: str, chunks: list[Chunk], ttl_seconds: int | None = None) -> None:
|
||||
index = await self._get_and_cache_vector_store_index(vector_db_id)
|
||||
async def insert_chunks(self, vector_store_id: str, chunks: list[Chunk], ttl_seconds: int | None = None) -> None:
|
||||
index = await self._get_and_cache_vector_store_index(vector_store_id)
|
||||
await index.insert_chunks(chunks)
|
||||
|
||||
async def query_chunks(
|
||||
self, vector_db_id: str, query: InterleavedContent, params: dict[str, Any] | None = None
|
||||
self, vector_store_id: str, query: InterleavedContent, params: dict[str, Any] | None = None
|
||||
) -> QueryChunksResponse:
|
||||
index = await self._get_and_cache_vector_store_index(vector_db_id)
|
||||
index = await self._get_and_cache_vector_store_index(vector_store_id)
|
||||
return await index.query_chunks(query, params)
|
||||
|
||||
async def _get_and_cache_vector_store_index(self, vector_store_id: str) -> VectorStoreWithIndex:
|
||||
|
|
|
|||
|
|
@ -222,19 +222,19 @@ class QdrantVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorStoresProtoc
|
|||
self.cache[vector_store_id] = index
|
||||
return index
|
||||
|
||||
async def insert_chunks(self, vector_db_id: str, chunks: list[Chunk], ttl_seconds: int | None = None) -> None:
|
||||
index = await self._get_and_cache_vector_store_index(vector_db_id)
|
||||
async def insert_chunks(self, vector_store_id: str, chunks: list[Chunk], ttl_seconds: int | None = None) -> None:
|
||||
index = await self._get_and_cache_vector_store_index(vector_store_id)
|
||||
if not index:
|
||||
raise VectorStoreNotFoundError(vector_db_id)
|
||||
raise VectorStoreNotFoundError(vector_store_id)
|
||||
|
||||
await index.insert_chunks(chunks)
|
||||
|
||||
async def query_chunks(
|
||||
self, vector_db_id: str, query: InterleavedContent, params: dict[str, Any] | None = None
|
||||
self, vector_store_id: str, query: InterleavedContent, params: dict[str, Any] | None = None
|
||||
) -> QueryChunksResponse:
|
||||
index = await self._get_and_cache_vector_store_index(vector_db_id)
|
||||
index = await self._get_and_cache_vector_store_index(vector_store_id)
|
||||
if not index:
|
||||
raise VectorStoreNotFoundError(vector_db_id)
|
||||
raise VectorStoreNotFoundError(vector_store_id)
|
||||
|
||||
return await index.query_chunks(query, params)
|
||||
|
||||
|
|
|
|||
|
|
@ -366,19 +366,19 @@ class WeaviateVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, NeedsRequestProv
|
|||
self.cache[vector_store_id] = index
|
||||
return index
|
||||
|
||||
async def insert_chunks(self, vector_db_id: str, chunks: list[Chunk], ttl_seconds: int | None = None) -> None:
|
||||
index = await self._get_and_cache_vector_store_index(vector_db_id)
|
||||
async def insert_chunks(self, vector_store_id: str, chunks: list[Chunk], ttl_seconds: int | None = None) -> None:
|
||||
index = await self._get_and_cache_vector_store_index(vector_store_id)
|
||||
if not index:
|
||||
raise VectorStoreNotFoundError(vector_db_id)
|
||||
raise VectorStoreNotFoundError(vector_store_id)
|
||||
|
||||
await index.insert_chunks(chunks)
|
||||
|
||||
async def query_chunks(
|
||||
self, vector_db_id: str, query: InterleavedContent, params: dict[str, Any] | None = None
|
||||
self, vector_store_id: str, query: InterleavedContent, params: dict[str, Any] | None = None
|
||||
) -> QueryChunksResponse:
|
||||
index = await self._get_and_cache_vector_store_index(vector_db_id)
|
||||
index = await self._get_and_cache_vector_store_index(vector_store_id)
|
||||
if not index:
|
||||
raise VectorStoreNotFoundError(vector_db_id)
|
||||
raise VectorStoreNotFoundError(vector_store_id)
|
||||
|
||||
return await index.query_chunks(query, params)
|
||||
|
||||
|
|
|
|||
|
|
@ -333,7 +333,7 @@ class OpenAIVectorStoreMixin(ABC):
|
|||
@abstractmethod
|
||||
async def insert_chunks(
|
||||
self,
|
||||
vector_db_id: str,
|
||||
vector_store_id: str,
|
||||
chunks: list[Chunk],
|
||||
ttl_seconds: int | None = None,
|
||||
) -> None:
|
||||
|
|
@ -342,7 +342,7 @@ class OpenAIVectorStoreMixin(ABC):
|
|||
|
||||
@abstractmethod
|
||||
async def query_chunks(
|
||||
self, vector_db_id: str, query: Any, params: dict[str, Any] | None = None
|
||||
self, vector_store_id: str, query: Any, params: dict[str, Any] | None = None
|
||||
) -> QueryChunksResponse:
|
||||
"""Query chunks from a vector database (provider-specific implementation)."""
|
||||
pass
|
||||
|
|
@ -609,7 +609,7 @@ class OpenAIVectorStoreMixin(ABC):
|
|||
# TODO: Add support for ranking_options.ranker
|
||||
|
||||
response = await self.query_chunks(
|
||||
vector_db_id=vector_store_id,
|
||||
vector_store_id=vector_store_id,
|
||||
query=search_query,
|
||||
params=params,
|
||||
)
|
||||
|
|
@ -803,7 +803,7 @@ class OpenAIVectorStoreMixin(ABC):
|
|||
)
|
||||
else:
|
||||
await self.insert_chunks(
|
||||
vector_db_id=vector_store_id,
|
||||
vector_store_id=vector_store_id,
|
||||
chunks=chunks,
|
||||
)
|
||||
vector_store_file_object.status = "completed"
|
||||
|
|
|
|||
|
|
@ -367,7 +367,7 @@ def test_openai_vector_store_with_chunks(
|
|||
|
||||
# Insert chunks using the native LlamaStack API (since OpenAI API doesn't have direct chunk insertion)
|
||||
llama_client.vector_io.insert(
|
||||
vector_db_id=vector_store.id,
|
||||
vector_store_id=vector_store.id,
|
||||
chunks=sample_chunks,
|
||||
)
|
||||
|
||||
|
|
@ -434,7 +434,7 @@ def test_openai_vector_store_search_relevance(
|
|||
|
||||
# Insert chunks using native API
|
||||
llama_client.vector_io.insert(
|
||||
vector_db_id=vector_store.id,
|
||||
vector_store_id=vector_store.id,
|
||||
chunks=sample_chunks,
|
||||
)
|
||||
|
||||
|
|
@ -484,7 +484,7 @@ def test_openai_vector_store_search_with_ranking_options(
|
|||
|
||||
# Insert chunks
|
||||
llama_client.vector_io.insert(
|
||||
vector_db_id=vector_store.id,
|
||||
vector_store_id=vector_store.id,
|
||||
chunks=sample_chunks,
|
||||
)
|
||||
|
||||
|
|
@ -544,7 +544,7 @@ def test_openai_vector_store_search_with_high_score_filter(
|
|||
|
||||
# Insert chunks
|
||||
llama_client.vector_io.insert(
|
||||
vector_db_id=vector_store.id,
|
||||
vector_store_id=vector_store.id,
|
||||
chunks=sample_chunks,
|
||||
)
|
||||
|
||||
|
|
@ -610,7 +610,7 @@ def test_openai_vector_store_search_with_max_num_results(
|
|||
|
||||
# Insert chunks
|
||||
llama_client.vector_io.insert(
|
||||
vector_db_id=vector_store.id,
|
||||
vector_store_id=vector_store.id,
|
||||
chunks=sample_chunks,
|
||||
)
|
||||
|
||||
|
|
@ -1175,7 +1175,7 @@ def test_openai_vector_store_search_modes(
|
|||
)
|
||||
|
||||
client_with_models.vector_io.insert(
|
||||
vector_db_id=vector_store.id,
|
||||
vector_store_id=vector_store.id,
|
||||
chunks=sample_chunks,
|
||||
)
|
||||
query = "Python programming language"
|
||||
|
|
|
|||
|
|
@ -123,12 +123,12 @@ def test_insert_chunks(
|
|||
actual_vector_store_id = create_response.id
|
||||
|
||||
client_with_empty_registry.vector_io.insert(
|
||||
vector_db_id=actual_vector_store_id,
|
||||
vector_store_id=actual_vector_store_id,
|
||||
chunks=sample_chunks,
|
||||
)
|
||||
|
||||
response = client_with_empty_registry.vector_io.query(
|
||||
vector_db_id=actual_vector_store_id,
|
||||
vector_store_id=actual_vector_store_id,
|
||||
query="What is the capital of France?",
|
||||
)
|
||||
assert response is not None
|
||||
|
|
@ -137,7 +137,7 @@ def test_insert_chunks(
|
|||
|
||||
query, expected_doc_id = test_case
|
||||
response = client_with_empty_registry.vector_io.query(
|
||||
vector_db_id=actual_vector_store_id,
|
||||
vector_store_id=actual_vector_store_id,
|
||||
query=query,
|
||||
)
|
||||
assert response is not None
|
||||
|
|
@ -174,13 +174,13 @@ def test_insert_chunks_with_precomputed_embeddings(
|
|||
]
|
||||
|
||||
client_with_empty_registry.vector_io.insert(
|
||||
vector_db_id=actual_vector_store_id,
|
||||
vector_store_id=actual_vector_store_id,
|
||||
chunks=chunks_with_embeddings,
|
||||
)
|
||||
|
||||
provider = [p.provider_id for p in client_with_empty_registry.providers.list() if p.api == "vector_io"][0]
|
||||
response = client_with_empty_registry.vector_io.query(
|
||||
vector_db_id=actual_vector_store_id,
|
||||
vector_store_id=actual_vector_store_id,
|
||||
query="precomputed embedding test",
|
||||
params=vector_io_provider_params_dict.get(provider, None),
|
||||
)
|
||||
|
|
@ -224,13 +224,13 @@ def test_query_returns_valid_object_when_identical_to_embedding_in_vdb(
|
|||
]
|
||||
|
||||
client_with_empty_registry.vector_io.insert(
|
||||
vector_db_id=actual_vector_store_id,
|
||||
vector_store_id=actual_vector_store_id,
|
||||
chunks=chunks_with_embeddings,
|
||||
)
|
||||
|
||||
provider = [p.provider_id for p in client_with_empty_registry.providers.list() if p.api == "vector_io"][0]
|
||||
response = client_with_empty_registry.vector_io.query(
|
||||
vector_db_id=actual_vector_store_id,
|
||||
vector_store_id=actual_vector_store_id,
|
||||
query="duplicate",
|
||||
params=vector_io_provider_params_dict.get(provider, None),
|
||||
)
|
||||
|
|
|
|||
|
|
@ -10,47 +10,124 @@ from unittest.mock import MagicMock
|
|||
import pytest
|
||||
|
||||
from llama_stack.core.request_headers import request_provider_data_context
|
||||
from llama_stack.providers.remote.inference.anthropic.anthropic import AnthropicInferenceAdapter
|
||||
from llama_stack.providers.remote.inference.anthropic.config import AnthropicConfig
|
||||
from llama_stack.providers.remote.inference.cerebras.cerebras import CerebrasInferenceAdapter
|
||||
from llama_stack.providers.remote.inference.cerebras.config import CerebrasImplConfig
|
||||
from llama_stack.providers.remote.inference.databricks.config import DatabricksImplConfig
|
||||
from llama_stack.providers.remote.inference.databricks.databricks import DatabricksInferenceAdapter
|
||||
from llama_stack.providers.remote.inference.fireworks.config import FireworksImplConfig
|
||||
from llama_stack.providers.remote.inference.fireworks.fireworks import FireworksInferenceAdapter
|
||||
from llama_stack.providers.remote.inference.gemini.config import GeminiConfig
|
||||
from llama_stack.providers.remote.inference.gemini.gemini import GeminiInferenceAdapter
|
||||
from llama_stack.providers.remote.inference.groq.config import GroqConfig
|
||||
from llama_stack.providers.remote.inference.groq.groq import GroqInferenceAdapter
|
||||
from llama_stack.providers.remote.inference.llama_openai_compat.config import LlamaCompatConfig
|
||||
from llama_stack.providers.remote.inference.llama_openai_compat.llama import LlamaCompatInferenceAdapter
|
||||
from llama_stack.providers.remote.inference.nvidia.config import NVIDIAConfig
|
||||
from llama_stack.providers.remote.inference.nvidia.nvidia import NVIDIAInferenceAdapter
|
||||
from llama_stack.providers.remote.inference.openai.config import OpenAIConfig
|
||||
from llama_stack.providers.remote.inference.openai.openai import OpenAIInferenceAdapter
|
||||
from llama_stack.providers.remote.inference.runpod.config import RunpodImplConfig
|
||||
from llama_stack.providers.remote.inference.runpod.runpod import RunpodInferenceAdapter
|
||||
from llama_stack.providers.remote.inference.sambanova.config import SambaNovaImplConfig
|
||||
from llama_stack.providers.remote.inference.sambanova.sambanova import SambaNovaInferenceAdapter
|
||||
from llama_stack.providers.remote.inference.together.config import TogetherImplConfig
|
||||
from llama_stack.providers.remote.inference.together.together import TogetherInferenceAdapter
|
||||
from llama_stack.providers.remote.inference.vllm.config import VLLMInferenceAdapterConfig
|
||||
from llama_stack.providers.remote.inference.vllm.vllm import VLLMInferenceAdapter
|
||||
from llama_stack.providers.remote.inference.watsonx.config import WatsonXConfig
|
||||
from llama_stack.providers.remote.inference.watsonx.watsonx import WatsonXInferenceAdapter
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"config_cls,adapter_cls,provider_data_validator",
|
||||
"config_cls,adapter_cls,provider_data_validator,config_params",
|
||||
[
|
||||
(
|
||||
GroqConfig,
|
||||
GroqInferenceAdapter,
|
||||
"llama_stack.providers.remote.inference.groq.config.GroqProviderDataValidator",
|
||||
{},
|
||||
),
|
||||
(
|
||||
OpenAIConfig,
|
||||
OpenAIInferenceAdapter,
|
||||
"llama_stack.providers.remote.inference.openai.config.OpenAIProviderDataValidator",
|
||||
{},
|
||||
),
|
||||
(
|
||||
TogetherImplConfig,
|
||||
TogetherInferenceAdapter,
|
||||
"llama_stack.providers.remote.inference.together.TogetherProviderDataValidator",
|
||||
{},
|
||||
),
|
||||
(
|
||||
LlamaCompatConfig,
|
||||
LlamaCompatInferenceAdapter,
|
||||
"llama_stack.providers.remote.inference.llama_openai_compat.config.LlamaProviderDataValidator",
|
||||
{},
|
||||
),
|
||||
(
|
||||
CerebrasImplConfig,
|
||||
CerebrasInferenceAdapter,
|
||||
"llama_stack.providers.remote.inference.cerebras.config.CerebrasProviderDataValidator",
|
||||
{},
|
||||
),
|
||||
(
|
||||
DatabricksImplConfig,
|
||||
DatabricksInferenceAdapter,
|
||||
"llama_stack.providers.remote.inference.databricks.config.DatabricksProviderDataValidator",
|
||||
{},
|
||||
),
|
||||
(
|
||||
NVIDIAConfig,
|
||||
NVIDIAInferenceAdapter,
|
||||
"llama_stack.providers.remote.inference.nvidia.config.NVIDIAProviderDataValidator",
|
||||
{},
|
||||
),
|
||||
(
|
||||
RunpodImplConfig,
|
||||
RunpodInferenceAdapter,
|
||||
"llama_stack.providers.remote.inference.runpod.config.RunpodProviderDataValidator",
|
||||
{},
|
||||
),
|
||||
(
|
||||
FireworksImplConfig,
|
||||
FireworksInferenceAdapter,
|
||||
"llama_stack.providers.remote.inference.fireworks.FireworksProviderDataValidator",
|
||||
{},
|
||||
),
|
||||
(
|
||||
AnthropicConfig,
|
||||
AnthropicInferenceAdapter,
|
||||
"llama_stack.providers.remote.inference.anthropic.config.AnthropicProviderDataValidator",
|
||||
{},
|
||||
),
|
||||
(
|
||||
GeminiConfig,
|
||||
GeminiInferenceAdapter,
|
||||
"llama_stack.providers.remote.inference.gemini.config.GeminiProviderDataValidator",
|
||||
{},
|
||||
),
|
||||
(
|
||||
SambaNovaImplConfig,
|
||||
SambaNovaInferenceAdapter,
|
||||
"llama_stack.providers.remote.inference.sambanova.config.SambaNovaProviderDataValidator",
|
||||
{},
|
||||
),
|
||||
(
|
||||
VLLMInferenceAdapterConfig,
|
||||
VLLMInferenceAdapter,
|
||||
"llama_stack.providers.remote.inference.vllm.VLLMProviderDataValidator",
|
||||
{
|
||||
"url": "http://fake",
|
||||
},
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_openai_provider_data_used(config_cls, adapter_cls, provider_data_validator: str):
|
||||
def test_openai_provider_data_used(config_cls, adapter_cls, provider_data_validator: str, config_params: dict):
|
||||
"""Ensure the OpenAI provider does not cache api keys across client requests"""
|
||||
|
||||
inference_adapter = adapter_cls(config=config_cls())
|
||||
inference_adapter = adapter_cls(config=config_cls(**config_params))
|
||||
|
||||
inference_adapter.__provider_spec__ = MagicMock()
|
||||
inference_adapter.__provider_spec__.provider_data_validator = provider_data_validator
|
||||
|
|
|
|||
|
|
@ -23,14 +23,14 @@ class TestRagQuery:
|
|||
config=MagicMock(), vector_io_api=MagicMock(), inference_api=MagicMock(), files_api=MagicMock()
|
||||
)
|
||||
with pytest.raises(ValueError):
|
||||
await rag_tool.query(content=MagicMock(), vector_db_ids=[])
|
||||
await rag_tool.query(content=MagicMock(), vector_store_ids=[])
|
||||
|
||||
async def test_query_chunk_metadata_handling(self):
|
||||
rag_tool = MemoryToolRuntimeImpl(
|
||||
config=MagicMock(), vector_io_api=MagicMock(), inference_api=MagicMock(), files_api=MagicMock()
|
||||
)
|
||||
content = "test query content"
|
||||
vector_db_ids = ["db1"]
|
||||
vector_store_ids = ["db1"]
|
||||
|
||||
chunk_metadata = ChunkMetadata(
|
||||
document_id="doc1",
|
||||
|
|
@ -55,7 +55,7 @@ class TestRagQuery:
|
|||
query_response = QueryChunksResponse(chunks=[chunk], scores=[1.0])
|
||||
|
||||
rag_tool.vector_io_api.query_chunks = AsyncMock(return_value=query_response)
|
||||
result = await rag_tool.query(content=content, vector_db_ids=vector_db_ids)
|
||||
result = await rag_tool.query(content=content, vector_store_ids=vector_store_ids)
|
||||
|
||||
assert result is not None
|
||||
expected_metadata_string = (
|
||||
|
|
@ -90,7 +90,7 @@ class TestRagQuery:
|
|||
files_api=MagicMock(),
|
||||
)
|
||||
|
||||
vector_db_ids = ["db1", "db2"]
|
||||
vector_store_ids = ["db1", "db2"]
|
||||
|
||||
# Fake chunks from each DB
|
||||
chunk_metadata1 = ChunkMetadata(
|
||||
|
|
@ -101,7 +101,7 @@ class TestRagQuery:
|
|||
)
|
||||
chunk1 = Chunk(
|
||||
content="chunk from db1",
|
||||
metadata={"vector_db_id": "db1", "document_id": "doc1"},
|
||||
metadata={"vector_store_id": "db1", "document_id": "doc1"},
|
||||
stored_chunk_id="c1",
|
||||
chunk_metadata=chunk_metadata1,
|
||||
)
|
||||
|
|
@ -114,7 +114,7 @@ class TestRagQuery:
|
|||
)
|
||||
chunk2 = Chunk(
|
||||
content="chunk from db2",
|
||||
metadata={"vector_db_id": "db2", "document_id": "doc2"},
|
||||
metadata={"vector_store_id": "db2", "document_id": "doc2"},
|
||||
stored_chunk_id="c2",
|
||||
chunk_metadata=chunk_metadata2,
|
||||
)
|
||||
|
|
@ -126,13 +126,13 @@ class TestRagQuery:
|
|||
]
|
||||
)
|
||||
|
||||
result = await rag_tool.query(content="test", vector_db_ids=vector_db_ids)
|
||||
result = await rag_tool.query(content="test", vector_store_ids=vector_store_ids)
|
||||
returned_chunks = result.metadata["chunks"]
|
||||
returned_scores = result.metadata["scores"]
|
||||
returned_doc_ids = result.metadata["document_ids"]
|
||||
returned_vector_db_ids = result.metadata["vector_db_ids"]
|
||||
returned_vector_store_ids = result.metadata["vector_store_ids"]
|
||||
|
||||
assert returned_chunks == ["chunk from db1", "chunk from db2"]
|
||||
assert returned_scores == (0.9, 0.8)
|
||||
assert returned_doc_ids == ["doc1", "doc2"]
|
||||
assert returned_vector_db_ids == ["db1", "db2"]
|
||||
assert returned_vector_store_ids == ["db1", "db2"]
|
||||
|
|
|
|||
45
uv.lock
generated
45
uv.lock
generated
|
|
@ -129,6 +129,25 @@ wheels = [
|
|||
{ url = "https://files.pythonhosted.org/packages/78/b6/6307fbef88d9b5ee7421e68d78a9f162e0da4900bc5f5793f6d3d0e34fb8/annotated_types-0.7.0-py3-none-any.whl", hash = "sha256:1f02e8b43a8fbbc3f3e0d4f0f4bfc8131bcb4eebe8849b8e5c773f3a1c582a53", size = 13643, upload-time = "2024-05-20T21:33:24.1Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "anthropic"
|
||||
version = "0.69.0"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "anyio" },
|
||||
{ name = "distro" },
|
||||
{ name = "docstring-parser" },
|
||||
{ name = "httpx" },
|
||||
{ name = "jiter" },
|
||||
{ name = "pydantic" },
|
||||
{ name = "sniffio" },
|
||||
{ name = "typing-extensions" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/c8/9d/9ad1778b95f15c5b04e7d328c1b5f558f1e893857b7c33cd288c19c0057a/anthropic-0.69.0.tar.gz", hash = "sha256:c604d287f4d73640f40bd2c0f3265a2eb6ce034217ead0608f6b07a8bc5ae5f2", size = 480622, upload-time = "2025-09-29T16:53:45.282Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/9b/38/75129688de5637eb5b383e5f2b1570a5cc3aecafa4de422da8eea4b90a6c/anthropic-0.69.0-py3-none-any.whl", hash = "sha256:1f73193040f33f11e27c2cd6ec25f24fe7c3f193dc1c5cde6b7a08b18a16bcc5", size = 337265, upload-time = "2025-09-29T16:53:43.686Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "anyio"
|
||||
version = "4.9.0"
|
||||
|
|
@ -758,6 +777,19 @@ wheels = [
|
|||
{ url = "https://files.pythonhosted.org/packages/79/b3/28ac139109d9005ad3f6b6f8976ffede6706a6478e21c889ce36c840918e/cryptography-45.0.5-cp37-abi3-win_amd64.whl", hash = "sha256:90cb0a7bb35959f37e23303b7eed0a32280510030daba3f7fdfbb65defde6a97", size = 3390016, upload-time = "2025-07-02T13:05:50.811Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "databricks-sdk"
|
||||
version = "0.67.0"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "google-auth" },
|
||||
{ name = "requests" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/b3/5b/df3e5424d833e4f3f9b42c409ef8b513e468c9cdf06c2a9935c6cbc4d128/databricks_sdk-0.67.0.tar.gz", hash = "sha256:f923227babcaad428b0c2eede2755ebe9deb996e2c8654f179eb37f486b37a36", size = 761000, upload-time = "2025-09-25T13:32:10.858Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/a0/ca/2aff3817041483fb8e4f75a74a36ff4ca3a826e276becd1179a591b6348f/databricks_sdk-0.67.0-py3-none-any.whl", hash = "sha256:ef49e49db45ed12c015a32a6f9d4ba395850f25bb3dcffdcaf31a5167fe03ee2", size = 718422, upload-time = "2025-09-25T13:32:09.011Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "datasets"
|
||||
version = "4.0.0"
|
||||
|
|
@ -856,6 +888,15 @@ wheels = [
|
|||
{ url = "https://files.pythonhosted.org/packages/12/b3/231ffd4ab1fc9d679809f356cebee130ac7daa00d6d6f3206dd4fd137e9e/distro-1.9.0-py3-none-any.whl", hash = "sha256:7bffd925d65168f85027d8da9af6bddab658135b840670a223589bc0c8ef02b2", size = 20277, upload-time = "2023-12-24T09:54:30.421Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "docstring-parser"
|
||||
version = "0.17.0"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/b2/9d/c3b43da9515bd270df0f80548d9944e389870713cc1fe2b8fb35fe2bcefd/docstring_parser-0.17.0.tar.gz", hash = "sha256:583de4a309722b3315439bb31d64ba3eebada841f2e2cee23b99df001434c912", size = 27442, upload-time = "2025-07-21T07:35:01.868Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/55/e2/2537ebcff11c1ee1ff17d8d0b6f4db75873e3b0fb32c2d4a2ee31ecb310a/docstring_parser-0.17.0-py3-none-any.whl", hash = "sha256:cf2569abd23dce8099b300f9b4fa8191e9582dda731fd533daf54c4551658708", size = 36896, upload-time = "2025-07-21T07:35:00.684Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "docutils"
|
||||
version = "0.21.2"
|
||||
|
|
@ -1863,9 +1904,11 @@ test = [
|
|||
unit = [
|
||||
{ name = "aiohttp" },
|
||||
{ name = "aiosqlite" },
|
||||
{ name = "anthropic" },
|
||||
{ name = "blobfile" },
|
||||
{ name = "chardet" },
|
||||
{ name = "coverage" },
|
||||
{ name = "databricks-sdk" },
|
||||
{ name = "faiss-cpu" },
|
||||
{ name = "litellm" },
|
||||
{ name = "mcp" },
|
||||
|
|
@ -1978,9 +2021,11 @@ test = [
|
|||
unit = [
|
||||
{ name = "aiohttp" },
|
||||
{ name = "aiosqlite" },
|
||||
{ name = "anthropic" },
|
||||
{ name = "blobfile" },
|
||||
{ name = "chardet" },
|
||||
{ name = "coverage" },
|
||||
{ name = "databricks-sdk" },
|
||||
{ name = "faiss-cpu" },
|
||||
{ name = "litellm" },
|
||||
{ name = "mcp" },
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue