Fix unit test failures

Signed-off-by: Bill Murdock <bmurdock@redhat.com>
This commit is contained in:
Bill Murdock 2025-10-06 20:00:58 -04:00
parent a4b9b1e494
commit e601fbc919
3 changed files with 33 additions and 10 deletions

View file

@ -7,16 +7,18 @@
import os
from typing import Any
from pydantic import BaseModel, Field, SecretStr
from pydantic import BaseModel, ConfigDict, Field, SecretStr
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
from llama_stack.schema_utils import json_schema_type
class WatsonXProviderDataValidator(BaseModel):
url: str
api_key: str
project_id: str
model_config = ConfigDict(
from_attributes=True,
extra="forbid",
)
watsonx_api_key: str | None
@json_schema_type

View file

@ -22,7 +22,7 @@ class WatsonXInferenceAdapter(LiteLLMOpenAIMixin):
LiteLLMOpenAIMixin.__init__(
self,
litellm_provider_name="watsonx",
api_key_from_config=config.api_key.get_secret_value(),
api_key_from_config=config.api_key.get_secret_value() if config.api_key else None,
provider_data_api_key_field="watsonx_api_key",
)
self.available_models = None

View file

@ -30,11 +30,6 @@ from llama_stack.providers.remote.inference.watsonx.watsonx import WatsonXInfere
GroqInferenceAdapter,
"llama_stack.providers.remote.inference.groq.config.GroqProviderDataValidator",
),
(
WatsonXConfig,
WatsonXInferenceAdapter,
"llama_stack.providers.remote.inference.watsonx.config.WatsonXProviderDataValidator",
),
(
OpenAIConfig,
OpenAIInferenceAdapter,
@ -65,3 +60,29 @@ def test_openai_provider_data_used(config_cls, adapter_cls, provider_data_valida
{"x-llamastack-provider-data": json.dumps({inference_adapter.provider_data_api_key_field: api_key})}
):
assert inference_adapter.client.api_key == api_key
@pytest.mark.parametrize(
"config_cls,adapter_cls,provider_data_validator",
[
(
WatsonXConfig,
WatsonXInferenceAdapter,
"llama_stack.providers.remote.inference.watsonx.config.WatsonXProviderDataValidator",
),
],
)
def test_litellm_provider_data_used(config_cls, adapter_cls, provider_data_validator: str):
"""Validate data for LiteLLM-based providers. Similar to test_openai_provider_data_used, but without the
assumption that there is an OpenAI-compatible client object."""
inference_adapter = adapter_cls(config=config_cls())
inference_adapter.__provider_spec__ = MagicMock()
inference_adapter.__provider_spec__.provider_data_validator = provider_data_validator
for api_key in ["test1", "test2"]:
with request_provider_data_context(
{"x-llamastack-provider-data": json.dumps({inference_adapter.provider_data_api_key_field: api_key})}
):
assert inference_adapter.get_api_key() == api_key