mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-14 22:52:39 +00:00
Merge branch 'main' into nvidia-e2e-notebook
This commit is contained in:
commit
bd64bc99ea
69 changed files with 7913 additions and 2495 deletions
2
.github/CODEOWNERS
vendored
2
.github/CODEOWNERS
vendored
|
|
@ -2,4 +2,4 @@
|
|||
|
||||
# These owners will be the default owners for everything in
|
||||
# the repo. Unless a later match takes precedence,
|
||||
* @ashwinb @yanxi0830 @hardikjshah @raghotham @ehhuang @terrytangyuan @leseb @bbrowning
|
||||
* @ashwinb @yanxi0830 @hardikjshah @raghotham @ehhuang @terrytangyuan @leseb @bbrowning @reluctantfuturist
|
||||
|
|
|
|||
26
.github/workflows/integration-auth-tests.yml
vendored
26
.github/workflows/integration-auth-tests.yml
vendored
|
|
@ -52,30 +52,7 @@ jobs:
|
|||
run: |
|
||||
kubectl create namespace llama-stack
|
||||
kubectl create serviceaccount llama-stack-auth -n llama-stack
|
||||
kubectl create rolebinding llama-stack-auth-rolebinding --clusterrole=admin --serviceaccount=llama-stack:llama-stack-auth -n llama-stack
|
||||
kubectl create token llama-stack-auth -n llama-stack > llama-stack-auth-token
|
||||
cat <<EOF | kubectl apply -f -
|
||||
apiVersion: rbac.authorization.k8s.io/v1
|
||||
kind: ClusterRole
|
||||
metadata:
|
||||
name: allow-anonymous-openid
|
||||
rules:
|
||||
- nonResourceURLs: ["/openid/v1/jwks"]
|
||||
verbs: ["get"]
|
||||
---
|
||||
apiVersion: rbac.authorization.k8s.io/v1
|
||||
kind: ClusterRoleBinding
|
||||
metadata:
|
||||
name: allow-anonymous-openid
|
||||
roleRef:
|
||||
apiGroup: rbac.authorization.k8s.io
|
||||
kind: ClusterRole
|
||||
name: allow-anonymous-openid
|
||||
subjects:
|
||||
- kind: User
|
||||
name: system:anonymous
|
||||
apiGroup: rbac.authorization.k8s.io
|
||||
EOF
|
||||
|
||||
- name: Set Kubernetes Config
|
||||
if: ${{ matrix.auth-provider == 'oauth2_token' }}
|
||||
|
|
@ -84,6 +61,7 @@ jobs:
|
|||
echo "KUBERNETES_CA_CERT_PATH=$(kubectl config view --minify -o jsonpath='{.clusters[0].cluster.certificate-authority}')" >> $GITHUB_ENV
|
||||
echo "KUBERNETES_ISSUER=$(kubectl get --raw /.well-known/openid-configuration| jq -r .issuer)" >> $GITHUB_ENV
|
||||
echo "KUBERNETES_AUDIENCE=$(kubectl create token llama-stack-auth -n llama-stack --duration=1h | cut -d. -f2 | base64 -d | jq -r '.aud[0]')" >> $GITHUB_ENV
|
||||
echo "TOKEN=$(cat llama-stack-auth-token)" >> $GITHUB_ENV
|
||||
|
||||
- name: Set Kube Auth Config and run server
|
||||
env:
|
||||
|
|
@ -101,7 +79,7 @@ jobs:
|
|||
EOF
|
||||
yq eval '.server.auth = {"provider_type": "${{ matrix.auth-provider }}"}' -i $run_dir/run.yaml
|
||||
yq eval '.server.auth.config = {"tls_cafile": "${{ env.KUBERNETES_CA_CERT_PATH }}", "issuer": "${{ env.KUBERNETES_ISSUER }}", "audience": "${{ env.KUBERNETES_AUDIENCE }}"}' -i $run_dir/run.yaml
|
||||
yq eval '.server.auth.config.jwks = {"uri": "${{ env.KUBERNETES_API_SERVER_URL }}"}' -i $run_dir/run.yaml
|
||||
yq eval '.server.auth.config.jwks = {"uri": "${{ env.KUBERNETES_API_SERVER_URL }}", "token": "${{ env.TOKEN }}"}' -i $run_dir/run.yaml
|
||||
cat $run_dir/run.yaml
|
||||
|
||||
nohup uv run llama stack run $run_dir/run.yaml --image-type venv > server.log 2>&1 &
|
||||
|
|
|
|||
2
.github/workflows/integration-tests.yml
vendored
2
.github/workflows/integration-tests.yml
vendored
|
|
@ -24,7 +24,7 @@ jobs:
|
|||
matrix:
|
||||
# Listing tests manually since some of them currently fail
|
||||
# TODO: generate matrix list from tests/integration when fixed
|
||||
test-type: [agents, inference, datasets, inspect, scoring, post_training, providers, tool_runtime]
|
||||
test-type: [agents, inference, datasets, inspect, scoring, post_training, providers, tool_runtime, vector_io]
|
||||
client-type: [library, http]
|
||||
python-version: ["3.10", "3.11", "3.12"]
|
||||
fail-fast: false # we want to run all tests regardless of failure
|
||||
|
|
|
|||
10
.github/workflows/test-external-providers.yml
vendored
10
.github/workflows/test-external-providers.yml
vendored
|
|
@ -45,20 +45,22 @@ jobs:
|
|||
|
||||
- name: Build distro from config file
|
||||
run: |
|
||||
USE_COPY_NOT_MOUNT=true LLAMA_STACK_DIR=. uv run llama stack build --config tests/external-provider/llama-stack-provider-ollama/custom-distro.yaml
|
||||
USE_COPY_NOT_MOUNT=true LLAMA_STACK_DIR=. llama stack build --config tests/external-provider/llama-stack-provider-ollama/custom-distro.yaml
|
||||
|
||||
- name: Start Llama Stack server in background
|
||||
if: ${{ matrix.image-type }} == 'venv'
|
||||
env:
|
||||
INFERENCE_MODEL: "meta-llama/Llama-3.2-3B-Instruct"
|
||||
run: |
|
||||
uv run pip list
|
||||
nohup uv run --active llama stack run tests/external-provider/llama-stack-provider-ollama/run.yaml --image-type ${{ matrix.image-type }} > server.log 2>&1 &
|
||||
# Use the virtual environment created by the build step (name comes from build config)
|
||||
source ci-test/bin/activate
|
||||
uv pip list
|
||||
nohup llama stack run tests/external-provider/llama-stack-provider-ollama/run.yaml --image-type ${{ matrix.image-type }} > server.log 2>&1 &
|
||||
|
||||
- name: Wait for Llama Stack server to be ready
|
||||
run: |
|
||||
for i in {1..30}; do
|
||||
if ! grep -q "remote::custom_ollama from /home/runner/.llama/providers.d/remote/inference/custom_ollama.yaml" server.log; then
|
||||
if ! grep -q "Successfully loaded external provider remote::custom_ollama" server.log; then
|
||||
echo "Waiting for Llama Stack server to load the provider..."
|
||||
sleep 1
|
||||
else
|
||||
|
|
|
|||
49
CHANGELOG.md
49
CHANGELOG.md
|
|
@ -1,5 +1,54 @@
|
|||
# Changelog
|
||||
|
||||
# v0.2.10.1
|
||||
Published on: 2025-06-06T20:11:02Z
|
||||
|
||||
## Highlights
|
||||
* ChromaDB provider fix
|
||||
|
||||
|
||||
---
|
||||
|
||||
# v0.2.10
|
||||
Published on: 2025-06-05T23:21:45Z
|
||||
|
||||
## Highlights
|
||||
|
||||
* OpenAI-compatible embeddings API
|
||||
* OpenAI-compatible Files API
|
||||
* Postgres support in starter distro
|
||||
* Enable ingestion of precomputed embeddings
|
||||
* Full multi-turn support in Responses API
|
||||
* Fine-grained access control policy
|
||||
|
||||
|
||||
---
|
||||
|
||||
# v0.2.9
|
||||
Published on: 2025-05-30T20:01:56Z
|
||||
|
||||
## Highlights
|
||||
* Added initial streaming support in Responses API
|
||||
* UI view for Responses
|
||||
* Postgres inference store support
|
||||
|
||||
|
||||
---
|
||||
|
||||
# v0.2.8
|
||||
Published on: 2025-05-27T21:03:47Z
|
||||
|
||||
# Release v0.2.8
|
||||
|
||||
## Highlights
|
||||
|
||||
* Server-side MCP with auth firewalls now works in the Stack - both for Agents and Responses
|
||||
* Get chat completions APIs and UI to show chat completions
|
||||
* Enable keyword search for sqlite-vec
|
||||
|
||||
|
||||
---
|
||||
|
||||
# v0.2.7
|
||||
Published on: 2025-05-16T20:38:10Z
|
||||
|
||||
|
|
|
|||
1269
docs/_static/llama-stack-spec.html
vendored
1269
docs/_static/llama-stack-spec.html
vendored
File diff suppressed because it is too large
Load diff
835
docs/_static/llama-stack-spec.yaml
vendored
835
docs/_static/llama-stack-spec.yaml
vendored
|
|
@ -2263,6 +2263,43 @@ paths:
|
|||
schema:
|
||||
$ref: '#/components/schemas/LogEventRequest'
|
||||
required: true
|
||||
/v1/openai/v1/vector_stores/{vector_store_id}/files:
|
||||
post:
|
||||
responses:
|
||||
'200':
|
||||
description: >-
|
||||
A VectorStoreFileObject representing the attached file.
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: '#/components/schemas/VectorStoreFileObject'
|
||||
'400':
|
||||
$ref: '#/components/responses/BadRequest400'
|
||||
'429':
|
||||
$ref: >-
|
||||
#/components/responses/TooManyRequests429
|
||||
'500':
|
||||
$ref: >-
|
||||
#/components/responses/InternalServerError500
|
||||
default:
|
||||
$ref: '#/components/responses/DefaultError'
|
||||
tags:
|
||||
- VectorIO
|
||||
description: Attach a file to a vector store.
|
||||
parameters:
|
||||
- name: vector_store_id
|
||||
in: path
|
||||
description: >-
|
||||
The ID of the vector store to attach the file to.
|
||||
required: true
|
||||
schema:
|
||||
type: string
|
||||
requestBody:
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: '#/components/schemas/OpenaiAttachFileToVectorStoreRequest'
|
||||
required: true
|
||||
/v1/openai/v1/completions:
|
||||
post:
|
||||
responses:
|
||||
|
|
@ -2294,6 +2331,91 @@ paths:
|
|||
schema:
|
||||
$ref: '#/components/schemas/OpenaiCompletionRequest'
|
||||
required: true
|
||||
/v1/openai/v1/vector_stores:
|
||||
get:
|
||||
responses:
|
||||
'200':
|
||||
description: >-
|
||||
A VectorStoreListResponse containing the list of vector stores.
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: '#/components/schemas/VectorStoreListResponse'
|
||||
'400':
|
||||
$ref: '#/components/responses/BadRequest400'
|
||||
'429':
|
||||
$ref: >-
|
||||
#/components/responses/TooManyRequests429
|
||||
'500':
|
||||
$ref: >-
|
||||
#/components/responses/InternalServerError500
|
||||
default:
|
||||
$ref: '#/components/responses/DefaultError'
|
||||
tags:
|
||||
- VectorIO
|
||||
description: Returns a list of vector stores.
|
||||
parameters:
|
||||
- name: limit
|
||||
in: query
|
||||
description: >-
|
||||
A limit on the number of objects to be returned. Limit can range between
|
||||
1 and 100, and the default is 20.
|
||||
required: false
|
||||
schema:
|
||||
type: integer
|
||||
- name: order
|
||||
in: query
|
||||
description: >-
|
||||
Sort order by the `created_at` timestamp of the objects. `asc` for ascending
|
||||
order and `desc` for descending order.
|
||||
required: false
|
||||
schema:
|
||||
type: string
|
||||
- name: after
|
||||
in: query
|
||||
description: >-
|
||||
A cursor for use in pagination. `after` is an object ID that defines your
|
||||
place in the list.
|
||||
required: false
|
||||
schema:
|
||||
type: string
|
||||
- name: before
|
||||
in: query
|
||||
description: >-
|
||||
A cursor for use in pagination. `before` is an object ID that defines
|
||||
your place in the list.
|
||||
required: false
|
||||
schema:
|
||||
type: string
|
||||
post:
|
||||
responses:
|
||||
'200':
|
||||
description: >-
|
||||
A VectorStoreObject representing the created vector store.
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: '#/components/schemas/VectorStoreObject'
|
||||
'400':
|
||||
$ref: '#/components/responses/BadRequest400'
|
||||
'429':
|
||||
$ref: >-
|
||||
#/components/responses/TooManyRequests429
|
||||
'500':
|
||||
$ref: >-
|
||||
#/components/responses/InternalServerError500
|
||||
default:
|
||||
$ref: '#/components/responses/DefaultError'
|
||||
tags:
|
||||
- VectorIO
|
||||
description: Creates a vector store.
|
||||
parameters: []
|
||||
requestBody:
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: '#/components/schemas/OpenaiCreateVectorStoreRequest'
|
||||
required: true
|
||||
/v1/openai/v1/files/{file_id}:
|
||||
get:
|
||||
responses:
|
||||
|
|
@ -2356,6 +2478,100 @@ paths:
|
|||
required: true
|
||||
schema:
|
||||
type: string
|
||||
/v1/openai/v1/vector_stores/{vector_store_id}:
|
||||
get:
|
||||
responses:
|
||||
'200':
|
||||
description: >-
|
||||
A VectorStoreObject representing the vector store.
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: '#/components/schemas/VectorStoreObject'
|
||||
'400':
|
||||
$ref: '#/components/responses/BadRequest400'
|
||||
'429':
|
||||
$ref: >-
|
||||
#/components/responses/TooManyRequests429
|
||||
'500':
|
||||
$ref: >-
|
||||
#/components/responses/InternalServerError500
|
||||
default:
|
||||
$ref: '#/components/responses/DefaultError'
|
||||
tags:
|
||||
- VectorIO
|
||||
description: Retrieves a vector store.
|
||||
parameters:
|
||||
- name: vector_store_id
|
||||
in: path
|
||||
description: The ID of the vector store to retrieve.
|
||||
required: true
|
||||
schema:
|
||||
type: string
|
||||
post:
|
||||
responses:
|
||||
'200':
|
||||
description: >-
|
||||
A VectorStoreObject representing the updated vector store.
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: '#/components/schemas/VectorStoreObject'
|
||||
'400':
|
||||
$ref: '#/components/responses/BadRequest400'
|
||||
'429':
|
||||
$ref: >-
|
||||
#/components/responses/TooManyRequests429
|
||||
'500':
|
||||
$ref: >-
|
||||
#/components/responses/InternalServerError500
|
||||
default:
|
||||
$ref: '#/components/responses/DefaultError'
|
||||
tags:
|
||||
- VectorIO
|
||||
description: Updates a vector store.
|
||||
parameters:
|
||||
- name: vector_store_id
|
||||
in: path
|
||||
description: The ID of the vector store to update.
|
||||
required: true
|
||||
schema:
|
||||
type: string
|
||||
requestBody:
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: '#/components/schemas/OpenaiUpdateVectorStoreRequest'
|
||||
required: true
|
||||
delete:
|
||||
responses:
|
||||
'200':
|
||||
description: >-
|
||||
A VectorStoreDeleteResponse indicating the deletion status.
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: '#/components/schemas/VectorStoreDeleteResponse'
|
||||
'400':
|
||||
$ref: '#/components/responses/BadRequest400'
|
||||
'429':
|
||||
$ref: >-
|
||||
#/components/responses/TooManyRequests429
|
||||
'500':
|
||||
$ref: >-
|
||||
#/components/responses/InternalServerError500
|
||||
default:
|
||||
$ref: '#/components/responses/DefaultError'
|
||||
tags:
|
||||
- VectorIO
|
||||
description: Delete a vector store.
|
||||
parameters:
|
||||
- name: vector_store_id
|
||||
in: path
|
||||
description: The ID of the vector store to delete.
|
||||
required: true
|
||||
schema:
|
||||
type: string
|
||||
/v1/openai/v1/embeddings:
|
||||
post:
|
||||
responses:
|
||||
|
|
@ -2546,6 +2762,46 @@ paths:
|
|||
required: true
|
||||
schema:
|
||||
type: string
|
||||
/v1/openai/v1/vector_stores/{vector_store_id}/search:
|
||||
post:
|
||||
responses:
|
||||
'200':
|
||||
description: >-
|
||||
A VectorStoreSearchResponse containing the search results.
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: '#/components/schemas/VectorStoreSearchResponsePage'
|
||||
'400':
|
||||
$ref: '#/components/responses/BadRequest400'
|
||||
'429':
|
||||
$ref: >-
|
||||
#/components/responses/TooManyRequests429
|
||||
'500':
|
||||
$ref: >-
|
||||
#/components/responses/InternalServerError500
|
||||
default:
|
||||
$ref: '#/components/responses/DefaultError'
|
||||
tags:
|
||||
- VectorIO
|
||||
description: >-
|
||||
Search for chunks in a vector store.
|
||||
|
||||
Searches a vector store for relevant chunks based on a query and optional
|
||||
file attribute filters.
|
||||
parameters:
|
||||
- name: vector_store_id
|
||||
in: path
|
||||
description: The ID of the vector store to search.
|
||||
required: true
|
||||
schema:
|
||||
type: string
|
||||
requestBody:
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: '#/components/schemas/OpenaiSearchVectorStoreRequest'
|
||||
required: true
|
||||
/v1/post-training/preference-optimize:
|
||||
post:
|
||||
responses:
|
||||
|
|
@ -4802,6 +5058,7 @@ components:
|
|||
OpenAIResponseInput:
|
||||
oneOf:
|
||||
- $ref: '#/components/schemas/OpenAIResponseOutputMessageWebSearchToolCall'
|
||||
- $ref: '#/components/schemas/OpenAIResponseOutputMessageFileSearchToolCall'
|
||||
- $ref: '#/components/schemas/OpenAIResponseOutputMessageFunctionToolCall'
|
||||
- $ref: '#/components/schemas/OpenAIResponseInputFunctionToolCallOutput'
|
||||
- $ref: '#/components/schemas/OpenAIResponseMessage'
|
||||
|
|
@ -4896,10 +5153,23 @@ components:
|
|||
type: string
|
||||
const: file_search
|
||||
default: file_search
|
||||
vector_store_id:
|
||||
vector_store_ids:
|
||||
type: array
|
||||
items:
|
||||
type: string
|
||||
filters:
|
||||
type: object
|
||||
additionalProperties:
|
||||
oneOf:
|
||||
- type: 'null'
|
||||
- type: boolean
|
||||
- type: number
|
||||
- type: string
|
||||
- type: array
|
||||
- type: object
|
||||
max_num_results:
|
||||
type: integer
|
||||
default: 10
|
||||
ranking_options:
|
||||
type: object
|
||||
properties:
|
||||
|
|
@ -4913,7 +5183,7 @@ components:
|
|||
additionalProperties: false
|
||||
required:
|
||||
- type
|
||||
- vector_store_id
|
||||
- vector_store_ids
|
||||
title: OpenAIResponseInputToolFileSearch
|
||||
OpenAIResponseInputToolFunction:
|
||||
type: object
|
||||
|
|
@ -5075,6 +5345,41 @@ components:
|
|||
- type
|
||||
title: >-
|
||||
OpenAIResponseOutputMessageContentOutputText
|
||||
"OpenAIResponseOutputMessageFileSearchToolCall":
|
||||
type: object
|
||||
properties:
|
||||
id:
|
||||
type: string
|
||||
queries:
|
||||
type: array
|
||||
items:
|
||||
type: string
|
||||
status:
|
||||
type: string
|
||||
type:
|
||||
type: string
|
||||
const: file_search_call
|
||||
default: file_search_call
|
||||
results:
|
||||
type: array
|
||||
items:
|
||||
type: object
|
||||
additionalProperties:
|
||||
oneOf:
|
||||
- type: 'null'
|
||||
- type: boolean
|
||||
- type: number
|
||||
- type: string
|
||||
- type: array
|
||||
- type: object
|
||||
additionalProperties: false
|
||||
required:
|
||||
- id
|
||||
- queries
|
||||
- status
|
||||
- type
|
||||
title: >-
|
||||
OpenAIResponseOutputMessageFileSearchToolCall
|
||||
"OpenAIResponseOutputMessageFunctionToolCall":
|
||||
type: object
|
||||
properties:
|
||||
|
|
@ -5272,6 +5577,7 @@ components:
|
|||
oneOf:
|
||||
- $ref: '#/components/schemas/OpenAIResponseMessage'
|
||||
- $ref: '#/components/schemas/OpenAIResponseOutputMessageWebSearchToolCall'
|
||||
- $ref: '#/components/schemas/OpenAIResponseOutputMessageFileSearchToolCall'
|
||||
- $ref: '#/components/schemas/OpenAIResponseOutputMessageFunctionToolCall'
|
||||
- $ref: '#/components/schemas/OpenAIResponseOutputMessageMCPCall'
|
||||
- $ref: '#/components/schemas/OpenAIResponseOutputMessageMCPListTools'
|
||||
|
|
@ -5280,6 +5586,7 @@ components:
|
|||
mapping:
|
||||
message: '#/components/schemas/OpenAIResponseMessage'
|
||||
web_search_call: '#/components/schemas/OpenAIResponseOutputMessageWebSearchToolCall'
|
||||
file_search_call: '#/components/schemas/OpenAIResponseOutputMessageFileSearchToolCall'
|
||||
function_call: '#/components/schemas/OpenAIResponseOutputMessageFunctionToolCall'
|
||||
mcp_call: '#/components/schemas/OpenAIResponseOutputMessageMCPCall'
|
||||
mcp_list_tools: '#/components/schemas/OpenAIResponseOutputMessageMCPListTools'
|
||||
|
|
@ -7511,6 +7818,9 @@ components:
|
|||
type: boolean
|
||||
description: >-
|
||||
Whether there are more items available after this set
|
||||
url:
|
||||
type: string
|
||||
description: The URL for accessing this list
|
||||
additionalProperties: false
|
||||
required:
|
||||
- data
|
||||
|
|
@ -8032,6 +8342,148 @@ components:
|
|||
- event
|
||||
- ttl_seconds
|
||||
title: LogEventRequest
|
||||
VectorStoreChunkingStrategy:
|
||||
oneOf:
|
||||
- $ref: '#/components/schemas/VectorStoreChunkingStrategyAuto'
|
||||
- $ref: '#/components/schemas/VectorStoreChunkingStrategyStatic'
|
||||
discriminator:
|
||||
propertyName: type
|
||||
mapping:
|
||||
auto: '#/components/schemas/VectorStoreChunkingStrategyAuto'
|
||||
static: '#/components/schemas/VectorStoreChunkingStrategyStatic'
|
||||
VectorStoreChunkingStrategyAuto:
|
||||
type: object
|
||||
properties:
|
||||
type:
|
||||
type: string
|
||||
const: auto
|
||||
default: auto
|
||||
additionalProperties: false
|
||||
required:
|
||||
- type
|
||||
title: VectorStoreChunkingStrategyAuto
|
||||
VectorStoreChunkingStrategyStatic:
|
||||
type: object
|
||||
properties:
|
||||
type:
|
||||
type: string
|
||||
const: static
|
||||
default: static
|
||||
static:
|
||||
$ref: '#/components/schemas/VectorStoreChunkingStrategyStaticConfig'
|
||||
additionalProperties: false
|
||||
required:
|
||||
- type
|
||||
- static
|
||||
title: VectorStoreChunkingStrategyStatic
|
||||
VectorStoreChunkingStrategyStaticConfig:
|
||||
type: object
|
||||
properties:
|
||||
chunk_overlap_tokens:
|
||||
type: integer
|
||||
default: 400
|
||||
max_chunk_size_tokens:
|
||||
type: integer
|
||||
default: 800
|
||||
additionalProperties: false
|
||||
required:
|
||||
- chunk_overlap_tokens
|
||||
- max_chunk_size_tokens
|
||||
title: VectorStoreChunkingStrategyStaticConfig
|
||||
OpenaiAttachFileToVectorStoreRequest:
|
||||
type: object
|
||||
properties:
|
||||
file_id:
|
||||
type: string
|
||||
description: >-
|
||||
The ID of the file to attach to the vector store.
|
||||
attributes:
|
||||
type: object
|
||||
additionalProperties:
|
||||
oneOf:
|
||||
- type: 'null'
|
||||
- type: boolean
|
||||
- type: number
|
||||
- type: string
|
||||
- type: array
|
||||
- type: object
|
||||
description: >-
|
||||
The key-value attributes stored with the file, which can be used for filtering.
|
||||
chunking_strategy:
|
||||
$ref: '#/components/schemas/VectorStoreChunkingStrategy'
|
||||
description: >-
|
||||
The chunking strategy to use for the file.
|
||||
additionalProperties: false
|
||||
required:
|
||||
- file_id
|
||||
title: OpenaiAttachFileToVectorStoreRequest
|
||||
VectorStoreFileLastError:
|
||||
type: object
|
||||
properties:
|
||||
code:
|
||||
oneOf:
|
||||
- type: string
|
||||
const: server_error
|
||||
- type: string
|
||||
const: rate_limit_exceeded
|
||||
message:
|
||||
type: string
|
||||
additionalProperties: false
|
||||
required:
|
||||
- code
|
||||
- message
|
||||
title: VectorStoreFileLastError
|
||||
VectorStoreFileObject:
|
||||
type: object
|
||||
properties:
|
||||
id:
|
||||
type: string
|
||||
object:
|
||||
type: string
|
||||
default: vector_store.file
|
||||
attributes:
|
||||
type: object
|
||||
additionalProperties:
|
||||
oneOf:
|
||||
- type: 'null'
|
||||
- type: boolean
|
||||
- type: number
|
||||
- type: string
|
||||
- type: array
|
||||
- type: object
|
||||
chunking_strategy:
|
||||
$ref: '#/components/schemas/VectorStoreChunkingStrategy'
|
||||
created_at:
|
||||
type: integer
|
||||
last_error:
|
||||
$ref: '#/components/schemas/VectorStoreFileLastError'
|
||||
status:
|
||||
oneOf:
|
||||
- type: string
|
||||
const: completed
|
||||
- type: string
|
||||
const: in_progress
|
||||
- type: string
|
||||
const: cancelled
|
||||
- type: string
|
||||
const: failed
|
||||
usage_bytes:
|
||||
type: integer
|
||||
default: 0
|
||||
vector_store_id:
|
||||
type: string
|
||||
additionalProperties: false
|
||||
required:
|
||||
- id
|
||||
- object
|
||||
- attributes
|
||||
- chunking_strategy
|
||||
- created_at
|
||||
- status
|
||||
- usage_bytes
|
||||
- vector_store_id
|
||||
title: VectorStoreFileObject
|
||||
description: OpenAI Vector Store File object.
|
||||
OpenAIJSONSchema:
|
||||
type: object
|
||||
properties:
|
||||
|
|
@ -8454,6 +8906,10 @@ components:
|
|||
type: string
|
||||
prompt_logprobs:
|
||||
type: integer
|
||||
suffix:
|
||||
type: string
|
||||
description: >-
|
||||
(Optional) The suffix that should be appended to the completion.
|
||||
additionalProperties: false
|
||||
required:
|
||||
- model
|
||||
|
|
@ -8505,6 +8961,133 @@ components:
|
|||
title: OpenAICompletionChoice
|
||||
description: >-
|
||||
A choice from an OpenAI-compatible completion response.
|
||||
OpenaiCreateVectorStoreRequest:
|
||||
type: object
|
||||
properties:
|
||||
name:
|
||||
type: string
|
||||
description: A name for the vector store.
|
||||
file_ids:
|
||||
type: array
|
||||
items:
|
||||
type: string
|
||||
description: >-
|
||||
A list of File IDs that the vector store should use. Useful for tools
|
||||
like `file_search` that can access files.
|
||||
expires_after:
|
||||
type: object
|
||||
additionalProperties:
|
||||
oneOf:
|
||||
- type: 'null'
|
||||
- type: boolean
|
||||
- type: number
|
||||
- type: string
|
||||
- type: array
|
||||
- type: object
|
||||
description: >-
|
||||
The expiration policy for a vector store.
|
||||
chunking_strategy:
|
||||
type: object
|
||||
additionalProperties:
|
||||
oneOf:
|
||||
- type: 'null'
|
||||
- type: boolean
|
||||
- type: number
|
||||
- type: string
|
||||
- type: array
|
||||
- type: object
|
||||
description: >-
|
||||
The chunking strategy used to chunk the file(s). If not set, will use
|
||||
the `auto` strategy.
|
||||
metadata:
|
||||
type: object
|
||||
additionalProperties:
|
||||
oneOf:
|
||||
- type: 'null'
|
||||
- type: boolean
|
||||
- type: number
|
||||
- type: string
|
||||
- type: array
|
||||
- type: object
|
||||
description: >-
|
||||
Set of 16 key-value pairs that can be attached to an object.
|
||||
embedding_model:
|
||||
type: string
|
||||
description: >-
|
||||
The embedding model to use for this vector store.
|
||||
embedding_dimension:
|
||||
type: integer
|
||||
description: >-
|
||||
The dimension of the embedding vectors (default: 384).
|
||||
provider_id:
|
||||
type: string
|
||||
description: >-
|
||||
The ID of the provider to use for this vector store.
|
||||
provider_vector_db_id:
|
||||
type: string
|
||||
description: >-
|
||||
The provider-specific vector database ID.
|
||||
additionalProperties: false
|
||||
required:
|
||||
- name
|
||||
title: OpenaiCreateVectorStoreRequest
|
||||
VectorStoreObject:
|
||||
type: object
|
||||
properties:
|
||||
id:
|
||||
type: string
|
||||
object:
|
||||
type: string
|
||||
default: vector_store
|
||||
created_at:
|
||||
type: integer
|
||||
name:
|
||||
type: string
|
||||
usage_bytes:
|
||||
type: integer
|
||||
default: 0
|
||||
file_counts:
|
||||
type: object
|
||||
additionalProperties:
|
||||
type: integer
|
||||
status:
|
||||
type: string
|
||||
default: completed
|
||||
expires_after:
|
||||
type: object
|
||||
additionalProperties:
|
||||
oneOf:
|
||||
- type: 'null'
|
||||
- type: boolean
|
||||
- type: number
|
||||
- type: string
|
||||
- type: array
|
||||
- type: object
|
||||
expires_at:
|
||||
type: integer
|
||||
last_active_at:
|
||||
type: integer
|
||||
metadata:
|
||||
type: object
|
||||
additionalProperties:
|
||||
oneOf:
|
||||
- type: 'null'
|
||||
- type: boolean
|
||||
- type: number
|
||||
- type: string
|
||||
- type: array
|
||||
- type: object
|
||||
additionalProperties: false
|
||||
required:
|
||||
- id
|
||||
- object
|
||||
- created_at
|
||||
- usage_bytes
|
||||
- file_counts
|
||||
- status
|
||||
- metadata
|
||||
title: VectorStoreObject
|
||||
description: OpenAI Vector Store object.
|
||||
OpenAIFileDeleteResponse:
|
||||
type: object
|
||||
properties:
|
||||
|
|
@ -8528,6 +9111,24 @@ components:
|
|||
title: OpenAIFileDeleteResponse
|
||||
description: >-
|
||||
Response for deleting a file in OpenAI Files API.
|
||||
VectorStoreDeleteResponse:
|
||||
type: object
|
||||
properties:
|
||||
id:
|
||||
type: string
|
||||
object:
|
||||
type: string
|
||||
default: vector_store.deleted
|
||||
deleted:
|
||||
type: boolean
|
||||
default: true
|
||||
additionalProperties: false
|
||||
required:
|
||||
- id
|
||||
- object
|
||||
- deleted
|
||||
title: VectorStoreDeleteResponse
|
||||
description: Response from deleting a vector store.
|
||||
OpenaiEmbeddingsRequest:
|
||||
type: object
|
||||
properties:
|
||||
|
|
@ -8751,9 +9352,179 @@ components:
|
|||
required:
|
||||
- data
|
||||
title: OpenAIListModelsResponse
|
||||
VectorStoreListResponse:
|
||||
type: object
|
||||
properties:
|
||||
object:
|
||||
type: string
|
||||
default: list
|
||||
data:
|
||||
type: array
|
||||
items:
|
||||
$ref: '#/components/schemas/VectorStoreObject'
|
||||
first_id:
|
||||
type: string
|
||||
last_id:
|
||||
type: string
|
||||
has_more:
|
||||
type: boolean
|
||||
default: false
|
||||
additionalProperties: false
|
||||
required:
|
||||
- object
|
||||
- data
|
||||
- has_more
|
||||
title: VectorStoreListResponse
|
||||
description: Response from listing vector stores.
|
||||
Response:
|
||||
type: object
|
||||
title: Response
|
||||
OpenaiSearchVectorStoreRequest:
|
||||
type: object
|
||||
properties:
|
||||
query:
|
||||
oneOf:
|
||||
- type: string
|
||||
- type: array
|
||||
items:
|
||||
type: string
|
||||
description: >-
|
||||
The query string or array for performing the search.
|
||||
filters:
|
||||
type: object
|
||||
additionalProperties:
|
||||
oneOf:
|
||||
- type: 'null'
|
||||
- type: boolean
|
||||
- type: number
|
||||
- type: string
|
||||
- type: array
|
||||
- type: object
|
||||
description: >-
|
||||
Filters based on file attributes to narrow the search results.
|
||||
max_num_results:
|
||||
type: integer
|
||||
description: >-
|
||||
Maximum number of results to return (1 to 50 inclusive, default 10).
|
||||
ranking_options:
|
||||
type: object
|
||||
additionalProperties:
|
||||
oneOf:
|
||||
- type: 'null'
|
||||
- type: boolean
|
||||
- type: number
|
||||
- type: string
|
||||
- type: array
|
||||
- type: object
|
||||
description: >-
|
||||
Ranking options for fine-tuning the search results.
|
||||
rewrite_query:
|
||||
type: boolean
|
||||
description: >-
|
||||
Whether to rewrite the natural language query for vector search (default
|
||||
false)
|
||||
additionalProperties: false
|
||||
required:
|
||||
- query
|
||||
title: OpenaiSearchVectorStoreRequest
|
||||
VectorStoreContent:
|
||||
type: object
|
||||
properties:
|
||||
type:
|
||||
type: string
|
||||
const: text
|
||||
text:
|
||||
type: string
|
||||
additionalProperties: false
|
||||
required:
|
||||
- type
|
||||
- text
|
||||
title: VectorStoreContent
|
||||
VectorStoreSearchResponse:
|
||||
type: object
|
||||
properties:
|
||||
file_id:
|
||||
type: string
|
||||
filename:
|
||||
type: string
|
||||
score:
|
||||
type: number
|
||||
attributes:
|
||||
type: object
|
||||
additionalProperties:
|
||||
oneOf:
|
||||
- type: string
|
||||
- type: number
|
||||
- type: boolean
|
||||
content:
|
||||
type: array
|
||||
items:
|
||||
$ref: '#/components/schemas/VectorStoreContent'
|
||||
additionalProperties: false
|
||||
required:
|
||||
- file_id
|
||||
- filename
|
||||
- score
|
||||
- content
|
||||
title: VectorStoreSearchResponse
|
||||
description: Response from searching a vector store.
|
||||
VectorStoreSearchResponsePage:
|
||||
type: object
|
||||
properties:
|
||||
object:
|
||||
type: string
|
||||
default: vector_store.search_results.page
|
||||
search_query:
|
||||
type: string
|
||||
data:
|
||||
type: array
|
||||
items:
|
||||
$ref: '#/components/schemas/VectorStoreSearchResponse'
|
||||
has_more:
|
||||
type: boolean
|
||||
default: false
|
||||
next_page:
|
||||
type: string
|
||||
additionalProperties: false
|
||||
required:
|
||||
- object
|
||||
- search_query
|
||||
- data
|
||||
- has_more
|
||||
title: VectorStoreSearchResponsePage
|
||||
description: Response from searching a vector store.
|
||||
OpenaiUpdateVectorStoreRequest:
|
||||
type: object
|
||||
properties:
|
||||
name:
|
||||
type: string
|
||||
description: The name of the vector store.
|
||||
expires_after:
|
||||
type: object
|
||||
additionalProperties:
|
||||
oneOf:
|
||||
- type: 'null'
|
||||
- type: boolean
|
||||
- type: number
|
||||
- type: string
|
||||
- type: array
|
||||
- type: object
|
||||
description: >-
|
||||
The expiration policy for a vector store.
|
||||
metadata:
|
||||
type: object
|
||||
additionalProperties:
|
||||
oneOf:
|
||||
- type: 'null'
|
||||
- type: boolean
|
||||
- type: number
|
||||
- type: string
|
||||
- type: array
|
||||
- type: object
|
||||
description: >-
|
||||
Set of 16 key-value pairs that can be attached to an object.
|
||||
additionalProperties: false
|
||||
title: OpenaiUpdateVectorStoreRequest
|
||||
DPOAlignmentConfig:
|
||||
type: object
|
||||
properties:
|
||||
|
|
@ -8992,7 +9763,13 @@ components:
|
|||
mode:
|
||||
type: string
|
||||
description: >-
|
||||
Search mode for retrieval—either "vector" or "keyword". Default "vector".
|
||||
Search mode for retrieval—either "vector", "keyword", or "hybrid". Default
|
||||
"vector".
|
||||
ranker:
|
||||
$ref: '#/components/schemas/Ranker'
|
||||
description: >-
|
||||
Configuration for the ranker to use in hybrid search. Defaults to RRF
|
||||
ranker.
|
||||
additionalProperties: false
|
||||
required:
|
||||
- query_generator_config
|
||||
|
|
@ -9011,6 +9788,58 @@ components:
|
|||
mapping:
|
||||
default: '#/components/schemas/DefaultRAGQueryGeneratorConfig'
|
||||
llm: '#/components/schemas/LLMRAGQueryGeneratorConfig'
|
||||
RRFRanker:
|
||||
type: object
|
||||
properties:
|
||||
type:
|
||||
type: string
|
||||
const: rrf
|
||||
default: rrf
|
||||
description: The type of ranker, always "rrf"
|
||||
impact_factor:
|
||||
type: number
|
||||
default: 60.0
|
||||
description: >-
|
||||
The impact factor for RRF scoring. Higher values give more weight to higher-ranked
|
||||
results. Must be greater than 0. Default of 60 is from the original RRF
|
||||
paper (Cormack et al., 2009).
|
||||
additionalProperties: false
|
||||
required:
|
||||
- type
|
||||
- impact_factor
|
||||
title: RRFRanker
|
||||
description: >-
|
||||
Reciprocal Rank Fusion (RRF) ranker configuration.
|
||||
Ranker:
|
||||
oneOf:
|
||||
- $ref: '#/components/schemas/RRFRanker'
|
||||
- $ref: '#/components/schemas/WeightedRanker'
|
||||
discriminator:
|
||||
propertyName: type
|
||||
mapping:
|
||||
rrf: '#/components/schemas/RRFRanker'
|
||||
weighted: '#/components/schemas/WeightedRanker'
|
||||
WeightedRanker:
|
||||
type: object
|
||||
properties:
|
||||
type:
|
||||
type: string
|
||||
const: weighted
|
||||
default: weighted
|
||||
description: The type of ranker, always "weighted"
|
||||
alpha:
|
||||
type: number
|
||||
default: 0.5
|
||||
description: >-
|
||||
Weight factor between 0 and 1. 0 means only use keyword scores, 1 means
|
||||
only use vector scores, values in between blend both scores.
|
||||
additionalProperties: false
|
||||
required:
|
||||
- type
|
||||
- alpha
|
||||
title: WeightedRanker
|
||||
description: >-
|
||||
Weighted ranker configuration that combines vector and keyword scores.
|
||||
QueryRequest:
|
||||
type: object
|
||||
properties:
|
||||
|
|
|
|||
|
|
@ -56,10 +56,10 @@ shields: []
|
|||
server:
|
||||
port: 8321
|
||||
auth:
|
||||
provider_type: "kubernetes"
|
||||
provider_type: "oauth2_token"
|
||||
config:
|
||||
api_server_url: "https://kubernetes.default.svc"
|
||||
ca_cert_path: "/path/to/ca.crt"
|
||||
jwks:
|
||||
uri: "https://my-token-issuing-svc.com/jwks"
|
||||
```
|
||||
|
||||
Let's break this down into the different sections. The first section specifies the set of APIs that the stack server will serve:
|
||||
|
|
@ -132,16 +132,52 @@ The server supports multiple authentication providers:
|
|||
|
||||
#### OAuth 2.0/OpenID Connect Provider with Kubernetes
|
||||
|
||||
The Kubernetes cluster must be configured to use a service account for authentication.
|
||||
The server can be configured to use service account tokens for authorization, validating these against the Kubernetes API server, e.g.:
|
||||
```yaml
|
||||
server:
|
||||
auth:
|
||||
provider_type: "oauth2_token"
|
||||
config:
|
||||
jwks:
|
||||
uri: "https://kubernetes.default.svc:8443/openid/v1/jwks"
|
||||
token: "${env.TOKEN:}"
|
||||
key_recheck_period: 3600
|
||||
tls_cafile: "/path/to/ca.crt"
|
||||
issuer: "https://kubernetes.default.svc"
|
||||
audience: "https://kubernetes.default.svc"
|
||||
```
|
||||
|
||||
To find your cluster's jwks uri (from which the public key(s) to verify the token signature are obtained), run:
|
||||
```
|
||||
kubectl get --raw /.well-known/openid-configuration| jq -r .jwks_uri
|
||||
```
|
||||
|
||||
For the tls_cafile, you can use the CA certificate of the OIDC provider:
|
||||
```bash
|
||||
kubectl config view --minify -o jsonpath='{.clusters[0].cluster.certificate-authority}'
|
||||
```
|
||||
|
||||
For the issuer, you can use the OIDC provider's URL:
|
||||
```bash
|
||||
kubectl get --raw /.well-known/openid-configuration| jq .issuer
|
||||
```
|
||||
|
||||
The audience can be obtained from a token, e.g. run:
|
||||
```bash
|
||||
kubectl create token default --duration=1h | cut -d. -f2 | base64 -d | jq .aud
|
||||
```
|
||||
|
||||
The jwks token is used to authorize access to the jwks endpoint. You can obtain a token by running:
|
||||
|
||||
```bash
|
||||
kubectl create namespace llama-stack
|
||||
kubectl create serviceaccount llama-stack-auth -n llama-stack
|
||||
kubectl create rolebinding llama-stack-auth-rolebinding --clusterrole=admin --serviceaccount=llama-stack:llama-stack-auth -n llama-stack
|
||||
kubectl create token llama-stack-auth -n llama-stack > llama-stack-auth-token
|
||||
export TOKEN=$(cat llama-stack-auth-token)
|
||||
```
|
||||
|
||||
Make sure the `kube-apiserver` runs with `--anonymous-auth=true` to allow unauthenticated requests
|
||||
Alternatively, you can configure the jwks endpoint to allow anonymous access. To do this, make sure
|
||||
the `kube-apiserver` runs with `--anonymous-auth=true` to allow unauthenticated requests
|
||||
and that the correct RoleBinding is created to allow the service account to access the necessary
|
||||
resources. If that is not the case, you can create a RoleBinding for the service account to access
|
||||
the necessary resources:
|
||||
|
|
@ -175,35 +211,6 @@ And then apply the configuration:
|
|||
kubectl apply -f allow-anonymous-openid.yaml
|
||||
```
|
||||
|
||||
Validates tokens against the Kubernetes API server through the OIDC provider:
|
||||
```yaml
|
||||
server:
|
||||
auth:
|
||||
provider_type: "oauth2_token"
|
||||
config:
|
||||
jwks:
|
||||
uri: "https://kubernetes.default.svc"
|
||||
key_recheck_period: 3600
|
||||
tls_cafile: "/path/to/ca.crt"
|
||||
issuer: "https://kubernetes.default.svc"
|
||||
audience: "https://kubernetes.default.svc"
|
||||
```
|
||||
|
||||
To find your cluster's audience, run:
|
||||
```bash
|
||||
kubectl create token default --duration=1h | cut -d. -f2 | base64 -d | jq .aud
|
||||
```
|
||||
|
||||
For the issuer, you can use the OIDC provider's URL:
|
||||
```bash
|
||||
kubectl get --raw /.well-known/openid-configuration| jq .issuer
|
||||
```
|
||||
|
||||
For the tls_cafile, you can use the CA certificate of the OIDC provider:
|
||||
```bash
|
||||
kubectl config view --minify -o jsonpath='{.clusters[0].cluster.certificate-authority}'
|
||||
```
|
||||
|
||||
The provider extracts user information from the JWT token:
|
||||
- Username from the `sub` claim becomes a role
|
||||
- Kubernetes groups become teams
|
||||
|
|
|
|||
|
|
@ -18,6 +18,7 @@ The `llamastack/distribution-ollama` distribution consists of the following prov
|
|||
| agents | `inline::meta-reference` |
|
||||
| datasetio | `remote::huggingface`, `inline::localfs` |
|
||||
| eval | `inline::meta-reference` |
|
||||
| files | `inline::localfs` |
|
||||
| inference | `remote::ollama` |
|
||||
| post_training | `inline::huggingface` |
|
||||
| safety | `inline::llama-guard` |
|
||||
|
|
|
|||
|
|
@ -66,25 +66,126 @@ To use sqlite-vec in your Llama Stack project, follow these steps:
|
|||
2. Configure your Llama Stack project to use SQLite-Vec.
|
||||
3. Start storing and querying vectors.
|
||||
|
||||
## Supported Search Modes
|
||||
The SQLite-vec provider supports three search modes:
|
||||
|
||||
The sqlite-vec provider supports both vector-based and keyword-based (full-text) search modes.
|
||||
|
||||
When using the RAGTool interface, you can specify the desired search behavior via the `mode` parameter in
|
||||
`RAGQueryConfig`. For example:
|
||||
1. **Vector Search** (`mode="vector"`): Performs pure vector similarity search using the embeddings.
|
||||
2. **Keyword Search** (`mode="keyword"`): Performs full-text search using SQLite's FTS5.
|
||||
3. **Hybrid Search** (`mode="hybrid"`): Combines both vector and keyword search for better results. First performs keyword search to get candidate matches, then applies vector similarity search on those candidates.
|
||||
|
||||
Example with hybrid search:
|
||||
```python
|
||||
from llama_stack.apis.tool_runtime.rag import RAGQueryConfig
|
||||
response = await vector_io.query_chunks(
|
||||
vector_db_id="my_db",
|
||||
query="your query here",
|
||||
params={"mode": "hybrid", "max_chunks": 3, "score_threshold": 0.7},
|
||||
)
|
||||
|
||||
query_config = RAGQueryConfig(max_chunks=6, mode="vector")
|
||||
# Using RRF ranker
|
||||
response = await vector_io.query_chunks(
|
||||
vector_db_id="my_db",
|
||||
query="your query here",
|
||||
params={
|
||||
"mode": "hybrid",
|
||||
"max_chunks": 3,
|
||||
"score_threshold": 0.7,
|
||||
"ranker": {"type": "rrf", "impact_factor": 60.0},
|
||||
},
|
||||
)
|
||||
|
||||
results = client.tool_runtime.rag_tool.query(
|
||||
vector_db_ids=[vector_db_id],
|
||||
content="what is torchtune",
|
||||
query_config=query_config,
|
||||
# Using weighted ranker
|
||||
response = await vector_io.query_chunks(
|
||||
vector_db_id="my_db",
|
||||
query="your query here",
|
||||
params={
|
||||
"mode": "hybrid",
|
||||
"max_chunks": 3,
|
||||
"score_threshold": 0.7,
|
||||
"ranker": {"type": "weighted", "alpha": 0.7}, # 70% vector, 30% keyword
|
||||
},
|
||||
)
|
||||
```
|
||||
|
||||
Example with explicit vector search:
|
||||
```python
|
||||
response = await vector_io.query_chunks(
|
||||
vector_db_id="my_db",
|
||||
query="your query here",
|
||||
params={"mode": "vector", "max_chunks": 3, "score_threshold": 0.7},
|
||||
)
|
||||
```
|
||||
|
||||
Example with keyword search:
|
||||
```python
|
||||
response = await vector_io.query_chunks(
|
||||
vector_db_id="my_db",
|
||||
query="your query here",
|
||||
params={"mode": "keyword", "max_chunks": 3, "score_threshold": 0.7},
|
||||
)
|
||||
```
|
||||
|
||||
## Supported Search Modes
|
||||
|
||||
The SQLite vector store supports three search modes:
|
||||
|
||||
1. **Vector Search** (`mode="vector"`): Uses vector similarity to find relevant chunks
|
||||
2. **Keyword Search** (`mode="keyword"`): Uses keyword matching to find relevant chunks
|
||||
3. **Hybrid Search** (`mode="hybrid"`): Combines both vector and keyword scores using a ranker
|
||||
|
||||
### Hybrid Search
|
||||
|
||||
Hybrid search combines the strengths of both vector and keyword search by:
|
||||
- Computing vector similarity scores
|
||||
- Computing keyword match scores
|
||||
- Using a ranker to combine these scores
|
||||
|
||||
Two ranker types are supported:
|
||||
|
||||
1. **RRF (Reciprocal Rank Fusion)**:
|
||||
- Combines ranks from both vector and keyword results
|
||||
- Uses an impact factor (default: 60.0) to control the weight of higher-ranked results
|
||||
- Good for balancing between vector and keyword results
|
||||
- The default impact factor of 60.0 comes from the original RRF paper by Cormack et al. (2009) [^1], which found this value to provide optimal performance across various retrieval tasks
|
||||
|
||||
2. **Weighted**:
|
||||
- Linearly combines normalized vector and keyword scores
|
||||
- Uses an alpha parameter (0-1) to control the blend:
|
||||
- alpha=0: Only use keyword scores
|
||||
- alpha=1: Only use vector scores
|
||||
- alpha=0.5: Equal weight to both (default)
|
||||
|
||||
Example using RAGQueryConfig with different search modes:
|
||||
|
||||
```python
|
||||
from llama_stack.apis.tools import RAGQueryConfig, RRFRanker, WeightedRanker
|
||||
|
||||
# Vector search
|
||||
config = RAGQueryConfig(mode="vector", max_chunks=5)
|
||||
|
||||
# Keyword search
|
||||
config = RAGQueryConfig(mode="keyword", max_chunks=5)
|
||||
|
||||
# Hybrid search with custom RRF ranker
|
||||
config = RAGQueryConfig(
|
||||
mode="hybrid",
|
||||
max_chunks=5,
|
||||
ranker=RRFRanker(impact_factor=50.0), # Custom impact factor
|
||||
)
|
||||
|
||||
# Hybrid search with weighted ranker
|
||||
config = RAGQueryConfig(
|
||||
mode="hybrid",
|
||||
max_chunks=5,
|
||||
ranker=WeightedRanker(alpha=0.7), # 70% vector, 30% keyword
|
||||
)
|
||||
|
||||
# Hybrid search with default RRF ranker
|
||||
config = RAGQueryConfig(
|
||||
mode="hybrid", max_chunks=5
|
||||
) # Will use RRF with impact_factor=60.0
|
||||
```
|
||||
|
||||
Note: The ranker configuration is only used in hybrid mode. For vector or keyword modes, the ranker parameter is ignored.
|
||||
|
||||
## Installation
|
||||
|
||||
You can install SQLite-Vec using pip:
|
||||
|
|
@ -96,3 +197,5 @@ pip install sqlite-vec
|
|||
## Documentation
|
||||
|
||||
See [sqlite-vec's GitHub repo](https://github.com/asg017/sqlite-vec/tree/main) for more details about sqlite-vec in general.
|
||||
|
||||
[^1]: Cormack, G. V., Clarke, C. L., & Buettcher, S. (2009). [Reciprocal rank fusion outperforms condorcet and individual rank learning methods](https://dl.acm.org/doi/10.1145/1571941.1572114). In Proceedings of the 32nd international ACM SIGIR conference on Research and development in information retrieval (pp. 758-759).
|
||||
|
|
|
|||
|
|
@ -81,6 +81,15 @@ class OpenAIResponseOutputMessageWebSearchToolCall(BaseModel):
|
|||
type: Literal["web_search_call"] = "web_search_call"
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class OpenAIResponseOutputMessageFileSearchToolCall(BaseModel):
|
||||
id: str
|
||||
queries: list[str]
|
||||
status: str
|
||||
type: Literal["file_search_call"] = "file_search_call"
|
||||
results: list[dict[str, Any]] | None = None
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class OpenAIResponseOutputMessageFunctionToolCall(BaseModel):
|
||||
call_id: str
|
||||
|
|
@ -119,6 +128,7 @@ class OpenAIResponseOutputMessageMCPListTools(BaseModel):
|
|||
OpenAIResponseOutput = Annotated[
|
||||
OpenAIResponseMessage
|
||||
| OpenAIResponseOutputMessageWebSearchToolCall
|
||||
| OpenAIResponseOutputMessageFileSearchToolCall
|
||||
| OpenAIResponseOutputMessageFunctionToolCall
|
||||
| OpenAIResponseOutputMessageMCPCall
|
||||
| OpenAIResponseOutputMessageMCPListTools,
|
||||
|
|
@ -362,6 +372,7 @@ class OpenAIResponseInputFunctionToolCallOutput(BaseModel):
|
|||
OpenAIResponseInput = Annotated[
|
||||
# Responses API allows output messages to be passed in as input
|
||||
OpenAIResponseOutputMessageWebSearchToolCall
|
||||
| OpenAIResponseOutputMessageFileSearchToolCall
|
||||
| OpenAIResponseOutputMessageFunctionToolCall
|
||||
| OpenAIResponseInputFunctionToolCallOutput
|
||||
|
|
||||
|
|
@ -397,9 +408,10 @@ class FileSearchRankingOptions(BaseModel):
|
|||
@json_schema_type
|
||||
class OpenAIResponseInputToolFileSearch(BaseModel):
|
||||
type: Literal["file_search"] = "file_search"
|
||||
vector_store_id: list[str]
|
||||
vector_store_ids: list[str]
|
||||
filters: dict[str, Any] | None = None
|
||||
max_num_results: int | None = Field(default=10, ge=1, le=50)
|
||||
ranking_options: FileSearchRankingOptions | None = None
|
||||
# TODO: add filters
|
||||
|
||||
|
||||
class ApprovalFilter(BaseModel):
|
||||
|
|
|
|||
|
|
@ -23,7 +23,9 @@ class PaginatedResponse(BaseModel):
|
|||
|
||||
:param data: The list of items for the current page
|
||||
:param has_more: Whether there are more items available after this set
|
||||
:param url: The URL for accessing this list
|
||||
"""
|
||||
|
||||
data: list[dict[str, Any]]
|
||||
has_more: bool
|
||||
url: str | None = None
|
||||
|
|
|
|||
|
|
@ -1038,6 +1038,8 @@ class InferenceProvider(Protocol):
|
|||
# vLLM-specific parameters
|
||||
guided_choice: list[str] | None = None,
|
||||
prompt_logprobs: int | None = None,
|
||||
# for fill-in-the-middle type completion
|
||||
suffix: str | None = None,
|
||||
) -> OpenAICompletion:
|
||||
"""Generate an OpenAI-compatible completion for the given prompt using the specified model.
|
||||
|
||||
|
|
@ -1058,6 +1060,7 @@ class InferenceProvider(Protocol):
|
|||
:param temperature: (Optional) The temperature to use.
|
||||
:param top_p: (Optional) The top p to use.
|
||||
:param user: (Optional) The user to use.
|
||||
:param suffix: (Optional) The suffix that should be appended to the completion.
|
||||
:returns: An OpenAICompletion.
|
||||
"""
|
||||
...
|
||||
|
|
|
|||
|
|
@ -15,6 +15,48 @@ from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol
|
|||
from llama_stack.schema_utils import json_schema_type, register_schema, webmethod
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class RRFRanker(BaseModel):
|
||||
"""
|
||||
Reciprocal Rank Fusion (RRF) ranker configuration.
|
||||
|
||||
:param type: The type of ranker, always "rrf"
|
||||
:param impact_factor: The impact factor for RRF scoring. Higher values give more weight to higher-ranked results.
|
||||
Must be greater than 0. Default of 60 is from the original RRF paper (Cormack et al., 2009).
|
||||
"""
|
||||
|
||||
type: Literal["rrf"] = "rrf"
|
||||
impact_factor: float = Field(default=60.0, gt=0.0) # default of 60 for optimal performance
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class WeightedRanker(BaseModel):
|
||||
"""
|
||||
Weighted ranker configuration that combines vector and keyword scores.
|
||||
|
||||
:param type: The type of ranker, always "weighted"
|
||||
:param alpha: Weight factor between 0 and 1.
|
||||
0 means only use keyword scores,
|
||||
1 means only use vector scores,
|
||||
values in between blend both scores.
|
||||
"""
|
||||
|
||||
type: Literal["weighted"] = "weighted"
|
||||
alpha: float = Field(
|
||||
default=0.5,
|
||||
ge=0.0,
|
||||
le=1.0,
|
||||
description="Weight factor between 0 and 1. 0 means only keyword scores, 1 means only vector scores.",
|
||||
)
|
||||
|
||||
|
||||
Ranker = Annotated[
|
||||
RRFRanker | WeightedRanker,
|
||||
Field(discriminator="type"),
|
||||
]
|
||||
register_schema(Ranker, name="Ranker")
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class RAGDocument(BaseModel):
|
||||
"""
|
||||
|
|
@ -76,7 +118,8 @@ class RAGQueryConfig(BaseModel):
|
|||
:param chunk_template: Template for formatting each retrieved chunk in the context.
|
||||
Available placeholders: {index} (1-based chunk ordinal), {chunk.content} (chunk content string), {metadata} (chunk metadata dict).
|
||||
Default: "Result {index}\\nContent: {chunk.content}\\nMetadata: {metadata}\\n"
|
||||
:param mode: Search mode for retrieval—either "vector" or "keyword". Default "vector".
|
||||
:param mode: Search mode for retrieval—either "vector", "keyword", or "hybrid". Default "vector".
|
||||
:param ranker: Configuration for the ranker to use in hybrid search. Defaults to RRF ranker.
|
||||
"""
|
||||
|
||||
# This config defines how a query is generated using the messages
|
||||
|
|
@ -86,6 +129,7 @@ class RAGQueryConfig(BaseModel):
|
|||
max_chunks: int = 5
|
||||
chunk_template: str = "Result {index}\nContent: {chunk.content}\nMetadata: {metadata}\n"
|
||||
mode: str | None = None
|
||||
ranker: Ranker | None = Field(default=None) # Only used for hybrid mode
|
||||
|
||||
@field_validator("chunk_template")
|
||||
def validate_chunk_template(cls, v: str) -> str:
|
||||
|
|
|
|||
|
|
@ -8,7 +8,7 @@
|
|||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
from typing import Any, Protocol, runtime_checkable
|
||||
from typing import Annotated, Any, Literal, Protocol, runtime_checkable
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
|
@ -16,6 +16,7 @@ from llama_stack.apis.inference import InterleavedContent
|
|||
from llama_stack.apis.vector_dbs import VectorDB
|
||||
from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol
|
||||
from llama_stack.schema_utils import json_schema_type, webmethod
|
||||
from llama_stack.strong_typing.schema import register_schema
|
||||
|
||||
|
||||
class Chunk(BaseModel):
|
||||
|
|
@ -37,6 +38,146 @@ class QueryChunksResponse(BaseModel):
|
|||
scores: list[float]
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class VectorStoreObject(BaseModel):
|
||||
"""OpenAI Vector Store object."""
|
||||
|
||||
id: str
|
||||
object: str = "vector_store"
|
||||
created_at: int
|
||||
name: str | None = None
|
||||
usage_bytes: int = 0
|
||||
file_counts: dict[str, int] = Field(default_factory=dict)
|
||||
status: str = "completed"
|
||||
expires_after: dict[str, Any] | None = None
|
||||
expires_at: int | None = None
|
||||
last_active_at: int | None = None
|
||||
metadata: dict[str, Any] = Field(default_factory=dict)
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class VectorStoreCreateRequest(BaseModel):
|
||||
"""Request to create a vector store."""
|
||||
|
||||
name: str | None = None
|
||||
file_ids: list[str] = Field(default_factory=list)
|
||||
expires_after: dict[str, Any] | None = None
|
||||
chunking_strategy: dict[str, Any] | None = None
|
||||
metadata: dict[str, Any] = Field(default_factory=dict)
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class VectorStoreModifyRequest(BaseModel):
|
||||
"""Request to modify a vector store."""
|
||||
|
||||
name: str | None = None
|
||||
expires_after: dict[str, Any] | None = None
|
||||
metadata: dict[str, Any] | None = None
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class VectorStoreListResponse(BaseModel):
|
||||
"""Response from listing vector stores."""
|
||||
|
||||
object: str = "list"
|
||||
data: list[VectorStoreObject]
|
||||
first_id: str | None = None
|
||||
last_id: str | None = None
|
||||
has_more: bool = False
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class VectorStoreSearchRequest(BaseModel):
|
||||
"""Request to search a vector store."""
|
||||
|
||||
query: str | list[str]
|
||||
filters: dict[str, Any] | None = None
|
||||
max_num_results: int = 10
|
||||
ranking_options: dict[str, Any] | None = None
|
||||
rewrite_query: bool = False
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class VectorStoreContent(BaseModel):
|
||||
type: Literal["text"]
|
||||
text: str
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class VectorStoreSearchResponse(BaseModel):
|
||||
"""Response from searching a vector store."""
|
||||
|
||||
file_id: str
|
||||
filename: str
|
||||
score: float
|
||||
attributes: dict[str, str | float | bool] | None = None
|
||||
content: list[VectorStoreContent]
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class VectorStoreSearchResponsePage(BaseModel):
|
||||
"""Response from searching a vector store."""
|
||||
|
||||
object: str = "vector_store.search_results.page"
|
||||
search_query: str
|
||||
data: list[VectorStoreSearchResponse]
|
||||
has_more: bool = False
|
||||
next_page: str | None = None
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class VectorStoreDeleteResponse(BaseModel):
|
||||
"""Response from deleting a vector store."""
|
||||
|
||||
id: str
|
||||
object: str = "vector_store.deleted"
|
||||
deleted: bool = True
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class VectorStoreChunkingStrategyAuto(BaseModel):
|
||||
type: Literal["auto"] = "auto"
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class VectorStoreChunkingStrategyStaticConfig(BaseModel):
|
||||
chunk_overlap_tokens: int = 400
|
||||
max_chunk_size_tokens: int = Field(800, ge=100, le=4096)
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class VectorStoreChunkingStrategyStatic(BaseModel):
|
||||
type: Literal["static"] = "static"
|
||||
static: VectorStoreChunkingStrategyStaticConfig
|
||||
|
||||
|
||||
VectorStoreChunkingStrategy = Annotated[
|
||||
VectorStoreChunkingStrategyAuto | VectorStoreChunkingStrategyStatic, Field(discriminator="type")
|
||||
]
|
||||
register_schema(VectorStoreChunkingStrategy, name="VectorStoreChunkingStrategy")
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class VectorStoreFileLastError(BaseModel):
|
||||
code: Literal["server_error"] | Literal["rate_limit_exceeded"]
|
||||
message: str
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class VectorStoreFileObject(BaseModel):
|
||||
"""OpenAI Vector Store File object."""
|
||||
|
||||
id: str
|
||||
object: str = "vector_store.file"
|
||||
attributes: dict[str, Any] = Field(default_factory=dict)
|
||||
chunking_strategy: VectorStoreChunkingStrategy
|
||||
created_at: int
|
||||
last_error: VectorStoreFileLastError | None = None
|
||||
status: Literal["completed"] | Literal["in_progress"] | Literal["cancelled"] | Literal["failed"]
|
||||
usage_bytes: int = 0
|
||||
vector_store_id: str
|
||||
|
||||
|
||||
class VectorDBStore(Protocol):
|
||||
def get_vector_db(self, vector_db_id: str) -> VectorDB | None: ...
|
||||
|
||||
|
|
@ -81,3 +222,134 @@ class VectorIO(Protocol):
|
|||
:returns: A QueryChunksResponse.
|
||||
"""
|
||||
...
|
||||
|
||||
# OpenAI Vector Stores API endpoints
|
||||
@webmethod(route="/openai/v1/vector_stores", method="POST")
|
||||
async def openai_create_vector_store(
|
||||
self,
|
||||
name: str,
|
||||
file_ids: list[str] | None = None,
|
||||
expires_after: dict[str, Any] | None = None,
|
||||
chunking_strategy: dict[str, Any] | None = None,
|
||||
metadata: dict[str, Any] | None = None,
|
||||
embedding_model: str | None = None,
|
||||
embedding_dimension: int | None = 384,
|
||||
provider_id: str | None = None,
|
||||
provider_vector_db_id: str | None = None,
|
||||
) -> VectorStoreObject:
|
||||
"""Creates a vector store.
|
||||
|
||||
:param name: A name for the vector store.
|
||||
:param file_ids: A list of File IDs that the vector store should use. Useful for tools like `file_search` that can access files.
|
||||
:param expires_after: The expiration policy for a vector store.
|
||||
:param chunking_strategy: The chunking strategy used to chunk the file(s). If not set, will use the `auto` strategy.
|
||||
:param metadata: Set of 16 key-value pairs that can be attached to an object.
|
||||
:param embedding_model: The embedding model to use for this vector store.
|
||||
:param embedding_dimension: The dimension of the embedding vectors (default: 384).
|
||||
:param provider_id: The ID of the provider to use for this vector store.
|
||||
:param provider_vector_db_id: The provider-specific vector database ID.
|
||||
:returns: A VectorStoreObject representing the created vector store.
|
||||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/openai/v1/vector_stores", method="GET")
|
||||
async def openai_list_vector_stores(
|
||||
self,
|
||||
limit: int | None = 20,
|
||||
order: str | None = "desc",
|
||||
after: str | None = None,
|
||||
before: str | None = None,
|
||||
) -> VectorStoreListResponse:
|
||||
"""Returns a list of vector stores.
|
||||
|
||||
:param limit: A limit on the number of objects to be returned. Limit can range between 1 and 100, and the default is 20.
|
||||
:param order: Sort order by the `created_at` timestamp of the objects. `asc` for ascending order and `desc` for descending order.
|
||||
:param after: A cursor for use in pagination. `after` is an object ID that defines your place in the list.
|
||||
:param before: A cursor for use in pagination. `before` is an object ID that defines your place in the list.
|
||||
:returns: A VectorStoreListResponse containing the list of vector stores.
|
||||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/openai/v1/vector_stores/{vector_store_id}", method="GET")
|
||||
async def openai_retrieve_vector_store(
|
||||
self,
|
||||
vector_store_id: str,
|
||||
) -> VectorStoreObject:
|
||||
"""Retrieves a vector store.
|
||||
|
||||
:param vector_store_id: The ID of the vector store to retrieve.
|
||||
:returns: A VectorStoreObject representing the vector store.
|
||||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/openai/v1/vector_stores/{vector_store_id}", method="POST")
|
||||
async def openai_update_vector_store(
|
||||
self,
|
||||
vector_store_id: str,
|
||||
name: str | None = None,
|
||||
expires_after: dict[str, Any] | None = None,
|
||||
metadata: dict[str, Any] | None = None,
|
||||
) -> VectorStoreObject:
|
||||
"""Updates a vector store.
|
||||
|
||||
:param vector_store_id: The ID of the vector store to update.
|
||||
:param name: The name of the vector store.
|
||||
:param expires_after: The expiration policy for a vector store.
|
||||
:param metadata: Set of 16 key-value pairs that can be attached to an object.
|
||||
:returns: A VectorStoreObject representing the updated vector store.
|
||||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/openai/v1/vector_stores/{vector_store_id}", method="DELETE")
|
||||
async def openai_delete_vector_store(
|
||||
self,
|
||||
vector_store_id: str,
|
||||
) -> VectorStoreDeleteResponse:
|
||||
"""Delete a vector store.
|
||||
|
||||
:param vector_store_id: The ID of the vector store to delete.
|
||||
:returns: A VectorStoreDeleteResponse indicating the deletion status.
|
||||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/openai/v1/vector_stores/{vector_store_id}/search", method="POST")
|
||||
async def openai_search_vector_store(
|
||||
self,
|
||||
vector_store_id: str,
|
||||
query: str | list[str],
|
||||
filters: dict[str, Any] | None = None,
|
||||
max_num_results: int | None = 10,
|
||||
ranking_options: dict[str, Any] | None = None,
|
||||
rewrite_query: bool | None = False,
|
||||
) -> VectorStoreSearchResponsePage:
|
||||
"""Search for chunks in a vector store.
|
||||
|
||||
Searches a vector store for relevant chunks based on a query and optional file attribute filters.
|
||||
|
||||
:param vector_store_id: The ID of the vector store to search.
|
||||
:param query: The query string or array for performing the search.
|
||||
:param filters: Filters based on file attributes to narrow the search results.
|
||||
:param max_num_results: Maximum number of results to return (1 to 50 inclusive, default 10).
|
||||
:param ranking_options: Ranking options for fine-tuning the search results.
|
||||
:param rewrite_query: Whether to rewrite the natural language query for vector search (default false)
|
||||
:returns: A VectorStoreSearchResponse containing the search results.
|
||||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/openai/v1/vector_stores/{vector_store_id}/files", method="POST")
|
||||
async def openai_attach_file_to_vector_store(
|
||||
self,
|
||||
vector_store_id: str,
|
||||
file_id: str,
|
||||
attributes: dict[str, Any] | None = None,
|
||||
chunking_strategy: VectorStoreChunkingStrategy | None = None,
|
||||
) -> VectorStoreFileObject:
|
||||
"""Attach a file to a vector store.
|
||||
|
||||
:param vector_store_id: The ID of the vector store to attach the file to.
|
||||
:param file_id: The ID of the file to attach to the vector store.
|
||||
:param attributes: The key-value attributes stored with the file, which can be used for filtering.
|
||||
:param chunking_strategy: The chunking strategy to use for the file.
|
||||
:returns: A VectorStoreFileObject representing the attached file.
|
||||
"""
|
||||
...
|
||||
|
|
|
|||
|
|
@ -180,6 +180,7 @@ def get_provider_registry(
|
|||
if provider_type_key in ret[api]:
|
||||
logger.warning(f"Overriding already registered provider {provider_type_key} for {api.name}")
|
||||
ret[api][provider_type_key] = spec
|
||||
logger.info(f"Successfully loaded external provider {provider_type_key}")
|
||||
except yaml.YAMLError as yaml_err:
|
||||
logger.error(f"Failed to parse YAML file {spec_path}: {yaml_err}")
|
||||
raise yaml_err
|
||||
|
|
|
|||
|
|
@ -394,9 +394,13 @@ def check_protocol_compliance(obj: Any, protocol: Any) -> None:
|
|||
logger.error(f"Method {name} incompatible proto: {proto_params} vs. obj: {obj_params}")
|
||||
missing_methods.append((name, "signature_mismatch"))
|
||||
else:
|
||||
# Check if the method is actually implemented in the class
|
||||
method_owner = next((cls for cls in mro if name in cls.__dict__), None)
|
||||
if method_owner is None or method_owner.__name__ == protocol.__name__:
|
||||
# Check if the method has a concrete implementation (not just a protocol stub)
|
||||
# Find all classes in MRO that define this method
|
||||
method_owners = [cls for cls in mro if name in cls.__dict__]
|
||||
|
||||
# Allow methods from mixins/parents, only reject if ONLY the protocol defines it
|
||||
if len(method_owners) == 1 and method_owners[0].__name__ == protocol.__name__:
|
||||
# Only reject if the method is ONLY defined in the protocol itself (abstract stub)
|
||||
missing_methods.append((name, "not_actually_implemented"))
|
||||
|
||||
if missing_methods:
|
||||
|
|
|
|||
|
|
@ -163,6 +163,9 @@ class InferenceRouter(Inference):
|
|||
messages: list[Message] | InterleavedContent,
|
||||
tool_prompt_format: ToolPromptFormat | None = None,
|
||||
) -> int | None:
|
||||
if not hasattr(self, "formatter") or self.formatter is None:
|
||||
return None
|
||||
|
||||
if isinstance(messages, list):
|
||||
encoded = self.formatter.encode_dialog_prompt(messages, tool_prompt_format)
|
||||
else:
|
||||
|
|
@ -423,6 +426,7 @@ class InferenceRouter(Inference):
|
|||
user: str | None = None,
|
||||
guided_choice: list[str] | None = None,
|
||||
prompt_logprobs: int | None = None,
|
||||
suffix: str | None = None,
|
||||
) -> OpenAICompletion:
|
||||
logger.debug(
|
||||
f"InferenceRouter.openai_completion: {model=}, {stream=}, {prompt=}",
|
||||
|
|
@ -453,6 +457,7 @@ class InferenceRouter(Inference):
|
|||
user=user,
|
||||
guided_choice=guided_choice,
|
||||
prompt_logprobs=prompt_logprobs,
|
||||
suffix=suffix,
|
||||
)
|
||||
|
||||
provider = self.routing_table.get_provider_impl(model_obj.identifier)
|
||||
|
|
@ -602,7 +607,7 @@ class InferenceRouter(Inference):
|
|||
|
||||
async def health(self) -> dict[str, HealthResponse]:
|
||||
health_statuses = {}
|
||||
timeout = 0.5
|
||||
timeout = 1 # increasing the timeout to 1 second for health checks
|
||||
for provider_id, impl in self.routing_table.impls_by_provider_id.items():
|
||||
try:
|
||||
# check if the provider has a health method
|
||||
|
|
|
|||
|
|
@ -9,7 +9,17 @@ from typing import Any
|
|||
from llama_stack.apis.common.content_types import (
|
||||
InterleavedContent,
|
||||
)
|
||||
from llama_stack.apis.vector_io import Chunk, QueryChunksResponse, VectorIO
|
||||
from llama_stack.apis.models import ModelType
|
||||
from llama_stack.apis.vector_io import (
|
||||
Chunk,
|
||||
QueryChunksResponse,
|
||||
VectorIO,
|
||||
VectorStoreDeleteResponse,
|
||||
VectorStoreListResponse,
|
||||
VectorStoreObject,
|
||||
VectorStoreSearchResponsePage,
|
||||
)
|
||||
from llama_stack.apis.vector_io.vector_io import VectorStoreChunkingStrategy, VectorStoreFileObject
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.datatypes import RoutingTable
|
||||
|
||||
|
|
@ -34,6 +44,31 @@ class VectorIORouter(VectorIO):
|
|||
logger.debug("VectorIORouter.shutdown")
|
||||
pass
|
||||
|
||||
async def _get_first_embedding_model(self) -> tuple[str, int] | None:
|
||||
"""Get the first available embedding model identifier."""
|
||||
try:
|
||||
# Get all models from the routing table
|
||||
all_models = await self.routing_table.get_all_with_type("model")
|
||||
|
||||
# Filter for embedding models
|
||||
embedding_models = [
|
||||
model
|
||||
for model in all_models
|
||||
if hasattr(model, "model_type") and model.model_type == ModelType.embedding
|
||||
]
|
||||
|
||||
if embedding_models:
|
||||
dimension = embedding_models[0].metadata.get("embedding_dimension", None)
|
||||
if dimension is None:
|
||||
raise ValueError(f"Embedding model {embedding_models[0].identifier} has no embedding dimension")
|
||||
return embedding_models[0].identifier, dimension
|
||||
else:
|
||||
logger.warning("No embedding models found in the routing table")
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting embedding models: {e}")
|
||||
return None
|
||||
|
||||
async def register_vector_db(
|
||||
self,
|
||||
vector_db_id: str,
|
||||
|
|
@ -70,3 +105,170 @@ class VectorIORouter(VectorIO):
|
|||
) -> QueryChunksResponse:
|
||||
logger.debug(f"VectorIORouter.query_chunks: {vector_db_id}")
|
||||
return await self.routing_table.get_provider_impl(vector_db_id).query_chunks(vector_db_id, query, params)
|
||||
|
||||
# OpenAI Vector Stores API endpoints
|
||||
async def openai_create_vector_store(
|
||||
self,
|
||||
name: str,
|
||||
file_ids: list[str] | None = None,
|
||||
expires_after: dict[str, Any] | None = None,
|
||||
chunking_strategy: dict[str, Any] | None = None,
|
||||
metadata: dict[str, Any] | None = None,
|
||||
embedding_model: str | None = None,
|
||||
embedding_dimension: int | None = None,
|
||||
provider_id: str | None = None,
|
||||
provider_vector_db_id: str | None = None,
|
||||
) -> VectorStoreObject:
|
||||
logger.debug(f"VectorIORouter.openai_create_vector_store: name={name}, provider_id={provider_id}")
|
||||
|
||||
# If no embedding model is provided, use the first available one
|
||||
if embedding_model is None:
|
||||
embedding_model_info = await self._get_first_embedding_model()
|
||||
if embedding_model_info is None:
|
||||
raise ValueError("No embedding model provided and no embedding models available in the system")
|
||||
embedding_model, embedding_dimension = embedding_model_info
|
||||
logger.info(f"No embedding model specified, using first available: {embedding_model}")
|
||||
|
||||
vector_db_id = name
|
||||
registered_vector_db = await self.routing_table.register_vector_db(
|
||||
vector_db_id,
|
||||
embedding_model,
|
||||
embedding_dimension,
|
||||
provider_id,
|
||||
provider_vector_db_id,
|
||||
)
|
||||
|
||||
return await self.routing_table.get_provider_impl(registered_vector_db.identifier).openai_create_vector_store(
|
||||
vector_db_id,
|
||||
file_ids=file_ids,
|
||||
expires_after=expires_after,
|
||||
chunking_strategy=chunking_strategy,
|
||||
metadata=metadata,
|
||||
embedding_model=embedding_model,
|
||||
embedding_dimension=embedding_dimension,
|
||||
provider_id=registered_vector_db.provider_id,
|
||||
provider_vector_db_id=registered_vector_db.provider_resource_id,
|
||||
)
|
||||
|
||||
async def openai_list_vector_stores(
|
||||
self,
|
||||
limit: int | None = 20,
|
||||
order: str | None = "desc",
|
||||
after: str | None = None,
|
||||
before: str | None = None,
|
||||
) -> VectorStoreListResponse:
|
||||
logger.debug(f"VectorIORouter.openai_list_vector_stores: limit={limit}")
|
||||
# Route to default provider for now - could aggregate from all providers in the future
|
||||
# call retrieve on each vector dbs to get list of vector stores
|
||||
vector_dbs = await self.routing_table.get_all_with_type("vector_db")
|
||||
all_stores = []
|
||||
for vector_db in vector_dbs:
|
||||
vector_store = await self.routing_table.get_provider_impl(
|
||||
vector_db.identifier
|
||||
).openai_retrieve_vector_store(vector_db.identifier)
|
||||
all_stores.append(vector_store)
|
||||
|
||||
# Sort by created_at
|
||||
reverse_order = order == "desc"
|
||||
all_stores.sort(key=lambda x: x.created_at, reverse=reverse_order)
|
||||
|
||||
# Apply cursor-based pagination
|
||||
if after:
|
||||
after_index = next((i for i, store in enumerate(all_stores) if store.id == after), -1)
|
||||
if after_index >= 0:
|
||||
all_stores = all_stores[after_index + 1 :]
|
||||
|
||||
if before:
|
||||
before_index = next((i for i, store in enumerate(all_stores) if store.id == before), len(all_stores))
|
||||
all_stores = all_stores[:before_index]
|
||||
|
||||
# Apply limit
|
||||
limited_stores = all_stores[:limit]
|
||||
|
||||
# Determine pagination info
|
||||
has_more = len(all_stores) > limit
|
||||
first_id = limited_stores[0].id if limited_stores else None
|
||||
last_id = limited_stores[-1].id if limited_stores else None
|
||||
|
||||
return VectorStoreListResponse(
|
||||
data=limited_stores,
|
||||
has_more=has_more,
|
||||
first_id=first_id,
|
||||
last_id=last_id,
|
||||
)
|
||||
|
||||
async def openai_retrieve_vector_store(
|
||||
self,
|
||||
vector_store_id: str,
|
||||
) -> VectorStoreObject:
|
||||
logger.debug(f"VectorIORouter.openai_retrieve_vector_store: {vector_store_id}")
|
||||
# Route based on vector store ID
|
||||
provider = self.routing_table.get_provider_impl(vector_store_id)
|
||||
return await provider.openai_retrieve_vector_store(vector_store_id)
|
||||
|
||||
async def openai_update_vector_store(
|
||||
self,
|
||||
vector_store_id: str,
|
||||
name: str | None = None,
|
||||
expires_after: dict[str, Any] | None = None,
|
||||
metadata: dict[str, Any] | None = None,
|
||||
) -> VectorStoreObject:
|
||||
logger.debug(f"VectorIORouter.openai_update_vector_store: {vector_store_id}")
|
||||
# Route based on vector store ID
|
||||
provider = self.routing_table.get_provider_impl(vector_store_id)
|
||||
return await provider.openai_update_vector_store(
|
||||
vector_store_id=vector_store_id,
|
||||
name=name,
|
||||
expires_after=expires_after,
|
||||
metadata=metadata,
|
||||
)
|
||||
|
||||
async def openai_delete_vector_store(
|
||||
self,
|
||||
vector_store_id: str,
|
||||
) -> VectorStoreDeleteResponse:
|
||||
logger.debug(f"VectorIORouter.openai_delete_vector_store: {vector_store_id}")
|
||||
# Route based on vector store ID
|
||||
provider = self.routing_table.get_provider_impl(vector_store_id)
|
||||
result = await provider.openai_delete_vector_store(vector_store_id)
|
||||
# drop from registry
|
||||
await self.routing_table.unregister_vector_db(vector_store_id)
|
||||
return result
|
||||
|
||||
async def openai_search_vector_store(
|
||||
self,
|
||||
vector_store_id: str,
|
||||
query: str | list[str],
|
||||
filters: dict[str, Any] | None = None,
|
||||
max_num_results: int | None = 10,
|
||||
ranking_options: dict[str, Any] | None = None,
|
||||
rewrite_query: bool | None = False,
|
||||
) -> VectorStoreSearchResponsePage:
|
||||
logger.debug(f"VectorIORouter.openai_search_vector_store: {vector_store_id}")
|
||||
# Route based on vector store ID
|
||||
provider = self.routing_table.get_provider_impl(vector_store_id)
|
||||
return await provider.openai_search_vector_store(
|
||||
vector_store_id=vector_store_id,
|
||||
query=query,
|
||||
filters=filters,
|
||||
max_num_results=max_num_results,
|
||||
ranking_options=ranking_options,
|
||||
rewrite_query=rewrite_query,
|
||||
)
|
||||
|
||||
async def openai_attach_file_to_vector_store(
|
||||
self,
|
||||
vector_store_id: str,
|
||||
file_id: str,
|
||||
attributes: dict[str, Any] | None = None,
|
||||
chunking_strategy: VectorStoreChunkingStrategy | None = None,
|
||||
) -> VectorStoreFileObject:
|
||||
logger.debug(f"VectorIORouter.openai_attach_file_to_vector_store: {vector_store_id}, {file_id}")
|
||||
# Route based on vector store ID
|
||||
provider = self.routing_table.get_provider_impl(vector_store_id)
|
||||
return await provider.openai_attach_file_to_vector_store(
|
||||
vector_store_id=vector_store_id,
|
||||
file_id=file_id,
|
||||
attributes=attributes,
|
||||
chunking_strategy=chunking_strategy,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -84,6 +84,7 @@ def get_attributes_from_claims(claims: dict[str, str], mapping: dict[str, str])
|
|||
class OAuth2JWKSConfig(BaseModel):
|
||||
# The JWKS URI for collecting public keys
|
||||
uri: str
|
||||
token: str | None = Field(default=None, description="token to authorise access to jwks")
|
||||
key_recheck_period: int = Field(default=3600, description="The period to recheck the JWKS URI for key updates")
|
||||
|
||||
|
||||
|
|
@ -246,9 +247,12 @@ class OAuth2TokenAuthProvider(AuthProvider):
|
|||
if self.config.jwks is None:
|
||||
raise ValueError("JWKS is not configured")
|
||||
if time.time() - self._jwks_at > self.config.jwks.key_recheck_period:
|
||||
headers = {}
|
||||
if self.config.jwks.token:
|
||||
headers["Authorization"] = f"Bearer {self.config.jwks.token}"
|
||||
verify = self.config.tls_cafile.as_posix() if self.config.tls_cafile else self.config.verify_tls
|
||||
async with httpx.AsyncClient(verify=verify) as client:
|
||||
res = await client.get(self.config.jwks.uri, timeout=5)
|
||||
res = await client.get(self.config.jwks.uri, timeout=5, headers=headers)
|
||||
res.raise_for_status()
|
||||
jwks_data = res.json()["keys"]
|
||||
updated = {}
|
||||
|
|
|
|||
|
|
@ -30,6 +30,7 @@ from fastapi.responses import JSONResponse, StreamingResponse
|
|||
from openai import BadRequestError
|
||||
from pydantic import BaseModel, ValidationError
|
||||
|
||||
from llama_stack.apis.common.responses import PaginatedResponse
|
||||
from llama_stack.distribution.datatypes import AuthenticationRequiredError, LoggingConfig, StackRunConfig
|
||||
from llama_stack.distribution.distribution import builtin_automatically_routed_apis
|
||||
from llama_stack.distribution.request_headers import PROVIDER_DATA_VAR, User, request_provider_data_context
|
||||
|
|
@ -230,7 +231,10 @@ def create_dynamic_typed_route(func: Any, method: str, route: str) -> Callable:
|
|||
return StreamingResponse(gen, media_type="text/event-stream")
|
||||
else:
|
||||
value = func(**kwargs)
|
||||
return await maybe_await(value)
|
||||
result = await maybe_await(value)
|
||||
if isinstance(result, PaginatedResponse) and result.url is None:
|
||||
result.url = route
|
||||
return result
|
||||
except Exception as e:
|
||||
logger.exception(f"Error executing endpoint {route=} {method=}")
|
||||
raise translate_exception(e) from e
|
||||
|
|
|
|||
|
|
@ -115,7 +115,7 @@ def parse_environment_config(env_config: str) -> dict[str, int]:
|
|||
|
||||
class CustomRichHandler(RichHandler):
|
||||
def __init__(self, *args, **kwargs):
|
||||
kwargs["console"] = Console(width=120)
|
||||
kwargs["console"] = Console(width=150)
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
def emit(self, record):
|
||||
|
|
|
|||
|
|
@ -24,6 +24,7 @@ from llama_stack.apis.agents.openai_responses import (
|
|||
OpenAIResponseInputMessageContentImage,
|
||||
OpenAIResponseInputMessageContentText,
|
||||
OpenAIResponseInputTool,
|
||||
OpenAIResponseInputToolFileSearch,
|
||||
OpenAIResponseInputToolMCP,
|
||||
OpenAIResponseMessage,
|
||||
OpenAIResponseObject,
|
||||
|
|
@ -34,6 +35,7 @@ from llama_stack.apis.agents.openai_responses import (
|
|||
OpenAIResponseOutput,
|
||||
OpenAIResponseOutputMessageContent,
|
||||
OpenAIResponseOutputMessageContentOutputText,
|
||||
OpenAIResponseOutputMessageFileSearchToolCall,
|
||||
OpenAIResponseOutputMessageFunctionToolCall,
|
||||
OpenAIResponseOutputMessageMCPListTools,
|
||||
OpenAIResponseOutputMessageWebSearchToolCall,
|
||||
|
|
@ -62,7 +64,7 @@ from llama_stack.apis.inference.inference import (
|
|||
OpenAIToolMessageParam,
|
||||
OpenAIUserMessageParam,
|
||||
)
|
||||
from llama_stack.apis.tools.tools import ToolGroups, ToolRuntime
|
||||
from llama_stack.apis.tools import RAGQueryConfig, ToolGroups, ToolRuntime
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.models.llama.datatypes import ToolDefinition, ToolParamDefinition
|
||||
from llama_stack.providers.utils.inference.openai_compat import convert_tooldef_to_openai_tool
|
||||
|
|
@ -198,7 +200,8 @@ class OpenAIResponsePreviousResponseWithInputItems(BaseModel):
|
|||
class ChatCompletionContext(BaseModel):
|
||||
model: str
|
||||
messages: list[OpenAIMessageParam]
|
||||
tools: list[ChatCompletionToolParam] | None = None
|
||||
response_tools: list[OpenAIResponseInputTool] | None = None
|
||||
chat_tools: list[ChatCompletionToolParam] | None = None
|
||||
mcp_tool_to_server: dict[str, OpenAIResponseInputToolMCP]
|
||||
temperature: float | None
|
||||
response_format: OpenAIResponseFormatParam
|
||||
|
|
@ -388,7 +391,8 @@ class OpenAIResponsesImpl:
|
|||
ctx = ChatCompletionContext(
|
||||
model=model,
|
||||
messages=messages,
|
||||
tools=chat_tools,
|
||||
response_tools=tools,
|
||||
chat_tools=chat_tools,
|
||||
mcp_tool_to_server=mcp_tool_to_server,
|
||||
temperature=temperature,
|
||||
response_format=response_format,
|
||||
|
|
@ -417,7 +421,7 @@ class OpenAIResponsesImpl:
|
|||
completion_result = await self.inference_api.openai_chat_completion(
|
||||
model=ctx.model,
|
||||
messages=messages,
|
||||
tools=ctx.tools,
|
||||
tools=ctx.chat_tools,
|
||||
stream=True,
|
||||
temperature=ctx.temperature,
|
||||
response_format=ctx.response_format,
|
||||
|
|
@ -606,6 +610,12 @@ class OpenAIResponsesImpl:
|
|||
if not tool:
|
||||
raise ValueError(f"Tool {tool_name} not found")
|
||||
chat_tools.append(make_openai_tool(tool_name, tool))
|
||||
elif input_tool.type == "file_search":
|
||||
tool_name = "knowledge_search"
|
||||
tool = await self.tool_groups_api.get_tool(tool_name)
|
||||
if not tool:
|
||||
raise ValueError(f"Tool {tool_name} not found")
|
||||
chat_tools.append(make_openai_tool(tool_name, tool))
|
||||
elif input_tool.type == "mcp":
|
||||
always_allowed = None
|
||||
never_allowed = None
|
||||
|
|
@ -667,6 +677,7 @@ class OpenAIResponsesImpl:
|
|||
|
||||
tool_call_id = tool_call.id
|
||||
function = tool_call.function
|
||||
tool_kwargs = json.loads(function.arguments) if function.arguments else {}
|
||||
|
||||
if not function or not tool_call_id or not function.name:
|
||||
return None, None
|
||||
|
|
@ -680,12 +691,26 @@ class OpenAIResponsesImpl:
|
|||
endpoint=mcp_tool.server_url,
|
||||
headers=mcp_tool.headers or {},
|
||||
tool_name=function.name,
|
||||
kwargs=json.loads(function.arguments) if function.arguments else {},
|
||||
kwargs=tool_kwargs,
|
||||
)
|
||||
else:
|
||||
if function.name == "knowledge_search":
|
||||
response_file_search_tool = next(
|
||||
t for t in ctx.response_tools if isinstance(t, OpenAIResponseInputToolFileSearch)
|
||||
)
|
||||
if response_file_search_tool:
|
||||
if response_file_search_tool.filters:
|
||||
logger.warning("Filters are not yet supported for file_search tool")
|
||||
if response_file_search_tool.ranking_options:
|
||||
logger.warning("Ranking options are not yet supported for file_search tool")
|
||||
tool_kwargs["vector_db_ids"] = response_file_search_tool.vector_store_ids
|
||||
tool_kwargs["query_config"] = RAGQueryConfig(
|
||||
mode="vector",
|
||||
max_chunks=response_file_search_tool.max_num_results,
|
||||
)
|
||||
result = await self.tool_runtime_api.invoke_tool(
|
||||
tool_name=function.name,
|
||||
kwargs=json.loads(function.arguments) if function.arguments else {},
|
||||
kwargs=tool_kwargs,
|
||||
)
|
||||
except Exception as e:
|
||||
error_exc = e
|
||||
|
|
@ -713,6 +738,27 @@ class OpenAIResponsesImpl:
|
|||
)
|
||||
if error_exc or (result.error_code and result.error_code > 0) or result.error_message:
|
||||
message.status = "failed"
|
||||
elif function.name == "knowledge_search":
|
||||
message = OpenAIResponseOutputMessageFileSearchToolCall(
|
||||
id=tool_call_id,
|
||||
queries=[tool_kwargs.get("query", "")],
|
||||
status="completed",
|
||||
)
|
||||
if "document_ids" in result.metadata:
|
||||
message.results = []
|
||||
for i, doc_id in enumerate(result.metadata["document_ids"]):
|
||||
text = result.metadata["chunks"][i] if "chunks" in result.metadata else None
|
||||
score = result.metadata["scores"][i] if "scores" in result.metadata else None
|
||||
message.results.append(
|
||||
{
|
||||
"file_id": doc_id,
|
||||
"filename": doc_id,
|
||||
"text": text,
|
||||
"score": score,
|
||||
}
|
||||
)
|
||||
if error_exc or (result.error_code and result.error_code > 0) or result.error_message:
|
||||
message.status = "failed"
|
||||
else:
|
||||
raise ValueError(f"Unknown tool {function.name} called")
|
||||
|
||||
|
|
|
|||
|
|
@ -121,8 +121,10 @@ class MemoryToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime, RAGToolRunti
|
|||
vector_db_id=vector_db_id,
|
||||
query=query,
|
||||
params={
|
||||
"max_chunks": query_config.max_chunks,
|
||||
"mode": query_config.mode,
|
||||
"max_chunks": query_config.max_chunks,
|
||||
"score_threshold": 0.0,
|
||||
"ranker": query_config.ranker,
|
||||
},
|
||||
)
|
||||
for vector_db_id in vector_db_ids
|
||||
|
|
@ -170,6 +172,8 @@ class MemoryToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime, RAGToolRunti
|
|||
content=picked,
|
||||
metadata={
|
||||
"document_ids": [c.metadata["document_id"] for c in chunks[: len(picked)]],
|
||||
"chunks": [c.content for c in chunks[: len(picked)]],
|
||||
"scores": scores[: len(picked)],
|
||||
},
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -16,6 +16,6 @@ async def get_provider_impl(config: FaissVectorIOConfig, deps: dict[Api, Any]):
|
|||
|
||||
assert isinstance(config, FaissVectorIOConfig), f"Unexpected config type: {type(config)}"
|
||||
|
||||
impl = FaissVectorIOAdapter(config, deps[Api.inference])
|
||||
impl = FaissVectorIOAdapter(config, deps[Api.inference], deps.get(Api.files, None))
|
||||
await impl.initialize()
|
||||
return impl
|
||||
|
|
|
|||
|
|
@ -15,13 +15,19 @@ import faiss
|
|||
import numpy as np
|
||||
from numpy.typing import NDArray
|
||||
|
||||
from llama_stack.apis.common.content_types import InterleavedContent
|
||||
from llama_stack.apis.files import Files
|
||||
from llama_stack.apis.inference import InterleavedContent
|
||||
from llama_stack.apis.inference.inference import Inference
|
||||
from llama_stack.apis.vector_dbs import VectorDB
|
||||
from llama_stack.apis.vector_io import Chunk, QueryChunksResponse, VectorIO
|
||||
from llama_stack.apis.vector_io import (
|
||||
Chunk,
|
||||
QueryChunksResponse,
|
||||
VectorIO,
|
||||
)
|
||||
from llama_stack.providers.datatypes import VectorDBsProtocolPrivate
|
||||
from llama_stack.providers.utils.kvstore import kvstore_impl
|
||||
from llama_stack.providers.utils.kvstore.api import KVStore
|
||||
from llama_stack.providers.utils.memory.openai_vector_store_mixin import OpenAIVectorStoreMixin
|
||||
from llama_stack.providers.utils.memory.vector_store import (
|
||||
EmbeddingIndex,
|
||||
VectorDBWithIndex,
|
||||
|
|
@ -34,6 +40,7 @@ logger = logging.getLogger(__name__)
|
|||
VERSION = "v3"
|
||||
VECTOR_DBS_PREFIX = f"vector_dbs:{VERSION}::"
|
||||
FAISS_INDEX_PREFIX = f"faiss_index:{VERSION}::"
|
||||
OPENAI_VECTOR_STORES_PREFIX = f"openai_vector_stores:{VERSION}::"
|
||||
|
||||
|
||||
class FaissIndex(EmbeddingIndex):
|
||||
|
|
@ -112,7 +119,7 @@ class FaissIndex(EmbeddingIndex):
|
|||
if i < 0:
|
||||
continue
|
||||
chunks.append(self.chunk_by_index[int(i)])
|
||||
scores.append(1.0 / float(d))
|
||||
scores.append(1.0 / float(d) if d != 0 else float("inf"))
|
||||
|
||||
return QueryChunksResponse(chunks=chunks, scores=scores)
|
||||
|
||||
|
|
@ -124,13 +131,26 @@ class FaissIndex(EmbeddingIndex):
|
|||
) -> QueryChunksResponse:
|
||||
raise NotImplementedError("Keyword search is not supported in FAISS")
|
||||
|
||||
async def query_hybrid(
|
||||
self,
|
||||
embedding: NDArray,
|
||||
query_string: str,
|
||||
k: int,
|
||||
score_threshold: float,
|
||||
reranker_type: str,
|
||||
reranker_params: dict[str, Any] | None = None,
|
||||
) -> QueryChunksResponse:
|
||||
raise NotImplementedError("Hybrid search is not supported in FAISS")
|
||||
|
||||
class FaissVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate):
|
||||
def __init__(self, config: FaissVectorIOConfig, inference_api: Inference) -> None:
|
||||
|
||||
class FaissVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolPrivate):
|
||||
def __init__(self, config: FaissVectorIOConfig, inference_api: Inference, files_api: Files | None) -> None:
|
||||
self.config = config
|
||||
self.inference_api = inference_api
|
||||
self.files_api = files_api
|
||||
self.cache: dict[str, VectorDBWithIndex] = {}
|
||||
self.kvstore: KVStore | None = None
|
||||
self.openai_vector_stores: dict[str, dict[str, Any]] = {}
|
||||
|
||||
async def initialize(self) -> None:
|
||||
self.kvstore = await kvstore_impl(self.config.kvstore)
|
||||
|
|
@ -148,6 +168,9 @@ class FaissVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate):
|
|||
)
|
||||
self.cache[vector_db.identifier] = index
|
||||
|
||||
# Load existing OpenAI vector stores using the mixin method
|
||||
self.openai_vector_stores = await self._load_openai_vector_stores()
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
# Cleanup if needed
|
||||
pass
|
||||
|
|
@ -208,3 +231,35 @@ class FaissVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate):
|
|||
raise ValueError(f"Vector DB {vector_db_id} not found")
|
||||
|
||||
return await index.query_chunks(query, params)
|
||||
|
||||
# OpenAI Vector Store Mixin abstract method implementations
|
||||
async def _save_openai_vector_store(self, store_id: str, store_info: dict[str, Any]) -> None:
|
||||
"""Save vector store metadata to kvstore."""
|
||||
assert self.kvstore is not None
|
||||
key = f"{OPENAI_VECTOR_STORES_PREFIX}{store_id}"
|
||||
await self.kvstore.set(key=key, value=json.dumps(store_info))
|
||||
|
||||
async def _load_openai_vector_stores(self) -> dict[str, dict[str, Any]]:
|
||||
"""Load all vector store metadata from kvstore."""
|
||||
assert self.kvstore is not None
|
||||
start_key = OPENAI_VECTOR_STORES_PREFIX
|
||||
end_key = f"{OPENAI_VECTOR_STORES_PREFIX}\xff"
|
||||
stored_openai_stores = await self.kvstore.values_in_range(start_key, end_key)
|
||||
|
||||
stores = {}
|
||||
for store_data in stored_openai_stores:
|
||||
store_info = json.loads(store_data)
|
||||
stores[store_info["id"]] = store_info
|
||||
return stores
|
||||
|
||||
async def _update_openai_vector_store(self, store_id: str, store_info: dict[str, Any]) -> None:
|
||||
"""Update vector store metadata in kvstore."""
|
||||
assert self.kvstore is not None
|
||||
key = f"{OPENAI_VECTOR_STORES_PREFIX}{store_id}"
|
||||
await self.kvstore.set(key=key, value=json.dumps(store_info))
|
||||
|
||||
async def _delete_openai_vector_store_from_storage(self, store_id: str) -> None:
|
||||
"""Delete vector store metadata from kvstore."""
|
||||
assert self.kvstore is not None
|
||||
key = f"{OPENAI_VECTOR_STORES_PREFIX}{store_id}"
|
||||
await self.kvstore.delete(key)
|
||||
|
|
|
|||
|
|
@ -15,6 +15,6 @@ async def get_provider_impl(config: SQLiteVectorIOConfig, deps: dict[Api, Any]):
|
|||
from .sqlite_vec import SQLiteVecVectorIOAdapter
|
||||
|
||||
assert isinstance(config, SQLiteVectorIOConfig), f"Unexpected config type: {type(config)}"
|
||||
impl = SQLiteVecVectorIOAdapter(config, deps[Api.inference])
|
||||
impl = SQLiteVecVectorIOAdapter(config, deps[Api.inference], deps.get(Api.files, None))
|
||||
await impl.initialize()
|
||||
return impl
|
||||
|
|
|
|||
|
|
@ -6,6 +6,7 @@
|
|||
|
||||
import asyncio
|
||||
import hashlib
|
||||
import json
|
||||
import logging
|
||||
import sqlite3
|
||||
import struct
|
||||
|
|
@ -16,18 +17,30 @@ import numpy as np
|
|||
import sqlite_vec
|
||||
from numpy.typing import NDArray
|
||||
|
||||
from llama_stack.apis.files.files import Files
|
||||
from llama_stack.apis.inference.inference import Inference
|
||||
from llama_stack.apis.vector_dbs import VectorDB
|
||||
from llama_stack.apis.vector_io import Chunk, QueryChunksResponse, VectorIO
|
||||
from llama_stack.apis.vector_io import (
|
||||
Chunk,
|
||||
QueryChunksResponse,
|
||||
VectorIO,
|
||||
)
|
||||
from llama_stack.providers.datatypes import VectorDBsProtocolPrivate
|
||||
from llama_stack.providers.utils.memory.vector_store import EmbeddingIndex, VectorDBWithIndex
|
||||
from llama_stack.providers.utils.memory.openai_vector_store_mixin import OpenAIVectorStoreMixin
|
||||
from llama_stack.providers.utils.memory.vector_store import (
|
||||
RERANKER_TYPE_RRF,
|
||||
RERANKER_TYPE_WEIGHTED,
|
||||
EmbeddingIndex,
|
||||
VectorDBWithIndex,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Specifying search mode is dependent on the VectorIO provider.
|
||||
VECTOR_SEARCH = "vector"
|
||||
KEYWORD_SEARCH = "keyword"
|
||||
SEARCH_MODES = {VECTOR_SEARCH, KEYWORD_SEARCH}
|
||||
HYBRID_SEARCH = "hybrid"
|
||||
SEARCH_MODES = {VECTOR_SEARCH, KEYWORD_SEARCH, HYBRID_SEARCH}
|
||||
|
||||
|
||||
def serialize_vector(vector: list[float]) -> bytes:
|
||||
|
|
@ -44,6 +57,59 @@ def _create_sqlite_connection(db_path):
|
|||
return connection
|
||||
|
||||
|
||||
def _normalize_scores(scores: dict[str, float]) -> dict[str, float]:
|
||||
"""Normalize scores to [0,1] range using min-max normalization."""
|
||||
if not scores:
|
||||
return {}
|
||||
min_score = min(scores.values())
|
||||
max_score = max(scores.values())
|
||||
score_range = max_score - min_score
|
||||
if score_range > 0:
|
||||
return {doc_id: (score - min_score) / score_range for doc_id, score in scores.items()}
|
||||
return {doc_id: 1.0 for doc_id in scores}
|
||||
|
||||
|
||||
def _weighted_rerank(
|
||||
vector_scores: dict[str, float],
|
||||
keyword_scores: dict[str, float],
|
||||
alpha: float = 0.5,
|
||||
) -> dict[str, float]:
|
||||
"""ReRanker that uses weighted average of scores."""
|
||||
all_ids = set(vector_scores.keys()) | set(keyword_scores.keys())
|
||||
normalized_vector_scores = _normalize_scores(vector_scores)
|
||||
normalized_keyword_scores = _normalize_scores(keyword_scores)
|
||||
|
||||
return {
|
||||
doc_id: (alpha * normalized_keyword_scores.get(doc_id, 0.0))
|
||||
+ ((1 - alpha) * normalized_vector_scores.get(doc_id, 0.0))
|
||||
for doc_id in all_ids
|
||||
}
|
||||
|
||||
|
||||
def _rrf_rerank(
|
||||
vector_scores: dict[str, float],
|
||||
keyword_scores: dict[str, float],
|
||||
impact_factor: float = 60.0,
|
||||
) -> dict[str, float]:
|
||||
"""ReRanker that uses Reciprocal Rank Fusion."""
|
||||
# Convert scores to ranks
|
||||
vector_ranks = {
|
||||
doc_id: i + 1 for i, (doc_id, _) in enumerate(sorted(vector_scores.items(), key=lambda x: x[1], reverse=True))
|
||||
}
|
||||
keyword_ranks = {
|
||||
doc_id: i + 1 for i, (doc_id, _) in enumerate(sorted(keyword_scores.items(), key=lambda x: x[1], reverse=True))
|
||||
}
|
||||
|
||||
all_ids = set(vector_scores.keys()) | set(keyword_scores.keys())
|
||||
rrf_scores = {}
|
||||
for doc_id in all_ids:
|
||||
vector_rank = vector_ranks.get(doc_id, float("inf"))
|
||||
keyword_rank = keyword_ranks.get(doc_id, float("inf"))
|
||||
# RRF formula: score = 1/(k + r) where k is impact_factor and r is the rank
|
||||
rrf_scores[doc_id] = (1.0 / (impact_factor + vector_rank)) + (1.0 / (impact_factor + keyword_rank))
|
||||
return rrf_scores
|
||||
|
||||
|
||||
class SQLiteVecIndex(EmbeddingIndex):
|
||||
"""
|
||||
An index implementation that stores embeddings in a SQLite virtual table using sqlite-vec.
|
||||
|
|
@ -248,8 +314,6 @@ class SQLiteVecIndex(EmbeddingIndex):
|
|||
"""
|
||||
Performs keyword-based search using SQLite FTS5 for relevance-ranked full-text search.
|
||||
"""
|
||||
if query_string is None:
|
||||
raise ValueError("query_string is required for keyword search.")
|
||||
|
||||
def _execute_query():
|
||||
connection = _create_sqlite_connection(self.db_path)
|
||||
|
|
@ -287,18 +351,95 @@ class SQLiteVecIndex(EmbeddingIndex):
|
|||
scores.append(score)
|
||||
return QueryChunksResponse(chunks=chunks, scores=scores)
|
||||
|
||||
async def query_hybrid(
|
||||
self,
|
||||
embedding: NDArray,
|
||||
query_string: str,
|
||||
k: int,
|
||||
score_threshold: float,
|
||||
reranker_type: str = RERANKER_TYPE_RRF,
|
||||
reranker_params: dict[str, Any] | None = None,
|
||||
) -> QueryChunksResponse:
|
||||
"""
|
||||
Hybrid search using a configurable re-ranking strategy.
|
||||
|
||||
class SQLiteVecVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate):
|
||||
Args:
|
||||
embedding: The query embedding vector
|
||||
query_string: The text query for keyword search
|
||||
k: Number of results to return
|
||||
score_threshold: Minimum similarity score threshold
|
||||
reranker_type: Type of reranker to use ("rrf" or "weighted")
|
||||
reranker_params: Parameters for the reranker
|
||||
|
||||
Returns:
|
||||
QueryChunksResponse with combined results
|
||||
"""
|
||||
if reranker_params is None:
|
||||
reranker_params = {}
|
||||
|
||||
# Get results from both search methods
|
||||
vector_response = await self.query_vector(embedding, k, score_threshold)
|
||||
keyword_response = await self.query_keyword(query_string, k, score_threshold)
|
||||
|
||||
# Convert responses to score dictionaries using generate_chunk_id
|
||||
vector_scores = {
|
||||
generate_chunk_id(chunk.metadata["document_id"], str(chunk.content)): score
|
||||
for chunk, score in zip(vector_response.chunks, vector_response.scores, strict=False)
|
||||
}
|
||||
keyword_scores = {
|
||||
generate_chunk_id(chunk.metadata["document_id"], str(chunk.content)): score
|
||||
for chunk, score in zip(keyword_response.chunks, keyword_response.scores, strict=False)
|
||||
}
|
||||
|
||||
# Combine scores using the specified reranker
|
||||
if reranker_type == RERANKER_TYPE_WEIGHTED:
|
||||
alpha = reranker_params.get("alpha", 0.5)
|
||||
combined_scores = _weighted_rerank(vector_scores, keyword_scores, alpha)
|
||||
else:
|
||||
# Default to RRF for None, RRF, or any unknown types
|
||||
impact_factor = reranker_params.get("impact_factor", 60.0)
|
||||
combined_scores = _rrf_rerank(vector_scores, keyword_scores, impact_factor)
|
||||
|
||||
# Sort by combined score and get top k results
|
||||
sorted_items = sorted(combined_scores.items(), key=lambda x: x[1], reverse=True)
|
||||
top_k_items = sorted_items[:k]
|
||||
|
||||
# Filter by score threshold
|
||||
filtered_items = [(doc_id, score) for doc_id, score in top_k_items if score >= score_threshold]
|
||||
|
||||
# Create a map of chunk_id to chunk for both responses
|
||||
chunk_map = {}
|
||||
for c in vector_response.chunks:
|
||||
chunk_id = generate_chunk_id(c.metadata["document_id"], str(c.content))
|
||||
chunk_map[chunk_id] = c
|
||||
for c in keyword_response.chunks:
|
||||
chunk_id = generate_chunk_id(c.metadata["document_id"], str(c.content))
|
||||
chunk_map[chunk_id] = c
|
||||
|
||||
# Use the map to look up chunks by their IDs
|
||||
chunks = []
|
||||
scores = []
|
||||
for doc_id, score in filtered_items:
|
||||
if doc_id in chunk_map:
|
||||
chunks.append(chunk_map[doc_id])
|
||||
scores.append(score)
|
||||
|
||||
return QueryChunksResponse(chunks=chunks, scores=scores)
|
||||
|
||||
|
||||
class SQLiteVecVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolPrivate):
|
||||
"""
|
||||
A VectorIO implementation using SQLite + sqlite_vec.
|
||||
This class handles vector database registration (with metadata stored in a table named `vector_dbs`)
|
||||
and creates a cache of VectorDBWithIndex instances (each wrapping a SQLiteVecIndex).
|
||||
"""
|
||||
|
||||
def __init__(self, config, inference_api: Inference) -> None:
|
||||
def __init__(self, config, inference_api: Inference, files_api: Files | None) -> None:
|
||||
self.config = config
|
||||
self.inference_api = inference_api
|
||||
self.files_api = files_api
|
||||
self.cache: dict[str, VectorDBWithIndex] = {}
|
||||
self.openai_vector_stores: dict[str, dict[str, Any]] = {}
|
||||
|
||||
async def initialize(self) -> None:
|
||||
def _setup_connection():
|
||||
|
|
@ -313,24 +454,38 @@ class SQLiteVecVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate):
|
|||
metadata TEXT
|
||||
);
|
||||
""")
|
||||
# Create a table to persist OpenAI vector stores.
|
||||
cur.execute("""
|
||||
CREATE TABLE IF NOT EXISTS openai_vector_stores (
|
||||
id TEXT PRIMARY KEY,
|
||||
metadata TEXT
|
||||
);
|
||||
""")
|
||||
connection.commit()
|
||||
# Load any existing vector DB registrations.
|
||||
cur.execute("SELECT metadata FROM vector_dbs")
|
||||
rows = cur.fetchall()
|
||||
return rows
|
||||
vector_db_rows = cur.fetchall()
|
||||
return vector_db_rows
|
||||
finally:
|
||||
cur.close()
|
||||
connection.close()
|
||||
|
||||
rows = await asyncio.to_thread(_setup_connection)
|
||||
for row in rows:
|
||||
vector_db_rows = await asyncio.to_thread(_setup_connection)
|
||||
|
||||
# Load existing vector DBs
|
||||
for row in vector_db_rows:
|
||||
vector_db_data = row[0]
|
||||
vector_db = VectorDB.model_validate_json(vector_db_data)
|
||||
index = await SQLiteVecIndex.create(
|
||||
vector_db.embedding_dimension, self.config.db_path, vector_db.identifier
|
||||
vector_db.embedding_dimension,
|
||||
self.config.db_path,
|
||||
vector_db.identifier,
|
||||
)
|
||||
self.cache[vector_db.identifier] = VectorDBWithIndex(vector_db, index, self.inference_api)
|
||||
|
||||
# Load existing OpenAI vector stores using the mixin method
|
||||
self.openai_vector_stores = await self._load_openai_vector_stores()
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
# nothing to do since we don't maintain a persistent connection
|
||||
pass
|
||||
|
|
@ -350,7 +505,11 @@ class SQLiteVecVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate):
|
|||
connection.close()
|
||||
|
||||
await asyncio.to_thread(_register_db)
|
||||
index = await SQLiteVecIndex.create(vector_db.embedding_dimension, self.config.db_path, vector_db.identifier)
|
||||
index = await SQLiteVecIndex.create(
|
||||
vector_db.embedding_dimension,
|
||||
self.config.db_path,
|
||||
vector_db.identifier,
|
||||
)
|
||||
self.cache[vector_db.identifier] = VectorDBWithIndex(vector_db, index, self.inference_api)
|
||||
|
||||
async def list_vector_dbs(self) -> list[VectorDB]:
|
||||
|
|
@ -375,6 +534,87 @@ class SQLiteVecVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate):
|
|||
|
||||
await asyncio.to_thread(_delete_vector_db_from_registry)
|
||||
|
||||
# OpenAI Vector Store Mixin abstract method implementations
|
||||
async def _save_openai_vector_store(self, store_id: str, store_info: dict[str, Any]) -> None:
|
||||
"""Save vector store metadata to SQLite database."""
|
||||
|
||||
def _store():
|
||||
connection = _create_sqlite_connection(self.config.db_path)
|
||||
cur = connection.cursor()
|
||||
try:
|
||||
cur.execute(
|
||||
"INSERT OR REPLACE INTO openai_vector_stores (id, metadata) VALUES (?, ?)",
|
||||
(store_id, json.dumps(store_info)),
|
||||
)
|
||||
connection.commit()
|
||||
except Exception as e:
|
||||
logger.error(f"Error saving openai vector store {store_id}: {e}")
|
||||
raise
|
||||
finally:
|
||||
cur.close()
|
||||
connection.close()
|
||||
|
||||
try:
|
||||
await asyncio.to_thread(_store)
|
||||
except Exception as e:
|
||||
logger.error(f"Error saving openai vector store {store_id}: {e}")
|
||||
raise
|
||||
|
||||
async def _load_openai_vector_stores(self) -> dict[str, dict[str, Any]]:
|
||||
"""Load all vector store metadata from SQLite database."""
|
||||
|
||||
def _load():
|
||||
connection = _create_sqlite_connection(self.config.db_path)
|
||||
cur = connection.cursor()
|
||||
try:
|
||||
cur.execute("SELECT metadata FROM openai_vector_stores")
|
||||
rows = cur.fetchall()
|
||||
return rows
|
||||
finally:
|
||||
cur.close()
|
||||
connection.close()
|
||||
|
||||
rows = await asyncio.to_thread(_load)
|
||||
stores = {}
|
||||
for row in rows:
|
||||
store_data = row[0]
|
||||
store_info = json.loads(store_data)
|
||||
stores[store_info["id"]] = store_info
|
||||
return stores
|
||||
|
||||
async def _update_openai_vector_store(self, store_id: str, store_info: dict[str, Any]) -> None:
|
||||
"""Update vector store metadata in SQLite database."""
|
||||
|
||||
def _update():
|
||||
connection = _create_sqlite_connection(self.config.db_path)
|
||||
cur = connection.cursor()
|
||||
try:
|
||||
cur.execute(
|
||||
"UPDATE openai_vector_stores SET metadata = ? WHERE id = ?",
|
||||
(json.dumps(store_info), store_id),
|
||||
)
|
||||
connection.commit()
|
||||
finally:
|
||||
cur.close()
|
||||
connection.close()
|
||||
|
||||
await asyncio.to_thread(_update)
|
||||
|
||||
async def _delete_openai_vector_store_from_storage(self, store_id: str) -> None:
|
||||
"""Delete vector store metadata from SQLite database."""
|
||||
|
||||
def _delete():
|
||||
connection = _create_sqlite_connection(self.config.db_path)
|
||||
cur = connection.cursor()
|
||||
try:
|
||||
cur.execute("DELETE FROM openai_vector_stores WHERE id = ?", (store_id,))
|
||||
connection.commit()
|
||||
finally:
|
||||
cur.close()
|
||||
connection.close()
|
||||
|
||||
await asyncio.to_thread(_delete)
|
||||
|
||||
async def insert_chunks(self, vector_db_id: str, chunks: list[Chunk], ttl_seconds: int | None = None) -> None:
|
||||
if vector_db_id not in self.cache:
|
||||
raise ValueError(f"Vector DB {vector_db_id} not found. Found: {list(self.cache.keys())}")
|
||||
|
|
|
|||
|
|
@ -24,6 +24,7 @@ def available_providers() -> list[ProviderSpec]:
|
|||
config_class="llama_stack.providers.inline.vector_io.faiss.FaissVectorIOConfig",
|
||||
deprecation_warning="Please use the `inline::faiss` provider instead.",
|
||||
api_dependencies=[Api.inference],
|
||||
optional_api_dependencies=[Api.files],
|
||||
),
|
||||
InlineProviderSpec(
|
||||
api=Api.vector_io,
|
||||
|
|
@ -32,6 +33,7 @@ def available_providers() -> list[ProviderSpec]:
|
|||
module="llama_stack.providers.inline.vector_io.faiss",
|
||||
config_class="llama_stack.providers.inline.vector_io.faiss.FaissVectorIOConfig",
|
||||
api_dependencies=[Api.inference],
|
||||
optional_api_dependencies=[Api.files],
|
||||
),
|
||||
# NOTE: sqlite-vec cannot be bundled into the container image because it does not have a
|
||||
# source distribution and the wheels are not available for all platforms.
|
||||
|
|
@ -42,6 +44,7 @@ def available_providers() -> list[ProviderSpec]:
|
|||
module="llama_stack.providers.inline.vector_io.sqlite_vec",
|
||||
config_class="llama_stack.providers.inline.vector_io.sqlite_vec.SQLiteVectorIOConfig",
|
||||
api_dependencies=[Api.inference],
|
||||
optional_api_dependencies=[Api.files],
|
||||
),
|
||||
InlineProviderSpec(
|
||||
api=Api.vector_io,
|
||||
|
|
@ -51,6 +54,7 @@ def available_providers() -> list[ProviderSpec]:
|
|||
config_class="llama_stack.providers.inline.vector_io.sqlite_vec.SQLiteVectorIOConfig",
|
||||
deprecation_warning="Please use the `inline::sqlite-vec` provider (notice the hyphen instead of underscore) instead.",
|
||||
api_dependencies=[Api.inference],
|
||||
optional_api_dependencies=[Api.files],
|
||||
),
|
||||
remote_provider_spec(
|
||||
Api.vector_io,
|
||||
|
|
|
|||
|
|
@ -318,6 +318,7 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProv
|
|||
user: str | None = None,
|
||||
guided_choice: list[str] | None = None,
|
||||
prompt_logprobs: int | None = None,
|
||||
suffix: str | None = None,
|
||||
) -> OpenAICompletion:
|
||||
model_obj = await self.model_store.get_model(model)
|
||||
|
||||
|
|
|
|||
|
|
@ -316,6 +316,7 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper):
|
|||
user: str | None = None,
|
||||
guided_choice: list[str] | None = None,
|
||||
prompt_logprobs: int | None = None,
|
||||
suffix: str | None = None,
|
||||
) -> OpenAICompletion:
|
||||
provider_model_id = await self._get_provider_model_id(model)
|
||||
|
||||
|
|
|
|||
|
|
@ -33,7 +33,6 @@ from llama_stack.apis.inference import (
|
|||
JsonSchemaResponseFormat,
|
||||
LogProbConfig,
|
||||
Message,
|
||||
OpenAIEmbeddingsResponse,
|
||||
ResponseFormat,
|
||||
SamplingParams,
|
||||
TextTruncation,
|
||||
|
|
@ -46,6 +45,8 @@ from llama_stack.apis.inference.inference import (
|
|||
OpenAIChatCompletion,
|
||||
OpenAIChatCompletionChunk,
|
||||
OpenAICompletion,
|
||||
OpenAIEmbeddingsResponse,
|
||||
OpenAIEmbeddingUsage,
|
||||
OpenAIMessageParam,
|
||||
OpenAIResponseFormatParam,
|
||||
)
|
||||
|
|
@ -62,8 +63,10 @@ from llama_stack.providers.utils.inference.model_registry import (
|
|||
from llama_stack.providers.utils.inference.openai_compat import (
|
||||
OpenAICompatCompletionChoice,
|
||||
OpenAICompatCompletionResponse,
|
||||
b64_encode_openai_embeddings_response,
|
||||
get_sampling_options,
|
||||
prepare_openai_completion_params,
|
||||
prepare_openai_embeddings_params,
|
||||
process_chat_completion_response,
|
||||
process_chat_completion_stream_response,
|
||||
process_completion_response,
|
||||
|
|
@ -386,7 +389,35 @@ class OllamaInferenceAdapter(
|
|||
dimensions: int | None = None,
|
||||
user: str | None = None,
|
||||
) -> OpenAIEmbeddingsResponse:
|
||||
raise NotImplementedError()
|
||||
model_obj = await self._get_model(model)
|
||||
if model_obj.model_type != ModelType.embedding:
|
||||
raise ValueError(f"Model {model} is not an embedding model")
|
||||
|
||||
if model_obj.provider_resource_id is None:
|
||||
raise ValueError(f"Model {model} has no provider_resource_id set")
|
||||
|
||||
# Note, at the moment Ollama does not support encoding_format, dimensions, and user parameters
|
||||
params = prepare_openai_embeddings_params(
|
||||
model=model_obj.provider_resource_id,
|
||||
input=input,
|
||||
encoding_format=encoding_format,
|
||||
dimensions=dimensions,
|
||||
user=user,
|
||||
)
|
||||
|
||||
response = await self.openai_client.embeddings.create(**params)
|
||||
data = b64_encode_openai_embeddings_response(response.data, encoding_format)
|
||||
|
||||
usage = OpenAIEmbeddingUsage(
|
||||
prompt_tokens=response.usage.prompt_tokens,
|
||||
total_tokens=response.usage.total_tokens,
|
||||
)
|
||||
# TODO: Investigate why model_obj.identifier is used instead of response.model
|
||||
return OpenAIEmbeddingsResponse(
|
||||
data=data,
|
||||
model=model_obj.identifier,
|
||||
usage=usage,
|
||||
)
|
||||
|
||||
async def openai_completion(
|
||||
self,
|
||||
|
|
@ -409,6 +440,7 @@ class OllamaInferenceAdapter(
|
|||
user: str | None = None,
|
||||
guided_choice: list[str] | None = None,
|
||||
prompt_logprobs: int | None = None,
|
||||
suffix: str | None = None,
|
||||
) -> OpenAICompletion:
|
||||
if not isinstance(prompt, str):
|
||||
raise ValueError("Ollama does not support non-string prompts for completion")
|
||||
|
|
@ -432,6 +464,7 @@ class OllamaInferenceAdapter(
|
|||
temperature=temperature,
|
||||
top_p=top_p,
|
||||
user=user,
|
||||
suffix=suffix,
|
||||
)
|
||||
return await self.openai_client.completions.create(**params) # type: ignore
|
||||
|
||||
|
|
|
|||
|
|
@ -90,6 +90,7 @@ class OpenAIInferenceAdapter(LiteLLMOpenAIMixin):
|
|||
user: str | None = None,
|
||||
guided_choice: list[str] | None = None,
|
||||
prompt_logprobs: int | None = None,
|
||||
suffix: str | None = None,
|
||||
) -> OpenAICompletion:
|
||||
if guided_choice is not None:
|
||||
logging.warning("guided_choice is not supported by the OpenAI API. Ignoring.")
|
||||
|
|
@ -117,6 +118,7 @@ class OpenAIInferenceAdapter(LiteLLMOpenAIMixin):
|
|||
temperature=temperature,
|
||||
top_p=top_p,
|
||||
user=user,
|
||||
suffix=suffix,
|
||||
)
|
||||
return await self._openai_client.completions.create(**params)
|
||||
|
||||
|
|
|
|||
|
|
@ -242,6 +242,7 @@ class PassthroughInferenceAdapter(Inference):
|
|||
user: str | None = None,
|
||||
guided_choice: list[str] | None = None,
|
||||
prompt_logprobs: int | None = None,
|
||||
suffix: str | None = None,
|
||||
) -> OpenAICompletion:
|
||||
client = self._get_client()
|
||||
model_obj = await self.model_store.get_model(model)
|
||||
|
|
|
|||
|
|
@ -299,6 +299,7 @@ class TogetherInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProvi
|
|||
user: str | None = None,
|
||||
guided_choice: list[str] | None = None,
|
||||
prompt_logprobs: int | None = None,
|
||||
suffix: str | None = None,
|
||||
) -> OpenAICompletion:
|
||||
model_obj = await self.model_store.get_model(model)
|
||||
params = await prepare_openai_completion_params(
|
||||
|
|
|
|||
|
|
@ -56,7 +56,11 @@ from llama_stack.apis.inference.inference import (
|
|||
from llama_stack.apis.models import Model, ModelType
|
||||
from llama_stack.models.llama.datatypes import BuiltinTool, StopReason, ToolCall
|
||||
from llama_stack.models.llama.sku_list import all_registered_models
|
||||
from llama_stack.providers.datatypes import ModelsProtocolPrivate
|
||||
from llama_stack.providers.datatypes import (
|
||||
HealthResponse,
|
||||
HealthStatus,
|
||||
ModelsProtocolPrivate,
|
||||
)
|
||||
from llama_stack.providers.utils.inference.model_registry import (
|
||||
ModelRegistryHelper,
|
||||
build_hf_repo_model_entry,
|
||||
|
|
@ -298,6 +302,22 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
|
|||
async def unregister_model(self, model_id: str) -> None:
|
||||
pass
|
||||
|
||||
async def health(self) -> HealthResponse:
|
||||
"""
|
||||
Performs a health check by verifying connectivity to the remote vLLM server.
|
||||
This method is used by the Provider API to verify
|
||||
that the service is running correctly.
|
||||
Returns:
|
||||
|
||||
HealthResponse: A dictionary containing the health status.
|
||||
"""
|
||||
try:
|
||||
client = self._create_client() if self.client is None else self.client
|
||||
_ = [m async for m in client.models.list()] # Ensure the client is initialized
|
||||
return HealthResponse(status=HealthStatus.OK)
|
||||
except Exception as e:
|
||||
return HealthResponse(status=HealthStatus.ERROR, message=f"Health check failed: {str(e)}")
|
||||
|
||||
async def _get_model(self, model_id: str) -> Model:
|
||||
if not self.model_store:
|
||||
raise ValueError("Model store not set")
|
||||
|
|
@ -539,6 +559,7 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
|
|||
user: str | None = None,
|
||||
guided_choice: list[str] | None = None,
|
||||
prompt_logprobs: int | None = None,
|
||||
suffix: str | None = None,
|
||||
) -> OpenAICompletion:
|
||||
self._lazy_initialize_client()
|
||||
model_obj = await self._get_model(model)
|
||||
|
|
|
|||
|
|
@ -292,6 +292,7 @@ class WatsonXInferenceAdapter(Inference, ModelRegistryHelper):
|
|||
user: str | None = None,
|
||||
guided_choice: list[str] | None = None,
|
||||
prompt_logprobs: int | None = None,
|
||||
suffix: str | None = None,
|
||||
) -> OpenAICompletion:
|
||||
model_obj = await self.model_store.get_model(model)
|
||||
params = await prepare_openai_completion_params(
|
||||
|
|
|
|||
|
|
@ -14,7 +14,16 @@ from numpy.typing import NDArray
|
|||
|
||||
from llama_stack.apis.inference import InterleavedContent
|
||||
from llama_stack.apis.vector_dbs import VectorDB
|
||||
from llama_stack.apis.vector_io import Chunk, QueryChunksResponse, VectorIO
|
||||
from llama_stack.apis.vector_io import (
|
||||
Chunk,
|
||||
QueryChunksResponse,
|
||||
VectorIO,
|
||||
VectorStoreDeleteResponse,
|
||||
VectorStoreListResponse,
|
||||
VectorStoreObject,
|
||||
VectorStoreSearchResponsePage,
|
||||
)
|
||||
from llama_stack.apis.vector_io.vector_io import VectorStoreChunkingStrategy, VectorStoreFileObject
|
||||
from llama_stack.providers.datatypes import Api, VectorDBsProtocolPrivate
|
||||
from llama_stack.providers.inline.vector_io.chroma import ChromaVectorIOConfig as InlineChromaVectorIOConfig
|
||||
from llama_stack.providers.utils.memory.vector_store import (
|
||||
|
|
@ -55,7 +64,7 @@ class ChromaIndex(EmbeddingIndex):
|
|||
)
|
||||
)
|
||||
|
||||
async def query(self, embedding: NDArray, k: int, score_threshold: float) -> QueryChunksResponse:
|
||||
async def query_vector(self, embedding: NDArray, k: int, score_threshold: float) -> QueryChunksResponse:
|
||||
results = await maybe_await(
|
||||
self.collection.query(
|
||||
query_embeddings=[embedding.tolist()],
|
||||
|
|
@ -76,8 +85,12 @@ class ChromaIndex(EmbeddingIndex):
|
|||
log.exception(f"Failed to parse document: {doc}")
|
||||
continue
|
||||
|
||||
score = 1.0 / float(dist) if dist != 0 else float("inf")
|
||||
if score < score_threshold:
|
||||
continue
|
||||
|
||||
chunks.append(chunk)
|
||||
scores.append(1.0 / float(dist))
|
||||
scores.append(score)
|
||||
|
||||
return QueryChunksResponse(chunks=chunks, scores=scores)
|
||||
|
||||
|
|
@ -92,6 +105,17 @@ class ChromaIndex(EmbeddingIndex):
|
|||
) -> QueryChunksResponse:
|
||||
raise NotImplementedError("Keyword search is not supported in Chroma")
|
||||
|
||||
async def query_hybrid(
|
||||
self,
|
||||
embedding: NDArray,
|
||||
query_string: str,
|
||||
k: int,
|
||||
score_threshold: float,
|
||||
reranker_type: str,
|
||||
reranker_params: dict[str, Any] | None = None,
|
||||
) -> QueryChunksResponse:
|
||||
raise NotImplementedError("Hybrid search is not supported in Chroma")
|
||||
|
||||
|
||||
class ChromaVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate):
|
||||
def __init__(
|
||||
|
|
@ -174,3 +198,67 @@ class ChromaVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate):
|
|||
index = VectorDBWithIndex(vector_db, ChromaIndex(self.client, collection), self.inference_api)
|
||||
self.cache[vector_db_id] = index
|
||||
return index
|
||||
|
||||
async def openai_create_vector_store(
|
||||
self,
|
||||
name: str,
|
||||
file_ids: list[str] | None = None,
|
||||
expires_after: dict[str, Any] | None = None,
|
||||
chunking_strategy: dict[str, Any] | None = None,
|
||||
metadata: dict[str, Any] | None = None,
|
||||
embedding_model: str | None = None,
|
||||
embedding_dimension: int | None = 384,
|
||||
provider_id: str | None = None,
|
||||
provider_vector_db_id: str | None = None,
|
||||
) -> VectorStoreObject:
|
||||
raise NotImplementedError("OpenAI Vector Stores API is not supported in Chroma")
|
||||
|
||||
async def openai_list_vector_stores(
|
||||
self,
|
||||
limit: int | None = 20,
|
||||
order: str | None = "desc",
|
||||
after: str | None = None,
|
||||
before: str | None = None,
|
||||
) -> VectorStoreListResponse:
|
||||
raise NotImplementedError("OpenAI Vector Stores API is not supported in Chroma")
|
||||
|
||||
async def openai_retrieve_vector_store(
|
||||
self,
|
||||
vector_store_id: str,
|
||||
) -> VectorStoreObject:
|
||||
raise NotImplementedError("OpenAI Vector Stores API is not supported in Chroma")
|
||||
|
||||
async def openai_update_vector_store(
|
||||
self,
|
||||
vector_store_id: str,
|
||||
name: str | None = None,
|
||||
expires_after: dict[str, Any] | None = None,
|
||||
metadata: dict[str, Any] | None = None,
|
||||
) -> VectorStoreObject:
|
||||
raise NotImplementedError("OpenAI Vector Stores API is not supported in Chroma")
|
||||
|
||||
async def openai_delete_vector_store(
|
||||
self,
|
||||
vector_store_id: str,
|
||||
) -> VectorStoreDeleteResponse:
|
||||
raise NotImplementedError("OpenAI Vector Stores API is not supported in Chroma")
|
||||
|
||||
async def openai_search_vector_store(
|
||||
self,
|
||||
vector_store_id: str,
|
||||
query: str | list[str],
|
||||
filters: dict[str, Any] | None = None,
|
||||
max_num_results: int | None = 10,
|
||||
ranking_options: dict[str, Any] | None = None,
|
||||
rewrite_query: bool | None = False,
|
||||
) -> VectorStoreSearchResponsePage:
|
||||
raise NotImplementedError("OpenAI Vector Stores API is not supported in Chroma")
|
||||
|
||||
async def openai_attach_file_to_vector_store(
|
||||
self,
|
||||
vector_store_id: str,
|
||||
file_id: str,
|
||||
attributes: dict[str, Any] | None = None,
|
||||
chunking_strategy: VectorStoreChunkingStrategy | None = None,
|
||||
) -> VectorStoreFileObject:
|
||||
raise NotImplementedError("OpenAI Vector Stores API is not supported in Chroma")
|
||||
|
|
|
|||
|
|
@ -16,7 +16,16 @@ from pymilvus import MilvusClient
|
|||
|
||||
from llama_stack.apis.inference import InterleavedContent
|
||||
from llama_stack.apis.vector_dbs import VectorDB
|
||||
from llama_stack.apis.vector_io import Chunk, QueryChunksResponse, VectorIO
|
||||
from llama_stack.apis.vector_io import (
|
||||
Chunk,
|
||||
QueryChunksResponse,
|
||||
VectorIO,
|
||||
VectorStoreDeleteResponse,
|
||||
VectorStoreListResponse,
|
||||
VectorStoreObject,
|
||||
VectorStoreSearchResponsePage,
|
||||
)
|
||||
from llama_stack.apis.vector_io.vector_io import VectorStoreChunkingStrategy, VectorStoreFileObject
|
||||
from llama_stack.providers.datatypes import Api, VectorDBsProtocolPrivate
|
||||
from llama_stack.providers.inline.vector_io.milvus import MilvusVectorIOConfig as InlineMilvusVectorIOConfig
|
||||
from llama_stack.providers.utils.memory.vector_store import (
|
||||
|
|
@ -94,6 +103,17 @@ class MilvusIndex(EmbeddingIndex):
|
|||
) -> QueryChunksResponse:
|
||||
raise NotImplementedError("Keyword search is not supported in Milvus")
|
||||
|
||||
async def query_hybrid(
|
||||
self,
|
||||
embedding: NDArray,
|
||||
query_string: str,
|
||||
k: int,
|
||||
score_threshold: float,
|
||||
reranker_type: str,
|
||||
reranker_params: dict[str, Any] | None = None,
|
||||
) -> QueryChunksResponse:
|
||||
raise NotImplementedError("Hybrid search is not supported in Milvus")
|
||||
|
||||
|
||||
class MilvusVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate):
|
||||
def __init__(
|
||||
|
|
@ -177,6 +197,70 @@ class MilvusVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate):
|
|||
|
||||
return await index.query_chunks(query, params)
|
||||
|
||||
async def openai_create_vector_store(
|
||||
self,
|
||||
name: str,
|
||||
file_ids: list[str] | None = None,
|
||||
expires_after: dict[str, Any] | None = None,
|
||||
chunking_strategy: dict[str, Any] | None = None,
|
||||
metadata: dict[str, Any] | None = None,
|
||||
embedding_model: str | None = None,
|
||||
embedding_dimension: int | None = 384,
|
||||
provider_id: str | None = None,
|
||||
provider_vector_db_id: str | None = None,
|
||||
) -> VectorStoreObject:
|
||||
raise NotImplementedError("OpenAI Vector Stores API is not supported in Qdrant")
|
||||
|
||||
async def openai_list_vector_stores(
|
||||
self,
|
||||
limit: int | None = 20,
|
||||
order: str | None = "desc",
|
||||
after: str | None = None,
|
||||
before: str | None = None,
|
||||
) -> VectorStoreListResponse:
|
||||
raise NotImplementedError("OpenAI Vector Stores API is not supported in Qdrant")
|
||||
|
||||
async def openai_retrieve_vector_store(
|
||||
self,
|
||||
vector_store_id: str,
|
||||
) -> VectorStoreObject:
|
||||
raise NotImplementedError("OpenAI Vector Stores API is not supported in Qdrant")
|
||||
|
||||
async def openai_update_vector_store(
|
||||
self,
|
||||
vector_store_id: str,
|
||||
name: str | None = None,
|
||||
expires_after: dict[str, Any] | None = None,
|
||||
metadata: dict[str, Any] | None = None,
|
||||
) -> VectorStoreObject:
|
||||
raise NotImplementedError("OpenAI Vector Stores API is not supported in Qdrant")
|
||||
|
||||
async def openai_delete_vector_store(
|
||||
self,
|
||||
vector_store_id: str,
|
||||
) -> VectorStoreDeleteResponse:
|
||||
raise NotImplementedError("OpenAI Vector Stores API is not supported in Qdrant")
|
||||
|
||||
async def openai_search_vector_store(
|
||||
self,
|
||||
vector_store_id: str,
|
||||
query: str | list[str],
|
||||
filters: dict[str, Any] | None = None,
|
||||
max_num_results: int | None = 10,
|
||||
ranking_options: dict[str, Any] | None = None,
|
||||
rewrite_query: bool | None = False,
|
||||
) -> VectorStoreSearchResponsePage:
|
||||
raise NotImplementedError("OpenAI Vector Stores API is not supported in Qdrant")
|
||||
|
||||
async def openai_attach_file_to_vector_store(
|
||||
self,
|
||||
vector_store_id: str,
|
||||
file_id: str,
|
||||
attributes: dict[str, Any] | None = None,
|
||||
chunking_strategy: VectorStoreChunkingStrategy | None = None,
|
||||
) -> VectorStoreFileObject:
|
||||
raise NotImplementedError("OpenAI Vector Stores API is not supported in Milvus")
|
||||
|
||||
|
||||
def generate_chunk_id(document_id: str, chunk_text: str) -> str:
|
||||
"""Generate a unique chunk ID using a hash of document ID and chunk text."""
|
||||
|
|
|
|||
|
|
@ -116,7 +116,7 @@ class PGVectorIndex(EmbeddingIndex):
|
|||
scores = []
|
||||
for doc, dist in results:
|
||||
chunks.append(Chunk(**doc))
|
||||
scores.append(1.0 / float(dist))
|
||||
scores.append(1.0 / float(dist) if dist != 0 else float("inf"))
|
||||
|
||||
return QueryChunksResponse(chunks=chunks, scores=scores)
|
||||
|
||||
|
|
@ -128,6 +128,17 @@ class PGVectorIndex(EmbeddingIndex):
|
|||
) -> QueryChunksResponse:
|
||||
raise NotImplementedError("Keyword search is not supported in PGVector")
|
||||
|
||||
async def query_hybrid(
|
||||
self,
|
||||
embedding: NDArray,
|
||||
query_string: str,
|
||||
k: int,
|
||||
score_threshold: float,
|
||||
reranker_type: str,
|
||||
reranker_params: dict[str, Any] | None = None,
|
||||
) -> QueryChunksResponse:
|
||||
raise NotImplementedError("Hybrid search is not supported in PGVector")
|
||||
|
||||
async def delete(self):
|
||||
with self.conn.cursor(cursor_factory=psycopg2.extras.DictCursor) as cur:
|
||||
cur.execute(f"DROP TABLE IF EXISTS {self.table_name}")
|
||||
|
|
|
|||
|
|
@ -14,7 +14,16 @@ from qdrant_client.models import PointStruct
|
|||
|
||||
from llama_stack.apis.inference import InterleavedContent
|
||||
from llama_stack.apis.vector_dbs import VectorDB
|
||||
from llama_stack.apis.vector_io import Chunk, QueryChunksResponse, VectorIO
|
||||
from llama_stack.apis.vector_io import (
|
||||
Chunk,
|
||||
QueryChunksResponse,
|
||||
VectorIO,
|
||||
VectorStoreDeleteResponse,
|
||||
VectorStoreListResponse,
|
||||
VectorStoreObject,
|
||||
VectorStoreSearchResponsePage,
|
||||
)
|
||||
from llama_stack.apis.vector_io.vector_io import VectorStoreChunkingStrategy, VectorStoreFileObject
|
||||
from llama_stack.providers.datatypes import Api, VectorDBsProtocolPrivate
|
||||
from llama_stack.providers.inline.vector_io.qdrant import QdrantVectorIOConfig as InlineQdrantVectorIOConfig
|
||||
from llama_stack.providers.utils.memory.vector_store import (
|
||||
|
|
@ -103,6 +112,17 @@ class QdrantIndex(EmbeddingIndex):
|
|||
) -> QueryChunksResponse:
|
||||
raise NotImplementedError("Keyword search is not supported in Qdrant")
|
||||
|
||||
async def query_hybrid(
|
||||
self,
|
||||
embedding: NDArray,
|
||||
query_string: str,
|
||||
k: int,
|
||||
score_threshold: float,
|
||||
reranker_type: str,
|
||||
reranker_params: dict[str, Any] | None = None,
|
||||
) -> QueryChunksResponse:
|
||||
raise NotImplementedError("Hybrid search is not supported in Qdrant")
|
||||
|
||||
async def delete(self):
|
||||
await self.client.delete_collection(collection_name=self.collection_name)
|
||||
|
||||
|
|
@ -178,3 +198,67 @@ class QdrantVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate):
|
|||
raise ValueError(f"Vector DB {vector_db_id} not found")
|
||||
|
||||
return await index.query_chunks(query, params)
|
||||
|
||||
async def openai_create_vector_store(
|
||||
self,
|
||||
name: str,
|
||||
file_ids: list[str] | None = None,
|
||||
expires_after: dict[str, Any] | None = None,
|
||||
chunking_strategy: dict[str, Any] | None = None,
|
||||
metadata: dict[str, Any] | None = None,
|
||||
embedding_model: str | None = None,
|
||||
embedding_dimension: int | None = 384,
|
||||
provider_id: str | None = None,
|
||||
provider_vector_db_id: str | None = None,
|
||||
) -> VectorStoreObject:
|
||||
raise NotImplementedError("OpenAI Vector Stores API is not supported in Qdrant")
|
||||
|
||||
async def openai_list_vector_stores(
|
||||
self,
|
||||
limit: int | None = 20,
|
||||
order: str | None = "desc",
|
||||
after: str | None = None,
|
||||
before: str | None = None,
|
||||
) -> VectorStoreListResponse:
|
||||
raise NotImplementedError("OpenAI Vector Stores API is not supported in Qdrant")
|
||||
|
||||
async def openai_retrieve_vector_store(
|
||||
self,
|
||||
vector_store_id: str,
|
||||
) -> VectorStoreObject:
|
||||
raise NotImplementedError("OpenAI Vector Stores API is not supported in Qdrant")
|
||||
|
||||
async def openai_update_vector_store(
|
||||
self,
|
||||
vector_store_id: str,
|
||||
name: str | None = None,
|
||||
expires_after: dict[str, Any] | None = None,
|
||||
metadata: dict[str, Any] | None = None,
|
||||
) -> VectorStoreObject:
|
||||
raise NotImplementedError("OpenAI Vector Stores API is not supported in Qdrant")
|
||||
|
||||
async def openai_delete_vector_store(
|
||||
self,
|
||||
vector_store_id: str,
|
||||
) -> VectorStoreDeleteResponse:
|
||||
raise NotImplementedError("OpenAI Vector Stores API is not supported in Qdrant")
|
||||
|
||||
async def openai_search_vector_store(
|
||||
self,
|
||||
vector_store_id: str,
|
||||
query: str | list[str],
|
||||
filters: dict[str, Any] | None = None,
|
||||
max_num_results: int | None = 10,
|
||||
ranking_options: dict[str, Any] | None = None,
|
||||
rewrite_query: bool | None = False,
|
||||
) -> VectorStoreSearchResponsePage:
|
||||
raise NotImplementedError("OpenAI Vector Stores API is not supported in Qdrant")
|
||||
|
||||
async def openai_attach_file_to_vector_store(
|
||||
self,
|
||||
vector_store_id: str,
|
||||
file_id: str,
|
||||
attributes: dict[str, Any] | None = None,
|
||||
chunking_strategy: VectorStoreChunkingStrategy | None = None,
|
||||
) -> VectorStoreFileObject:
|
||||
raise NotImplementedError("OpenAI Vector Stores API is not supported in Qdrant")
|
||||
|
|
|
|||
|
|
@ -76,7 +76,7 @@ class WeaviateIndex(EmbeddingIndex):
|
|||
continue
|
||||
|
||||
chunks.append(chunk)
|
||||
scores.append(1.0 / doc.metadata.distance)
|
||||
scores.append(1.0 / doc.metadata.distance if doc.metadata.distance != 0 else float("inf"))
|
||||
|
||||
return QueryChunksResponse(chunks=chunks, scores=scores)
|
||||
|
||||
|
|
@ -92,6 +92,17 @@ class WeaviateIndex(EmbeddingIndex):
|
|||
) -> QueryChunksResponse:
|
||||
raise NotImplementedError("Keyword search is not supported in Weaviate")
|
||||
|
||||
async def query_hybrid(
|
||||
self,
|
||||
embedding: NDArray,
|
||||
query_string: str,
|
||||
k: int,
|
||||
score_threshold: float,
|
||||
reranker_type: str,
|
||||
reranker_params: dict[str, Any] | None = None,
|
||||
) -> QueryChunksResponse:
|
||||
raise NotImplementedError("Hybrid search is not supported in Weaviate")
|
||||
|
||||
|
||||
class WeaviateVectorIOAdapter(
|
||||
VectorIO,
|
||||
|
|
|
|||
|
|
@ -4,8 +4,6 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import base64
|
||||
import struct
|
||||
from collections.abc import AsyncGenerator, AsyncIterator
|
||||
from typing import Any
|
||||
|
||||
|
|
@ -37,7 +35,6 @@ from llama_stack.apis.inference.inference import (
|
|||
OpenAIChatCompletion,
|
||||
OpenAIChatCompletionChunk,
|
||||
OpenAICompletion,
|
||||
OpenAIEmbeddingData,
|
||||
OpenAIEmbeddingsResponse,
|
||||
OpenAIEmbeddingUsage,
|
||||
OpenAIMessageParam,
|
||||
|
|
@ -48,6 +45,7 @@ from llama_stack.distribution.request_headers import NeedsRequestProviderData
|
|||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper
|
||||
from llama_stack.providers.utils.inference.openai_compat import (
|
||||
b64_encode_openai_embeddings_response,
|
||||
convert_message_to_openai_dict_new,
|
||||
convert_openai_chat_completion_choice,
|
||||
convert_openai_chat_completion_stream,
|
||||
|
|
@ -293,16 +291,7 @@ class LiteLLMOpenAIMixin(
|
|||
)
|
||||
|
||||
# Convert response to OpenAI format
|
||||
data = []
|
||||
for i, embedding_data in enumerate(response["data"]):
|
||||
# we encode to base64 if the encoding format is base64 in the request
|
||||
if encoding_format == "base64":
|
||||
byte_data = b"".join(struct.pack("f", f) for f in embedding_data["embedding"])
|
||||
embedding = base64.b64encode(byte_data).decode("utf-8")
|
||||
else:
|
||||
embedding = embedding_data["embedding"]
|
||||
|
||||
data.append(OpenAIEmbeddingData(embedding=embedding, index=i))
|
||||
data = b64_encode_openai_embeddings_response(response.data, encoding_format)
|
||||
|
||||
usage = OpenAIEmbeddingUsage(
|
||||
prompt_tokens=response["usage"]["prompt_tokens"],
|
||||
|
|
@ -336,6 +325,7 @@ class LiteLLMOpenAIMixin(
|
|||
user: str | None = None,
|
||||
guided_choice: list[str] | None = None,
|
||||
prompt_logprobs: int | None = None,
|
||||
suffix: str | None = None,
|
||||
) -> OpenAICompletion:
|
||||
model_obj = await self.model_store.get_model(model)
|
||||
params = await prepare_openai_completion_params(
|
||||
|
|
|
|||
|
|
@ -3,8 +3,10 @@
|
|||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
import base64
|
||||
import json
|
||||
import logging
|
||||
import struct
|
||||
import time
|
||||
import uuid
|
||||
import warnings
|
||||
|
|
@ -108,6 +110,7 @@ from llama_stack.apis.inference.inference import (
|
|||
OpenAIChatCompletion,
|
||||
OpenAICompletion,
|
||||
OpenAICompletionChoice,
|
||||
OpenAIEmbeddingData,
|
||||
OpenAIMessageParam,
|
||||
OpenAIResponseFormatParam,
|
||||
ToolConfig,
|
||||
|
|
@ -1287,6 +1290,7 @@ class OpenAICompletionToLlamaStackMixin:
|
|||
user: str | None = None,
|
||||
guided_choice: list[str] | None = None,
|
||||
prompt_logprobs: int | None = None,
|
||||
suffix: str | None = None,
|
||||
) -> OpenAICompletion:
|
||||
if stream:
|
||||
raise ValueError(f"{self.__class__.__name__} doesn't support streaming openai completions")
|
||||
|
|
@ -1483,3 +1487,55 @@ class OpenAIChatCompletionToLlamaStackMixin:
|
|||
model=model,
|
||||
object="chat.completion",
|
||||
)
|
||||
|
||||
|
||||
def prepare_openai_embeddings_params(
|
||||
model: str,
|
||||
input: str | list[str],
|
||||
encoding_format: str | None = "float",
|
||||
dimensions: int | None = None,
|
||||
user: str | None = None,
|
||||
):
|
||||
if model is None:
|
||||
raise ValueError("Model must be provided for embeddings")
|
||||
|
||||
input_list = [input] if isinstance(input, str) else input
|
||||
|
||||
params: dict[str, Any] = {
|
||||
"model": model,
|
||||
"input": input_list,
|
||||
}
|
||||
|
||||
if encoding_format is not None:
|
||||
params["encoding_format"] = encoding_format
|
||||
if dimensions is not None:
|
||||
params["dimensions"] = dimensions
|
||||
if user is not None:
|
||||
params["user"] = user
|
||||
|
||||
return params
|
||||
|
||||
|
||||
def b64_encode_openai_embeddings_response(
|
||||
response_data: dict, encoding_format: str | None = "float"
|
||||
) -> list[OpenAIEmbeddingData]:
|
||||
"""
|
||||
Process the OpenAI embeddings response to encode the embeddings in base64 format if specified.
|
||||
"""
|
||||
data = []
|
||||
for i, embedding_data in enumerate(response_data):
|
||||
if encoding_format == "base64":
|
||||
byte_array = bytearray()
|
||||
for embedding_value in embedding_data.embedding:
|
||||
byte_array.extend(struct.pack("f", float(embedding_value)))
|
||||
|
||||
response_embedding = base64.b64encode(byte_array).decode("utf-8")
|
||||
else:
|
||||
response_embedding = embedding_data.embedding
|
||||
data.append(
|
||||
OpenAIEmbeddingData(
|
||||
embedding=response_embedding,
|
||||
index=i,
|
||||
)
|
||||
)
|
||||
return data
|
||||
|
|
|
|||
482
llama_stack/providers/utils/memory/openai_vector_store_mixin.py
Normal file
482
llama_stack/providers/utils/memory/openai_vector_store_mixin.py
Normal file
|
|
@ -0,0 +1,482 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import logging
|
||||
import mimetypes
|
||||
import time
|
||||
import uuid
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any
|
||||
|
||||
from llama_stack.apis.files import Files
|
||||
from llama_stack.apis.vector_dbs import VectorDB
|
||||
from llama_stack.apis.vector_io import (
|
||||
QueryChunksResponse,
|
||||
VectorStoreContent,
|
||||
VectorStoreDeleteResponse,
|
||||
VectorStoreListResponse,
|
||||
VectorStoreObject,
|
||||
VectorStoreSearchResponse,
|
||||
VectorStoreSearchResponsePage,
|
||||
)
|
||||
from llama_stack.apis.vector_io.vector_io import (
|
||||
Chunk,
|
||||
VectorStoreChunkingStrategy,
|
||||
VectorStoreChunkingStrategyAuto,
|
||||
VectorStoreChunkingStrategyStatic,
|
||||
VectorStoreFileLastError,
|
||||
VectorStoreFileObject,
|
||||
)
|
||||
from llama_stack.providers.utils.memory.vector_store import content_from_data_and_mime_type, make_overlapped_chunks
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Constants for OpenAI vector stores
|
||||
CHUNK_MULTIPLIER = 5
|
||||
|
||||
|
||||
class OpenAIVectorStoreMixin(ABC):
|
||||
"""
|
||||
Mixin class that provides common OpenAI Vector Store API implementation.
|
||||
Providers need to implement the abstract storage methods and maintain
|
||||
an openai_vector_stores in-memory cache.
|
||||
"""
|
||||
|
||||
# These should be provided by the implementing class
|
||||
openai_vector_stores: dict[str, dict[str, Any]]
|
||||
files_api: Files | None
|
||||
|
||||
@abstractmethod
|
||||
async def _save_openai_vector_store(self, store_id: str, store_info: dict[str, Any]) -> None:
|
||||
"""Save vector store metadata to persistent storage."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def _load_openai_vector_stores(self) -> dict[str, dict[str, Any]]:
|
||||
"""Load all vector store metadata from persistent storage."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def _update_openai_vector_store(self, store_id: str, store_info: dict[str, Any]) -> None:
|
||||
"""Update vector store metadata in persistent storage."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def _delete_openai_vector_store_from_storage(self, store_id: str) -> None:
|
||||
"""Delete vector store metadata from persistent storage."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def register_vector_db(self, vector_db: VectorDB) -> None:
|
||||
"""Register a vector database (provider-specific implementation)."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def unregister_vector_db(self, vector_db_id: str) -> None:
|
||||
"""Unregister a vector database (provider-specific implementation)."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def insert_chunks(
|
||||
self,
|
||||
vector_db_id: str,
|
||||
chunks: list[Chunk],
|
||||
ttl_seconds: int | None = None,
|
||||
) -> None:
|
||||
"""Insert chunks into a vector database (provider-specific implementation)."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def query_chunks(
|
||||
self, vector_db_id: str, query: Any, params: dict[str, Any] | None = None
|
||||
) -> QueryChunksResponse:
|
||||
"""Query chunks from a vector database (provider-specific implementation)."""
|
||||
pass
|
||||
|
||||
async def openai_create_vector_store(
|
||||
self,
|
||||
name: str,
|
||||
file_ids: list[str] | None = None,
|
||||
expires_after: dict[str, Any] | None = None,
|
||||
chunking_strategy: dict[str, Any] | None = None,
|
||||
metadata: dict[str, Any] | None = None,
|
||||
embedding_model: str | None = None,
|
||||
embedding_dimension: int | None = 384,
|
||||
provider_id: str | None = None,
|
||||
provider_vector_db_id: str | None = None,
|
||||
) -> VectorStoreObject:
|
||||
"""Creates a vector store."""
|
||||
# store and vector_db have the same id
|
||||
store_id = name or str(uuid.uuid4())
|
||||
created_at = int(time.time())
|
||||
|
||||
if provider_id is None:
|
||||
raise ValueError("Provider ID is required")
|
||||
|
||||
if embedding_model is None:
|
||||
raise ValueError("Embedding model is required")
|
||||
|
||||
# Use provided embedding dimension or default to 384
|
||||
if embedding_dimension is None:
|
||||
raise ValueError("Embedding dimension is required")
|
||||
|
||||
provider_vector_db_id = provider_vector_db_id or store_id
|
||||
vector_db = VectorDB(
|
||||
identifier=store_id,
|
||||
embedding_dimension=embedding_dimension,
|
||||
embedding_model=embedding_model,
|
||||
provider_id=provider_id,
|
||||
provider_resource_id=provider_vector_db_id,
|
||||
)
|
||||
# Register the vector DB
|
||||
await self.register_vector_db(vector_db)
|
||||
|
||||
# Create OpenAI vector store metadata
|
||||
store_info = {
|
||||
"id": store_id,
|
||||
"object": "vector_store",
|
||||
"created_at": created_at,
|
||||
"name": store_id,
|
||||
"usage_bytes": 0,
|
||||
"file_counts": {},
|
||||
"status": "completed",
|
||||
"expires_after": expires_after,
|
||||
"expires_at": None,
|
||||
"last_active_at": created_at,
|
||||
"file_ids": file_ids or [],
|
||||
"chunking_strategy": chunking_strategy,
|
||||
}
|
||||
|
||||
# Add provider information to metadata if provided
|
||||
metadata = metadata or {}
|
||||
if provider_id:
|
||||
metadata["provider_id"] = provider_id
|
||||
if provider_vector_db_id:
|
||||
metadata["provider_vector_db_id"] = provider_vector_db_id
|
||||
store_info["metadata"] = metadata
|
||||
|
||||
# Save to persistent storage (provider-specific)
|
||||
await self._save_openai_vector_store(store_id, store_info)
|
||||
|
||||
# Store in memory cache
|
||||
self.openai_vector_stores[store_id] = store_info
|
||||
|
||||
return VectorStoreObject(
|
||||
id=store_id,
|
||||
created_at=created_at,
|
||||
name=store_id,
|
||||
usage_bytes=0,
|
||||
file_counts={},
|
||||
status="completed",
|
||||
expires_after=expires_after,
|
||||
expires_at=None,
|
||||
last_active_at=created_at,
|
||||
metadata=metadata,
|
||||
)
|
||||
|
||||
async def openai_list_vector_stores(
|
||||
self,
|
||||
limit: int | None = 20,
|
||||
order: str | None = "desc",
|
||||
after: str | None = None,
|
||||
before: str | None = None,
|
||||
) -> VectorStoreListResponse:
|
||||
"""Returns a list of vector stores."""
|
||||
limit = limit or 20
|
||||
order = order or "desc"
|
||||
|
||||
# Get all vector stores
|
||||
all_stores = list(self.openai_vector_stores.values())
|
||||
|
||||
# Sort by created_at
|
||||
reverse_order = order == "desc"
|
||||
all_stores.sort(key=lambda x: x["created_at"], reverse=reverse_order)
|
||||
|
||||
# Apply cursor-based pagination
|
||||
if after:
|
||||
after_index = next((i for i, store in enumerate(all_stores) if store["id"] == after), -1)
|
||||
if after_index >= 0:
|
||||
all_stores = all_stores[after_index + 1 :]
|
||||
|
||||
if before:
|
||||
before_index = next((i for i, store in enumerate(all_stores) if store["id"] == before), len(all_stores))
|
||||
all_stores = all_stores[:before_index]
|
||||
|
||||
# Apply limit
|
||||
limited_stores = all_stores[:limit]
|
||||
# Convert to VectorStoreObject instances
|
||||
data = [VectorStoreObject(**store) for store in limited_stores]
|
||||
|
||||
# Determine pagination info
|
||||
has_more = len(all_stores) > limit
|
||||
first_id = data[0].id if data else None
|
||||
last_id = data[-1].id if data else None
|
||||
|
||||
return VectorStoreListResponse(
|
||||
data=data,
|
||||
has_more=has_more,
|
||||
first_id=first_id,
|
||||
last_id=last_id,
|
||||
)
|
||||
|
||||
async def openai_retrieve_vector_store(
|
||||
self,
|
||||
vector_store_id: str,
|
||||
) -> VectorStoreObject:
|
||||
"""Retrieves a vector store."""
|
||||
if vector_store_id not in self.openai_vector_stores:
|
||||
raise ValueError(f"Vector store {vector_store_id} not found")
|
||||
|
||||
store_info = self.openai_vector_stores[vector_store_id]
|
||||
return VectorStoreObject(**store_info)
|
||||
|
||||
async def openai_update_vector_store(
|
||||
self,
|
||||
vector_store_id: str,
|
||||
name: str | None = None,
|
||||
expires_after: dict[str, Any] | None = None,
|
||||
metadata: dict[str, Any] | None = None,
|
||||
) -> VectorStoreObject:
|
||||
"""Modifies a vector store."""
|
||||
if vector_store_id not in self.openai_vector_stores:
|
||||
raise ValueError(f"Vector store {vector_store_id} not found")
|
||||
|
||||
store_info = self.openai_vector_stores[vector_store_id].copy()
|
||||
|
||||
# Update fields if provided
|
||||
if name is not None:
|
||||
store_info["name"] = name
|
||||
if expires_after is not None:
|
||||
store_info["expires_after"] = expires_after
|
||||
if metadata is not None:
|
||||
store_info["metadata"] = metadata
|
||||
|
||||
# Update last_active_at
|
||||
store_info["last_active_at"] = int(time.time())
|
||||
|
||||
# Save to persistent storage (provider-specific)
|
||||
await self._update_openai_vector_store(vector_store_id, store_info)
|
||||
|
||||
# Update in-memory cache
|
||||
self.openai_vector_stores[vector_store_id] = store_info
|
||||
|
||||
return VectorStoreObject(**store_info)
|
||||
|
||||
async def openai_delete_vector_store(
|
||||
self,
|
||||
vector_store_id: str,
|
||||
) -> VectorStoreDeleteResponse:
|
||||
"""Delete a vector store."""
|
||||
if vector_store_id not in self.openai_vector_stores:
|
||||
raise ValueError(f"Vector store {vector_store_id} not found")
|
||||
|
||||
# Delete from persistent storage (provider-specific)
|
||||
await self._delete_openai_vector_store_from_storage(vector_store_id)
|
||||
|
||||
# Delete from in-memory cache
|
||||
del self.openai_vector_stores[vector_store_id]
|
||||
|
||||
# Also delete the underlying vector DB
|
||||
try:
|
||||
await self.unregister_vector_db(vector_store_id)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to delete underlying vector DB {vector_store_id}: {e}")
|
||||
|
||||
return VectorStoreDeleteResponse(
|
||||
id=vector_store_id,
|
||||
deleted=True,
|
||||
)
|
||||
|
||||
async def openai_search_vector_store(
|
||||
self,
|
||||
vector_store_id: str,
|
||||
query: str | list[str],
|
||||
filters: dict[str, Any] | None = None,
|
||||
max_num_results: int | None = 10,
|
||||
ranking_options: dict[str, Any] | None = None,
|
||||
rewrite_query: bool | None = False,
|
||||
# search_mode: Literal["keyword", "vector", "hybrid"] = "vector",
|
||||
) -> VectorStoreSearchResponsePage:
|
||||
"""Search for chunks in a vector store."""
|
||||
# TODO: Add support in the API for this
|
||||
search_mode = "vector"
|
||||
max_num_results = max_num_results or 10
|
||||
|
||||
if vector_store_id not in self.openai_vector_stores:
|
||||
raise ValueError(f"Vector store {vector_store_id} not found")
|
||||
|
||||
if isinstance(query, list):
|
||||
search_query = " ".join(query)
|
||||
else:
|
||||
search_query = query
|
||||
|
||||
try:
|
||||
score_threshold = ranking_options.get("score_threshold", 0.0) if ranking_options else 0.0
|
||||
params = {
|
||||
"max_chunks": max_num_results * CHUNK_MULTIPLIER,
|
||||
"score_threshold": score_threshold,
|
||||
"mode": search_mode,
|
||||
}
|
||||
# TODO: Add support for ranking_options.ranker
|
||||
|
||||
response = await self.query_chunks(
|
||||
vector_db_id=vector_store_id,
|
||||
query=search_query,
|
||||
params=params,
|
||||
)
|
||||
|
||||
# Convert response to OpenAI format
|
||||
data = []
|
||||
for chunk, score in zip(response.chunks, response.scores, strict=False):
|
||||
# Apply score based filtering
|
||||
if score < score_threshold:
|
||||
continue
|
||||
|
||||
# Apply filters if provided
|
||||
if filters:
|
||||
# Simple metadata filtering
|
||||
if not self._matches_filters(chunk.metadata, filters):
|
||||
continue
|
||||
|
||||
# content is InterleavedContent
|
||||
if isinstance(chunk.content, str):
|
||||
content = [
|
||||
VectorStoreContent(
|
||||
type="text",
|
||||
text=chunk.content,
|
||||
)
|
||||
]
|
||||
elif isinstance(chunk.content, list):
|
||||
# TODO: Add support for other types of content
|
||||
content = [
|
||||
VectorStoreContent(
|
||||
type="text",
|
||||
text=item.text,
|
||||
)
|
||||
for item in chunk.content
|
||||
if item.type == "text"
|
||||
]
|
||||
else:
|
||||
if chunk.content.type != "text":
|
||||
raise ValueError(f"Unsupported content type: {chunk.content.type}")
|
||||
content = [
|
||||
VectorStoreContent(
|
||||
type="text",
|
||||
text=chunk.content.text,
|
||||
)
|
||||
]
|
||||
|
||||
response_data_item = VectorStoreSearchResponse(
|
||||
file_id=chunk.metadata.get("file_id", ""),
|
||||
filename=chunk.metadata.get("filename", ""),
|
||||
score=score,
|
||||
attributes=chunk.metadata,
|
||||
content=content,
|
||||
)
|
||||
data.append(response_data_item)
|
||||
if len(data) >= max_num_results:
|
||||
break
|
||||
|
||||
return VectorStoreSearchResponsePage(
|
||||
search_query=search_query,
|
||||
data=data,
|
||||
has_more=False, # For simplicity, we don't implement pagination here
|
||||
next_page=None,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error searching vector store {vector_store_id}: {e}")
|
||||
# Return empty results on error
|
||||
return VectorStoreSearchResponsePage(
|
||||
search_query=search_query,
|
||||
data=[],
|
||||
has_more=False,
|
||||
next_page=None,
|
||||
)
|
||||
|
||||
def _matches_filters(self, metadata: dict[str, Any], filters: dict[str, Any]) -> bool:
|
||||
"""Check if metadata matches the provided filters."""
|
||||
for key, value in filters.items():
|
||||
if key not in metadata:
|
||||
return False
|
||||
if metadata[key] != value:
|
||||
return False
|
||||
return True
|
||||
|
||||
async def openai_attach_file_to_vector_store(
|
||||
self,
|
||||
vector_store_id: str,
|
||||
file_id: str,
|
||||
attributes: dict[str, Any] | None = None,
|
||||
chunking_strategy: VectorStoreChunkingStrategy | None = None,
|
||||
) -> VectorStoreFileObject:
|
||||
attributes = attributes or {}
|
||||
chunking_strategy = chunking_strategy or VectorStoreChunkingStrategyAuto()
|
||||
|
||||
vector_store_file_object = VectorStoreFileObject(
|
||||
id=file_id,
|
||||
attributes=attributes,
|
||||
chunking_strategy=chunking_strategy,
|
||||
created_at=int(time.time()),
|
||||
status="in_progress",
|
||||
vector_store_id=vector_store_id,
|
||||
)
|
||||
|
||||
if not hasattr(self, "files_api") or not self.files_api:
|
||||
vector_store_file_object.status = "failed"
|
||||
vector_store_file_object.last_error = VectorStoreFileLastError(
|
||||
code="server_error",
|
||||
message="Files API is not available",
|
||||
)
|
||||
return vector_store_file_object
|
||||
|
||||
if isinstance(chunking_strategy, VectorStoreChunkingStrategyStatic):
|
||||
max_chunk_size_tokens = chunking_strategy.static.max_chunk_size_tokens
|
||||
chunk_overlap_tokens = chunking_strategy.static.chunk_overlap_tokens
|
||||
else:
|
||||
# Default values from OpenAI API spec
|
||||
max_chunk_size_tokens = 800
|
||||
chunk_overlap_tokens = 400
|
||||
|
||||
try:
|
||||
file_response = await self.files_api.openai_retrieve_file(file_id)
|
||||
mime_type, _ = mimetypes.guess_type(file_response.filename)
|
||||
content_response = await self.files_api.openai_retrieve_file_content(file_id)
|
||||
|
||||
content = content_from_data_and_mime_type(content_response.body, mime_type)
|
||||
|
||||
chunks = make_overlapped_chunks(
|
||||
file_id,
|
||||
content,
|
||||
max_chunk_size_tokens,
|
||||
chunk_overlap_tokens,
|
||||
attributes,
|
||||
)
|
||||
|
||||
if not chunks:
|
||||
vector_store_file_object.status = "failed"
|
||||
vector_store_file_object.last_error = VectorStoreFileLastError(
|
||||
code="server_error",
|
||||
message="No chunks were generated from the file",
|
||||
)
|
||||
return vector_store_file_object
|
||||
|
||||
await self.insert_chunks(
|
||||
vector_db_id=vector_store_id,
|
||||
chunks=chunks,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error attaching file to vector store: {e}")
|
||||
vector_store_file_object.status = "failed"
|
||||
vector_store_file_object.last_error = VectorStoreFileLastError(
|
||||
code="server_error",
|
||||
message=str(e),
|
||||
)
|
||||
return vector_store_file_object
|
||||
|
||||
vector_store_file_object.status = "completed"
|
||||
|
||||
return vector_store_file_object
|
||||
|
|
@ -32,6 +32,10 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
|
|||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
# Constants for reranker types
|
||||
RERANKER_TYPE_RRF = "rrf"
|
||||
RERANKER_TYPE_WEIGHTED = "weighted"
|
||||
|
||||
|
||||
def parse_pdf(data: bytes) -> str:
|
||||
# For PDF and DOC/DOCX files, we can't reliably convert to string
|
||||
|
|
@ -72,16 +76,18 @@ def content_from_data(data_url: str) -> str:
|
|||
data = unquote(data)
|
||||
encoding = parts["encoding"] or "utf-8"
|
||||
data = data.encode(encoding)
|
||||
return content_from_data_and_mime_type(data, parts["mimetype"], parts.get("encoding", None))
|
||||
|
||||
encoding = parts["encoding"]
|
||||
if not encoding:
|
||||
import chardet
|
||||
|
||||
detected = chardet.detect(data)
|
||||
encoding = detected["encoding"]
|
||||
def content_from_data_and_mime_type(data: bytes | str, mime_type: str | None, encoding: str | None = None) -> str:
|
||||
if isinstance(data, bytes):
|
||||
if not encoding:
|
||||
import chardet
|
||||
|
||||
mime_type = parts["mimetype"]
|
||||
mime_category = mime_type.split("/")[0]
|
||||
detected = chardet.detect(data)
|
||||
encoding = detected["encoding"]
|
||||
|
||||
mime_category = mime_type.split("/")[0] if mime_type else None
|
||||
if mime_category == "text":
|
||||
# For text-based files (including CSV, MD)
|
||||
return data.decode(encoding)
|
||||
|
|
@ -200,6 +206,18 @@ class EmbeddingIndex(ABC):
|
|||
async def query_keyword(self, query_string: str, k: int, score_threshold: float) -> QueryChunksResponse:
|
||||
raise NotImplementedError()
|
||||
|
||||
@abstractmethod
|
||||
async def query_hybrid(
|
||||
self,
|
||||
embedding: NDArray,
|
||||
query_string: str,
|
||||
k: int,
|
||||
score_threshold: float,
|
||||
reranker_type: str,
|
||||
reranker_params: dict[str, Any] | None = None,
|
||||
) -> QueryChunksResponse:
|
||||
raise NotImplementedError()
|
||||
|
||||
@abstractmethod
|
||||
async def delete(self):
|
||||
raise NotImplementedError()
|
||||
|
|
@ -243,10 +261,29 @@ class VectorDBWithIndex:
|
|||
k = params.get("max_chunks", 3)
|
||||
mode = params.get("mode")
|
||||
score_threshold = params.get("score_threshold", 0.0)
|
||||
|
||||
# Get ranker configuration
|
||||
ranker = params.get("ranker")
|
||||
if ranker is None:
|
||||
# Default to RRF with impact_factor=60.0
|
||||
reranker_type = RERANKER_TYPE_RRF
|
||||
reranker_params = {"impact_factor": 60.0}
|
||||
else:
|
||||
reranker_type = ranker.type
|
||||
reranker_params = (
|
||||
{"impact_factor": ranker.impact_factor} if ranker.type == RERANKER_TYPE_RRF else {"alpha": ranker.alpha}
|
||||
)
|
||||
|
||||
query_string = interleaved_content_as_str(query)
|
||||
if mode == "keyword":
|
||||
return await self.index.query_keyword(query_string, k, score_threshold)
|
||||
|
||||
# Calculate embeddings for both vector and hybrid modes
|
||||
embeddings_response = await self.inference_api.embeddings(self.vector_db.embedding_model, [query_string])
|
||||
query_vector = np.array(embeddings_response.embeddings[0], dtype=np.float32)
|
||||
if mode == "hybrid":
|
||||
return await self.index.query_hybrid(
|
||||
query_vector, query_string, k, score_threshold, reranker_type, reranker_params
|
||||
)
|
||||
else:
|
||||
embeddings_response = await self.inference_api.embeddings(self.vector_db.embedding_model, [query_string])
|
||||
query_vector = np.array(embeddings_response.embeddings[0], dtype=np.float32)
|
||||
return await self.index.query_vector(query_vector, k, score_threshold)
|
||||
|
|
|
|||
|
|
@ -23,6 +23,8 @@ distribution_spec:
|
|||
- inline::basic
|
||||
- inline::llm-as-judge
|
||||
- inline::braintrust
|
||||
files:
|
||||
- inline::localfs
|
||||
post_training:
|
||||
- inline::huggingface
|
||||
tool_runtime:
|
||||
|
|
|
|||
|
|
@ -13,6 +13,7 @@ from llama_stack.distribution.datatypes import (
|
|||
ShieldInput,
|
||||
ToolGroupInput,
|
||||
)
|
||||
from llama_stack.providers.inline.files.localfs.config import LocalfsFilesImplConfig
|
||||
from llama_stack.providers.inline.post_training.huggingface import HuggingFacePostTrainingConfig
|
||||
from llama_stack.providers.inline.vector_io.faiss.config import FaissVectorIOConfig
|
||||
from llama_stack.providers.remote.inference.ollama import OllamaImplConfig
|
||||
|
|
@ -29,6 +30,7 @@ def get_distribution_template() -> DistributionTemplate:
|
|||
"eval": ["inline::meta-reference"],
|
||||
"datasetio": ["remote::huggingface", "inline::localfs"],
|
||||
"scoring": ["inline::basic", "inline::llm-as-judge", "inline::braintrust"],
|
||||
"files": ["inline::localfs"],
|
||||
"post_training": ["inline::huggingface"],
|
||||
"tool_runtime": [
|
||||
"remote::brave-search",
|
||||
|
|
@ -49,6 +51,11 @@ def get_distribution_template() -> DistributionTemplate:
|
|||
provider_type="inline::faiss",
|
||||
config=FaissVectorIOConfig.sample_run_config(f"~/.llama/distributions/{name}"),
|
||||
)
|
||||
files_provider = Provider(
|
||||
provider_id="meta-reference-files",
|
||||
provider_type="inline::localfs",
|
||||
config=LocalfsFilesImplConfig.sample_run_config(f"~/.llama/distributions/{name}"),
|
||||
)
|
||||
posttraining_provider = Provider(
|
||||
provider_id="huggingface",
|
||||
provider_type="inline::huggingface",
|
||||
|
|
@ -98,6 +105,7 @@ def get_distribution_template() -> DistributionTemplate:
|
|||
provider_overrides={
|
||||
"inference": [inference_provider],
|
||||
"vector_io": [vector_io_provider_faiss],
|
||||
"files": [files_provider],
|
||||
"post_training": [posttraining_provider],
|
||||
},
|
||||
default_models=[inference_model, embedding_model],
|
||||
|
|
@ -107,6 +115,7 @@ def get_distribution_template() -> DistributionTemplate:
|
|||
provider_overrides={
|
||||
"inference": [inference_provider],
|
||||
"vector_io": [vector_io_provider_faiss],
|
||||
"files": [files_provider],
|
||||
"post_training": [posttraining_provider],
|
||||
"safety": [
|
||||
Provider(
|
||||
|
|
|
|||
|
|
@ -4,6 +4,7 @@ apis:
|
|||
- agents
|
||||
- datasetio
|
||||
- eval
|
||||
- files
|
||||
- inference
|
||||
- post_training
|
||||
- safety
|
||||
|
|
@ -84,6 +85,14 @@ providers:
|
|||
provider_type: inline::braintrust
|
||||
config:
|
||||
openai_api_key: ${env.OPENAI_API_KEY:}
|
||||
files:
|
||||
- provider_id: meta-reference-files
|
||||
provider_type: inline::localfs
|
||||
config:
|
||||
storage_dir: ${env.FILES_STORAGE_DIR:~/.llama/distributions/ollama/files}
|
||||
metadata_store:
|
||||
type: sqlite
|
||||
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/ollama}/files_metadata.db
|
||||
post_training:
|
||||
- provider_id: huggingface
|
||||
provider_type: inline::huggingface
|
||||
|
|
|
|||
|
|
@ -4,6 +4,7 @@ apis:
|
|||
- agents
|
||||
- datasetio
|
||||
- eval
|
||||
- files
|
||||
- inference
|
||||
- post_training
|
||||
- safety
|
||||
|
|
@ -82,6 +83,14 @@ providers:
|
|||
provider_type: inline::braintrust
|
||||
config:
|
||||
openai_api_key: ${env.OPENAI_API_KEY:}
|
||||
files:
|
||||
- provider_id: meta-reference-files
|
||||
provider_type: inline::localfs
|
||||
config:
|
||||
storage_dir: ${env.FILES_STORAGE_DIR:~/.llama/distributions/ollama/files}
|
||||
metadata_store:
|
||||
type: sqlite
|
||||
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/ollama}/files_metadata.db
|
||||
post_training:
|
||||
- provider_id: huggingface
|
||||
provider_type: inline::huggingface
|
||||
|
|
|
|||
|
|
@ -17,6 +17,8 @@ distribution_spec:
|
|||
- inline::sqlite-vec
|
||||
- remote::chromadb
|
||||
- remote::pgvector
|
||||
files:
|
||||
- inline::localfs
|
||||
safety:
|
||||
- inline::llama-guard
|
||||
agents:
|
||||
|
|
|
|||
|
|
@ -4,6 +4,7 @@ apis:
|
|||
- agents
|
||||
- datasetio
|
||||
- eval
|
||||
- files
|
||||
- inference
|
||||
- safety
|
||||
- scoring
|
||||
|
|
@ -75,6 +76,14 @@ providers:
|
|||
db: ${env.PGVECTOR_DB:}
|
||||
user: ${env.PGVECTOR_USER:}
|
||||
password: ${env.PGVECTOR_PASSWORD:}
|
||||
files:
|
||||
- provider_id: meta-reference-files
|
||||
provider_type: inline::localfs
|
||||
config:
|
||||
storage_dir: ${env.FILES_STORAGE_DIR:~/.llama/distributions/starter/files}
|
||||
metadata_store:
|
||||
type: sqlite
|
||||
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/starter}/files_metadata.db
|
||||
safety:
|
||||
- provider_id: llama-guard
|
||||
provider_type: inline::llama-guard
|
||||
|
|
|
|||
|
|
@ -12,6 +12,7 @@ from llama_stack.distribution.datatypes import (
|
|||
ShieldInput,
|
||||
ToolGroupInput,
|
||||
)
|
||||
from llama_stack.providers.inline.files.localfs.config import LocalfsFilesImplConfig
|
||||
from llama_stack.providers.inline.inference.sentence_transformers import (
|
||||
SentenceTransformersInferenceConfig,
|
||||
)
|
||||
|
|
@ -134,6 +135,7 @@ def get_distribution_template() -> DistributionTemplate:
|
|||
providers = {
|
||||
"inference": ([p.provider_type for p in inference_providers] + ["inline::sentence-transformers"]),
|
||||
"vector_io": ["inline::sqlite-vec", "remote::chromadb", "remote::pgvector"],
|
||||
"files": ["inline::localfs"],
|
||||
"safety": ["inline::llama-guard"],
|
||||
"agents": ["inline::meta-reference"],
|
||||
"telemetry": ["inline::meta-reference"],
|
||||
|
|
@ -170,6 +172,11 @@ def get_distribution_template() -> DistributionTemplate:
|
|||
),
|
||||
),
|
||||
]
|
||||
files_provider = Provider(
|
||||
provider_id="meta-reference-files",
|
||||
provider_type="inline::localfs",
|
||||
config=LocalfsFilesImplConfig.sample_run_config(f"~/.llama/distributions/{name}"),
|
||||
)
|
||||
embedding_provider = Provider(
|
||||
provider_id="sentence-transformers",
|
||||
provider_type="inline::sentence-transformers",
|
||||
|
|
@ -212,6 +219,7 @@ def get_distribution_template() -> DistributionTemplate:
|
|||
provider_overrides={
|
||||
"inference": inference_providers + [embedding_provider],
|
||||
"vector_io": vector_io_providers,
|
||||
"files": [files_provider],
|
||||
},
|
||||
default_models=default_models + [embedding_model],
|
||||
default_tool_groups=default_tool_groups,
|
||||
|
|
|
|||
|
|
@ -84,6 +84,7 @@ unit = [
|
|||
"sqlalchemy",
|
||||
"sqlalchemy[asyncio]>=2.0.41",
|
||||
"blobfile",
|
||||
"faiss-cpu"
|
||||
]
|
||||
# These are the core dependencies required for running integration tests. They are shared across all
|
||||
# providers. If a provider requires additional dependencies, please add them to your environment
|
||||
|
|
|
|||
|
|
@ -147,7 +147,7 @@ referencing==0.36.2
|
|||
# jsonschema-specifications
|
||||
regex==2024.11.6
|
||||
# via tiktoken
|
||||
requests==2.32.3
|
||||
requests==2.32.4
|
||||
# via
|
||||
# huggingface-hub
|
||||
# llama-stack
|
||||
|
|
|
|||
|
|
@ -22,9 +22,6 @@ def provider_from_model(client_with_models, model_id):
|
|||
|
||||
|
||||
def skip_if_model_doesnt_support_openai_completion(client_with_models, model_id):
|
||||
if isinstance(client_with_models, LlamaStackAsLibraryClient):
|
||||
pytest.skip("OpenAI completions are not supported when testing with library client yet.")
|
||||
|
||||
provider = provider_from_model(client_with_models, model_id)
|
||||
if provider.provider_type in (
|
||||
"inline::meta-reference",
|
||||
|
|
@ -44,6 +41,23 @@ def skip_if_model_doesnt_support_openai_completion(client_with_models, model_id)
|
|||
pytest.skip(f"Model {model_id} hosted by {provider.provider_type} doesn't support OpenAI completions.")
|
||||
|
||||
|
||||
def skip_if_model_doesnt_support_suffix(client_with_models, model_id):
|
||||
# To test `fim` ( fill in the middle ) completion, we need to use a model that supports suffix.
|
||||
# Use this to specifically test this API functionality.
|
||||
|
||||
# pytest -sv --stack-config="inference=ollama" \
|
||||
# tests/integration/inference/test_openai_completion.py \
|
||||
# --text-model qwen2.5-coder:1.5b \
|
||||
# -k test_openai_completion_non_streaming_suffix
|
||||
|
||||
if model_id != "qwen2.5-coder:1.5b":
|
||||
pytest.skip(f"Suffix is not supported for the model: {model_id}.")
|
||||
|
||||
provider = provider_from_model(client_with_models, model_id)
|
||||
if provider.provider_type != "remote::ollama":
|
||||
pytest.skip(f"Provider {provider.provider_type} doesn't support suffix.")
|
||||
|
||||
|
||||
def skip_if_model_doesnt_support_openai_chat_completion(client_with_models, model_id):
|
||||
if isinstance(client_with_models, LlamaStackAsLibraryClient):
|
||||
pytest.skip("OpenAI chat completions are not supported when testing with library client yet.")
|
||||
|
|
@ -102,6 +116,32 @@ def test_openai_completion_non_streaming(llama_stack_client, client_with_models,
|
|||
assert len(choice.text) > 10
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"test_case",
|
||||
[
|
||||
"inference:completion:suffix",
|
||||
],
|
||||
)
|
||||
def test_openai_completion_non_streaming_suffix(llama_stack_client, client_with_models, text_model_id, test_case):
|
||||
skip_if_model_doesnt_support_openai_completion(client_with_models, text_model_id)
|
||||
skip_if_model_doesnt_support_suffix(client_with_models, text_model_id)
|
||||
tc = TestCase(test_case)
|
||||
|
||||
# ollama needs more verbose prompting for some reason here...
|
||||
response = llama_stack_client.completions.create(
|
||||
model=text_model_id,
|
||||
prompt=tc["content"],
|
||||
stream=False,
|
||||
suffix=tc["suffix"],
|
||||
max_tokens=10,
|
||||
)
|
||||
|
||||
assert len(response.choices) > 0
|
||||
choice = response.choices[0]
|
||||
assert len(choice.text) > 5
|
||||
assert "france" in choice.text.lower()
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"test_case",
|
||||
[
|
||||
|
|
@ -372,10 +412,14 @@ def test_inference_store_tool_calls(compat_client, client_with_models, text_mode
|
|||
)
|
||||
assert input_content == message, retrieved_response
|
||||
tool_calls = retrieved_response.choices[0].message.tool_calls
|
||||
# sometimes model doesn't ouptut tool calls, but we still want to test that the tool was called
|
||||
# sometimes model doesn't output tool calls, but we still want to test that the tool was called
|
||||
if tool_calls:
|
||||
# because we test with small models, just check that we retrieved
|
||||
# a tool call with a name and arguments string, but ignore contents
|
||||
assert len(tool_calls) == 1
|
||||
assert tool_calls[0].function.name == "get_weather"
|
||||
assert "tokyo" in tool_calls[0].function.arguments.lower()
|
||||
assert tool_calls[0].function.name
|
||||
assert tool_calls[0].function.arguments
|
||||
else:
|
||||
# failed tool call parses show up as a message with content, so ensure
|
||||
# that the retrieve response content matches the original request
|
||||
assert retrieved_response.choices[0].message.content == content
|
||||
|
|
|
|||
|
|
@ -34,11 +34,15 @@ def skip_if_model_doesnt_support_variable_dimensions(model_id):
|
|||
pytest.skip("{model_id} does not support variable output embedding dimensions")
|
||||
|
||||
|
||||
def skip_if_model_doesnt_support_openai_embeddings(client_with_models, model_id):
|
||||
if isinstance(client_with_models, LlamaStackAsLibraryClient):
|
||||
pytest.skip("OpenAI embeddings are not supported when testing with library client yet.")
|
||||
@pytest.fixture(params=["openai_client", "llama_stack_client"])
|
||||
def compat_client(request, client_with_models):
|
||||
if request.param == "openai_client" and isinstance(client_with_models, LlamaStackAsLibraryClient):
|
||||
pytest.skip("OpenAI client tests not supported with library client")
|
||||
return request.getfixturevalue(request.param)
|
||||
|
||||
provider = provider_from_model(client_with_models, model_id)
|
||||
|
||||
def skip_if_model_doesnt_support_openai_embeddings(client, model_id):
|
||||
provider = provider_from_model(client, model_id)
|
||||
if provider.provider_type in (
|
||||
"inline::meta-reference",
|
||||
"remote::bedrock",
|
||||
|
|
@ -47,7 +51,6 @@ def skip_if_model_doesnt_support_openai_embeddings(client_with_models, model_id)
|
|||
"remote::runpod",
|
||||
"remote::sambanova",
|
||||
"remote::tgi",
|
||||
"remote::ollama",
|
||||
):
|
||||
pytest.skip(f"Model {model_id} hosted by {provider.provider_type} doesn't support OpenAI embeddings.")
|
||||
|
||||
|
|
@ -58,13 +61,13 @@ def openai_client(client_with_models):
|
|||
return OpenAI(base_url=base_url, api_key="fake")
|
||||
|
||||
|
||||
def test_openai_embeddings_single_string(openai_client, client_with_models, embedding_model_id):
|
||||
def test_openai_embeddings_single_string(compat_client, client_with_models, embedding_model_id):
|
||||
"""Test OpenAI embeddings endpoint with a single string input."""
|
||||
skip_if_model_doesnt_support_openai_embeddings(client_with_models, embedding_model_id)
|
||||
|
||||
input_text = "Hello, world!"
|
||||
|
||||
response = openai_client.embeddings.create(
|
||||
response = compat_client.embeddings.create(
|
||||
model=embedding_model_id,
|
||||
input=input_text,
|
||||
encoding_format="float",
|
||||
|
|
@ -80,13 +83,13 @@ def test_openai_embeddings_single_string(openai_client, client_with_models, embe
|
|||
assert all(isinstance(x, float) for x in response.data[0].embedding)
|
||||
|
||||
|
||||
def test_openai_embeddings_multiple_strings(openai_client, client_with_models, embedding_model_id):
|
||||
def test_openai_embeddings_multiple_strings(compat_client, client_with_models, embedding_model_id):
|
||||
"""Test OpenAI embeddings endpoint with multiple string inputs."""
|
||||
skip_if_model_doesnt_support_openai_embeddings(client_with_models, embedding_model_id)
|
||||
|
||||
input_texts = ["Hello, world!", "How are you today?", "This is a test."]
|
||||
|
||||
response = openai_client.embeddings.create(
|
||||
response = compat_client.embeddings.create(
|
||||
model=embedding_model_id,
|
||||
input=input_texts,
|
||||
)
|
||||
|
|
@ -103,13 +106,13 @@ def test_openai_embeddings_multiple_strings(openai_client, client_with_models, e
|
|||
assert all(isinstance(x, float) for x in embedding_data.embedding)
|
||||
|
||||
|
||||
def test_openai_embeddings_with_encoding_format_float(openai_client, client_with_models, embedding_model_id):
|
||||
def test_openai_embeddings_with_encoding_format_float(compat_client, client_with_models, embedding_model_id):
|
||||
"""Test OpenAI embeddings endpoint with float encoding format."""
|
||||
skip_if_model_doesnt_support_openai_embeddings(client_with_models, embedding_model_id)
|
||||
|
||||
input_text = "Test encoding format"
|
||||
|
||||
response = openai_client.embeddings.create(
|
||||
response = compat_client.embeddings.create(
|
||||
model=embedding_model_id,
|
||||
input=input_text,
|
||||
encoding_format="float",
|
||||
|
|
@ -121,7 +124,7 @@ def test_openai_embeddings_with_encoding_format_float(openai_client, client_with
|
|||
assert all(isinstance(x, float) for x in response.data[0].embedding)
|
||||
|
||||
|
||||
def test_openai_embeddings_with_dimensions(openai_client, client_with_models, embedding_model_id):
|
||||
def test_openai_embeddings_with_dimensions(compat_client, client_with_models, embedding_model_id):
|
||||
"""Test OpenAI embeddings endpoint with custom dimensions parameter."""
|
||||
skip_if_model_doesnt_support_openai_embeddings(client_with_models, embedding_model_id)
|
||||
skip_if_model_doesnt_support_variable_dimensions(embedding_model_id)
|
||||
|
|
@ -129,7 +132,7 @@ def test_openai_embeddings_with_dimensions(openai_client, client_with_models, em
|
|||
input_text = "Test dimensions parameter"
|
||||
dimensions = 16
|
||||
|
||||
response = openai_client.embeddings.create(
|
||||
response = compat_client.embeddings.create(
|
||||
model=embedding_model_id,
|
||||
input=input_text,
|
||||
dimensions=dimensions,
|
||||
|
|
@ -142,14 +145,14 @@ def test_openai_embeddings_with_dimensions(openai_client, client_with_models, em
|
|||
assert len(response.data[0].embedding) > 0
|
||||
|
||||
|
||||
def test_openai_embeddings_with_user_parameter(openai_client, client_with_models, embedding_model_id):
|
||||
def test_openai_embeddings_with_user_parameter(compat_client, client_with_models, embedding_model_id):
|
||||
"""Test OpenAI embeddings endpoint with user parameter."""
|
||||
skip_if_model_doesnt_support_openai_embeddings(client_with_models, embedding_model_id)
|
||||
|
||||
input_text = "Test user parameter"
|
||||
user_id = "test-user-123"
|
||||
|
||||
response = openai_client.embeddings.create(
|
||||
response = compat_client.embeddings.create(
|
||||
model=embedding_model_id,
|
||||
input=input_text,
|
||||
user=user_id,
|
||||
|
|
@ -161,41 +164,41 @@ def test_openai_embeddings_with_user_parameter(openai_client, client_with_models
|
|||
assert len(response.data[0].embedding) > 0
|
||||
|
||||
|
||||
def test_openai_embeddings_empty_list_error(openai_client, client_with_models, embedding_model_id):
|
||||
def test_openai_embeddings_empty_list_error(compat_client, client_with_models, embedding_model_id):
|
||||
"""Test that empty list input raises an appropriate error."""
|
||||
skip_if_model_doesnt_support_openai_embeddings(client_with_models, embedding_model_id)
|
||||
|
||||
with pytest.raises(Exception): # noqa: B017
|
||||
openai_client.embeddings.create(
|
||||
compat_client.embeddings.create(
|
||||
model=embedding_model_id,
|
||||
input=[],
|
||||
)
|
||||
|
||||
|
||||
def test_openai_embeddings_invalid_model_error(openai_client, client_with_models, embedding_model_id):
|
||||
def test_openai_embeddings_invalid_model_error(compat_client, client_with_models, embedding_model_id):
|
||||
"""Test that invalid model ID raises an appropriate error."""
|
||||
skip_if_model_doesnt_support_openai_embeddings(client_with_models, embedding_model_id)
|
||||
|
||||
with pytest.raises(Exception): # noqa: B017
|
||||
openai_client.embeddings.create(
|
||||
compat_client.embeddings.create(
|
||||
model="invalid-model-id",
|
||||
input="Test text",
|
||||
)
|
||||
|
||||
|
||||
def test_openai_embeddings_different_inputs_different_outputs(openai_client, client_with_models, embedding_model_id):
|
||||
def test_openai_embeddings_different_inputs_different_outputs(compat_client, client_with_models, embedding_model_id):
|
||||
"""Test that different inputs produce different embeddings."""
|
||||
skip_if_model_doesnt_support_openai_embeddings(client_with_models, embedding_model_id)
|
||||
|
||||
input_text1 = "This is the first text"
|
||||
input_text2 = "This is completely different content"
|
||||
|
||||
response1 = openai_client.embeddings.create(
|
||||
response1 = compat_client.embeddings.create(
|
||||
model=embedding_model_id,
|
||||
input=input_text1,
|
||||
)
|
||||
|
||||
response2 = openai_client.embeddings.create(
|
||||
response2 = compat_client.embeddings.create(
|
||||
model=embedding_model_id,
|
||||
input=input_text2,
|
||||
)
|
||||
|
|
@ -208,7 +211,7 @@ def test_openai_embeddings_different_inputs_different_outputs(openai_client, cli
|
|||
assert embedding1 != embedding2
|
||||
|
||||
|
||||
def test_openai_embeddings_with_encoding_format_base64(openai_client, client_with_models, embedding_model_id):
|
||||
def test_openai_embeddings_with_encoding_format_base64(compat_client, client_with_models, embedding_model_id):
|
||||
"""Test OpenAI embeddings endpoint with base64 encoding format."""
|
||||
skip_if_model_doesnt_support_openai_embeddings(client_with_models, embedding_model_id)
|
||||
skip_if_model_doesnt_support_variable_dimensions(embedding_model_id)
|
||||
|
|
@ -216,7 +219,7 @@ def test_openai_embeddings_with_encoding_format_base64(openai_client, client_wit
|
|||
input_text = "Test base64 encoding format"
|
||||
dimensions = 12
|
||||
|
||||
response = openai_client.embeddings.create(
|
||||
response = compat_client.embeddings.create(
|
||||
model=embedding_model_id,
|
||||
input=input_text,
|
||||
encoding_format="base64",
|
||||
|
|
@ -241,13 +244,13 @@ def test_openai_embeddings_with_encoding_format_base64(openai_client, client_wit
|
|||
assert all(isinstance(x, float) for x in embedding_floats)
|
||||
|
||||
|
||||
def test_openai_embeddings_base64_batch_processing(openai_client, client_with_models, embedding_model_id):
|
||||
def test_openai_embeddings_base64_batch_processing(compat_client, client_with_models, embedding_model_id):
|
||||
"""Test OpenAI embeddings endpoint with base64 encoding for batch processing."""
|
||||
skip_if_model_doesnt_support_openai_embeddings(client_with_models, embedding_model_id)
|
||||
|
||||
input_texts = ["First text for base64", "Second text for base64", "Third text for base64"]
|
||||
|
||||
response = openai_client.embeddings.create(
|
||||
response = compat_client.embeddings.create(
|
||||
model=embedding_model_id,
|
||||
input=input_texts,
|
||||
encoding_format="base64",
|
||||
|
|
|
|||
|
|
@ -4,6 +4,12 @@
|
|||
"content": "Complete the sentence using one word: Roses are red, violets are "
|
||||
}
|
||||
},
|
||||
"suffix": {
|
||||
"data": {
|
||||
"content": "The capital of ",
|
||||
"suffix": "is Paris."
|
||||
}
|
||||
},
|
||||
"non_streaming": {
|
||||
"data": {
|
||||
"content": "Micheael Jordan is born in ",
|
||||
|
|
|
|||
425
tests/integration/vector_io/test_openai_vector_stores.py
Normal file
425
tests/integration/vector_io/test_openai_vector_stores.py
Normal file
|
|
@ -0,0 +1,425 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import logging
|
||||
import time
|
||||
|
||||
import pytest
|
||||
from openai import OpenAI
|
||||
|
||||
from llama_stack.apis.vector_io import Chunk
|
||||
from llama_stack.distribution.library_client import LlamaStackAsLibraryClient
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def skip_if_provider_doesnt_support_openai_vector_stores(client_with_models):
|
||||
vector_io_providers = [p for p in client_with_models.providers.list() if p.api == "vector_io"]
|
||||
for p in vector_io_providers:
|
||||
if p.provider_type in ["inline::faiss", "inline::sqlite-vec"]:
|
||||
return
|
||||
|
||||
pytest.skip("OpenAI vector stores are not supported by any provider")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def openai_client(client_with_models):
|
||||
base_url = f"{client_with_models.base_url}/v1/openai/v1"
|
||||
return OpenAI(base_url=base_url, api_key="fake")
|
||||
|
||||
|
||||
@pytest.fixture(params=["openai_client", "llama_stack_client"])
|
||||
def compat_client(request, client_with_models):
|
||||
if request.param == "openai_client" and isinstance(client_with_models, LlamaStackAsLibraryClient):
|
||||
pytest.skip("OpenAI client tests not supported with library client")
|
||||
return request.getfixturevalue(request.param)
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def sample_chunks():
|
||||
return [
|
||||
Chunk(
|
||||
content="Python is a high-level programming language that emphasizes code readability and allows programmers to express concepts in fewer lines of code than would be possible in languages such as C++ or Java.",
|
||||
metadata={"document_id": "doc1", "topic": "programming"},
|
||||
),
|
||||
Chunk(
|
||||
content="Machine learning is a subset of artificial intelligence that enables systems to automatically learn and improve from experience without being explicitly programmed, using statistical techniques to give computer systems the ability to progressively improve performance on a specific task.",
|
||||
metadata={"document_id": "doc2", "topic": "ai"},
|
||||
),
|
||||
Chunk(
|
||||
content="Data structures are fundamental to computer science because they provide organized ways to store and access data efficiently, enable faster processing of data through optimized algorithms, and form the building blocks for more complex software systems.",
|
||||
metadata={"document_id": "doc3", "topic": "computer_science"},
|
||||
),
|
||||
Chunk(
|
||||
content="Neural networks are inspired by biological neural networks found in animal brains, using interconnected nodes called artificial neurons to process information through weighted connections that can be trained to recognize patterns and solve complex problems through iterative learning.",
|
||||
metadata={"document_id": "doc4", "topic": "ai"},
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def compat_client_with_empty_stores(compat_client):
|
||||
def clear_vector_stores():
|
||||
# List and delete all existing vector stores
|
||||
try:
|
||||
response = compat_client.vector_stores.list()
|
||||
for store in response.data:
|
||||
compat_client.vector_stores.delete(vector_store_id=store.id)
|
||||
except Exception:
|
||||
# If the API is not available or fails, just continue
|
||||
logger.warning("Failed to clear vector stores")
|
||||
pass
|
||||
|
||||
clear_vector_stores()
|
||||
yield compat_client
|
||||
|
||||
# Clean up after the test
|
||||
clear_vector_stores()
|
||||
|
||||
|
||||
def test_openai_create_vector_store(compat_client_with_empty_stores, client_with_models):
|
||||
"""Test creating a vector store using OpenAI API."""
|
||||
skip_if_provider_doesnt_support_openai_vector_stores(client_with_models)
|
||||
client = compat_client_with_empty_stores
|
||||
|
||||
# Create a vector store
|
||||
vector_store = client.vector_stores.create(
|
||||
name="test_vector_store", metadata={"purpose": "testing", "environment": "integration"}
|
||||
)
|
||||
|
||||
assert vector_store is not None
|
||||
assert vector_store.name == "test_vector_store"
|
||||
assert vector_store.object == "vector_store"
|
||||
assert vector_store.status in ["completed", "in_progress"]
|
||||
assert vector_store.metadata["purpose"] == "testing"
|
||||
assert vector_store.metadata["environment"] == "integration"
|
||||
assert hasattr(vector_store, "id")
|
||||
assert hasattr(vector_store, "created_at")
|
||||
|
||||
|
||||
def test_openai_list_vector_stores(compat_client_with_empty_stores, client_with_models):
|
||||
"""Test listing vector stores using OpenAI API."""
|
||||
skip_if_provider_doesnt_support_openai_vector_stores(client_with_models)
|
||||
|
||||
client = compat_client_with_empty_stores
|
||||
|
||||
# Create a few vector stores
|
||||
store1 = client.vector_stores.create(name="store1", metadata={"type": "test"})
|
||||
store2 = client.vector_stores.create(name="store2", metadata={"type": "test"})
|
||||
|
||||
# List vector stores
|
||||
response = client.vector_stores.list()
|
||||
|
||||
assert response is not None
|
||||
assert hasattr(response, "data")
|
||||
assert len(response.data) >= 2
|
||||
|
||||
# Check that our stores are in the list
|
||||
store_ids = [store.id for store in response.data]
|
||||
assert store1.id in store_ids
|
||||
assert store2.id in store_ids
|
||||
|
||||
# Test pagination with limit
|
||||
limited_response = client.vector_stores.list(limit=1)
|
||||
assert len(limited_response.data) == 1
|
||||
|
||||
|
||||
def test_openai_retrieve_vector_store(compat_client_with_empty_stores, client_with_models):
|
||||
"""Test retrieving a specific vector store using OpenAI API."""
|
||||
skip_if_provider_doesnt_support_openai_vector_stores(client_with_models)
|
||||
|
||||
client = compat_client_with_empty_stores
|
||||
|
||||
# Create a vector store
|
||||
created_store = client.vector_stores.create(name="retrieve_test_store", metadata={"purpose": "retrieval_test"})
|
||||
|
||||
# Retrieve the store
|
||||
retrieved_store = client.vector_stores.retrieve(vector_store_id=created_store.id)
|
||||
|
||||
assert retrieved_store is not None
|
||||
assert retrieved_store.id == created_store.id
|
||||
assert retrieved_store.name == "retrieve_test_store"
|
||||
assert retrieved_store.metadata["purpose"] == "retrieval_test"
|
||||
assert retrieved_store.object == "vector_store"
|
||||
|
||||
|
||||
def test_openai_update_vector_store(compat_client_with_empty_stores, client_with_models):
|
||||
"""Test modifying a vector store using OpenAI API."""
|
||||
skip_if_provider_doesnt_support_openai_vector_stores(client_with_models)
|
||||
|
||||
client = compat_client_with_empty_stores
|
||||
|
||||
# Create a vector store
|
||||
created_store = client.vector_stores.create(name="original_name", metadata={"version": "1.0"})
|
||||
time.sleep(1)
|
||||
# Modify the store
|
||||
modified_store = client.vector_stores.update(
|
||||
vector_store_id=created_store.id, name="modified_name", metadata={"version": "1.1", "updated": "true"}
|
||||
)
|
||||
|
||||
assert modified_store is not None
|
||||
assert modified_store.id == created_store.id
|
||||
assert modified_store.name == "modified_name"
|
||||
assert modified_store.metadata["version"] == "1.1"
|
||||
assert modified_store.metadata["updated"] == "true"
|
||||
# last_active_at should be updated
|
||||
assert modified_store.last_active_at > created_store.last_active_at
|
||||
|
||||
|
||||
def test_openai_delete_vector_store(compat_client_with_empty_stores, client_with_models):
|
||||
"""Test deleting a vector store using OpenAI API."""
|
||||
skip_if_provider_doesnt_support_openai_vector_stores(client_with_models)
|
||||
|
||||
client = compat_client_with_empty_stores
|
||||
|
||||
# Create a vector store
|
||||
created_store = client.vector_stores.create(name="delete_test_store", metadata={"purpose": "deletion_test"})
|
||||
|
||||
# Delete the store
|
||||
delete_response = client.vector_stores.delete(vector_store_id=created_store.id)
|
||||
|
||||
assert delete_response is not None
|
||||
assert delete_response.id == created_store.id
|
||||
assert delete_response.deleted is True
|
||||
assert delete_response.object == "vector_store.deleted"
|
||||
|
||||
# Verify the store is deleted - attempting to retrieve should fail
|
||||
with pytest.raises(Exception): # noqa: B017
|
||||
client.vector_stores.retrieve(vector_store_id=created_store.id)
|
||||
|
||||
|
||||
def test_openai_vector_store_search_empty(compat_client_with_empty_stores, client_with_models):
|
||||
"""Test searching an empty vector store using OpenAI API."""
|
||||
skip_if_provider_doesnt_support_openai_vector_stores(client_with_models)
|
||||
|
||||
client = compat_client_with_empty_stores
|
||||
|
||||
# Create a vector store
|
||||
vector_store = client.vector_stores.create(name="search_test_store", metadata={"purpose": "search_testing"})
|
||||
|
||||
# Search the empty store
|
||||
search_response = client.vector_stores.search(
|
||||
vector_store_id=vector_store.id, query="test query", max_num_results=5
|
||||
)
|
||||
|
||||
assert search_response is not None
|
||||
assert hasattr(search_response, "data")
|
||||
assert len(search_response.data) == 0 # Empty store should return no results
|
||||
assert search_response.search_query == "test query"
|
||||
assert search_response.has_more is False
|
||||
|
||||
|
||||
def test_openai_vector_store_with_chunks(compat_client_with_empty_stores, client_with_models, sample_chunks):
|
||||
"""Test vector store functionality with actual chunks using both OpenAI and native APIs."""
|
||||
skip_if_provider_doesnt_support_openai_vector_stores(client_with_models)
|
||||
|
||||
compat_client = compat_client_with_empty_stores
|
||||
llama_client = client_with_models
|
||||
|
||||
# Create a vector store using OpenAI API
|
||||
vector_store = compat_client.vector_stores.create(name="chunks_test_store", metadata={"purpose": "chunks_testing"})
|
||||
|
||||
# Insert chunks using the native LlamaStack API (since OpenAI API doesn't have direct chunk insertion)
|
||||
llama_client.vector_io.insert(
|
||||
vector_db_id=vector_store.id,
|
||||
chunks=sample_chunks,
|
||||
)
|
||||
|
||||
# Search using OpenAI API
|
||||
search_response = compat_client.vector_stores.search(
|
||||
vector_store_id=vector_store.id, query="What is Python programming language?", max_num_results=3
|
||||
)
|
||||
assert search_response is not None
|
||||
assert len(search_response.data) > 0
|
||||
|
||||
# The top result should be about Python (doc1)
|
||||
top_result = search_response.data[0]
|
||||
top_content = top_result.content[0].text
|
||||
assert "python" in top_content.lower() or "programming" in top_content.lower()
|
||||
assert top_result.attributes["document_id"] == "doc1"
|
||||
|
||||
# Test filtering by metadata
|
||||
filtered_search = compat_client.vector_stores.search(
|
||||
vector_store_id=vector_store.id, query="artificial intelligence", filters={"topic": "ai"}, max_num_results=5
|
||||
)
|
||||
|
||||
assert filtered_search is not None
|
||||
# All results should have topic "ai"
|
||||
for result in filtered_search.data:
|
||||
assert result.attributes["topic"] == "ai"
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"test_case",
|
||||
[
|
||||
("What makes Python different from other languages?", "doc1", "programming"),
|
||||
("How do systems learn automatically?", "doc2", "ai"),
|
||||
("Why are data structures important?", "doc3", "computer_science"),
|
||||
("What inspires neural networks?", "doc4", "ai"),
|
||||
],
|
||||
)
|
||||
def test_openai_vector_store_search_relevance(
|
||||
compat_client_with_empty_stores, client_with_models, sample_chunks, test_case
|
||||
):
|
||||
"""Test that OpenAI vector store search returns relevant results for different queries."""
|
||||
skip_if_provider_doesnt_support_openai_vector_stores(client_with_models)
|
||||
|
||||
compat_client = compat_client_with_empty_stores
|
||||
llama_client = client_with_models
|
||||
|
||||
query, expected_doc_id, expected_topic = test_case
|
||||
|
||||
# Create a vector store
|
||||
vector_store = compat_client.vector_stores.create(
|
||||
name=f"relevance_test_{expected_doc_id}", metadata={"purpose": "relevance_testing"}
|
||||
)
|
||||
|
||||
# Insert chunks using native API
|
||||
llama_client.vector_io.insert(
|
||||
vector_db_id=vector_store.id,
|
||||
chunks=sample_chunks,
|
||||
)
|
||||
|
||||
# Search using OpenAI API
|
||||
search_response = compat_client.vector_stores.search(
|
||||
vector_store_id=vector_store.id, query=query, max_num_results=4
|
||||
)
|
||||
|
||||
assert search_response is not None
|
||||
assert len(search_response.data) > 0
|
||||
|
||||
# The top result should match the expected document
|
||||
top_result = search_response.data[0]
|
||||
|
||||
assert top_result.attributes["document_id"] == expected_doc_id
|
||||
assert top_result.attributes["topic"] == expected_topic
|
||||
|
||||
# Verify score is included and reasonable
|
||||
assert isinstance(top_result.score, int | float)
|
||||
assert top_result.score > 0
|
||||
|
||||
|
||||
def test_openai_vector_store_search_with_ranking_options(
|
||||
compat_client_with_empty_stores, client_with_models, sample_chunks
|
||||
):
|
||||
"""Test OpenAI vector store search with ranking options."""
|
||||
skip_if_provider_doesnt_support_openai_vector_stores(client_with_models)
|
||||
|
||||
compat_client = compat_client_with_empty_stores
|
||||
llama_client = client_with_models
|
||||
|
||||
# Create a vector store
|
||||
vector_store = compat_client.vector_stores.create(
|
||||
name="ranking_test_store", metadata={"purpose": "ranking_testing"}
|
||||
)
|
||||
|
||||
# Insert chunks
|
||||
llama_client.vector_io.insert(
|
||||
vector_db_id=vector_store.id,
|
||||
chunks=sample_chunks,
|
||||
)
|
||||
|
||||
# Search with ranking options
|
||||
threshold = 0.1
|
||||
search_response = compat_client.vector_stores.search(
|
||||
vector_store_id=vector_store.id,
|
||||
query="machine learning and artificial intelligence",
|
||||
max_num_results=3,
|
||||
ranking_options={"score_threshold": threshold},
|
||||
)
|
||||
|
||||
assert search_response is not None
|
||||
assert len(search_response.data) > 0
|
||||
|
||||
# All results should meet the score threshold
|
||||
for result in search_response.data:
|
||||
assert result.score >= threshold
|
||||
|
||||
|
||||
def test_openai_vector_store_search_with_high_score_filter(
|
||||
compat_client_with_empty_stores, client_with_models, sample_chunks
|
||||
):
|
||||
"""Test that searching with text very similar to a document and high score threshold returns only that document."""
|
||||
skip_if_provider_doesnt_support_openai_vector_stores(client_with_models)
|
||||
|
||||
compat_client = compat_client_with_empty_stores
|
||||
llama_client = client_with_models
|
||||
|
||||
# Create a vector store
|
||||
vector_store = compat_client.vector_stores.create(
|
||||
name="high_score_filter_test", metadata={"purpose": "high_score_filtering"}
|
||||
)
|
||||
|
||||
# Insert chunks
|
||||
llama_client.vector_io.insert(
|
||||
vector_db_id=vector_store.id,
|
||||
chunks=sample_chunks,
|
||||
)
|
||||
|
||||
# Query with text very similar to the Python document (doc1)
|
||||
# This should match very closely to the first sample chunk about Python
|
||||
query = "Python is a high-level programming language with code readability and fewer lines than C++ or Java"
|
||||
|
||||
# picking up thrshold to be slightly higher than the second result
|
||||
search_response = compat_client.vector_stores.search(
|
||||
vector_store_id=vector_store.id,
|
||||
query=query,
|
||||
max_num_results=3,
|
||||
)
|
||||
assert len(search_response.data) > 1, "Expected more than one result"
|
||||
threshold = search_response.data[1].score + 0.0001
|
||||
|
||||
# we expect only one result with the requested threshold
|
||||
search_response = compat_client.vector_stores.search(
|
||||
vector_store_id=vector_store.id,
|
||||
query=query,
|
||||
max_num_results=10, # Allow more results but expect filtering
|
||||
ranking_options={"score_threshold": threshold},
|
||||
)
|
||||
|
||||
# With high threshold and similar query, we should get only the Python document
|
||||
assert len(search_response.data) == 1, "Expected only one result with high threshold"
|
||||
|
||||
# The top result should be the Python document (doc1)
|
||||
top_result = search_response.data[0]
|
||||
assert top_result.attributes["document_id"] == "doc1"
|
||||
assert top_result.attributes["topic"] == "programming"
|
||||
assert top_result.score >= threshold
|
||||
|
||||
# Verify the content contains Python-related terms
|
||||
top_content = top_result.content[0].text
|
||||
assert "python" in top_content.lower() or "programming" in top_content.lower()
|
||||
|
||||
|
||||
def test_openai_vector_store_search_with_max_num_results(
|
||||
compat_client_with_empty_stores, client_with_models, sample_chunks
|
||||
):
|
||||
"""Test OpenAI vector store search with max_num_results."""
|
||||
skip_if_provider_doesnt_support_openai_vector_stores(client_with_models)
|
||||
|
||||
compat_client = compat_client_with_empty_stores
|
||||
llama_client = client_with_models
|
||||
|
||||
# Create a vector store
|
||||
vector_store = compat_client.vector_stores.create(
|
||||
name="max_num_results_test_store", metadata={"purpose": "max_num_results_testing"}
|
||||
)
|
||||
|
||||
# Insert chunks
|
||||
llama_client.vector_io.insert(
|
||||
vector_db_id=vector_store.id,
|
||||
chunks=sample_chunks,
|
||||
)
|
||||
|
||||
# Search with max_num_results
|
||||
search_response = compat_client.vector_stores.search(
|
||||
vector_store_id=vector_store.id,
|
||||
query="machine learning and artificial intelligence",
|
||||
max_num_results=2,
|
||||
)
|
||||
|
||||
assert search_response is not None
|
||||
assert len(search_response.data) == 2
|
||||
|
|
@ -154,3 +154,36 @@ def test_insert_chunks_with_precomputed_embeddings(client_with_empty_registry, e
|
|||
assert len(response.chunks) > 0
|
||||
assert response.chunks[0].metadata["document_id"] == "doc1"
|
||||
assert response.chunks[0].metadata["source"] == "precomputed"
|
||||
|
||||
|
||||
def test_query_returns_valid_object_when_identical_to_embedding_in_vdb(client_with_empty_registry, embedding_model_id):
|
||||
vector_db_id = "test_precomputed_embeddings_db"
|
||||
client_with_empty_registry.vector_dbs.register(
|
||||
vector_db_id=vector_db_id,
|
||||
embedding_model=embedding_model_id,
|
||||
embedding_dimension=384,
|
||||
)
|
||||
|
||||
chunks_with_embeddings = [
|
||||
Chunk(
|
||||
content="duplicate",
|
||||
metadata={"document_id": "doc1", "source": "precomputed"},
|
||||
embedding=[0.1] * 384,
|
||||
),
|
||||
]
|
||||
|
||||
client_with_empty_registry.vector_io.insert(
|
||||
vector_db_id=vector_db_id,
|
||||
chunks=chunks_with_embeddings,
|
||||
)
|
||||
|
||||
response = client_with_empty_registry.vector_io.query(
|
||||
vector_db_id=vector_db_id,
|
||||
query="duplicate",
|
||||
)
|
||||
|
||||
# Verify the top result is the expected document
|
||||
assert response is not None
|
||||
assert len(response.chunks) > 0
|
||||
assert response.chunks[0].metadata["document_id"] == "doc1"
|
||||
assert response.chunks[0].metadata["source"] == "precomputed"
|
||||
|
|
|
|||
|
|
@ -11,7 +11,7 @@ import threading
|
|||
import time
|
||||
from http.server import BaseHTTPRequestHandler, HTTPServer
|
||||
from typing import Any
|
||||
from unittest.mock import AsyncMock, patch
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
|
|
@ -44,6 +44,7 @@ from llama_stack.apis.inference import (
|
|||
)
|
||||
from llama_stack.apis.models import Model
|
||||
from llama_stack.models.llama.datatypes import StopReason, ToolCall
|
||||
from llama_stack.providers.datatypes import HealthStatus
|
||||
from llama_stack.providers.remote.inference.vllm.config import VLLMInferenceAdapterConfig
|
||||
from llama_stack.providers.remote.inference.vllm.vllm import (
|
||||
VLLMInferenceAdapter,
|
||||
|
|
@ -642,3 +643,70 @@ async def test_process_vllm_chat_completion_stream_response_tool_without_args():
|
|||
assert chunks[-2].event.delta.type == "tool_call"
|
||||
assert chunks[-2].event.delta.tool_call.tool_name == mock_tool_name
|
||||
assert chunks[-2].event.delta.tool_call.arguments == {}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_health_status_success(vllm_inference_adapter):
|
||||
"""
|
||||
Test the health method of VLLM InferenceAdapter when the connection is successful.
|
||||
|
||||
This test verifies that the health method returns a HealthResponse with status OK, only
|
||||
when the connection to the vLLM server is successful.
|
||||
"""
|
||||
# Set vllm_inference_adapter.client to None to ensure _create_client is called
|
||||
vllm_inference_adapter.client = None
|
||||
with patch.object(vllm_inference_adapter, "_create_client") as mock_create_client:
|
||||
# Create mock client and models
|
||||
mock_client = MagicMock()
|
||||
mock_models = MagicMock()
|
||||
|
||||
# Create a mock async iterator that yields a model when iterated
|
||||
async def mock_list():
|
||||
for model in [MagicMock()]:
|
||||
yield model
|
||||
|
||||
# Set up the models.list to return our mock async iterator
|
||||
mock_models.list.return_value = mock_list()
|
||||
mock_client.models = mock_models
|
||||
mock_create_client.return_value = mock_client
|
||||
|
||||
# Call the health method
|
||||
health_response = await vllm_inference_adapter.health()
|
||||
# Verify the response
|
||||
assert health_response["status"] == HealthStatus.OK
|
||||
|
||||
# Verify that models.list was called
|
||||
mock_models.list.assert_called_once()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_health_status_failure(vllm_inference_adapter):
|
||||
"""
|
||||
Test the health method of VLLM InferenceAdapter when the connection fails.
|
||||
|
||||
This test verifies that the health method returns a HealthResponse with status ERROR
|
||||
and an appropriate error message when the connection to the vLLM server fails.
|
||||
"""
|
||||
vllm_inference_adapter.client = None
|
||||
with patch.object(vllm_inference_adapter, "_create_client") as mock_create_client:
|
||||
# Create mock client and models
|
||||
mock_client = MagicMock()
|
||||
mock_models = MagicMock()
|
||||
|
||||
# Create a mock async iterator that raises an exception when iterated
|
||||
async def mock_list():
|
||||
raise Exception("Connection failed")
|
||||
yield # Unreachable code
|
||||
|
||||
# Set up the models.list to return our mock async iterator
|
||||
mock_models.list.return_value = mock_list()
|
||||
mock_client.models = mock_models
|
||||
mock_create_client.return_value = mock_client
|
||||
|
||||
# Call the health method
|
||||
health_response = await vllm_inference_adapter.health()
|
||||
# Verify the response
|
||||
assert health_response["status"] == HealthStatus.ERROR
|
||||
assert "Health check failed: Connection failed" in health_response["message"]
|
||||
|
||||
mock_models.list.assert_called_once()
|
||||
|
|
|
|||
120
tests/unit/providers/vector_io/test_faiss.py
Normal file
120
tests/unit/providers/vector_io/test_faiss.py
Normal file
|
|
@ -0,0 +1,120 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import asyncio
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
|
||||
from llama_stack.apis.inference import EmbeddingsResponse, Inference
|
||||
from llama_stack.apis.vector_dbs import VectorDB
|
||||
from llama_stack.apis.vector_io import Chunk, QueryChunksResponse
|
||||
from llama_stack.providers.inline.vector_io.faiss.config import FaissVectorIOConfig
|
||||
from llama_stack.providers.inline.vector_io.faiss.faiss import (
|
||||
FaissIndex,
|
||||
FaissVectorIOAdapter,
|
||||
)
|
||||
|
||||
# This test is a unit test for the FaissVectorIOAdapter class. This should only contain
|
||||
# tests which are specific to this class. More general (API-level) tests should be placed in
|
||||
# tests/integration/vector_io/
|
||||
#
|
||||
# How to run this test:
|
||||
#
|
||||
# pytest tests/unit/providers/vector_io/test_faiss.py \
|
||||
# -v -s --tb=short --disable-warnings --asyncio-mode=auto
|
||||
|
||||
FAISS_PROVIDER = "faiss"
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def loop():
|
||||
return asyncio.new_event_loop()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def embedding_dimension():
|
||||
return 384
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def vector_db_id():
|
||||
return "test_vector_db"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_chunks():
|
||||
return [
|
||||
Chunk(content="MOCK text content 1", mime_type="text/plain", metadata={"document_id": "mock-doc-1"}),
|
||||
Chunk(content="MOCK text content 1", mime_type="text/plain", metadata={"document_id": "mock-doc-2"}),
|
||||
]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_embeddings(embedding_dimension):
|
||||
return np.random.rand(2, embedding_dimension).astype(np.float32)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_vector_db(vector_db_id, embedding_dimension) -> MagicMock:
|
||||
mock_vector_db = MagicMock(spec=VectorDB)
|
||||
mock_vector_db.embedding_model = "mock_embedding_model"
|
||||
mock_vector_db.identifier = vector_db_id
|
||||
mock_vector_db.embedding_dimension = embedding_dimension
|
||||
return mock_vector_db
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_inference_api(sample_embeddings):
|
||||
mock_api = MagicMock(spec=Inference)
|
||||
mock_api.embeddings = AsyncMock(return_value=EmbeddingsResponse(embeddings=sample_embeddings))
|
||||
return mock_api
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def faiss_config():
|
||||
config = MagicMock(spec=FaissVectorIOConfig)
|
||||
config.kvstore = None
|
||||
return config
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def faiss_index(embedding_dimension):
|
||||
index = await FaissIndex.create(dimension=embedding_dimension)
|
||||
yield index
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def faiss_adapter(faiss_config, mock_inference_api) -> FaissVectorIOAdapter:
|
||||
adapter = FaissVectorIOAdapter(config=faiss_config, inference_api=mock_inference_api)
|
||||
await adapter.initialize()
|
||||
yield adapter
|
||||
await adapter.shutdown()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_faiss_query_vector_returns_infinity_when_query_and_embedding_are_identical(
|
||||
faiss_index, sample_chunks, sample_embeddings, embedding_dimension
|
||||
):
|
||||
await faiss_index.add_chunks(sample_chunks, sample_embeddings)
|
||||
query_embedding = np.random.rand(embedding_dimension).astype(np.float32)
|
||||
|
||||
with patch.object(faiss_index.index, "search") as mock_search:
|
||||
mock_search.return_value = (np.array([[0.0, 0.1]]), np.array([[0, 1]]))
|
||||
|
||||
response = await faiss_index.query_vector(embedding=query_embedding, k=2, score_threshold=0.0)
|
||||
|
||||
assert isinstance(response, QueryChunksResponse)
|
||||
assert len(response.chunks) == 2
|
||||
assert len(response.scores) == 2
|
||||
|
||||
assert response.scores[0] == float("inf") # infinity (1.0 / 0.0)
|
||||
assert response.scores[1] == 10.0 # (1.0 / 0.1 = 10.0)
|
||||
|
||||
assert response.chunks[0] == sample_chunks[0]
|
||||
assert response.chunks[1] == sample_chunks[1]
|
||||
|
|
@ -84,6 +84,28 @@ async def test_query_chunks_full_text_search(sqlite_vec_index, sample_chunks, sa
|
|||
assert len(response_no_results.chunks) == 0, f"Expected 0 results, but got {len(response_no_results.chunks)}"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_chunks_hybrid(sqlite_vec_index, sample_chunks, sample_embeddings):
|
||||
await sqlite_vec_index.add_chunks(sample_chunks, sample_embeddings)
|
||||
|
||||
# Create a query embedding that's similar to the first chunk
|
||||
query_embedding = sample_embeddings[0]
|
||||
query_string = "Sentence 5"
|
||||
|
||||
response = await sqlite_vec_index.query_hybrid(
|
||||
embedding=query_embedding,
|
||||
query_string=query_string,
|
||||
k=3,
|
||||
score_threshold=0.0,
|
||||
reranker_type="rrf",
|
||||
reranker_params={"impact_factor": 60.0},
|
||||
)
|
||||
|
||||
assert len(response.chunks) == 3, f"Expected 3 results, got {len(response.chunks)}"
|
||||
# Verify scores are in descending order (higher is better)
|
||||
assert all(response.scores[i] >= response.scores[i + 1] for i in range(len(response.scores) - 1))
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_chunks_full_text_search_k_greater_than_results(sqlite_vec_index, sample_chunks, sample_embeddings):
|
||||
# Re-initialize with a clean index
|
||||
|
|
@ -141,3 +163,355 @@ def test_generate_chunk_id():
|
|||
"bc744db3-1b25-0a9c-cdff-b6ba3df73c36",
|
||||
"f68df25d-d9aa-ab4d-5684-64a233add20d",
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_chunks_hybrid_no_keyword_matches(sqlite_vec_index, sample_chunks, sample_embeddings):
|
||||
"""Test hybrid search when keyword search returns no matches - should still return vector results."""
|
||||
await sqlite_vec_index.add_chunks(sample_chunks, sample_embeddings)
|
||||
|
||||
# Use a non-existent keyword but a valid vector query
|
||||
query_embedding = sample_embeddings[0]
|
||||
query_string = "Sentence 499"
|
||||
|
||||
# First verify keyword search returns no results
|
||||
keyword_response = await sqlite_vec_index.query_keyword(query_string, k=5, score_threshold=0.0)
|
||||
assert len(keyword_response.chunks) == 0, "Keyword search should return no results"
|
||||
|
||||
# Get hybrid results
|
||||
response = await sqlite_vec_index.query_hybrid(
|
||||
embedding=query_embedding,
|
||||
query_string=query_string,
|
||||
k=3,
|
||||
score_threshold=0.0,
|
||||
reranker_type="rrf",
|
||||
reranker_params={"impact_factor": 60.0},
|
||||
)
|
||||
|
||||
# Should still get results from vector search
|
||||
assert len(response.chunks) > 0, "Should get results from vector search even with no keyword matches"
|
||||
# Verify scores are in descending order
|
||||
assert all(response.scores[i] >= response.scores[i + 1] for i in range(len(response.scores) - 1))
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_chunks_hybrid_score_threshold(sqlite_vec_index, sample_chunks, sample_embeddings):
|
||||
"""Test hybrid search with a high score threshold."""
|
||||
await sqlite_vec_index.add_chunks(sample_chunks, sample_embeddings)
|
||||
|
||||
# Use a very high score threshold that no results will meet
|
||||
query_embedding = sample_embeddings[0]
|
||||
query_string = "Sentence 5"
|
||||
|
||||
response = await sqlite_vec_index.query_hybrid(
|
||||
embedding=query_embedding,
|
||||
query_string=query_string,
|
||||
k=3,
|
||||
score_threshold=1000.0, # Very high threshold
|
||||
reranker_type="rrf",
|
||||
reranker_params={"impact_factor": 60.0},
|
||||
)
|
||||
|
||||
# Should return no results due to high threshold
|
||||
assert len(response.chunks) == 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_chunks_hybrid_different_embedding(
|
||||
sqlite_vec_index, sample_chunks, sample_embeddings, embedding_dimension
|
||||
):
|
||||
"""Test hybrid search with a different embedding than the stored ones."""
|
||||
await sqlite_vec_index.add_chunks(sample_chunks, sample_embeddings)
|
||||
|
||||
# Create a random embedding that's different from stored ones
|
||||
query_embedding = np.random.rand(embedding_dimension).astype(np.float32)
|
||||
query_string = "Sentence 5"
|
||||
|
||||
response = await sqlite_vec_index.query_hybrid(
|
||||
embedding=query_embedding,
|
||||
query_string=query_string,
|
||||
k=3,
|
||||
score_threshold=0.0,
|
||||
reranker_type="rrf",
|
||||
reranker_params={"impact_factor": 60.0},
|
||||
)
|
||||
|
||||
# Should still get results if keyword matches exist
|
||||
assert len(response.chunks) > 0
|
||||
# Verify scores are in descending order
|
||||
assert all(response.scores[i] >= response.scores[i + 1] for i in range(len(response.scores) - 1))
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_chunks_hybrid_rrf_ranking(sqlite_vec_index, sample_chunks, sample_embeddings):
|
||||
"""Test that RRF properly combines rankings when documents appear in both search methods."""
|
||||
await sqlite_vec_index.add_chunks(sample_chunks, sample_embeddings)
|
||||
|
||||
# Create a query embedding that's similar to the first chunk
|
||||
query_embedding = sample_embeddings[0]
|
||||
# Use a keyword that appears in multiple documents
|
||||
query_string = "Sentence 5"
|
||||
|
||||
response = await sqlite_vec_index.query_hybrid(
|
||||
embedding=query_embedding,
|
||||
query_string=query_string,
|
||||
k=5,
|
||||
score_threshold=0.0,
|
||||
reranker_type="rrf",
|
||||
reranker_params={"impact_factor": 60.0},
|
||||
)
|
||||
|
||||
# Verify we get results from both search methods
|
||||
assert len(response.chunks) > 0
|
||||
# Verify scores are in descending order (RRF should maintain this)
|
||||
assert all(response.scores[i] >= response.scores[i + 1] for i in range(len(response.scores) - 1))
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_chunks_hybrid_score_selection(sqlite_vec_index, sample_chunks, sample_embeddings):
|
||||
await sqlite_vec_index.add_chunks(sample_chunks, sample_embeddings)
|
||||
|
||||
# Create a query embedding that's similar to the first chunk
|
||||
query_embedding = sample_embeddings[0]
|
||||
# Use a keyword that appears in the first document
|
||||
query_string = "Sentence 0 from document 0"
|
||||
|
||||
# Test weighted re-ranking
|
||||
response = await sqlite_vec_index.query_hybrid(
|
||||
embedding=query_embedding,
|
||||
query_string=query_string,
|
||||
k=1,
|
||||
score_threshold=0.0,
|
||||
reranker_type="weighted",
|
||||
reranker_params={"alpha": 0.5},
|
||||
)
|
||||
assert len(response.chunks) == 1
|
||||
# Score should be weighted average of normalized keyword score and vector score
|
||||
assert response.scores[0] > 0.5 # Both scores should be high
|
||||
|
||||
# Test RRF re-ranking
|
||||
response = await sqlite_vec_index.query_hybrid(
|
||||
embedding=query_embedding,
|
||||
query_string=query_string,
|
||||
k=1,
|
||||
score_threshold=0.0,
|
||||
reranker_type="rrf",
|
||||
reranker_params={"impact_factor": 60.0},
|
||||
)
|
||||
assert len(response.chunks) == 1
|
||||
# RRF score should be sum of reciprocal ranks
|
||||
assert response.scores[0] == pytest.approx(2.0 / 61.0, rel=1e-6) # 1/(60+1) + 1/(60+1)
|
||||
|
||||
# Test default re-ranking (should be RRF)
|
||||
response = await sqlite_vec_index.query_hybrid(
|
||||
embedding=query_embedding,
|
||||
query_string=query_string,
|
||||
k=1,
|
||||
score_threshold=0.0,
|
||||
reranker_type="rrf",
|
||||
reranker_params={"impact_factor": 60.0},
|
||||
)
|
||||
assert len(response.chunks) == 1
|
||||
assert response.scores[0] == pytest.approx(2.0 / 61.0, rel=1e-6) # Should behave like RRF
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_chunks_hybrid_mixed_results(sqlite_vec_index, sample_chunks, sample_embeddings):
|
||||
"""Test hybrid search with documents that appear in only one search method."""
|
||||
await sqlite_vec_index.add_chunks(sample_chunks, sample_embeddings)
|
||||
|
||||
# Create a query embedding that's similar to the first chunk
|
||||
query_embedding = sample_embeddings[0]
|
||||
# Use a keyword that appears in a different document
|
||||
query_string = "Sentence 9 from document 2"
|
||||
|
||||
response = await sqlite_vec_index.query_hybrid(
|
||||
embedding=query_embedding,
|
||||
query_string=query_string,
|
||||
k=3,
|
||||
score_threshold=0.0,
|
||||
reranker_type="rrf",
|
||||
reranker_params={"impact_factor": 60.0},
|
||||
)
|
||||
|
||||
# Should get results from both search methods
|
||||
assert len(response.chunks) > 0
|
||||
# Verify scores are in descending order
|
||||
assert all(response.scores[i] >= response.scores[i + 1] for i in range(len(response.scores) - 1))
|
||||
# Verify we get results from both the vector-similar document and keyword-matched document
|
||||
doc_ids = {chunk.metadata["document_id"] for chunk in response.chunks}
|
||||
assert "document-0" in doc_ids # From vector search
|
||||
assert "document-2" in doc_ids # From keyword search
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_chunks_hybrid_weighted_reranker_parametrization(
|
||||
sqlite_vec_index, sample_chunks, sample_embeddings
|
||||
):
|
||||
"""Test WeightedReRanker with different alpha values."""
|
||||
# Re-add data before each search to ensure test isolation
|
||||
await sqlite_vec_index.add_chunks(sample_chunks, sample_embeddings)
|
||||
query_embedding = sample_embeddings[0]
|
||||
query_string = "Sentence 0 from document 0"
|
||||
|
||||
# alpha=1.0 (should behave like pure keyword)
|
||||
response = await sqlite_vec_index.query_hybrid(
|
||||
embedding=query_embedding,
|
||||
query_string=query_string,
|
||||
k=1,
|
||||
score_threshold=0.0,
|
||||
reranker_type="weighted",
|
||||
reranker_params={"alpha": 1.0},
|
||||
)
|
||||
assert len(response.chunks) > 0 # Should get at least one result
|
||||
assert any("document-0" in chunk.metadata["document_id"] for chunk in response.chunks)
|
||||
|
||||
# alpha=0.0 (should behave like pure vector)
|
||||
response = await sqlite_vec_index.query_hybrid(
|
||||
embedding=query_embedding,
|
||||
query_string=query_string,
|
||||
k=1,
|
||||
score_threshold=0.0,
|
||||
reranker_type="weighted",
|
||||
reranker_params={"alpha": 0.0},
|
||||
)
|
||||
assert len(response.chunks) > 0 # Should get at least one result
|
||||
assert any("document-0" in chunk.metadata["document_id"] for chunk in response.chunks)
|
||||
|
||||
await sqlite_vec_index.add_chunks(sample_chunks, sample_embeddings)
|
||||
# alpha=0.7 (should be a mix)
|
||||
response = await sqlite_vec_index.query_hybrid(
|
||||
embedding=query_embedding,
|
||||
query_string=query_string,
|
||||
k=1,
|
||||
score_threshold=0.0,
|
||||
reranker_type="weighted",
|
||||
reranker_params={"alpha": 0.7},
|
||||
)
|
||||
assert len(response.chunks) > 0 # Should get at least one result
|
||||
assert any("document-0" in chunk.metadata["document_id"] for chunk in response.chunks)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_chunks_hybrid_rrf_impact_factor(sqlite_vec_index, sample_chunks, sample_embeddings):
|
||||
"""Test RRFReRanker with different impact factors."""
|
||||
await sqlite_vec_index.add_chunks(sample_chunks, sample_embeddings)
|
||||
query_embedding = sample_embeddings[0]
|
||||
query_string = "Sentence 0 from document 0"
|
||||
|
||||
# impact_factor=10
|
||||
response = await sqlite_vec_index.query_hybrid(
|
||||
embedding=query_embedding,
|
||||
query_string=query_string,
|
||||
k=1,
|
||||
score_threshold=0.0,
|
||||
reranker_type="rrf",
|
||||
reranker_params={"impact_factor": 10.0},
|
||||
)
|
||||
assert len(response.chunks) == 1
|
||||
assert response.scores[0] == pytest.approx(2.0 / 11.0, rel=1e-6)
|
||||
|
||||
# impact_factor=100
|
||||
response = await sqlite_vec_index.query_hybrid(
|
||||
embedding=query_embedding,
|
||||
query_string=query_string,
|
||||
k=1,
|
||||
score_threshold=0.0,
|
||||
reranker_type="rrf",
|
||||
reranker_params={"impact_factor": 100.0},
|
||||
)
|
||||
assert len(response.chunks) == 1
|
||||
assert response.scores[0] == pytest.approx(2.0 / 101.0, rel=1e-6)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_chunks_hybrid_edge_cases(sqlite_vec_index, sample_chunks, sample_embeddings):
|
||||
await sqlite_vec_index.add_chunks(sample_chunks, sample_embeddings)
|
||||
|
||||
# No results from either search - use a completely different embedding and a nonzero threshold
|
||||
query_embedding = np.ones_like(sample_embeddings[0]) * -1 # Very different from sample embeddings
|
||||
query_string = "no_such_keyword_that_will_never_match"
|
||||
response = await sqlite_vec_index.query_hybrid(
|
||||
embedding=query_embedding,
|
||||
query_string=query_string,
|
||||
k=3,
|
||||
score_threshold=0.1, # Nonzero threshold to filter out low-similarity matches
|
||||
reranker_type="rrf",
|
||||
reranker_params={"impact_factor": 60.0},
|
||||
)
|
||||
assert len(response.chunks) == 0
|
||||
|
||||
# All results below threshold
|
||||
query_embedding = sample_embeddings[0]
|
||||
query_string = "Sentence 0 from document 0"
|
||||
response = await sqlite_vec_index.query_hybrid(
|
||||
embedding=query_embedding,
|
||||
query_string=query_string,
|
||||
k=3,
|
||||
score_threshold=1000.0,
|
||||
reranker_type="rrf",
|
||||
reranker_params={"impact_factor": 60.0},
|
||||
)
|
||||
assert len(response.chunks) == 0
|
||||
|
||||
# Large k value
|
||||
response = await sqlite_vec_index.query_hybrid(
|
||||
embedding=query_embedding,
|
||||
query_string=query_string,
|
||||
k=100,
|
||||
score_threshold=0.0,
|
||||
reranker_type="rrf",
|
||||
reranker_params={"impact_factor": 60.0},
|
||||
)
|
||||
# Should not error, should return all available results
|
||||
assert len(response.chunks) > 0
|
||||
assert len(response.chunks) <= 100
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_chunks_hybrid_tie_breaking(
|
||||
sqlite_vec_index, sample_embeddings, embedding_dimension, tmp_path_factory
|
||||
):
|
||||
"""Test tie-breaking and determinism when scores are equal."""
|
||||
# Create two chunks with the same content and embedding
|
||||
chunk1 = Chunk(content="identical", metadata={"document_id": "docA"})
|
||||
chunk2 = Chunk(content="identical", metadata={"document_id": "docB"})
|
||||
chunks = [chunk1, chunk2]
|
||||
# Use the same embedding for both chunks to ensure equal scores
|
||||
same_embedding = sample_embeddings[0]
|
||||
embeddings = np.array([same_embedding, same_embedding])
|
||||
|
||||
# Clear existing data and recreate index
|
||||
await sqlite_vec_index.delete()
|
||||
temp_dir = tmp_path_factory.getbasetemp()
|
||||
db_path = str(temp_dir / "test_sqlite.db")
|
||||
sqlite_vec_index = await SQLiteVecIndex.create(dimension=embedding_dimension, db_path=db_path, bank_id="test_bank")
|
||||
await sqlite_vec_index.add_chunks(chunks, embeddings)
|
||||
|
||||
# Query with the same embedding and content to ensure equal scores
|
||||
query_embedding = same_embedding
|
||||
query_string = "identical"
|
||||
|
||||
# Run multiple queries to verify determinism
|
||||
responses = []
|
||||
for _ in range(3):
|
||||
response = await sqlite_vec_index.query_hybrid(
|
||||
embedding=query_embedding,
|
||||
query_string=query_string,
|
||||
k=2,
|
||||
score_threshold=0.0,
|
||||
reranker_type="rrf",
|
||||
reranker_params={"impact_factor": 60.0},
|
||||
)
|
||||
responses.append(response)
|
||||
|
||||
# Verify all responses are identical
|
||||
first_response = responses[0]
|
||||
for response in responses[1:]:
|
||||
assert response.chunks == first_response.chunks
|
||||
assert response.scores == first_response.scores
|
||||
|
||||
# Verify both chunks are returned with equal scores
|
||||
assert len(first_response.chunks) == 2
|
||||
assert first_response.scores[0] == first_response.scores[1]
|
||||
assert {chunk.metadata["document_id"] for chunk in first_response.chunks} == {"docA", "docB"}
|
||||
|
|
|
|||
|
|
@ -345,6 +345,56 @@ def test_invalid_oauth2_authentication(oauth2_client, invalid_token):
|
|||
assert "Invalid JWT token" in response.json()["error"]["message"]
|
||||
|
||||
|
||||
async def mock_auth_jwks_response(*args, **kwargs):
|
||||
if "headers" not in kwargs or "Authorization" not in kwargs["headers"]:
|
||||
return MockResponse(401, {})
|
||||
authz = kwargs["headers"]["Authorization"]
|
||||
if authz != "Bearer my-jwks-token":
|
||||
return MockResponse(401, {})
|
||||
return await mock_jwks_response(args, kwargs)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def oauth2_app_with_jwks_token():
|
||||
app = FastAPI()
|
||||
auth_config = AuthenticationConfig(
|
||||
provider_type=AuthProviderType.OAUTH2_TOKEN,
|
||||
config={
|
||||
"jwks": {
|
||||
"uri": "http://mock-authz-service/token/introspect",
|
||||
"key_recheck_period": "3600",
|
||||
"token": "my-jwks-token",
|
||||
},
|
||||
"audience": "llama-stack",
|
||||
},
|
||||
)
|
||||
app.add_middleware(AuthenticationMiddleware, auth_config=auth_config)
|
||||
|
||||
@app.get("/test")
|
||||
def test_endpoint():
|
||||
return {"message": "Authentication successful"}
|
||||
|
||||
return app
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def oauth2_client_with_jwks_token(oauth2_app_with_jwks_token):
|
||||
return TestClient(oauth2_app_with_jwks_token)
|
||||
|
||||
|
||||
@patch("httpx.AsyncClient.get", new=mock_auth_jwks_response)
|
||||
def test_oauth2_with_jwks_token_expected(oauth2_client, jwt_token_valid):
|
||||
response = oauth2_client.get("/test", headers={"Authorization": f"Bearer {jwt_token_valid}"})
|
||||
assert response.status_code == 401
|
||||
|
||||
|
||||
@patch("httpx.AsyncClient.get", new=mock_auth_jwks_response)
|
||||
def test_oauth2_with_jwks_token_configured(oauth2_client_with_jwks_token, jwt_token_valid):
|
||||
response = oauth2_client_with_jwks_token.get("/test", headers={"Authorization": f"Bearer {jwt_token_valid}"})
|
||||
assert response.status_code == 200
|
||||
assert response.json() == {"message": "Authentication successful"}
|
||||
|
||||
|
||||
def test_get_attributes_from_claims():
|
||||
claims = {
|
||||
"sub": "my-user",
|
||||
|
|
|
|||
|
|
@ -5,10 +5,12 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
import asyncio
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from llama_stack.distribution.server.server import create_sse_event, sse_generator
|
||||
from llama_stack.apis.common.responses import PaginatedResponse
|
||||
from llama_stack.distribution.server.server import create_dynamic_typed_route, create_sse_event, sse_generator
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
|
@ -89,3 +91,24 @@ async def test_sse_generator_error_before_response_starts():
|
|||
# We should have 1 error event
|
||||
assert len(seen_events) == 1
|
||||
assert 'data: {"error":' in seen_events[0]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_paginated_response_url_setting():
|
||||
"""Test that PaginatedResponse gets url set to route path."""
|
||||
|
||||
async def mock_api_method():
|
||||
return PaginatedResponse(data=[], has_more=False, url=None)
|
||||
|
||||
route_handler = create_dynamic_typed_route(mock_api_method, "get", "/test/route")
|
||||
|
||||
# Mock minimal request
|
||||
request = MagicMock()
|
||||
request.scope = {"user_attributes": {}, "principal": ""}
|
||||
request.headers = {}
|
||||
request.body = AsyncMock(return_value=b"")
|
||||
|
||||
result = await route_handler(request)
|
||||
|
||||
assert isinstance(result, PaginatedResponse)
|
||||
assert result.url == "/test/route"
|
||||
|
|
|
|||
Binary file not shown.
|
|
@ -31,6 +31,25 @@ test_response_web_search:
|
|||
search_context_size: "low"
|
||||
output: "128"
|
||||
|
||||
test_response_file_search:
|
||||
test_name: test_response_file_search
|
||||
test_params:
|
||||
case:
|
||||
- case_id: "llama_experts"
|
||||
input: "How many experts does the Llama 4 Maverick model have?"
|
||||
tools:
|
||||
- type: file_search
|
||||
# vector_store_ids param for file_search tool gets added by the test runner
|
||||
file_content: "Llama 4 Maverick has 128 experts"
|
||||
output: "128"
|
||||
- case_id: "llama_experts_pdf"
|
||||
input: "How many experts does the Llama 4 Maverick model have?"
|
||||
tools:
|
||||
- type: file_search
|
||||
# vector_store_ids param for file_search toolgets added by the test runner
|
||||
file_path: "pdfs/llama_stack_and_models.pdf"
|
||||
output: "128"
|
||||
|
||||
test_response_mcp_tool:
|
||||
test_name: test_response_mcp_tool
|
||||
test_params:
|
||||
|
|
|
|||
|
|
@ -5,6 +5,8 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
import json
|
||||
import os
|
||||
import time
|
||||
|
||||
import httpx
|
||||
import openai
|
||||
|
|
@ -23,6 +25,31 @@ from tests.verifications.openai_api.fixtures.load import load_test_cases
|
|||
responses_test_cases = load_test_cases("responses")
|
||||
|
||||
|
||||
def _new_vector_store(openai_client, name):
|
||||
# Ensure we don't reuse an existing vector store
|
||||
vector_stores = openai_client.vector_stores.list()
|
||||
for vector_store in vector_stores:
|
||||
if vector_store.name == name:
|
||||
openai_client.vector_stores.delete(vector_store_id=vector_store.id)
|
||||
|
||||
# Create a new vector store
|
||||
vector_store = openai_client.vector_stores.create(
|
||||
name=name,
|
||||
)
|
||||
return vector_store
|
||||
|
||||
|
||||
def _upload_file(openai_client, name, file_path):
|
||||
# Ensure we don't reuse an existing file
|
||||
files = openai_client.files.list()
|
||||
for file in files:
|
||||
if file.filename == name:
|
||||
openai_client.files.delete(file_id=file.id)
|
||||
|
||||
# Upload a text file with our document content
|
||||
return openai_client.files.create(file=open(file_path, "rb"), purpose="assistants")
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"case",
|
||||
responses_test_cases["test_response_basic"]["test_params"]["case"],
|
||||
|
|
@ -258,6 +285,111 @@ def test_response_non_streaming_web_search(request, openai_client, model, provid
|
|||
assert case["output"].lower() in response.output_text.lower().strip()
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"case",
|
||||
responses_test_cases["test_response_file_search"]["test_params"]["case"],
|
||||
ids=case_id_generator,
|
||||
)
|
||||
def test_response_non_streaming_file_search(
|
||||
request, openai_client, model, provider, verification_config, tmp_path, case
|
||||
):
|
||||
if isinstance(openai_client, LlamaStackAsLibraryClient):
|
||||
pytest.skip("Responses API file search is not yet supported in library client.")
|
||||
|
||||
test_name_base = get_base_test_name(request)
|
||||
if should_skip_test(verification_config, provider, model, test_name_base):
|
||||
pytest.skip(f"Skipping {test_name_base} for model {model} on provider {provider} based on config.")
|
||||
|
||||
vector_store = _new_vector_store(openai_client, "test_vector_store")
|
||||
|
||||
if "file_content" in case:
|
||||
file_name = "test_response_non_streaming_file_search.txt"
|
||||
file_path = tmp_path / file_name
|
||||
file_path.write_text(case["file_content"])
|
||||
elif "file_path" in case:
|
||||
file_path = os.path.join(os.path.dirname(__file__), "fixtures", case["file_path"])
|
||||
file_name = os.path.basename(file_path)
|
||||
else:
|
||||
raise ValueError(f"No file content or path provided for case {case['case_id']}")
|
||||
|
||||
file_response = _upload_file(openai_client, file_name, file_path)
|
||||
|
||||
# Attach our file to the vector store
|
||||
file_attach_response = openai_client.vector_stores.files.create(
|
||||
vector_store_id=vector_store.id,
|
||||
file_id=file_response.id,
|
||||
)
|
||||
|
||||
# Wait for the file to be attached
|
||||
while file_attach_response.status == "in_progress":
|
||||
time.sleep(0.1)
|
||||
file_attach_response = openai_client.vector_stores.files.retrieve(
|
||||
vector_store_id=vector_store.id,
|
||||
file_id=file_response.id,
|
||||
)
|
||||
assert file_attach_response.status == "completed", f"Expected file to be attached, got {file_attach_response}"
|
||||
assert not file_attach_response.last_error
|
||||
|
||||
# Update our tools with the right vector store id
|
||||
tools = case["tools"]
|
||||
for tool in tools:
|
||||
if tool["type"] == "file_search":
|
||||
tool["vector_store_ids"] = [vector_store.id]
|
||||
|
||||
# Create the response request, which should query our vector store
|
||||
response = openai_client.responses.create(
|
||||
model=model,
|
||||
input=case["input"],
|
||||
tools=tools,
|
||||
stream=False,
|
||||
include=["file_search_call.results"],
|
||||
)
|
||||
|
||||
# Verify the file_search_tool was called
|
||||
assert len(response.output) > 1
|
||||
assert response.output[0].type == "file_search_call"
|
||||
assert response.output[0].status == "completed"
|
||||
assert response.output[0].queries # ensure it's some non-empty list
|
||||
assert response.output[0].results
|
||||
assert case["output"].lower() in response.output[0].results[0].text.lower()
|
||||
assert response.output[0].results[0].score > 0
|
||||
|
||||
# Verify the output_text generated by the response
|
||||
assert case["output"].lower() in response.output_text.lower().strip()
|
||||
|
||||
|
||||
def test_response_non_streaming_file_search_empty_vector_store(
|
||||
request, openai_client, model, provider, verification_config
|
||||
):
|
||||
if isinstance(openai_client, LlamaStackAsLibraryClient):
|
||||
pytest.skip("Responses API file search is not yet supported in library client.")
|
||||
|
||||
test_name_base = get_base_test_name(request)
|
||||
if should_skip_test(verification_config, provider, model, test_name_base):
|
||||
pytest.skip(f"Skipping {test_name_base} for model {model} on provider {provider} based on config.")
|
||||
|
||||
vector_store = _new_vector_store(openai_client, "test_vector_store")
|
||||
|
||||
# Create the response request, which should query our vector store
|
||||
response = openai_client.responses.create(
|
||||
model=model,
|
||||
input="How many experts does the Llama 4 Maverick model have?",
|
||||
tools=[{"type": "file_search", "vector_store_ids": [vector_store.id]}],
|
||||
stream=False,
|
||||
include=["file_search_call.results"],
|
||||
)
|
||||
|
||||
# Verify the file_search_tool was called
|
||||
assert len(response.output) > 1
|
||||
assert response.output[0].type == "file_search_call"
|
||||
assert response.output[0].status == "completed"
|
||||
assert response.output[0].queries # ensure it's some non-empty list
|
||||
assert not response.output[0].results # ensure we don't get any results
|
||||
|
||||
# Verify some output_text was generated by the response
|
||||
assert response.output_text
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"case",
|
||||
responses_test_cases["test_response_mcp_tool"]["test_params"]["case"],
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue