mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-12 12:06:04 +00:00
Fix unit test failures
Signed-off-by: Bill Murdock <bmurdock@redhat.com>
This commit is contained in:
parent
a4b9b1e494
commit
e601fbc919
3 changed files with 33 additions and 10 deletions
|
|
@ -7,16 +7,18 @@
|
|||
import os
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field, SecretStr
|
||||
from pydantic import BaseModel, ConfigDict, Field, SecretStr
|
||||
|
||||
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
|
||||
from llama_stack.schema_utils import json_schema_type
|
||||
|
||||
|
||||
class WatsonXProviderDataValidator(BaseModel):
|
||||
url: str
|
||||
api_key: str
|
||||
project_id: str
|
||||
model_config = ConfigDict(
|
||||
from_attributes=True,
|
||||
extra="forbid",
|
||||
)
|
||||
watsonx_api_key: str | None
|
||||
|
||||
|
||||
@json_schema_type
|
||||
|
|
|
|||
|
|
@ -22,7 +22,7 @@ class WatsonXInferenceAdapter(LiteLLMOpenAIMixin):
|
|||
LiteLLMOpenAIMixin.__init__(
|
||||
self,
|
||||
litellm_provider_name="watsonx",
|
||||
api_key_from_config=config.api_key.get_secret_value(),
|
||||
api_key_from_config=config.api_key.get_secret_value() if config.api_key else None,
|
||||
provider_data_api_key_field="watsonx_api_key",
|
||||
)
|
||||
self.available_models = None
|
||||
|
|
|
|||
|
|
@ -30,11 +30,6 @@ from llama_stack.providers.remote.inference.watsonx.watsonx import WatsonXInfere
|
|||
GroqInferenceAdapter,
|
||||
"llama_stack.providers.remote.inference.groq.config.GroqProviderDataValidator",
|
||||
),
|
||||
(
|
||||
WatsonXConfig,
|
||||
WatsonXInferenceAdapter,
|
||||
"llama_stack.providers.remote.inference.watsonx.config.WatsonXProviderDataValidator",
|
||||
),
|
||||
(
|
||||
OpenAIConfig,
|
||||
OpenAIInferenceAdapter,
|
||||
|
|
@ -65,3 +60,29 @@ def test_openai_provider_data_used(config_cls, adapter_cls, provider_data_valida
|
|||
{"x-llamastack-provider-data": json.dumps({inference_adapter.provider_data_api_key_field: api_key})}
|
||||
):
|
||||
assert inference_adapter.client.api_key == api_key
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"config_cls,adapter_cls,provider_data_validator",
|
||||
[
|
||||
(
|
||||
WatsonXConfig,
|
||||
WatsonXInferenceAdapter,
|
||||
"llama_stack.providers.remote.inference.watsonx.config.WatsonXProviderDataValidator",
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_litellm_provider_data_used(config_cls, adapter_cls, provider_data_validator: str):
|
||||
"""Validate data for LiteLLM-based providers. Similar to test_openai_provider_data_used, but without the
|
||||
assumption that there is an OpenAI-compatible client object."""
|
||||
|
||||
inference_adapter = adapter_cls(config=config_cls())
|
||||
|
||||
inference_adapter.__provider_spec__ = MagicMock()
|
||||
inference_adapter.__provider_spec__.provider_data_validator = provider_data_validator
|
||||
|
||||
for api_key in ["test1", "test2"]:
|
||||
with request_provider_data_context(
|
||||
{"x-llamastack-provider-data": json.dumps({inference_adapter.provider_data_api_key_field: api_key})}
|
||||
):
|
||||
assert inference_adapter.get_api_key() == api_key
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue