mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-13 17:12:37 +00:00
Fix unit test failures
Signed-off-by: Bill Murdock <bmurdock@redhat.com>
This commit is contained in:
parent
a4b9b1e494
commit
e601fbc919
3 changed files with 33 additions and 10 deletions
|
|
@ -7,16 +7,18 @@
|
||||||
import os
|
import os
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from pydantic import BaseModel, Field, SecretStr
|
from pydantic import BaseModel, ConfigDict, Field, SecretStr
|
||||||
|
|
||||||
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
|
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
|
||||||
from llama_stack.schema_utils import json_schema_type
|
from llama_stack.schema_utils import json_schema_type
|
||||||
|
|
||||||
|
|
||||||
class WatsonXProviderDataValidator(BaseModel):
|
class WatsonXProviderDataValidator(BaseModel):
|
||||||
url: str
|
model_config = ConfigDict(
|
||||||
api_key: str
|
from_attributes=True,
|
||||||
project_id: str
|
extra="forbid",
|
||||||
|
)
|
||||||
|
watsonx_api_key: str | None
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
|
|
|
||||||
|
|
@ -22,7 +22,7 @@ class WatsonXInferenceAdapter(LiteLLMOpenAIMixin):
|
||||||
LiteLLMOpenAIMixin.__init__(
|
LiteLLMOpenAIMixin.__init__(
|
||||||
self,
|
self,
|
||||||
litellm_provider_name="watsonx",
|
litellm_provider_name="watsonx",
|
||||||
api_key_from_config=config.api_key.get_secret_value(),
|
api_key_from_config=config.api_key.get_secret_value() if config.api_key else None,
|
||||||
provider_data_api_key_field="watsonx_api_key",
|
provider_data_api_key_field="watsonx_api_key",
|
||||||
)
|
)
|
||||||
self.available_models = None
|
self.available_models = None
|
||||||
|
|
|
||||||
|
|
@ -30,11 +30,6 @@ from llama_stack.providers.remote.inference.watsonx.watsonx import WatsonXInfere
|
||||||
GroqInferenceAdapter,
|
GroqInferenceAdapter,
|
||||||
"llama_stack.providers.remote.inference.groq.config.GroqProviderDataValidator",
|
"llama_stack.providers.remote.inference.groq.config.GroqProviderDataValidator",
|
||||||
),
|
),
|
||||||
(
|
|
||||||
WatsonXConfig,
|
|
||||||
WatsonXInferenceAdapter,
|
|
||||||
"llama_stack.providers.remote.inference.watsonx.config.WatsonXProviderDataValidator",
|
|
||||||
),
|
|
||||||
(
|
(
|
||||||
OpenAIConfig,
|
OpenAIConfig,
|
||||||
OpenAIInferenceAdapter,
|
OpenAIInferenceAdapter,
|
||||||
|
|
@ -65,3 +60,29 @@ def test_openai_provider_data_used(config_cls, adapter_cls, provider_data_valida
|
||||||
{"x-llamastack-provider-data": json.dumps({inference_adapter.provider_data_api_key_field: api_key})}
|
{"x-llamastack-provider-data": json.dumps({inference_adapter.provider_data_api_key_field: api_key})}
|
||||||
):
|
):
|
||||||
assert inference_adapter.client.api_key == api_key
|
assert inference_adapter.client.api_key == api_key
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"config_cls,adapter_cls,provider_data_validator",
|
||||||
|
[
|
||||||
|
(
|
||||||
|
WatsonXConfig,
|
||||||
|
WatsonXInferenceAdapter,
|
||||||
|
"llama_stack.providers.remote.inference.watsonx.config.WatsonXProviderDataValidator",
|
||||||
|
),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_litellm_provider_data_used(config_cls, adapter_cls, provider_data_validator: str):
|
||||||
|
"""Validate data for LiteLLM-based providers. Similar to test_openai_provider_data_used, but without the
|
||||||
|
assumption that there is an OpenAI-compatible client object."""
|
||||||
|
|
||||||
|
inference_adapter = adapter_cls(config=config_cls())
|
||||||
|
|
||||||
|
inference_adapter.__provider_spec__ = MagicMock()
|
||||||
|
inference_adapter.__provider_spec__.provider_data_validator = provider_data_validator
|
||||||
|
|
||||||
|
for api_key in ["test1", "test2"]:
|
||||||
|
with request_provider_data_context(
|
||||||
|
{"x-llamastack-provider-data": json.dumps({inference_adapter.provider_data_api_key_field: api_key})}
|
||||||
|
):
|
||||||
|
assert inference_adapter.get_api_key() == api_key
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue