From 642126e13b9d42cdc95bc5169d8179a293f523e2 Mon Sep 17 00:00:00 2001 From: Derek Higgins Date: Mon, 13 Oct 2025 17:55:55 +0100 Subject: [PATCH 1/4] fix: record job checking wrong directory (#3799) Fixed CI job to check the correct directory for file changes Artifacts are now stored in multiple directories not just ./tests/integration/recordings Signed-off-by: Derek Higgins --- .github/actions/run-and-record-tests/action.yml | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/.github/actions/run-and-record-tests/action.yml b/.github/actions/run-and-record-tests/action.yml index d240381c5..a5aa31af4 100644 --- a/.github/actions/run-and-record-tests/action.yml +++ b/.github/actions/run-and-record-tests/action.yml @@ -66,11 +66,11 @@ runs: shell: bash run: | echo "Checking for recording changes" - git status --porcelain tests/integration/recordings/ + git status --porcelain tests/integration/ - if [[ -n $(git status --porcelain tests/integration/recordings/) ]]; then + if [[ -n $(git status --porcelain tests/integration/) ]]; then echo "New recordings detected, committing and pushing" - git add tests/integration/recordings/ + git add tests/integration/ git commit -m "Recordings update from CI (suite: ${{ inputs.suite }})" git fetch origin ${{ github.ref_name }} From 968c364a3ea0f7134c90e6c342f8bfb22499e71d Mon Sep 17 00:00:00 2001 From: Francisco Arceo Date: Mon, 13 Oct 2025 13:25:36 -0400 Subject: [PATCH 2/4] =?UTF-8?q?chore:=20Auto-detect=20Provider=20ID=20when?= =?UTF-8?q?=20only=201=20Vector=20Store=20Provider=20avai=E2=80=A6=20(#380?= =?UTF-8?q?2)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit # What does this PR do? 2 main changes: 1. Remove `provider_id` requirement in call to vector stores and 2. Removes "register first embedding model" logic - Now forces embedding model id as required on Vector Store creation Simplifies the UX for OpenAI to: ```python vs = client.vector_stores.create( name="my_citations_db", extra_body={ "embedding_model": "ollama/nomic-embed-text:latest", } ) ``` ## Test Plan --------- Signed-off-by: Francisco Javier Arceo --- llama_stack/core/routers/vector_io.py | 60 ++++++++-------- .../utils/memory/openai_vector_store_mixin.py | 9 +-- .../vector_io/test_openai_vector_stores.py | 70 ------------------- tests/integration/vector_io/test_vector_io.py | 42 ++++++++--- tests/unit/core/routers/test_vector_io.py | 57 +++++++++++++++ 5 files changed, 123 insertions(+), 115 deletions(-) create mode 100644 tests/unit/core/routers/test_vector_io.py diff --git a/llama_stack/core/routers/vector_io.py b/llama_stack/core/routers/vector_io.py index 79789ef0a..dc7b3a694 100644 --- a/llama_stack/core/routers/vector_io.py +++ b/llama_stack/core/routers/vector_io.py @@ -55,30 +55,18 @@ class VectorIORouter(VectorIO): logger.debug("VectorIORouter.shutdown") pass - async def _get_first_embedding_model(self) -> tuple[str, int] | None: - """Get the first available embedding model identifier.""" - try: - # Get all models from the routing table - all_models = await self.routing_table.get_all_with_type("model") + async def _get_embedding_model_dimension(self, embedding_model_id: str) -> int: + """Get the embedding dimension for a specific embedding model.""" + all_models = await self.routing_table.get_all_with_type("model") - # Filter for embedding models - embedding_models = [ - model - for model in all_models - if hasattr(model, "model_type") and model.model_type == ModelType.embedding - ] - - if embedding_models: - dimension = embedding_models[0].metadata.get("embedding_dimension", None) + for model in all_models: + if model.identifier == embedding_model_id and model.model_type == ModelType.embedding: + dimension = model.metadata.get("embedding_dimension") if dimension is None: - raise ValueError(f"Embedding model {embedding_models[0].identifier} has no embedding dimension") - return embedding_models[0].identifier, dimension - else: - logger.warning("No embedding models found in the routing table") - return None - except Exception as e: - logger.error(f"Error getting embedding models: {e}") - return None + raise ValueError(f"Embedding model '{embedding_model_id}' has no embedding_dimension in metadata") + return int(dimension) + + raise ValueError(f"Embedding model '{embedding_model_id}' not found or not an embedding model") async def register_vector_db( self, @@ -129,20 +117,30 @@ class VectorIORouter(VectorIO): # Extract llama-stack-specific parameters from extra_body extra = params.model_extra or {} embedding_model = extra.get("embedding_model") - embedding_dimension = extra.get("embedding_dimension", 384) + embedding_dimension = extra.get("embedding_dimension") provider_id = extra.get("provider_id") logger.debug(f"VectorIORouter.openai_create_vector_store: name={params.name}, provider_id={provider_id}") - # If no embedding model is provided, use the first available one - # TODO: this branch will soon be deleted so you _must_ provide the embedding_model when - # creating a vector store + # Require explicit embedding model specification if embedding_model is None: - embedding_model_info = await self._get_first_embedding_model() - if embedding_model_info is None: - raise ValueError("No embedding model provided and no embedding models available in the system") - embedding_model, embedding_dimension = embedding_model_info - logger.info(f"No embedding model specified, using first available: {embedding_model}") + raise ValueError("embedding_model is required in extra_body when creating a vector store") + + if embedding_dimension is None: + embedding_dimension = await self._get_embedding_model_dimension(embedding_model) + + # Auto-select provider if not specified + if provider_id is None: + num_providers = len(self.routing_table.impls_by_provider_id) + if num_providers == 0: + raise ValueError("No vector_io providers available") + if num_providers > 1: + available_providers = list(self.routing_table.impls_by_provider_id.keys()) + raise ValueError( + f"Multiple vector_io providers available. Please specify provider_id in extra_body. " + f"Available providers: {available_providers}" + ) + provider_id = list(self.routing_table.impls_by_provider_id.keys())[0] vector_db_id = f"vs_{uuid.uuid4()}" registered_vector_db = await self.routing_table.register_vector_db( diff --git a/llama_stack/providers/utils/memory/openai_vector_store_mixin.py b/llama_stack/providers/utils/memory/openai_vector_store_mixin.py index 70bcbba32..02c3d9730 100644 --- a/llama_stack/providers/utils/memory/openai_vector_store_mixin.py +++ b/llama_stack/providers/utils/memory/openai_vector_store_mixin.py @@ -353,14 +353,12 @@ class OpenAIVectorStoreMixin(ABC): provider_vector_db_id = extra.get("provider_vector_db_id") embedding_model = extra.get("embedding_model") embedding_dimension = extra.get("embedding_dimension", 384) - provider_id = extra.get("provider_id") + # use provider_id set by router; fallback to provider's own ID when used directly via --stack-config + provider_id = extra.get("provider_id") or getattr(self, "__provider_id__", None) # Derive the canonical vector_db_id (allow override, else generate) vector_db_id = provider_vector_db_id or generate_object_id("vector_store", lambda: f"vs_{uuid.uuid4()}") - if provider_id is None: - raise ValueError("Provider ID is required") - if embedding_model is None: raise ValueError("Embedding model is required") @@ -369,6 +367,9 @@ class OpenAIVectorStoreMixin(ABC): raise ValueError("Embedding dimension is required") # Register the VectorDB backing this vector store + if provider_id is None: + raise ValueError("Provider ID is required but was not provided") + vector_db = VectorDB( identifier=vector_db_id, embedding_dimension=embedding_dimension, diff --git a/tests/integration/vector_io/test_openai_vector_stores.py b/tests/integration/vector_io/test_openai_vector_stores.py index 347b43145..904e382e1 100644 --- a/tests/integration/vector_io/test_openai_vector_stores.py +++ b/tests/integration/vector_io/test_openai_vector_stores.py @@ -146,8 +146,6 @@ def test_openai_create_vector_store( metadata={"purpose": "testing", "environment": "integration"}, extra_body={ "embedding_model": embedding_model_id, - "embedding_dimension": embedding_dimension, - "provider_id": "my_provider", }, ) @@ -175,8 +173,6 @@ def test_openai_list_vector_stores( metadata={"type": "test"}, extra_body={ "embedding_model": embedding_model_id, - "embedding_dimension": embedding_dimension, - "provider_id": "my_provider", }, ) store2 = client.vector_stores.create( @@ -184,8 +180,6 @@ def test_openai_list_vector_stores( metadata={"type": "test"}, extra_body={ "embedding_model": embedding_model_id, - "embedding_dimension": embedding_dimension, - "provider_id": "my_provider", }, ) @@ -220,8 +214,6 @@ def test_openai_retrieve_vector_store( metadata={"purpose": "retrieval_test"}, extra_body={ "embedding_model": embedding_model_id, - "embedding_dimension": embedding_dimension, - "provider_id": "my_provider", }, ) @@ -249,8 +241,6 @@ def test_openai_update_vector_store( metadata={"version": "1.0"}, extra_body={ "embedding_model": embedding_model_id, - "embedding_dimension": embedding_dimension, - "provider_id": "my_provider", }, ) time.sleep(1) @@ -282,8 +272,6 @@ def test_openai_delete_vector_store( metadata={"purpose": "deletion_test"}, extra_body={ "embedding_model": embedding_model_id, - "embedding_dimension": embedding_dimension, - "provider_id": "my_provider", }, ) @@ -314,8 +302,6 @@ def test_openai_vector_store_search_empty( metadata={"purpose": "search_testing"}, extra_body={ "embedding_model": embedding_model_id, - "embedding_dimension": embedding_dimension, - "provider_id": "my_provider", }, ) @@ -346,8 +332,6 @@ def test_openai_vector_store_with_chunks( metadata={"purpose": "chunks_testing"}, extra_body={ "embedding_model": embedding_model_id, - "embedding_dimension": embedding_dimension, - "provider_id": "my_provider", }, ) @@ -412,8 +396,6 @@ def test_openai_vector_store_search_relevance( metadata={"purpose": "relevance_testing"}, extra_body={ "embedding_model": embedding_model_id, - "embedding_dimension": embedding_dimension, - "provider_id": "my_provider", }, ) @@ -457,8 +439,6 @@ def test_openai_vector_store_search_with_ranking_options( metadata={"purpose": "ranking_testing"}, extra_body={ "embedding_model": embedding_model_id, - "embedding_dimension": embedding_dimension, - "provider_id": "my_provider", }, ) @@ -500,8 +480,6 @@ def test_openai_vector_store_search_with_high_score_filter( metadata={"purpose": "high_score_filtering"}, extra_body={ "embedding_model": embedding_model_id, - "embedding_dimension": embedding_dimension, - "provider_id": "my_provider", }, ) @@ -561,8 +539,6 @@ def test_openai_vector_store_search_with_max_num_results( metadata={"purpose": "max_num_results_testing"}, extra_body={ "embedding_model": embedding_model_id, - "embedding_dimension": embedding_dimension, - "provider_id": "my_provider", }, ) @@ -596,8 +572,6 @@ def test_openai_vector_store_attach_file( name="test_store", extra_body={ "embedding_model": embedding_model_id, - "embedding_dimension": embedding_dimension, - "provider_id": "my_provider", }, ) @@ -666,8 +640,6 @@ def test_openai_vector_store_attach_files_on_creation( file_ids=file_ids, extra_body={ "embedding_model": embedding_model_id, - "embedding_dimension": embedding_dimension, - "provider_id": "my_provider", }, ) @@ -713,8 +685,6 @@ def test_openai_vector_store_list_files( name="test_store", extra_body={ "embedding_model": embedding_model_id, - "embedding_dimension": embedding_dimension, - "provider_id": "my_provider", }, ) @@ -799,8 +769,6 @@ def test_openai_vector_store_retrieve_file_contents( name="test_store", extra_body={ "embedding_model": embedding_model_id, - "embedding_dimension": embedding_dimension, - "provider_id": "my_provider", }, ) @@ -819,8 +787,6 @@ def test_openai_vector_store_retrieve_file_contents( attributes=attributes, extra_body={ "embedding_model": embedding_model_id, - "embedding_dimension": embedding_dimension, - "provider_id": "my_provider", }, ) @@ -857,8 +823,6 @@ def test_openai_vector_store_delete_file( name="test_store", extra_body={ "embedding_model": embedding_model_id, - "embedding_dimension": embedding_dimension, - "provider_id": "my_provider", }, ) @@ -918,8 +882,6 @@ def test_openai_vector_store_delete_file_removes_from_vector_store( name="test_store", extra_body={ "embedding_model": embedding_model_id, - "embedding_dimension": embedding_dimension, - "provider_id": "my_provider", }, ) @@ -965,8 +927,6 @@ def test_openai_vector_store_update_file( name="test_store", extra_body={ "embedding_model": embedding_model_id, - "embedding_dimension": embedding_dimension, - "provider_id": "my_provider", }, ) @@ -1026,8 +986,6 @@ def test_create_vector_store_files_duplicate_vector_store_name( name="test_store_with_files", extra_body={ "embedding_model": embedding_model_id, - "embedding_dimension": embedding_dimension, - "provider_id": "my_provider", }, ) assert vector_store.file_counts.completed == 0 @@ -1040,8 +998,6 @@ def test_create_vector_store_files_duplicate_vector_store_name( name="test_store_with_files", extra_body={ "embedding_model": embedding_model_id, - "embedding_dimension": embedding_dimension, - "provider_id": "my_provider", }, ) @@ -1053,8 +1009,6 @@ def test_create_vector_store_files_duplicate_vector_store_name( file_id=file_ids[0], extra_body={ "embedding_model": embedding_model_id, - "embedding_dimension": embedding_dimension, - "provider_id": "my_provider", }, ) assert created_file.status == "completed" @@ -1065,8 +1019,6 @@ def test_create_vector_store_files_duplicate_vector_store_name( file_id=file_ids[1], extra_body={ "embedding_model": embedding_model_id, - "embedding_dimension": embedding_dimension, - "provider_id": "my_provider", }, ) assert created_file_from_non_deleted_vector_store.status == "completed" @@ -1087,8 +1039,6 @@ def test_openai_vector_store_search_modes( metadata={"purpose": "search_mode_testing"}, extra_body={ "embedding_model": embedding_model_id, - "embedding_dimension": embedding_dimension, - "provider_id": "my_provider", }, ) @@ -1120,8 +1070,6 @@ def test_openai_vector_store_file_batch_create_and_retrieve( name="batch_test_store", extra_body={ "embedding_model": embedding_model_id, - "embedding_dimension": embedding_dimension, - "provider_id": "my_provider", }, ) @@ -1139,8 +1087,6 @@ def test_openai_vector_store_file_batch_create_and_retrieve( file_ids=file_ids, extra_body={ "embedding_model": embedding_model_id, - "embedding_dimension": embedding_dimension, - "provider_id": "my_provider", }, ) @@ -1187,8 +1133,6 @@ def test_openai_vector_store_file_batch_list_files( name="batch_list_test_store", extra_body={ "embedding_model": embedding_model_id, - "embedding_dimension": embedding_dimension, - "provider_id": "my_provider", }, ) @@ -1206,8 +1150,6 @@ def test_openai_vector_store_file_batch_list_files( file_ids=file_ids, extra_body={ "embedding_model": embedding_model_id, - "embedding_dimension": embedding_dimension, - "provider_id": "my_provider", }, ) @@ -1284,8 +1226,6 @@ def test_openai_vector_store_file_batch_cancel( name="batch_cancel_test_store", extra_body={ "embedding_model": embedding_model_id, - "embedding_dimension": embedding_dimension, - "provider_id": "my_provider", }, ) @@ -1303,8 +1243,6 @@ def test_openai_vector_store_file_batch_cancel( file_ids=file_ids, extra_body={ "embedding_model": embedding_model_id, - "embedding_dimension": embedding_dimension, - "provider_id": "my_provider", }, ) @@ -1343,8 +1281,6 @@ def test_openai_vector_store_file_batch_retrieve_contents( name="batch_contents_test_store", extra_body={ "embedding_model": embedding_model_id, - "embedding_dimension": embedding_dimension, - "provider_id": "my_provider", }, ) @@ -1367,8 +1303,6 @@ def test_openai_vector_store_file_batch_retrieve_contents( file_ids=file_ids, extra_body={ "embedding_model": embedding_model_id, - "embedding_dimension": embedding_dimension, - "provider_id": "my_provider", }, ) @@ -1420,8 +1354,6 @@ def test_openai_vector_store_file_batch_error_handling( name="batch_error_test_store", extra_body={ "embedding_model": embedding_model_id, - "embedding_dimension": embedding_dimension, - "provider_id": "my_provider", }, ) @@ -1433,8 +1365,6 @@ def test_openai_vector_store_file_batch_error_handling( file_ids=file_ids, extra_body={ "embedding_model": embedding_model_id, - "embedding_dimension": embedding_dimension, - "provider_id": "my_provider", }, ) diff --git a/tests/integration/vector_io/test_vector_io.py b/tests/integration/vector_io/test_vector_io.py index f2205ed0a..653299338 100644 --- a/tests/integration/vector_io/test_vector_io.py +++ b/tests/integration/vector_io/test_vector_io.py @@ -52,8 +52,6 @@ def test_vector_db_retrieve(client_with_empty_registry, embedding_model_id, embe name=vector_db_name, extra_body={ "embedding_model": embedding_model_id, - "embedding_dimension": embedding_dimension, - "provider_id": "my_provider", }, ) @@ -73,8 +71,6 @@ def test_vector_db_register(client_with_empty_registry, embedding_model_id, embe name=vector_db_name, extra_body={ "embedding_model": embedding_model_id, - "embedding_dimension": embedding_dimension, - "provider_id": "my_provider", }, ) @@ -110,8 +106,6 @@ def test_insert_chunks(client_with_empty_registry, embedding_model_id, embedding name=vector_db_name, extra_body={ "embedding_model": embedding_model_id, - "embedding_dimension": embedding_dimension, - "provider_id": "my_provider", }, ) @@ -152,8 +146,6 @@ def test_insert_chunks_with_precomputed_embeddings(client_with_empty_registry, e name=vector_db_name, extra_body={ "embedding_model": embedding_model_id, - "embedding_dimension": embedding_dimension, - "provider_id": "my_provider", }, ) @@ -202,8 +194,6 @@ def test_query_returns_valid_object_when_identical_to_embedding_in_vdb( name=vector_db_name, extra_body={ "embedding_model": embedding_model_id, - "embedding_dimension": embedding_dimension, - "provider_id": "my_provider", }, ) @@ -234,3 +224,35 @@ def test_query_returns_valid_object_when_identical_to_embedding_in_vdb( assert len(response.chunks) > 0 assert response.chunks[0].metadata["document_id"] == "doc1" assert response.chunks[0].metadata["source"] == "precomputed" + + +def test_auto_extract_embedding_dimension(client_with_empty_registry, embedding_model_id): + vs = client_with_empty_registry.vector_stores.create( + name="test_auto_extract", extra_body={"embedding_model": embedding_model_id} + ) + assert vs.id is not None + + +def test_provider_auto_selection_single_provider(client_with_empty_registry, embedding_model_id): + providers = [p for p in client_with_empty_registry.providers.list() if p.api == "vector_io"] + if len(providers) != 1: + pytest.skip(f"Test requires exactly one vector_io provider, found {len(providers)}") + + vs = client_with_empty_registry.vector_stores.create( + name="test_auto_provider", extra_body={"embedding_model": embedding_model_id} + ) + assert vs.id is not None + + +def test_provider_id_override(client_with_empty_registry, embedding_model_id): + providers = [p for p in client_with_empty_registry.providers.list() if p.api == "vector_io"] + if len(providers) != 1: + pytest.skip(f"Test requires exactly one vector_io provider, found {len(providers)}") + + provider_id = providers[0].provider_id + + vs = client_with_empty_registry.vector_stores.create( + name="test_provider_override", extra_body={"embedding_model": embedding_model_id, "provider_id": provider_id} + ) + assert vs.id is not None + assert vs.metadata.get("provider_id") == provider_id diff --git a/tests/unit/core/routers/test_vector_io.py b/tests/unit/core/routers/test_vector_io.py new file mode 100644 index 000000000..997df0d78 --- /dev/null +++ b/tests/unit/core/routers/test_vector_io.py @@ -0,0 +1,57 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from unittest.mock import AsyncMock, Mock + +import pytest + +from llama_stack.apis.vector_io import OpenAICreateVectorStoreRequestWithExtraBody +from llama_stack.core.routers.vector_io import VectorIORouter + + +async def test_single_provider_auto_selection(): + # provider_id automatically selected during vector store create() when only one provider available + mock_routing_table = Mock() + mock_routing_table.impls_by_provider_id = {"inline::faiss": "mock_provider"} + mock_routing_table.get_all_with_type = AsyncMock( + return_value=[ + Mock(identifier="all-MiniLM-L6-v2", model_type="embedding", metadata={"embedding_dimension": 384}) + ] + ) + mock_routing_table.register_vector_db = AsyncMock( + return_value=Mock(identifier="vs_123", provider_id="inline::faiss", provider_resource_id="vs_123") + ) + mock_routing_table.get_provider_impl = AsyncMock( + return_value=Mock(openai_create_vector_store=AsyncMock(return_value=Mock(id="vs_123"))) + ) + router = VectorIORouter(mock_routing_table) + request = OpenAICreateVectorStoreRequestWithExtraBody.model_validate( + {"name": "test_store", "embedding_model": "all-MiniLM-L6-v2"} + ) + + result = await router.openai_create_vector_store(request) + assert result.id == "vs_123" + + +async def test_create_vector_stores_multiple_providers_missing_provider_id_error(): + # if multiple providers are available, vector store create will error without provider_id + mock_routing_table = Mock() + mock_routing_table.impls_by_provider_id = { + "inline::faiss": "mock_provider_1", + "inline::sqlite-vec": "mock_provider_2", + } + mock_routing_table.get_all_with_type = AsyncMock( + return_value=[ + Mock(identifier="all-MiniLM-L6-v2", model_type="embedding", metadata={"embedding_dimension": 384}) + ] + ) + router = VectorIORouter(mock_routing_table) + request = OpenAICreateVectorStoreRequestWithExtraBody.model_validate( + {"name": "test_store", "embedding_model": "all-MiniLM-L6-v2"} + ) + + with pytest.raises(ValueError, match="Multiple vector_io providers available"): + await router.openai_create_vector_store(request) From 1136daf310b6f9cf5215fc682e0b37d242b2ebdc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?S=C3=A9bastien=20Han?= Date: Tue, 14 Oct 2025 09:35:48 +0200 Subject: [PATCH 3/4] fix: replace python-jose with PyJWT for JWT handling (#3756) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit # What does this PR do? This commit migrates the authentication system from python-jose to PyJWT to eliminate the dependency on the archived rsa package. The migration includes: - Refactored OAuth2TokenAuthProvider to use PyJWT's PyJWKClient for clean JWKS handling - Removed manual JWKS fetching, caching and key extraction logic in favor of PyJWT's built-in functionality The new implementation is cleaner, more maintainable, and follows PyJWT best practices while maintaining full backward compatibility. ## Test Plan Unit tests. Auth CI. --------- Signed-off-by: Sébastien Han --- llama_stack/core/server/auth_providers.py | 94 ++++++++++++----------- pyproject.toml | 2 +- tests/unit/server/test_auth.py | 34 ++++++-- uv.lock | 49 ++++-------- 4 files changed, 93 insertions(+), 86 deletions(-) diff --git a/llama_stack/core/server/auth_providers.py b/llama_stack/core/server/auth_providers.py index 38188c49a..05a21c8d4 100644 --- a/llama_stack/core/server/auth_providers.py +++ b/llama_stack/core/server/auth_providers.py @@ -5,13 +5,11 @@ # the root directory of this source tree. import ssl -import time from abc import ABC, abstractmethod -from asyncio import Lock from urllib.parse import parse_qs, urljoin, urlparse import httpx -from jose import jwt +import jwt from pydantic import BaseModel, Field from llama_stack.apis.common.errors import TokenValidationError @@ -98,9 +96,7 @@ class OAuth2TokenAuthProvider(AuthProvider): def __init__(self, config: OAuth2TokenAuthConfig): self.config = config - self._jwks_at: float = 0.0 - self._jwks: dict[str, str] = {} - self._jwks_lock = Lock() + self._jwks_client: jwt.PyJWKClient | None = None async def validate_token(self, token: str, scope: dict | None = None) -> User: if self.config.jwks: @@ -109,23 +105,60 @@ class OAuth2TokenAuthProvider(AuthProvider): return await self.introspect_token(token, scope) raise ValueError("One of jwks or introspection must be configured") + def _get_jwks_client(self) -> jwt.PyJWKClient: + if self._jwks_client is None: + ssl_context = None + if not self.config.verify_tls: + # Disable SSL verification if verify_tls is False + ssl_context = ssl.create_default_context() + ssl_context.check_hostname = False + ssl_context.verify_mode = ssl.CERT_NONE + elif self.config.tls_cafile: + # Use custom CA file if provided + ssl_context = ssl.create_default_context( + cafile=self.config.tls_cafile.as_posix(), + ) + # If verify_tls is True and no tls_cafile, ssl_context remains None (use system defaults) + + # Prepare headers for JWKS request - this is needed for Kubernetes to authenticate + # to the JWK endpoint, we must use the token in the config to authenticate + headers = {} + if self.config.jwks and self.config.jwks.token: + headers["Authorization"] = f"Bearer {self.config.jwks.token}" + + self._jwks_client = jwt.PyJWKClient( + self.config.jwks.uri if self.config.jwks else None, + cache_keys=True, + max_cached_keys=10, + lifespan=self.config.jwks.key_recheck_period if self.config.jwks else None, + headers=headers, + ssl_context=ssl_context, + ) + return self._jwks_client + async def validate_jwt_token(self, token: str, scope: dict | None = None) -> User: """Validate a token using the JWT token.""" - await self._refresh_jwks() - try: - header = jwt.get_unverified_header(token) - kid = header["kid"] - if kid not in self._jwks: - raise ValueError(f"Unknown key ID: {kid}") - key_data = self._jwks[kid] - algorithm = header.get("alg", "RS256") + jwks_client: jwt.PyJWKClient = self._get_jwks_client() + signing_key = jwks_client.get_signing_key_from_jwt(token) + algorithm = jwt.get_unverified_header(token)["alg"] claims = jwt.decode( token, - key_data, + signing_key.key, algorithms=[algorithm], audience=self.config.audience, issuer=self.config.issuer, + options={"verify_exp": True, "verify_aud": True, "verify_iss": True}, + ) + + # Decode and verify the JWT + claims = jwt.decode( + token, + signing_key.key, + algorithms=[algorithm], + audience=self.config.audience, + issuer=self.config.issuer, + options={"verify_exp": True, "verify_aud": True, "verify_iss": True}, ) except Exception as exc: raise ValueError("Invalid JWT token") from exc @@ -201,37 +234,6 @@ class OAuth2TokenAuthProvider(AuthProvider): else: return "Authentication required. Please provide a valid OAuth2 Bearer token in the Authorization header" - async def _refresh_jwks(self) -> None: - """ - Refresh the JWKS cache. - - This is a simple cache that expires after a certain amount of time (defined by `key_recheck_period`). - If the cache is expired, we refresh the JWKS from the JWKS URI. - - Notes: for Kubernetes which doesn't fully implement the OIDC protocol: - * It doesn't have user authentication flows - * It doesn't have refresh tokens - """ - async with self._jwks_lock: - if self.config.jwks is None: - raise ValueError("JWKS is not configured") - if time.time() - self._jwks_at > self.config.jwks.key_recheck_period: - headers = {} - if self.config.jwks.token: - headers["Authorization"] = f"Bearer {self.config.jwks.token}" - verify = self.config.tls_cafile.as_posix() if self.config.tls_cafile else self.config.verify_tls - async with httpx.AsyncClient(verify=verify) as client: - res = await client.get(self.config.jwks.uri, timeout=5, headers=headers) - res.raise_for_status() - jwks_data = res.json()["keys"] - updated = {} - for k in jwks_data: - kid = k["kid"] - # Store the entire key object as it may be needed for different algorithms - updated[kid] = k - self._jwks = updated - self._jwks_at = time.time() - class CustomAuthProvider(AuthProvider): """Custom authentication provider that uses an external endpoint.""" diff --git a/pyproject.toml b/pyproject.toml index 81997c249..d55de794d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -34,7 +34,7 @@ dependencies = [ "openai>=1.107", # for expires_after support "prompt-toolkit", "python-dotenv", - "python-jose[cryptography]", + "pyjwt[crypto]>=2.10.0", # Pull crypto to support RS256 for jwt. Requires 2.10.0+ for ssl_context support. "pydantic>=2.11.9", "rich", "starlette", diff --git a/tests/unit/server/test_auth.py b/tests/unit/server/test_auth.py index 9dbabe195..04ae89db8 100644 --- a/tests/unit/server/test_auth.py +++ b/tests/unit/server/test_auth.py @@ -5,7 +5,8 @@ # the root directory of this source tree. import base64 -from unittest.mock import AsyncMock, patch +import json +from unittest.mock import AsyncMock, Mock, patch import pytest from fastapi import FastAPI @@ -374,7 +375,7 @@ async def mock_jwks_response(*args, **kwargs): @pytest.fixture def jwt_token_valid(): - from jose import jwt + import jwt return jwt.encode( { @@ -389,8 +390,30 @@ def jwt_token_valid(): ) -@patch("httpx.AsyncClient.get", new=mock_jwks_response) -def test_valid_oauth2_authentication(oauth2_client, jwt_token_valid): +@pytest.fixture +def mock_jwks_urlopen(): + """Mock urllib.request.urlopen for PyJWKClient JWKS requests.""" + with patch("urllib.request.urlopen") as mock_urlopen: + # Mock the JWKS response for PyJWKClient + mock_response = Mock() + mock_response.read.return_value = json.dumps( + { + "keys": [ + { + "kid": "1234567890", + "kty": "oct", + "alg": "HS256", + "use": "sig", + "k": base64.b64encode(b"foobarbaz").decode(), + } + ] + } + ).encode() + mock_urlopen.return_value.__enter__.return_value = mock_response + yield mock_urlopen + + +def test_valid_oauth2_authentication(oauth2_client, jwt_token_valid, mock_jwks_urlopen): response = oauth2_client.get("/test", headers={"Authorization": f"Bearer {jwt_token_valid}"}) assert response.status_code == 200 assert response.json() == {"message": "Authentication successful"} @@ -447,8 +470,7 @@ def test_oauth2_with_jwks_token_expected(oauth2_client, jwt_token_valid): assert response.status_code == 401 -@patch("httpx.AsyncClient.get", new=mock_auth_jwks_response) -def test_oauth2_with_jwks_token_configured(oauth2_client_with_jwks_token, jwt_token_valid): +def test_oauth2_with_jwks_token_configured(oauth2_client_with_jwks_token, jwt_token_valid, mock_jwks_urlopen): response = oauth2_client_with_jwks_token.get("/test", headers={"Authorization": f"Bearer {jwt_token_valid}"}) assert response.status_code == 200 assert response.json() == {"message": "Authentication successful"} diff --git a/uv.lock b/uv.lock index 0fcb02768..747e82aaa 100644 --- a/uv.lock +++ b/uv.lock @@ -874,18 +874,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/b0/0d/9feae160378a3553fa9a339b0e9c1a048e147a4127210e286ef18b730f03/durationpy-0.10-py3-none-any.whl", hash = "sha256:3b41e1b601234296b4fb368338fdcd3e13e0b4fb5b67345948f4f2bf9868b286", size = 3922, upload-time = "2025-05-17T13:52:36.463Z" }, ] -[[package]] -name = "ecdsa" -version = "0.19.1" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "six" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/c0/1f/924e3caae75f471eae4b26bd13b698f6af2c44279f67af317439c2f4c46a/ecdsa-0.19.1.tar.gz", hash = "sha256:478cba7b62555866fcb3bb3fe985e06decbdb68ef55713c4e5ab98c57d508e61", size = 201793, upload-time = "2025-03-13T11:52:43.25Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/cb/a3/460c57f094a4a165c84a1341c373b0a4f5ec6ac244b998d5021aade89b77/ecdsa-0.19.1-py2.py3-none-any.whl", hash = "sha256:30638e27cf77b7e15c4c4cc1973720149e1033827cfd00661ca5c8cc0cdb24c3", size = 150607, upload-time = "2025-03-13T11:52:41.757Z" }, -] - [[package]] name = "eval-type-backport" version = "0.2.2" @@ -1787,8 +1775,8 @@ dependencies = [ { name = "pillow" }, { name = "prompt-toolkit" }, { name = "pydantic" }, + { name = "pyjwt", extra = ["crypto"] }, { name = "python-dotenv" }, - { name = "python-jose", extra = ["cryptography"] }, { name = "python-multipart" }, { name = "rich" }, { name = "sqlalchemy", extra = ["asyncio"] }, @@ -1910,8 +1898,8 @@ requires-dist = [ { name = "pillow" }, { name = "prompt-toolkit" }, { name = "pydantic", specifier = ">=2.11.9" }, + { name = "pyjwt", extras = ["crypto"], specifier = ">=2.10.0" }, { name = "python-dotenv" }, - { name = "python-jose", extras = ["cryptography"] }, { name = "python-multipart", specifier = ">=0.0.20" }, { name = "rich" }, { name = "sqlalchemy", extras = ["asyncio"], specifier = ">=2.0.41" }, @@ -3558,6 +3546,20 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/c7/21/705964c7812476f378728bdf590ca4b771ec72385c533964653c68e86bdc/pygments-2.19.2-py3-none-any.whl", hash = "sha256:86540386c03d588bb81d44bc3928634ff26449851e99741617ecb9037ee5ec0b", size = 1225217, upload-time = "2025-06-21T13:39:07.939Z" }, ] +[[package]] +name = "pyjwt" +version = "2.10.1" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/e7/46/bd74733ff231675599650d3e47f361794b22ef3e3770998dda30d3b63726/pyjwt-2.10.1.tar.gz", hash = "sha256:3cc5772eb20009233caf06e9d8a0577824723b44e6648ee0a2aedb6cf9381953", size = 87785, upload-time = "2024-11-28T03:43:29.933Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/61/ad/689f02752eeec26aed679477e80e632ef1b682313be70793d798c1d5fc8f/PyJWT-2.10.1-py3-none-any.whl", hash = "sha256:dcdd193e30abefd5debf142f9adfcdd2b58004e644f25406ffaebd50bd98dacb", size = 22997, upload-time = "2024-11-28T03:43:27.893Z" }, +] + +[package.optional-dependencies] +crypto = [ + { name = "cryptography" }, +] + [[package]] name = "pymilvus" version = "2.6.1" @@ -3747,25 +3749,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/0c/fa/df59acedf7bbb937f69174d00f921a7b93aa5a5f5c17d05296c814fff6fc/python_engineio-4.12.2-py3-none-any.whl", hash = "sha256:8218ab66950e179dfec4b4bbb30aecf3f5d86f5e58e6fc1aa7fde2c698b2804f", size = 59536, upload-time = "2025-06-04T19:22:16.916Z" }, ] -[[package]] -name = "python-jose" -version = "3.5.0" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "ecdsa" }, - { name = "pyasn1" }, - { name = "rsa" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/c6/77/3a1c9039db7124eb039772b935f2244fbb73fc8ee65b9acf2375da1c07bf/python_jose-3.5.0.tar.gz", hash = "sha256:fb4eaa44dbeb1c26dcc69e4bd7ec54a1cb8dd64d3b4d81ef08d90ff453f2b01b", size = 92726, upload-time = "2025-05-28T17:31:54.288Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/d9/c3/0bd11992072e6a1c513b16500a5d07f91a24017c5909b02c72c62d7ad024/python_jose-3.5.0-py2.py3-none-any.whl", hash = "sha256:abd1202f23d34dfad2c3d28cb8617b90acf34132c7afd60abd0b0b7d3cb55771", size = 34624, upload-time = "2025-05-28T17:31:52.802Z" }, -] - -[package.optional-dependencies] -cryptography = [ - { name = "cryptography" }, -] - [[package]] name = "python-multipart" version = "0.0.20" From 0dbf79c328d8444cd9fa90891be9a4e9c36588df Mon Sep 17 00:00:00 2001 From: Cesare Pompeiano <195810094+are-ces@users.noreply.github.com> Date: Tue, 14 Oct 2025 14:52:32 +0200 Subject: [PATCH 4/4] fix: Fixed WatsonX remote inference provider (#3801) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit # What does this PR do? This PR fixes issues with the WatsonX provider so it works correctly with LiteLLM. The main problem was that WatsonX requests failed because the provider data validator didn’t properly handle the API key and project ID. This was fixed by updating the WatsonXProviderDataValidator and ensuring the provider data is loaded correctly. The openai_chat_completion method was also updated to match the behavior of other providers while adding WatsonX-specific fields like project_id. It still calls await super().openai_chat_completion.__func__(self, params) to keep the existing setup and tracing logic. After these changes, WatsonX requests now run correctly. ## Test Plan The changes were tested by running chat completion requests and confirming that credentials and project parameters are passed correctly. I have tested with my WatsonX credentials, by using the cli with `uv run llama-stack-client inference chat-completion --session` --------- Signed-off-by: Sébastien Han Co-authored-by: Sébastien Han --- llama_stack/providers/registry/inference.py | 2 +- .../remote/inference/watsonx/config.py | 10 +- .../remote/inference/watsonx/watsonx.py | 243 +++++++++++++++++- .../inference/test_openai_completion.py | 11 +- .../inference/test_openai_embeddings.py | 14 +- 5 files changed, 254 insertions(+), 26 deletions(-) diff --git a/llama_stack/providers/registry/inference.py b/llama_stack/providers/registry/inference.py index f89565892..6033c3186 100644 --- a/llama_stack/providers/registry/inference.py +++ b/llama_stack/providers/registry/inference.py @@ -271,7 +271,7 @@ Available Models: pip_packages=["litellm"], module="llama_stack.providers.remote.inference.watsonx", config_class="llama_stack.providers.remote.inference.watsonx.WatsonXConfig", - provider_data_validator="llama_stack.providers.remote.inference.watsonx.WatsonXProviderDataValidator", + provider_data_validator="llama_stack.providers.remote.inference.watsonx.config.WatsonXProviderDataValidator", description="IBM WatsonX inference provider for accessing AI models on IBM's WatsonX platform.", ), RemoteProviderSpec( diff --git a/llama_stack/providers/remote/inference/watsonx/config.py b/llama_stack/providers/remote/inference/watsonx/config.py index 022dc5ee7..8d8df13b4 100644 --- a/llama_stack/providers/remote/inference/watsonx/config.py +++ b/llama_stack/providers/remote/inference/watsonx/config.py @@ -7,18 +7,18 @@ import os from typing import Any -from pydantic import BaseModel, ConfigDict, Field +from pydantic import BaseModel, Field from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig from llama_stack.schema_utils import json_schema_type class WatsonXProviderDataValidator(BaseModel): - model_config = ConfigDict( - from_attributes=True, - extra="forbid", + watsonx_project_id: str | None = Field( + default=None, + description="IBM WatsonX project ID", ) - watsonx_api_key: str | None + watsonx_api_key: str | None = None @json_schema_type diff --git a/llama_stack/providers/remote/inference/watsonx/watsonx.py b/llama_stack/providers/remote/inference/watsonx/watsonx.py index 654d61f34..2c051719b 100644 --- a/llama_stack/providers/remote/inference/watsonx/watsonx.py +++ b/llama_stack/providers/remote/inference/watsonx/watsonx.py @@ -4,42 +4,259 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. +from collections.abc import AsyncIterator from typing import Any +import litellm import requests -from llama_stack.apis.inference import ChatCompletionRequest +from llama_stack.apis.inference.inference import ( + OpenAIChatCompletion, + OpenAIChatCompletionChunk, + OpenAIChatCompletionRequestWithExtraBody, + OpenAIChatCompletionUsage, + OpenAICompletion, + OpenAICompletionRequestWithExtraBody, + OpenAIEmbeddingsRequestWithExtraBody, + OpenAIEmbeddingsResponse, +) from llama_stack.apis.models import Model from llama_stack.apis.models.models import ModelType +from llama_stack.log import get_logger from llama_stack.providers.remote.inference.watsonx.config import WatsonXConfig from llama_stack.providers.utils.inference.litellm_openai_mixin import LiteLLMOpenAIMixin +from llama_stack.providers.utils.inference.openai_compat import prepare_openai_completion_params +from llama_stack.providers.utils.telemetry.tracing import get_current_span + +logger = get_logger(name=__name__, category="providers::remote::watsonx") class WatsonXInferenceAdapter(LiteLLMOpenAIMixin): _model_cache: dict[str, Model] = {} + provider_data_api_key_field: str = "watsonx_api_key" + def __init__(self, config: WatsonXConfig): + self.available_models = None + self.config = config + api_key = config.auth_credential.get_secret_value() if config.auth_credential else None LiteLLMOpenAIMixin.__init__( self, litellm_provider_name="watsonx", - api_key_from_config=config.auth_credential.get_secret_value() if config.auth_credential else None, + api_key_from_config=api_key, provider_data_api_key_field="watsonx_api_key", + openai_compat_api_base=self.get_base_url(), + ) + + async def openai_chat_completion( + self, + params: OpenAIChatCompletionRequestWithExtraBody, + ) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]: + """ + Override parent method to add timeout and inject usage object when missing. + This works around a LiteLLM defect where usage block is sometimes dropped. + """ + + # Add usage tracking for streaming when telemetry is active + stream_options = params.stream_options + if params.stream and get_current_span() is not None: + if stream_options is None: + stream_options = {"include_usage": True} + elif "include_usage" not in stream_options: + stream_options = {**stream_options, "include_usage": True} + + model_obj = await self.model_store.get_model(params.model) + + request_params = await prepare_openai_completion_params( + model=self.get_litellm_model_name(model_obj.provider_resource_id), + messages=params.messages, + frequency_penalty=params.frequency_penalty, + function_call=params.function_call, + functions=params.functions, + logit_bias=params.logit_bias, + logprobs=params.logprobs, + max_completion_tokens=params.max_completion_tokens, + max_tokens=params.max_tokens, + n=params.n, + parallel_tool_calls=params.parallel_tool_calls, + presence_penalty=params.presence_penalty, + response_format=params.response_format, + seed=params.seed, + stop=params.stop, + stream=params.stream, + stream_options=stream_options, + temperature=params.temperature, + tool_choice=params.tool_choice, + tools=params.tools, + top_logprobs=params.top_logprobs, + top_p=params.top_p, + user=params.user, + api_key=self.get_api_key(), + api_base=self.api_base, + # These are watsonx-specific parameters + timeout=self.config.timeout, + project_id=self.config.project_id, + ) + + result = await litellm.acompletion(**request_params) + + # If not streaming, check and inject usage if missing + if not params.stream: + # Use getattr to safely handle cases where usage attribute might not exist + if getattr(result, "usage", None) is None: + # Create usage object with zeros + usage_obj = OpenAIChatCompletionUsage( + prompt_tokens=0, + completion_tokens=0, + total_tokens=0, + ) + # Use model_copy to create a new response with the usage injected + result = result.model_copy(update={"usage": usage_obj}) + return result + + # For streaming, wrap the iterator to normalize chunks + return self._normalize_stream(result) + + def _normalize_chunk(self, chunk: OpenAIChatCompletionChunk) -> OpenAIChatCompletionChunk: + """ + Normalize a chunk to ensure it has all expected attributes. + This works around LiteLLM not always including all expected attributes. + """ + # Ensure chunk has usage attribute with zeros if missing + if not hasattr(chunk, "usage") or chunk.usage is None: + usage_obj = OpenAIChatCompletionUsage( + prompt_tokens=0, + completion_tokens=0, + total_tokens=0, + ) + chunk = chunk.model_copy(update={"usage": usage_obj}) + + # Ensure all delta objects in choices have expected attributes + if hasattr(chunk, "choices") and chunk.choices: + normalized_choices = [] + for choice in chunk.choices: + if hasattr(choice, "delta") and choice.delta: + delta = choice.delta + # Build update dict for missing attributes + delta_updates = {} + if not hasattr(delta, "refusal"): + delta_updates["refusal"] = None + if not hasattr(delta, "reasoning_content"): + delta_updates["reasoning_content"] = None + + # If we need to update delta, create a new choice with updated delta + if delta_updates: + new_delta = delta.model_copy(update=delta_updates) + new_choice = choice.model_copy(update={"delta": new_delta}) + normalized_choices.append(new_choice) + else: + normalized_choices.append(choice) + else: + normalized_choices.append(choice) + + # If we modified any choices, create a new chunk with updated choices + if any(normalized_choices[i] is not chunk.choices[i] for i in range(len(chunk.choices))): + chunk = chunk.model_copy(update={"choices": normalized_choices}) + + return chunk + + async def _normalize_stream( + self, stream: AsyncIterator[OpenAIChatCompletionChunk] + ) -> AsyncIterator[OpenAIChatCompletionChunk]: + """ + Normalize all chunks in the stream to ensure they have expected attributes. + This works around LiteLLM sometimes not including expected attributes. + """ + try: + async for chunk in stream: + # Normalize and yield each chunk immediately + yield self._normalize_chunk(chunk) + except Exception as e: + logger.error(f"Error normalizing stream: {e}", exc_info=True) + raise + + async def openai_completion( + self, + params: OpenAICompletionRequestWithExtraBody, + ) -> OpenAICompletion: + """ + Override parent method to add watsonx-specific parameters. + """ + from llama_stack.providers.utils.inference.openai_compat import prepare_openai_completion_params + + model_obj = await self.model_store.get_model(params.model) + + request_params = await prepare_openai_completion_params( + model=self.get_litellm_model_name(model_obj.provider_resource_id), + prompt=params.prompt, + best_of=params.best_of, + echo=params.echo, + frequency_penalty=params.frequency_penalty, + logit_bias=params.logit_bias, + logprobs=params.logprobs, + max_tokens=params.max_tokens, + n=params.n, + presence_penalty=params.presence_penalty, + seed=params.seed, + stop=params.stop, + stream=params.stream, + stream_options=params.stream_options, + temperature=params.temperature, + top_p=params.top_p, + user=params.user, + suffix=params.suffix, + api_key=self.get_api_key(), + api_base=self.api_base, + # These are watsonx-specific parameters + timeout=self.config.timeout, + project_id=self.config.project_id, + ) + return await litellm.atext_completion(**request_params) + + async def openai_embeddings( + self, + params: OpenAIEmbeddingsRequestWithExtraBody, + ) -> OpenAIEmbeddingsResponse: + """ + Override parent method to add watsonx-specific parameters. + """ + model_obj = await self.model_store.get_model(params.model) + + # Convert input to list if it's a string + input_list = [params.input] if isinstance(params.input, str) else params.input + + # Call litellm embedding function with watsonx-specific parameters + response = litellm.embedding( + model=self.get_litellm_model_name(model_obj.provider_resource_id), + input=input_list, + api_key=self.get_api_key(), + api_base=self.api_base, + dimensions=params.dimensions, + # These are watsonx-specific parameters + timeout=self.config.timeout, + project_id=self.config.project_id, + ) + + # Convert response to OpenAI format + from llama_stack.apis.inference import OpenAIEmbeddingUsage + from llama_stack.providers.utils.inference.litellm_openai_mixin import b64_encode_openai_embeddings_response + + data = b64_encode_openai_embeddings_response(response.data, params.encoding_format) + + usage = OpenAIEmbeddingUsage( + prompt_tokens=response["usage"]["prompt_tokens"], + total_tokens=response["usage"]["total_tokens"], + ) + + return OpenAIEmbeddingsResponse( + data=data, + model=model_obj.provider_resource_id, + usage=usage, ) - self.available_models = None - self.config = config def get_base_url(self) -> str: return self.config.url - async def _get_params(self, request: ChatCompletionRequest) -> dict[str, Any]: - # Get base parameters from parent - params = await super()._get_params(request) - - # Add watsonx.ai specific parameters - params["project_id"] = self.config.project_id - params["time_limit"] = self.config.timeout - return params - # Copied from OpenAIMixin async def check_model_availability(self, model: str) -> bool: """ diff --git a/tests/integration/inference/test_openai_completion.py b/tests/integration/inference/test_openai_completion.py index 3f0cffb2d..65f773889 100644 --- a/tests/integration/inference/test_openai_completion.py +++ b/tests/integration/inference/test_openai_completion.py @@ -58,7 +58,6 @@ def skip_if_model_doesnt_support_openai_completion(client_with_models, model_id) # does not work with the specified model, gpt-5-mini. Please choose different model and try # again. You can learn more about which models can be used with each operation here: # https://go.microsoft.com/fwlink/?linkid=2197993.'}}"} - "remote::watsonx", # return 404 when hitting the /openai/v1 endpoint "remote::llama-openai-compat", ): pytest.skip(f"Model {model_id} hosted by {provider.provider_type} doesn't support OpenAI completions.") @@ -68,6 +67,7 @@ def skip_if_doesnt_support_completions_logprobs(client_with_models, model_id): provider_type = provider_from_model(client_with_models, model_id).provider_type if provider_type in ( "remote::ollama", # logprobs is ignored + "remote::watsonx", ): pytest.skip(f"Model {model_id} hosted by {provider_type} doesn't support /v1/completions logprobs.") @@ -110,6 +110,7 @@ def skip_if_doesnt_support_n(client_with_models, model_id): # Error code 400 - {'message': '"n" > 1 is not currently supported', 'type': 'invalid_request_error', 'param': 'n', 'code': 'wrong_api_format'} "remote::cerebras", "remote::databricks", # Bad request: parameter "n" must be equal to 1 for streaming mode + "remote::watsonx", ): pytest.skip(f"Model {model_id} hosted by {provider.provider_type} doesn't support n param.") @@ -124,7 +125,6 @@ def skip_if_model_doesnt_support_openai_chat_completion(client_with_models, mode "remote::databricks", "remote::cerebras", "remote::runpod", - "remote::watsonx", # watsonx returns 404 when hitting the /openai/v1 endpoint ): pytest.skip(f"Model {model_id} hosted by {provider.provider_type} doesn't support OpenAI chat completions.") @@ -508,6 +508,12 @@ def test_openai_chat_completion_non_streaming_with_file(openai_client, client_wi assert "hello world" in normalized_content +def skip_if_doesnt_support_completions_stop_sequence(client_with_models, model_id): + provider_type = provider_from_model(client_with_models, model_id).provider_type + if provider_type in ("remote::watsonx",): # openai.BadRequestError: Error code: 400 + pytest.skip(f"Model {model_id} hosted by {provider_type} doesn't support /v1/completions stop sequence.") + + @pytest.mark.parametrize( "test_case", [ @@ -516,6 +522,7 @@ def test_openai_chat_completion_non_streaming_with_file(openai_client, client_wi ) def test_openai_completion_stop_sequence(client_with_models, openai_client, text_model_id, test_case): skip_if_model_doesnt_support_openai_completion(client_with_models, text_model_id) + skip_if_doesnt_support_completions_stop_sequence(client_with_models, text_model_id) tc = TestCase(test_case) diff --git a/tests/integration/inference/test_openai_embeddings.py b/tests/integration/inference/test_openai_embeddings.py index 84e92706a..0c1d4d08e 100644 --- a/tests/integration/inference/test_openai_embeddings.py +++ b/tests/integration/inference/test_openai_embeddings.py @@ -50,11 +50,15 @@ def skip_if_model_doesnt_support_encoding_format_base64(client, model_id): def skip_if_model_doesnt_support_variable_dimensions(client_with_models, model_id): provider = provider_from_model(client_with_models, model_id) - if provider.provider_type in ( - "remote::together", # returns 400 - "inline::sentence-transformers", - # Error code: 400 - {'error_code': 'BAD_REQUEST', 'message': 'Bad request: json: unknown field "dimensions"\n'} - "remote::databricks", + if ( + provider.provider_type + in ( + "remote::together", # returns 400 + "inline::sentence-transformers", + # Error code: 400 - {'error_code': 'BAD_REQUEST', 'message': 'Bad request: json: unknown field "dimensions"\n'} + "remote::databricks", + "remote::watsonx", # openai.BadRequestError: Error code: 400 - {'detail': "litellm.UnsupportedParamsError: watsonx does not support parameters: {'dimensions': 384} + ) ): pytest.skip( f"Model {model_id} hosted by {provider.provider_type} does not support variable output embedding dimensions."