diff --git a/.github/actions/run-and-record-tests/action.yml b/.github/actions/run-and-record-tests/action.yml index d240381c5..a5aa31af4 100644 --- a/.github/actions/run-and-record-tests/action.yml +++ b/.github/actions/run-and-record-tests/action.yml @@ -66,11 +66,11 @@ runs: shell: bash run: | echo "Checking for recording changes" - git status --porcelain tests/integration/recordings/ + git status --porcelain tests/integration/ - if [[ -n $(git status --porcelain tests/integration/recordings/) ]]; then + if [[ -n $(git status --porcelain tests/integration/) ]]; then echo "New recordings detected, committing and pushing" - git add tests/integration/recordings/ + git add tests/integration/ git commit -m "Recordings update from CI (suite: ${{ inputs.suite }})" git fetch origin ${{ github.ref_name }} diff --git a/llama_stack/core/routers/vector_io.py b/llama_stack/core/routers/vector_io.py index 79789ef0a..dc7b3a694 100644 --- a/llama_stack/core/routers/vector_io.py +++ b/llama_stack/core/routers/vector_io.py @@ -55,30 +55,18 @@ class VectorIORouter(VectorIO): logger.debug("VectorIORouter.shutdown") pass - async def _get_first_embedding_model(self) -> tuple[str, int] | None: - """Get the first available embedding model identifier.""" - try: - # Get all models from the routing table - all_models = await self.routing_table.get_all_with_type("model") + async def _get_embedding_model_dimension(self, embedding_model_id: str) -> int: + """Get the embedding dimension for a specific embedding model.""" + all_models = await self.routing_table.get_all_with_type("model") - # Filter for embedding models - embedding_models = [ - model - for model in all_models - if hasattr(model, "model_type") and model.model_type == ModelType.embedding - ] - - if embedding_models: - dimension = embedding_models[0].metadata.get("embedding_dimension", None) + for model in all_models: + if model.identifier == embedding_model_id and model.model_type == ModelType.embedding: + dimension = model.metadata.get("embedding_dimension") if dimension is None: - raise ValueError(f"Embedding model {embedding_models[0].identifier} has no embedding dimension") - return embedding_models[0].identifier, dimension - else: - logger.warning("No embedding models found in the routing table") - return None - except Exception as e: - logger.error(f"Error getting embedding models: {e}") - return None + raise ValueError(f"Embedding model '{embedding_model_id}' has no embedding_dimension in metadata") + return int(dimension) + + raise ValueError(f"Embedding model '{embedding_model_id}' not found or not an embedding model") async def register_vector_db( self, @@ -129,20 +117,30 @@ class VectorIORouter(VectorIO): # Extract llama-stack-specific parameters from extra_body extra = params.model_extra or {} embedding_model = extra.get("embedding_model") - embedding_dimension = extra.get("embedding_dimension", 384) + embedding_dimension = extra.get("embedding_dimension") provider_id = extra.get("provider_id") logger.debug(f"VectorIORouter.openai_create_vector_store: name={params.name}, provider_id={provider_id}") - # If no embedding model is provided, use the first available one - # TODO: this branch will soon be deleted so you _must_ provide the embedding_model when - # creating a vector store + # Require explicit embedding model specification if embedding_model is None: - embedding_model_info = await self._get_first_embedding_model() - if embedding_model_info is None: - raise ValueError("No embedding model provided and no embedding models available in the system") - embedding_model, embedding_dimension = embedding_model_info - logger.info(f"No embedding model specified, using first available: {embedding_model}") + raise ValueError("embedding_model is required in extra_body when creating a vector store") + + if embedding_dimension is None: + embedding_dimension = await self._get_embedding_model_dimension(embedding_model) + + # Auto-select provider if not specified + if provider_id is None: + num_providers = len(self.routing_table.impls_by_provider_id) + if num_providers == 0: + raise ValueError("No vector_io providers available") + if num_providers > 1: + available_providers = list(self.routing_table.impls_by_provider_id.keys()) + raise ValueError( + f"Multiple vector_io providers available. Please specify provider_id in extra_body. " + f"Available providers: {available_providers}" + ) + provider_id = list(self.routing_table.impls_by_provider_id.keys())[0] vector_db_id = f"vs_{uuid.uuid4()}" registered_vector_db = await self.routing_table.register_vector_db( diff --git a/llama_stack/core/server/auth_providers.py b/llama_stack/core/server/auth_providers.py index 38188c49a..05a21c8d4 100644 --- a/llama_stack/core/server/auth_providers.py +++ b/llama_stack/core/server/auth_providers.py @@ -5,13 +5,11 @@ # the root directory of this source tree. import ssl -import time from abc import ABC, abstractmethod -from asyncio import Lock from urllib.parse import parse_qs, urljoin, urlparse import httpx -from jose import jwt +import jwt from pydantic import BaseModel, Field from llama_stack.apis.common.errors import TokenValidationError @@ -98,9 +96,7 @@ class OAuth2TokenAuthProvider(AuthProvider): def __init__(self, config: OAuth2TokenAuthConfig): self.config = config - self._jwks_at: float = 0.0 - self._jwks: dict[str, str] = {} - self._jwks_lock = Lock() + self._jwks_client: jwt.PyJWKClient | None = None async def validate_token(self, token: str, scope: dict | None = None) -> User: if self.config.jwks: @@ -109,23 +105,60 @@ class OAuth2TokenAuthProvider(AuthProvider): return await self.introspect_token(token, scope) raise ValueError("One of jwks or introspection must be configured") + def _get_jwks_client(self) -> jwt.PyJWKClient: + if self._jwks_client is None: + ssl_context = None + if not self.config.verify_tls: + # Disable SSL verification if verify_tls is False + ssl_context = ssl.create_default_context() + ssl_context.check_hostname = False + ssl_context.verify_mode = ssl.CERT_NONE + elif self.config.tls_cafile: + # Use custom CA file if provided + ssl_context = ssl.create_default_context( + cafile=self.config.tls_cafile.as_posix(), + ) + # If verify_tls is True and no tls_cafile, ssl_context remains None (use system defaults) + + # Prepare headers for JWKS request - this is needed for Kubernetes to authenticate + # to the JWK endpoint, we must use the token in the config to authenticate + headers = {} + if self.config.jwks and self.config.jwks.token: + headers["Authorization"] = f"Bearer {self.config.jwks.token}" + + self._jwks_client = jwt.PyJWKClient( + self.config.jwks.uri if self.config.jwks else None, + cache_keys=True, + max_cached_keys=10, + lifespan=self.config.jwks.key_recheck_period if self.config.jwks else None, + headers=headers, + ssl_context=ssl_context, + ) + return self._jwks_client + async def validate_jwt_token(self, token: str, scope: dict | None = None) -> User: """Validate a token using the JWT token.""" - await self._refresh_jwks() - try: - header = jwt.get_unverified_header(token) - kid = header["kid"] - if kid not in self._jwks: - raise ValueError(f"Unknown key ID: {kid}") - key_data = self._jwks[kid] - algorithm = header.get("alg", "RS256") + jwks_client: jwt.PyJWKClient = self._get_jwks_client() + signing_key = jwks_client.get_signing_key_from_jwt(token) + algorithm = jwt.get_unverified_header(token)["alg"] claims = jwt.decode( token, - key_data, + signing_key.key, algorithms=[algorithm], audience=self.config.audience, issuer=self.config.issuer, + options={"verify_exp": True, "verify_aud": True, "verify_iss": True}, + ) + + # Decode and verify the JWT + claims = jwt.decode( + token, + signing_key.key, + algorithms=[algorithm], + audience=self.config.audience, + issuer=self.config.issuer, + options={"verify_exp": True, "verify_aud": True, "verify_iss": True}, ) except Exception as exc: raise ValueError("Invalid JWT token") from exc @@ -201,37 +234,6 @@ class OAuth2TokenAuthProvider(AuthProvider): else: return "Authentication required. Please provide a valid OAuth2 Bearer token in the Authorization header" - async def _refresh_jwks(self) -> None: - """ - Refresh the JWKS cache. - - This is a simple cache that expires after a certain amount of time (defined by `key_recheck_period`). - If the cache is expired, we refresh the JWKS from the JWKS URI. - - Notes: for Kubernetes which doesn't fully implement the OIDC protocol: - * It doesn't have user authentication flows - * It doesn't have refresh tokens - """ - async with self._jwks_lock: - if self.config.jwks is None: - raise ValueError("JWKS is not configured") - if time.time() - self._jwks_at > self.config.jwks.key_recheck_period: - headers = {} - if self.config.jwks.token: - headers["Authorization"] = f"Bearer {self.config.jwks.token}" - verify = self.config.tls_cafile.as_posix() if self.config.tls_cafile else self.config.verify_tls - async with httpx.AsyncClient(verify=verify) as client: - res = await client.get(self.config.jwks.uri, timeout=5, headers=headers) - res.raise_for_status() - jwks_data = res.json()["keys"] - updated = {} - for k in jwks_data: - kid = k["kid"] - # Store the entire key object as it may be needed for different algorithms - updated[kid] = k - self._jwks = updated - self._jwks_at = time.time() - class CustomAuthProvider(AuthProvider): """Custom authentication provider that uses an external endpoint.""" diff --git a/llama_stack/providers/registry/inference.py b/llama_stack/providers/registry/inference.py index ab1434669..35afb296d 100644 --- a/llama_stack/providers/registry/inference.py +++ b/llama_stack/providers/registry/inference.py @@ -277,7 +277,7 @@ Available Models: pip_packages=["litellm"], module="llama_stack.providers.remote.inference.watsonx", config_class="llama_stack.providers.remote.inference.watsonx.WatsonXConfig", - provider_data_validator="llama_stack.providers.remote.inference.watsonx.WatsonXProviderDataValidator", + provider_data_validator="llama_stack.providers.remote.inference.watsonx.config.WatsonXProviderDataValidator", description="IBM WatsonX inference provider for accessing AI models on IBM's WatsonX platform.", ), RemoteProviderSpec( diff --git a/llama_stack/providers/remote/inference/watsonx/config.py b/llama_stack/providers/remote/inference/watsonx/config.py index 022dc5ee7..8d8df13b4 100644 --- a/llama_stack/providers/remote/inference/watsonx/config.py +++ b/llama_stack/providers/remote/inference/watsonx/config.py @@ -7,18 +7,18 @@ import os from typing import Any -from pydantic import BaseModel, ConfigDict, Field +from pydantic import BaseModel, Field from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig from llama_stack.schema_utils import json_schema_type class WatsonXProviderDataValidator(BaseModel): - model_config = ConfigDict( - from_attributes=True, - extra="forbid", + watsonx_project_id: str | None = Field( + default=None, + description="IBM WatsonX project ID", ) - watsonx_api_key: str | None + watsonx_api_key: str | None = None @json_schema_type diff --git a/llama_stack/providers/remote/inference/watsonx/watsonx.py b/llama_stack/providers/remote/inference/watsonx/watsonx.py index 654d61f34..2c051719b 100644 --- a/llama_stack/providers/remote/inference/watsonx/watsonx.py +++ b/llama_stack/providers/remote/inference/watsonx/watsonx.py @@ -4,42 +4,259 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. +from collections.abc import AsyncIterator from typing import Any +import litellm import requests -from llama_stack.apis.inference import ChatCompletionRequest +from llama_stack.apis.inference.inference import ( + OpenAIChatCompletion, + OpenAIChatCompletionChunk, + OpenAIChatCompletionRequestWithExtraBody, + OpenAIChatCompletionUsage, + OpenAICompletion, + OpenAICompletionRequestWithExtraBody, + OpenAIEmbeddingsRequestWithExtraBody, + OpenAIEmbeddingsResponse, +) from llama_stack.apis.models import Model from llama_stack.apis.models.models import ModelType +from llama_stack.log import get_logger from llama_stack.providers.remote.inference.watsonx.config import WatsonXConfig from llama_stack.providers.utils.inference.litellm_openai_mixin import LiteLLMOpenAIMixin +from llama_stack.providers.utils.inference.openai_compat import prepare_openai_completion_params +from llama_stack.providers.utils.telemetry.tracing import get_current_span + +logger = get_logger(name=__name__, category="providers::remote::watsonx") class WatsonXInferenceAdapter(LiteLLMOpenAIMixin): _model_cache: dict[str, Model] = {} + provider_data_api_key_field: str = "watsonx_api_key" + def __init__(self, config: WatsonXConfig): + self.available_models = None + self.config = config + api_key = config.auth_credential.get_secret_value() if config.auth_credential else None LiteLLMOpenAIMixin.__init__( self, litellm_provider_name="watsonx", - api_key_from_config=config.auth_credential.get_secret_value() if config.auth_credential else None, + api_key_from_config=api_key, provider_data_api_key_field="watsonx_api_key", + openai_compat_api_base=self.get_base_url(), + ) + + async def openai_chat_completion( + self, + params: OpenAIChatCompletionRequestWithExtraBody, + ) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]: + """ + Override parent method to add timeout and inject usage object when missing. + This works around a LiteLLM defect where usage block is sometimes dropped. + """ + + # Add usage tracking for streaming when telemetry is active + stream_options = params.stream_options + if params.stream and get_current_span() is not None: + if stream_options is None: + stream_options = {"include_usage": True} + elif "include_usage" not in stream_options: + stream_options = {**stream_options, "include_usage": True} + + model_obj = await self.model_store.get_model(params.model) + + request_params = await prepare_openai_completion_params( + model=self.get_litellm_model_name(model_obj.provider_resource_id), + messages=params.messages, + frequency_penalty=params.frequency_penalty, + function_call=params.function_call, + functions=params.functions, + logit_bias=params.logit_bias, + logprobs=params.logprobs, + max_completion_tokens=params.max_completion_tokens, + max_tokens=params.max_tokens, + n=params.n, + parallel_tool_calls=params.parallel_tool_calls, + presence_penalty=params.presence_penalty, + response_format=params.response_format, + seed=params.seed, + stop=params.stop, + stream=params.stream, + stream_options=stream_options, + temperature=params.temperature, + tool_choice=params.tool_choice, + tools=params.tools, + top_logprobs=params.top_logprobs, + top_p=params.top_p, + user=params.user, + api_key=self.get_api_key(), + api_base=self.api_base, + # These are watsonx-specific parameters + timeout=self.config.timeout, + project_id=self.config.project_id, + ) + + result = await litellm.acompletion(**request_params) + + # If not streaming, check and inject usage if missing + if not params.stream: + # Use getattr to safely handle cases where usage attribute might not exist + if getattr(result, "usage", None) is None: + # Create usage object with zeros + usage_obj = OpenAIChatCompletionUsage( + prompt_tokens=0, + completion_tokens=0, + total_tokens=0, + ) + # Use model_copy to create a new response with the usage injected + result = result.model_copy(update={"usage": usage_obj}) + return result + + # For streaming, wrap the iterator to normalize chunks + return self._normalize_stream(result) + + def _normalize_chunk(self, chunk: OpenAIChatCompletionChunk) -> OpenAIChatCompletionChunk: + """ + Normalize a chunk to ensure it has all expected attributes. + This works around LiteLLM not always including all expected attributes. + """ + # Ensure chunk has usage attribute with zeros if missing + if not hasattr(chunk, "usage") or chunk.usage is None: + usage_obj = OpenAIChatCompletionUsage( + prompt_tokens=0, + completion_tokens=0, + total_tokens=0, + ) + chunk = chunk.model_copy(update={"usage": usage_obj}) + + # Ensure all delta objects in choices have expected attributes + if hasattr(chunk, "choices") and chunk.choices: + normalized_choices = [] + for choice in chunk.choices: + if hasattr(choice, "delta") and choice.delta: + delta = choice.delta + # Build update dict for missing attributes + delta_updates = {} + if not hasattr(delta, "refusal"): + delta_updates["refusal"] = None + if not hasattr(delta, "reasoning_content"): + delta_updates["reasoning_content"] = None + + # If we need to update delta, create a new choice with updated delta + if delta_updates: + new_delta = delta.model_copy(update=delta_updates) + new_choice = choice.model_copy(update={"delta": new_delta}) + normalized_choices.append(new_choice) + else: + normalized_choices.append(choice) + else: + normalized_choices.append(choice) + + # If we modified any choices, create a new chunk with updated choices + if any(normalized_choices[i] is not chunk.choices[i] for i in range(len(chunk.choices))): + chunk = chunk.model_copy(update={"choices": normalized_choices}) + + return chunk + + async def _normalize_stream( + self, stream: AsyncIterator[OpenAIChatCompletionChunk] + ) -> AsyncIterator[OpenAIChatCompletionChunk]: + """ + Normalize all chunks in the stream to ensure they have expected attributes. + This works around LiteLLM sometimes not including expected attributes. + """ + try: + async for chunk in stream: + # Normalize and yield each chunk immediately + yield self._normalize_chunk(chunk) + except Exception as e: + logger.error(f"Error normalizing stream: {e}", exc_info=True) + raise + + async def openai_completion( + self, + params: OpenAICompletionRequestWithExtraBody, + ) -> OpenAICompletion: + """ + Override parent method to add watsonx-specific parameters. + """ + from llama_stack.providers.utils.inference.openai_compat import prepare_openai_completion_params + + model_obj = await self.model_store.get_model(params.model) + + request_params = await prepare_openai_completion_params( + model=self.get_litellm_model_name(model_obj.provider_resource_id), + prompt=params.prompt, + best_of=params.best_of, + echo=params.echo, + frequency_penalty=params.frequency_penalty, + logit_bias=params.logit_bias, + logprobs=params.logprobs, + max_tokens=params.max_tokens, + n=params.n, + presence_penalty=params.presence_penalty, + seed=params.seed, + stop=params.stop, + stream=params.stream, + stream_options=params.stream_options, + temperature=params.temperature, + top_p=params.top_p, + user=params.user, + suffix=params.suffix, + api_key=self.get_api_key(), + api_base=self.api_base, + # These are watsonx-specific parameters + timeout=self.config.timeout, + project_id=self.config.project_id, + ) + return await litellm.atext_completion(**request_params) + + async def openai_embeddings( + self, + params: OpenAIEmbeddingsRequestWithExtraBody, + ) -> OpenAIEmbeddingsResponse: + """ + Override parent method to add watsonx-specific parameters. + """ + model_obj = await self.model_store.get_model(params.model) + + # Convert input to list if it's a string + input_list = [params.input] if isinstance(params.input, str) else params.input + + # Call litellm embedding function with watsonx-specific parameters + response = litellm.embedding( + model=self.get_litellm_model_name(model_obj.provider_resource_id), + input=input_list, + api_key=self.get_api_key(), + api_base=self.api_base, + dimensions=params.dimensions, + # These are watsonx-specific parameters + timeout=self.config.timeout, + project_id=self.config.project_id, + ) + + # Convert response to OpenAI format + from llama_stack.apis.inference import OpenAIEmbeddingUsage + from llama_stack.providers.utils.inference.litellm_openai_mixin import b64_encode_openai_embeddings_response + + data = b64_encode_openai_embeddings_response(response.data, params.encoding_format) + + usage = OpenAIEmbeddingUsage( + prompt_tokens=response["usage"]["prompt_tokens"], + total_tokens=response["usage"]["total_tokens"], + ) + + return OpenAIEmbeddingsResponse( + data=data, + model=model_obj.provider_resource_id, + usage=usage, ) - self.available_models = None - self.config = config def get_base_url(self) -> str: return self.config.url - async def _get_params(self, request: ChatCompletionRequest) -> dict[str, Any]: - # Get base parameters from parent - params = await super()._get_params(request) - - # Add watsonx.ai specific parameters - params["project_id"] = self.config.project_id - params["time_limit"] = self.config.timeout - return params - # Copied from OpenAIMixin async def check_model_availability(self, model: str) -> bool: """ diff --git a/llama_stack/providers/utils/memory/openai_vector_store_mixin.py b/llama_stack/providers/utils/memory/openai_vector_store_mixin.py index 100c78a9a..d46e9bbd9 100644 --- a/llama_stack/providers/utils/memory/openai_vector_store_mixin.py +++ b/llama_stack/providers/utils/memory/openai_vector_store_mixin.py @@ -353,14 +353,11 @@ class OpenAIVectorStoreMixin(ABC): provider_vector_db_id = extra.get("provider_vector_db_id") embedding_model = extra.get("embedding_model") embedding_dimension = extra.get("embedding_dimension", 768) - provider_id = extra.get("provider_id") - + # use provider_id set by router; fallback to provider's own ID when used directly via --stack-config + provider_id = extra.get("provider_id") or getattr(self, "__provider_id__", None) # Derive the canonical vector_db_id (allow override, else generate) vector_db_id = provider_vector_db_id or generate_object_id("vector_store", lambda: f"vs_{uuid.uuid4()}") - if provider_id is None: - raise ValueError("Provider ID is required") - if embedding_model is None: raise ValueError("Embedding model is required") @@ -369,6 +366,9 @@ class OpenAIVectorStoreMixin(ABC): raise ValueError("Embedding dimension is required") # Register the VectorDB backing this vector store + if provider_id is None: + raise ValueError("Provider ID is required but was not provided") + vector_db = VectorDB( identifier=vector_db_id, embedding_dimension=embedding_dimension, diff --git a/pyproject.toml b/pyproject.toml index 81997c249..d55de794d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -34,7 +34,7 @@ dependencies = [ "openai>=1.107", # for expires_after support "prompt-toolkit", "python-dotenv", - "python-jose[cryptography]", + "pyjwt[crypto]>=2.10.0", # Pull crypto to support RS256 for jwt. Requires 2.10.0+ for ssl_context support. "pydantic>=2.11.9", "rich", "starlette", diff --git a/tests/integration/inference/test_openai_completion.py b/tests/integration/inference/test_openai_completion.py index 3f0cffb2d..65f773889 100644 --- a/tests/integration/inference/test_openai_completion.py +++ b/tests/integration/inference/test_openai_completion.py @@ -58,7 +58,6 @@ def skip_if_model_doesnt_support_openai_completion(client_with_models, model_id) # does not work with the specified model, gpt-5-mini. Please choose different model and try # again. You can learn more about which models can be used with each operation here: # https://go.microsoft.com/fwlink/?linkid=2197993.'}}"} - "remote::watsonx", # return 404 when hitting the /openai/v1 endpoint "remote::llama-openai-compat", ): pytest.skip(f"Model {model_id} hosted by {provider.provider_type} doesn't support OpenAI completions.") @@ -68,6 +67,7 @@ def skip_if_doesnt_support_completions_logprobs(client_with_models, model_id): provider_type = provider_from_model(client_with_models, model_id).provider_type if provider_type in ( "remote::ollama", # logprobs is ignored + "remote::watsonx", ): pytest.skip(f"Model {model_id} hosted by {provider_type} doesn't support /v1/completions logprobs.") @@ -110,6 +110,7 @@ def skip_if_doesnt_support_n(client_with_models, model_id): # Error code 400 - {'message': '"n" > 1 is not currently supported', 'type': 'invalid_request_error', 'param': 'n', 'code': 'wrong_api_format'} "remote::cerebras", "remote::databricks", # Bad request: parameter "n" must be equal to 1 for streaming mode + "remote::watsonx", ): pytest.skip(f"Model {model_id} hosted by {provider.provider_type} doesn't support n param.") @@ -124,7 +125,6 @@ def skip_if_model_doesnt_support_openai_chat_completion(client_with_models, mode "remote::databricks", "remote::cerebras", "remote::runpod", - "remote::watsonx", # watsonx returns 404 when hitting the /openai/v1 endpoint ): pytest.skip(f"Model {model_id} hosted by {provider.provider_type} doesn't support OpenAI chat completions.") @@ -508,6 +508,12 @@ def test_openai_chat_completion_non_streaming_with_file(openai_client, client_wi assert "hello world" in normalized_content +def skip_if_doesnt_support_completions_stop_sequence(client_with_models, model_id): + provider_type = provider_from_model(client_with_models, model_id).provider_type + if provider_type in ("remote::watsonx",): # openai.BadRequestError: Error code: 400 + pytest.skip(f"Model {model_id} hosted by {provider_type} doesn't support /v1/completions stop sequence.") + + @pytest.mark.parametrize( "test_case", [ @@ -516,6 +522,7 @@ def test_openai_chat_completion_non_streaming_with_file(openai_client, client_wi ) def test_openai_completion_stop_sequence(client_with_models, openai_client, text_model_id, test_case): skip_if_model_doesnt_support_openai_completion(client_with_models, text_model_id) + skip_if_doesnt_support_completions_stop_sequence(client_with_models, text_model_id) tc = TestCase(test_case) diff --git a/tests/integration/inference/test_openai_embeddings.py b/tests/integration/inference/test_openai_embeddings.py index 2827284c8..fc2f66b9c 100644 --- a/tests/integration/inference/test_openai_embeddings.py +++ b/tests/integration/inference/test_openai_embeddings.py @@ -50,11 +50,15 @@ def skip_if_model_doesnt_support_encoding_format_base64(client, model_id): def skip_if_model_doesnt_support_variable_dimensions(client_with_models, model_id): provider = provider_from_model(client_with_models, model_id) - if provider.provider_type in ( - "remote::together", # returns 400 - "inline::sentence-transformers", - # Error code: 400 - {'error_code': 'BAD_REQUEST', 'message': 'Bad request: json: unknown field "dimensions"\n'} - "remote::databricks", + if ( + provider.provider_type + in ( + "remote::together", # returns 400 + "inline::sentence-transformers", + # Error code: 400 - {'error_code': 'BAD_REQUEST', 'message': 'Bad request: json: unknown field "dimensions"\n'} + "remote::databricks", + "remote::watsonx", # openai.BadRequestError: Error code: 400 - {'detail': "litellm.UnsupportedParamsError: watsonx does not support parameters: {'dimensions': 384} + ) ): pytest.skip( f"Model {model_id} hosted by {provider.provider_type} does not support variable output embedding dimensions." diff --git a/tests/integration/vector_io/test_openai_vector_stores.py b/tests/integration/vector_io/test_openai_vector_stores.py index 98c8fe920..3d5052f38 100644 --- a/tests/integration/vector_io/test_openai_vector_stores.py +++ b/tests/integration/vector_io/test_openai_vector_stores.py @@ -146,8 +146,6 @@ def test_openai_create_vector_store( metadata={"purpose": "testing", "environment": "integration"}, extra_body={ "embedding_model": embedding_model_id, - "embedding_dimension": embedding_dimension, - "provider_id": "my_provider", }, ) @@ -175,8 +173,6 @@ def test_openai_list_vector_stores( metadata={"type": "test"}, extra_body={ "embedding_model": embedding_model_id, - "embedding_dimension": embedding_dimension, - "provider_id": "my_provider", }, ) store2 = client.vector_stores.create( @@ -184,8 +180,6 @@ def test_openai_list_vector_stores( metadata={"type": "test"}, extra_body={ "embedding_model": embedding_model_id, - "embedding_dimension": embedding_dimension, - "provider_id": "my_provider", }, ) @@ -220,8 +214,6 @@ def test_openai_retrieve_vector_store( metadata={"purpose": "retrieval_test"}, extra_body={ "embedding_model": embedding_model_id, - "embedding_dimension": embedding_dimension, - "provider_id": "my_provider", }, ) @@ -249,8 +241,6 @@ def test_openai_update_vector_store( metadata={"version": "1.0"}, extra_body={ "embedding_model": embedding_model_id, - "embedding_dimension": embedding_dimension, - "provider_id": "my_provider", }, ) time.sleep(1) @@ -282,8 +272,6 @@ def test_openai_delete_vector_store( metadata={"purpose": "deletion_test"}, extra_body={ "embedding_model": embedding_model_id, - "embedding_dimension": embedding_dimension, - "provider_id": "my_provider", }, ) @@ -314,8 +302,6 @@ def test_openai_vector_store_search_empty( metadata={"purpose": "search_testing"}, extra_body={ "embedding_model": embedding_model_id, - "embedding_dimension": embedding_dimension, - "provider_id": "my_provider", }, ) @@ -346,8 +332,6 @@ def test_openai_vector_store_with_chunks( metadata={"purpose": "chunks_testing"}, extra_body={ "embedding_model": embedding_model_id, - "embedding_dimension": embedding_dimension, - "provider_id": "my_provider", }, ) @@ -412,8 +396,6 @@ def test_openai_vector_store_search_relevance( metadata={"purpose": "relevance_testing"}, extra_body={ "embedding_model": embedding_model_id, - "embedding_dimension": embedding_dimension, - "provider_id": "my_provider", }, ) @@ -457,8 +439,6 @@ def test_openai_vector_store_search_with_ranking_options( metadata={"purpose": "ranking_testing"}, extra_body={ "embedding_model": embedding_model_id, - "embedding_dimension": embedding_dimension, - "provider_id": "my_provider", }, ) @@ -512,8 +492,6 @@ def test_openai_vector_store_search_with_high_score_filter( metadata={"purpose": "high_score_filtering"}, extra_body={ "embedding_model": embedding_model_id, - "embedding_dimension": embedding_dimension, - "provider_id": "my_provider", }, ) @@ -573,8 +551,6 @@ def test_openai_vector_store_search_with_max_num_results( metadata={"purpose": "max_num_results_testing"}, extra_body={ "embedding_model": embedding_model_id, - "embedding_dimension": embedding_dimension, - "provider_id": "my_provider", }, ) @@ -608,8 +584,6 @@ def test_openai_vector_store_attach_file( name="test_store", extra_body={ "embedding_model": embedding_model_id, - "embedding_dimension": embedding_dimension, - "provider_id": "my_provider", }, ) @@ -688,8 +662,6 @@ def test_openai_vector_store_attach_files_on_creation( file_ids=file_ids, extra_body={ "embedding_model": embedding_model_id, - "embedding_dimension": embedding_dimension, - "provider_id": "my_provider", }, ) @@ -735,8 +707,6 @@ def test_openai_vector_store_list_files( name="test_store", extra_body={ "embedding_model": embedding_model_id, - "embedding_dimension": embedding_dimension, - "provider_id": "my_provider", }, ) @@ -826,8 +796,6 @@ def test_openai_vector_store_retrieve_file_contents( name="test_store", extra_body={ "embedding_model": embedding_model_id, - "embedding_dimension": embedding_dimension, - "provider_id": "my_provider", }, ) @@ -851,8 +819,6 @@ def test_openai_vector_store_retrieve_file_contents( attributes=attributes, extra_body={ "embedding_model": embedding_model_id, - "embedding_dimension": embedding_dimension, - "provider_id": "my_provider", }, ) @@ -889,8 +855,6 @@ def test_openai_vector_store_delete_file( name="test_store", extra_body={ "embedding_model": embedding_model_id, - "embedding_dimension": embedding_dimension, - "provider_id": "my_provider", }, ) @@ -955,8 +919,6 @@ def test_openai_vector_store_delete_file_removes_from_vector_store( name="test_store", extra_body={ "embedding_model": embedding_model_id, - "embedding_dimension": embedding_dimension, - "provider_id": "my_provider", }, ) @@ -1007,8 +969,6 @@ def test_openai_vector_store_update_file( name="test_store", extra_body={ "embedding_model": embedding_model_id, - "embedding_dimension": embedding_dimension, - "provider_id": "my_provider", }, ) @@ -1078,8 +1038,6 @@ def test_create_vector_store_files_duplicate_vector_store_name( name="test_store_with_files", extra_body={ "embedding_model": embedding_model_id, - "embedding_dimension": embedding_dimension, - "provider_id": "my_provider", }, ) assert vector_store.file_counts.completed == 0 @@ -1092,8 +1050,6 @@ def test_create_vector_store_files_duplicate_vector_store_name( name="test_store_with_files", extra_body={ "embedding_model": embedding_model_id, - "embedding_dimension": embedding_dimension, - "provider_id": "my_provider", }, ) @@ -1105,8 +1061,6 @@ def test_create_vector_store_files_duplicate_vector_store_name( file_id=file_ids[0], extra_body={ "embedding_model": embedding_model_id, - "embedding_dimension": embedding_dimension, - "provider_id": "my_provider", }, ) assert created_file.status == "completed" @@ -1117,8 +1071,6 @@ def test_create_vector_store_files_duplicate_vector_store_name( file_id=file_ids[1], extra_body={ "embedding_model": embedding_model_id, - "embedding_dimension": embedding_dimension, - "provider_id": "my_provider", }, ) assert created_file_from_non_deleted_vector_store.status == "completed" @@ -1139,8 +1091,6 @@ def test_openai_vector_store_search_modes( metadata={"purpose": "search_mode_testing"}, extra_body={ "embedding_model": embedding_model_id, - "embedding_dimension": embedding_dimension, - "provider_id": "my_provider", }, ) @@ -1172,8 +1122,6 @@ def test_openai_vector_store_file_batch_create_and_retrieve( name="batch_test_store", extra_body={ "embedding_model": embedding_model_id, - "embedding_dimension": embedding_dimension, - "provider_id": "my_provider", }, ) @@ -1191,8 +1139,6 @@ def test_openai_vector_store_file_batch_create_and_retrieve( file_ids=file_ids, extra_body={ "embedding_model": embedding_model_id, - "embedding_dimension": embedding_dimension, - "provider_id": "my_provider", }, ) @@ -1239,8 +1185,6 @@ def test_openai_vector_store_file_batch_list_files( name="batch_list_test_store", extra_body={ "embedding_model": embedding_model_id, - "embedding_dimension": embedding_dimension, - "provider_id": "my_provider", }, ) @@ -1258,8 +1202,6 @@ def test_openai_vector_store_file_batch_list_files( file_ids=file_ids, extra_body={ "embedding_model": embedding_model_id, - "embedding_dimension": embedding_dimension, - "provider_id": "my_provider", }, ) @@ -1336,8 +1278,6 @@ def test_openai_vector_store_file_batch_cancel( name="batch_cancel_test_store", extra_body={ "embedding_model": embedding_model_id, - "embedding_dimension": embedding_dimension, - "provider_id": "my_provider", }, ) @@ -1355,8 +1295,6 @@ def test_openai_vector_store_file_batch_cancel( file_ids=file_ids, extra_body={ "embedding_model": embedding_model_id, - "embedding_dimension": embedding_dimension, - "provider_id": "my_provider", }, ) @@ -1395,8 +1333,6 @@ def test_openai_vector_store_file_batch_retrieve_contents( name="batch_contents_test_store", extra_body={ "embedding_model": embedding_model_id, - "embedding_dimension": embedding_dimension, - "provider_id": "my_provider", }, ) @@ -1419,8 +1355,6 @@ def test_openai_vector_store_file_batch_retrieve_contents( file_ids=file_ids, extra_body={ "embedding_model": embedding_model_id, - "embedding_dimension": embedding_dimension, - "provider_id": "my_provider", }, ) @@ -1472,8 +1406,6 @@ def test_openai_vector_store_file_batch_error_handling( name="batch_error_test_store", extra_body={ "embedding_model": embedding_model_id, - "embedding_dimension": embedding_dimension, - "provider_id": "my_provider", }, ) @@ -1485,8 +1417,6 @@ def test_openai_vector_store_file_batch_error_handling( file_ids=file_ids, extra_body={ "embedding_model": embedding_model_id, - "embedding_dimension": embedding_dimension, - "provider_id": "my_provider", }, ) diff --git a/tests/integration/vector_io/test_vector_io.py b/tests/integration/vector_io/test_vector_io.py index f2205ed0a..653299338 100644 --- a/tests/integration/vector_io/test_vector_io.py +++ b/tests/integration/vector_io/test_vector_io.py @@ -52,8 +52,6 @@ def test_vector_db_retrieve(client_with_empty_registry, embedding_model_id, embe name=vector_db_name, extra_body={ "embedding_model": embedding_model_id, - "embedding_dimension": embedding_dimension, - "provider_id": "my_provider", }, ) @@ -73,8 +71,6 @@ def test_vector_db_register(client_with_empty_registry, embedding_model_id, embe name=vector_db_name, extra_body={ "embedding_model": embedding_model_id, - "embedding_dimension": embedding_dimension, - "provider_id": "my_provider", }, ) @@ -110,8 +106,6 @@ def test_insert_chunks(client_with_empty_registry, embedding_model_id, embedding name=vector_db_name, extra_body={ "embedding_model": embedding_model_id, - "embedding_dimension": embedding_dimension, - "provider_id": "my_provider", }, ) @@ -152,8 +146,6 @@ def test_insert_chunks_with_precomputed_embeddings(client_with_empty_registry, e name=vector_db_name, extra_body={ "embedding_model": embedding_model_id, - "embedding_dimension": embedding_dimension, - "provider_id": "my_provider", }, ) @@ -202,8 +194,6 @@ def test_query_returns_valid_object_when_identical_to_embedding_in_vdb( name=vector_db_name, extra_body={ "embedding_model": embedding_model_id, - "embedding_dimension": embedding_dimension, - "provider_id": "my_provider", }, ) @@ -234,3 +224,35 @@ def test_query_returns_valid_object_when_identical_to_embedding_in_vdb( assert len(response.chunks) > 0 assert response.chunks[0].metadata["document_id"] == "doc1" assert response.chunks[0].metadata["source"] == "precomputed" + + +def test_auto_extract_embedding_dimension(client_with_empty_registry, embedding_model_id): + vs = client_with_empty_registry.vector_stores.create( + name="test_auto_extract", extra_body={"embedding_model": embedding_model_id} + ) + assert vs.id is not None + + +def test_provider_auto_selection_single_provider(client_with_empty_registry, embedding_model_id): + providers = [p for p in client_with_empty_registry.providers.list() if p.api == "vector_io"] + if len(providers) != 1: + pytest.skip(f"Test requires exactly one vector_io provider, found {len(providers)}") + + vs = client_with_empty_registry.vector_stores.create( + name="test_auto_provider", extra_body={"embedding_model": embedding_model_id} + ) + assert vs.id is not None + + +def test_provider_id_override(client_with_empty_registry, embedding_model_id): + providers = [p for p in client_with_empty_registry.providers.list() if p.api == "vector_io"] + if len(providers) != 1: + pytest.skip(f"Test requires exactly one vector_io provider, found {len(providers)}") + + provider_id = providers[0].provider_id + + vs = client_with_empty_registry.vector_stores.create( + name="test_provider_override", extra_body={"embedding_model": embedding_model_id, "provider_id": provider_id} + ) + assert vs.id is not None + assert vs.metadata.get("provider_id") == provider_id diff --git a/tests/unit/core/routers/test_vector_io.py b/tests/unit/core/routers/test_vector_io.py new file mode 100644 index 000000000..997df0d78 --- /dev/null +++ b/tests/unit/core/routers/test_vector_io.py @@ -0,0 +1,57 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from unittest.mock import AsyncMock, Mock + +import pytest + +from llama_stack.apis.vector_io import OpenAICreateVectorStoreRequestWithExtraBody +from llama_stack.core.routers.vector_io import VectorIORouter + + +async def test_single_provider_auto_selection(): + # provider_id automatically selected during vector store create() when only one provider available + mock_routing_table = Mock() + mock_routing_table.impls_by_provider_id = {"inline::faiss": "mock_provider"} + mock_routing_table.get_all_with_type = AsyncMock( + return_value=[ + Mock(identifier="all-MiniLM-L6-v2", model_type="embedding", metadata={"embedding_dimension": 384}) + ] + ) + mock_routing_table.register_vector_db = AsyncMock( + return_value=Mock(identifier="vs_123", provider_id="inline::faiss", provider_resource_id="vs_123") + ) + mock_routing_table.get_provider_impl = AsyncMock( + return_value=Mock(openai_create_vector_store=AsyncMock(return_value=Mock(id="vs_123"))) + ) + router = VectorIORouter(mock_routing_table) + request = OpenAICreateVectorStoreRequestWithExtraBody.model_validate( + {"name": "test_store", "embedding_model": "all-MiniLM-L6-v2"} + ) + + result = await router.openai_create_vector_store(request) + assert result.id == "vs_123" + + +async def test_create_vector_stores_multiple_providers_missing_provider_id_error(): + # if multiple providers are available, vector store create will error without provider_id + mock_routing_table = Mock() + mock_routing_table.impls_by_provider_id = { + "inline::faiss": "mock_provider_1", + "inline::sqlite-vec": "mock_provider_2", + } + mock_routing_table.get_all_with_type = AsyncMock( + return_value=[ + Mock(identifier="all-MiniLM-L6-v2", model_type="embedding", metadata={"embedding_dimension": 384}) + ] + ) + router = VectorIORouter(mock_routing_table) + request = OpenAICreateVectorStoreRequestWithExtraBody.model_validate( + {"name": "test_store", "embedding_model": "all-MiniLM-L6-v2"} + ) + + with pytest.raises(ValueError, match="Multiple vector_io providers available"): + await router.openai_create_vector_store(request) diff --git a/tests/unit/server/test_auth.py b/tests/unit/server/test_auth.py index 9dbabe195..04ae89db8 100644 --- a/tests/unit/server/test_auth.py +++ b/tests/unit/server/test_auth.py @@ -5,7 +5,8 @@ # the root directory of this source tree. import base64 -from unittest.mock import AsyncMock, patch +import json +from unittest.mock import AsyncMock, Mock, patch import pytest from fastapi import FastAPI @@ -374,7 +375,7 @@ async def mock_jwks_response(*args, **kwargs): @pytest.fixture def jwt_token_valid(): - from jose import jwt + import jwt return jwt.encode( { @@ -389,8 +390,30 @@ def jwt_token_valid(): ) -@patch("httpx.AsyncClient.get", new=mock_jwks_response) -def test_valid_oauth2_authentication(oauth2_client, jwt_token_valid): +@pytest.fixture +def mock_jwks_urlopen(): + """Mock urllib.request.urlopen for PyJWKClient JWKS requests.""" + with patch("urllib.request.urlopen") as mock_urlopen: + # Mock the JWKS response for PyJWKClient + mock_response = Mock() + mock_response.read.return_value = json.dumps( + { + "keys": [ + { + "kid": "1234567890", + "kty": "oct", + "alg": "HS256", + "use": "sig", + "k": base64.b64encode(b"foobarbaz").decode(), + } + ] + } + ).encode() + mock_urlopen.return_value.__enter__.return_value = mock_response + yield mock_urlopen + + +def test_valid_oauth2_authentication(oauth2_client, jwt_token_valid, mock_jwks_urlopen): response = oauth2_client.get("/test", headers={"Authorization": f"Bearer {jwt_token_valid}"}) assert response.status_code == 200 assert response.json() == {"message": "Authentication successful"} @@ -447,8 +470,7 @@ def test_oauth2_with_jwks_token_expected(oauth2_client, jwt_token_valid): assert response.status_code == 401 -@patch("httpx.AsyncClient.get", new=mock_auth_jwks_response) -def test_oauth2_with_jwks_token_configured(oauth2_client_with_jwks_token, jwt_token_valid): +def test_oauth2_with_jwks_token_configured(oauth2_client_with_jwks_token, jwt_token_valid, mock_jwks_urlopen): response = oauth2_client_with_jwks_token.get("/test", headers={"Authorization": f"Bearer {jwt_token_valid}"}) assert response.status_code == 200 assert response.json() == {"message": "Authentication successful"} diff --git a/uv.lock b/uv.lock index 0fcb02768..747e82aaa 100644 --- a/uv.lock +++ b/uv.lock @@ -874,18 +874,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/b0/0d/9feae160378a3553fa9a339b0e9c1a048e147a4127210e286ef18b730f03/durationpy-0.10-py3-none-any.whl", hash = "sha256:3b41e1b601234296b4fb368338fdcd3e13e0b4fb5b67345948f4f2bf9868b286", size = 3922, upload-time = "2025-05-17T13:52:36.463Z" }, ] -[[package]] -name = "ecdsa" -version = "0.19.1" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "six" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/c0/1f/924e3caae75f471eae4b26bd13b698f6af2c44279f67af317439c2f4c46a/ecdsa-0.19.1.tar.gz", hash = "sha256:478cba7b62555866fcb3bb3fe985e06decbdb68ef55713c4e5ab98c57d508e61", size = 201793, upload-time = "2025-03-13T11:52:43.25Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/cb/a3/460c57f094a4a165c84a1341c373b0a4f5ec6ac244b998d5021aade89b77/ecdsa-0.19.1-py2.py3-none-any.whl", hash = "sha256:30638e27cf77b7e15c4c4cc1973720149e1033827cfd00661ca5c8cc0cdb24c3", size = 150607, upload-time = "2025-03-13T11:52:41.757Z" }, -] - [[package]] name = "eval-type-backport" version = "0.2.2" @@ -1787,8 +1775,8 @@ dependencies = [ { name = "pillow" }, { name = "prompt-toolkit" }, { name = "pydantic" }, + { name = "pyjwt", extra = ["crypto"] }, { name = "python-dotenv" }, - { name = "python-jose", extra = ["cryptography"] }, { name = "python-multipart" }, { name = "rich" }, { name = "sqlalchemy", extra = ["asyncio"] }, @@ -1910,8 +1898,8 @@ requires-dist = [ { name = "pillow" }, { name = "prompt-toolkit" }, { name = "pydantic", specifier = ">=2.11.9" }, + { name = "pyjwt", extras = ["crypto"], specifier = ">=2.10.0" }, { name = "python-dotenv" }, - { name = "python-jose", extras = ["cryptography"] }, { name = "python-multipart", specifier = ">=0.0.20" }, { name = "rich" }, { name = "sqlalchemy", extras = ["asyncio"], specifier = ">=2.0.41" }, @@ -3558,6 +3546,20 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/c7/21/705964c7812476f378728bdf590ca4b771ec72385c533964653c68e86bdc/pygments-2.19.2-py3-none-any.whl", hash = "sha256:86540386c03d588bb81d44bc3928634ff26449851e99741617ecb9037ee5ec0b", size = 1225217, upload-time = "2025-06-21T13:39:07.939Z" }, ] +[[package]] +name = "pyjwt" +version = "2.10.1" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/e7/46/bd74733ff231675599650d3e47f361794b22ef3e3770998dda30d3b63726/pyjwt-2.10.1.tar.gz", hash = "sha256:3cc5772eb20009233caf06e9d8a0577824723b44e6648ee0a2aedb6cf9381953", size = 87785, upload-time = "2024-11-28T03:43:29.933Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/61/ad/689f02752eeec26aed679477e80e632ef1b682313be70793d798c1d5fc8f/PyJWT-2.10.1-py3-none-any.whl", hash = "sha256:dcdd193e30abefd5debf142f9adfcdd2b58004e644f25406ffaebd50bd98dacb", size = 22997, upload-time = "2024-11-28T03:43:27.893Z" }, +] + +[package.optional-dependencies] +crypto = [ + { name = "cryptography" }, +] + [[package]] name = "pymilvus" version = "2.6.1" @@ -3747,25 +3749,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/0c/fa/df59acedf7bbb937f69174d00f921a7b93aa5a5f5c17d05296c814fff6fc/python_engineio-4.12.2-py3-none-any.whl", hash = "sha256:8218ab66950e179dfec4b4bbb30aecf3f5d86f5e58e6fc1aa7fde2c698b2804f", size = 59536, upload-time = "2025-06-04T19:22:16.916Z" }, ] -[[package]] -name = "python-jose" -version = "3.5.0" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "ecdsa" }, - { name = "pyasn1" }, - { name = "rsa" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/c6/77/3a1c9039db7124eb039772b935f2244fbb73fc8ee65b9acf2375da1c07bf/python_jose-3.5.0.tar.gz", hash = "sha256:fb4eaa44dbeb1c26dcc69e4bd7ec54a1cb8dd64d3b4d81ef08d90ff453f2b01b", size = 92726, upload-time = "2025-05-28T17:31:54.288Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/d9/c3/0bd11992072e6a1c513b16500a5d07f91a24017c5909b02c72c62d7ad024/python_jose-3.5.0-py2.py3-none-any.whl", hash = "sha256:abd1202f23d34dfad2c3d28cb8617b90acf34132c7afd60abd0b0b7d3cb55771", size = 34624, upload-time = "2025-05-28T17:31:52.802Z" }, -] - -[package.optional-dependencies] -cryptography = [ - { name = "cryptography" }, -] - [[package]] name = "python-multipart" version = "0.0.20"