Merge branch 'main' into change-default-embedding-model

This commit is contained in:
Francisco Arceo 2025-10-14 10:05:04 -04:00 committed by GitHub
commit da35f2452e
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
15 changed files with 473 additions and 231 deletions

View file

@ -66,11 +66,11 @@ runs:
shell: bash shell: bash
run: | run: |
echo "Checking for recording changes" echo "Checking for recording changes"
git status --porcelain tests/integration/recordings/ git status --porcelain tests/integration/
if [[ -n $(git status --porcelain tests/integration/recordings/) ]]; then if [[ -n $(git status --porcelain tests/integration/) ]]; then
echo "New recordings detected, committing and pushing" echo "New recordings detected, committing and pushing"
git add tests/integration/recordings/ git add tests/integration/
git commit -m "Recordings update from CI (suite: ${{ inputs.suite }})" git commit -m "Recordings update from CI (suite: ${{ inputs.suite }})"
git fetch origin ${{ github.ref_name }} git fetch origin ${{ github.ref_name }}

View file

@ -55,30 +55,18 @@ class VectorIORouter(VectorIO):
logger.debug("VectorIORouter.shutdown") logger.debug("VectorIORouter.shutdown")
pass pass
async def _get_first_embedding_model(self) -> tuple[str, int] | None: async def _get_embedding_model_dimension(self, embedding_model_id: str) -> int:
"""Get the first available embedding model identifier.""" """Get the embedding dimension for a specific embedding model."""
try:
# Get all models from the routing table
all_models = await self.routing_table.get_all_with_type("model") all_models = await self.routing_table.get_all_with_type("model")
# Filter for embedding models for model in all_models:
embedding_models = [ if model.identifier == embedding_model_id and model.model_type == ModelType.embedding:
model dimension = model.metadata.get("embedding_dimension")
for model in all_models
if hasattr(model, "model_type") and model.model_type == ModelType.embedding
]
if embedding_models:
dimension = embedding_models[0].metadata.get("embedding_dimension", None)
if dimension is None: if dimension is None:
raise ValueError(f"Embedding model {embedding_models[0].identifier} has no embedding dimension") raise ValueError(f"Embedding model '{embedding_model_id}' has no embedding_dimension in metadata")
return embedding_models[0].identifier, dimension return int(dimension)
else:
logger.warning("No embedding models found in the routing table") raise ValueError(f"Embedding model '{embedding_model_id}' not found or not an embedding model")
return None
except Exception as e:
logger.error(f"Error getting embedding models: {e}")
return None
async def register_vector_db( async def register_vector_db(
self, self,
@ -129,20 +117,30 @@ class VectorIORouter(VectorIO):
# Extract llama-stack-specific parameters from extra_body # Extract llama-stack-specific parameters from extra_body
extra = params.model_extra or {} extra = params.model_extra or {}
embedding_model = extra.get("embedding_model") embedding_model = extra.get("embedding_model")
embedding_dimension = extra.get("embedding_dimension", 384) embedding_dimension = extra.get("embedding_dimension")
provider_id = extra.get("provider_id") provider_id = extra.get("provider_id")
logger.debug(f"VectorIORouter.openai_create_vector_store: name={params.name}, provider_id={provider_id}") logger.debug(f"VectorIORouter.openai_create_vector_store: name={params.name}, provider_id={provider_id}")
# If no embedding model is provided, use the first available one # Require explicit embedding model specification
# TODO: this branch will soon be deleted so you _must_ provide the embedding_model when
# creating a vector store
if embedding_model is None: if embedding_model is None:
embedding_model_info = await self._get_first_embedding_model() raise ValueError("embedding_model is required in extra_body when creating a vector store")
if embedding_model_info is None:
raise ValueError("No embedding model provided and no embedding models available in the system") if embedding_dimension is None:
embedding_model, embedding_dimension = embedding_model_info embedding_dimension = await self._get_embedding_model_dimension(embedding_model)
logger.info(f"No embedding model specified, using first available: {embedding_model}")
# Auto-select provider if not specified
if provider_id is None:
num_providers = len(self.routing_table.impls_by_provider_id)
if num_providers == 0:
raise ValueError("No vector_io providers available")
if num_providers > 1:
available_providers = list(self.routing_table.impls_by_provider_id.keys())
raise ValueError(
f"Multiple vector_io providers available. Please specify provider_id in extra_body. "
f"Available providers: {available_providers}"
)
provider_id = list(self.routing_table.impls_by_provider_id.keys())[0]
vector_db_id = f"vs_{uuid.uuid4()}" vector_db_id = f"vs_{uuid.uuid4()}"
registered_vector_db = await self.routing_table.register_vector_db( registered_vector_db = await self.routing_table.register_vector_db(

View file

@ -5,13 +5,11 @@
# the root directory of this source tree. # the root directory of this source tree.
import ssl import ssl
import time
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from asyncio import Lock
from urllib.parse import parse_qs, urljoin, urlparse from urllib.parse import parse_qs, urljoin, urlparse
import httpx import httpx
from jose import jwt import jwt
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from llama_stack.apis.common.errors import TokenValidationError from llama_stack.apis.common.errors import TokenValidationError
@ -98,9 +96,7 @@ class OAuth2TokenAuthProvider(AuthProvider):
def __init__(self, config: OAuth2TokenAuthConfig): def __init__(self, config: OAuth2TokenAuthConfig):
self.config = config self.config = config
self._jwks_at: float = 0.0 self._jwks_client: jwt.PyJWKClient | None = None
self._jwks: dict[str, str] = {}
self._jwks_lock = Lock()
async def validate_token(self, token: str, scope: dict | None = None) -> User: async def validate_token(self, token: str, scope: dict | None = None) -> User:
if self.config.jwks: if self.config.jwks:
@ -109,23 +105,60 @@ class OAuth2TokenAuthProvider(AuthProvider):
return await self.introspect_token(token, scope) return await self.introspect_token(token, scope)
raise ValueError("One of jwks or introspection must be configured") raise ValueError("One of jwks or introspection must be configured")
def _get_jwks_client(self) -> jwt.PyJWKClient:
if self._jwks_client is None:
ssl_context = None
if not self.config.verify_tls:
# Disable SSL verification if verify_tls is False
ssl_context = ssl.create_default_context()
ssl_context.check_hostname = False
ssl_context.verify_mode = ssl.CERT_NONE
elif self.config.tls_cafile:
# Use custom CA file if provided
ssl_context = ssl.create_default_context(
cafile=self.config.tls_cafile.as_posix(),
)
# If verify_tls is True and no tls_cafile, ssl_context remains None (use system defaults)
# Prepare headers for JWKS request - this is needed for Kubernetes to authenticate
# to the JWK endpoint, we must use the token in the config to authenticate
headers = {}
if self.config.jwks and self.config.jwks.token:
headers["Authorization"] = f"Bearer {self.config.jwks.token}"
self._jwks_client = jwt.PyJWKClient(
self.config.jwks.uri if self.config.jwks else None,
cache_keys=True,
max_cached_keys=10,
lifespan=self.config.jwks.key_recheck_period if self.config.jwks else None,
headers=headers,
ssl_context=ssl_context,
)
return self._jwks_client
async def validate_jwt_token(self, token: str, scope: dict | None = None) -> User: async def validate_jwt_token(self, token: str, scope: dict | None = None) -> User:
"""Validate a token using the JWT token.""" """Validate a token using the JWT token."""
await self._refresh_jwks()
try: try:
header = jwt.get_unverified_header(token) jwks_client: jwt.PyJWKClient = self._get_jwks_client()
kid = header["kid"] signing_key = jwks_client.get_signing_key_from_jwt(token)
if kid not in self._jwks: algorithm = jwt.get_unverified_header(token)["alg"]
raise ValueError(f"Unknown key ID: {kid}")
key_data = self._jwks[kid]
algorithm = header.get("alg", "RS256")
claims = jwt.decode( claims = jwt.decode(
token, token,
key_data, signing_key.key,
algorithms=[algorithm], algorithms=[algorithm],
audience=self.config.audience, audience=self.config.audience,
issuer=self.config.issuer, issuer=self.config.issuer,
options={"verify_exp": True, "verify_aud": True, "verify_iss": True},
)
# Decode and verify the JWT
claims = jwt.decode(
token,
signing_key.key,
algorithms=[algorithm],
audience=self.config.audience,
issuer=self.config.issuer,
options={"verify_exp": True, "verify_aud": True, "verify_iss": True},
) )
except Exception as exc: except Exception as exc:
raise ValueError("Invalid JWT token") from exc raise ValueError("Invalid JWT token") from exc
@ -201,37 +234,6 @@ class OAuth2TokenAuthProvider(AuthProvider):
else: else:
return "Authentication required. Please provide a valid OAuth2 Bearer token in the Authorization header" return "Authentication required. Please provide a valid OAuth2 Bearer token in the Authorization header"
async def _refresh_jwks(self) -> None:
"""
Refresh the JWKS cache.
This is a simple cache that expires after a certain amount of time (defined by `key_recheck_period`).
If the cache is expired, we refresh the JWKS from the JWKS URI.
Notes: for Kubernetes which doesn't fully implement the OIDC protocol:
* It doesn't have user authentication flows
* It doesn't have refresh tokens
"""
async with self._jwks_lock:
if self.config.jwks is None:
raise ValueError("JWKS is not configured")
if time.time() - self._jwks_at > self.config.jwks.key_recheck_period:
headers = {}
if self.config.jwks.token:
headers["Authorization"] = f"Bearer {self.config.jwks.token}"
verify = self.config.tls_cafile.as_posix() if self.config.tls_cafile else self.config.verify_tls
async with httpx.AsyncClient(verify=verify) as client:
res = await client.get(self.config.jwks.uri, timeout=5, headers=headers)
res.raise_for_status()
jwks_data = res.json()["keys"]
updated = {}
for k in jwks_data:
kid = k["kid"]
# Store the entire key object as it may be needed for different algorithms
updated[kid] = k
self._jwks = updated
self._jwks_at = time.time()
class CustomAuthProvider(AuthProvider): class CustomAuthProvider(AuthProvider):
"""Custom authentication provider that uses an external endpoint.""" """Custom authentication provider that uses an external endpoint."""

View file

@ -277,7 +277,7 @@ Available Models:
pip_packages=["litellm"], pip_packages=["litellm"],
module="llama_stack.providers.remote.inference.watsonx", module="llama_stack.providers.remote.inference.watsonx",
config_class="llama_stack.providers.remote.inference.watsonx.WatsonXConfig", config_class="llama_stack.providers.remote.inference.watsonx.WatsonXConfig",
provider_data_validator="llama_stack.providers.remote.inference.watsonx.WatsonXProviderDataValidator", provider_data_validator="llama_stack.providers.remote.inference.watsonx.config.WatsonXProviderDataValidator",
description="IBM WatsonX inference provider for accessing AI models on IBM's WatsonX platform.", description="IBM WatsonX inference provider for accessing AI models on IBM's WatsonX platform.",
), ),
RemoteProviderSpec( RemoteProviderSpec(

View file

@ -7,18 +7,18 @@
import os import os
from typing import Any from typing import Any
from pydantic import BaseModel, ConfigDict, Field from pydantic import BaseModel, Field
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
from llama_stack.schema_utils import json_schema_type from llama_stack.schema_utils import json_schema_type
class WatsonXProviderDataValidator(BaseModel): class WatsonXProviderDataValidator(BaseModel):
model_config = ConfigDict( watsonx_project_id: str | None = Field(
from_attributes=True, default=None,
extra="forbid", description="IBM WatsonX project ID",
) )
watsonx_api_key: str | None watsonx_api_key: str | None = None
@json_schema_type @json_schema_type

View file

@ -4,42 +4,259 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
from collections.abc import AsyncIterator
from typing import Any from typing import Any
import litellm
import requests import requests
from llama_stack.apis.inference import ChatCompletionRequest from llama_stack.apis.inference.inference import (
OpenAIChatCompletion,
OpenAIChatCompletionChunk,
OpenAIChatCompletionRequestWithExtraBody,
OpenAIChatCompletionUsage,
OpenAICompletion,
OpenAICompletionRequestWithExtraBody,
OpenAIEmbeddingsRequestWithExtraBody,
OpenAIEmbeddingsResponse,
)
from llama_stack.apis.models import Model from llama_stack.apis.models import Model
from llama_stack.apis.models.models import ModelType from llama_stack.apis.models.models import ModelType
from llama_stack.log import get_logger
from llama_stack.providers.remote.inference.watsonx.config import WatsonXConfig from llama_stack.providers.remote.inference.watsonx.config import WatsonXConfig
from llama_stack.providers.utils.inference.litellm_openai_mixin import LiteLLMOpenAIMixin from llama_stack.providers.utils.inference.litellm_openai_mixin import LiteLLMOpenAIMixin
from llama_stack.providers.utils.inference.openai_compat import prepare_openai_completion_params
from llama_stack.providers.utils.telemetry.tracing import get_current_span
logger = get_logger(name=__name__, category="providers::remote::watsonx")
class WatsonXInferenceAdapter(LiteLLMOpenAIMixin): class WatsonXInferenceAdapter(LiteLLMOpenAIMixin):
_model_cache: dict[str, Model] = {} _model_cache: dict[str, Model] = {}
provider_data_api_key_field: str = "watsonx_api_key"
def __init__(self, config: WatsonXConfig): def __init__(self, config: WatsonXConfig):
self.available_models = None
self.config = config
api_key = config.auth_credential.get_secret_value() if config.auth_credential else None
LiteLLMOpenAIMixin.__init__( LiteLLMOpenAIMixin.__init__(
self, self,
litellm_provider_name="watsonx", litellm_provider_name="watsonx",
api_key_from_config=config.auth_credential.get_secret_value() if config.auth_credential else None, api_key_from_config=api_key,
provider_data_api_key_field="watsonx_api_key", provider_data_api_key_field="watsonx_api_key",
openai_compat_api_base=self.get_base_url(),
)
async def openai_chat_completion(
self,
params: OpenAIChatCompletionRequestWithExtraBody,
) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]:
"""
Override parent method to add timeout and inject usage object when missing.
This works around a LiteLLM defect where usage block is sometimes dropped.
"""
# Add usage tracking for streaming when telemetry is active
stream_options = params.stream_options
if params.stream and get_current_span() is not None:
if stream_options is None:
stream_options = {"include_usage": True}
elif "include_usage" not in stream_options:
stream_options = {**stream_options, "include_usage": True}
model_obj = await self.model_store.get_model(params.model)
request_params = await prepare_openai_completion_params(
model=self.get_litellm_model_name(model_obj.provider_resource_id),
messages=params.messages,
frequency_penalty=params.frequency_penalty,
function_call=params.function_call,
functions=params.functions,
logit_bias=params.logit_bias,
logprobs=params.logprobs,
max_completion_tokens=params.max_completion_tokens,
max_tokens=params.max_tokens,
n=params.n,
parallel_tool_calls=params.parallel_tool_calls,
presence_penalty=params.presence_penalty,
response_format=params.response_format,
seed=params.seed,
stop=params.stop,
stream=params.stream,
stream_options=stream_options,
temperature=params.temperature,
tool_choice=params.tool_choice,
tools=params.tools,
top_logprobs=params.top_logprobs,
top_p=params.top_p,
user=params.user,
api_key=self.get_api_key(),
api_base=self.api_base,
# These are watsonx-specific parameters
timeout=self.config.timeout,
project_id=self.config.project_id,
)
result = await litellm.acompletion(**request_params)
# If not streaming, check and inject usage if missing
if not params.stream:
# Use getattr to safely handle cases where usage attribute might not exist
if getattr(result, "usage", None) is None:
# Create usage object with zeros
usage_obj = OpenAIChatCompletionUsage(
prompt_tokens=0,
completion_tokens=0,
total_tokens=0,
)
# Use model_copy to create a new response with the usage injected
result = result.model_copy(update={"usage": usage_obj})
return result
# For streaming, wrap the iterator to normalize chunks
return self._normalize_stream(result)
def _normalize_chunk(self, chunk: OpenAIChatCompletionChunk) -> OpenAIChatCompletionChunk:
"""
Normalize a chunk to ensure it has all expected attributes.
This works around LiteLLM not always including all expected attributes.
"""
# Ensure chunk has usage attribute with zeros if missing
if not hasattr(chunk, "usage") or chunk.usage is None:
usage_obj = OpenAIChatCompletionUsage(
prompt_tokens=0,
completion_tokens=0,
total_tokens=0,
)
chunk = chunk.model_copy(update={"usage": usage_obj})
# Ensure all delta objects in choices have expected attributes
if hasattr(chunk, "choices") and chunk.choices:
normalized_choices = []
for choice in chunk.choices:
if hasattr(choice, "delta") and choice.delta:
delta = choice.delta
# Build update dict for missing attributes
delta_updates = {}
if not hasattr(delta, "refusal"):
delta_updates["refusal"] = None
if not hasattr(delta, "reasoning_content"):
delta_updates["reasoning_content"] = None
# If we need to update delta, create a new choice with updated delta
if delta_updates:
new_delta = delta.model_copy(update=delta_updates)
new_choice = choice.model_copy(update={"delta": new_delta})
normalized_choices.append(new_choice)
else:
normalized_choices.append(choice)
else:
normalized_choices.append(choice)
# If we modified any choices, create a new chunk with updated choices
if any(normalized_choices[i] is not chunk.choices[i] for i in range(len(chunk.choices))):
chunk = chunk.model_copy(update={"choices": normalized_choices})
return chunk
async def _normalize_stream(
self, stream: AsyncIterator[OpenAIChatCompletionChunk]
) -> AsyncIterator[OpenAIChatCompletionChunk]:
"""
Normalize all chunks in the stream to ensure they have expected attributes.
This works around LiteLLM sometimes not including expected attributes.
"""
try:
async for chunk in stream:
# Normalize and yield each chunk immediately
yield self._normalize_chunk(chunk)
except Exception as e:
logger.error(f"Error normalizing stream: {e}", exc_info=True)
raise
async def openai_completion(
self,
params: OpenAICompletionRequestWithExtraBody,
) -> OpenAICompletion:
"""
Override parent method to add watsonx-specific parameters.
"""
from llama_stack.providers.utils.inference.openai_compat import prepare_openai_completion_params
model_obj = await self.model_store.get_model(params.model)
request_params = await prepare_openai_completion_params(
model=self.get_litellm_model_name(model_obj.provider_resource_id),
prompt=params.prompt,
best_of=params.best_of,
echo=params.echo,
frequency_penalty=params.frequency_penalty,
logit_bias=params.logit_bias,
logprobs=params.logprobs,
max_tokens=params.max_tokens,
n=params.n,
presence_penalty=params.presence_penalty,
seed=params.seed,
stop=params.stop,
stream=params.stream,
stream_options=params.stream_options,
temperature=params.temperature,
top_p=params.top_p,
user=params.user,
suffix=params.suffix,
api_key=self.get_api_key(),
api_base=self.api_base,
# These are watsonx-specific parameters
timeout=self.config.timeout,
project_id=self.config.project_id,
)
return await litellm.atext_completion(**request_params)
async def openai_embeddings(
self,
params: OpenAIEmbeddingsRequestWithExtraBody,
) -> OpenAIEmbeddingsResponse:
"""
Override parent method to add watsonx-specific parameters.
"""
model_obj = await self.model_store.get_model(params.model)
# Convert input to list if it's a string
input_list = [params.input] if isinstance(params.input, str) else params.input
# Call litellm embedding function with watsonx-specific parameters
response = litellm.embedding(
model=self.get_litellm_model_name(model_obj.provider_resource_id),
input=input_list,
api_key=self.get_api_key(),
api_base=self.api_base,
dimensions=params.dimensions,
# These are watsonx-specific parameters
timeout=self.config.timeout,
project_id=self.config.project_id,
)
# Convert response to OpenAI format
from llama_stack.apis.inference import OpenAIEmbeddingUsage
from llama_stack.providers.utils.inference.litellm_openai_mixin import b64_encode_openai_embeddings_response
data = b64_encode_openai_embeddings_response(response.data, params.encoding_format)
usage = OpenAIEmbeddingUsage(
prompt_tokens=response["usage"]["prompt_tokens"],
total_tokens=response["usage"]["total_tokens"],
)
return OpenAIEmbeddingsResponse(
data=data,
model=model_obj.provider_resource_id,
usage=usage,
) )
self.available_models = None
self.config = config
def get_base_url(self) -> str: def get_base_url(self) -> str:
return self.config.url return self.config.url
async def _get_params(self, request: ChatCompletionRequest) -> dict[str, Any]:
# Get base parameters from parent
params = await super()._get_params(request)
# Add watsonx.ai specific parameters
params["project_id"] = self.config.project_id
params["time_limit"] = self.config.timeout
return params
# Copied from OpenAIMixin # Copied from OpenAIMixin
async def check_model_availability(self, model: str) -> bool: async def check_model_availability(self, model: str) -> bool:
""" """

View file

@ -353,14 +353,11 @@ class OpenAIVectorStoreMixin(ABC):
provider_vector_db_id = extra.get("provider_vector_db_id") provider_vector_db_id = extra.get("provider_vector_db_id")
embedding_model = extra.get("embedding_model") embedding_model = extra.get("embedding_model")
embedding_dimension = extra.get("embedding_dimension", 768) embedding_dimension = extra.get("embedding_dimension", 768)
provider_id = extra.get("provider_id") # use provider_id set by router; fallback to provider's own ID when used directly via --stack-config
provider_id = extra.get("provider_id") or getattr(self, "__provider_id__", None)
# Derive the canonical vector_db_id (allow override, else generate) # Derive the canonical vector_db_id (allow override, else generate)
vector_db_id = provider_vector_db_id or generate_object_id("vector_store", lambda: f"vs_{uuid.uuid4()}") vector_db_id = provider_vector_db_id or generate_object_id("vector_store", lambda: f"vs_{uuid.uuid4()}")
if provider_id is None:
raise ValueError("Provider ID is required")
if embedding_model is None: if embedding_model is None:
raise ValueError("Embedding model is required") raise ValueError("Embedding model is required")
@ -369,6 +366,9 @@ class OpenAIVectorStoreMixin(ABC):
raise ValueError("Embedding dimension is required") raise ValueError("Embedding dimension is required")
# Register the VectorDB backing this vector store # Register the VectorDB backing this vector store
if provider_id is None:
raise ValueError("Provider ID is required but was not provided")
vector_db = VectorDB( vector_db = VectorDB(
identifier=vector_db_id, identifier=vector_db_id,
embedding_dimension=embedding_dimension, embedding_dimension=embedding_dimension,

View file

@ -34,7 +34,7 @@ dependencies = [
"openai>=1.107", # for expires_after support "openai>=1.107", # for expires_after support
"prompt-toolkit", "prompt-toolkit",
"python-dotenv", "python-dotenv",
"python-jose[cryptography]", "pyjwt[crypto]>=2.10.0", # Pull crypto to support RS256 for jwt. Requires 2.10.0+ for ssl_context support.
"pydantic>=2.11.9", "pydantic>=2.11.9",
"rich", "rich",
"starlette", "starlette",

View file

@ -58,7 +58,6 @@ def skip_if_model_doesnt_support_openai_completion(client_with_models, model_id)
# does not work with the specified model, gpt-5-mini. Please choose different model and try # does not work with the specified model, gpt-5-mini. Please choose different model and try
# again. You can learn more about which models can be used with each operation here: # again. You can learn more about which models can be used with each operation here:
# https://go.microsoft.com/fwlink/?linkid=2197993.'}}"} # https://go.microsoft.com/fwlink/?linkid=2197993.'}}"}
"remote::watsonx", # return 404 when hitting the /openai/v1 endpoint
"remote::llama-openai-compat", "remote::llama-openai-compat",
): ):
pytest.skip(f"Model {model_id} hosted by {provider.provider_type} doesn't support OpenAI completions.") pytest.skip(f"Model {model_id} hosted by {provider.provider_type} doesn't support OpenAI completions.")
@ -68,6 +67,7 @@ def skip_if_doesnt_support_completions_logprobs(client_with_models, model_id):
provider_type = provider_from_model(client_with_models, model_id).provider_type provider_type = provider_from_model(client_with_models, model_id).provider_type
if provider_type in ( if provider_type in (
"remote::ollama", # logprobs is ignored "remote::ollama", # logprobs is ignored
"remote::watsonx",
): ):
pytest.skip(f"Model {model_id} hosted by {provider_type} doesn't support /v1/completions logprobs.") pytest.skip(f"Model {model_id} hosted by {provider_type} doesn't support /v1/completions logprobs.")
@ -110,6 +110,7 @@ def skip_if_doesnt_support_n(client_with_models, model_id):
# Error code 400 - {'message': '"n" > 1 is not currently supported', 'type': 'invalid_request_error', 'param': 'n', 'code': 'wrong_api_format'} # Error code 400 - {'message': '"n" > 1 is not currently supported', 'type': 'invalid_request_error', 'param': 'n', 'code': 'wrong_api_format'}
"remote::cerebras", "remote::cerebras",
"remote::databricks", # Bad request: parameter "n" must be equal to 1 for streaming mode "remote::databricks", # Bad request: parameter "n" must be equal to 1 for streaming mode
"remote::watsonx",
): ):
pytest.skip(f"Model {model_id} hosted by {provider.provider_type} doesn't support n param.") pytest.skip(f"Model {model_id} hosted by {provider.provider_type} doesn't support n param.")
@ -124,7 +125,6 @@ def skip_if_model_doesnt_support_openai_chat_completion(client_with_models, mode
"remote::databricks", "remote::databricks",
"remote::cerebras", "remote::cerebras",
"remote::runpod", "remote::runpod",
"remote::watsonx", # watsonx returns 404 when hitting the /openai/v1 endpoint
): ):
pytest.skip(f"Model {model_id} hosted by {provider.provider_type} doesn't support OpenAI chat completions.") pytest.skip(f"Model {model_id} hosted by {provider.provider_type} doesn't support OpenAI chat completions.")
@ -508,6 +508,12 @@ def test_openai_chat_completion_non_streaming_with_file(openai_client, client_wi
assert "hello world" in normalized_content assert "hello world" in normalized_content
def skip_if_doesnt_support_completions_stop_sequence(client_with_models, model_id):
provider_type = provider_from_model(client_with_models, model_id).provider_type
if provider_type in ("remote::watsonx",): # openai.BadRequestError: Error code: 400
pytest.skip(f"Model {model_id} hosted by {provider_type} doesn't support /v1/completions stop sequence.")
@pytest.mark.parametrize( @pytest.mark.parametrize(
"test_case", "test_case",
[ [
@ -516,6 +522,7 @@ def test_openai_chat_completion_non_streaming_with_file(openai_client, client_wi
) )
def test_openai_completion_stop_sequence(client_with_models, openai_client, text_model_id, test_case): def test_openai_completion_stop_sequence(client_with_models, openai_client, text_model_id, test_case):
skip_if_model_doesnt_support_openai_completion(client_with_models, text_model_id) skip_if_model_doesnt_support_openai_completion(client_with_models, text_model_id)
skip_if_doesnt_support_completions_stop_sequence(client_with_models, text_model_id)
tc = TestCase(test_case) tc = TestCase(test_case)

View file

@ -50,11 +50,15 @@ def skip_if_model_doesnt_support_encoding_format_base64(client, model_id):
def skip_if_model_doesnt_support_variable_dimensions(client_with_models, model_id): def skip_if_model_doesnt_support_variable_dimensions(client_with_models, model_id):
provider = provider_from_model(client_with_models, model_id) provider = provider_from_model(client_with_models, model_id)
if provider.provider_type in ( if (
provider.provider_type
in (
"remote::together", # returns 400 "remote::together", # returns 400
"inline::sentence-transformers", "inline::sentence-transformers",
# Error code: 400 - {'error_code': 'BAD_REQUEST', 'message': 'Bad request: json: unknown field "dimensions"\n'} # Error code: 400 - {'error_code': 'BAD_REQUEST', 'message': 'Bad request: json: unknown field "dimensions"\n'}
"remote::databricks", "remote::databricks",
"remote::watsonx", # openai.BadRequestError: Error code: 400 - {'detail': "litellm.UnsupportedParamsError: watsonx does not support parameters: {'dimensions': 384}
)
): ):
pytest.skip( pytest.skip(
f"Model {model_id} hosted by {provider.provider_type} does not support variable output embedding dimensions." f"Model {model_id} hosted by {provider.provider_type} does not support variable output embedding dimensions."

View file

@ -146,8 +146,6 @@ def test_openai_create_vector_store(
metadata={"purpose": "testing", "environment": "integration"}, metadata={"purpose": "testing", "environment": "integration"},
extra_body={ extra_body={
"embedding_model": embedding_model_id, "embedding_model": embedding_model_id,
"embedding_dimension": embedding_dimension,
"provider_id": "my_provider",
}, },
) )
@ -175,8 +173,6 @@ def test_openai_list_vector_stores(
metadata={"type": "test"}, metadata={"type": "test"},
extra_body={ extra_body={
"embedding_model": embedding_model_id, "embedding_model": embedding_model_id,
"embedding_dimension": embedding_dimension,
"provider_id": "my_provider",
}, },
) )
store2 = client.vector_stores.create( store2 = client.vector_stores.create(
@ -184,8 +180,6 @@ def test_openai_list_vector_stores(
metadata={"type": "test"}, metadata={"type": "test"},
extra_body={ extra_body={
"embedding_model": embedding_model_id, "embedding_model": embedding_model_id,
"embedding_dimension": embedding_dimension,
"provider_id": "my_provider",
}, },
) )
@ -220,8 +214,6 @@ def test_openai_retrieve_vector_store(
metadata={"purpose": "retrieval_test"}, metadata={"purpose": "retrieval_test"},
extra_body={ extra_body={
"embedding_model": embedding_model_id, "embedding_model": embedding_model_id,
"embedding_dimension": embedding_dimension,
"provider_id": "my_provider",
}, },
) )
@ -249,8 +241,6 @@ def test_openai_update_vector_store(
metadata={"version": "1.0"}, metadata={"version": "1.0"},
extra_body={ extra_body={
"embedding_model": embedding_model_id, "embedding_model": embedding_model_id,
"embedding_dimension": embedding_dimension,
"provider_id": "my_provider",
}, },
) )
time.sleep(1) time.sleep(1)
@ -282,8 +272,6 @@ def test_openai_delete_vector_store(
metadata={"purpose": "deletion_test"}, metadata={"purpose": "deletion_test"},
extra_body={ extra_body={
"embedding_model": embedding_model_id, "embedding_model": embedding_model_id,
"embedding_dimension": embedding_dimension,
"provider_id": "my_provider",
}, },
) )
@ -314,8 +302,6 @@ def test_openai_vector_store_search_empty(
metadata={"purpose": "search_testing"}, metadata={"purpose": "search_testing"},
extra_body={ extra_body={
"embedding_model": embedding_model_id, "embedding_model": embedding_model_id,
"embedding_dimension": embedding_dimension,
"provider_id": "my_provider",
}, },
) )
@ -346,8 +332,6 @@ def test_openai_vector_store_with_chunks(
metadata={"purpose": "chunks_testing"}, metadata={"purpose": "chunks_testing"},
extra_body={ extra_body={
"embedding_model": embedding_model_id, "embedding_model": embedding_model_id,
"embedding_dimension": embedding_dimension,
"provider_id": "my_provider",
}, },
) )
@ -412,8 +396,6 @@ def test_openai_vector_store_search_relevance(
metadata={"purpose": "relevance_testing"}, metadata={"purpose": "relevance_testing"},
extra_body={ extra_body={
"embedding_model": embedding_model_id, "embedding_model": embedding_model_id,
"embedding_dimension": embedding_dimension,
"provider_id": "my_provider",
}, },
) )
@ -457,8 +439,6 @@ def test_openai_vector_store_search_with_ranking_options(
metadata={"purpose": "ranking_testing"}, metadata={"purpose": "ranking_testing"},
extra_body={ extra_body={
"embedding_model": embedding_model_id, "embedding_model": embedding_model_id,
"embedding_dimension": embedding_dimension,
"provider_id": "my_provider",
}, },
) )
@ -512,8 +492,6 @@ def test_openai_vector_store_search_with_high_score_filter(
metadata={"purpose": "high_score_filtering"}, metadata={"purpose": "high_score_filtering"},
extra_body={ extra_body={
"embedding_model": embedding_model_id, "embedding_model": embedding_model_id,
"embedding_dimension": embedding_dimension,
"provider_id": "my_provider",
}, },
) )
@ -573,8 +551,6 @@ def test_openai_vector_store_search_with_max_num_results(
metadata={"purpose": "max_num_results_testing"}, metadata={"purpose": "max_num_results_testing"},
extra_body={ extra_body={
"embedding_model": embedding_model_id, "embedding_model": embedding_model_id,
"embedding_dimension": embedding_dimension,
"provider_id": "my_provider",
}, },
) )
@ -608,8 +584,6 @@ def test_openai_vector_store_attach_file(
name="test_store", name="test_store",
extra_body={ extra_body={
"embedding_model": embedding_model_id, "embedding_model": embedding_model_id,
"embedding_dimension": embedding_dimension,
"provider_id": "my_provider",
}, },
) )
@ -688,8 +662,6 @@ def test_openai_vector_store_attach_files_on_creation(
file_ids=file_ids, file_ids=file_ids,
extra_body={ extra_body={
"embedding_model": embedding_model_id, "embedding_model": embedding_model_id,
"embedding_dimension": embedding_dimension,
"provider_id": "my_provider",
}, },
) )
@ -735,8 +707,6 @@ def test_openai_vector_store_list_files(
name="test_store", name="test_store",
extra_body={ extra_body={
"embedding_model": embedding_model_id, "embedding_model": embedding_model_id,
"embedding_dimension": embedding_dimension,
"provider_id": "my_provider",
}, },
) )
@ -826,8 +796,6 @@ def test_openai_vector_store_retrieve_file_contents(
name="test_store", name="test_store",
extra_body={ extra_body={
"embedding_model": embedding_model_id, "embedding_model": embedding_model_id,
"embedding_dimension": embedding_dimension,
"provider_id": "my_provider",
}, },
) )
@ -851,8 +819,6 @@ def test_openai_vector_store_retrieve_file_contents(
attributes=attributes, attributes=attributes,
extra_body={ extra_body={
"embedding_model": embedding_model_id, "embedding_model": embedding_model_id,
"embedding_dimension": embedding_dimension,
"provider_id": "my_provider",
}, },
) )
@ -889,8 +855,6 @@ def test_openai_vector_store_delete_file(
name="test_store", name="test_store",
extra_body={ extra_body={
"embedding_model": embedding_model_id, "embedding_model": embedding_model_id,
"embedding_dimension": embedding_dimension,
"provider_id": "my_provider",
}, },
) )
@ -955,8 +919,6 @@ def test_openai_vector_store_delete_file_removes_from_vector_store(
name="test_store", name="test_store",
extra_body={ extra_body={
"embedding_model": embedding_model_id, "embedding_model": embedding_model_id,
"embedding_dimension": embedding_dimension,
"provider_id": "my_provider",
}, },
) )
@ -1007,8 +969,6 @@ def test_openai_vector_store_update_file(
name="test_store", name="test_store",
extra_body={ extra_body={
"embedding_model": embedding_model_id, "embedding_model": embedding_model_id,
"embedding_dimension": embedding_dimension,
"provider_id": "my_provider",
}, },
) )
@ -1078,8 +1038,6 @@ def test_create_vector_store_files_duplicate_vector_store_name(
name="test_store_with_files", name="test_store_with_files",
extra_body={ extra_body={
"embedding_model": embedding_model_id, "embedding_model": embedding_model_id,
"embedding_dimension": embedding_dimension,
"provider_id": "my_provider",
}, },
) )
assert vector_store.file_counts.completed == 0 assert vector_store.file_counts.completed == 0
@ -1092,8 +1050,6 @@ def test_create_vector_store_files_duplicate_vector_store_name(
name="test_store_with_files", name="test_store_with_files",
extra_body={ extra_body={
"embedding_model": embedding_model_id, "embedding_model": embedding_model_id,
"embedding_dimension": embedding_dimension,
"provider_id": "my_provider",
}, },
) )
@ -1105,8 +1061,6 @@ def test_create_vector_store_files_duplicate_vector_store_name(
file_id=file_ids[0], file_id=file_ids[0],
extra_body={ extra_body={
"embedding_model": embedding_model_id, "embedding_model": embedding_model_id,
"embedding_dimension": embedding_dimension,
"provider_id": "my_provider",
}, },
) )
assert created_file.status == "completed" assert created_file.status == "completed"
@ -1117,8 +1071,6 @@ def test_create_vector_store_files_duplicate_vector_store_name(
file_id=file_ids[1], file_id=file_ids[1],
extra_body={ extra_body={
"embedding_model": embedding_model_id, "embedding_model": embedding_model_id,
"embedding_dimension": embedding_dimension,
"provider_id": "my_provider",
}, },
) )
assert created_file_from_non_deleted_vector_store.status == "completed" assert created_file_from_non_deleted_vector_store.status == "completed"
@ -1139,8 +1091,6 @@ def test_openai_vector_store_search_modes(
metadata={"purpose": "search_mode_testing"}, metadata={"purpose": "search_mode_testing"},
extra_body={ extra_body={
"embedding_model": embedding_model_id, "embedding_model": embedding_model_id,
"embedding_dimension": embedding_dimension,
"provider_id": "my_provider",
}, },
) )
@ -1172,8 +1122,6 @@ def test_openai_vector_store_file_batch_create_and_retrieve(
name="batch_test_store", name="batch_test_store",
extra_body={ extra_body={
"embedding_model": embedding_model_id, "embedding_model": embedding_model_id,
"embedding_dimension": embedding_dimension,
"provider_id": "my_provider",
}, },
) )
@ -1191,8 +1139,6 @@ def test_openai_vector_store_file_batch_create_and_retrieve(
file_ids=file_ids, file_ids=file_ids,
extra_body={ extra_body={
"embedding_model": embedding_model_id, "embedding_model": embedding_model_id,
"embedding_dimension": embedding_dimension,
"provider_id": "my_provider",
}, },
) )
@ -1239,8 +1185,6 @@ def test_openai_vector_store_file_batch_list_files(
name="batch_list_test_store", name="batch_list_test_store",
extra_body={ extra_body={
"embedding_model": embedding_model_id, "embedding_model": embedding_model_id,
"embedding_dimension": embedding_dimension,
"provider_id": "my_provider",
}, },
) )
@ -1258,8 +1202,6 @@ def test_openai_vector_store_file_batch_list_files(
file_ids=file_ids, file_ids=file_ids,
extra_body={ extra_body={
"embedding_model": embedding_model_id, "embedding_model": embedding_model_id,
"embedding_dimension": embedding_dimension,
"provider_id": "my_provider",
}, },
) )
@ -1336,8 +1278,6 @@ def test_openai_vector_store_file_batch_cancel(
name="batch_cancel_test_store", name="batch_cancel_test_store",
extra_body={ extra_body={
"embedding_model": embedding_model_id, "embedding_model": embedding_model_id,
"embedding_dimension": embedding_dimension,
"provider_id": "my_provider",
}, },
) )
@ -1355,8 +1295,6 @@ def test_openai_vector_store_file_batch_cancel(
file_ids=file_ids, file_ids=file_ids,
extra_body={ extra_body={
"embedding_model": embedding_model_id, "embedding_model": embedding_model_id,
"embedding_dimension": embedding_dimension,
"provider_id": "my_provider",
}, },
) )
@ -1395,8 +1333,6 @@ def test_openai_vector_store_file_batch_retrieve_contents(
name="batch_contents_test_store", name="batch_contents_test_store",
extra_body={ extra_body={
"embedding_model": embedding_model_id, "embedding_model": embedding_model_id,
"embedding_dimension": embedding_dimension,
"provider_id": "my_provider",
}, },
) )
@ -1419,8 +1355,6 @@ def test_openai_vector_store_file_batch_retrieve_contents(
file_ids=file_ids, file_ids=file_ids,
extra_body={ extra_body={
"embedding_model": embedding_model_id, "embedding_model": embedding_model_id,
"embedding_dimension": embedding_dimension,
"provider_id": "my_provider",
}, },
) )
@ -1472,8 +1406,6 @@ def test_openai_vector_store_file_batch_error_handling(
name="batch_error_test_store", name="batch_error_test_store",
extra_body={ extra_body={
"embedding_model": embedding_model_id, "embedding_model": embedding_model_id,
"embedding_dimension": embedding_dimension,
"provider_id": "my_provider",
}, },
) )
@ -1485,8 +1417,6 @@ def test_openai_vector_store_file_batch_error_handling(
file_ids=file_ids, file_ids=file_ids,
extra_body={ extra_body={
"embedding_model": embedding_model_id, "embedding_model": embedding_model_id,
"embedding_dimension": embedding_dimension,
"provider_id": "my_provider",
}, },
) )

View file

@ -52,8 +52,6 @@ def test_vector_db_retrieve(client_with_empty_registry, embedding_model_id, embe
name=vector_db_name, name=vector_db_name,
extra_body={ extra_body={
"embedding_model": embedding_model_id, "embedding_model": embedding_model_id,
"embedding_dimension": embedding_dimension,
"provider_id": "my_provider",
}, },
) )
@ -73,8 +71,6 @@ def test_vector_db_register(client_with_empty_registry, embedding_model_id, embe
name=vector_db_name, name=vector_db_name,
extra_body={ extra_body={
"embedding_model": embedding_model_id, "embedding_model": embedding_model_id,
"embedding_dimension": embedding_dimension,
"provider_id": "my_provider",
}, },
) )
@ -110,8 +106,6 @@ def test_insert_chunks(client_with_empty_registry, embedding_model_id, embedding
name=vector_db_name, name=vector_db_name,
extra_body={ extra_body={
"embedding_model": embedding_model_id, "embedding_model": embedding_model_id,
"embedding_dimension": embedding_dimension,
"provider_id": "my_provider",
}, },
) )
@ -152,8 +146,6 @@ def test_insert_chunks_with_precomputed_embeddings(client_with_empty_registry, e
name=vector_db_name, name=vector_db_name,
extra_body={ extra_body={
"embedding_model": embedding_model_id, "embedding_model": embedding_model_id,
"embedding_dimension": embedding_dimension,
"provider_id": "my_provider",
}, },
) )
@ -202,8 +194,6 @@ def test_query_returns_valid_object_when_identical_to_embedding_in_vdb(
name=vector_db_name, name=vector_db_name,
extra_body={ extra_body={
"embedding_model": embedding_model_id, "embedding_model": embedding_model_id,
"embedding_dimension": embedding_dimension,
"provider_id": "my_provider",
}, },
) )
@ -234,3 +224,35 @@ def test_query_returns_valid_object_when_identical_to_embedding_in_vdb(
assert len(response.chunks) > 0 assert len(response.chunks) > 0
assert response.chunks[0].metadata["document_id"] == "doc1" assert response.chunks[0].metadata["document_id"] == "doc1"
assert response.chunks[0].metadata["source"] == "precomputed" assert response.chunks[0].metadata["source"] == "precomputed"
def test_auto_extract_embedding_dimension(client_with_empty_registry, embedding_model_id):
vs = client_with_empty_registry.vector_stores.create(
name="test_auto_extract", extra_body={"embedding_model": embedding_model_id}
)
assert vs.id is not None
def test_provider_auto_selection_single_provider(client_with_empty_registry, embedding_model_id):
providers = [p for p in client_with_empty_registry.providers.list() if p.api == "vector_io"]
if len(providers) != 1:
pytest.skip(f"Test requires exactly one vector_io provider, found {len(providers)}")
vs = client_with_empty_registry.vector_stores.create(
name="test_auto_provider", extra_body={"embedding_model": embedding_model_id}
)
assert vs.id is not None
def test_provider_id_override(client_with_empty_registry, embedding_model_id):
providers = [p for p in client_with_empty_registry.providers.list() if p.api == "vector_io"]
if len(providers) != 1:
pytest.skip(f"Test requires exactly one vector_io provider, found {len(providers)}")
provider_id = providers[0].provider_id
vs = client_with_empty_registry.vector_stores.create(
name="test_provider_override", extra_body={"embedding_model": embedding_model_id, "provider_id": provider_id}
)
assert vs.id is not None
assert vs.metadata.get("provider_id") == provider_id

View file

@ -0,0 +1,57 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from unittest.mock import AsyncMock, Mock
import pytest
from llama_stack.apis.vector_io import OpenAICreateVectorStoreRequestWithExtraBody
from llama_stack.core.routers.vector_io import VectorIORouter
async def test_single_provider_auto_selection():
# provider_id automatically selected during vector store create() when only one provider available
mock_routing_table = Mock()
mock_routing_table.impls_by_provider_id = {"inline::faiss": "mock_provider"}
mock_routing_table.get_all_with_type = AsyncMock(
return_value=[
Mock(identifier="all-MiniLM-L6-v2", model_type="embedding", metadata={"embedding_dimension": 384})
]
)
mock_routing_table.register_vector_db = AsyncMock(
return_value=Mock(identifier="vs_123", provider_id="inline::faiss", provider_resource_id="vs_123")
)
mock_routing_table.get_provider_impl = AsyncMock(
return_value=Mock(openai_create_vector_store=AsyncMock(return_value=Mock(id="vs_123")))
)
router = VectorIORouter(mock_routing_table)
request = OpenAICreateVectorStoreRequestWithExtraBody.model_validate(
{"name": "test_store", "embedding_model": "all-MiniLM-L6-v2"}
)
result = await router.openai_create_vector_store(request)
assert result.id == "vs_123"
async def test_create_vector_stores_multiple_providers_missing_provider_id_error():
# if multiple providers are available, vector store create will error without provider_id
mock_routing_table = Mock()
mock_routing_table.impls_by_provider_id = {
"inline::faiss": "mock_provider_1",
"inline::sqlite-vec": "mock_provider_2",
}
mock_routing_table.get_all_with_type = AsyncMock(
return_value=[
Mock(identifier="all-MiniLM-L6-v2", model_type="embedding", metadata={"embedding_dimension": 384})
]
)
router = VectorIORouter(mock_routing_table)
request = OpenAICreateVectorStoreRequestWithExtraBody.model_validate(
{"name": "test_store", "embedding_model": "all-MiniLM-L6-v2"}
)
with pytest.raises(ValueError, match="Multiple vector_io providers available"):
await router.openai_create_vector_store(request)

View file

@ -5,7 +5,8 @@
# the root directory of this source tree. # the root directory of this source tree.
import base64 import base64
from unittest.mock import AsyncMock, patch import json
from unittest.mock import AsyncMock, Mock, patch
import pytest import pytest
from fastapi import FastAPI from fastapi import FastAPI
@ -374,7 +375,7 @@ async def mock_jwks_response(*args, **kwargs):
@pytest.fixture @pytest.fixture
def jwt_token_valid(): def jwt_token_valid():
from jose import jwt import jwt
return jwt.encode( return jwt.encode(
{ {
@ -389,8 +390,30 @@ def jwt_token_valid():
) )
@patch("httpx.AsyncClient.get", new=mock_jwks_response) @pytest.fixture
def test_valid_oauth2_authentication(oauth2_client, jwt_token_valid): def mock_jwks_urlopen():
"""Mock urllib.request.urlopen for PyJWKClient JWKS requests."""
with patch("urllib.request.urlopen") as mock_urlopen:
# Mock the JWKS response for PyJWKClient
mock_response = Mock()
mock_response.read.return_value = json.dumps(
{
"keys": [
{
"kid": "1234567890",
"kty": "oct",
"alg": "HS256",
"use": "sig",
"k": base64.b64encode(b"foobarbaz").decode(),
}
]
}
).encode()
mock_urlopen.return_value.__enter__.return_value = mock_response
yield mock_urlopen
def test_valid_oauth2_authentication(oauth2_client, jwt_token_valid, mock_jwks_urlopen):
response = oauth2_client.get("/test", headers={"Authorization": f"Bearer {jwt_token_valid}"}) response = oauth2_client.get("/test", headers={"Authorization": f"Bearer {jwt_token_valid}"})
assert response.status_code == 200 assert response.status_code == 200
assert response.json() == {"message": "Authentication successful"} assert response.json() == {"message": "Authentication successful"}
@ -447,8 +470,7 @@ def test_oauth2_with_jwks_token_expected(oauth2_client, jwt_token_valid):
assert response.status_code == 401 assert response.status_code == 401
@patch("httpx.AsyncClient.get", new=mock_auth_jwks_response) def test_oauth2_with_jwks_token_configured(oauth2_client_with_jwks_token, jwt_token_valid, mock_jwks_urlopen):
def test_oauth2_with_jwks_token_configured(oauth2_client_with_jwks_token, jwt_token_valid):
response = oauth2_client_with_jwks_token.get("/test", headers={"Authorization": f"Bearer {jwt_token_valid}"}) response = oauth2_client_with_jwks_token.get("/test", headers={"Authorization": f"Bearer {jwt_token_valid}"})
assert response.status_code == 200 assert response.status_code == 200
assert response.json() == {"message": "Authentication successful"} assert response.json() == {"message": "Authentication successful"}

49
uv.lock generated
View file

@ -874,18 +874,6 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/b0/0d/9feae160378a3553fa9a339b0e9c1a048e147a4127210e286ef18b730f03/durationpy-0.10-py3-none-any.whl", hash = "sha256:3b41e1b601234296b4fb368338fdcd3e13e0b4fb5b67345948f4f2bf9868b286", size = 3922, upload-time = "2025-05-17T13:52:36.463Z" }, { url = "https://files.pythonhosted.org/packages/b0/0d/9feae160378a3553fa9a339b0e9c1a048e147a4127210e286ef18b730f03/durationpy-0.10-py3-none-any.whl", hash = "sha256:3b41e1b601234296b4fb368338fdcd3e13e0b4fb5b67345948f4f2bf9868b286", size = 3922, upload-time = "2025-05-17T13:52:36.463Z" },
] ]
[[package]]
name = "ecdsa"
version = "0.19.1"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "six" },
]
sdist = { url = "https://files.pythonhosted.org/packages/c0/1f/924e3caae75f471eae4b26bd13b698f6af2c44279f67af317439c2f4c46a/ecdsa-0.19.1.tar.gz", hash = "sha256:478cba7b62555866fcb3bb3fe985e06decbdb68ef55713c4e5ab98c57d508e61", size = 201793, upload-time = "2025-03-13T11:52:43.25Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/cb/a3/460c57f094a4a165c84a1341c373b0a4f5ec6ac244b998d5021aade89b77/ecdsa-0.19.1-py2.py3-none-any.whl", hash = "sha256:30638e27cf77b7e15c4c4cc1973720149e1033827cfd00661ca5c8cc0cdb24c3", size = 150607, upload-time = "2025-03-13T11:52:41.757Z" },
]
[[package]] [[package]]
name = "eval-type-backport" name = "eval-type-backport"
version = "0.2.2" version = "0.2.2"
@ -1787,8 +1775,8 @@ dependencies = [
{ name = "pillow" }, { name = "pillow" },
{ name = "prompt-toolkit" }, { name = "prompt-toolkit" },
{ name = "pydantic" }, { name = "pydantic" },
{ name = "pyjwt", extra = ["crypto"] },
{ name = "python-dotenv" }, { name = "python-dotenv" },
{ name = "python-jose", extra = ["cryptography"] },
{ name = "python-multipart" }, { name = "python-multipart" },
{ name = "rich" }, { name = "rich" },
{ name = "sqlalchemy", extra = ["asyncio"] }, { name = "sqlalchemy", extra = ["asyncio"] },
@ -1910,8 +1898,8 @@ requires-dist = [
{ name = "pillow" }, { name = "pillow" },
{ name = "prompt-toolkit" }, { name = "prompt-toolkit" },
{ name = "pydantic", specifier = ">=2.11.9" }, { name = "pydantic", specifier = ">=2.11.9" },
{ name = "pyjwt", extras = ["crypto"], specifier = ">=2.10.0" },
{ name = "python-dotenv" }, { name = "python-dotenv" },
{ name = "python-jose", extras = ["cryptography"] },
{ name = "python-multipart", specifier = ">=0.0.20" }, { name = "python-multipart", specifier = ">=0.0.20" },
{ name = "rich" }, { name = "rich" },
{ name = "sqlalchemy", extras = ["asyncio"], specifier = ">=2.0.41" }, { name = "sqlalchemy", extras = ["asyncio"], specifier = ">=2.0.41" },
@ -3558,6 +3546,20 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/c7/21/705964c7812476f378728bdf590ca4b771ec72385c533964653c68e86bdc/pygments-2.19.2-py3-none-any.whl", hash = "sha256:86540386c03d588bb81d44bc3928634ff26449851e99741617ecb9037ee5ec0b", size = 1225217, upload-time = "2025-06-21T13:39:07.939Z" }, { url = "https://files.pythonhosted.org/packages/c7/21/705964c7812476f378728bdf590ca4b771ec72385c533964653c68e86bdc/pygments-2.19.2-py3-none-any.whl", hash = "sha256:86540386c03d588bb81d44bc3928634ff26449851e99741617ecb9037ee5ec0b", size = 1225217, upload-time = "2025-06-21T13:39:07.939Z" },
] ]
[[package]]
name = "pyjwt"
version = "2.10.1"
source = { registry = "https://pypi.org/simple" }
sdist = { url = "https://files.pythonhosted.org/packages/e7/46/bd74733ff231675599650d3e47f361794b22ef3e3770998dda30d3b63726/pyjwt-2.10.1.tar.gz", hash = "sha256:3cc5772eb20009233caf06e9d8a0577824723b44e6648ee0a2aedb6cf9381953", size = 87785, upload-time = "2024-11-28T03:43:29.933Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/61/ad/689f02752eeec26aed679477e80e632ef1b682313be70793d798c1d5fc8f/PyJWT-2.10.1-py3-none-any.whl", hash = "sha256:dcdd193e30abefd5debf142f9adfcdd2b58004e644f25406ffaebd50bd98dacb", size = 22997, upload-time = "2024-11-28T03:43:27.893Z" },
]
[package.optional-dependencies]
crypto = [
{ name = "cryptography" },
]
[[package]] [[package]]
name = "pymilvus" name = "pymilvus"
version = "2.6.1" version = "2.6.1"
@ -3747,25 +3749,6 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/0c/fa/df59acedf7bbb937f69174d00f921a7b93aa5a5f5c17d05296c814fff6fc/python_engineio-4.12.2-py3-none-any.whl", hash = "sha256:8218ab66950e179dfec4b4bbb30aecf3f5d86f5e58e6fc1aa7fde2c698b2804f", size = 59536, upload-time = "2025-06-04T19:22:16.916Z" }, { url = "https://files.pythonhosted.org/packages/0c/fa/df59acedf7bbb937f69174d00f921a7b93aa5a5f5c17d05296c814fff6fc/python_engineio-4.12.2-py3-none-any.whl", hash = "sha256:8218ab66950e179dfec4b4bbb30aecf3f5d86f5e58e6fc1aa7fde2c698b2804f", size = 59536, upload-time = "2025-06-04T19:22:16.916Z" },
] ]
[[package]]
name = "python-jose"
version = "3.5.0"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "ecdsa" },
{ name = "pyasn1" },
{ name = "rsa" },
]
sdist = { url = "https://files.pythonhosted.org/packages/c6/77/3a1c9039db7124eb039772b935f2244fbb73fc8ee65b9acf2375da1c07bf/python_jose-3.5.0.tar.gz", hash = "sha256:fb4eaa44dbeb1c26dcc69e4bd7ec54a1cb8dd64d3b4d81ef08d90ff453f2b01b", size = 92726, upload-time = "2025-05-28T17:31:54.288Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/d9/c3/0bd11992072e6a1c513b16500a5d07f91a24017c5909b02c72c62d7ad024/python_jose-3.5.0-py2.py3-none-any.whl", hash = "sha256:abd1202f23d34dfad2c3d28cb8617b90acf34132c7afd60abd0b0b7d3cb55771", size = 34624, upload-time = "2025-05-28T17:31:52.802Z" },
]
[package.optional-dependencies]
cryptography = [
{ name = "cryptography" },
]
[[package]] [[package]]
name = "python-multipart" name = "python-multipart"
version = "0.0.20" version = "0.0.20"