diff --git a/llama_stack/providers/remote/inference/watsonx/config.py b/llama_stack/providers/remote/inference/watsonx/config.py index 2a77f9edb..9e98d4003 100644 --- a/llama_stack/providers/remote/inference/watsonx/config.py +++ b/llama_stack/providers/remote/inference/watsonx/config.py @@ -7,16 +7,18 @@ import os from typing import Any -from pydantic import BaseModel, Field, SecretStr +from pydantic import BaseModel, ConfigDict, Field, SecretStr from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig from llama_stack.schema_utils import json_schema_type class WatsonXProviderDataValidator(BaseModel): - url: str - api_key: str - project_id: str + model_config = ConfigDict( + from_attributes=True, + extra="forbid", + ) + watsonx_api_key: str | None @json_schema_type diff --git a/llama_stack/providers/remote/inference/watsonx/watsonx.py b/llama_stack/providers/remote/inference/watsonx/watsonx.py index e7f96405a..d04472936 100644 --- a/llama_stack/providers/remote/inference/watsonx/watsonx.py +++ b/llama_stack/providers/remote/inference/watsonx/watsonx.py @@ -22,7 +22,7 @@ class WatsonXInferenceAdapter(LiteLLMOpenAIMixin): LiteLLMOpenAIMixin.__init__( self, litellm_provider_name="watsonx", - api_key_from_config=config.api_key.get_secret_value(), + api_key_from_config=config.api_key.get_secret_value() if config.api_key else None, provider_data_api_key_field="watsonx_api_key", ) self.available_models = None diff --git a/tests/unit/providers/inference/test_inference_client_caching.py b/tests/unit/providers/inference/test_inference_client_caching.py index eb9444e34..55a6793c2 100644 --- a/tests/unit/providers/inference/test_inference_client_caching.py +++ b/tests/unit/providers/inference/test_inference_client_caching.py @@ -30,11 +30,6 @@ from llama_stack.providers.remote.inference.watsonx.watsonx import WatsonXInfere GroqInferenceAdapter, "llama_stack.providers.remote.inference.groq.config.GroqProviderDataValidator", ), - ( - WatsonXConfig, - WatsonXInferenceAdapter, - "llama_stack.providers.remote.inference.watsonx.config.WatsonXProviderDataValidator", - ), ( OpenAIConfig, OpenAIInferenceAdapter, @@ -65,3 +60,29 @@ def test_openai_provider_data_used(config_cls, adapter_cls, provider_data_valida {"x-llamastack-provider-data": json.dumps({inference_adapter.provider_data_api_key_field: api_key})} ): assert inference_adapter.client.api_key == api_key + + +@pytest.mark.parametrize( + "config_cls,adapter_cls,provider_data_validator", + [ + ( + WatsonXConfig, + WatsonXInferenceAdapter, + "llama_stack.providers.remote.inference.watsonx.config.WatsonXProviderDataValidator", + ), + ], +) +def test_litellm_provider_data_used(config_cls, adapter_cls, provider_data_validator: str): + """Validate data for LiteLLM-based providers. Similar to test_openai_provider_data_used, but without the + assumption that there is an OpenAI-compatible client object.""" + + inference_adapter = adapter_cls(config=config_cls()) + + inference_adapter.__provider_spec__ = MagicMock() + inference_adapter.__provider_spec__.provider_data_validator = provider_data_validator + + for api_key in ["test1", "test2"]: + with request_provider_data_context( + {"x-llamastack-provider-data": json.dumps({inference_adapter.provider_data_api_key_field: api_key})} + ): + assert inference_adapter.get_api_key() == api_key