mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-18 02:42:31 +00:00
Merge branch 'main' into feat/add-url-to-paginated-response
This commit is contained in:
commit
b5047db685
24 changed files with 911 additions and 856 deletions
26
.github/workflows/integration-auth-tests.yml
vendored
26
.github/workflows/integration-auth-tests.yml
vendored
|
@ -52,30 +52,7 @@ jobs:
|
||||||
run: |
|
run: |
|
||||||
kubectl create namespace llama-stack
|
kubectl create namespace llama-stack
|
||||||
kubectl create serviceaccount llama-stack-auth -n llama-stack
|
kubectl create serviceaccount llama-stack-auth -n llama-stack
|
||||||
kubectl create rolebinding llama-stack-auth-rolebinding --clusterrole=admin --serviceaccount=llama-stack:llama-stack-auth -n llama-stack
|
|
||||||
kubectl create token llama-stack-auth -n llama-stack > llama-stack-auth-token
|
kubectl create token llama-stack-auth -n llama-stack > llama-stack-auth-token
|
||||||
cat <<EOF | kubectl apply -f -
|
|
||||||
apiVersion: rbac.authorization.k8s.io/v1
|
|
||||||
kind: ClusterRole
|
|
||||||
metadata:
|
|
||||||
name: allow-anonymous-openid
|
|
||||||
rules:
|
|
||||||
- nonResourceURLs: ["/openid/v1/jwks"]
|
|
||||||
verbs: ["get"]
|
|
||||||
---
|
|
||||||
apiVersion: rbac.authorization.k8s.io/v1
|
|
||||||
kind: ClusterRoleBinding
|
|
||||||
metadata:
|
|
||||||
name: allow-anonymous-openid
|
|
||||||
roleRef:
|
|
||||||
apiGroup: rbac.authorization.k8s.io
|
|
||||||
kind: ClusterRole
|
|
||||||
name: allow-anonymous-openid
|
|
||||||
subjects:
|
|
||||||
- kind: User
|
|
||||||
name: system:anonymous
|
|
||||||
apiGroup: rbac.authorization.k8s.io
|
|
||||||
EOF
|
|
||||||
|
|
||||||
- name: Set Kubernetes Config
|
- name: Set Kubernetes Config
|
||||||
if: ${{ matrix.auth-provider == 'oauth2_token' }}
|
if: ${{ matrix.auth-provider == 'oauth2_token' }}
|
||||||
|
@ -84,6 +61,7 @@ jobs:
|
||||||
echo "KUBERNETES_CA_CERT_PATH=$(kubectl config view --minify -o jsonpath='{.clusters[0].cluster.certificate-authority}')" >> $GITHUB_ENV
|
echo "KUBERNETES_CA_CERT_PATH=$(kubectl config view --minify -o jsonpath='{.clusters[0].cluster.certificate-authority}')" >> $GITHUB_ENV
|
||||||
echo "KUBERNETES_ISSUER=$(kubectl get --raw /.well-known/openid-configuration| jq -r .issuer)" >> $GITHUB_ENV
|
echo "KUBERNETES_ISSUER=$(kubectl get --raw /.well-known/openid-configuration| jq -r .issuer)" >> $GITHUB_ENV
|
||||||
echo "KUBERNETES_AUDIENCE=$(kubectl create token llama-stack-auth -n llama-stack --duration=1h | cut -d. -f2 | base64 -d | jq -r '.aud[0]')" >> $GITHUB_ENV
|
echo "KUBERNETES_AUDIENCE=$(kubectl create token llama-stack-auth -n llama-stack --duration=1h | cut -d. -f2 | base64 -d | jq -r '.aud[0]')" >> $GITHUB_ENV
|
||||||
|
echo "TOKEN=$(cat llama-stack-auth-token)" >> $GITHUB_ENV
|
||||||
|
|
||||||
- name: Set Kube Auth Config and run server
|
- name: Set Kube Auth Config and run server
|
||||||
env:
|
env:
|
||||||
|
@ -101,7 +79,7 @@ jobs:
|
||||||
EOF
|
EOF
|
||||||
yq eval '.server.auth = {"provider_type": "${{ matrix.auth-provider }}"}' -i $run_dir/run.yaml
|
yq eval '.server.auth = {"provider_type": "${{ matrix.auth-provider }}"}' -i $run_dir/run.yaml
|
||||||
yq eval '.server.auth.config = {"tls_cafile": "${{ env.KUBERNETES_CA_CERT_PATH }}", "issuer": "${{ env.KUBERNETES_ISSUER }}", "audience": "${{ env.KUBERNETES_AUDIENCE }}"}' -i $run_dir/run.yaml
|
yq eval '.server.auth.config = {"tls_cafile": "${{ env.KUBERNETES_CA_CERT_PATH }}", "issuer": "${{ env.KUBERNETES_ISSUER }}", "audience": "${{ env.KUBERNETES_AUDIENCE }}"}' -i $run_dir/run.yaml
|
||||||
yq eval '.server.auth.config.jwks = {"uri": "${{ env.KUBERNETES_API_SERVER_URL }}"}' -i $run_dir/run.yaml
|
yq eval '.server.auth.config.jwks = {"uri": "${{ env.KUBERNETES_API_SERVER_URL }}", "token": "${{ env.TOKEN }}"}' -i $run_dir/run.yaml
|
||||||
cat $run_dir/run.yaml
|
cat $run_dir/run.yaml
|
||||||
|
|
||||||
nohup uv run llama stack run $run_dir/run.yaml --image-type venv > server.log 2>&1 &
|
nohup uv run llama stack run $run_dir/run.yaml --image-type venv > server.log 2>&1 &
|
||||||
|
|
2
.github/workflows/integration-tests.yml
vendored
2
.github/workflows/integration-tests.yml
vendored
|
@ -24,7 +24,7 @@ jobs:
|
||||||
matrix:
|
matrix:
|
||||||
# Listing tests manually since some of them currently fail
|
# Listing tests manually since some of them currently fail
|
||||||
# TODO: generate matrix list from tests/integration when fixed
|
# TODO: generate matrix list from tests/integration when fixed
|
||||||
test-type: [agents, inference, datasets, inspect, scoring, post_training, providers, tool_runtime]
|
test-type: [agents, inference, datasets, inspect, scoring, post_training, providers, tool_runtime, vector_io]
|
||||||
client-type: [library, http]
|
client-type: [library, http]
|
||||||
python-version: ["3.10", "3.11", "3.12"]
|
python-version: ["3.10", "3.11", "3.12"]
|
||||||
fail-fast: false # we want to run all tests regardless of failure
|
fail-fast: false # we want to run all tests regardless of failure
|
||||||
|
|
10
.github/workflows/test-external-providers.yml
vendored
10
.github/workflows/test-external-providers.yml
vendored
|
@ -45,20 +45,22 @@ jobs:
|
||||||
|
|
||||||
- name: Build distro from config file
|
- name: Build distro from config file
|
||||||
run: |
|
run: |
|
||||||
USE_COPY_NOT_MOUNT=true LLAMA_STACK_DIR=. uv run llama stack build --config tests/external-provider/llama-stack-provider-ollama/custom-distro.yaml
|
USE_COPY_NOT_MOUNT=true LLAMA_STACK_DIR=. llama stack build --config tests/external-provider/llama-stack-provider-ollama/custom-distro.yaml
|
||||||
|
|
||||||
- name: Start Llama Stack server in background
|
- name: Start Llama Stack server in background
|
||||||
if: ${{ matrix.image-type }} == 'venv'
|
if: ${{ matrix.image-type }} == 'venv'
|
||||||
env:
|
env:
|
||||||
INFERENCE_MODEL: "meta-llama/Llama-3.2-3B-Instruct"
|
INFERENCE_MODEL: "meta-llama/Llama-3.2-3B-Instruct"
|
||||||
run: |
|
run: |
|
||||||
uv run pip list
|
# Use the virtual environment created by the build step (name comes from build config)
|
||||||
nohup uv run --active llama stack run tests/external-provider/llama-stack-provider-ollama/run.yaml --image-type ${{ matrix.image-type }} > server.log 2>&1 &
|
source ci-test/bin/activate
|
||||||
|
uv pip list
|
||||||
|
nohup llama stack run tests/external-provider/llama-stack-provider-ollama/run.yaml --image-type ${{ matrix.image-type }} > server.log 2>&1 &
|
||||||
|
|
||||||
- name: Wait for Llama Stack server to be ready
|
- name: Wait for Llama Stack server to be ready
|
||||||
run: |
|
run: |
|
||||||
for i in {1..30}; do
|
for i in {1..30}; do
|
||||||
if ! grep -q "remote::custom_ollama from /home/runner/.llama/providers.d/remote/inference/custom_ollama.yaml" server.log; then
|
if ! grep -q "Successfully loaded external provider remote::custom_ollama" server.log; then
|
||||||
echo "Waiting for Llama Stack server to load the provider..."
|
echo "Waiting for Llama Stack server to load the provider..."
|
||||||
sleep 1
|
sleep 1
|
||||||
else
|
else
|
||||||
|
|
102
docs/_static/llama-stack-spec.html
vendored
102
docs/_static/llama-stack-spec.html
vendored
|
@ -3318,7 +3318,7 @@
|
||||||
"name": "limit",
|
"name": "limit",
|
||||||
"in": "query",
|
"in": "query",
|
||||||
"description": "A limit on the number of objects to be returned. Limit can range between 1 and 100, and the default is 20.",
|
"description": "A limit on the number of objects to be returned. Limit can range between 1 and 100, and the default is 20.",
|
||||||
"required": true,
|
"required": false,
|
||||||
"schema": {
|
"schema": {
|
||||||
"type": "integer"
|
"type": "integer"
|
||||||
}
|
}
|
||||||
|
@ -3327,7 +3327,7 @@
|
||||||
"name": "order",
|
"name": "order",
|
||||||
"in": "query",
|
"in": "query",
|
||||||
"description": "Sort order by the `created_at` timestamp of the objects. `asc` for ascending order and `desc` for descending order.",
|
"description": "Sort order by the `created_at` timestamp of the objects. `asc` for ascending order and `desc` for descending order.",
|
||||||
"required": true,
|
"required": false,
|
||||||
"schema": {
|
"schema": {
|
||||||
"type": "string"
|
"type": "string"
|
||||||
}
|
}
|
||||||
|
@ -3864,7 +3864,7 @@
|
||||||
"content": {
|
"content": {
|
||||||
"application/json": {
|
"application/json": {
|
||||||
"schema": {
|
"schema": {
|
||||||
"$ref": "#/components/schemas/VectorStoreSearchResponse"
|
"$ref": "#/components/schemas/VectorStoreSearchResponsePage"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -12587,6 +12587,9 @@
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"additionalProperties": false,
|
"additionalProperties": false,
|
||||||
|
"required": [
|
||||||
|
"name"
|
||||||
|
],
|
||||||
"title": "OpenaiCreateVectorStoreRequest"
|
"title": "OpenaiCreateVectorStoreRequest"
|
||||||
},
|
},
|
||||||
"VectorStoreObject": {
|
"VectorStoreObject": {
|
||||||
|
@ -13129,13 +13132,74 @@
|
||||||
},
|
},
|
||||||
"additionalProperties": false,
|
"additionalProperties": false,
|
||||||
"required": [
|
"required": [
|
||||||
"query",
|
"query"
|
||||||
"max_num_results",
|
|
||||||
"rewrite_query"
|
|
||||||
],
|
],
|
||||||
"title": "OpenaiSearchVectorStoreRequest"
|
"title": "OpenaiSearchVectorStoreRequest"
|
||||||
},
|
},
|
||||||
|
"VectorStoreContent": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"type": {
|
||||||
|
"type": "string",
|
||||||
|
"const": "text"
|
||||||
|
},
|
||||||
|
"text": {
|
||||||
|
"type": "string"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"additionalProperties": false,
|
||||||
|
"required": [
|
||||||
|
"type",
|
||||||
|
"text"
|
||||||
|
],
|
||||||
|
"title": "VectorStoreContent"
|
||||||
|
},
|
||||||
"VectorStoreSearchResponse": {
|
"VectorStoreSearchResponse": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"file_id": {
|
||||||
|
"type": "string"
|
||||||
|
},
|
||||||
|
"filename": {
|
||||||
|
"type": "string"
|
||||||
|
},
|
||||||
|
"score": {
|
||||||
|
"type": "number"
|
||||||
|
},
|
||||||
|
"attributes": {
|
||||||
|
"type": "object",
|
||||||
|
"additionalProperties": {
|
||||||
|
"oneOf": [
|
||||||
|
{
|
||||||
|
"type": "string"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "number"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "boolean"
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"content": {
|
||||||
|
"type": "array",
|
||||||
|
"items": {
|
||||||
|
"$ref": "#/components/schemas/VectorStoreContent"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"additionalProperties": false,
|
||||||
|
"required": [
|
||||||
|
"file_id",
|
||||||
|
"filename",
|
||||||
|
"score",
|
||||||
|
"content"
|
||||||
|
],
|
||||||
|
"title": "VectorStoreSearchResponse",
|
||||||
|
"description": "Response from searching a vector store."
|
||||||
|
},
|
||||||
|
"VectorStoreSearchResponsePage": {
|
||||||
"type": "object",
|
"type": "object",
|
||||||
"properties": {
|
"properties": {
|
||||||
"object": {
|
"object": {
|
||||||
|
@ -13148,29 +13212,7 @@
|
||||||
"data": {
|
"data": {
|
||||||
"type": "array",
|
"type": "array",
|
||||||
"items": {
|
"items": {
|
||||||
"type": "object",
|
"$ref": "#/components/schemas/VectorStoreSearchResponse"
|
||||||
"additionalProperties": {
|
|
||||||
"oneOf": [
|
|
||||||
{
|
|
||||||
"type": "null"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"type": "boolean"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"type": "number"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"type": "string"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"type": "array"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"type": "object"
|
|
||||||
}
|
|
||||||
]
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"has_more": {
|
"has_more": {
|
||||||
|
@ -13188,7 +13230,7 @@
|
||||||
"data",
|
"data",
|
||||||
"has_more"
|
"has_more"
|
||||||
],
|
],
|
||||||
"title": "VectorStoreSearchResponse",
|
"title": "VectorStoreSearchResponsePage",
|
||||||
"description": "Response from searching a vector store."
|
"description": "Response from searching a vector store."
|
||||||
},
|
},
|
||||||
"OpenaiUpdateVectorStoreRequest": {
|
"OpenaiUpdateVectorStoreRequest": {
|
||||||
|
|
63
docs/_static/llama-stack-spec.yaml
vendored
63
docs/_static/llama-stack-spec.yaml
vendored
|
@ -2323,7 +2323,7 @@ paths:
|
||||||
description: >-
|
description: >-
|
||||||
A limit on the number of objects to be returned. Limit can range between
|
A limit on the number of objects to be returned. Limit can range between
|
||||||
1 and 100, and the default is 20.
|
1 and 100, and the default is 20.
|
||||||
required: true
|
required: false
|
||||||
schema:
|
schema:
|
||||||
type: integer
|
type: integer
|
||||||
- name: order
|
- name: order
|
||||||
|
@ -2331,7 +2331,7 @@ paths:
|
||||||
description: >-
|
description: >-
|
||||||
Sort order by the `created_at` timestamp of the objects. `asc` for ascending
|
Sort order by the `created_at` timestamp of the objects. `asc` for ascending
|
||||||
order and `desc` for descending order.
|
order and `desc` for descending order.
|
||||||
required: true
|
required: false
|
||||||
schema:
|
schema:
|
||||||
type: string
|
type: string
|
||||||
- name: after
|
- name: after
|
||||||
|
@ -2734,7 +2734,7 @@ paths:
|
||||||
content:
|
content:
|
||||||
application/json:
|
application/json:
|
||||||
schema:
|
schema:
|
||||||
$ref: '#/components/schemas/VectorStoreSearchResponse'
|
$ref: '#/components/schemas/VectorStoreSearchResponsePage'
|
||||||
'400':
|
'400':
|
||||||
$ref: '#/components/responses/BadRequest400'
|
$ref: '#/components/responses/BadRequest400'
|
||||||
'429':
|
'429':
|
||||||
|
@ -8794,6 +8794,8 @@ components:
|
||||||
description: >-
|
description: >-
|
||||||
The provider-specific vector database ID.
|
The provider-specific vector database ID.
|
||||||
additionalProperties: false
|
additionalProperties: false
|
||||||
|
required:
|
||||||
|
- name
|
||||||
title: OpenaiCreateVectorStoreRequest
|
title: OpenaiCreateVectorStoreRequest
|
||||||
VectorStoreObject:
|
VectorStoreObject:
|
||||||
type: object
|
type: object
|
||||||
|
@ -9190,10 +9192,49 @@ components:
|
||||||
additionalProperties: false
|
additionalProperties: false
|
||||||
required:
|
required:
|
||||||
- query
|
- query
|
||||||
- max_num_results
|
|
||||||
- rewrite_query
|
|
||||||
title: OpenaiSearchVectorStoreRequest
|
title: OpenaiSearchVectorStoreRequest
|
||||||
|
VectorStoreContent:
|
||||||
|
type: object
|
||||||
|
properties:
|
||||||
|
type:
|
||||||
|
type: string
|
||||||
|
const: text
|
||||||
|
text:
|
||||||
|
type: string
|
||||||
|
additionalProperties: false
|
||||||
|
required:
|
||||||
|
- type
|
||||||
|
- text
|
||||||
|
title: VectorStoreContent
|
||||||
VectorStoreSearchResponse:
|
VectorStoreSearchResponse:
|
||||||
|
type: object
|
||||||
|
properties:
|
||||||
|
file_id:
|
||||||
|
type: string
|
||||||
|
filename:
|
||||||
|
type: string
|
||||||
|
score:
|
||||||
|
type: number
|
||||||
|
attributes:
|
||||||
|
type: object
|
||||||
|
additionalProperties:
|
||||||
|
oneOf:
|
||||||
|
- type: string
|
||||||
|
- type: number
|
||||||
|
- type: boolean
|
||||||
|
content:
|
||||||
|
type: array
|
||||||
|
items:
|
||||||
|
$ref: '#/components/schemas/VectorStoreContent'
|
||||||
|
additionalProperties: false
|
||||||
|
required:
|
||||||
|
- file_id
|
||||||
|
- filename
|
||||||
|
- score
|
||||||
|
- content
|
||||||
|
title: VectorStoreSearchResponse
|
||||||
|
description: Response from searching a vector store.
|
||||||
|
VectorStoreSearchResponsePage:
|
||||||
type: object
|
type: object
|
||||||
properties:
|
properties:
|
||||||
object:
|
object:
|
||||||
|
@ -9204,15 +9245,7 @@ components:
|
||||||
data:
|
data:
|
||||||
type: array
|
type: array
|
||||||
items:
|
items:
|
||||||
type: object
|
$ref: '#/components/schemas/VectorStoreSearchResponse'
|
||||||
additionalProperties:
|
|
||||||
oneOf:
|
|
||||||
- type: 'null'
|
|
||||||
- type: boolean
|
|
||||||
- type: number
|
|
||||||
- type: string
|
|
||||||
- type: array
|
|
||||||
- type: object
|
|
||||||
has_more:
|
has_more:
|
||||||
type: boolean
|
type: boolean
|
||||||
default: false
|
default: false
|
||||||
|
@ -9224,7 +9257,7 @@ components:
|
||||||
- search_query
|
- search_query
|
||||||
- data
|
- data
|
||||||
- has_more
|
- has_more
|
||||||
title: VectorStoreSearchResponse
|
title: VectorStoreSearchResponsePage
|
||||||
description: Response from searching a vector store.
|
description: Response from searching a vector store.
|
||||||
OpenaiUpdateVectorStoreRequest:
|
OpenaiUpdateVectorStoreRequest:
|
||||||
type: object
|
type: object
|
||||||
|
|
|
@ -56,10 +56,10 @@ shields: []
|
||||||
server:
|
server:
|
||||||
port: 8321
|
port: 8321
|
||||||
auth:
|
auth:
|
||||||
provider_type: "kubernetes"
|
provider_type: "oauth2_token"
|
||||||
config:
|
config:
|
||||||
api_server_url: "https://kubernetes.default.svc"
|
jwks:
|
||||||
ca_cert_path: "/path/to/ca.crt"
|
uri: "https://my-token-issuing-svc.com/jwks"
|
||||||
```
|
```
|
||||||
|
|
||||||
Let's break this down into the different sections. The first section specifies the set of APIs that the stack server will serve:
|
Let's break this down into the different sections. The first section specifies the set of APIs that the stack server will serve:
|
||||||
|
@ -132,16 +132,52 @@ The server supports multiple authentication providers:
|
||||||
|
|
||||||
#### OAuth 2.0/OpenID Connect Provider with Kubernetes
|
#### OAuth 2.0/OpenID Connect Provider with Kubernetes
|
||||||
|
|
||||||
The Kubernetes cluster must be configured to use a service account for authentication.
|
The server can be configured to use service account tokens for authorization, validating these against the Kubernetes API server, e.g.:
|
||||||
|
```yaml
|
||||||
|
server:
|
||||||
|
auth:
|
||||||
|
provider_type: "oauth2_token"
|
||||||
|
config:
|
||||||
|
jwks:
|
||||||
|
uri: "https://kubernetes.default.svc:8443/openid/v1/jwks"
|
||||||
|
token: "${env.TOKEN:}"
|
||||||
|
key_recheck_period: 3600
|
||||||
|
tls_cafile: "/path/to/ca.crt"
|
||||||
|
issuer: "https://kubernetes.default.svc"
|
||||||
|
audience: "https://kubernetes.default.svc"
|
||||||
|
```
|
||||||
|
|
||||||
|
To find your cluster's jwks uri (from which the public key(s) to verify the token signature are obtained), run:
|
||||||
|
```
|
||||||
|
kubectl get --raw /.well-known/openid-configuration| jq -r .jwks_uri
|
||||||
|
```
|
||||||
|
|
||||||
|
For the tls_cafile, you can use the CA certificate of the OIDC provider:
|
||||||
|
```bash
|
||||||
|
kubectl config view --minify -o jsonpath='{.clusters[0].cluster.certificate-authority}'
|
||||||
|
```
|
||||||
|
|
||||||
|
For the issuer, you can use the OIDC provider's URL:
|
||||||
|
```bash
|
||||||
|
kubectl get --raw /.well-known/openid-configuration| jq .issuer
|
||||||
|
```
|
||||||
|
|
||||||
|
The audience can be obtained from a token, e.g. run:
|
||||||
|
```bash
|
||||||
|
kubectl create token default --duration=1h | cut -d. -f2 | base64 -d | jq .aud
|
||||||
|
```
|
||||||
|
|
||||||
|
The jwks token is used to authorize access to the jwks endpoint. You can obtain a token by running:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
kubectl create namespace llama-stack
|
kubectl create namespace llama-stack
|
||||||
kubectl create serviceaccount llama-stack-auth -n llama-stack
|
kubectl create serviceaccount llama-stack-auth -n llama-stack
|
||||||
kubectl create rolebinding llama-stack-auth-rolebinding --clusterrole=admin --serviceaccount=llama-stack:llama-stack-auth -n llama-stack
|
|
||||||
kubectl create token llama-stack-auth -n llama-stack > llama-stack-auth-token
|
kubectl create token llama-stack-auth -n llama-stack > llama-stack-auth-token
|
||||||
|
export TOKEN=$(cat llama-stack-auth-token)
|
||||||
```
|
```
|
||||||
|
|
||||||
Make sure the `kube-apiserver` runs with `--anonymous-auth=true` to allow unauthenticated requests
|
Alternatively, you can configure the jwks endpoint to allow anonymous access. To do this, make sure
|
||||||
|
the `kube-apiserver` runs with `--anonymous-auth=true` to allow unauthenticated requests
|
||||||
and that the correct RoleBinding is created to allow the service account to access the necessary
|
and that the correct RoleBinding is created to allow the service account to access the necessary
|
||||||
resources. If that is not the case, you can create a RoleBinding for the service account to access
|
resources. If that is not the case, you can create a RoleBinding for the service account to access
|
||||||
the necessary resources:
|
the necessary resources:
|
||||||
|
@ -175,35 +211,6 @@ And then apply the configuration:
|
||||||
kubectl apply -f allow-anonymous-openid.yaml
|
kubectl apply -f allow-anonymous-openid.yaml
|
||||||
```
|
```
|
||||||
|
|
||||||
Validates tokens against the Kubernetes API server through the OIDC provider:
|
|
||||||
```yaml
|
|
||||||
server:
|
|
||||||
auth:
|
|
||||||
provider_type: "oauth2_token"
|
|
||||||
config:
|
|
||||||
jwks:
|
|
||||||
uri: "https://kubernetes.default.svc"
|
|
||||||
key_recheck_period: 3600
|
|
||||||
tls_cafile: "/path/to/ca.crt"
|
|
||||||
issuer: "https://kubernetes.default.svc"
|
|
||||||
audience: "https://kubernetes.default.svc"
|
|
||||||
```
|
|
||||||
|
|
||||||
To find your cluster's audience, run:
|
|
||||||
```bash
|
|
||||||
kubectl create token default --duration=1h | cut -d. -f2 | base64 -d | jq .aud
|
|
||||||
```
|
|
||||||
|
|
||||||
For the issuer, you can use the OIDC provider's URL:
|
|
||||||
```bash
|
|
||||||
kubectl get --raw /.well-known/openid-configuration| jq .issuer
|
|
||||||
```
|
|
||||||
|
|
||||||
For the tls_cafile, you can use the CA certificate of the OIDC provider:
|
|
||||||
```bash
|
|
||||||
kubectl config view --minify -o jsonpath='{.clusters[0].cluster.certificate-authority}'
|
|
||||||
```
|
|
||||||
|
|
||||||
The provider extracts user information from the JWT token:
|
The provider extracts user information from the JWT token:
|
||||||
- Username from the `sub` claim becomes a role
|
- Username from the `sub` claim becomes a role
|
||||||
- Kubernetes groups become teams
|
- Kubernetes groups become teams
|
||||||
|
|
|
@ -8,7 +8,7 @@
|
||||||
#
|
#
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
from typing import Any, Protocol, runtime_checkable
|
from typing import Any, Literal, Protocol, runtime_checkable
|
||||||
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
@ -96,13 +96,30 @@ class VectorStoreSearchRequest(BaseModel):
|
||||||
rewrite_query: bool = False
|
rewrite_query: bool = False
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class VectorStoreContent(BaseModel):
|
||||||
|
type: Literal["text"]
|
||||||
|
text: str
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class VectorStoreSearchResponse(BaseModel):
|
class VectorStoreSearchResponse(BaseModel):
|
||||||
"""Response from searching a vector store."""
|
"""Response from searching a vector store."""
|
||||||
|
|
||||||
|
file_id: str
|
||||||
|
filename: str
|
||||||
|
score: float
|
||||||
|
attributes: dict[str, str | float | bool] | None = None
|
||||||
|
content: list[VectorStoreContent]
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class VectorStoreSearchResponsePage(BaseModel):
|
||||||
|
"""Response from searching a vector store."""
|
||||||
|
|
||||||
object: str = "vector_store.search_results.page"
|
object: str = "vector_store.search_results.page"
|
||||||
search_query: str
|
search_query: str
|
||||||
data: list[dict[str, Any]]
|
data: list[VectorStoreSearchResponse]
|
||||||
has_more: bool = False
|
has_more: bool = False
|
||||||
next_page: str | None = None
|
next_page: str | None = None
|
||||||
|
|
||||||
|
@ -165,7 +182,7 @@ class VectorIO(Protocol):
|
||||||
@webmethod(route="/openai/v1/vector_stores", method="POST")
|
@webmethod(route="/openai/v1/vector_stores", method="POST")
|
||||||
async def openai_create_vector_store(
|
async def openai_create_vector_store(
|
||||||
self,
|
self,
|
||||||
name: str | None = None,
|
name: str,
|
||||||
file_ids: list[str] | None = None,
|
file_ids: list[str] | None = None,
|
||||||
expires_after: dict[str, Any] | None = None,
|
expires_after: dict[str, Any] | None = None,
|
||||||
chunking_strategy: dict[str, Any] | None = None,
|
chunking_strategy: dict[str, Any] | None = None,
|
||||||
|
@ -193,8 +210,8 @@ class VectorIO(Protocol):
|
||||||
@webmethod(route="/openai/v1/vector_stores", method="GET")
|
@webmethod(route="/openai/v1/vector_stores", method="GET")
|
||||||
async def openai_list_vector_stores(
|
async def openai_list_vector_stores(
|
||||||
self,
|
self,
|
||||||
limit: int = 20,
|
limit: int | None = 20,
|
||||||
order: str = "desc",
|
order: str | None = "desc",
|
||||||
after: str | None = None,
|
after: str | None = None,
|
||||||
before: str | None = None,
|
before: str | None = None,
|
||||||
) -> VectorStoreListResponse:
|
) -> VectorStoreListResponse:
|
||||||
|
@ -256,10 +273,10 @@ class VectorIO(Protocol):
|
||||||
vector_store_id: str,
|
vector_store_id: str,
|
||||||
query: str | list[str],
|
query: str | list[str],
|
||||||
filters: dict[str, Any] | None = None,
|
filters: dict[str, Any] | None = None,
|
||||||
max_num_results: int = 10,
|
max_num_results: int | None = 10,
|
||||||
ranking_options: dict[str, Any] | None = None,
|
ranking_options: dict[str, Any] | None = None,
|
||||||
rewrite_query: bool = False,
|
rewrite_query: bool | None = False,
|
||||||
) -> VectorStoreSearchResponse:
|
) -> VectorStoreSearchResponsePage:
|
||||||
"""Search for chunks in a vector store.
|
"""Search for chunks in a vector store.
|
||||||
|
|
||||||
Searches a vector store for relevant chunks based on a query and optional file attribute filters.
|
Searches a vector store for relevant chunks based on a query and optional file attribute filters.
|
||||||
|
|
|
@ -180,6 +180,7 @@ def get_provider_registry(
|
||||||
if provider_type_key in ret[api]:
|
if provider_type_key in ret[api]:
|
||||||
logger.warning(f"Overriding already registered provider {provider_type_key} for {api.name}")
|
logger.warning(f"Overriding already registered provider {provider_type_key} for {api.name}")
|
||||||
ret[api][provider_type_key] = spec
|
ret[api][provider_type_key] = spec
|
||||||
|
logger.info(f"Successfully loaded external provider {provider_type_key}")
|
||||||
except yaml.YAMLError as yaml_err:
|
except yaml.YAMLError as yaml_err:
|
||||||
logger.error(f"Failed to parse YAML file {spec_path}: {yaml_err}")
|
logger.error(f"Failed to parse YAML file {spec_path}: {yaml_err}")
|
||||||
raise yaml_err
|
raise yaml_err
|
||||||
|
|
|
@ -394,9 +394,13 @@ def check_protocol_compliance(obj: Any, protocol: Any) -> None:
|
||||||
logger.error(f"Method {name} incompatible proto: {proto_params} vs. obj: {obj_params}")
|
logger.error(f"Method {name} incompatible proto: {proto_params} vs. obj: {obj_params}")
|
||||||
missing_methods.append((name, "signature_mismatch"))
|
missing_methods.append((name, "signature_mismatch"))
|
||||||
else:
|
else:
|
||||||
# Check if the method is actually implemented in the class
|
# Check if the method has a concrete implementation (not just a protocol stub)
|
||||||
method_owner = next((cls for cls in mro if name in cls.__dict__), None)
|
# Find all classes in MRO that define this method
|
||||||
if method_owner is None or method_owner.__name__ == protocol.__name__:
|
method_owners = [cls for cls in mro if name in cls.__dict__]
|
||||||
|
|
||||||
|
# Allow methods from mixins/parents, only reject if ONLY the protocol defines it
|
||||||
|
if len(method_owners) == 1 and method_owners[0].__name__ == protocol.__name__:
|
||||||
|
# Only reject if the method is ONLY defined in the protocol itself (abstract stub)
|
||||||
missing_methods.append((name, "not_actually_implemented"))
|
missing_methods.append((name, "not_actually_implemented"))
|
||||||
|
|
||||||
if missing_methods:
|
if missing_methods:
|
||||||
|
|
|
@ -163,6 +163,9 @@ class InferenceRouter(Inference):
|
||||||
messages: list[Message] | InterleavedContent,
|
messages: list[Message] | InterleavedContent,
|
||||||
tool_prompt_format: ToolPromptFormat | None = None,
|
tool_prompt_format: ToolPromptFormat | None = None,
|
||||||
) -> int | None:
|
) -> int | None:
|
||||||
|
if not hasattr(self, "formatter") or self.formatter is None:
|
||||||
|
return None
|
||||||
|
|
||||||
if isinstance(messages, list):
|
if isinstance(messages, list):
|
||||||
encoded = self.formatter.encode_dialog_prompt(messages, tool_prompt_format)
|
encoded = self.formatter.encode_dialog_prompt(messages, tool_prompt_format)
|
||||||
else:
|
else:
|
||||||
|
|
|
@ -17,7 +17,7 @@ from llama_stack.apis.vector_io import (
|
||||||
VectorStoreDeleteResponse,
|
VectorStoreDeleteResponse,
|
||||||
VectorStoreListResponse,
|
VectorStoreListResponse,
|
||||||
VectorStoreObject,
|
VectorStoreObject,
|
||||||
VectorStoreSearchResponse,
|
VectorStoreSearchResponsePage,
|
||||||
)
|
)
|
||||||
from llama_stack.log import get_logger
|
from llama_stack.log import get_logger
|
||||||
from llama_stack.providers.datatypes import RoutingTable
|
from llama_stack.providers.datatypes import RoutingTable
|
||||||
|
@ -108,7 +108,7 @@ class VectorIORouter(VectorIO):
|
||||||
# OpenAI Vector Stores API endpoints
|
# OpenAI Vector Stores API endpoints
|
||||||
async def openai_create_vector_store(
|
async def openai_create_vector_store(
|
||||||
self,
|
self,
|
||||||
name: str | None = None,
|
name: str,
|
||||||
file_ids: list[str] | None = None,
|
file_ids: list[str] | None = None,
|
||||||
expires_after: dict[str, Any] | None = None,
|
expires_after: dict[str, Any] | None = None,
|
||||||
chunking_strategy: dict[str, Any] | None = None,
|
chunking_strategy: dict[str, Any] | None = None,
|
||||||
|
@ -151,8 +151,8 @@ class VectorIORouter(VectorIO):
|
||||||
|
|
||||||
async def openai_list_vector_stores(
|
async def openai_list_vector_stores(
|
||||||
self,
|
self,
|
||||||
limit: int = 20,
|
limit: int | None = 20,
|
||||||
order: str = "desc",
|
order: str | None = "desc",
|
||||||
after: str | None = None,
|
after: str | None = None,
|
||||||
before: str | None = None,
|
before: str | None = None,
|
||||||
) -> VectorStoreListResponse:
|
) -> VectorStoreListResponse:
|
||||||
|
@ -239,10 +239,10 @@ class VectorIORouter(VectorIO):
|
||||||
vector_store_id: str,
|
vector_store_id: str,
|
||||||
query: str | list[str],
|
query: str | list[str],
|
||||||
filters: dict[str, Any] | None = None,
|
filters: dict[str, Any] | None = None,
|
||||||
max_num_results: int = 10,
|
max_num_results: int | None = 10,
|
||||||
ranking_options: dict[str, Any] | None = None,
|
ranking_options: dict[str, Any] | None = None,
|
||||||
rewrite_query: bool = False,
|
rewrite_query: bool | None = False,
|
||||||
) -> VectorStoreSearchResponse:
|
) -> VectorStoreSearchResponsePage:
|
||||||
logger.debug(f"VectorIORouter.openai_search_vector_store: {vector_store_id}")
|
logger.debug(f"VectorIORouter.openai_search_vector_store: {vector_store_id}")
|
||||||
# Route based on vector store ID
|
# Route based on vector store ID
|
||||||
provider = self.routing_table.get_provider_impl(vector_store_id)
|
provider = self.routing_table.get_provider_impl(vector_store_id)
|
||||||
|
|
|
@ -84,6 +84,7 @@ def get_attributes_from_claims(claims: dict[str, str], mapping: dict[str, str])
|
||||||
class OAuth2JWKSConfig(BaseModel):
|
class OAuth2JWKSConfig(BaseModel):
|
||||||
# The JWKS URI for collecting public keys
|
# The JWKS URI for collecting public keys
|
||||||
uri: str
|
uri: str
|
||||||
|
token: str | None = Field(default=None, description="token to authorise access to jwks")
|
||||||
key_recheck_period: int = Field(default=3600, description="The period to recheck the JWKS URI for key updates")
|
key_recheck_period: int = Field(default=3600, description="The period to recheck the JWKS URI for key updates")
|
||||||
|
|
||||||
|
|
||||||
|
@ -246,9 +247,12 @@ class OAuth2TokenAuthProvider(AuthProvider):
|
||||||
if self.config.jwks is None:
|
if self.config.jwks is None:
|
||||||
raise ValueError("JWKS is not configured")
|
raise ValueError("JWKS is not configured")
|
||||||
if time.time() - self._jwks_at > self.config.jwks.key_recheck_period:
|
if time.time() - self._jwks_at > self.config.jwks.key_recheck_period:
|
||||||
|
headers = {}
|
||||||
|
if self.config.jwks.token:
|
||||||
|
headers["Authorization"] = f"Bearer {self.config.jwks.token}"
|
||||||
verify = self.config.tls_cafile.as_posix() if self.config.tls_cafile else self.config.verify_tls
|
verify = self.config.tls_cafile.as_posix() if self.config.tls_cafile else self.config.verify_tls
|
||||||
async with httpx.AsyncClient(verify=verify) as client:
|
async with httpx.AsyncClient(verify=verify) as client:
|
||||||
res = await client.get(self.config.jwks.uri, timeout=5)
|
res = await client.get(self.config.jwks.uri, timeout=5, headers=headers)
|
||||||
res.raise_for_status()
|
res.raise_for_status()
|
||||||
jwks_data = res.json()["keys"]
|
jwks_data = res.json()["keys"]
|
||||||
updated = {}
|
updated = {}
|
||||||
|
|
|
@ -115,7 +115,7 @@ def parse_environment_config(env_config: str) -> dict[str, int]:
|
||||||
|
|
||||||
class CustomRichHandler(RichHandler):
|
class CustomRichHandler(RichHandler):
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
kwargs["console"] = Console(width=120)
|
kwargs["console"] = Console(width=150)
|
||||||
super().__init__(*args, **kwargs)
|
super().__init__(*args, **kwargs)
|
||||||
|
|
||||||
def emit(self, record):
|
def emit(self, record):
|
||||||
|
|
|
@ -9,9 +9,7 @@ import base64
|
||||||
import io
|
import io
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
import time
|
from typing import Any
|
||||||
import uuid
|
|
||||||
from typing import Any, Literal
|
|
||||||
|
|
||||||
import faiss
|
import faiss
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
@ -24,14 +22,11 @@ from llama_stack.apis.vector_io import (
|
||||||
Chunk,
|
Chunk,
|
||||||
QueryChunksResponse,
|
QueryChunksResponse,
|
||||||
VectorIO,
|
VectorIO,
|
||||||
VectorStoreDeleteResponse,
|
|
||||||
VectorStoreListResponse,
|
|
||||||
VectorStoreObject,
|
|
||||||
VectorStoreSearchResponse,
|
|
||||||
)
|
)
|
||||||
from llama_stack.providers.datatypes import VectorDBsProtocolPrivate
|
from llama_stack.providers.datatypes import VectorDBsProtocolPrivate
|
||||||
from llama_stack.providers.utils.kvstore import kvstore_impl
|
from llama_stack.providers.utils.kvstore import kvstore_impl
|
||||||
from llama_stack.providers.utils.kvstore.api import KVStore
|
from llama_stack.providers.utils.kvstore.api import KVStore
|
||||||
|
from llama_stack.providers.utils.memory.openai_vector_store_mixin import OpenAIVectorStoreMixin
|
||||||
from llama_stack.providers.utils.memory.vector_store import (
|
from llama_stack.providers.utils.memory.vector_store import (
|
||||||
EmbeddingIndex,
|
EmbeddingIndex,
|
||||||
VectorDBWithIndex,
|
VectorDBWithIndex,
|
||||||
|
@ -47,10 +42,6 @@ FAISS_INDEX_PREFIX = f"faiss_index:{VERSION}::"
|
||||||
OPENAI_VECTOR_STORES_PREFIX = f"openai_vector_stores:{VERSION}::"
|
OPENAI_VECTOR_STORES_PREFIX = f"openai_vector_stores:{VERSION}::"
|
||||||
|
|
||||||
|
|
||||||
# In faiss, since we do
|
|
||||||
CHUNK_MULTIPLIER = 5
|
|
||||||
|
|
||||||
|
|
||||||
class FaissIndex(EmbeddingIndex):
|
class FaissIndex(EmbeddingIndex):
|
||||||
def __init__(self, dimension: int, kvstore: KVStore | None = None, bank_id: str | None = None):
|
def __init__(self, dimension: int, kvstore: KVStore | None = None, bank_id: str | None = None):
|
||||||
self.index = faiss.IndexFlatL2(dimension)
|
self.index = faiss.IndexFlatL2(dimension)
|
||||||
|
@ -140,7 +131,7 @@ class FaissIndex(EmbeddingIndex):
|
||||||
raise NotImplementedError("Keyword search is not supported in FAISS")
|
raise NotImplementedError("Keyword search is not supported in FAISS")
|
||||||
|
|
||||||
|
|
||||||
class FaissVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate):
|
class FaissVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolPrivate):
|
||||||
def __init__(self, config: FaissVectorIOConfig, inference_api: Inference) -> None:
|
def __init__(self, config: FaissVectorIOConfig, inference_api: Inference) -> None:
|
||||||
self.config = config
|
self.config = config
|
||||||
self.inference_api = inference_api
|
self.inference_api = inference_api
|
||||||
|
@ -164,14 +155,8 @@ class FaissVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate):
|
||||||
)
|
)
|
||||||
self.cache[vector_db.identifier] = index
|
self.cache[vector_db.identifier] = index
|
||||||
|
|
||||||
# Load existing OpenAI vector stores
|
# Load existing OpenAI vector stores using the mixin method
|
||||||
start_key = OPENAI_VECTOR_STORES_PREFIX
|
self.openai_vector_stores = await self._load_openai_vector_stores()
|
||||||
end_key = f"{OPENAI_VECTOR_STORES_PREFIX}\xff"
|
|
||||||
stored_openai_stores = await self.kvstore.values_in_range(start_key, end_key)
|
|
||||||
|
|
||||||
for store_data in stored_openai_stores:
|
|
||||||
store_info = json.loads(store_data)
|
|
||||||
self.openai_vector_stores[store_info["id"]] = store_info
|
|
||||||
|
|
||||||
async def shutdown(self) -> None:
|
async def shutdown(self) -> None:
|
||||||
# Cleanup if needed
|
# Cleanup if needed
|
||||||
|
@ -234,285 +219,34 @@ class FaissVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate):
|
||||||
|
|
||||||
return await index.query_chunks(query, params)
|
return await index.query_chunks(query, params)
|
||||||
|
|
||||||
# OpenAI Vector Stores API endpoints implementation
|
# OpenAI Vector Store Mixin abstract method implementations
|
||||||
async def openai_create_vector_store(
|
async def _save_openai_vector_store(self, store_id: str, store_info: dict[str, Any]) -> None:
|
||||||
self,
|
"""Save vector store metadata to kvstore."""
|
||||||
name: str | None = None,
|
|
||||||
file_ids: list[str] | None = None,
|
|
||||||
expires_after: dict[str, Any] | None = None,
|
|
||||||
chunking_strategy: dict[str, Any] | None = None,
|
|
||||||
metadata: dict[str, Any] | None = None,
|
|
||||||
embedding_model: str | None = None,
|
|
||||||
embedding_dimension: int | None = 384,
|
|
||||||
provider_id: str | None = None,
|
|
||||||
provider_vector_db_id: str | None = None,
|
|
||||||
) -> VectorStoreObject:
|
|
||||||
"""Creates a vector store."""
|
|
||||||
assert self.kvstore is not None
|
assert self.kvstore is not None
|
||||||
# store and vector_db have the same id
|
|
||||||
store_id = name or str(uuid.uuid4())
|
|
||||||
created_at = int(time.time())
|
|
||||||
|
|
||||||
if provider_id is None:
|
|
||||||
raise ValueError("Provider ID is required")
|
|
||||||
|
|
||||||
if embedding_model is None:
|
|
||||||
raise ValueError("Embedding model is required")
|
|
||||||
|
|
||||||
# Use provided embedding dimension or default to 384
|
|
||||||
if embedding_dimension is None:
|
|
||||||
raise ValueError("Embedding dimension is required")
|
|
||||||
|
|
||||||
provider_vector_db_id = provider_vector_db_id or store_id
|
|
||||||
vector_db = VectorDB(
|
|
||||||
identifier=store_id,
|
|
||||||
embedding_dimension=embedding_dimension,
|
|
||||||
embedding_model=embedding_model,
|
|
||||||
provider_id=provider_id,
|
|
||||||
provider_resource_id=provider_vector_db_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Register the vector DB
|
|
||||||
await self.register_vector_db(vector_db)
|
|
||||||
|
|
||||||
# Create OpenAI vector store metadata
|
|
||||||
store_info = {
|
|
||||||
"id": store_id,
|
|
||||||
"object": "vector_store",
|
|
||||||
"created_at": created_at,
|
|
||||||
"name": store_id,
|
|
||||||
"usage_bytes": 0,
|
|
||||||
"file_counts": {},
|
|
||||||
"status": "completed",
|
|
||||||
"expires_after": expires_after,
|
|
||||||
"expires_at": None,
|
|
||||||
"last_active_at": created_at,
|
|
||||||
"file_ids": file_ids or [],
|
|
||||||
"chunking_strategy": chunking_strategy,
|
|
||||||
}
|
|
||||||
|
|
||||||
# Add provider information to metadata if provided
|
|
||||||
metadata = metadata or {}
|
|
||||||
if provider_id:
|
|
||||||
metadata["provider_id"] = provider_id
|
|
||||||
if provider_vector_db_id:
|
|
||||||
metadata["provider_vector_db_id"] = provider_vector_db_id
|
|
||||||
store_info["metadata"] = metadata
|
|
||||||
|
|
||||||
# Store in kvstore
|
|
||||||
key = f"{OPENAI_VECTOR_STORES_PREFIX}{store_id}"
|
key = f"{OPENAI_VECTOR_STORES_PREFIX}{store_id}"
|
||||||
await self.kvstore.set(key=key, value=json.dumps(store_info))
|
await self.kvstore.set(key=key, value=json.dumps(store_info))
|
||||||
|
|
||||||
# Store in memory cache
|
async def _load_openai_vector_stores(self) -> dict[str, dict[str, Any]]:
|
||||||
self.openai_vector_stores[store_id] = store_info
|
"""Load all vector store metadata from kvstore."""
|
||||||
|
|
||||||
return VectorStoreObject(
|
|
||||||
id=store_id,
|
|
||||||
created_at=created_at,
|
|
||||||
name=store_id,
|
|
||||||
usage_bytes=0,
|
|
||||||
file_counts={},
|
|
||||||
status="completed",
|
|
||||||
expires_after=expires_after,
|
|
||||||
expires_at=None,
|
|
||||||
last_active_at=created_at,
|
|
||||||
metadata=metadata,
|
|
||||||
)
|
|
||||||
|
|
||||||
async def openai_list_vector_stores(
|
|
||||||
self,
|
|
||||||
limit: int = 20,
|
|
||||||
order: str = "desc",
|
|
||||||
after: str | None = None,
|
|
||||||
before: str | None = None,
|
|
||||||
) -> VectorStoreListResponse:
|
|
||||||
"""Returns a list of vector stores."""
|
|
||||||
# Get all vector stores
|
|
||||||
all_stores = list(self.openai_vector_stores.values())
|
|
||||||
|
|
||||||
# Sort by created_at
|
|
||||||
reverse_order = order == "desc"
|
|
||||||
all_stores.sort(key=lambda x: x["created_at"], reverse=reverse_order)
|
|
||||||
|
|
||||||
# Apply cursor-based pagination
|
|
||||||
if after:
|
|
||||||
after_index = next((i for i, store in enumerate(all_stores) if store["id"] == after), -1)
|
|
||||||
if after_index >= 0:
|
|
||||||
all_stores = all_stores[after_index + 1 :]
|
|
||||||
|
|
||||||
if before:
|
|
||||||
before_index = next((i for i, store in enumerate(all_stores) if store["id"] == before), len(all_stores))
|
|
||||||
all_stores = all_stores[:before_index]
|
|
||||||
|
|
||||||
# Apply limit
|
|
||||||
limited_stores = all_stores[:limit]
|
|
||||||
# Convert to VectorStoreObject instances
|
|
||||||
data = [VectorStoreObject(**store) for store in limited_stores]
|
|
||||||
|
|
||||||
# Determine pagination info
|
|
||||||
has_more = len(all_stores) > limit
|
|
||||||
first_id = data[0].id if data else None
|
|
||||||
last_id = data[-1].id if data else None
|
|
||||||
|
|
||||||
return VectorStoreListResponse(
|
|
||||||
data=data,
|
|
||||||
has_more=has_more,
|
|
||||||
first_id=first_id,
|
|
||||||
last_id=last_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
async def openai_retrieve_vector_store(
|
|
||||||
self,
|
|
||||||
vector_store_id: str,
|
|
||||||
) -> VectorStoreObject:
|
|
||||||
"""Retrieves a vector store."""
|
|
||||||
if vector_store_id not in self.openai_vector_stores:
|
|
||||||
raise ValueError(f"Vector store {vector_store_id} not found")
|
|
||||||
|
|
||||||
store_info = self.openai_vector_stores[vector_store_id]
|
|
||||||
return VectorStoreObject(**store_info)
|
|
||||||
|
|
||||||
async def openai_update_vector_store(
|
|
||||||
self,
|
|
||||||
vector_store_id: str,
|
|
||||||
name: str | None = None,
|
|
||||||
expires_after: dict[str, Any] | None = None,
|
|
||||||
metadata: dict[str, Any] | None = None,
|
|
||||||
) -> VectorStoreObject:
|
|
||||||
"""Modifies a vector store."""
|
|
||||||
assert self.kvstore is not None
|
assert self.kvstore is not None
|
||||||
if vector_store_id not in self.openai_vector_stores:
|
start_key = OPENAI_VECTOR_STORES_PREFIX
|
||||||
raise ValueError(f"Vector store {vector_store_id} not found")
|
end_key = f"{OPENAI_VECTOR_STORES_PREFIX}\xff"
|
||||||
|
stored_openai_stores = await self.kvstore.values_in_range(start_key, end_key)
|
||||||
|
|
||||||
store_info = self.openai_vector_stores[vector_store_id].copy()
|
stores = {}
|
||||||
|
for store_data in stored_openai_stores:
|
||||||
|
store_info = json.loads(store_data)
|
||||||
|
stores[store_info["id"]] = store_info
|
||||||
|
return stores
|
||||||
|
|
||||||
# Update fields if provided
|
async def _update_openai_vector_store(self, store_id: str, store_info: dict[str, Any]) -> None:
|
||||||
if name is not None:
|
"""Update vector store metadata in kvstore."""
|
||||||
store_info["name"] = name
|
assert self.kvstore is not None
|
||||||
if expires_after is not None:
|
key = f"{OPENAI_VECTOR_STORES_PREFIX}{store_id}"
|
||||||
store_info["expires_after"] = expires_after
|
|
||||||
if metadata is not None:
|
|
||||||
store_info["metadata"] = metadata
|
|
||||||
|
|
||||||
# Update last_active_at
|
|
||||||
store_info["last_active_at"] = int(time.time())
|
|
||||||
|
|
||||||
# Save to kvstore
|
|
||||||
key = f"{OPENAI_VECTOR_STORES_PREFIX}{vector_store_id}"
|
|
||||||
await self.kvstore.set(key=key, value=json.dumps(store_info))
|
await self.kvstore.set(key=key, value=json.dumps(store_info))
|
||||||
|
|
||||||
# Update in-memory cache
|
async def _delete_openai_vector_store_from_storage(self, store_id: str) -> None:
|
||||||
self.openai_vector_stores[vector_store_id] = store_info
|
"""Delete vector store metadata from kvstore."""
|
||||||
|
|
||||||
return VectorStoreObject(**store_info)
|
|
||||||
|
|
||||||
async def openai_delete_vector_store(
|
|
||||||
self,
|
|
||||||
vector_store_id: str,
|
|
||||||
) -> VectorStoreDeleteResponse:
|
|
||||||
"""Delete a vector store."""
|
|
||||||
assert self.kvstore is not None
|
assert self.kvstore is not None
|
||||||
if vector_store_id not in self.openai_vector_stores:
|
key = f"{OPENAI_VECTOR_STORES_PREFIX}{store_id}"
|
||||||
raise ValueError(f"Vector store {vector_store_id} not found")
|
|
||||||
|
|
||||||
# Delete from kvstore
|
|
||||||
key = f"{OPENAI_VECTOR_STORES_PREFIX}{vector_store_id}"
|
|
||||||
await self.kvstore.delete(key)
|
await self.kvstore.delete(key)
|
||||||
|
|
||||||
# Delete from in-memory cache
|
|
||||||
del self.openai_vector_stores[vector_store_id]
|
|
||||||
|
|
||||||
# Also delete the underlying vector DB
|
|
||||||
try:
|
|
||||||
await self.unregister_vector_db(vector_store_id)
|
|
||||||
except Exception as e:
|
|
||||||
logger.warning(f"Failed to delete underlying vector DB {vector_store_id}: {e}")
|
|
||||||
|
|
||||||
return VectorStoreDeleteResponse(
|
|
||||||
id=vector_store_id,
|
|
||||||
deleted=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
async def openai_search_vector_store(
|
|
||||||
self,
|
|
||||||
vector_store_id: str,
|
|
||||||
query: str | list[str],
|
|
||||||
filters: dict[str, Any] | None = None,
|
|
||||||
max_num_results: int = 10,
|
|
||||||
ranking_options: dict[str, Any] | None = None,
|
|
||||||
rewrite_query: bool = False,
|
|
||||||
search_mode: Literal["keyword", "vector", "hybrid"] = "vector",
|
|
||||||
) -> VectorStoreSearchResponse:
|
|
||||||
"""Search for chunks in a vector store."""
|
|
||||||
if vector_store_id not in self.openai_vector_stores:
|
|
||||||
raise ValueError(f"Vector store {vector_store_id} not found")
|
|
||||||
|
|
||||||
if isinstance(query, list):
|
|
||||||
search_query = " ".join(query)
|
|
||||||
else:
|
|
||||||
search_query = query
|
|
||||||
|
|
||||||
try:
|
|
||||||
score_threshold = ranking_options.get("score_threshold", 0.0) if ranking_options else 0.0
|
|
||||||
params = {
|
|
||||||
"max_chunks": max_num_results * CHUNK_MULTIPLIER,
|
|
||||||
"score_threshold": score_threshold,
|
|
||||||
"mode": search_mode,
|
|
||||||
}
|
|
||||||
# TODO: Add support for ranking_options.ranker
|
|
||||||
|
|
||||||
response = await self.query_chunks(
|
|
||||||
vector_db_id=vector_store_id,
|
|
||||||
query=search_query,
|
|
||||||
params=params,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Convert response to OpenAI format
|
|
||||||
data = []
|
|
||||||
for i, (chunk, score) in enumerate(zip(response.chunks, response.scores, strict=False)):
|
|
||||||
# Apply score based filtering
|
|
||||||
if score < score_threshold:
|
|
||||||
continue
|
|
||||||
|
|
||||||
# Apply filters if provided
|
|
||||||
if filters:
|
|
||||||
# Simple metadata filtering
|
|
||||||
if not self._matches_filters(chunk.metadata, filters):
|
|
||||||
continue
|
|
||||||
|
|
||||||
chunk_data = {
|
|
||||||
"id": f"chunk_{i}",
|
|
||||||
"object": "vector_store.search_result",
|
|
||||||
"score": score,
|
|
||||||
"content": chunk.content.content if hasattr(chunk.content, "content") else str(chunk.content),
|
|
||||||
"metadata": chunk.metadata,
|
|
||||||
}
|
|
||||||
data.append(chunk_data)
|
|
||||||
if len(data) >= max_num_results:
|
|
||||||
break
|
|
||||||
|
|
||||||
return VectorStoreSearchResponse(
|
|
||||||
search_query=search_query,
|
|
||||||
data=data,
|
|
||||||
has_more=False, # For simplicity, we don't implement pagination here
|
|
||||||
next_page=None,
|
|
||||||
)
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error searching vector store {vector_store_id}: {e}")
|
|
||||||
# Return empty results on error
|
|
||||||
return VectorStoreSearchResponse(
|
|
||||||
search_query=search_query,
|
|
||||||
data=[],
|
|
||||||
has_more=False,
|
|
||||||
next_page=None,
|
|
||||||
)
|
|
||||||
|
|
||||||
def _matches_filters(self, metadata: dict[str, Any], filters: dict[str, Any]) -> bool:
|
|
||||||
"""Check if metadata matches the provided filters."""
|
|
||||||
for key, value in filters.items():
|
|
||||||
if key not in metadata:
|
|
||||||
return False
|
|
||||||
if metadata[key] != value:
|
|
||||||
return False
|
|
||||||
return True
|
|
||||||
|
|
|
@ -10,9 +10,8 @@ import json
|
||||||
import logging
|
import logging
|
||||||
import sqlite3
|
import sqlite3
|
||||||
import struct
|
import struct
|
||||||
import time
|
|
||||||
import uuid
|
import uuid
|
||||||
from typing import Any, Literal
|
from typing import Any
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import sqlite_vec
|
import sqlite_vec
|
||||||
|
@ -24,12 +23,9 @@ from llama_stack.apis.vector_io import (
|
||||||
Chunk,
|
Chunk,
|
||||||
QueryChunksResponse,
|
QueryChunksResponse,
|
||||||
VectorIO,
|
VectorIO,
|
||||||
VectorStoreDeleteResponse,
|
|
||||||
VectorStoreListResponse,
|
|
||||||
VectorStoreObject,
|
|
||||||
VectorStoreSearchResponse,
|
|
||||||
)
|
)
|
||||||
from llama_stack.providers.datatypes import VectorDBsProtocolPrivate
|
from llama_stack.providers.datatypes import VectorDBsProtocolPrivate
|
||||||
|
from llama_stack.providers.utils.memory.openai_vector_store_mixin import OpenAIVectorStoreMixin
|
||||||
from llama_stack.providers.utils.memory.vector_store import EmbeddingIndex, VectorDBWithIndex
|
from llama_stack.providers.utils.memory.vector_store import EmbeddingIndex, VectorDBWithIndex
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
@ -39,11 +35,6 @@ VECTOR_SEARCH = "vector"
|
||||||
KEYWORD_SEARCH = "keyword"
|
KEYWORD_SEARCH = "keyword"
|
||||||
SEARCH_MODES = {VECTOR_SEARCH, KEYWORD_SEARCH}
|
SEARCH_MODES = {VECTOR_SEARCH, KEYWORD_SEARCH}
|
||||||
|
|
||||||
# Constants for OpenAI vector stores (similar to faiss)
|
|
||||||
VERSION = "v3"
|
|
||||||
OPENAI_VECTOR_STORES_PREFIX = f"openai_vector_stores:{VERSION}::"
|
|
||||||
CHUNK_MULTIPLIER = 5
|
|
||||||
|
|
||||||
|
|
||||||
def serialize_vector(vector: list[float]) -> bytes:
|
def serialize_vector(vector: list[float]) -> bytes:
|
||||||
"""Serialize a list of floats into a compact binary representation."""
|
"""Serialize a list of floats into a compact binary representation."""
|
||||||
|
@ -303,7 +294,7 @@ class SQLiteVecIndex(EmbeddingIndex):
|
||||||
return QueryChunksResponse(chunks=chunks, scores=scores)
|
return QueryChunksResponse(chunks=chunks, scores=scores)
|
||||||
|
|
||||||
|
|
||||||
class SQLiteVecVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate):
|
class SQLiteVecVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolPrivate):
|
||||||
"""
|
"""
|
||||||
A VectorIO implementation using SQLite + sqlite_vec.
|
A VectorIO implementation using SQLite + sqlite_vec.
|
||||||
This class handles vector database registration (with metadata stored in a table named `vector_dbs`)
|
This class handles vector database registration (with metadata stored in a table named `vector_dbs`)
|
||||||
|
@ -340,15 +331,12 @@ class SQLiteVecVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate):
|
||||||
# Load any existing vector DB registrations.
|
# Load any existing vector DB registrations.
|
||||||
cur.execute("SELECT metadata FROM vector_dbs")
|
cur.execute("SELECT metadata FROM vector_dbs")
|
||||||
vector_db_rows = cur.fetchall()
|
vector_db_rows = cur.fetchall()
|
||||||
# Load any existing OpenAI vector stores.
|
return vector_db_rows
|
||||||
cur.execute("SELECT metadata FROM openai_vector_stores")
|
|
||||||
openai_store_rows = cur.fetchall()
|
|
||||||
return vector_db_rows, openai_store_rows
|
|
||||||
finally:
|
finally:
|
||||||
cur.close()
|
cur.close()
|
||||||
connection.close()
|
connection.close()
|
||||||
|
|
||||||
vector_db_rows, openai_store_rows = await asyncio.to_thread(_setup_connection)
|
vector_db_rows = await asyncio.to_thread(_setup_connection)
|
||||||
|
|
||||||
# Load existing vector DBs
|
# Load existing vector DBs
|
||||||
for row in vector_db_rows:
|
for row in vector_db_rows:
|
||||||
|
@ -359,11 +347,8 @@ class SQLiteVecVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate):
|
||||||
)
|
)
|
||||||
self.cache[vector_db.identifier] = VectorDBWithIndex(vector_db, index, self.inference_api)
|
self.cache[vector_db.identifier] = VectorDBWithIndex(vector_db, index, self.inference_api)
|
||||||
|
|
||||||
# Load existing OpenAI vector stores
|
# Load existing OpenAI vector stores using the mixin method
|
||||||
for row in openai_store_rows:
|
self.openai_vector_stores = await self._load_openai_vector_stores()
|
||||||
store_data = row[0]
|
|
||||||
store_info = json.loads(store_data)
|
|
||||||
self.openai_vector_stores[store_info["id"]] = store_info
|
|
||||||
|
|
||||||
async def shutdown(self) -> None:
|
async def shutdown(self) -> None:
|
||||||
# nothing to do since we don't maintain a persistent connection
|
# nothing to do since we don't maintain a persistent connection
|
||||||
|
@ -409,6 +394,87 @@ class SQLiteVecVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate):
|
||||||
|
|
||||||
await asyncio.to_thread(_delete_vector_db_from_registry)
|
await asyncio.to_thread(_delete_vector_db_from_registry)
|
||||||
|
|
||||||
|
# OpenAI Vector Store Mixin abstract method implementations
|
||||||
|
async def _save_openai_vector_store(self, store_id: str, store_info: dict[str, Any]) -> None:
|
||||||
|
"""Save vector store metadata to SQLite database."""
|
||||||
|
|
||||||
|
def _store():
|
||||||
|
connection = _create_sqlite_connection(self.config.db_path)
|
||||||
|
cur = connection.cursor()
|
||||||
|
try:
|
||||||
|
cur.execute(
|
||||||
|
"INSERT OR REPLACE INTO openai_vector_stores (id, metadata) VALUES (?, ?)",
|
||||||
|
(store_id, json.dumps(store_info)),
|
||||||
|
)
|
||||||
|
connection.commit()
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error saving openai vector store {store_id}: {e}")
|
||||||
|
raise
|
||||||
|
finally:
|
||||||
|
cur.close()
|
||||||
|
connection.close()
|
||||||
|
|
||||||
|
try:
|
||||||
|
await asyncio.to_thread(_store)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error saving openai vector store {store_id}: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
async def _load_openai_vector_stores(self) -> dict[str, dict[str, Any]]:
|
||||||
|
"""Load all vector store metadata from SQLite database."""
|
||||||
|
|
||||||
|
def _load():
|
||||||
|
connection = _create_sqlite_connection(self.config.db_path)
|
||||||
|
cur = connection.cursor()
|
||||||
|
try:
|
||||||
|
cur.execute("SELECT metadata FROM openai_vector_stores")
|
||||||
|
rows = cur.fetchall()
|
||||||
|
return rows
|
||||||
|
finally:
|
||||||
|
cur.close()
|
||||||
|
connection.close()
|
||||||
|
|
||||||
|
rows = await asyncio.to_thread(_load)
|
||||||
|
stores = {}
|
||||||
|
for row in rows:
|
||||||
|
store_data = row[0]
|
||||||
|
store_info = json.loads(store_data)
|
||||||
|
stores[store_info["id"]] = store_info
|
||||||
|
return stores
|
||||||
|
|
||||||
|
async def _update_openai_vector_store(self, store_id: str, store_info: dict[str, Any]) -> None:
|
||||||
|
"""Update vector store metadata in SQLite database."""
|
||||||
|
|
||||||
|
def _update():
|
||||||
|
connection = _create_sqlite_connection(self.config.db_path)
|
||||||
|
cur = connection.cursor()
|
||||||
|
try:
|
||||||
|
cur.execute(
|
||||||
|
"UPDATE openai_vector_stores SET metadata = ? WHERE id = ?",
|
||||||
|
(json.dumps(store_info), store_id),
|
||||||
|
)
|
||||||
|
connection.commit()
|
||||||
|
finally:
|
||||||
|
cur.close()
|
||||||
|
connection.close()
|
||||||
|
|
||||||
|
await asyncio.to_thread(_update)
|
||||||
|
|
||||||
|
async def _delete_openai_vector_store_from_storage(self, store_id: str) -> None:
|
||||||
|
"""Delete vector store metadata from SQLite database."""
|
||||||
|
|
||||||
|
def _delete():
|
||||||
|
connection = _create_sqlite_connection(self.config.db_path)
|
||||||
|
cur = connection.cursor()
|
||||||
|
try:
|
||||||
|
cur.execute("DELETE FROM openai_vector_stores WHERE id = ?", (store_id,))
|
||||||
|
connection.commit()
|
||||||
|
finally:
|
||||||
|
cur.close()
|
||||||
|
connection.close()
|
||||||
|
|
||||||
|
await asyncio.to_thread(_delete)
|
||||||
|
|
||||||
async def insert_chunks(self, vector_db_id: str, chunks: list[Chunk], ttl_seconds: int | None = None) -> None:
|
async def insert_chunks(self, vector_db_id: str, chunks: list[Chunk], ttl_seconds: int | None = None) -> None:
|
||||||
if vector_db_id not in self.cache:
|
if vector_db_id not in self.cache:
|
||||||
raise ValueError(f"Vector DB {vector_db_id} not found. Found: {list(self.cache.keys())}")
|
raise ValueError(f"Vector DB {vector_db_id} not found. Found: {list(self.cache.keys())}")
|
||||||
|
@ -423,318 +489,6 @@ class SQLiteVecVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate):
|
||||||
raise ValueError(f"Vector DB {vector_db_id} not found")
|
raise ValueError(f"Vector DB {vector_db_id} not found")
|
||||||
return await self.cache[vector_db_id].query_chunks(query, params)
|
return await self.cache[vector_db_id].query_chunks(query, params)
|
||||||
|
|
||||||
async def openai_create_vector_store(
|
|
||||||
self,
|
|
||||||
name: str | None = None,
|
|
||||||
file_ids: list[str] | None = None,
|
|
||||||
expires_after: dict[str, Any] | None = None,
|
|
||||||
chunking_strategy: dict[str, Any] | None = None,
|
|
||||||
metadata: dict[str, Any] | None = None,
|
|
||||||
embedding_model: str | None = None,
|
|
||||||
embedding_dimension: int | None = 384,
|
|
||||||
provider_id: str | None = None,
|
|
||||||
provider_vector_db_id: str | None = None,
|
|
||||||
) -> VectorStoreObject:
|
|
||||||
"""Creates a vector store."""
|
|
||||||
# store and vector_db have the same id
|
|
||||||
store_id = name or str(uuid.uuid4())
|
|
||||||
created_at = int(time.time())
|
|
||||||
|
|
||||||
if provider_id is None:
|
|
||||||
raise ValueError("Provider ID is required")
|
|
||||||
|
|
||||||
if embedding_model is None:
|
|
||||||
raise ValueError("Embedding model is required")
|
|
||||||
|
|
||||||
# Use provided embedding dimension or default to 384
|
|
||||||
if embedding_dimension is None:
|
|
||||||
raise ValueError("Embedding dimension is required")
|
|
||||||
|
|
||||||
provider_vector_db_id = provider_vector_db_id or store_id
|
|
||||||
vector_db = VectorDB(
|
|
||||||
identifier=store_id,
|
|
||||||
embedding_dimension=embedding_dimension,
|
|
||||||
embedding_model=embedding_model,
|
|
||||||
provider_id=provider_id,
|
|
||||||
provider_resource_id=provider_vector_db_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Register the vector DB
|
|
||||||
await self.register_vector_db(vector_db)
|
|
||||||
|
|
||||||
# Create OpenAI vector store metadata
|
|
||||||
store_info = {
|
|
||||||
"id": store_id,
|
|
||||||
"object": "vector_store",
|
|
||||||
"created_at": created_at,
|
|
||||||
"name": store_id,
|
|
||||||
"usage_bytes": 0,
|
|
||||||
"file_counts": {},
|
|
||||||
"status": "completed",
|
|
||||||
"expires_after": expires_after,
|
|
||||||
"expires_at": None,
|
|
||||||
"last_active_at": created_at,
|
|
||||||
"file_ids": file_ids or [],
|
|
||||||
"chunking_strategy": chunking_strategy,
|
|
||||||
}
|
|
||||||
|
|
||||||
# Add provider information to metadata if provided
|
|
||||||
metadata = metadata or {}
|
|
||||||
if provider_id:
|
|
||||||
metadata["provider_id"] = provider_id
|
|
||||||
if provider_vector_db_id:
|
|
||||||
metadata["provider_vector_db_id"] = provider_vector_db_id
|
|
||||||
store_info["metadata"] = metadata
|
|
||||||
|
|
||||||
# Store in SQLite database
|
|
||||||
def _store_openai_vector_store():
|
|
||||||
connection = _create_sqlite_connection(self.config.db_path)
|
|
||||||
cur = connection.cursor()
|
|
||||||
try:
|
|
||||||
cur.execute(
|
|
||||||
"INSERT OR REPLACE INTO openai_vector_stores (id, metadata) VALUES (?, ?)",
|
|
||||||
(store_id, json.dumps(store_info)),
|
|
||||||
)
|
|
||||||
connection.commit()
|
|
||||||
finally:
|
|
||||||
cur.close()
|
|
||||||
connection.close()
|
|
||||||
|
|
||||||
await asyncio.to_thread(_store_openai_vector_store)
|
|
||||||
|
|
||||||
# Store in memory cache
|
|
||||||
self.openai_vector_stores[store_id] = store_info
|
|
||||||
|
|
||||||
return VectorStoreObject(
|
|
||||||
id=store_id,
|
|
||||||
created_at=created_at,
|
|
||||||
name=store_id,
|
|
||||||
usage_bytes=0,
|
|
||||||
file_counts={},
|
|
||||||
status="completed",
|
|
||||||
expires_after=expires_after,
|
|
||||||
expires_at=None,
|
|
||||||
last_active_at=created_at,
|
|
||||||
metadata=metadata,
|
|
||||||
)
|
|
||||||
|
|
||||||
async def openai_list_vector_stores(
|
|
||||||
self,
|
|
||||||
limit: int = 20,
|
|
||||||
order: str = "desc",
|
|
||||||
after: str | None = None,
|
|
||||||
before: str | None = None,
|
|
||||||
) -> VectorStoreListResponse:
|
|
||||||
"""Returns a list of vector stores."""
|
|
||||||
# Get all vector stores
|
|
||||||
all_stores = list(self.openai_vector_stores.values())
|
|
||||||
|
|
||||||
# Sort by created_at
|
|
||||||
reverse_order = order == "desc"
|
|
||||||
all_stores.sort(key=lambda x: x["created_at"], reverse=reverse_order)
|
|
||||||
|
|
||||||
# Apply cursor-based pagination
|
|
||||||
if after:
|
|
||||||
after_index = next((i for i, store in enumerate(all_stores) if store["id"] == after), -1)
|
|
||||||
if after_index >= 0:
|
|
||||||
all_stores = all_stores[after_index + 1 :]
|
|
||||||
|
|
||||||
if before:
|
|
||||||
before_index = next((i for i, store in enumerate(all_stores) if store["id"] == before), len(all_stores))
|
|
||||||
all_stores = all_stores[:before_index]
|
|
||||||
|
|
||||||
# Apply limit
|
|
||||||
limited_stores = all_stores[:limit]
|
|
||||||
# Convert to VectorStoreObject instances
|
|
||||||
data = [VectorStoreObject(**store) for store in limited_stores]
|
|
||||||
|
|
||||||
# Determine pagination info
|
|
||||||
has_more = len(all_stores) > limit
|
|
||||||
first_id = data[0].id if data else None
|
|
||||||
last_id = data[-1].id if data else None
|
|
||||||
|
|
||||||
return VectorStoreListResponse(
|
|
||||||
data=data,
|
|
||||||
has_more=has_more,
|
|
||||||
first_id=first_id,
|
|
||||||
last_id=last_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
async def openai_retrieve_vector_store(
|
|
||||||
self,
|
|
||||||
vector_store_id: str,
|
|
||||||
) -> VectorStoreObject:
|
|
||||||
"""Retrieves a vector store."""
|
|
||||||
if vector_store_id not in self.openai_vector_stores:
|
|
||||||
raise ValueError(f"Vector store {vector_store_id} not found")
|
|
||||||
|
|
||||||
store_info = self.openai_vector_stores[vector_store_id]
|
|
||||||
return VectorStoreObject(**store_info)
|
|
||||||
|
|
||||||
async def openai_update_vector_store(
|
|
||||||
self,
|
|
||||||
vector_store_id: str,
|
|
||||||
name: str | None = None,
|
|
||||||
expires_after: dict[str, Any] | None = None,
|
|
||||||
metadata: dict[str, Any] | None = None,
|
|
||||||
) -> VectorStoreObject:
|
|
||||||
"""Modifies a vector store."""
|
|
||||||
if vector_store_id not in self.openai_vector_stores:
|
|
||||||
raise ValueError(f"Vector store {vector_store_id} not found")
|
|
||||||
|
|
||||||
store_info = self.openai_vector_stores[vector_store_id].copy()
|
|
||||||
|
|
||||||
# Update fields if provided
|
|
||||||
if name is not None:
|
|
||||||
store_info["name"] = name
|
|
||||||
if expires_after is not None:
|
|
||||||
store_info["expires_after"] = expires_after
|
|
||||||
if metadata is not None:
|
|
||||||
store_info["metadata"] = metadata
|
|
||||||
|
|
||||||
# Update last_active_at
|
|
||||||
store_info["last_active_at"] = int(time.time())
|
|
||||||
|
|
||||||
# Save to SQLite database
|
|
||||||
def _update_openai_vector_store():
|
|
||||||
connection = _create_sqlite_connection(self.config.db_path)
|
|
||||||
cur = connection.cursor()
|
|
||||||
try:
|
|
||||||
cur.execute(
|
|
||||||
"UPDATE openai_vector_stores SET metadata = ? WHERE id = ?",
|
|
||||||
(json.dumps(store_info), vector_store_id),
|
|
||||||
)
|
|
||||||
connection.commit()
|
|
||||||
finally:
|
|
||||||
cur.close()
|
|
||||||
connection.close()
|
|
||||||
|
|
||||||
await asyncio.to_thread(_update_openai_vector_store)
|
|
||||||
|
|
||||||
# Update in-memory cache
|
|
||||||
self.openai_vector_stores[vector_store_id] = store_info
|
|
||||||
|
|
||||||
return VectorStoreObject(**store_info)
|
|
||||||
|
|
||||||
async def openai_delete_vector_store(
|
|
||||||
self,
|
|
||||||
vector_store_id: str,
|
|
||||||
) -> VectorStoreDeleteResponse:
|
|
||||||
"""Delete a vector store."""
|
|
||||||
if vector_store_id not in self.openai_vector_stores:
|
|
||||||
raise ValueError(f"Vector store {vector_store_id} not found")
|
|
||||||
|
|
||||||
# Delete from SQLite database
|
|
||||||
def _delete_openai_vector_store():
|
|
||||||
connection = _create_sqlite_connection(self.config.db_path)
|
|
||||||
cur = connection.cursor()
|
|
||||||
try:
|
|
||||||
cur.execute("DELETE FROM openai_vector_stores WHERE id = ?", (vector_store_id,))
|
|
||||||
connection.commit()
|
|
||||||
finally:
|
|
||||||
cur.close()
|
|
||||||
connection.close()
|
|
||||||
|
|
||||||
await asyncio.to_thread(_delete_openai_vector_store)
|
|
||||||
|
|
||||||
# Delete from in-memory cache
|
|
||||||
del self.openai_vector_stores[vector_store_id]
|
|
||||||
|
|
||||||
# Also delete the underlying vector DB
|
|
||||||
try:
|
|
||||||
await self.unregister_vector_db(vector_store_id)
|
|
||||||
except Exception as e:
|
|
||||||
logger.warning(f"Failed to delete underlying vector DB {vector_store_id}: {e}")
|
|
||||||
|
|
||||||
return VectorStoreDeleteResponse(
|
|
||||||
id=vector_store_id,
|
|
||||||
deleted=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
async def openai_search_vector_store(
|
|
||||||
self,
|
|
||||||
vector_store_id: str,
|
|
||||||
query: str | list[str],
|
|
||||||
filters: dict[str, Any] | None = None,
|
|
||||||
max_num_results: int = 10,
|
|
||||||
ranking_options: dict[str, Any] | None = None,
|
|
||||||
rewrite_query: bool = False,
|
|
||||||
search_mode: Literal["keyword", "vector", "hybrid"] = "vector",
|
|
||||||
) -> VectorStoreSearchResponse:
|
|
||||||
"""Search for chunks in a vector store."""
|
|
||||||
if vector_store_id not in self.openai_vector_stores:
|
|
||||||
raise ValueError(f"Vector store {vector_store_id} not found")
|
|
||||||
|
|
||||||
if isinstance(query, list):
|
|
||||||
search_query = " ".join(query)
|
|
||||||
else:
|
|
||||||
search_query = query
|
|
||||||
|
|
||||||
try:
|
|
||||||
score_threshold = ranking_options.get("score_threshold", 0.0) if ranking_options else 0.0
|
|
||||||
params = {
|
|
||||||
"max_chunks": max_num_results * CHUNK_MULTIPLIER,
|
|
||||||
"score_threshold": score_threshold,
|
|
||||||
"mode": search_mode,
|
|
||||||
}
|
|
||||||
# TODO: Add support for ranking_options.ranker
|
|
||||||
|
|
||||||
response = await self.query_chunks(
|
|
||||||
vector_db_id=vector_store_id,
|
|
||||||
query=search_query,
|
|
||||||
params=params,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Convert response to OpenAI format
|
|
||||||
data = []
|
|
||||||
for i, (chunk, score) in enumerate(zip(response.chunks, response.scores, strict=False)):
|
|
||||||
# Apply score based filtering
|
|
||||||
if score < score_threshold:
|
|
||||||
continue
|
|
||||||
|
|
||||||
# Apply filters if provided
|
|
||||||
if filters:
|
|
||||||
# Simple metadata filtering
|
|
||||||
if not self._matches_filters(chunk.metadata, filters):
|
|
||||||
continue
|
|
||||||
|
|
||||||
chunk_data = {
|
|
||||||
"id": f"chunk_{i}",
|
|
||||||
"object": "vector_store.search_result",
|
|
||||||
"score": score,
|
|
||||||
"content": chunk.content.content if hasattr(chunk.content, "content") else str(chunk.content),
|
|
||||||
"metadata": chunk.metadata,
|
|
||||||
}
|
|
||||||
data.append(chunk_data)
|
|
||||||
if len(data) >= max_num_results:
|
|
||||||
break
|
|
||||||
|
|
||||||
return VectorStoreSearchResponse(
|
|
||||||
search_query=search_query,
|
|
||||||
data=data,
|
|
||||||
has_more=False, # For simplicity, we don't implement pagination here
|
|
||||||
next_page=None,
|
|
||||||
)
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error searching vector store {vector_store_id}: {e}")
|
|
||||||
# Return empty results on error
|
|
||||||
return VectorStoreSearchResponse(
|
|
||||||
search_query=search_query,
|
|
||||||
data=[],
|
|
||||||
has_more=False,
|
|
||||||
next_page=None,
|
|
||||||
)
|
|
||||||
|
|
||||||
def _matches_filters(self, metadata: dict[str, Any], filters: dict[str, Any]) -> bool:
|
|
||||||
"""Check if metadata matches the provided filters."""
|
|
||||||
for key, value in filters.items():
|
|
||||||
if key not in metadata:
|
|
||||||
return False
|
|
||||||
if metadata[key] != value:
|
|
||||||
return False
|
|
||||||
return True
|
|
||||||
|
|
||||||
|
|
||||||
def generate_chunk_id(document_id: str, chunk_text: str) -> str:
|
def generate_chunk_id(document_id: str, chunk_text: str) -> str:
|
||||||
"""Generate a unique chunk ID using a hash of document ID and chunk text."""
|
"""Generate a unique chunk ID using a hash of document ID and chunk text."""
|
||||||
|
|
|
@ -6,7 +6,7 @@
|
||||||
import asyncio
|
import asyncio
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
from typing import Any, Literal
|
from typing import Any
|
||||||
from urllib.parse import urlparse
|
from urllib.parse import urlparse
|
||||||
|
|
||||||
import chromadb
|
import chromadb
|
||||||
|
@ -21,7 +21,7 @@ from llama_stack.apis.vector_io import (
|
||||||
VectorStoreDeleteResponse,
|
VectorStoreDeleteResponse,
|
||||||
VectorStoreListResponse,
|
VectorStoreListResponse,
|
||||||
VectorStoreObject,
|
VectorStoreObject,
|
||||||
VectorStoreSearchResponse,
|
VectorStoreSearchResponsePage,
|
||||||
)
|
)
|
||||||
from llama_stack.providers.datatypes import Api, VectorDBsProtocolPrivate
|
from llama_stack.providers.datatypes import Api, VectorDBsProtocolPrivate
|
||||||
from llama_stack.providers.inline.vector_io.chroma import ChromaVectorIOConfig as InlineChromaVectorIOConfig
|
from llama_stack.providers.inline.vector_io.chroma import ChromaVectorIOConfig as InlineChromaVectorIOConfig
|
||||||
|
@ -189,7 +189,7 @@ class ChromaVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate):
|
||||||
|
|
||||||
async def openai_create_vector_store(
|
async def openai_create_vector_store(
|
||||||
self,
|
self,
|
||||||
name: str | None = None,
|
name: str,
|
||||||
file_ids: list[str] | None = None,
|
file_ids: list[str] | None = None,
|
||||||
expires_after: dict[str, Any] | None = None,
|
expires_after: dict[str, Any] | None = None,
|
||||||
chunking_strategy: dict[str, Any] | None = None,
|
chunking_strategy: dict[str, Any] | None = None,
|
||||||
|
@ -203,8 +203,8 @@ class ChromaVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate):
|
||||||
|
|
||||||
async def openai_list_vector_stores(
|
async def openai_list_vector_stores(
|
||||||
self,
|
self,
|
||||||
limit: int = 20,
|
limit: int | None = 20,
|
||||||
order: str = "desc",
|
order: str | None = "desc",
|
||||||
after: str | None = None,
|
after: str | None = None,
|
||||||
before: str | None = None,
|
before: str | None = None,
|
||||||
) -> VectorStoreListResponse:
|
) -> VectorStoreListResponse:
|
||||||
|
@ -236,9 +236,8 @@ class ChromaVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate):
|
||||||
vector_store_id: str,
|
vector_store_id: str,
|
||||||
query: str | list[str],
|
query: str | list[str],
|
||||||
filters: dict[str, Any] | None = None,
|
filters: dict[str, Any] | None = None,
|
||||||
max_num_results: int = 10,
|
max_num_results: int | None = 10,
|
||||||
ranking_options: dict[str, Any] | None = None,
|
ranking_options: dict[str, Any] | None = None,
|
||||||
rewrite_query: bool = False,
|
rewrite_query: bool | None = False,
|
||||||
search_mode: Literal["keyword", "vector", "hybrid"] = "vector",
|
) -> VectorStoreSearchResponsePage:
|
||||||
) -> VectorStoreSearchResponse:
|
|
||||||
raise NotImplementedError("OpenAI Vector Stores API is not supported in Chroma")
|
raise NotImplementedError("OpenAI Vector Stores API is not supported in Chroma")
|
||||||
|
|
|
@ -9,7 +9,7 @@ import hashlib
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import uuid
|
import uuid
|
||||||
from typing import Any, Literal
|
from typing import Any
|
||||||
|
|
||||||
from numpy.typing import NDArray
|
from numpy.typing import NDArray
|
||||||
from pymilvus import MilvusClient
|
from pymilvus import MilvusClient
|
||||||
|
@ -23,7 +23,7 @@ from llama_stack.apis.vector_io import (
|
||||||
VectorStoreDeleteResponse,
|
VectorStoreDeleteResponse,
|
||||||
VectorStoreListResponse,
|
VectorStoreListResponse,
|
||||||
VectorStoreObject,
|
VectorStoreObject,
|
||||||
VectorStoreSearchResponse,
|
VectorStoreSearchResponsePage,
|
||||||
)
|
)
|
||||||
from llama_stack.providers.datatypes import Api, VectorDBsProtocolPrivate
|
from llama_stack.providers.datatypes import Api, VectorDBsProtocolPrivate
|
||||||
from llama_stack.providers.inline.vector_io.milvus import MilvusVectorIOConfig as InlineMilvusVectorIOConfig
|
from llama_stack.providers.inline.vector_io.milvus import MilvusVectorIOConfig as InlineMilvusVectorIOConfig
|
||||||
|
@ -187,7 +187,7 @@ class MilvusVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate):
|
||||||
|
|
||||||
async def openai_create_vector_store(
|
async def openai_create_vector_store(
|
||||||
self,
|
self,
|
||||||
name: str | None = None,
|
name: str,
|
||||||
file_ids: list[str] | None = None,
|
file_ids: list[str] | None = None,
|
||||||
expires_after: dict[str, Any] | None = None,
|
expires_after: dict[str, Any] | None = None,
|
||||||
chunking_strategy: dict[str, Any] | None = None,
|
chunking_strategy: dict[str, Any] | None = None,
|
||||||
|
@ -201,8 +201,8 @@ class MilvusVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate):
|
||||||
|
|
||||||
async def openai_list_vector_stores(
|
async def openai_list_vector_stores(
|
||||||
self,
|
self,
|
||||||
limit: int = 20,
|
limit: int | None = 20,
|
||||||
order: str = "desc",
|
order: str | None = "desc",
|
||||||
after: str | None = None,
|
after: str | None = None,
|
||||||
before: str | None = None,
|
before: str | None = None,
|
||||||
) -> VectorStoreListResponse:
|
) -> VectorStoreListResponse:
|
||||||
|
@ -234,11 +234,10 @@ class MilvusVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate):
|
||||||
vector_store_id: str,
|
vector_store_id: str,
|
||||||
query: str | list[str],
|
query: str | list[str],
|
||||||
filters: dict[str, Any] | None = None,
|
filters: dict[str, Any] | None = None,
|
||||||
max_num_results: int = 10,
|
max_num_results: int | None = 10,
|
||||||
ranking_options: dict[str, Any] | None = None,
|
ranking_options: dict[str, Any] | None = None,
|
||||||
rewrite_query: bool = False,
|
rewrite_query: bool | None = False,
|
||||||
search_mode: Literal["keyword", "vector", "hybrid"] = "vector",
|
) -> VectorStoreSearchResponsePage:
|
||||||
) -> VectorStoreSearchResponse:
|
|
||||||
raise NotImplementedError("OpenAI Vector Stores API is not supported in Qdrant")
|
raise NotImplementedError("OpenAI Vector Stores API is not supported in Qdrant")
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -6,7 +6,7 @@
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
import uuid
|
import uuid
|
||||||
from typing import Any, Literal
|
from typing import Any
|
||||||
|
|
||||||
from numpy.typing import NDArray
|
from numpy.typing import NDArray
|
||||||
from qdrant_client import AsyncQdrantClient, models
|
from qdrant_client import AsyncQdrantClient, models
|
||||||
|
@ -21,7 +21,7 @@ from llama_stack.apis.vector_io import (
|
||||||
VectorStoreDeleteResponse,
|
VectorStoreDeleteResponse,
|
||||||
VectorStoreListResponse,
|
VectorStoreListResponse,
|
||||||
VectorStoreObject,
|
VectorStoreObject,
|
||||||
VectorStoreSearchResponse,
|
VectorStoreSearchResponsePage,
|
||||||
)
|
)
|
||||||
from llama_stack.providers.datatypes import Api, VectorDBsProtocolPrivate
|
from llama_stack.providers.datatypes import Api, VectorDBsProtocolPrivate
|
||||||
from llama_stack.providers.inline.vector_io.qdrant import QdrantVectorIOConfig as InlineQdrantVectorIOConfig
|
from llama_stack.providers.inline.vector_io.qdrant import QdrantVectorIOConfig as InlineQdrantVectorIOConfig
|
||||||
|
@ -189,7 +189,7 @@ class QdrantVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate):
|
||||||
|
|
||||||
async def openai_create_vector_store(
|
async def openai_create_vector_store(
|
||||||
self,
|
self,
|
||||||
name: str | None = None,
|
name: str,
|
||||||
file_ids: list[str] | None = None,
|
file_ids: list[str] | None = None,
|
||||||
expires_after: dict[str, Any] | None = None,
|
expires_after: dict[str, Any] | None = None,
|
||||||
chunking_strategy: dict[str, Any] | None = None,
|
chunking_strategy: dict[str, Any] | None = None,
|
||||||
|
@ -203,8 +203,8 @@ class QdrantVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate):
|
||||||
|
|
||||||
async def openai_list_vector_stores(
|
async def openai_list_vector_stores(
|
||||||
self,
|
self,
|
||||||
limit: int = 20,
|
limit: int | None = 20,
|
||||||
order: str = "desc",
|
order: str | None = "desc",
|
||||||
after: str | None = None,
|
after: str | None = None,
|
||||||
before: str | None = None,
|
before: str | None = None,
|
||||||
) -> VectorStoreListResponse:
|
) -> VectorStoreListResponse:
|
||||||
|
@ -236,9 +236,8 @@ class QdrantVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate):
|
||||||
vector_store_id: str,
|
vector_store_id: str,
|
||||||
query: str | list[str],
|
query: str | list[str],
|
||||||
filters: dict[str, Any] | None = None,
|
filters: dict[str, Any] | None = None,
|
||||||
max_num_results: int = 10,
|
max_num_results: int | None = 10,
|
||||||
ranking_options: dict[str, Any] | None = None,
|
ranking_options: dict[str, Any] | None = None,
|
||||||
rewrite_query: bool = False,
|
rewrite_query: bool | None = False,
|
||||||
search_mode: Literal["keyword", "vector", "hybrid"] = "vector",
|
) -> VectorStoreSearchResponsePage:
|
||||||
) -> VectorStoreSearchResponse:
|
|
||||||
raise NotImplementedError("OpenAI Vector Stores API is not supported in Qdrant")
|
raise NotImplementedError("OpenAI Vector Stores API is not supported in Qdrant")
|
||||||
|
|
|
@ -76,7 +76,7 @@ class WeaviateIndex(EmbeddingIndex):
|
||||||
continue
|
continue
|
||||||
|
|
||||||
chunks.append(chunk)
|
chunks.append(chunk)
|
||||||
scores.append(1.0 / doc.metadata.distance)
|
scores.append(1.0 / doc.metadata.distance if doc.metadata.distance != 0 else float("inf"))
|
||||||
|
|
||||||
return QueryChunksResponse(chunks=chunks, scores=scores)
|
return QueryChunksResponse(chunks=chunks, scores=scores)
|
||||||
|
|
||||||
|
|
385
llama_stack/providers/utils/memory/openai_vector_store_mixin.py
Normal file
385
llama_stack/providers/utils/memory/openai_vector_store_mixin.py
Normal file
|
@ -0,0 +1,385 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import time
|
||||||
|
import uuid
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from llama_stack.apis.vector_dbs import VectorDB
|
||||||
|
from llama_stack.apis.vector_io import (
|
||||||
|
QueryChunksResponse,
|
||||||
|
VectorStoreContent,
|
||||||
|
VectorStoreDeleteResponse,
|
||||||
|
VectorStoreListResponse,
|
||||||
|
VectorStoreObject,
|
||||||
|
VectorStoreSearchResponse,
|
||||||
|
VectorStoreSearchResponsePage,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# Constants for OpenAI vector stores
|
||||||
|
CHUNK_MULTIPLIER = 5
|
||||||
|
|
||||||
|
|
||||||
|
class OpenAIVectorStoreMixin(ABC):
|
||||||
|
"""
|
||||||
|
Mixin class that provides common OpenAI Vector Store API implementation.
|
||||||
|
Providers need to implement the abstract storage methods and maintain
|
||||||
|
an openai_vector_stores in-memory cache.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# These should be provided by the implementing class
|
||||||
|
openai_vector_stores: dict[str, dict[str, Any]]
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def _save_openai_vector_store(self, store_id: str, store_info: dict[str, Any]) -> None:
|
||||||
|
"""Save vector store metadata to persistent storage."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def _load_openai_vector_stores(self) -> dict[str, dict[str, Any]]:
|
||||||
|
"""Load all vector store metadata from persistent storage."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def _update_openai_vector_store(self, store_id: str, store_info: dict[str, Any]) -> None:
|
||||||
|
"""Update vector store metadata in persistent storage."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def _delete_openai_vector_store_from_storage(self, store_id: str) -> None:
|
||||||
|
"""Delete vector store metadata from persistent storage."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def register_vector_db(self, vector_db: VectorDB) -> None:
|
||||||
|
"""Register a vector database (provider-specific implementation)."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def unregister_vector_db(self, vector_db_id: str) -> None:
|
||||||
|
"""Unregister a vector database (provider-specific implementation)."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def query_chunks(
|
||||||
|
self, vector_db_id: str, query: Any, params: dict[str, Any] | None = None
|
||||||
|
) -> QueryChunksResponse:
|
||||||
|
"""Query chunks from a vector database (provider-specific implementation)."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def openai_create_vector_store(
|
||||||
|
self,
|
||||||
|
name: str,
|
||||||
|
file_ids: list[str] | None = None,
|
||||||
|
expires_after: dict[str, Any] | None = None,
|
||||||
|
chunking_strategy: dict[str, Any] | None = None,
|
||||||
|
metadata: dict[str, Any] | None = None,
|
||||||
|
embedding_model: str | None = None,
|
||||||
|
embedding_dimension: int | None = 384,
|
||||||
|
provider_id: str | None = None,
|
||||||
|
provider_vector_db_id: str | None = None,
|
||||||
|
) -> VectorStoreObject:
|
||||||
|
"""Creates a vector store."""
|
||||||
|
# store and vector_db have the same id
|
||||||
|
store_id = name or str(uuid.uuid4())
|
||||||
|
created_at = int(time.time())
|
||||||
|
|
||||||
|
if provider_id is None:
|
||||||
|
raise ValueError("Provider ID is required")
|
||||||
|
|
||||||
|
if embedding_model is None:
|
||||||
|
raise ValueError("Embedding model is required")
|
||||||
|
|
||||||
|
# Use provided embedding dimension or default to 384
|
||||||
|
if embedding_dimension is None:
|
||||||
|
raise ValueError("Embedding dimension is required")
|
||||||
|
|
||||||
|
provider_vector_db_id = provider_vector_db_id or store_id
|
||||||
|
vector_db = VectorDB(
|
||||||
|
identifier=store_id,
|
||||||
|
embedding_dimension=embedding_dimension,
|
||||||
|
embedding_model=embedding_model,
|
||||||
|
provider_id=provider_id,
|
||||||
|
provider_resource_id=provider_vector_db_id,
|
||||||
|
)
|
||||||
|
# Register the vector DB
|
||||||
|
await self.register_vector_db(vector_db)
|
||||||
|
|
||||||
|
# Create OpenAI vector store metadata
|
||||||
|
store_info = {
|
||||||
|
"id": store_id,
|
||||||
|
"object": "vector_store",
|
||||||
|
"created_at": created_at,
|
||||||
|
"name": store_id,
|
||||||
|
"usage_bytes": 0,
|
||||||
|
"file_counts": {},
|
||||||
|
"status": "completed",
|
||||||
|
"expires_after": expires_after,
|
||||||
|
"expires_at": None,
|
||||||
|
"last_active_at": created_at,
|
||||||
|
"file_ids": file_ids or [],
|
||||||
|
"chunking_strategy": chunking_strategy,
|
||||||
|
}
|
||||||
|
|
||||||
|
# Add provider information to metadata if provided
|
||||||
|
metadata = metadata or {}
|
||||||
|
if provider_id:
|
||||||
|
metadata["provider_id"] = provider_id
|
||||||
|
if provider_vector_db_id:
|
||||||
|
metadata["provider_vector_db_id"] = provider_vector_db_id
|
||||||
|
store_info["metadata"] = metadata
|
||||||
|
|
||||||
|
# Save to persistent storage (provider-specific)
|
||||||
|
await self._save_openai_vector_store(store_id, store_info)
|
||||||
|
|
||||||
|
# Store in memory cache
|
||||||
|
self.openai_vector_stores[store_id] = store_info
|
||||||
|
|
||||||
|
return VectorStoreObject(
|
||||||
|
id=store_id,
|
||||||
|
created_at=created_at,
|
||||||
|
name=store_id,
|
||||||
|
usage_bytes=0,
|
||||||
|
file_counts={},
|
||||||
|
status="completed",
|
||||||
|
expires_after=expires_after,
|
||||||
|
expires_at=None,
|
||||||
|
last_active_at=created_at,
|
||||||
|
metadata=metadata,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def openai_list_vector_stores(
|
||||||
|
self,
|
||||||
|
limit: int | None = 20,
|
||||||
|
order: str | None = "desc",
|
||||||
|
after: str | None = None,
|
||||||
|
before: str | None = None,
|
||||||
|
) -> VectorStoreListResponse:
|
||||||
|
"""Returns a list of vector stores."""
|
||||||
|
limit = limit or 20
|
||||||
|
order = order or "desc"
|
||||||
|
|
||||||
|
# Get all vector stores
|
||||||
|
all_stores = list(self.openai_vector_stores.values())
|
||||||
|
|
||||||
|
# Sort by created_at
|
||||||
|
reverse_order = order == "desc"
|
||||||
|
all_stores.sort(key=lambda x: x["created_at"], reverse=reverse_order)
|
||||||
|
|
||||||
|
# Apply cursor-based pagination
|
||||||
|
if after:
|
||||||
|
after_index = next((i for i, store in enumerate(all_stores) if store["id"] == after), -1)
|
||||||
|
if after_index >= 0:
|
||||||
|
all_stores = all_stores[after_index + 1 :]
|
||||||
|
|
||||||
|
if before:
|
||||||
|
before_index = next((i for i, store in enumerate(all_stores) if store["id"] == before), len(all_stores))
|
||||||
|
all_stores = all_stores[:before_index]
|
||||||
|
|
||||||
|
# Apply limit
|
||||||
|
limited_stores = all_stores[:limit]
|
||||||
|
# Convert to VectorStoreObject instances
|
||||||
|
data = [VectorStoreObject(**store) for store in limited_stores]
|
||||||
|
|
||||||
|
# Determine pagination info
|
||||||
|
has_more = len(all_stores) > limit
|
||||||
|
first_id = data[0].id if data else None
|
||||||
|
last_id = data[-1].id if data else None
|
||||||
|
|
||||||
|
return VectorStoreListResponse(
|
||||||
|
data=data,
|
||||||
|
has_more=has_more,
|
||||||
|
first_id=first_id,
|
||||||
|
last_id=last_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def openai_retrieve_vector_store(
|
||||||
|
self,
|
||||||
|
vector_store_id: str,
|
||||||
|
) -> VectorStoreObject:
|
||||||
|
"""Retrieves a vector store."""
|
||||||
|
if vector_store_id not in self.openai_vector_stores:
|
||||||
|
raise ValueError(f"Vector store {vector_store_id} not found")
|
||||||
|
|
||||||
|
store_info = self.openai_vector_stores[vector_store_id]
|
||||||
|
return VectorStoreObject(**store_info)
|
||||||
|
|
||||||
|
async def openai_update_vector_store(
|
||||||
|
self,
|
||||||
|
vector_store_id: str,
|
||||||
|
name: str | None = None,
|
||||||
|
expires_after: dict[str, Any] | None = None,
|
||||||
|
metadata: dict[str, Any] | None = None,
|
||||||
|
) -> VectorStoreObject:
|
||||||
|
"""Modifies a vector store."""
|
||||||
|
if vector_store_id not in self.openai_vector_stores:
|
||||||
|
raise ValueError(f"Vector store {vector_store_id} not found")
|
||||||
|
|
||||||
|
store_info = self.openai_vector_stores[vector_store_id].copy()
|
||||||
|
|
||||||
|
# Update fields if provided
|
||||||
|
if name is not None:
|
||||||
|
store_info["name"] = name
|
||||||
|
if expires_after is not None:
|
||||||
|
store_info["expires_after"] = expires_after
|
||||||
|
if metadata is not None:
|
||||||
|
store_info["metadata"] = metadata
|
||||||
|
|
||||||
|
# Update last_active_at
|
||||||
|
store_info["last_active_at"] = int(time.time())
|
||||||
|
|
||||||
|
# Save to persistent storage (provider-specific)
|
||||||
|
await self._update_openai_vector_store(vector_store_id, store_info)
|
||||||
|
|
||||||
|
# Update in-memory cache
|
||||||
|
self.openai_vector_stores[vector_store_id] = store_info
|
||||||
|
|
||||||
|
return VectorStoreObject(**store_info)
|
||||||
|
|
||||||
|
async def openai_delete_vector_store(
|
||||||
|
self,
|
||||||
|
vector_store_id: str,
|
||||||
|
) -> VectorStoreDeleteResponse:
|
||||||
|
"""Delete a vector store."""
|
||||||
|
if vector_store_id not in self.openai_vector_stores:
|
||||||
|
raise ValueError(f"Vector store {vector_store_id} not found")
|
||||||
|
|
||||||
|
# Delete from persistent storage (provider-specific)
|
||||||
|
await self._delete_openai_vector_store_from_storage(vector_store_id)
|
||||||
|
|
||||||
|
# Delete from in-memory cache
|
||||||
|
del self.openai_vector_stores[vector_store_id]
|
||||||
|
|
||||||
|
# Also delete the underlying vector DB
|
||||||
|
try:
|
||||||
|
await self.unregister_vector_db(vector_store_id)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Failed to delete underlying vector DB {vector_store_id}: {e}")
|
||||||
|
|
||||||
|
return VectorStoreDeleteResponse(
|
||||||
|
id=vector_store_id,
|
||||||
|
deleted=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def openai_search_vector_store(
|
||||||
|
self,
|
||||||
|
vector_store_id: str,
|
||||||
|
query: str | list[str],
|
||||||
|
filters: dict[str, Any] | None = None,
|
||||||
|
max_num_results: int | None = 10,
|
||||||
|
ranking_options: dict[str, Any] | None = None,
|
||||||
|
rewrite_query: bool | None = False,
|
||||||
|
# search_mode: Literal["keyword", "vector", "hybrid"] = "vector",
|
||||||
|
) -> VectorStoreSearchResponsePage:
|
||||||
|
"""Search for chunks in a vector store."""
|
||||||
|
# TODO: Add support in the API for this
|
||||||
|
search_mode = "vector"
|
||||||
|
max_num_results = max_num_results or 10
|
||||||
|
|
||||||
|
if vector_store_id not in self.openai_vector_stores:
|
||||||
|
raise ValueError(f"Vector store {vector_store_id} not found")
|
||||||
|
|
||||||
|
if isinstance(query, list):
|
||||||
|
search_query = " ".join(query)
|
||||||
|
else:
|
||||||
|
search_query = query
|
||||||
|
|
||||||
|
try:
|
||||||
|
score_threshold = ranking_options.get("score_threshold", 0.0) if ranking_options else 0.0
|
||||||
|
params = {
|
||||||
|
"max_chunks": max_num_results * CHUNK_MULTIPLIER,
|
||||||
|
"score_threshold": score_threshold,
|
||||||
|
"mode": search_mode,
|
||||||
|
}
|
||||||
|
# TODO: Add support for ranking_options.ranker
|
||||||
|
|
||||||
|
response = await self.query_chunks(
|
||||||
|
vector_db_id=vector_store_id,
|
||||||
|
query=search_query,
|
||||||
|
params=params,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Convert response to OpenAI format
|
||||||
|
data = []
|
||||||
|
for chunk, score in zip(response.chunks, response.scores, strict=False):
|
||||||
|
# Apply score based filtering
|
||||||
|
if score < score_threshold:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Apply filters if provided
|
||||||
|
if filters:
|
||||||
|
# Simple metadata filtering
|
||||||
|
if not self._matches_filters(chunk.metadata, filters):
|
||||||
|
continue
|
||||||
|
|
||||||
|
# content is InterleavedContent
|
||||||
|
if isinstance(chunk.content, str):
|
||||||
|
content = [
|
||||||
|
VectorStoreContent(
|
||||||
|
type="text",
|
||||||
|
text=chunk.content,
|
||||||
|
)
|
||||||
|
]
|
||||||
|
elif isinstance(chunk.content, list):
|
||||||
|
# TODO: Add support for other types of content
|
||||||
|
content = [
|
||||||
|
VectorStoreContent(
|
||||||
|
type="text",
|
||||||
|
text=item.text,
|
||||||
|
)
|
||||||
|
for item in chunk.content
|
||||||
|
if item.type == "text"
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
if chunk.content.type != "text":
|
||||||
|
raise ValueError(f"Unsupported content type: {chunk.content.type}")
|
||||||
|
content = [
|
||||||
|
VectorStoreContent(
|
||||||
|
type="text",
|
||||||
|
text=chunk.content.text,
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
|
response_data_item = VectorStoreSearchResponse(
|
||||||
|
file_id=chunk.metadata.get("file_id", ""),
|
||||||
|
filename=chunk.metadata.get("filename", ""),
|
||||||
|
score=score,
|
||||||
|
attributes=chunk.metadata,
|
||||||
|
content=content,
|
||||||
|
)
|
||||||
|
data.append(response_data_item)
|
||||||
|
if len(data) >= max_num_results:
|
||||||
|
break
|
||||||
|
|
||||||
|
return VectorStoreSearchResponsePage(
|
||||||
|
search_query=search_query,
|
||||||
|
data=data,
|
||||||
|
has_more=False, # For simplicity, we don't implement pagination here
|
||||||
|
next_page=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error searching vector store {vector_store_id}: {e}")
|
||||||
|
# Return empty results on error
|
||||||
|
return VectorStoreSearchResponsePage(
|
||||||
|
search_query=search_query,
|
||||||
|
data=[],
|
||||||
|
has_more=False,
|
||||||
|
next_page=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _matches_filters(self, metadata: dict[str, Any], filters: dict[str, Any]) -> bool:
|
||||||
|
"""Check if metadata matches the provided filters."""
|
||||||
|
for key, value in filters.items():
|
||||||
|
if key not in metadata:
|
||||||
|
return False
|
||||||
|
if metadata[key] != value:
|
||||||
|
return False
|
||||||
|
return True
|
|
@ -34,11 +34,15 @@ def skip_if_model_doesnt_support_variable_dimensions(model_id):
|
||||||
pytest.skip("{model_id} does not support variable output embedding dimensions")
|
pytest.skip("{model_id} does not support variable output embedding dimensions")
|
||||||
|
|
||||||
|
|
||||||
def skip_if_model_doesnt_support_openai_embeddings(client_with_models, model_id):
|
@pytest.fixture(params=["openai_client", "llama_stack_client"])
|
||||||
if isinstance(client_with_models, LlamaStackAsLibraryClient):
|
def compat_client(request, client_with_models):
|
||||||
pytest.skip("OpenAI embeddings are not supported when testing with library client yet.")
|
if request.param == "openai_client" and isinstance(client_with_models, LlamaStackAsLibraryClient):
|
||||||
|
pytest.skip("OpenAI client tests not supported with library client")
|
||||||
|
return request.getfixturevalue(request.param)
|
||||||
|
|
||||||
provider = provider_from_model(client_with_models, model_id)
|
|
||||||
|
def skip_if_model_doesnt_support_openai_embeddings(client, model_id):
|
||||||
|
provider = provider_from_model(client, model_id)
|
||||||
if provider.provider_type in (
|
if provider.provider_type in (
|
||||||
"inline::meta-reference",
|
"inline::meta-reference",
|
||||||
"remote::bedrock",
|
"remote::bedrock",
|
||||||
|
@ -58,13 +62,13 @@ def openai_client(client_with_models):
|
||||||
return OpenAI(base_url=base_url, api_key="fake")
|
return OpenAI(base_url=base_url, api_key="fake")
|
||||||
|
|
||||||
|
|
||||||
def test_openai_embeddings_single_string(openai_client, client_with_models, embedding_model_id):
|
def test_openai_embeddings_single_string(compat_client, client_with_models, embedding_model_id):
|
||||||
"""Test OpenAI embeddings endpoint with a single string input."""
|
"""Test OpenAI embeddings endpoint with a single string input."""
|
||||||
skip_if_model_doesnt_support_openai_embeddings(client_with_models, embedding_model_id)
|
skip_if_model_doesnt_support_openai_embeddings(client_with_models, embedding_model_id)
|
||||||
|
|
||||||
input_text = "Hello, world!"
|
input_text = "Hello, world!"
|
||||||
|
|
||||||
response = openai_client.embeddings.create(
|
response = compat_client.embeddings.create(
|
||||||
model=embedding_model_id,
|
model=embedding_model_id,
|
||||||
input=input_text,
|
input=input_text,
|
||||||
encoding_format="float",
|
encoding_format="float",
|
||||||
|
@ -80,13 +84,13 @@ def test_openai_embeddings_single_string(openai_client, client_with_models, embe
|
||||||
assert all(isinstance(x, float) for x in response.data[0].embedding)
|
assert all(isinstance(x, float) for x in response.data[0].embedding)
|
||||||
|
|
||||||
|
|
||||||
def test_openai_embeddings_multiple_strings(openai_client, client_with_models, embedding_model_id):
|
def test_openai_embeddings_multiple_strings(compat_client, client_with_models, embedding_model_id):
|
||||||
"""Test OpenAI embeddings endpoint with multiple string inputs."""
|
"""Test OpenAI embeddings endpoint with multiple string inputs."""
|
||||||
skip_if_model_doesnt_support_openai_embeddings(client_with_models, embedding_model_id)
|
skip_if_model_doesnt_support_openai_embeddings(client_with_models, embedding_model_id)
|
||||||
|
|
||||||
input_texts = ["Hello, world!", "How are you today?", "This is a test."]
|
input_texts = ["Hello, world!", "How are you today?", "This is a test."]
|
||||||
|
|
||||||
response = openai_client.embeddings.create(
|
response = compat_client.embeddings.create(
|
||||||
model=embedding_model_id,
|
model=embedding_model_id,
|
||||||
input=input_texts,
|
input=input_texts,
|
||||||
)
|
)
|
||||||
|
@ -103,13 +107,13 @@ def test_openai_embeddings_multiple_strings(openai_client, client_with_models, e
|
||||||
assert all(isinstance(x, float) for x in embedding_data.embedding)
|
assert all(isinstance(x, float) for x in embedding_data.embedding)
|
||||||
|
|
||||||
|
|
||||||
def test_openai_embeddings_with_encoding_format_float(openai_client, client_with_models, embedding_model_id):
|
def test_openai_embeddings_with_encoding_format_float(compat_client, client_with_models, embedding_model_id):
|
||||||
"""Test OpenAI embeddings endpoint with float encoding format."""
|
"""Test OpenAI embeddings endpoint with float encoding format."""
|
||||||
skip_if_model_doesnt_support_openai_embeddings(client_with_models, embedding_model_id)
|
skip_if_model_doesnt_support_openai_embeddings(client_with_models, embedding_model_id)
|
||||||
|
|
||||||
input_text = "Test encoding format"
|
input_text = "Test encoding format"
|
||||||
|
|
||||||
response = openai_client.embeddings.create(
|
response = compat_client.embeddings.create(
|
||||||
model=embedding_model_id,
|
model=embedding_model_id,
|
||||||
input=input_text,
|
input=input_text,
|
||||||
encoding_format="float",
|
encoding_format="float",
|
||||||
|
@ -121,7 +125,7 @@ def test_openai_embeddings_with_encoding_format_float(openai_client, client_with
|
||||||
assert all(isinstance(x, float) for x in response.data[0].embedding)
|
assert all(isinstance(x, float) for x in response.data[0].embedding)
|
||||||
|
|
||||||
|
|
||||||
def test_openai_embeddings_with_dimensions(openai_client, client_with_models, embedding_model_id):
|
def test_openai_embeddings_with_dimensions(compat_client, client_with_models, embedding_model_id):
|
||||||
"""Test OpenAI embeddings endpoint with custom dimensions parameter."""
|
"""Test OpenAI embeddings endpoint with custom dimensions parameter."""
|
||||||
skip_if_model_doesnt_support_openai_embeddings(client_with_models, embedding_model_id)
|
skip_if_model_doesnt_support_openai_embeddings(client_with_models, embedding_model_id)
|
||||||
skip_if_model_doesnt_support_variable_dimensions(embedding_model_id)
|
skip_if_model_doesnt_support_variable_dimensions(embedding_model_id)
|
||||||
|
@ -129,7 +133,7 @@ def test_openai_embeddings_with_dimensions(openai_client, client_with_models, em
|
||||||
input_text = "Test dimensions parameter"
|
input_text = "Test dimensions parameter"
|
||||||
dimensions = 16
|
dimensions = 16
|
||||||
|
|
||||||
response = openai_client.embeddings.create(
|
response = compat_client.embeddings.create(
|
||||||
model=embedding_model_id,
|
model=embedding_model_id,
|
||||||
input=input_text,
|
input=input_text,
|
||||||
dimensions=dimensions,
|
dimensions=dimensions,
|
||||||
|
@ -142,14 +146,14 @@ def test_openai_embeddings_with_dimensions(openai_client, client_with_models, em
|
||||||
assert len(response.data[0].embedding) > 0
|
assert len(response.data[0].embedding) > 0
|
||||||
|
|
||||||
|
|
||||||
def test_openai_embeddings_with_user_parameter(openai_client, client_with_models, embedding_model_id):
|
def test_openai_embeddings_with_user_parameter(compat_client, client_with_models, embedding_model_id):
|
||||||
"""Test OpenAI embeddings endpoint with user parameter."""
|
"""Test OpenAI embeddings endpoint with user parameter."""
|
||||||
skip_if_model_doesnt_support_openai_embeddings(client_with_models, embedding_model_id)
|
skip_if_model_doesnt_support_openai_embeddings(client_with_models, embedding_model_id)
|
||||||
|
|
||||||
input_text = "Test user parameter"
|
input_text = "Test user parameter"
|
||||||
user_id = "test-user-123"
|
user_id = "test-user-123"
|
||||||
|
|
||||||
response = openai_client.embeddings.create(
|
response = compat_client.embeddings.create(
|
||||||
model=embedding_model_id,
|
model=embedding_model_id,
|
||||||
input=input_text,
|
input=input_text,
|
||||||
user=user_id,
|
user=user_id,
|
||||||
|
@ -161,41 +165,41 @@ def test_openai_embeddings_with_user_parameter(openai_client, client_with_models
|
||||||
assert len(response.data[0].embedding) > 0
|
assert len(response.data[0].embedding) > 0
|
||||||
|
|
||||||
|
|
||||||
def test_openai_embeddings_empty_list_error(openai_client, client_with_models, embedding_model_id):
|
def test_openai_embeddings_empty_list_error(compat_client, client_with_models, embedding_model_id):
|
||||||
"""Test that empty list input raises an appropriate error."""
|
"""Test that empty list input raises an appropriate error."""
|
||||||
skip_if_model_doesnt_support_openai_embeddings(client_with_models, embedding_model_id)
|
skip_if_model_doesnt_support_openai_embeddings(client_with_models, embedding_model_id)
|
||||||
|
|
||||||
with pytest.raises(Exception): # noqa: B017
|
with pytest.raises(Exception): # noqa: B017
|
||||||
openai_client.embeddings.create(
|
compat_client.embeddings.create(
|
||||||
model=embedding_model_id,
|
model=embedding_model_id,
|
||||||
input=[],
|
input=[],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def test_openai_embeddings_invalid_model_error(openai_client, client_with_models, embedding_model_id):
|
def test_openai_embeddings_invalid_model_error(compat_client, client_with_models, embedding_model_id):
|
||||||
"""Test that invalid model ID raises an appropriate error."""
|
"""Test that invalid model ID raises an appropriate error."""
|
||||||
skip_if_model_doesnt_support_openai_embeddings(client_with_models, embedding_model_id)
|
skip_if_model_doesnt_support_openai_embeddings(client_with_models, embedding_model_id)
|
||||||
|
|
||||||
with pytest.raises(Exception): # noqa: B017
|
with pytest.raises(Exception): # noqa: B017
|
||||||
openai_client.embeddings.create(
|
compat_client.embeddings.create(
|
||||||
model="invalid-model-id",
|
model="invalid-model-id",
|
||||||
input="Test text",
|
input="Test text",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def test_openai_embeddings_different_inputs_different_outputs(openai_client, client_with_models, embedding_model_id):
|
def test_openai_embeddings_different_inputs_different_outputs(compat_client, client_with_models, embedding_model_id):
|
||||||
"""Test that different inputs produce different embeddings."""
|
"""Test that different inputs produce different embeddings."""
|
||||||
skip_if_model_doesnt_support_openai_embeddings(client_with_models, embedding_model_id)
|
skip_if_model_doesnt_support_openai_embeddings(client_with_models, embedding_model_id)
|
||||||
|
|
||||||
input_text1 = "This is the first text"
|
input_text1 = "This is the first text"
|
||||||
input_text2 = "This is completely different content"
|
input_text2 = "This is completely different content"
|
||||||
|
|
||||||
response1 = openai_client.embeddings.create(
|
response1 = compat_client.embeddings.create(
|
||||||
model=embedding_model_id,
|
model=embedding_model_id,
|
||||||
input=input_text1,
|
input=input_text1,
|
||||||
)
|
)
|
||||||
|
|
||||||
response2 = openai_client.embeddings.create(
|
response2 = compat_client.embeddings.create(
|
||||||
model=embedding_model_id,
|
model=embedding_model_id,
|
||||||
input=input_text2,
|
input=input_text2,
|
||||||
)
|
)
|
||||||
|
@ -208,7 +212,7 @@ def test_openai_embeddings_different_inputs_different_outputs(openai_client, cli
|
||||||
assert embedding1 != embedding2
|
assert embedding1 != embedding2
|
||||||
|
|
||||||
|
|
||||||
def test_openai_embeddings_with_encoding_format_base64(openai_client, client_with_models, embedding_model_id):
|
def test_openai_embeddings_with_encoding_format_base64(compat_client, client_with_models, embedding_model_id):
|
||||||
"""Test OpenAI embeddings endpoint with base64 encoding format."""
|
"""Test OpenAI embeddings endpoint with base64 encoding format."""
|
||||||
skip_if_model_doesnt_support_openai_embeddings(client_with_models, embedding_model_id)
|
skip_if_model_doesnt_support_openai_embeddings(client_with_models, embedding_model_id)
|
||||||
skip_if_model_doesnt_support_variable_dimensions(embedding_model_id)
|
skip_if_model_doesnt_support_variable_dimensions(embedding_model_id)
|
||||||
|
@ -216,7 +220,7 @@ def test_openai_embeddings_with_encoding_format_base64(openai_client, client_wit
|
||||||
input_text = "Test base64 encoding format"
|
input_text = "Test base64 encoding format"
|
||||||
dimensions = 12
|
dimensions = 12
|
||||||
|
|
||||||
response = openai_client.embeddings.create(
|
response = compat_client.embeddings.create(
|
||||||
model=embedding_model_id,
|
model=embedding_model_id,
|
||||||
input=input_text,
|
input=input_text,
|
||||||
encoding_format="base64",
|
encoding_format="base64",
|
||||||
|
@ -241,13 +245,13 @@ def test_openai_embeddings_with_encoding_format_base64(openai_client, client_wit
|
||||||
assert all(isinstance(x, float) for x in embedding_floats)
|
assert all(isinstance(x, float) for x in embedding_floats)
|
||||||
|
|
||||||
|
|
||||||
def test_openai_embeddings_base64_batch_processing(openai_client, client_with_models, embedding_model_id):
|
def test_openai_embeddings_base64_batch_processing(compat_client, client_with_models, embedding_model_id):
|
||||||
"""Test OpenAI embeddings endpoint with base64 encoding for batch processing."""
|
"""Test OpenAI embeddings endpoint with base64 encoding for batch processing."""
|
||||||
skip_if_model_doesnt_support_openai_embeddings(client_with_models, embedding_model_id)
|
skip_if_model_doesnt_support_openai_embeddings(client_with_models, embedding_model_id)
|
||||||
|
|
||||||
input_texts = ["First text for base64", "Second text for base64", "Third text for base64"]
|
input_texts = ["First text for base64", "Second text for base64", "Third text for base64"]
|
||||||
|
|
||||||
response = openai_client.embeddings.create(
|
response = compat_client.embeddings.create(
|
||||||
model=embedding_model_id,
|
model=embedding_model_id,
|
||||||
input=input_texts,
|
input=input_texts,
|
||||||
encoding_format="base64",
|
encoding_format="base64",
|
||||||
|
|
|
@ -17,9 +17,6 @@ logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def skip_if_provider_doesnt_support_openai_vector_stores(client_with_models):
|
def skip_if_provider_doesnt_support_openai_vector_stores(client_with_models):
|
||||||
if isinstance(client_with_models, LlamaStackAsLibraryClient):
|
|
||||||
pytest.skip("OpenAI vector stores are not supported when testing with library client yet.")
|
|
||||||
|
|
||||||
vector_io_providers = [p for p in client_with_models.providers.list() if p.api == "vector_io"]
|
vector_io_providers = [p for p in client_with_models.providers.list() if p.api == "vector_io"]
|
||||||
for p in vector_io_providers:
|
for p in vector_io_providers:
|
||||||
if p.provider_type in ["inline::faiss", "inline::sqlite-vec"]:
|
if p.provider_type in ["inline::faiss", "inline::sqlite-vec"]:
|
||||||
|
@ -34,6 +31,13 @@ def openai_client(client_with_models):
|
||||||
return OpenAI(base_url=base_url, api_key="fake")
|
return OpenAI(base_url=base_url, api_key="fake")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(params=["openai_client", "llama_stack_client"])
|
||||||
|
def compat_client(request, client_with_models):
|
||||||
|
if request.param == "openai_client" and isinstance(client_with_models, LlamaStackAsLibraryClient):
|
||||||
|
pytest.skip("OpenAI client tests not supported with library client")
|
||||||
|
return request.getfixturevalue(request.param)
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="session")
|
@pytest.fixture(scope="session")
|
||||||
def sample_chunks():
|
def sample_chunks():
|
||||||
return [
|
return [
|
||||||
|
@ -57,29 +61,29 @@ def sample_chunks():
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="function")
|
@pytest.fixture(scope="function")
|
||||||
def openai_client_with_empty_stores(openai_client):
|
def compat_client_with_empty_stores(compat_client):
|
||||||
def clear_vector_stores():
|
def clear_vector_stores():
|
||||||
# List and delete all existing vector stores
|
# List and delete all existing vector stores
|
||||||
try:
|
try:
|
||||||
response = openai_client.vector_stores.list()
|
response = compat_client.vector_stores.list()
|
||||||
for store in response.data:
|
for store in response.data:
|
||||||
openai_client.vector_stores.delete(vector_store_id=store.id)
|
compat_client.vector_stores.delete(vector_store_id=store.id)
|
||||||
except Exception:
|
except Exception:
|
||||||
# If the API is not available or fails, just continue
|
# If the API is not available or fails, just continue
|
||||||
logger.warning("Failed to clear vector stores")
|
logger.warning("Failed to clear vector stores")
|
||||||
pass
|
pass
|
||||||
|
|
||||||
clear_vector_stores()
|
clear_vector_stores()
|
||||||
yield openai_client
|
yield compat_client
|
||||||
|
|
||||||
# Clean up after the test
|
# Clean up after the test
|
||||||
clear_vector_stores()
|
clear_vector_stores()
|
||||||
|
|
||||||
|
|
||||||
def test_openai_create_vector_store(openai_client_with_empty_stores, client_with_models):
|
def test_openai_create_vector_store(compat_client_with_empty_stores, client_with_models):
|
||||||
"""Test creating a vector store using OpenAI API."""
|
"""Test creating a vector store using OpenAI API."""
|
||||||
skip_if_provider_doesnt_support_openai_vector_stores(client_with_models)
|
skip_if_provider_doesnt_support_openai_vector_stores(client_with_models)
|
||||||
client = openai_client_with_empty_stores
|
client = compat_client_with_empty_stores
|
||||||
|
|
||||||
# Create a vector store
|
# Create a vector store
|
||||||
vector_store = client.vector_stores.create(
|
vector_store = client.vector_stores.create(
|
||||||
|
@ -96,11 +100,11 @@ def test_openai_create_vector_store(openai_client_with_empty_stores, client_with
|
||||||
assert hasattr(vector_store, "created_at")
|
assert hasattr(vector_store, "created_at")
|
||||||
|
|
||||||
|
|
||||||
def test_openai_list_vector_stores(openai_client_with_empty_stores, client_with_models):
|
def test_openai_list_vector_stores(compat_client_with_empty_stores, client_with_models):
|
||||||
"""Test listing vector stores using OpenAI API."""
|
"""Test listing vector stores using OpenAI API."""
|
||||||
skip_if_provider_doesnt_support_openai_vector_stores(client_with_models)
|
skip_if_provider_doesnt_support_openai_vector_stores(client_with_models)
|
||||||
|
|
||||||
client = openai_client_with_empty_stores
|
client = compat_client_with_empty_stores
|
||||||
|
|
||||||
# Create a few vector stores
|
# Create a few vector stores
|
||||||
store1 = client.vector_stores.create(name="store1", metadata={"type": "test"})
|
store1 = client.vector_stores.create(name="store1", metadata={"type": "test"})
|
||||||
|
@ -123,11 +127,11 @@ def test_openai_list_vector_stores(openai_client_with_empty_stores, client_with_
|
||||||
assert len(limited_response.data) == 1
|
assert len(limited_response.data) == 1
|
||||||
|
|
||||||
|
|
||||||
def test_openai_retrieve_vector_store(openai_client_with_empty_stores, client_with_models):
|
def test_openai_retrieve_vector_store(compat_client_with_empty_stores, client_with_models):
|
||||||
"""Test retrieving a specific vector store using OpenAI API."""
|
"""Test retrieving a specific vector store using OpenAI API."""
|
||||||
skip_if_provider_doesnt_support_openai_vector_stores(client_with_models)
|
skip_if_provider_doesnt_support_openai_vector_stores(client_with_models)
|
||||||
|
|
||||||
client = openai_client_with_empty_stores
|
client = compat_client_with_empty_stores
|
||||||
|
|
||||||
# Create a vector store
|
# Create a vector store
|
||||||
created_store = client.vector_stores.create(name="retrieve_test_store", metadata={"purpose": "retrieval_test"})
|
created_store = client.vector_stores.create(name="retrieve_test_store", metadata={"purpose": "retrieval_test"})
|
||||||
|
@ -142,11 +146,11 @@ def test_openai_retrieve_vector_store(openai_client_with_empty_stores, client_wi
|
||||||
assert retrieved_store.object == "vector_store"
|
assert retrieved_store.object == "vector_store"
|
||||||
|
|
||||||
|
|
||||||
def test_openai_update_vector_store(openai_client_with_empty_stores, client_with_models):
|
def test_openai_update_vector_store(compat_client_with_empty_stores, client_with_models):
|
||||||
"""Test modifying a vector store using OpenAI API."""
|
"""Test modifying a vector store using OpenAI API."""
|
||||||
skip_if_provider_doesnt_support_openai_vector_stores(client_with_models)
|
skip_if_provider_doesnt_support_openai_vector_stores(client_with_models)
|
||||||
|
|
||||||
client = openai_client_with_empty_stores
|
client = compat_client_with_empty_stores
|
||||||
|
|
||||||
# Create a vector store
|
# Create a vector store
|
||||||
created_store = client.vector_stores.create(name="original_name", metadata={"version": "1.0"})
|
created_store = client.vector_stores.create(name="original_name", metadata={"version": "1.0"})
|
||||||
|
@ -165,11 +169,11 @@ def test_openai_update_vector_store(openai_client_with_empty_stores, client_with
|
||||||
assert modified_store.last_active_at > created_store.last_active_at
|
assert modified_store.last_active_at > created_store.last_active_at
|
||||||
|
|
||||||
|
|
||||||
def test_openai_delete_vector_store(openai_client_with_empty_stores, client_with_models):
|
def test_openai_delete_vector_store(compat_client_with_empty_stores, client_with_models):
|
||||||
"""Test deleting a vector store using OpenAI API."""
|
"""Test deleting a vector store using OpenAI API."""
|
||||||
skip_if_provider_doesnt_support_openai_vector_stores(client_with_models)
|
skip_if_provider_doesnt_support_openai_vector_stores(client_with_models)
|
||||||
|
|
||||||
client = openai_client_with_empty_stores
|
client = compat_client_with_empty_stores
|
||||||
|
|
||||||
# Create a vector store
|
# Create a vector store
|
||||||
created_store = client.vector_stores.create(name="delete_test_store", metadata={"purpose": "deletion_test"})
|
created_store = client.vector_stores.create(name="delete_test_store", metadata={"purpose": "deletion_test"})
|
||||||
|
@ -187,11 +191,11 @@ def test_openai_delete_vector_store(openai_client_with_empty_stores, client_with
|
||||||
client.vector_stores.retrieve(vector_store_id=created_store.id)
|
client.vector_stores.retrieve(vector_store_id=created_store.id)
|
||||||
|
|
||||||
|
|
||||||
def test_openai_vector_store_search_empty(openai_client_with_empty_stores, client_with_models):
|
def test_openai_vector_store_search_empty(compat_client_with_empty_stores, client_with_models):
|
||||||
"""Test searching an empty vector store using OpenAI API."""
|
"""Test searching an empty vector store using OpenAI API."""
|
||||||
skip_if_provider_doesnt_support_openai_vector_stores(client_with_models)
|
skip_if_provider_doesnt_support_openai_vector_stores(client_with_models)
|
||||||
|
|
||||||
client = openai_client_with_empty_stores
|
client = compat_client_with_empty_stores
|
||||||
|
|
||||||
# Create a vector store
|
# Create a vector store
|
||||||
vector_store = client.vector_stores.create(name="search_test_store", metadata={"purpose": "search_testing"})
|
vector_store = client.vector_stores.create(name="search_test_store", metadata={"purpose": "search_testing"})
|
||||||
|
@ -208,15 +212,15 @@ def test_openai_vector_store_search_empty(openai_client_with_empty_stores, clien
|
||||||
assert search_response.has_more is False
|
assert search_response.has_more is False
|
||||||
|
|
||||||
|
|
||||||
def test_openai_vector_store_with_chunks(openai_client_with_empty_stores, client_with_models, sample_chunks):
|
def test_openai_vector_store_with_chunks(compat_client_with_empty_stores, client_with_models, sample_chunks):
|
||||||
"""Test vector store functionality with actual chunks using both OpenAI and native APIs."""
|
"""Test vector store functionality with actual chunks using both OpenAI and native APIs."""
|
||||||
skip_if_provider_doesnt_support_openai_vector_stores(client_with_models)
|
skip_if_provider_doesnt_support_openai_vector_stores(client_with_models)
|
||||||
|
|
||||||
openai_client = openai_client_with_empty_stores
|
compat_client = compat_client_with_empty_stores
|
||||||
llama_client = client_with_models
|
llama_client = client_with_models
|
||||||
|
|
||||||
# Create a vector store using OpenAI API
|
# Create a vector store using OpenAI API
|
||||||
vector_store = openai_client.vector_stores.create(name="chunks_test_store", metadata={"purpose": "chunks_testing"})
|
vector_store = compat_client.vector_stores.create(name="chunks_test_store", metadata={"purpose": "chunks_testing"})
|
||||||
|
|
||||||
# Insert chunks using the native LlamaStack API (since OpenAI API doesn't have direct chunk insertion)
|
# Insert chunks using the native LlamaStack API (since OpenAI API doesn't have direct chunk insertion)
|
||||||
llama_client.vector_io.insert(
|
llama_client.vector_io.insert(
|
||||||
|
@ -225,7 +229,7 @@ def test_openai_vector_store_with_chunks(openai_client_with_empty_stores, client
|
||||||
)
|
)
|
||||||
|
|
||||||
# Search using OpenAI API
|
# Search using OpenAI API
|
||||||
search_response = openai_client.vector_stores.search(
|
search_response = compat_client.vector_stores.search(
|
||||||
vector_store_id=vector_store.id, query="What is Python programming language?", max_num_results=3
|
vector_store_id=vector_store.id, query="What is Python programming language?", max_num_results=3
|
||||||
)
|
)
|
||||||
assert search_response is not None
|
assert search_response is not None
|
||||||
|
@ -233,18 +237,19 @@ def test_openai_vector_store_with_chunks(openai_client_with_empty_stores, client
|
||||||
|
|
||||||
# The top result should be about Python (doc1)
|
# The top result should be about Python (doc1)
|
||||||
top_result = search_response.data[0]
|
top_result = search_response.data[0]
|
||||||
assert "python" in top_result.content.lower() or "programming" in top_result.content.lower()
|
top_content = top_result.content[0].text
|
||||||
assert top_result.metadata["document_id"] == "doc1"
|
assert "python" in top_content.lower() or "programming" in top_content.lower()
|
||||||
|
assert top_result.attributes["document_id"] == "doc1"
|
||||||
|
|
||||||
# Test filtering by metadata
|
# Test filtering by metadata
|
||||||
filtered_search = openai_client.vector_stores.search(
|
filtered_search = compat_client.vector_stores.search(
|
||||||
vector_store_id=vector_store.id, query="artificial intelligence", filters={"topic": "ai"}, max_num_results=5
|
vector_store_id=vector_store.id, query="artificial intelligence", filters={"topic": "ai"}, max_num_results=5
|
||||||
)
|
)
|
||||||
|
|
||||||
assert filtered_search is not None
|
assert filtered_search is not None
|
||||||
# All results should have topic "ai"
|
# All results should have topic "ai"
|
||||||
for result in filtered_search.data:
|
for result in filtered_search.data:
|
||||||
assert result.metadata["topic"] == "ai"
|
assert result.attributes["topic"] == "ai"
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
|
@ -257,18 +262,18 @@ def test_openai_vector_store_with_chunks(openai_client_with_empty_stores, client
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
def test_openai_vector_store_search_relevance(
|
def test_openai_vector_store_search_relevance(
|
||||||
openai_client_with_empty_stores, client_with_models, sample_chunks, test_case
|
compat_client_with_empty_stores, client_with_models, sample_chunks, test_case
|
||||||
):
|
):
|
||||||
"""Test that OpenAI vector store search returns relevant results for different queries."""
|
"""Test that OpenAI vector store search returns relevant results for different queries."""
|
||||||
skip_if_provider_doesnt_support_openai_vector_stores(client_with_models)
|
skip_if_provider_doesnt_support_openai_vector_stores(client_with_models)
|
||||||
|
|
||||||
openai_client = openai_client_with_empty_stores
|
compat_client = compat_client_with_empty_stores
|
||||||
llama_client = client_with_models
|
llama_client = client_with_models
|
||||||
|
|
||||||
query, expected_doc_id, expected_topic = test_case
|
query, expected_doc_id, expected_topic = test_case
|
||||||
|
|
||||||
# Create a vector store
|
# Create a vector store
|
||||||
vector_store = openai_client.vector_stores.create(
|
vector_store = compat_client.vector_stores.create(
|
||||||
name=f"relevance_test_{expected_doc_id}", metadata={"purpose": "relevance_testing"}
|
name=f"relevance_test_{expected_doc_id}", metadata={"purpose": "relevance_testing"}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -279,7 +284,7 @@ def test_openai_vector_store_search_relevance(
|
||||||
)
|
)
|
||||||
|
|
||||||
# Search using OpenAI API
|
# Search using OpenAI API
|
||||||
search_response = openai_client.vector_stores.search(
|
search_response = compat_client.vector_stores.search(
|
||||||
vector_store_id=vector_store.id, query=query, max_num_results=4
|
vector_store_id=vector_store.id, query=query, max_num_results=4
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -288,8 +293,9 @@ def test_openai_vector_store_search_relevance(
|
||||||
|
|
||||||
# The top result should match the expected document
|
# The top result should match the expected document
|
||||||
top_result = search_response.data[0]
|
top_result = search_response.data[0]
|
||||||
assert top_result.metadata["document_id"] == expected_doc_id
|
|
||||||
assert top_result.metadata["topic"] == expected_topic
|
assert top_result.attributes["document_id"] == expected_doc_id
|
||||||
|
assert top_result.attributes["topic"] == expected_topic
|
||||||
|
|
||||||
# Verify score is included and reasonable
|
# Verify score is included and reasonable
|
||||||
assert isinstance(top_result.score, int | float)
|
assert isinstance(top_result.score, int | float)
|
||||||
|
@ -297,16 +303,16 @@ def test_openai_vector_store_search_relevance(
|
||||||
|
|
||||||
|
|
||||||
def test_openai_vector_store_search_with_ranking_options(
|
def test_openai_vector_store_search_with_ranking_options(
|
||||||
openai_client_with_empty_stores, client_with_models, sample_chunks
|
compat_client_with_empty_stores, client_with_models, sample_chunks
|
||||||
):
|
):
|
||||||
"""Test OpenAI vector store search with ranking options."""
|
"""Test OpenAI vector store search with ranking options."""
|
||||||
skip_if_provider_doesnt_support_openai_vector_stores(client_with_models)
|
skip_if_provider_doesnt_support_openai_vector_stores(client_with_models)
|
||||||
|
|
||||||
openai_client = openai_client_with_empty_stores
|
compat_client = compat_client_with_empty_stores
|
||||||
llama_client = client_with_models
|
llama_client = client_with_models
|
||||||
|
|
||||||
# Create a vector store
|
# Create a vector store
|
||||||
vector_store = openai_client.vector_stores.create(
|
vector_store = compat_client.vector_stores.create(
|
||||||
name="ranking_test_store", metadata={"purpose": "ranking_testing"}
|
name="ranking_test_store", metadata={"purpose": "ranking_testing"}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -318,7 +324,7 @@ def test_openai_vector_store_search_with_ranking_options(
|
||||||
|
|
||||||
# Search with ranking options
|
# Search with ranking options
|
||||||
threshold = 0.1
|
threshold = 0.1
|
||||||
search_response = openai_client.vector_stores.search(
|
search_response = compat_client.vector_stores.search(
|
||||||
vector_store_id=vector_store.id,
|
vector_store_id=vector_store.id,
|
||||||
query="machine learning and artificial intelligence",
|
query="machine learning and artificial intelligence",
|
||||||
max_num_results=3,
|
max_num_results=3,
|
||||||
|
@ -334,16 +340,16 @@ def test_openai_vector_store_search_with_ranking_options(
|
||||||
|
|
||||||
|
|
||||||
def test_openai_vector_store_search_with_high_score_filter(
|
def test_openai_vector_store_search_with_high_score_filter(
|
||||||
openai_client_with_empty_stores, client_with_models, sample_chunks
|
compat_client_with_empty_stores, client_with_models, sample_chunks
|
||||||
):
|
):
|
||||||
"""Test that searching with text very similar to a document and high score threshold returns only that document."""
|
"""Test that searching with text very similar to a document and high score threshold returns only that document."""
|
||||||
skip_if_provider_doesnt_support_openai_vector_stores(client_with_models)
|
skip_if_provider_doesnt_support_openai_vector_stores(client_with_models)
|
||||||
|
|
||||||
openai_client = openai_client_with_empty_stores
|
compat_client = compat_client_with_empty_stores
|
||||||
llama_client = client_with_models
|
llama_client = client_with_models
|
||||||
|
|
||||||
# Create a vector store
|
# Create a vector store
|
||||||
vector_store = openai_client.vector_stores.create(
|
vector_store = compat_client.vector_stores.create(
|
||||||
name="high_score_filter_test", metadata={"purpose": "high_score_filtering"}
|
name="high_score_filter_test", metadata={"purpose": "high_score_filtering"}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -358,7 +364,7 @@ def test_openai_vector_store_search_with_high_score_filter(
|
||||||
query = "Python is a high-level programming language with code readability and fewer lines than C++ or Java"
|
query = "Python is a high-level programming language with code readability and fewer lines than C++ or Java"
|
||||||
|
|
||||||
# picking up thrshold to be slightly higher than the second result
|
# picking up thrshold to be slightly higher than the second result
|
||||||
search_response = openai_client.vector_stores.search(
|
search_response = compat_client.vector_stores.search(
|
||||||
vector_store_id=vector_store.id,
|
vector_store_id=vector_store.id,
|
||||||
query=query,
|
query=query,
|
||||||
max_num_results=3,
|
max_num_results=3,
|
||||||
|
@ -367,7 +373,7 @@ def test_openai_vector_store_search_with_high_score_filter(
|
||||||
threshold = search_response.data[1].score + 0.0001
|
threshold = search_response.data[1].score + 0.0001
|
||||||
|
|
||||||
# we expect only one result with the requested threshold
|
# we expect only one result with the requested threshold
|
||||||
search_response = openai_client.vector_stores.search(
|
search_response = compat_client.vector_stores.search(
|
||||||
vector_store_id=vector_store.id,
|
vector_store_id=vector_store.id,
|
||||||
query=query,
|
query=query,
|
||||||
max_num_results=10, # Allow more results but expect filtering
|
max_num_results=10, # Allow more results but expect filtering
|
||||||
|
@ -379,25 +385,26 @@ def test_openai_vector_store_search_with_high_score_filter(
|
||||||
|
|
||||||
# The top result should be the Python document (doc1)
|
# The top result should be the Python document (doc1)
|
||||||
top_result = search_response.data[0]
|
top_result = search_response.data[0]
|
||||||
assert top_result.metadata["document_id"] == "doc1"
|
assert top_result.attributes["document_id"] == "doc1"
|
||||||
assert top_result.metadata["topic"] == "programming"
|
assert top_result.attributes["topic"] == "programming"
|
||||||
assert top_result.score >= threshold
|
assert top_result.score >= threshold
|
||||||
|
|
||||||
# Verify the content contains Python-related terms
|
# Verify the content contains Python-related terms
|
||||||
assert "python" in top_result.content.lower() or "programming" in top_result.content.lower()
|
top_content = top_result.content[0].text
|
||||||
|
assert "python" in top_content.lower() or "programming" in top_content.lower()
|
||||||
|
|
||||||
|
|
||||||
def test_openai_vector_store_search_with_max_num_results(
|
def test_openai_vector_store_search_with_max_num_results(
|
||||||
openai_client_with_empty_stores, client_with_models, sample_chunks
|
compat_client_with_empty_stores, client_with_models, sample_chunks
|
||||||
):
|
):
|
||||||
"""Test OpenAI vector store search with max_num_results."""
|
"""Test OpenAI vector store search with max_num_results."""
|
||||||
skip_if_provider_doesnt_support_openai_vector_stores(client_with_models)
|
skip_if_provider_doesnt_support_openai_vector_stores(client_with_models)
|
||||||
|
|
||||||
openai_client = openai_client_with_empty_stores
|
compat_client = compat_client_with_empty_stores
|
||||||
llama_client = client_with_models
|
llama_client = client_with_models
|
||||||
|
|
||||||
# Create a vector store
|
# Create a vector store
|
||||||
vector_store = openai_client.vector_stores.create(
|
vector_store = compat_client.vector_stores.create(
|
||||||
name="max_num_results_test_store", metadata={"purpose": "max_num_results_testing"}
|
name="max_num_results_test_store", metadata={"purpose": "max_num_results_testing"}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -408,7 +415,7 @@ def test_openai_vector_store_search_with_max_num_results(
|
||||||
)
|
)
|
||||||
|
|
||||||
# Search with max_num_results
|
# Search with max_num_results
|
||||||
search_response = openai_client.vector_stores.search(
|
search_response = compat_client.vector_stores.search(
|
||||||
vector_store_id=vector_store.id,
|
vector_store_id=vector_store.id,
|
||||||
query="machine learning and artificial intelligence",
|
query="machine learning and artificial intelligence",
|
||||||
max_num_results=2,
|
max_num_results=2,
|
||||||
|
|
|
@ -154,3 +154,36 @@ def test_insert_chunks_with_precomputed_embeddings(client_with_empty_registry, e
|
||||||
assert len(response.chunks) > 0
|
assert len(response.chunks) > 0
|
||||||
assert response.chunks[0].metadata["document_id"] == "doc1"
|
assert response.chunks[0].metadata["document_id"] == "doc1"
|
||||||
assert response.chunks[0].metadata["source"] == "precomputed"
|
assert response.chunks[0].metadata["source"] == "precomputed"
|
||||||
|
|
||||||
|
|
||||||
|
def test_query_returns_valid_object_when_identical_to_embedding_in_vdb(client_with_empty_registry, embedding_model_id):
|
||||||
|
vector_db_id = "test_precomputed_embeddings_db"
|
||||||
|
client_with_empty_registry.vector_dbs.register(
|
||||||
|
vector_db_id=vector_db_id,
|
||||||
|
embedding_model=embedding_model_id,
|
||||||
|
embedding_dimension=384,
|
||||||
|
)
|
||||||
|
|
||||||
|
chunks_with_embeddings = [
|
||||||
|
Chunk(
|
||||||
|
content="duplicate",
|
||||||
|
metadata={"document_id": "doc1", "source": "precomputed"},
|
||||||
|
embedding=[0.1] * 384,
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
client_with_empty_registry.vector_io.insert(
|
||||||
|
vector_db_id=vector_db_id,
|
||||||
|
chunks=chunks_with_embeddings,
|
||||||
|
)
|
||||||
|
|
||||||
|
response = client_with_empty_registry.vector_io.query(
|
||||||
|
vector_db_id=vector_db_id,
|
||||||
|
query="duplicate",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify the top result is the expected document
|
||||||
|
assert response is not None
|
||||||
|
assert len(response.chunks) > 0
|
||||||
|
assert response.chunks[0].metadata["document_id"] == "doc1"
|
||||||
|
assert response.chunks[0].metadata["source"] == "precomputed"
|
||||||
|
|
|
@ -345,6 +345,56 @@ def test_invalid_oauth2_authentication(oauth2_client, invalid_token):
|
||||||
assert "Invalid JWT token" in response.json()["error"]["message"]
|
assert "Invalid JWT token" in response.json()["error"]["message"]
|
||||||
|
|
||||||
|
|
||||||
|
async def mock_auth_jwks_response(*args, **kwargs):
|
||||||
|
if "headers" not in kwargs or "Authorization" not in kwargs["headers"]:
|
||||||
|
return MockResponse(401, {})
|
||||||
|
authz = kwargs["headers"]["Authorization"]
|
||||||
|
if authz != "Bearer my-jwks-token":
|
||||||
|
return MockResponse(401, {})
|
||||||
|
return await mock_jwks_response(args, kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def oauth2_app_with_jwks_token():
|
||||||
|
app = FastAPI()
|
||||||
|
auth_config = AuthenticationConfig(
|
||||||
|
provider_type=AuthProviderType.OAUTH2_TOKEN,
|
||||||
|
config={
|
||||||
|
"jwks": {
|
||||||
|
"uri": "http://mock-authz-service/token/introspect",
|
||||||
|
"key_recheck_period": "3600",
|
||||||
|
"token": "my-jwks-token",
|
||||||
|
},
|
||||||
|
"audience": "llama-stack",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
app.add_middleware(AuthenticationMiddleware, auth_config=auth_config)
|
||||||
|
|
||||||
|
@app.get("/test")
|
||||||
|
def test_endpoint():
|
||||||
|
return {"message": "Authentication successful"}
|
||||||
|
|
||||||
|
return app
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def oauth2_client_with_jwks_token(oauth2_app_with_jwks_token):
|
||||||
|
return TestClient(oauth2_app_with_jwks_token)
|
||||||
|
|
||||||
|
|
||||||
|
@patch("httpx.AsyncClient.get", new=mock_auth_jwks_response)
|
||||||
|
def test_oauth2_with_jwks_token_expected(oauth2_client, jwt_token_valid):
|
||||||
|
response = oauth2_client.get("/test", headers={"Authorization": f"Bearer {jwt_token_valid}"})
|
||||||
|
assert response.status_code == 401
|
||||||
|
|
||||||
|
|
||||||
|
@patch("httpx.AsyncClient.get", new=mock_auth_jwks_response)
|
||||||
|
def test_oauth2_with_jwks_token_configured(oauth2_client_with_jwks_token, jwt_token_valid):
|
||||||
|
response = oauth2_client_with_jwks_token.get("/test", headers={"Authorization": f"Bearer {jwt_token_valid}"})
|
||||||
|
assert response.status_code == 200
|
||||||
|
assert response.json() == {"message": "Authentication successful"}
|
||||||
|
|
||||||
|
|
||||||
def test_get_attributes_from_claims():
|
def test_get_attributes_from_claims():
|
||||||
claims = {
|
claims = {
|
||||||
"sub": "my-user",
|
"sub": "my-user",
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue