Merge branch 'main' into feat/add-url-to-paginated-response

This commit is contained in:
Rohan Awhad 2025-06-13 13:07:45 -04:00 committed by GitHub
commit b5047db685
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
24 changed files with 911 additions and 856 deletions

View file

@ -52,30 +52,7 @@ jobs:
run: | run: |
kubectl create namespace llama-stack kubectl create namespace llama-stack
kubectl create serviceaccount llama-stack-auth -n llama-stack kubectl create serviceaccount llama-stack-auth -n llama-stack
kubectl create rolebinding llama-stack-auth-rolebinding --clusterrole=admin --serviceaccount=llama-stack:llama-stack-auth -n llama-stack
kubectl create token llama-stack-auth -n llama-stack > llama-stack-auth-token kubectl create token llama-stack-auth -n llama-stack > llama-stack-auth-token
cat <<EOF | kubectl apply -f -
apiVersion: rbac.authorization.k8s.io/v1
kind: ClusterRole
metadata:
name: allow-anonymous-openid
rules:
- nonResourceURLs: ["/openid/v1/jwks"]
verbs: ["get"]
---
apiVersion: rbac.authorization.k8s.io/v1
kind: ClusterRoleBinding
metadata:
name: allow-anonymous-openid
roleRef:
apiGroup: rbac.authorization.k8s.io
kind: ClusterRole
name: allow-anonymous-openid
subjects:
- kind: User
name: system:anonymous
apiGroup: rbac.authorization.k8s.io
EOF
- name: Set Kubernetes Config - name: Set Kubernetes Config
if: ${{ matrix.auth-provider == 'oauth2_token' }} if: ${{ matrix.auth-provider == 'oauth2_token' }}
@ -84,6 +61,7 @@ jobs:
echo "KUBERNETES_CA_CERT_PATH=$(kubectl config view --minify -o jsonpath='{.clusters[0].cluster.certificate-authority}')" >> $GITHUB_ENV echo "KUBERNETES_CA_CERT_PATH=$(kubectl config view --minify -o jsonpath='{.clusters[0].cluster.certificate-authority}')" >> $GITHUB_ENV
echo "KUBERNETES_ISSUER=$(kubectl get --raw /.well-known/openid-configuration| jq -r .issuer)" >> $GITHUB_ENV echo "KUBERNETES_ISSUER=$(kubectl get --raw /.well-known/openid-configuration| jq -r .issuer)" >> $GITHUB_ENV
echo "KUBERNETES_AUDIENCE=$(kubectl create token llama-stack-auth -n llama-stack --duration=1h | cut -d. -f2 | base64 -d | jq -r '.aud[0]')" >> $GITHUB_ENV echo "KUBERNETES_AUDIENCE=$(kubectl create token llama-stack-auth -n llama-stack --duration=1h | cut -d. -f2 | base64 -d | jq -r '.aud[0]')" >> $GITHUB_ENV
echo "TOKEN=$(cat llama-stack-auth-token)" >> $GITHUB_ENV
- name: Set Kube Auth Config and run server - name: Set Kube Auth Config and run server
env: env:
@ -101,7 +79,7 @@ jobs:
EOF EOF
yq eval '.server.auth = {"provider_type": "${{ matrix.auth-provider }}"}' -i $run_dir/run.yaml yq eval '.server.auth = {"provider_type": "${{ matrix.auth-provider }}"}' -i $run_dir/run.yaml
yq eval '.server.auth.config = {"tls_cafile": "${{ env.KUBERNETES_CA_CERT_PATH }}", "issuer": "${{ env.KUBERNETES_ISSUER }}", "audience": "${{ env.KUBERNETES_AUDIENCE }}"}' -i $run_dir/run.yaml yq eval '.server.auth.config = {"tls_cafile": "${{ env.KUBERNETES_CA_CERT_PATH }}", "issuer": "${{ env.KUBERNETES_ISSUER }}", "audience": "${{ env.KUBERNETES_AUDIENCE }}"}' -i $run_dir/run.yaml
yq eval '.server.auth.config.jwks = {"uri": "${{ env.KUBERNETES_API_SERVER_URL }}"}' -i $run_dir/run.yaml yq eval '.server.auth.config.jwks = {"uri": "${{ env.KUBERNETES_API_SERVER_URL }}", "token": "${{ env.TOKEN }}"}' -i $run_dir/run.yaml
cat $run_dir/run.yaml cat $run_dir/run.yaml
nohup uv run llama stack run $run_dir/run.yaml --image-type venv > server.log 2>&1 & nohup uv run llama stack run $run_dir/run.yaml --image-type venv > server.log 2>&1 &

View file

@ -24,7 +24,7 @@ jobs:
matrix: matrix:
# Listing tests manually since some of them currently fail # Listing tests manually since some of them currently fail
# TODO: generate matrix list from tests/integration when fixed # TODO: generate matrix list from tests/integration when fixed
test-type: [agents, inference, datasets, inspect, scoring, post_training, providers, tool_runtime] test-type: [agents, inference, datasets, inspect, scoring, post_training, providers, tool_runtime, vector_io]
client-type: [library, http] client-type: [library, http]
python-version: ["3.10", "3.11", "3.12"] python-version: ["3.10", "3.11", "3.12"]
fail-fast: false # we want to run all tests regardless of failure fail-fast: false # we want to run all tests regardless of failure

View file

@ -45,20 +45,22 @@ jobs:
- name: Build distro from config file - name: Build distro from config file
run: | run: |
USE_COPY_NOT_MOUNT=true LLAMA_STACK_DIR=. uv run llama stack build --config tests/external-provider/llama-stack-provider-ollama/custom-distro.yaml USE_COPY_NOT_MOUNT=true LLAMA_STACK_DIR=. llama stack build --config tests/external-provider/llama-stack-provider-ollama/custom-distro.yaml
- name: Start Llama Stack server in background - name: Start Llama Stack server in background
if: ${{ matrix.image-type }} == 'venv' if: ${{ matrix.image-type }} == 'venv'
env: env:
INFERENCE_MODEL: "meta-llama/Llama-3.2-3B-Instruct" INFERENCE_MODEL: "meta-llama/Llama-3.2-3B-Instruct"
run: | run: |
uv run pip list # Use the virtual environment created by the build step (name comes from build config)
nohup uv run --active llama stack run tests/external-provider/llama-stack-provider-ollama/run.yaml --image-type ${{ matrix.image-type }} > server.log 2>&1 & source ci-test/bin/activate
uv pip list
nohup llama stack run tests/external-provider/llama-stack-provider-ollama/run.yaml --image-type ${{ matrix.image-type }} > server.log 2>&1 &
- name: Wait for Llama Stack server to be ready - name: Wait for Llama Stack server to be ready
run: | run: |
for i in {1..30}; do for i in {1..30}; do
if ! grep -q "remote::custom_ollama from /home/runner/.llama/providers.d/remote/inference/custom_ollama.yaml" server.log; then if ! grep -q "Successfully loaded external provider remote::custom_ollama" server.log; then
echo "Waiting for Llama Stack server to load the provider..." echo "Waiting for Llama Stack server to load the provider..."
sleep 1 sleep 1
else else

View file

@ -3318,7 +3318,7 @@
"name": "limit", "name": "limit",
"in": "query", "in": "query",
"description": "A limit on the number of objects to be returned. Limit can range between 1 and 100, and the default is 20.", "description": "A limit on the number of objects to be returned. Limit can range between 1 and 100, and the default is 20.",
"required": true, "required": false,
"schema": { "schema": {
"type": "integer" "type": "integer"
} }
@ -3327,7 +3327,7 @@
"name": "order", "name": "order",
"in": "query", "in": "query",
"description": "Sort order by the `created_at` timestamp of the objects. `asc` for ascending order and `desc` for descending order.", "description": "Sort order by the `created_at` timestamp of the objects. `asc` for ascending order and `desc` for descending order.",
"required": true, "required": false,
"schema": { "schema": {
"type": "string" "type": "string"
} }
@ -3864,7 +3864,7 @@
"content": { "content": {
"application/json": { "application/json": {
"schema": { "schema": {
"$ref": "#/components/schemas/VectorStoreSearchResponse" "$ref": "#/components/schemas/VectorStoreSearchResponsePage"
} }
} }
} }
@ -12587,6 +12587,9 @@
} }
}, },
"additionalProperties": false, "additionalProperties": false,
"required": [
"name"
],
"title": "OpenaiCreateVectorStoreRequest" "title": "OpenaiCreateVectorStoreRequest"
}, },
"VectorStoreObject": { "VectorStoreObject": {
@ -13129,13 +13132,74 @@
}, },
"additionalProperties": false, "additionalProperties": false,
"required": [ "required": [
"query", "query"
"max_num_results",
"rewrite_query"
], ],
"title": "OpenaiSearchVectorStoreRequest" "title": "OpenaiSearchVectorStoreRequest"
}, },
"VectorStoreContent": {
"type": "object",
"properties": {
"type": {
"type": "string",
"const": "text"
},
"text": {
"type": "string"
}
},
"additionalProperties": false,
"required": [
"type",
"text"
],
"title": "VectorStoreContent"
},
"VectorStoreSearchResponse": { "VectorStoreSearchResponse": {
"type": "object",
"properties": {
"file_id": {
"type": "string"
},
"filename": {
"type": "string"
},
"score": {
"type": "number"
},
"attributes": {
"type": "object",
"additionalProperties": {
"oneOf": [
{
"type": "string"
},
{
"type": "number"
},
{
"type": "boolean"
}
]
}
},
"content": {
"type": "array",
"items": {
"$ref": "#/components/schemas/VectorStoreContent"
}
}
},
"additionalProperties": false,
"required": [
"file_id",
"filename",
"score",
"content"
],
"title": "VectorStoreSearchResponse",
"description": "Response from searching a vector store."
},
"VectorStoreSearchResponsePage": {
"type": "object", "type": "object",
"properties": { "properties": {
"object": { "object": {
@ -13148,29 +13212,7 @@
"data": { "data": {
"type": "array", "type": "array",
"items": { "items": {
"type": "object", "$ref": "#/components/schemas/VectorStoreSearchResponse"
"additionalProperties": {
"oneOf": [
{
"type": "null"
},
{
"type": "boolean"
},
{
"type": "number"
},
{
"type": "string"
},
{
"type": "array"
},
{
"type": "object"
}
]
}
} }
}, },
"has_more": { "has_more": {
@ -13188,7 +13230,7 @@
"data", "data",
"has_more" "has_more"
], ],
"title": "VectorStoreSearchResponse", "title": "VectorStoreSearchResponsePage",
"description": "Response from searching a vector store." "description": "Response from searching a vector store."
}, },
"OpenaiUpdateVectorStoreRequest": { "OpenaiUpdateVectorStoreRequest": {

View file

@ -2323,7 +2323,7 @@ paths:
description: >- description: >-
A limit on the number of objects to be returned. Limit can range between A limit on the number of objects to be returned. Limit can range between
1 and 100, and the default is 20. 1 and 100, and the default is 20.
required: true required: false
schema: schema:
type: integer type: integer
- name: order - name: order
@ -2331,7 +2331,7 @@ paths:
description: >- description: >-
Sort order by the `created_at` timestamp of the objects. `asc` for ascending Sort order by the `created_at` timestamp of the objects. `asc` for ascending
order and `desc` for descending order. order and `desc` for descending order.
required: true required: false
schema: schema:
type: string type: string
- name: after - name: after
@ -2734,7 +2734,7 @@ paths:
content: content:
application/json: application/json:
schema: schema:
$ref: '#/components/schemas/VectorStoreSearchResponse' $ref: '#/components/schemas/VectorStoreSearchResponsePage'
'400': '400':
$ref: '#/components/responses/BadRequest400' $ref: '#/components/responses/BadRequest400'
'429': '429':
@ -8794,6 +8794,8 @@ components:
description: >- description: >-
The provider-specific vector database ID. The provider-specific vector database ID.
additionalProperties: false additionalProperties: false
required:
- name
title: OpenaiCreateVectorStoreRequest title: OpenaiCreateVectorStoreRequest
VectorStoreObject: VectorStoreObject:
type: object type: object
@ -9190,10 +9192,49 @@ components:
additionalProperties: false additionalProperties: false
required: required:
- query - query
- max_num_results
- rewrite_query
title: OpenaiSearchVectorStoreRequest title: OpenaiSearchVectorStoreRequest
VectorStoreContent:
type: object
properties:
type:
type: string
const: text
text:
type: string
additionalProperties: false
required:
- type
- text
title: VectorStoreContent
VectorStoreSearchResponse: VectorStoreSearchResponse:
type: object
properties:
file_id:
type: string
filename:
type: string
score:
type: number
attributes:
type: object
additionalProperties:
oneOf:
- type: string
- type: number
- type: boolean
content:
type: array
items:
$ref: '#/components/schemas/VectorStoreContent'
additionalProperties: false
required:
- file_id
- filename
- score
- content
title: VectorStoreSearchResponse
description: Response from searching a vector store.
VectorStoreSearchResponsePage:
type: object type: object
properties: properties:
object: object:
@ -9204,15 +9245,7 @@ components:
data: data:
type: array type: array
items: items:
type: object $ref: '#/components/schemas/VectorStoreSearchResponse'
additionalProperties:
oneOf:
- type: 'null'
- type: boolean
- type: number
- type: string
- type: array
- type: object
has_more: has_more:
type: boolean type: boolean
default: false default: false
@ -9224,7 +9257,7 @@ components:
- search_query - search_query
- data - data
- has_more - has_more
title: VectorStoreSearchResponse title: VectorStoreSearchResponsePage
description: Response from searching a vector store. description: Response from searching a vector store.
OpenaiUpdateVectorStoreRequest: OpenaiUpdateVectorStoreRequest:
type: object type: object

View file

@ -56,10 +56,10 @@ shields: []
server: server:
port: 8321 port: 8321
auth: auth:
provider_type: "kubernetes" provider_type: "oauth2_token"
config: config:
api_server_url: "https://kubernetes.default.svc" jwks:
ca_cert_path: "/path/to/ca.crt" uri: "https://my-token-issuing-svc.com/jwks"
``` ```
Let's break this down into the different sections. The first section specifies the set of APIs that the stack server will serve: Let's break this down into the different sections. The first section specifies the set of APIs that the stack server will serve:
@ -132,16 +132,52 @@ The server supports multiple authentication providers:
#### OAuth 2.0/OpenID Connect Provider with Kubernetes #### OAuth 2.0/OpenID Connect Provider with Kubernetes
The Kubernetes cluster must be configured to use a service account for authentication. The server can be configured to use service account tokens for authorization, validating these against the Kubernetes API server, e.g.:
```yaml
server:
auth:
provider_type: "oauth2_token"
config:
jwks:
uri: "https://kubernetes.default.svc:8443/openid/v1/jwks"
token: "${env.TOKEN:}"
key_recheck_period: 3600
tls_cafile: "/path/to/ca.crt"
issuer: "https://kubernetes.default.svc"
audience: "https://kubernetes.default.svc"
```
To find your cluster's jwks uri (from which the public key(s) to verify the token signature are obtained), run:
```
kubectl get --raw /.well-known/openid-configuration| jq -r .jwks_uri
```
For the tls_cafile, you can use the CA certificate of the OIDC provider:
```bash
kubectl config view --minify -o jsonpath='{.clusters[0].cluster.certificate-authority}'
```
For the issuer, you can use the OIDC provider's URL:
```bash
kubectl get --raw /.well-known/openid-configuration| jq .issuer
```
The audience can be obtained from a token, e.g. run:
```bash
kubectl create token default --duration=1h | cut -d. -f2 | base64 -d | jq .aud
```
The jwks token is used to authorize access to the jwks endpoint. You can obtain a token by running:
```bash ```bash
kubectl create namespace llama-stack kubectl create namespace llama-stack
kubectl create serviceaccount llama-stack-auth -n llama-stack kubectl create serviceaccount llama-stack-auth -n llama-stack
kubectl create rolebinding llama-stack-auth-rolebinding --clusterrole=admin --serviceaccount=llama-stack:llama-stack-auth -n llama-stack
kubectl create token llama-stack-auth -n llama-stack > llama-stack-auth-token kubectl create token llama-stack-auth -n llama-stack > llama-stack-auth-token
export TOKEN=$(cat llama-stack-auth-token)
``` ```
Make sure the `kube-apiserver` runs with `--anonymous-auth=true` to allow unauthenticated requests Alternatively, you can configure the jwks endpoint to allow anonymous access. To do this, make sure
the `kube-apiserver` runs with `--anonymous-auth=true` to allow unauthenticated requests
and that the correct RoleBinding is created to allow the service account to access the necessary and that the correct RoleBinding is created to allow the service account to access the necessary
resources. If that is not the case, you can create a RoleBinding for the service account to access resources. If that is not the case, you can create a RoleBinding for the service account to access
the necessary resources: the necessary resources:
@ -175,35 +211,6 @@ And then apply the configuration:
kubectl apply -f allow-anonymous-openid.yaml kubectl apply -f allow-anonymous-openid.yaml
``` ```
Validates tokens against the Kubernetes API server through the OIDC provider:
```yaml
server:
auth:
provider_type: "oauth2_token"
config:
jwks:
uri: "https://kubernetes.default.svc"
key_recheck_period: 3600
tls_cafile: "/path/to/ca.crt"
issuer: "https://kubernetes.default.svc"
audience: "https://kubernetes.default.svc"
```
To find your cluster's audience, run:
```bash
kubectl create token default --duration=1h | cut -d. -f2 | base64 -d | jq .aud
```
For the issuer, you can use the OIDC provider's URL:
```bash
kubectl get --raw /.well-known/openid-configuration| jq .issuer
```
For the tls_cafile, you can use the CA certificate of the OIDC provider:
```bash
kubectl config view --minify -o jsonpath='{.clusters[0].cluster.certificate-authority}'
```
The provider extracts user information from the JWT token: The provider extracts user information from the JWT token:
- Username from the `sub` claim becomes a role - Username from the `sub` claim becomes a role
- Kubernetes groups become teams - Kubernetes groups become teams

View file

@ -8,7 +8,7 @@
# #
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
from typing import Any, Protocol, runtime_checkable from typing import Any, Literal, Protocol, runtime_checkable
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
@ -96,13 +96,30 @@ class VectorStoreSearchRequest(BaseModel):
rewrite_query: bool = False rewrite_query: bool = False
@json_schema_type
class VectorStoreContent(BaseModel):
type: Literal["text"]
text: str
@json_schema_type @json_schema_type
class VectorStoreSearchResponse(BaseModel): class VectorStoreSearchResponse(BaseModel):
"""Response from searching a vector store.""" """Response from searching a vector store."""
file_id: str
filename: str
score: float
attributes: dict[str, str | float | bool] | None = None
content: list[VectorStoreContent]
@json_schema_type
class VectorStoreSearchResponsePage(BaseModel):
"""Response from searching a vector store."""
object: str = "vector_store.search_results.page" object: str = "vector_store.search_results.page"
search_query: str search_query: str
data: list[dict[str, Any]] data: list[VectorStoreSearchResponse]
has_more: bool = False has_more: bool = False
next_page: str | None = None next_page: str | None = None
@ -165,7 +182,7 @@ class VectorIO(Protocol):
@webmethod(route="/openai/v1/vector_stores", method="POST") @webmethod(route="/openai/v1/vector_stores", method="POST")
async def openai_create_vector_store( async def openai_create_vector_store(
self, self,
name: str | None = None, name: str,
file_ids: list[str] | None = None, file_ids: list[str] | None = None,
expires_after: dict[str, Any] | None = None, expires_after: dict[str, Any] | None = None,
chunking_strategy: dict[str, Any] | None = None, chunking_strategy: dict[str, Any] | None = None,
@ -193,8 +210,8 @@ class VectorIO(Protocol):
@webmethod(route="/openai/v1/vector_stores", method="GET") @webmethod(route="/openai/v1/vector_stores", method="GET")
async def openai_list_vector_stores( async def openai_list_vector_stores(
self, self,
limit: int = 20, limit: int | None = 20,
order: str = "desc", order: str | None = "desc",
after: str | None = None, after: str | None = None,
before: str | None = None, before: str | None = None,
) -> VectorStoreListResponse: ) -> VectorStoreListResponse:
@ -256,10 +273,10 @@ class VectorIO(Protocol):
vector_store_id: str, vector_store_id: str,
query: str | list[str], query: str | list[str],
filters: dict[str, Any] | None = None, filters: dict[str, Any] | None = None,
max_num_results: int = 10, max_num_results: int | None = 10,
ranking_options: dict[str, Any] | None = None, ranking_options: dict[str, Any] | None = None,
rewrite_query: bool = False, rewrite_query: bool | None = False,
) -> VectorStoreSearchResponse: ) -> VectorStoreSearchResponsePage:
"""Search for chunks in a vector store. """Search for chunks in a vector store.
Searches a vector store for relevant chunks based on a query and optional file attribute filters. Searches a vector store for relevant chunks based on a query and optional file attribute filters.

View file

@ -180,6 +180,7 @@ def get_provider_registry(
if provider_type_key in ret[api]: if provider_type_key in ret[api]:
logger.warning(f"Overriding already registered provider {provider_type_key} for {api.name}") logger.warning(f"Overriding already registered provider {provider_type_key} for {api.name}")
ret[api][provider_type_key] = spec ret[api][provider_type_key] = spec
logger.info(f"Successfully loaded external provider {provider_type_key}")
except yaml.YAMLError as yaml_err: except yaml.YAMLError as yaml_err:
logger.error(f"Failed to parse YAML file {spec_path}: {yaml_err}") logger.error(f"Failed to parse YAML file {spec_path}: {yaml_err}")
raise yaml_err raise yaml_err

View file

@ -394,9 +394,13 @@ def check_protocol_compliance(obj: Any, protocol: Any) -> None:
logger.error(f"Method {name} incompatible proto: {proto_params} vs. obj: {obj_params}") logger.error(f"Method {name} incompatible proto: {proto_params} vs. obj: {obj_params}")
missing_methods.append((name, "signature_mismatch")) missing_methods.append((name, "signature_mismatch"))
else: else:
# Check if the method is actually implemented in the class # Check if the method has a concrete implementation (not just a protocol stub)
method_owner = next((cls for cls in mro if name in cls.__dict__), None) # Find all classes in MRO that define this method
if method_owner is None or method_owner.__name__ == protocol.__name__: method_owners = [cls for cls in mro if name in cls.__dict__]
# Allow methods from mixins/parents, only reject if ONLY the protocol defines it
if len(method_owners) == 1 and method_owners[0].__name__ == protocol.__name__:
# Only reject if the method is ONLY defined in the protocol itself (abstract stub)
missing_methods.append((name, "not_actually_implemented")) missing_methods.append((name, "not_actually_implemented"))
if missing_methods: if missing_methods:

View file

@ -163,6 +163,9 @@ class InferenceRouter(Inference):
messages: list[Message] | InterleavedContent, messages: list[Message] | InterleavedContent,
tool_prompt_format: ToolPromptFormat | None = None, tool_prompt_format: ToolPromptFormat | None = None,
) -> int | None: ) -> int | None:
if not hasattr(self, "formatter") or self.formatter is None:
return None
if isinstance(messages, list): if isinstance(messages, list):
encoded = self.formatter.encode_dialog_prompt(messages, tool_prompt_format) encoded = self.formatter.encode_dialog_prompt(messages, tool_prompt_format)
else: else:

View file

@ -17,7 +17,7 @@ from llama_stack.apis.vector_io import (
VectorStoreDeleteResponse, VectorStoreDeleteResponse,
VectorStoreListResponse, VectorStoreListResponse,
VectorStoreObject, VectorStoreObject,
VectorStoreSearchResponse, VectorStoreSearchResponsePage,
) )
from llama_stack.log import get_logger from llama_stack.log import get_logger
from llama_stack.providers.datatypes import RoutingTable from llama_stack.providers.datatypes import RoutingTable
@ -108,7 +108,7 @@ class VectorIORouter(VectorIO):
# OpenAI Vector Stores API endpoints # OpenAI Vector Stores API endpoints
async def openai_create_vector_store( async def openai_create_vector_store(
self, self,
name: str | None = None, name: str,
file_ids: list[str] | None = None, file_ids: list[str] | None = None,
expires_after: dict[str, Any] | None = None, expires_after: dict[str, Any] | None = None,
chunking_strategy: dict[str, Any] | None = None, chunking_strategy: dict[str, Any] | None = None,
@ -151,8 +151,8 @@ class VectorIORouter(VectorIO):
async def openai_list_vector_stores( async def openai_list_vector_stores(
self, self,
limit: int = 20, limit: int | None = 20,
order: str = "desc", order: str | None = "desc",
after: str | None = None, after: str | None = None,
before: str | None = None, before: str | None = None,
) -> VectorStoreListResponse: ) -> VectorStoreListResponse:
@ -239,10 +239,10 @@ class VectorIORouter(VectorIO):
vector_store_id: str, vector_store_id: str,
query: str | list[str], query: str | list[str],
filters: dict[str, Any] | None = None, filters: dict[str, Any] | None = None,
max_num_results: int = 10, max_num_results: int | None = 10,
ranking_options: dict[str, Any] | None = None, ranking_options: dict[str, Any] | None = None,
rewrite_query: bool = False, rewrite_query: bool | None = False,
) -> VectorStoreSearchResponse: ) -> VectorStoreSearchResponsePage:
logger.debug(f"VectorIORouter.openai_search_vector_store: {vector_store_id}") logger.debug(f"VectorIORouter.openai_search_vector_store: {vector_store_id}")
# Route based on vector store ID # Route based on vector store ID
provider = self.routing_table.get_provider_impl(vector_store_id) provider = self.routing_table.get_provider_impl(vector_store_id)

View file

@ -84,6 +84,7 @@ def get_attributes_from_claims(claims: dict[str, str], mapping: dict[str, str])
class OAuth2JWKSConfig(BaseModel): class OAuth2JWKSConfig(BaseModel):
# The JWKS URI for collecting public keys # The JWKS URI for collecting public keys
uri: str uri: str
token: str | None = Field(default=None, description="token to authorise access to jwks")
key_recheck_period: int = Field(default=3600, description="The period to recheck the JWKS URI for key updates") key_recheck_period: int = Field(default=3600, description="The period to recheck the JWKS URI for key updates")
@ -246,9 +247,12 @@ class OAuth2TokenAuthProvider(AuthProvider):
if self.config.jwks is None: if self.config.jwks is None:
raise ValueError("JWKS is not configured") raise ValueError("JWKS is not configured")
if time.time() - self._jwks_at > self.config.jwks.key_recheck_period: if time.time() - self._jwks_at > self.config.jwks.key_recheck_period:
headers = {}
if self.config.jwks.token:
headers["Authorization"] = f"Bearer {self.config.jwks.token}"
verify = self.config.tls_cafile.as_posix() if self.config.tls_cafile else self.config.verify_tls verify = self.config.tls_cafile.as_posix() if self.config.tls_cafile else self.config.verify_tls
async with httpx.AsyncClient(verify=verify) as client: async with httpx.AsyncClient(verify=verify) as client:
res = await client.get(self.config.jwks.uri, timeout=5) res = await client.get(self.config.jwks.uri, timeout=5, headers=headers)
res.raise_for_status() res.raise_for_status()
jwks_data = res.json()["keys"] jwks_data = res.json()["keys"]
updated = {} updated = {}

View file

@ -115,7 +115,7 @@ def parse_environment_config(env_config: str) -> dict[str, int]:
class CustomRichHandler(RichHandler): class CustomRichHandler(RichHandler):
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
kwargs["console"] = Console(width=120) kwargs["console"] = Console(width=150)
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
def emit(self, record): def emit(self, record):

View file

@ -9,9 +9,7 @@ import base64
import io import io
import json import json
import logging import logging
import time from typing import Any
import uuid
from typing import Any, Literal
import faiss import faiss
import numpy as np import numpy as np
@ -24,14 +22,11 @@ from llama_stack.apis.vector_io import (
Chunk, Chunk,
QueryChunksResponse, QueryChunksResponse,
VectorIO, VectorIO,
VectorStoreDeleteResponse,
VectorStoreListResponse,
VectorStoreObject,
VectorStoreSearchResponse,
) )
from llama_stack.providers.datatypes import VectorDBsProtocolPrivate from llama_stack.providers.datatypes import VectorDBsProtocolPrivate
from llama_stack.providers.utils.kvstore import kvstore_impl from llama_stack.providers.utils.kvstore import kvstore_impl
from llama_stack.providers.utils.kvstore.api import KVStore from llama_stack.providers.utils.kvstore.api import KVStore
from llama_stack.providers.utils.memory.openai_vector_store_mixin import OpenAIVectorStoreMixin
from llama_stack.providers.utils.memory.vector_store import ( from llama_stack.providers.utils.memory.vector_store import (
EmbeddingIndex, EmbeddingIndex,
VectorDBWithIndex, VectorDBWithIndex,
@ -47,10 +42,6 @@ FAISS_INDEX_PREFIX = f"faiss_index:{VERSION}::"
OPENAI_VECTOR_STORES_PREFIX = f"openai_vector_stores:{VERSION}::" OPENAI_VECTOR_STORES_PREFIX = f"openai_vector_stores:{VERSION}::"
# In faiss, since we do
CHUNK_MULTIPLIER = 5
class FaissIndex(EmbeddingIndex): class FaissIndex(EmbeddingIndex):
def __init__(self, dimension: int, kvstore: KVStore | None = None, bank_id: str | None = None): def __init__(self, dimension: int, kvstore: KVStore | None = None, bank_id: str | None = None):
self.index = faiss.IndexFlatL2(dimension) self.index = faiss.IndexFlatL2(dimension)
@ -140,7 +131,7 @@ class FaissIndex(EmbeddingIndex):
raise NotImplementedError("Keyword search is not supported in FAISS") raise NotImplementedError("Keyword search is not supported in FAISS")
class FaissVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate): class FaissVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolPrivate):
def __init__(self, config: FaissVectorIOConfig, inference_api: Inference) -> None: def __init__(self, config: FaissVectorIOConfig, inference_api: Inference) -> None:
self.config = config self.config = config
self.inference_api = inference_api self.inference_api = inference_api
@ -164,14 +155,8 @@ class FaissVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate):
) )
self.cache[vector_db.identifier] = index self.cache[vector_db.identifier] = index
# Load existing OpenAI vector stores # Load existing OpenAI vector stores using the mixin method
start_key = OPENAI_VECTOR_STORES_PREFIX self.openai_vector_stores = await self._load_openai_vector_stores()
end_key = f"{OPENAI_VECTOR_STORES_PREFIX}\xff"
stored_openai_stores = await self.kvstore.values_in_range(start_key, end_key)
for store_data in stored_openai_stores:
store_info = json.loads(store_data)
self.openai_vector_stores[store_info["id"]] = store_info
async def shutdown(self) -> None: async def shutdown(self) -> None:
# Cleanup if needed # Cleanup if needed
@ -234,285 +219,34 @@ class FaissVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate):
return await index.query_chunks(query, params) return await index.query_chunks(query, params)
# OpenAI Vector Stores API endpoints implementation # OpenAI Vector Store Mixin abstract method implementations
async def openai_create_vector_store( async def _save_openai_vector_store(self, store_id: str, store_info: dict[str, Any]) -> None:
self, """Save vector store metadata to kvstore."""
name: str | None = None,
file_ids: list[str] | None = None,
expires_after: dict[str, Any] | None = None,
chunking_strategy: dict[str, Any] | None = None,
metadata: dict[str, Any] | None = None,
embedding_model: str | None = None,
embedding_dimension: int | None = 384,
provider_id: str | None = None,
provider_vector_db_id: str | None = None,
) -> VectorStoreObject:
"""Creates a vector store."""
assert self.kvstore is not None assert self.kvstore is not None
# store and vector_db have the same id
store_id = name or str(uuid.uuid4())
created_at = int(time.time())
if provider_id is None:
raise ValueError("Provider ID is required")
if embedding_model is None:
raise ValueError("Embedding model is required")
# Use provided embedding dimension or default to 384
if embedding_dimension is None:
raise ValueError("Embedding dimension is required")
provider_vector_db_id = provider_vector_db_id or store_id
vector_db = VectorDB(
identifier=store_id,
embedding_dimension=embedding_dimension,
embedding_model=embedding_model,
provider_id=provider_id,
provider_resource_id=provider_vector_db_id,
)
# Register the vector DB
await self.register_vector_db(vector_db)
# Create OpenAI vector store metadata
store_info = {
"id": store_id,
"object": "vector_store",
"created_at": created_at,
"name": store_id,
"usage_bytes": 0,
"file_counts": {},
"status": "completed",
"expires_after": expires_after,
"expires_at": None,
"last_active_at": created_at,
"file_ids": file_ids or [],
"chunking_strategy": chunking_strategy,
}
# Add provider information to metadata if provided
metadata = metadata or {}
if provider_id:
metadata["provider_id"] = provider_id
if provider_vector_db_id:
metadata["provider_vector_db_id"] = provider_vector_db_id
store_info["metadata"] = metadata
# Store in kvstore
key = f"{OPENAI_VECTOR_STORES_PREFIX}{store_id}" key = f"{OPENAI_VECTOR_STORES_PREFIX}{store_id}"
await self.kvstore.set(key=key, value=json.dumps(store_info)) await self.kvstore.set(key=key, value=json.dumps(store_info))
# Store in memory cache async def _load_openai_vector_stores(self) -> dict[str, dict[str, Any]]:
self.openai_vector_stores[store_id] = store_info """Load all vector store metadata from kvstore."""
return VectorStoreObject(
id=store_id,
created_at=created_at,
name=store_id,
usage_bytes=0,
file_counts={},
status="completed",
expires_after=expires_after,
expires_at=None,
last_active_at=created_at,
metadata=metadata,
)
async def openai_list_vector_stores(
self,
limit: int = 20,
order: str = "desc",
after: str | None = None,
before: str | None = None,
) -> VectorStoreListResponse:
"""Returns a list of vector stores."""
# Get all vector stores
all_stores = list(self.openai_vector_stores.values())
# Sort by created_at
reverse_order = order == "desc"
all_stores.sort(key=lambda x: x["created_at"], reverse=reverse_order)
# Apply cursor-based pagination
if after:
after_index = next((i for i, store in enumerate(all_stores) if store["id"] == after), -1)
if after_index >= 0:
all_stores = all_stores[after_index + 1 :]
if before:
before_index = next((i for i, store in enumerate(all_stores) if store["id"] == before), len(all_stores))
all_stores = all_stores[:before_index]
# Apply limit
limited_stores = all_stores[:limit]
# Convert to VectorStoreObject instances
data = [VectorStoreObject(**store) for store in limited_stores]
# Determine pagination info
has_more = len(all_stores) > limit
first_id = data[0].id if data else None
last_id = data[-1].id if data else None
return VectorStoreListResponse(
data=data,
has_more=has_more,
first_id=first_id,
last_id=last_id,
)
async def openai_retrieve_vector_store(
self,
vector_store_id: str,
) -> VectorStoreObject:
"""Retrieves a vector store."""
if vector_store_id not in self.openai_vector_stores:
raise ValueError(f"Vector store {vector_store_id} not found")
store_info = self.openai_vector_stores[vector_store_id]
return VectorStoreObject(**store_info)
async def openai_update_vector_store(
self,
vector_store_id: str,
name: str | None = None,
expires_after: dict[str, Any] | None = None,
metadata: dict[str, Any] | None = None,
) -> VectorStoreObject:
"""Modifies a vector store."""
assert self.kvstore is not None assert self.kvstore is not None
if vector_store_id not in self.openai_vector_stores: start_key = OPENAI_VECTOR_STORES_PREFIX
raise ValueError(f"Vector store {vector_store_id} not found") end_key = f"{OPENAI_VECTOR_STORES_PREFIX}\xff"
stored_openai_stores = await self.kvstore.values_in_range(start_key, end_key)
store_info = self.openai_vector_stores[vector_store_id].copy() stores = {}
for store_data in stored_openai_stores:
store_info = json.loads(store_data)
stores[store_info["id"]] = store_info
return stores
# Update fields if provided async def _update_openai_vector_store(self, store_id: str, store_info: dict[str, Any]) -> None:
if name is not None: """Update vector store metadata in kvstore."""
store_info["name"] = name assert self.kvstore is not None
if expires_after is not None: key = f"{OPENAI_VECTOR_STORES_PREFIX}{store_id}"
store_info["expires_after"] = expires_after
if metadata is not None:
store_info["metadata"] = metadata
# Update last_active_at
store_info["last_active_at"] = int(time.time())
# Save to kvstore
key = f"{OPENAI_VECTOR_STORES_PREFIX}{vector_store_id}"
await self.kvstore.set(key=key, value=json.dumps(store_info)) await self.kvstore.set(key=key, value=json.dumps(store_info))
# Update in-memory cache async def _delete_openai_vector_store_from_storage(self, store_id: str) -> None:
self.openai_vector_stores[vector_store_id] = store_info """Delete vector store metadata from kvstore."""
return VectorStoreObject(**store_info)
async def openai_delete_vector_store(
self,
vector_store_id: str,
) -> VectorStoreDeleteResponse:
"""Delete a vector store."""
assert self.kvstore is not None assert self.kvstore is not None
if vector_store_id not in self.openai_vector_stores: key = f"{OPENAI_VECTOR_STORES_PREFIX}{store_id}"
raise ValueError(f"Vector store {vector_store_id} not found")
# Delete from kvstore
key = f"{OPENAI_VECTOR_STORES_PREFIX}{vector_store_id}"
await self.kvstore.delete(key) await self.kvstore.delete(key)
# Delete from in-memory cache
del self.openai_vector_stores[vector_store_id]
# Also delete the underlying vector DB
try:
await self.unregister_vector_db(vector_store_id)
except Exception as e:
logger.warning(f"Failed to delete underlying vector DB {vector_store_id}: {e}")
return VectorStoreDeleteResponse(
id=vector_store_id,
deleted=True,
)
async def openai_search_vector_store(
self,
vector_store_id: str,
query: str | list[str],
filters: dict[str, Any] | None = None,
max_num_results: int = 10,
ranking_options: dict[str, Any] | None = None,
rewrite_query: bool = False,
search_mode: Literal["keyword", "vector", "hybrid"] = "vector",
) -> VectorStoreSearchResponse:
"""Search for chunks in a vector store."""
if vector_store_id not in self.openai_vector_stores:
raise ValueError(f"Vector store {vector_store_id} not found")
if isinstance(query, list):
search_query = " ".join(query)
else:
search_query = query
try:
score_threshold = ranking_options.get("score_threshold", 0.0) if ranking_options else 0.0
params = {
"max_chunks": max_num_results * CHUNK_MULTIPLIER,
"score_threshold": score_threshold,
"mode": search_mode,
}
# TODO: Add support for ranking_options.ranker
response = await self.query_chunks(
vector_db_id=vector_store_id,
query=search_query,
params=params,
)
# Convert response to OpenAI format
data = []
for i, (chunk, score) in enumerate(zip(response.chunks, response.scores, strict=False)):
# Apply score based filtering
if score < score_threshold:
continue
# Apply filters if provided
if filters:
# Simple metadata filtering
if not self._matches_filters(chunk.metadata, filters):
continue
chunk_data = {
"id": f"chunk_{i}",
"object": "vector_store.search_result",
"score": score,
"content": chunk.content.content if hasattr(chunk.content, "content") else str(chunk.content),
"metadata": chunk.metadata,
}
data.append(chunk_data)
if len(data) >= max_num_results:
break
return VectorStoreSearchResponse(
search_query=search_query,
data=data,
has_more=False, # For simplicity, we don't implement pagination here
next_page=None,
)
except Exception as e:
logger.error(f"Error searching vector store {vector_store_id}: {e}")
# Return empty results on error
return VectorStoreSearchResponse(
search_query=search_query,
data=[],
has_more=False,
next_page=None,
)
def _matches_filters(self, metadata: dict[str, Any], filters: dict[str, Any]) -> bool:
"""Check if metadata matches the provided filters."""
for key, value in filters.items():
if key not in metadata:
return False
if metadata[key] != value:
return False
return True

View file

@ -10,9 +10,8 @@ import json
import logging import logging
import sqlite3 import sqlite3
import struct import struct
import time
import uuid import uuid
from typing import Any, Literal from typing import Any
import numpy as np import numpy as np
import sqlite_vec import sqlite_vec
@ -24,12 +23,9 @@ from llama_stack.apis.vector_io import (
Chunk, Chunk,
QueryChunksResponse, QueryChunksResponse,
VectorIO, VectorIO,
VectorStoreDeleteResponse,
VectorStoreListResponse,
VectorStoreObject,
VectorStoreSearchResponse,
) )
from llama_stack.providers.datatypes import VectorDBsProtocolPrivate from llama_stack.providers.datatypes import VectorDBsProtocolPrivate
from llama_stack.providers.utils.memory.openai_vector_store_mixin import OpenAIVectorStoreMixin
from llama_stack.providers.utils.memory.vector_store import EmbeddingIndex, VectorDBWithIndex from llama_stack.providers.utils.memory.vector_store import EmbeddingIndex, VectorDBWithIndex
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -39,11 +35,6 @@ VECTOR_SEARCH = "vector"
KEYWORD_SEARCH = "keyword" KEYWORD_SEARCH = "keyword"
SEARCH_MODES = {VECTOR_SEARCH, KEYWORD_SEARCH} SEARCH_MODES = {VECTOR_SEARCH, KEYWORD_SEARCH}
# Constants for OpenAI vector stores (similar to faiss)
VERSION = "v3"
OPENAI_VECTOR_STORES_PREFIX = f"openai_vector_stores:{VERSION}::"
CHUNK_MULTIPLIER = 5
def serialize_vector(vector: list[float]) -> bytes: def serialize_vector(vector: list[float]) -> bytes:
"""Serialize a list of floats into a compact binary representation.""" """Serialize a list of floats into a compact binary representation."""
@ -303,7 +294,7 @@ class SQLiteVecIndex(EmbeddingIndex):
return QueryChunksResponse(chunks=chunks, scores=scores) return QueryChunksResponse(chunks=chunks, scores=scores)
class SQLiteVecVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate): class SQLiteVecVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolPrivate):
""" """
A VectorIO implementation using SQLite + sqlite_vec. A VectorIO implementation using SQLite + sqlite_vec.
This class handles vector database registration (with metadata stored in a table named `vector_dbs`) This class handles vector database registration (with metadata stored in a table named `vector_dbs`)
@ -340,15 +331,12 @@ class SQLiteVecVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate):
# Load any existing vector DB registrations. # Load any existing vector DB registrations.
cur.execute("SELECT metadata FROM vector_dbs") cur.execute("SELECT metadata FROM vector_dbs")
vector_db_rows = cur.fetchall() vector_db_rows = cur.fetchall()
# Load any existing OpenAI vector stores. return vector_db_rows
cur.execute("SELECT metadata FROM openai_vector_stores")
openai_store_rows = cur.fetchall()
return vector_db_rows, openai_store_rows
finally: finally:
cur.close() cur.close()
connection.close() connection.close()
vector_db_rows, openai_store_rows = await asyncio.to_thread(_setup_connection) vector_db_rows = await asyncio.to_thread(_setup_connection)
# Load existing vector DBs # Load existing vector DBs
for row in vector_db_rows: for row in vector_db_rows:
@ -359,11 +347,8 @@ class SQLiteVecVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate):
) )
self.cache[vector_db.identifier] = VectorDBWithIndex(vector_db, index, self.inference_api) self.cache[vector_db.identifier] = VectorDBWithIndex(vector_db, index, self.inference_api)
# Load existing OpenAI vector stores # Load existing OpenAI vector stores using the mixin method
for row in openai_store_rows: self.openai_vector_stores = await self._load_openai_vector_stores()
store_data = row[0]
store_info = json.loads(store_data)
self.openai_vector_stores[store_info["id"]] = store_info
async def shutdown(self) -> None: async def shutdown(self) -> None:
# nothing to do since we don't maintain a persistent connection # nothing to do since we don't maintain a persistent connection
@ -409,6 +394,87 @@ class SQLiteVecVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate):
await asyncio.to_thread(_delete_vector_db_from_registry) await asyncio.to_thread(_delete_vector_db_from_registry)
# OpenAI Vector Store Mixin abstract method implementations
async def _save_openai_vector_store(self, store_id: str, store_info: dict[str, Any]) -> None:
"""Save vector store metadata to SQLite database."""
def _store():
connection = _create_sqlite_connection(self.config.db_path)
cur = connection.cursor()
try:
cur.execute(
"INSERT OR REPLACE INTO openai_vector_stores (id, metadata) VALUES (?, ?)",
(store_id, json.dumps(store_info)),
)
connection.commit()
except Exception as e:
logger.error(f"Error saving openai vector store {store_id}: {e}")
raise
finally:
cur.close()
connection.close()
try:
await asyncio.to_thread(_store)
except Exception as e:
logger.error(f"Error saving openai vector store {store_id}: {e}")
raise
async def _load_openai_vector_stores(self) -> dict[str, dict[str, Any]]:
"""Load all vector store metadata from SQLite database."""
def _load():
connection = _create_sqlite_connection(self.config.db_path)
cur = connection.cursor()
try:
cur.execute("SELECT metadata FROM openai_vector_stores")
rows = cur.fetchall()
return rows
finally:
cur.close()
connection.close()
rows = await asyncio.to_thread(_load)
stores = {}
for row in rows:
store_data = row[0]
store_info = json.loads(store_data)
stores[store_info["id"]] = store_info
return stores
async def _update_openai_vector_store(self, store_id: str, store_info: dict[str, Any]) -> None:
"""Update vector store metadata in SQLite database."""
def _update():
connection = _create_sqlite_connection(self.config.db_path)
cur = connection.cursor()
try:
cur.execute(
"UPDATE openai_vector_stores SET metadata = ? WHERE id = ?",
(json.dumps(store_info), store_id),
)
connection.commit()
finally:
cur.close()
connection.close()
await asyncio.to_thread(_update)
async def _delete_openai_vector_store_from_storage(self, store_id: str) -> None:
"""Delete vector store metadata from SQLite database."""
def _delete():
connection = _create_sqlite_connection(self.config.db_path)
cur = connection.cursor()
try:
cur.execute("DELETE FROM openai_vector_stores WHERE id = ?", (store_id,))
connection.commit()
finally:
cur.close()
connection.close()
await asyncio.to_thread(_delete)
async def insert_chunks(self, vector_db_id: str, chunks: list[Chunk], ttl_seconds: int | None = None) -> None: async def insert_chunks(self, vector_db_id: str, chunks: list[Chunk], ttl_seconds: int | None = None) -> None:
if vector_db_id not in self.cache: if vector_db_id not in self.cache:
raise ValueError(f"Vector DB {vector_db_id} not found. Found: {list(self.cache.keys())}") raise ValueError(f"Vector DB {vector_db_id} not found. Found: {list(self.cache.keys())}")
@ -423,318 +489,6 @@ class SQLiteVecVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate):
raise ValueError(f"Vector DB {vector_db_id} not found") raise ValueError(f"Vector DB {vector_db_id} not found")
return await self.cache[vector_db_id].query_chunks(query, params) return await self.cache[vector_db_id].query_chunks(query, params)
async def openai_create_vector_store(
self,
name: str | None = None,
file_ids: list[str] | None = None,
expires_after: dict[str, Any] | None = None,
chunking_strategy: dict[str, Any] | None = None,
metadata: dict[str, Any] | None = None,
embedding_model: str | None = None,
embedding_dimension: int | None = 384,
provider_id: str | None = None,
provider_vector_db_id: str | None = None,
) -> VectorStoreObject:
"""Creates a vector store."""
# store and vector_db have the same id
store_id = name or str(uuid.uuid4())
created_at = int(time.time())
if provider_id is None:
raise ValueError("Provider ID is required")
if embedding_model is None:
raise ValueError("Embedding model is required")
# Use provided embedding dimension or default to 384
if embedding_dimension is None:
raise ValueError("Embedding dimension is required")
provider_vector_db_id = provider_vector_db_id or store_id
vector_db = VectorDB(
identifier=store_id,
embedding_dimension=embedding_dimension,
embedding_model=embedding_model,
provider_id=provider_id,
provider_resource_id=provider_vector_db_id,
)
# Register the vector DB
await self.register_vector_db(vector_db)
# Create OpenAI vector store metadata
store_info = {
"id": store_id,
"object": "vector_store",
"created_at": created_at,
"name": store_id,
"usage_bytes": 0,
"file_counts": {},
"status": "completed",
"expires_after": expires_after,
"expires_at": None,
"last_active_at": created_at,
"file_ids": file_ids or [],
"chunking_strategy": chunking_strategy,
}
# Add provider information to metadata if provided
metadata = metadata or {}
if provider_id:
metadata["provider_id"] = provider_id
if provider_vector_db_id:
metadata["provider_vector_db_id"] = provider_vector_db_id
store_info["metadata"] = metadata
# Store in SQLite database
def _store_openai_vector_store():
connection = _create_sqlite_connection(self.config.db_path)
cur = connection.cursor()
try:
cur.execute(
"INSERT OR REPLACE INTO openai_vector_stores (id, metadata) VALUES (?, ?)",
(store_id, json.dumps(store_info)),
)
connection.commit()
finally:
cur.close()
connection.close()
await asyncio.to_thread(_store_openai_vector_store)
# Store in memory cache
self.openai_vector_stores[store_id] = store_info
return VectorStoreObject(
id=store_id,
created_at=created_at,
name=store_id,
usage_bytes=0,
file_counts={},
status="completed",
expires_after=expires_after,
expires_at=None,
last_active_at=created_at,
metadata=metadata,
)
async def openai_list_vector_stores(
self,
limit: int = 20,
order: str = "desc",
after: str | None = None,
before: str | None = None,
) -> VectorStoreListResponse:
"""Returns a list of vector stores."""
# Get all vector stores
all_stores = list(self.openai_vector_stores.values())
# Sort by created_at
reverse_order = order == "desc"
all_stores.sort(key=lambda x: x["created_at"], reverse=reverse_order)
# Apply cursor-based pagination
if after:
after_index = next((i for i, store in enumerate(all_stores) if store["id"] == after), -1)
if after_index >= 0:
all_stores = all_stores[after_index + 1 :]
if before:
before_index = next((i for i, store in enumerate(all_stores) if store["id"] == before), len(all_stores))
all_stores = all_stores[:before_index]
# Apply limit
limited_stores = all_stores[:limit]
# Convert to VectorStoreObject instances
data = [VectorStoreObject(**store) for store in limited_stores]
# Determine pagination info
has_more = len(all_stores) > limit
first_id = data[0].id if data else None
last_id = data[-1].id if data else None
return VectorStoreListResponse(
data=data,
has_more=has_more,
first_id=first_id,
last_id=last_id,
)
async def openai_retrieve_vector_store(
self,
vector_store_id: str,
) -> VectorStoreObject:
"""Retrieves a vector store."""
if vector_store_id not in self.openai_vector_stores:
raise ValueError(f"Vector store {vector_store_id} not found")
store_info = self.openai_vector_stores[vector_store_id]
return VectorStoreObject(**store_info)
async def openai_update_vector_store(
self,
vector_store_id: str,
name: str | None = None,
expires_after: dict[str, Any] | None = None,
metadata: dict[str, Any] | None = None,
) -> VectorStoreObject:
"""Modifies a vector store."""
if vector_store_id not in self.openai_vector_stores:
raise ValueError(f"Vector store {vector_store_id} not found")
store_info = self.openai_vector_stores[vector_store_id].copy()
# Update fields if provided
if name is not None:
store_info["name"] = name
if expires_after is not None:
store_info["expires_after"] = expires_after
if metadata is not None:
store_info["metadata"] = metadata
# Update last_active_at
store_info["last_active_at"] = int(time.time())
# Save to SQLite database
def _update_openai_vector_store():
connection = _create_sqlite_connection(self.config.db_path)
cur = connection.cursor()
try:
cur.execute(
"UPDATE openai_vector_stores SET metadata = ? WHERE id = ?",
(json.dumps(store_info), vector_store_id),
)
connection.commit()
finally:
cur.close()
connection.close()
await asyncio.to_thread(_update_openai_vector_store)
# Update in-memory cache
self.openai_vector_stores[vector_store_id] = store_info
return VectorStoreObject(**store_info)
async def openai_delete_vector_store(
self,
vector_store_id: str,
) -> VectorStoreDeleteResponse:
"""Delete a vector store."""
if vector_store_id not in self.openai_vector_stores:
raise ValueError(f"Vector store {vector_store_id} not found")
# Delete from SQLite database
def _delete_openai_vector_store():
connection = _create_sqlite_connection(self.config.db_path)
cur = connection.cursor()
try:
cur.execute("DELETE FROM openai_vector_stores WHERE id = ?", (vector_store_id,))
connection.commit()
finally:
cur.close()
connection.close()
await asyncio.to_thread(_delete_openai_vector_store)
# Delete from in-memory cache
del self.openai_vector_stores[vector_store_id]
# Also delete the underlying vector DB
try:
await self.unregister_vector_db(vector_store_id)
except Exception as e:
logger.warning(f"Failed to delete underlying vector DB {vector_store_id}: {e}")
return VectorStoreDeleteResponse(
id=vector_store_id,
deleted=True,
)
async def openai_search_vector_store(
self,
vector_store_id: str,
query: str | list[str],
filters: dict[str, Any] | None = None,
max_num_results: int = 10,
ranking_options: dict[str, Any] | None = None,
rewrite_query: bool = False,
search_mode: Literal["keyword", "vector", "hybrid"] = "vector",
) -> VectorStoreSearchResponse:
"""Search for chunks in a vector store."""
if vector_store_id not in self.openai_vector_stores:
raise ValueError(f"Vector store {vector_store_id} not found")
if isinstance(query, list):
search_query = " ".join(query)
else:
search_query = query
try:
score_threshold = ranking_options.get("score_threshold", 0.0) if ranking_options else 0.0
params = {
"max_chunks": max_num_results * CHUNK_MULTIPLIER,
"score_threshold": score_threshold,
"mode": search_mode,
}
# TODO: Add support for ranking_options.ranker
response = await self.query_chunks(
vector_db_id=vector_store_id,
query=search_query,
params=params,
)
# Convert response to OpenAI format
data = []
for i, (chunk, score) in enumerate(zip(response.chunks, response.scores, strict=False)):
# Apply score based filtering
if score < score_threshold:
continue
# Apply filters if provided
if filters:
# Simple metadata filtering
if not self._matches_filters(chunk.metadata, filters):
continue
chunk_data = {
"id": f"chunk_{i}",
"object": "vector_store.search_result",
"score": score,
"content": chunk.content.content if hasattr(chunk.content, "content") else str(chunk.content),
"metadata": chunk.metadata,
}
data.append(chunk_data)
if len(data) >= max_num_results:
break
return VectorStoreSearchResponse(
search_query=search_query,
data=data,
has_more=False, # For simplicity, we don't implement pagination here
next_page=None,
)
except Exception as e:
logger.error(f"Error searching vector store {vector_store_id}: {e}")
# Return empty results on error
return VectorStoreSearchResponse(
search_query=search_query,
data=[],
has_more=False,
next_page=None,
)
def _matches_filters(self, metadata: dict[str, Any], filters: dict[str, Any]) -> bool:
"""Check if metadata matches the provided filters."""
for key, value in filters.items():
if key not in metadata:
return False
if metadata[key] != value:
return False
return True
def generate_chunk_id(document_id: str, chunk_text: str) -> str: def generate_chunk_id(document_id: str, chunk_text: str) -> str:
"""Generate a unique chunk ID using a hash of document ID and chunk text.""" """Generate a unique chunk ID using a hash of document ID and chunk text."""

View file

@ -6,7 +6,7 @@
import asyncio import asyncio
import json import json
import logging import logging
from typing import Any, Literal from typing import Any
from urllib.parse import urlparse from urllib.parse import urlparse
import chromadb import chromadb
@ -21,7 +21,7 @@ from llama_stack.apis.vector_io import (
VectorStoreDeleteResponse, VectorStoreDeleteResponse,
VectorStoreListResponse, VectorStoreListResponse,
VectorStoreObject, VectorStoreObject,
VectorStoreSearchResponse, VectorStoreSearchResponsePage,
) )
from llama_stack.providers.datatypes import Api, VectorDBsProtocolPrivate from llama_stack.providers.datatypes import Api, VectorDBsProtocolPrivate
from llama_stack.providers.inline.vector_io.chroma import ChromaVectorIOConfig as InlineChromaVectorIOConfig from llama_stack.providers.inline.vector_io.chroma import ChromaVectorIOConfig as InlineChromaVectorIOConfig
@ -189,7 +189,7 @@ class ChromaVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate):
async def openai_create_vector_store( async def openai_create_vector_store(
self, self,
name: str | None = None, name: str,
file_ids: list[str] | None = None, file_ids: list[str] | None = None,
expires_after: dict[str, Any] | None = None, expires_after: dict[str, Any] | None = None,
chunking_strategy: dict[str, Any] | None = None, chunking_strategy: dict[str, Any] | None = None,
@ -203,8 +203,8 @@ class ChromaVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate):
async def openai_list_vector_stores( async def openai_list_vector_stores(
self, self,
limit: int = 20, limit: int | None = 20,
order: str = "desc", order: str | None = "desc",
after: str | None = None, after: str | None = None,
before: str | None = None, before: str | None = None,
) -> VectorStoreListResponse: ) -> VectorStoreListResponse:
@ -236,9 +236,8 @@ class ChromaVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate):
vector_store_id: str, vector_store_id: str,
query: str | list[str], query: str | list[str],
filters: dict[str, Any] | None = None, filters: dict[str, Any] | None = None,
max_num_results: int = 10, max_num_results: int | None = 10,
ranking_options: dict[str, Any] | None = None, ranking_options: dict[str, Any] | None = None,
rewrite_query: bool = False, rewrite_query: bool | None = False,
search_mode: Literal["keyword", "vector", "hybrid"] = "vector", ) -> VectorStoreSearchResponsePage:
) -> VectorStoreSearchResponse:
raise NotImplementedError("OpenAI Vector Stores API is not supported in Chroma") raise NotImplementedError("OpenAI Vector Stores API is not supported in Chroma")

View file

@ -9,7 +9,7 @@ import hashlib
import logging import logging
import os import os
import uuid import uuid
from typing import Any, Literal from typing import Any
from numpy.typing import NDArray from numpy.typing import NDArray
from pymilvus import MilvusClient from pymilvus import MilvusClient
@ -23,7 +23,7 @@ from llama_stack.apis.vector_io import (
VectorStoreDeleteResponse, VectorStoreDeleteResponse,
VectorStoreListResponse, VectorStoreListResponse,
VectorStoreObject, VectorStoreObject,
VectorStoreSearchResponse, VectorStoreSearchResponsePage,
) )
from llama_stack.providers.datatypes import Api, VectorDBsProtocolPrivate from llama_stack.providers.datatypes import Api, VectorDBsProtocolPrivate
from llama_stack.providers.inline.vector_io.milvus import MilvusVectorIOConfig as InlineMilvusVectorIOConfig from llama_stack.providers.inline.vector_io.milvus import MilvusVectorIOConfig as InlineMilvusVectorIOConfig
@ -187,7 +187,7 @@ class MilvusVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate):
async def openai_create_vector_store( async def openai_create_vector_store(
self, self,
name: str | None = None, name: str,
file_ids: list[str] | None = None, file_ids: list[str] | None = None,
expires_after: dict[str, Any] | None = None, expires_after: dict[str, Any] | None = None,
chunking_strategy: dict[str, Any] | None = None, chunking_strategy: dict[str, Any] | None = None,
@ -201,8 +201,8 @@ class MilvusVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate):
async def openai_list_vector_stores( async def openai_list_vector_stores(
self, self,
limit: int = 20, limit: int | None = 20,
order: str = "desc", order: str | None = "desc",
after: str | None = None, after: str | None = None,
before: str | None = None, before: str | None = None,
) -> VectorStoreListResponse: ) -> VectorStoreListResponse:
@ -234,11 +234,10 @@ class MilvusVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate):
vector_store_id: str, vector_store_id: str,
query: str | list[str], query: str | list[str],
filters: dict[str, Any] | None = None, filters: dict[str, Any] | None = None,
max_num_results: int = 10, max_num_results: int | None = 10,
ranking_options: dict[str, Any] | None = None, ranking_options: dict[str, Any] | None = None,
rewrite_query: bool = False, rewrite_query: bool | None = False,
search_mode: Literal["keyword", "vector", "hybrid"] = "vector", ) -> VectorStoreSearchResponsePage:
) -> VectorStoreSearchResponse:
raise NotImplementedError("OpenAI Vector Stores API is not supported in Qdrant") raise NotImplementedError("OpenAI Vector Stores API is not supported in Qdrant")

View file

@ -6,7 +6,7 @@
import logging import logging
import uuid import uuid
from typing import Any, Literal from typing import Any
from numpy.typing import NDArray from numpy.typing import NDArray
from qdrant_client import AsyncQdrantClient, models from qdrant_client import AsyncQdrantClient, models
@ -21,7 +21,7 @@ from llama_stack.apis.vector_io import (
VectorStoreDeleteResponse, VectorStoreDeleteResponse,
VectorStoreListResponse, VectorStoreListResponse,
VectorStoreObject, VectorStoreObject,
VectorStoreSearchResponse, VectorStoreSearchResponsePage,
) )
from llama_stack.providers.datatypes import Api, VectorDBsProtocolPrivate from llama_stack.providers.datatypes import Api, VectorDBsProtocolPrivate
from llama_stack.providers.inline.vector_io.qdrant import QdrantVectorIOConfig as InlineQdrantVectorIOConfig from llama_stack.providers.inline.vector_io.qdrant import QdrantVectorIOConfig as InlineQdrantVectorIOConfig
@ -189,7 +189,7 @@ class QdrantVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate):
async def openai_create_vector_store( async def openai_create_vector_store(
self, self,
name: str | None = None, name: str,
file_ids: list[str] | None = None, file_ids: list[str] | None = None,
expires_after: dict[str, Any] | None = None, expires_after: dict[str, Any] | None = None,
chunking_strategy: dict[str, Any] | None = None, chunking_strategy: dict[str, Any] | None = None,
@ -203,8 +203,8 @@ class QdrantVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate):
async def openai_list_vector_stores( async def openai_list_vector_stores(
self, self,
limit: int = 20, limit: int | None = 20,
order: str = "desc", order: str | None = "desc",
after: str | None = None, after: str | None = None,
before: str | None = None, before: str | None = None,
) -> VectorStoreListResponse: ) -> VectorStoreListResponse:
@ -236,9 +236,8 @@ class QdrantVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate):
vector_store_id: str, vector_store_id: str,
query: str | list[str], query: str | list[str],
filters: dict[str, Any] | None = None, filters: dict[str, Any] | None = None,
max_num_results: int = 10, max_num_results: int | None = 10,
ranking_options: dict[str, Any] | None = None, ranking_options: dict[str, Any] | None = None,
rewrite_query: bool = False, rewrite_query: bool | None = False,
search_mode: Literal["keyword", "vector", "hybrid"] = "vector", ) -> VectorStoreSearchResponsePage:
) -> VectorStoreSearchResponse:
raise NotImplementedError("OpenAI Vector Stores API is not supported in Qdrant") raise NotImplementedError("OpenAI Vector Stores API is not supported in Qdrant")

View file

@ -76,7 +76,7 @@ class WeaviateIndex(EmbeddingIndex):
continue continue
chunks.append(chunk) chunks.append(chunk)
scores.append(1.0 / doc.metadata.distance) scores.append(1.0 / doc.metadata.distance if doc.metadata.distance != 0 else float("inf"))
return QueryChunksResponse(chunks=chunks, scores=scores) return QueryChunksResponse(chunks=chunks, scores=scores)

View file

@ -0,0 +1,385 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import logging
import time
import uuid
from abc import ABC, abstractmethod
from typing import Any
from llama_stack.apis.vector_dbs import VectorDB
from llama_stack.apis.vector_io import (
QueryChunksResponse,
VectorStoreContent,
VectorStoreDeleteResponse,
VectorStoreListResponse,
VectorStoreObject,
VectorStoreSearchResponse,
VectorStoreSearchResponsePage,
)
logger = logging.getLogger(__name__)
# Constants for OpenAI vector stores
CHUNK_MULTIPLIER = 5
class OpenAIVectorStoreMixin(ABC):
"""
Mixin class that provides common OpenAI Vector Store API implementation.
Providers need to implement the abstract storage methods and maintain
an openai_vector_stores in-memory cache.
"""
# These should be provided by the implementing class
openai_vector_stores: dict[str, dict[str, Any]]
@abstractmethod
async def _save_openai_vector_store(self, store_id: str, store_info: dict[str, Any]) -> None:
"""Save vector store metadata to persistent storage."""
pass
@abstractmethod
async def _load_openai_vector_stores(self) -> dict[str, dict[str, Any]]:
"""Load all vector store metadata from persistent storage."""
pass
@abstractmethod
async def _update_openai_vector_store(self, store_id: str, store_info: dict[str, Any]) -> None:
"""Update vector store metadata in persistent storage."""
pass
@abstractmethod
async def _delete_openai_vector_store_from_storage(self, store_id: str) -> None:
"""Delete vector store metadata from persistent storage."""
pass
@abstractmethod
async def register_vector_db(self, vector_db: VectorDB) -> None:
"""Register a vector database (provider-specific implementation)."""
pass
@abstractmethod
async def unregister_vector_db(self, vector_db_id: str) -> None:
"""Unregister a vector database (provider-specific implementation)."""
pass
@abstractmethod
async def query_chunks(
self, vector_db_id: str, query: Any, params: dict[str, Any] | None = None
) -> QueryChunksResponse:
"""Query chunks from a vector database (provider-specific implementation)."""
pass
async def openai_create_vector_store(
self,
name: str,
file_ids: list[str] | None = None,
expires_after: dict[str, Any] | None = None,
chunking_strategy: dict[str, Any] | None = None,
metadata: dict[str, Any] | None = None,
embedding_model: str | None = None,
embedding_dimension: int | None = 384,
provider_id: str | None = None,
provider_vector_db_id: str | None = None,
) -> VectorStoreObject:
"""Creates a vector store."""
# store and vector_db have the same id
store_id = name or str(uuid.uuid4())
created_at = int(time.time())
if provider_id is None:
raise ValueError("Provider ID is required")
if embedding_model is None:
raise ValueError("Embedding model is required")
# Use provided embedding dimension or default to 384
if embedding_dimension is None:
raise ValueError("Embedding dimension is required")
provider_vector_db_id = provider_vector_db_id or store_id
vector_db = VectorDB(
identifier=store_id,
embedding_dimension=embedding_dimension,
embedding_model=embedding_model,
provider_id=provider_id,
provider_resource_id=provider_vector_db_id,
)
# Register the vector DB
await self.register_vector_db(vector_db)
# Create OpenAI vector store metadata
store_info = {
"id": store_id,
"object": "vector_store",
"created_at": created_at,
"name": store_id,
"usage_bytes": 0,
"file_counts": {},
"status": "completed",
"expires_after": expires_after,
"expires_at": None,
"last_active_at": created_at,
"file_ids": file_ids or [],
"chunking_strategy": chunking_strategy,
}
# Add provider information to metadata if provided
metadata = metadata or {}
if provider_id:
metadata["provider_id"] = provider_id
if provider_vector_db_id:
metadata["provider_vector_db_id"] = provider_vector_db_id
store_info["metadata"] = metadata
# Save to persistent storage (provider-specific)
await self._save_openai_vector_store(store_id, store_info)
# Store in memory cache
self.openai_vector_stores[store_id] = store_info
return VectorStoreObject(
id=store_id,
created_at=created_at,
name=store_id,
usage_bytes=0,
file_counts={},
status="completed",
expires_after=expires_after,
expires_at=None,
last_active_at=created_at,
metadata=metadata,
)
async def openai_list_vector_stores(
self,
limit: int | None = 20,
order: str | None = "desc",
after: str | None = None,
before: str | None = None,
) -> VectorStoreListResponse:
"""Returns a list of vector stores."""
limit = limit or 20
order = order or "desc"
# Get all vector stores
all_stores = list(self.openai_vector_stores.values())
# Sort by created_at
reverse_order = order == "desc"
all_stores.sort(key=lambda x: x["created_at"], reverse=reverse_order)
# Apply cursor-based pagination
if after:
after_index = next((i for i, store in enumerate(all_stores) if store["id"] == after), -1)
if after_index >= 0:
all_stores = all_stores[after_index + 1 :]
if before:
before_index = next((i for i, store in enumerate(all_stores) if store["id"] == before), len(all_stores))
all_stores = all_stores[:before_index]
# Apply limit
limited_stores = all_stores[:limit]
# Convert to VectorStoreObject instances
data = [VectorStoreObject(**store) for store in limited_stores]
# Determine pagination info
has_more = len(all_stores) > limit
first_id = data[0].id if data else None
last_id = data[-1].id if data else None
return VectorStoreListResponse(
data=data,
has_more=has_more,
first_id=first_id,
last_id=last_id,
)
async def openai_retrieve_vector_store(
self,
vector_store_id: str,
) -> VectorStoreObject:
"""Retrieves a vector store."""
if vector_store_id not in self.openai_vector_stores:
raise ValueError(f"Vector store {vector_store_id} not found")
store_info = self.openai_vector_stores[vector_store_id]
return VectorStoreObject(**store_info)
async def openai_update_vector_store(
self,
vector_store_id: str,
name: str | None = None,
expires_after: dict[str, Any] | None = None,
metadata: dict[str, Any] | None = None,
) -> VectorStoreObject:
"""Modifies a vector store."""
if vector_store_id not in self.openai_vector_stores:
raise ValueError(f"Vector store {vector_store_id} not found")
store_info = self.openai_vector_stores[vector_store_id].copy()
# Update fields if provided
if name is not None:
store_info["name"] = name
if expires_after is not None:
store_info["expires_after"] = expires_after
if metadata is not None:
store_info["metadata"] = metadata
# Update last_active_at
store_info["last_active_at"] = int(time.time())
# Save to persistent storage (provider-specific)
await self._update_openai_vector_store(vector_store_id, store_info)
# Update in-memory cache
self.openai_vector_stores[vector_store_id] = store_info
return VectorStoreObject(**store_info)
async def openai_delete_vector_store(
self,
vector_store_id: str,
) -> VectorStoreDeleteResponse:
"""Delete a vector store."""
if vector_store_id not in self.openai_vector_stores:
raise ValueError(f"Vector store {vector_store_id} not found")
# Delete from persistent storage (provider-specific)
await self._delete_openai_vector_store_from_storage(vector_store_id)
# Delete from in-memory cache
del self.openai_vector_stores[vector_store_id]
# Also delete the underlying vector DB
try:
await self.unregister_vector_db(vector_store_id)
except Exception as e:
logger.warning(f"Failed to delete underlying vector DB {vector_store_id}: {e}")
return VectorStoreDeleteResponse(
id=vector_store_id,
deleted=True,
)
async def openai_search_vector_store(
self,
vector_store_id: str,
query: str | list[str],
filters: dict[str, Any] | None = None,
max_num_results: int | None = 10,
ranking_options: dict[str, Any] | None = None,
rewrite_query: bool | None = False,
# search_mode: Literal["keyword", "vector", "hybrid"] = "vector",
) -> VectorStoreSearchResponsePage:
"""Search for chunks in a vector store."""
# TODO: Add support in the API for this
search_mode = "vector"
max_num_results = max_num_results or 10
if vector_store_id not in self.openai_vector_stores:
raise ValueError(f"Vector store {vector_store_id} not found")
if isinstance(query, list):
search_query = " ".join(query)
else:
search_query = query
try:
score_threshold = ranking_options.get("score_threshold", 0.0) if ranking_options else 0.0
params = {
"max_chunks": max_num_results * CHUNK_MULTIPLIER,
"score_threshold": score_threshold,
"mode": search_mode,
}
# TODO: Add support for ranking_options.ranker
response = await self.query_chunks(
vector_db_id=vector_store_id,
query=search_query,
params=params,
)
# Convert response to OpenAI format
data = []
for chunk, score in zip(response.chunks, response.scores, strict=False):
# Apply score based filtering
if score < score_threshold:
continue
# Apply filters if provided
if filters:
# Simple metadata filtering
if not self._matches_filters(chunk.metadata, filters):
continue
# content is InterleavedContent
if isinstance(chunk.content, str):
content = [
VectorStoreContent(
type="text",
text=chunk.content,
)
]
elif isinstance(chunk.content, list):
# TODO: Add support for other types of content
content = [
VectorStoreContent(
type="text",
text=item.text,
)
for item in chunk.content
if item.type == "text"
]
else:
if chunk.content.type != "text":
raise ValueError(f"Unsupported content type: {chunk.content.type}")
content = [
VectorStoreContent(
type="text",
text=chunk.content.text,
)
]
response_data_item = VectorStoreSearchResponse(
file_id=chunk.metadata.get("file_id", ""),
filename=chunk.metadata.get("filename", ""),
score=score,
attributes=chunk.metadata,
content=content,
)
data.append(response_data_item)
if len(data) >= max_num_results:
break
return VectorStoreSearchResponsePage(
search_query=search_query,
data=data,
has_more=False, # For simplicity, we don't implement pagination here
next_page=None,
)
except Exception as e:
logger.error(f"Error searching vector store {vector_store_id}: {e}")
# Return empty results on error
return VectorStoreSearchResponsePage(
search_query=search_query,
data=[],
has_more=False,
next_page=None,
)
def _matches_filters(self, metadata: dict[str, Any], filters: dict[str, Any]) -> bool:
"""Check if metadata matches the provided filters."""
for key, value in filters.items():
if key not in metadata:
return False
if metadata[key] != value:
return False
return True

View file

@ -34,11 +34,15 @@ def skip_if_model_doesnt_support_variable_dimensions(model_id):
pytest.skip("{model_id} does not support variable output embedding dimensions") pytest.skip("{model_id} does not support variable output embedding dimensions")
def skip_if_model_doesnt_support_openai_embeddings(client_with_models, model_id): @pytest.fixture(params=["openai_client", "llama_stack_client"])
if isinstance(client_with_models, LlamaStackAsLibraryClient): def compat_client(request, client_with_models):
pytest.skip("OpenAI embeddings are not supported when testing with library client yet.") if request.param == "openai_client" and isinstance(client_with_models, LlamaStackAsLibraryClient):
pytest.skip("OpenAI client tests not supported with library client")
return request.getfixturevalue(request.param)
provider = provider_from_model(client_with_models, model_id)
def skip_if_model_doesnt_support_openai_embeddings(client, model_id):
provider = provider_from_model(client, model_id)
if provider.provider_type in ( if provider.provider_type in (
"inline::meta-reference", "inline::meta-reference",
"remote::bedrock", "remote::bedrock",
@ -58,13 +62,13 @@ def openai_client(client_with_models):
return OpenAI(base_url=base_url, api_key="fake") return OpenAI(base_url=base_url, api_key="fake")
def test_openai_embeddings_single_string(openai_client, client_with_models, embedding_model_id): def test_openai_embeddings_single_string(compat_client, client_with_models, embedding_model_id):
"""Test OpenAI embeddings endpoint with a single string input.""" """Test OpenAI embeddings endpoint with a single string input."""
skip_if_model_doesnt_support_openai_embeddings(client_with_models, embedding_model_id) skip_if_model_doesnt_support_openai_embeddings(client_with_models, embedding_model_id)
input_text = "Hello, world!" input_text = "Hello, world!"
response = openai_client.embeddings.create( response = compat_client.embeddings.create(
model=embedding_model_id, model=embedding_model_id,
input=input_text, input=input_text,
encoding_format="float", encoding_format="float",
@ -80,13 +84,13 @@ def test_openai_embeddings_single_string(openai_client, client_with_models, embe
assert all(isinstance(x, float) for x in response.data[0].embedding) assert all(isinstance(x, float) for x in response.data[0].embedding)
def test_openai_embeddings_multiple_strings(openai_client, client_with_models, embedding_model_id): def test_openai_embeddings_multiple_strings(compat_client, client_with_models, embedding_model_id):
"""Test OpenAI embeddings endpoint with multiple string inputs.""" """Test OpenAI embeddings endpoint with multiple string inputs."""
skip_if_model_doesnt_support_openai_embeddings(client_with_models, embedding_model_id) skip_if_model_doesnt_support_openai_embeddings(client_with_models, embedding_model_id)
input_texts = ["Hello, world!", "How are you today?", "This is a test."] input_texts = ["Hello, world!", "How are you today?", "This is a test."]
response = openai_client.embeddings.create( response = compat_client.embeddings.create(
model=embedding_model_id, model=embedding_model_id,
input=input_texts, input=input_texts,
) )
@ -103,13 +107,13 @@ def test_openai_embeddings_multiple_strings(openai_client, client_with_models, e
assert all(isinstance(x, float) for x in embedding_data.embedding) assert all(isinstance(x, float) for x in embedding_data.embedding)
def test_openai_embeddings_with_encoding_format_float(openai_client, client_with_models, embedding_model_id): def test_openai_embeddings_with_encoding_format_float(compat_client, client_with_models, embedding_model_id):
"""Test OpenAI embeddings endpoint with float encoding format.""" """Test OpenAI embeddings endpoint with float encoding format."""
skip_if_model_doesnt_support_openai_embeddings(client_with_models, embedding_model_id) skip_if_model_doesnt_support_openai_embeddings(client_with_models, embedding_model_id)
input_text = "Test encoding format" input_text = "Test encoding format"
response = openai_client.embeddings.create( response = compat_client.embeddings.create(
model=embedding_model_id, model=embedding_model_id,
input=input_text, input=input_text,
encoding_format="float", encoding_format="float",
@ -121,7 +125,7 @@ def test_openai_embeddings_with_encoding_format_float(openai_client, client_with
assert all(isinstance(x, float) for x in response.data[0].embedding) assert all(isinstance(x, float) for x in response.data[0].embedding)
def test_openai_embeddings_with_dimensions(openai_client, client_with_models, embedding_model_id): def test_openai_embeddings_with_dimensions(compat_client, client_with_models, embedding_model_id):
"""Test OpenAI embeddings endpoint with custom dimensions parameter.""" """Test OpenAI embeddings endpoint with custom dimensions parameter."""
skip_if_model_doesnt_support_openai_embeddings(client_with_models, embedding_model_id) skip_if_model_doesnt_support_openai_embeddings(client_with_models, embedding_model_id)
skip_if_model_doesnt_support_variable_dimensions(embedding_model_id) skip_if_model_doesnt_support_variable_dimensions(embedding_model_id)
@ -129,7 +133,7 @@ def test_openai_embeddings_with_dimensions(openai_client, client_with_models, em
input_text = "Test dimensions parameter" input_text = "Test dimensions parameter"
dimensions = 16 dimensions = 16
response = openai_client.embeddings.create( response = compat_client.embeddings.create(
model=embedding_model_id, model=embedding_model_id,
input=input_text, input=input_text,
dimensions=dimensions, dimensions=dimensions,
@ -142,14 +146,14 @@ def test_openai_embeddings_with_dimensions(openai_client, client_with_models, em
assert len(response.data[0].embedding) > 0 assert len(response.data[0].embedding) > 0
def test_openai_embeddings_with_user_parameter(openai_client, client_with_models, embedding_model_id): def test_openai_embeddings_with_user_parameter(compat_client, client_with_models, embedding_model_id):
"""Test OpenAI embeddings endpoint with user parameter.""" """Test OpenAI embeddings endpoint with user parameter."""
skip_if_model_doesnt_support_openai_embeddings(client_with_models, embedding_model_id) skip_if_model_doesnt_support_openai_embeddings(client_with_models, embedding_model_id)
input_text = "Test user parameter" input_text = "Test user parameter"
user_id = "test-user-123" user_id = "test-user-123"
response = openai_client.embeddings.create( response = compat_client.embeddings.create(
model=embedding_model_id, model=embedding_model_id,
input=input_text, input=input_text,
user=user_id, user=user_id,
@ -161,41 +165,41 @@ def test_openai_embeddings_with_user_parameter(openai_client, client_with_models
assert len(response.data[0].embedding) > 0 assert len(response.data[0].embedding) > 0
def test_openai_embeddings_empty_list_error(openai_client, client_with_models, embedding_model_id): def test_openai_embeddings_empty_list_error(compat_client, client_with_models, embedding_model_id):
"""Test that empty list input raises an appropriate error.""" """Test that empty list input raises an appropriate error."""
skip_if_model_doesnt_support_openai_embeddings(client_with_models, embedding_model_id) skip_if_model_doesnt_support_openai_embeddings(client_with_models, embedding_model_id)
with pytest.raises(Exception): # noqa: B017 with pytest.raises(Exception): # noqa: B017
openai_client.embeddings.create( compat_client.embeddings.create(
model=embedding_model_id, model=embedding_model_id,
input=[], input=[],
) )
def test_openai_embeddings_invalid_model_error(openai_client, client_with_models, embedding_model_id): def test_openai_embeddings_invalid_model_error(compat_client, client_with_models, embedding_model_id):
"""Test that invalid model ID raises an appropriate error.""" """Test that invalid model ID raises an appropriate error."""
skip_if_model_doesnt_support_openai_embeddings(client_with_models, embedding_model_id) skip_if_model_doesnt_support_openai_embeddings(client_with_models, embedding_model_id)
with pytest.raises(Exception): # noqa: B017 with pytest.raises(Exception): # noqa: B017
openai_client.embeddings.create( compat_client.embeddings.create(
model="invalid-model-id", model="invalid-model-id",
input="Test text", input="Test text",
) )
def test_openai_embeddings_different_inputs_different_outputs(openai_client, client_with_models, embedding_model_id): def test_openai_embeddings_different_inputs_different_outputs(compat_client, client_with_models, embedding_model_id):
"""Test that different inputs produce different embeddings.""" """Test that different inputs produce different embeddings."""
skip_if_model_doesnt_support_openai_embeddings(client_with_models, embedding_model_id) skip_if_model_doesnt_support_openai_embeddings(client_with_models, embedding_model_id)
input_text1 = "This is the first text" input_text1 = "This is the first text"
input_text2 = "This is completely different content" input_text2 = "This is completely different content"
response1 = openai_client.embeddings.create( response1 = compat_client.embeddings.create(
model=embedding_model_id, model=embedding_model_id,
input=input_text1, input=input_text1,
) )
response2 = openai_client.embeddings.create( response2 = compat_client.embeddings.create(
model=embedding_model_id, model=embedding_model_id,
input=input_text2, input=input_text2,
) )
@ -208,7 +212,7 @@ def test_openai_embeddings_different_inputs_different_outputs(openai_client, cli
assert embedding1 != embedding2 assert embedding1 != embedding2
def test_openai_embeddings_with_encoding_format_base64(openai_client, client_with_models, embedding_model_id): def test_openai_embeddings_with_encoding_format_base64(compat_client, client_with_models, embedding_model_id):
"""Test OpenAI embeddings endpoint with base64 encoding format.""" """Test OpenAI embeddings endpoint with base64 encoding format."""
skip_if_model_doesnt_support_openai_embeddings(client_with_models, embedding_model_id) skip_if_model_doesnt_support_openai_embeddings(client_with_models, embedding_model_id)
skip_if_model_doesnt_support_variable_dimensions(embedding_model_id) skip_if_model_doesnt_support_variable_dimensions(embedding_model_id)
@ -216,7 +220,7 @@ def test_openai_embeddings_with_encoding_format_base64(openai_client, client_wit
input_text = "Test base64 encoding format" input_text = "Test base64 encoding format"
dimensions = 12 dimensions = 12
response = openai_client.embeddings.create( response = compat_client.embeddings.create(
model=embedding_model_id, model=embedding_model_id,
input=input_text, input=input_text,
encoding_format="base64", encoding_format="base64",
@ -241,13 +245,13 @@ def test_openai_embeddings_with_encoding_format_base64(openai_client, client_wit
assert all(isinstance(x, float) for x in embedding_floats) assert all(isinstance(x, float) for x in embedding_floats)
def test_openai_embeddings_base64_batch_processing(openai_client, client_with_models, embedding_model_id): def test_openai_embeddings_base64_batch_processing(compat_client, client_with_models, embedding_model_id):
"""Test OpenAI embeddings endpoint with base64 encoding for batch processing.""" """Test OpenAI embeddings endpoint with base64 encoding for batch processing."""
skip_if_model_doesnt_support_openai_embeddings(client_with_models, embedding_model_id) skip_if_model_doesnt_support_openai_embeddings(client_with_models, embedding_model_id)
input_texts = ["First text for base64", "Second text for base64", "Third text for base64"] input_texts = ["First text for base64", "Second text for base64", "Third text for base64"]
response = openai_client.embeddings.create( response = compat_client.embeddings.create(
model=embedding_model_id, model=embedding_model_id,
input=input_texts, input=input_texts,
encoding_format="base64", encoding_format="base64",

View file

@ -17,9 +17,6 @@ logger = logging.getLogger(__name__)
def skip_if_provider_doesnt_support_openai_vector_stores(client_with_models): def skip_if_provider_doesnt_support_openai_vector_stores(client_with_models):
if isinstance(client_with_models, LlamaStackAsLibraryClient):
pytest.skip("OpenAI vector stores are not supported when testing with library client yet.")
vector_io_providers = [p for p in client_with_models.providers.list() if p.api == "vector_io"] vector_io_providers = [p for p in client_with_models.providers.list() if p.api == "vector_io"]
for p in vector_io_providers: for p in vector_io_providers:
if p.provider_type in ["inline::faiss", "inline::sqlite-vec"]: if p.provider_type in ["inline::faiss", "inline::sqlite-vec"]:
@ -34,6 +31,13 @@ def openai_client(client_with_models):
return OpenAI(base_url=base_url, api_key="fake") return OpenAI(base_url=base_url, api_key="fake")
@pytest.fixture(params=["openai_client", "llama_stack_client"])
def compat_client(request, client_with_models):
if request.param == "openai_client" and isinstance(client_with_models, LlamaStackAsLibraryClient):
pytest.skip("OpenAI client tests not supported with library client")
return request.getfixturevalue(request.param)
@pytest.fixture(scope="session") @pytest.fixture(scope="session")
def sample_chunks(): def sample_chunks():
return [ return [
@ -57,29 +61,29 @@ def sample_chunks():
@pytest.fixture(scope="function") @pytest.fixture(scope="function")
def openai_client_with_empty_stores(openai_client): def compat_client_with_empty_stores(compat_client):
def clear_vector_stores(): def clear_vector_stores():
# List and delete all existing vector stores # List and delete all existing vector stores
try: try:
response = openai_client.vector_stores.list() response = compat_client.vector_stores.list()
for store in response.data: for store in response.data:
openai_client.vector_stores.delete(vector_store_id=store.id) compat_client.vector_stores.delete(vector_store_id=store.id)
except Exception: except Exception:
# If the API is not available or fails, just continue # If the API is not available or fails, just continue
logger.warning("Failed to clear vector stores") logger.warning("Failed to clear vector stores")
pass pass
clear_vector_stores() clear_vector_stores()
yield openai_client yield compat_client
# Clean up after the test # Clean up after the test
clear_vector_stores() clear_vector_stores()
def test_openai_create_vector_store(openai_client_with_empty_stores, client_with_models): def test_openai_create_vector_store(compat_client_with_empty_stores, client_with_models):
"""Test creating a vector store using OpenAI API.""" """Test creating a vector store using OpenAI API."""
skip_if_provider_doesnt_support_openai_vector_stores(client_with_models) skip_if_provider_doesnt_support_openai_vector_stores(client_with_models)
client = openai_client_with_empty_stores client = compat_client_with_empty_stores
# Create a vector store # Create a vector store
vector_store = client.vector_stores.create( vector_store = client.vector_stores.create(
@ -96,11 +100,11 @@ def test_openai_create_vector_store(openai_client_with_empty_stores, client_with
assert hasattr(vector_store, "created_at") assert hasattr(vector_store, "created_at")
def test_openai_list_vector_stores(openai_client_with_empty_stores, client_with_models): def test_openai_list_vector_stores(compat_client_with_empty_stores, client_with_models):
"""Test listing vector stores using OpenAI API.""" """Test listing vector stores using OpenAI API."""
skip_if_provider_doesnt_support_openai_vector_stores(client_with_models) skip_if_provider_doesnt_support_openai_vector_stores(client_with_models)
client = openai_client_with_empty_stores client = compat_client_with_empty_stores
# Create a few vector stores # Create a few vector stores
store1 = client.vector_stores.create(name="store1", metadata={"type": "test"}) store1 = client.vector_stores.create(name="store1", metadata={"type": "test"})
@ -123,11 +127,11 @@ def test_openai_list_vector_stores(openai_client_with_empty_stores, client_with_
assert len(limited_response.data) == 1 assert len(limited_response.data) == 1
def test_openai_retrieve_vector_store(openai_client_with_empty_stores, client_with_models): def test_openai_retrieve_vector_store(compat_client_with_empty_stores, client_with_models):
"""Test retrieving a specific vector store using OpenAI API.""" """Test retrieving a specific vector store using OpenAI API."""
skip_if_provider_doesnt_support_openai_vector_stores(client_with_models) skip_if_provider_doesnt_support_openai_vector_stores(client_with_models)
client = openai_client_with_empty_stores client = compat_client_with_empty_stores
# Create a vector store # Create a vector store
created_store = client.vector_stores.create(name="retrieve_test_store", metadata={"purpose": "retrieval_test"}) created_store = client.vector_stores.create(name="retrieve_test_store", metadata={"purpose": "retrieval_test"})
@ -142,11 +146,11 @@ def test_openai_retrieve_vector_store(openai_client_with_empty_stores, client_wi
assert retrieved_store.object == "vector_store" assert retrieved_store.object == "vector_store"
def test_openai_update_vector_store(openai_client_with_empty_stores, client_with_models): def test_openai_update_vector_store(compat_client_with_empty_stores, client_with_models):
"""Test modifying a vector store using OpenAI API.""" """Test modifying a vector store using OpenAI API."""
skip_if_provider_doesnt_support_openai_vector_stores(client_with_models) skip_if_provider_doesnt_support_openai_vector_stores(client_with_models)
client = openai_client_with_empty_stores client = compat_client_with_empty_stores
# Create a vector store # Create a vector store
created_store = client.vector_stores.create(name="original_name", metadata={"version": "1.0"}) created_store = client.vector_stores.create(name="original_name", metadata={"version": "1.0"})
@ -165,11 +169,11 @@ def test_openai_update_vector_store(openai_client_with_empty_stores, client_with
assert modified_store.last_active_at > created_store.last_active_at assert modified_store.last_active_at > created_store.last_active_at
def test_openai_delete_vector_store(openai_client_with_empty_stores, client_with_models): def test_openai_delete_vector_store(compat_client_with_empty_stores, client_with_models):
"""Test deleting a vector store using OpenAI API.""" """Test deleting a vector store using OpenAI API."""
skip_if_provider_doesnt_support_openai_vector_stores(client_with_models) skip_if_provider_doesnt_support_openai_vector_stores(client_with_models)
client = openai_client_with_empty_stores client = compat_client_with_empty_stores
# Create a vector store # Create a vector store
created_store = client.vector_stores.create(name="delete_test_store", metadata={"purpose": "deletion_test"}) created_store = client.vector_stores.create(name="delete_test_store", metadata={"purpose": "deletion_test"})
@ -187,11 +191,11 @@ def test_openai_delete_vector_store(openai_client_with_empty_stores, client_with
client.vector_stores.retrieve(vector_store_id=created_store.id) client.vector_stores.retrieve(vector_store_id=created_store.id)
def test_openai_vector_store_search_empty(openai_client_with_empty_stores, client_with_models): def test_openai_vector_store_search_empty(compat_client_with_empty_stores, client_with_models):
"""Test searching an empty vector store using OpenAI API.""" """Test searching an empty vector store using OpenAI API."""
skip_if_provider_doesnt_support_openai_vector_stores(client_with_models) skip_if_provider_doesnt_support_openai_vector_stores(client_with_models)
client = openai_client_with_empty_stores client = compat_client_with_empty_stores
# Create a vector store # Create a vector store
vector_store = client.vector_stores.create(name="search_test_store", metadata={"purpose": "search_testing"}) vector_store = client.vector_stores.create(name="search_test_store", metadata={"purpose": "search_testing"})
@ -208,15 +212,15 @@ def test_openai_vector_store_search_empty(openai_client_with_empty_stores, clien
assert search_response.has_more is False assert search_response.has_more is False
def test_openai_vector_store_with_chunks(openai_client_with_empty_stores, client_with_models, sample_chunks): def test_openai_vector_store_with_chunks(compat_client_with_empty_stores, client_with_models, sample_chunks):
"""Test vector store functionality with actual chunks using both OpenAI and native APIs.""" """Test vector store functionality with actual chunks using both OpenAI and native APIs."""
skip_if_provider_doesnt_support_openai_vector_stores(client_with_models) skip_if_provider_doesnt_support_openai_vector_stores(client_with_models)
openai_client = openai_client_with_empty_stores compat_client = compat_client_with_empty_stores
llama_client = client_with_models llama_client = client_with_models
# Create a vector store using OpenAI API # Create a vector store using OpenAI API
vector_store = openai_client.vector_stores.create(name="chunks_test_store", metadata={"purpose": "chunks_testing"}) vector_store = compat_client.vector_stores.create(name="chunks_test_store", metadata={"purpose": "chunks_testing"})
# Insert chunks using the native LlamaStack API (since OpenAI API doesn't have direct chunk insertion) # Insert chunks using the native LlamaStack API (since OpenAI API doesn't have direct chunk insertion)
llama_client.vector_io.insert( llama_client.vector_io.insert(
@ -225,7 +229,7 @@ def test_openai_vector_store_with_chunks(openai_client_with_empty_stores, client
) )
# Search using OpenAI API # Search using OpenAI API
search_response = openai_client.vector_stores.search( search_response = compat_client.vector_stores.search(
vector_store_id=vector_store.id, query="What is Python programming language?", max_num_results=3 vector_store_id=vector_store.id, query="What is Python programming language?", max_num_results=3
) )
assert search_response is not None assert search_response is not None
@ -233,18 +237,19 @@ def test_openai_vector_store_with_chunks(openai_client_with_empty_stores, client
# The top result should be about Python (doc1) # The top result should be about Python (doc1)
top_result = search_response.data[0] top_result = search_response.data[0]
assert "python" in top_result.content.lower() or "programming" in top_result.content.lower() top_content = top_result.content[0].text
assert top_result.metadata["document_id"] == "doc1" assert "python" in top_content.lower() or "programming" in top_content.lower()
assert top_result.attributes["document_id"] == "doc1"
# Test filtering by metadata # Test filtering by metadata
filtered_search = openai_client.vector_stores.search( filtered_search = compat_client.vector_stores.search(
vector_store_id=vector_store.id, query="artificial intelligence", filters={"topic": "ai"}, max_num_results=5 vector_store_id=vector_store.id, query="artificial intelligence", filters={"topic": "ai"}, max_num_results=5
) )
assert filtered_search is not None assert filtered_search is not None
# All results should have topic "ai" # All results should have topic "ai"
for result in filtered_search.data: for result in filtered_search.data:
assert result.metadata["topic"] == "ai" assert result.attributes["topic"] == "ai"
@pytest.mark.parametrize( @pytest.mark.parametrize(
@ -257,18 +262,18 @@ def test_openai_vector_store_with_chunks(openai_client_with_empty_stores, client
], ],
) )
def test_openai_vector_store_search_relevance( def test_openai_vector_store_search_relevance(
openai_client_with_empty_stores, client_with_models, sample_chunks, test_case compat_client_with_empty_stores, client_with_models, sample_chunks, test_case
): ):
"""Test that OpenAI vector store search returns relevant results for different queries.""" """Test that OpenAI vector store search returns relevant results for different queries."""
skip_if_provider_doesnt_support_openai_vector_stores(client_with_models) skip_if_provider_doesnt_support_openai_vector_stores(client_with_models)
openai_client = openai_client_with_empty_stores compat_client = compat_client_with_empty_stores
llama_client = client_with_models llama_client = client_with_models
query, expected_doc_id, expected_topic = test_case query, expected_doc_id, expected_topic = test_case
# Create a vector store # Create a vector store
vector_store = openai_client.vector_stores.create( vector_store = compat_client.vector_stores.create(
name=f"relevance_test_{expected_doc_id}", metadata={"purpose": "relevance_testing"} name=f"relevance_test_{expected_doc_id}", metadata={"purpose": "relevance_testing"}
) )
@ -279,7 +284,7 @@ def test_openai_vector_store_search_relevance(
) )
# Search using OpenAI API # Search using OpenAI API
search_response = openai_client.vector_stores.search( search_response = compat_client.vector_stores.search(
vector_store_id=vector_store.id, query=query, max_num_results=4 vector_store_id=vector_store.id, query=query, max_num_results=4
) )
@ -288,8 +293,9 @@ def test_openai_vector_store_search_relevance(
# The top result should match the expected document # The top result should match the expected document
top_result = search_response.data[0] top_result = search_response.data[0]
assert top_result.metadata["document_id"] == expected_doc_id
assert top_result.metadata["topic"] == expected_topic assert top_result.attributes["document_id"] == expected_doc_id
assert top_result.attributes["topic"] == expected_topic
# Verify score is included and reasonable # Verify score is included and reasonable
assert isinstance(top_result.score, int | float) assert isinstance(top_result.score, int | float)
@ -297,16 +303,16 @@ def test_openai_vector_store_search_relevance(
def test_openai_vector_store_search_with_ranking_options( def test_openai_vector_store_search_with_ranking_options(
openai_client_with_empty_stores, client_with_models, sample_chunks compat_client_with_empty_stores, client_with_models, sample_chunks
): ):
"""Test OpenAI vector store search with ranking options.""" """Test OpenAI vector store search with ranking options."""
skip_if_provider_doesnt_support_openai_vector_stores(client_with_models) skip_if_provider_doesnt_support_openai_vector_stores(client_with_models)
openai_client = openai_client_with_empty_stores compat_client = compat_client_with_empty_stores
llama_client = client_with_models llama_client = client_with_models
# Create a vector store # Create a vector store
vector_store = openai_client.vector_stores.create( vector_store = compat_client.vector_stores.create(
name="ranking_test_store", metadata={"purpose": "ranking_testing"} name="ranking_test_store", metadata={"purpose": "ranking_testing"}
) )
@ -318,7 +324,7 @@ def test_openai_vector_store_search_with_ranking_options(
# Search with ranking options # Search with ranking options
threshold = 0.1 threshold = 0.1
search_response = openai_client.vector_stores.search( search_response = compat_client.vector_stores.search(
vector_store_id=vector_store.id, vector_store_id=vector_store.id,
query="machine learning and artificial intelligence", query="machine learning and artificial intelligence",
max_num_results=3, max_num_results=3,
@ -334,16 +340,16 @@ def test_openai_vector_store_search_with_ranking_options(
def test_openai_vector_store_search_with_high_score_filter( def test_openai_vector_store_search_with_high_score_filter(
openai_client_with_empty_stores, client_with_models, sample_chunks compat_client_with_empty_stores, client_with_models, sample_chunks
): ):
"""Test that searching with text very similar to a document and high score threshold returns only that document.""" """Test that searching with text very similar to a document and high score threshold returns only that document."""
skip_if_provider_doesnt_support_openai_vector_stores(client_with_models) skip_if_provider_doesnt_support_openai_vector_stores(client_with_models)
openai_client = openai_client_with_empty_stores compat_client = compat_client_with_empty_stores
llama_client = client_with_models llama_client = client_with_models
# Create a vector store # Create a vector store
vector_store = openai_client.vector_stores.create( vector_store = compat_client.vector_stores.create(
name="high_score_filter_test", metadata={"purpose": "high_score_filtering"} name="high_score_filter_test", metadata={"purpose": "high_score_filtering"}
) )
@ -358,7 +364,7 @@ def test_openai_vector_store_search_with_high_score_filter(
query = "Python is a high-level programming language with code readability and fewer lines than C++ or Java" query = "Python is a high-level programming language with code readability and fewer lines than C++ or Java"
# picking up thrshold to be slightly higher than the second result # picking up thrshold to be slightly higher than the second result
search_response = openai_client.vector_stores.search( search_response = compat_client.vector_stores.search(
vector_store_id=vector_store.id, vector_store_id=vector_store.id,
query=query, query=query,
max_num_results=3, max_num_results=3,
@ -367,7 +373,7 @@ def test_openai_vector_store_search_with_high_score_filter(
threshold = search_response.data[1].score + 0.0001 threshold = search_response.data[1].score + 0.0001
# we expect only one result with the requested threshold # we expect only one result with the requested threshold
search_response = openai_client.vector_stores.search( search_response = compat_client.vector_stores.search(
vector_store_id=vector_store.id, vector_store_id=vector_store.id,
query=query, query=query,
max_num_results=10, # Allow more results but expect filtering max_num_results=10, # Allow more results but expect filtering
@ -379,25 +385,26 @@ def test_openai_vector_store_search_with_high_score_filter(
# The top result should be the Python document (doc1) # The top result should be the Python document (doc1)
top_result = search_response.data[0] top_result = search_response.data[0]
assert top_result.metadata["document_id"] == "doc1" assert top_result.attributes["document_id"] == "doc1"
assert top_result.metadata["topic"] == "programming" assert top_result.attributes["topic"] == "programming"
assert top_result.score >= threshold assert top_result.score >= threshold
# Verify the content contains Python-related terms # Verify the content contains Python-related terms
assert "python" in top_result.content.lower() or "programming" in top_result.content.lower() top_content = top_result.content[0].text
assert "python" in top_content.lower() or "programming" in top_content.lower()
def test_openai_vector_store_search_with_max_num_results( def test_openai_vector_store_search_with_max_num_results(
openai_client_with_empty_stores, client_with_models, sample_chunks compat_client_with_empty_stores, client_with_models, sample_chunks
): ):
"""Test OpenAI vector store search with max_num_results.""" """Test OpenAI vector store search with max_num_results."""
skip_if_provider_doesnt_support_openai_vector_stores(client_with_models) skip_if_provider_doesnt_support_openai_vector_stores(client_with_models)
openai_client = openai_client_with_empty_stores compat_client = compat_client_with_empty_stores
llama_client = client_with_models llama_client = client_with_models
# Create a vector store # Create a vector store
vector_store = openai_client.vector_stores.create( vector_store = compat_client.vector_stores.create(
name="max_num_results_test_store", metadata={"purpose": "max_num_results_testing"} name="max_num_results_test_store", metadata={"purpose": "max_num_results_testing"}
) )
@ -408,7 +415,7 @@ def test_openai_vector_store_search_with_max_num_results(
) )
# Search with max_num_results # Search with max_num_results
search_response = openai_client.vector_stores.search( search_response = compat_client.vector_stores.search(
vector_store_id=vector_store.id, vector_store_id=vector_store.id,
query="machine learning and artificial intelligence", query="machine learning and artificial intelligence",
max_num_results=2, max_num_results=2,

View file

@ -154,3 +154,36 @@ def test_insert_chunks_with_precomputed_embeddings(client_with_empty_registry, e
assert len(response.chunks) > 0 assert len(response.chunks) > 0
assert response.chunks[0].metadata["document_id"] == "doc1" assert response.chunks[0].metadata["document_id"] == "doc1"
assert response.chunks[0].metadata["source"] == "precomputed" assert response.chunks[0].metadata["source"] == "precomputed"
def test_query_returns_valid_object_when_identical_to_embedding_in_vdb(client_with_empty_registry, embedding_model_id):
vector_db_id = "test_precomputed_embeddings_db"
client_with_empty_registry.vector_dbs.register(
vector_db_id=vector_db_id,
embedding_model=embedding_model_id,
embedding_dimension=384,
)
chunks_with_embeddings = [
Chunk(
content="duplicate",
metadata={"document_id": "doc1", "source": "precomputed"},
embedding=[0.1] * 384,
),
]
client_with_empty_registry.vector_io.insert(
vector_db_id=vector_db_id,
chunks=chunks_with_embeddings,
)
response = client_with_empty_registry.vector_io.query(
vector_db_id=vector_db_id,
query="duplicate",
)
# Verify the top result is the expected document
assert response is not None
assert len(response.chunks) > 0
assert response.chunks[0].metadata["document_id"] == "doc1"
assert response.chunks[0].metadata["source"] == "precomputed"

View file

@ -345,6 +345,56 @@ def test_invalid_oauth2_authentication(oauth2_client, invalid_token):
assert "Invalid JWT token" in response.json()["error"]["message"] assert "Invalid JWT token" in response.json()["error"]["message"]
async def mock_auth_jwks_response(*args, **kwargs):
if "headers" not in kwargs or "Authorization" not in kwargs["headers"]:
return MockResponse(401, {})
authz = kwargs["headers"]["Authorization"]
if authz != "Bearer my-jwks-token":
return MockResponse(401, {})
return await mock_jwks_response(args, kwargs)
@pytest.fixture
def oauth2_app_with_jwks_token():
app = FastAPI()
auth_config = AuthenticationConfig(
provider_type=AuthProviderType.OAUTH2_TOKEN,
config={
"jwks": {
"uri": "http://mock-authz-service/token/introspect",
"key_recheck_period": "3600",
"token": "my-jwks-token",
},
"audience": "llama-stack",
},
)
app.add_middleware(AuthenticationMiddleware, auth_config=auth_config)
@app.get("/test")
def test_endpoint():
return {"message": "Authentication successful"}
return app
@pytest.fixture
def oauth2_client_with_jwks_token(oauth2_app_with_jwks_token):
return TestClient(oauth2_app_with_jwks_token)
@patch("httpx.AsyncClient.get", new=mock_auth_jwks_response)
def test_oauth2_with_jwks_token_expected(oauth2_client, jwt_token_valid):
response = oauth2_client.get("/test", headers={"Authorization": f"Bearer {jwt_token_valid}"})
assert response.status_code == 401
@patch("httpx.AsyncClient.get", new=mock_auth_jwks_response)
def test_oauth2_with_jwks_token_configured(oauth2_client_with_jwks_token, jwt_token_valid):
response = oauth2_client_with_jwks_token.get("/test", headers={"Authorization": f"Bearer {jwt_token_valid}"})
assert response.status_code == 200
assert response.json() == {"message": "Authentication successful"}
def test_get_attributes_from_claims(): def test_get_attributes_from_claims():
claims = { claims = {
"sub": "my-user", "sub": "my-user",