diff --git a/llama_stack/providers/utils/inference/openai_mixin.py b/llama_stack/providers/utils/inference/openai_mixin.py index dc397aa76..a9ccc8091 100644 --- a/llama_stack/providers/utils/inference/openai_mixin.py +++ b/llama_stack/providers/utils/inference/openai_mixin.py @@ -168,13 +168,7 @@ class OpenAIMixin(NeedsRequestProviderData, ABC, BaseModel): is used instead of any config API key. """ - api_key = self.get_api_key() - - if self.provider_data_api_key_field: - provider_data = self.get_request_provider_data() - if provider_data and getattr(provider_data, self.provider_data_api_key_field, None): - api_key = getattr(provider_data, self.provider_data_api_key_field) - + api_key = self._get_api_key_from_config_or_provider_data() if not api_key: message = "API key not provided." if self.provider_data_api_key_field: @@ -187,6 +181,16 @@ class OpenAIMixin(NeedsRequestProviderData, ABC, BaseModel): **self.get_extra_client_params(), ) + def _get_api_key_from_config_or_provider_data(self) -> str | None: + api_key = self.get_api_key() + + if self.provider_data_api_key_field: + provider_data = self.get_request_provider_data() + if provider_data and getattr(provider_data, self.provider_data_api_key_field, None): + api_key = getattr(provider_data, self.provider_data_api_key_field) + + return api_key + async def _get_provider_model_id(self, model: str) -> str: """ Get the provider-specific model ID from the model store. @@ -387,6 +391,11 @@ class OpenAIMixin(NeedsRequestProviderData, ABC, BaseModel): """ self._model_cache = {} + api_key = self._get_api_key_from_config_or_provider_data() + if not api_key: + logger.debug(f"{self.__class__.__name__}.list_provider_model_ids() disabled because API key not provided") + return None + try: iterable = await self.list_provider_model_ids() except Exception as e: diff --git a/tests/unit/providers/utils/inference/test_openai_mixin.py b/tests/unit/providers/utils/inference/test_openai_mixin.py index 78241bc22..61a1f8f61 100644 --- a/tests/unit/providers/utils/inference/test_openai_mixin.py +++ b/tests/unit/providers/utils/inference/test_openai_mixin.py @@ -23,10 +23,10 @@ class OpenAIMixinImpl(OpenAIMixin): __provider_id__: str = "test-provider" def get_api_key(self) -> str: - raise NotImplementedError("This method should be mocked in tests") + return "test-api-key" def get_base_url(self) -> str: - raise NotImplementedError("This method should be mocked in tests") + return "http://test-base-url" class OpenAIMixinWithEmbeddingsImpl(OpenAIMixinImpl):