litellm-mirror/tests/llm_translation/test_watsonx.py
Krish Dholakia fb1272b46b Support checking provider-specific /models endpoints for available models based on key (#7538)
* test(test_utils.py): initial test for valid models

Addresses https://github.com/BerriAI/litellm/issues/7525

* fix: test

* feat(fireworks_ai/transformation.py): support retrieving valid models from fireworks ai endpoint

* refactor(fireworks_ai/): support checking model info on `/v1/models` route

* docs(set_keys.md): update docs to clarify check llm provider api usage

* fix(watsonx/common_utils.py): support 'WATSONX_ZENAPIKEY' for iam auth

* fix(watsonx): read in watsonx token from env var

* fix: fix linting errors

* fix(utils.py): fix provider config check

* style: cleanup unused imports
2025-01-03 19:29:59 -08:00

113 lines
3.4 KiB
Python

import json
import os
import sys
from datetime import datetime
from unittest.mock import AsyncMock
sys.path.insert(
0, os.path.abspath("../..")
) # Adds the parent directory to the system path
import litellm
from litellm import completion
from litellm.llms.watsonx.common_utils import IBMWatsonXMixin
from litellm.llms.custom_httpx.http_handler import HTTPHandler, AsyncHTTPHandler
from unittest.mock import patch, MagicMock, AsyncMock, Mock
import pytest
@pytest.fixture
def watsonx_chat_completion_call():
def _call(
model="watsonx/my-test-model",
messages=None,
api_key="test_api_key",
headers=None,
client=None,
patch_token_call=True,
):
if messages is None:
messages = [{"role": "user", "content": "Hello, how are you?"}]
if client is None:
client = HTTPHandler()
if patch_token_call:
mock_response = Mock()
mock_response.json.return_value = {
"access_token": "mock_access_token",
"expires_in": 3600,
}
mock_response.raise_for_status = Mock() # No-op to simulate no exception
with patch.object(client, "post") as mock_post, patch.object(
litellm.module_level_client, "post", return_value=mock_response
) as mock_get:
completion(
model=model,
messages=messages,
api_key=api_key,
headers=headers or {},
client=client,
)
return mock_post, mock_get
else:
with patch.object(client, "post") as mock_post:
completion(
model=model,
messages=messages,
api_key=api_key,
headers=headers or {},
client=client,
)
return mock_post, None
return _call
@pytest.mark.parametrize("with_custom_auth_header", [True, False])
def test_watsonx_custom_auth_header(
with_custom_auth_header, watsonx_chat_completion_call
):
headers = (
{"Authorization": "Bearer my-custom-auth-header"}
if with_custom_auth_header
else {}
)
mock_post, _ = watsonx_chat_completion_call(headers=headers)
assert mock_post.call_count == 1
if with_custom_auth_header:
assert (
mock_post.call_args[1]["headers"]["Authorization"]
== "Bearer my-custom-auth-header"
)
else:
assert (
mock_post.call_args[1]["headers"]["Authorization"]
== "Bearer mock_access_token"
)
@pytest.mark.parametrize("env_var_key", ["WATSONX_ZENAPIKEY", "WATSONX_TOKEN"])
def test_watsonx_token_in_env_var(
monkeypatch, watsonx_chat_completion_call, env_var_key
):
monkeypatch.setenv(env_var_key, "my-custom-token")
mock_post, _ = watsonx_chat_completion_call(patch_token_call=False)
assert mock_post.call_count == 1
assert (
mock_post.call_args[1]["headers"]["Authorization"] == "Bearer my-custom-token"
)
def test_watsonx_chat_completions_endpoint(watsonx_chat_completion_call):
model = "watsonx/another-model"
messages = [{"role": "user", "content": "Test message"}]
mock_post, _ = watsonx_chat_completion_call(model=model, messages=messages)
assert mock_post.call_count == 1
assert "deployment" not in mock_post.call_args.kwargs["url"]