litellm-mirror/tests/llm_translation/test_watsonx.py
Krish Dholakia 76795dba39
Deepseek r1 support + watsonx qa improvements (#7907)
* fix(types/utils.py): support returning 'reasoning_content' for deepseek models

Fixes https://github.com/BerriAI/litellm/issues/7877#issuecomment-2603813218

* fix(convert_dict_to_response.py): return deepseek response in provider_specific_field

allows for separating openai vs. non-openai params in model response

* fix(utils.py): support 'provider_specific_field' in delta chunk as well

allows deepseek reasoning content chunk to be returned to user from stream as well

Fixes https://github.com/BerriAI/litellm/issues/7877#issuecomment-2603813218

* fix(watsonx/chat/handler.py): fix passing space id to watsonx on chat route

* fix(watsonx/): fix watsonx_text/ route with space id

* fix(watsonx/): qa item - also adds better unit testing for watsonx embedding calls

* fix(utils.py): rename to '..fields'

* fix: fix linting errors

* fix(utils.py): fix typing - don't show provider-specific field if none or empty - prevents default respons
e from being non-oai compatible

* fix: cleanup unused imports

* docs(deepseek.md): add docs for deepseek reasoning model
2025-01-21 23:13:15 -08:00

220 lines
6.9 KiB
Python

import json
import os
import sys
from datetime import datetime
from unittest.mock import AsyncMock
sys.path.insert(
0, os.path.abspath("../..")
) # Adds the parent directory to the system path
import litellm
from litellm import completion, embedding
from litellm.llms.watsonx.common_utils import IBMWatsonXMixin
from litellm.llms.custom_httpx.http_handler import HTTPHandler, AsyncHTTPHandler
from unittest.mock import patch, MagicMock, AsyncMock, Mock
import pytest
from typing import Optional
@pytest.fixture
def watsonx_chat_completion_call():
def _call(
model="watsonx/my-test-model",
messages=None,
api_key="test_api_key",
space_id: Optional[str] = None,
headers=None,
client=None,
patch_token_call=True,
):
if messages is None:
messages = [{"role": "user", "content": "Hello, how are you?"}]
if client is None:
client = HTTPHandler()
if patch_token_call:
mock_response = Mock()
mock_response.json.return_value = {
"access_token": "mock_access_token",
"expires_in": 3600,
}
mock_response.raise_for_status = Mock() # No-op to simulate no exception
with patch.object(client, "post") as mock_post, patch.object(
litellm.module_level_client, "post", return_value=mock_response
) as mock_get:
try:
completion(
model=model,
messages=messages,
api_key=api_key,
headers=headers or {},
client=client,
space_id=space_id,
)
except Exception as e:
print(e)
return mock_post, mock_get
else:
with patch.object(client, "post") as mock_post:
try:
completion(
model=model,
messages=messages,
api_key=api_key,
headers=headers or {},
client=client,
space_id=space_id,
)
except Exception as e:
print(e)
return mock_post, None
return _call
@pytest.fixture
def watsonx_embedding_call():
def _call(
model="watsonx/my-test-model",
input=None,
api_key="test_api_key",
space_id: Optional[str] = None,
headers=None,
client=None,
patch_token_call=True,
):
if input is None:
input = ["Hello, how are you?"]
if client is None:
client = HTTPHandler()
if patch_token_call:
mock_response = Mock()
mock_response.json.return_value = {
"access_token": "mock_access_token",
"expires_in": 3600,
}
mock_response.raise_for_status = Mock() # No-op to simulate no exception
with patch.object(client, "post") as mock_post, patch.object(
litellm.module_level_client, "post", return_value=mock_response
) as mock_get:
try:
embedding(
model=model,
input=input,
api_key=api_key,
headers=headers or {},
client=client,
space_id=space_id,
)
except Exception as e:
print(e)
return mock_post, mock_get
else:
with patch.object(client, "post") as mock_post:
try:
embedding(
model=model,
input=input,
api_key=api_key,
headers=headers or {},
client=client,
space_id=space_id,
)
except Exception as e:
print(e)
return mock_post, None
return _call
@pytest.mark.parametrize("with_custom_auth_header", [True, False])
def test_watsonx_custom_auth_header(
with_custom_auth_header, watsonx_chat_completion_call
):
headers = (
{"Authorization": "Bearer my-custom-auth-header"}
if with_custom_auth_header
else {}
)
mock_post, _ = watsonx_chat_completion_call(headers=headers)
assert mock_post.call_count == 1
if with_custom_auth_header:
assert (
mock_post.call_args[1]["headers"]["Authorization"]
== "Bearer my-custom-auth-header"
)
else:
assert (
mock_post.call_args[1]["headers"]["Authorization"]
== "Bearer mock_access_token"
)
@pytest.mark.parametrize("env_var_key", ["WATSONX_ZENAPIKEY", "WATSONX_TOKEN"])
def test_watsonx_token_in_env_var(
monkeypatch, watsonx_chat_completion_call, env_var_key
):
monkeypatch.setenv(env_var_key, "my-custom-token")
mock_post, _ = watsonx_chat_completion_call(patch_token_call=False)
assert mock_post.call_count == 1
if env_var_key == "WATSONX_ZENAPIKEY":
assert (
mock_post.call_args[1]["headers"]["Authorization"]
== "ZenApiKey my-custom-token"
)
else:
assert (
mock_post.call_args[1]["headers"]["Authorization"]
== "Bearer my-custom-token"
)
def test_watsonx_chat_completions_endpoint(watsonx_chat_completion_call):
model = "watsonx/another-model"
messages = [{"role": "user", "content": "Test message"}]
mock_post, _ = watsonx_chat_completion_call(model=model, messages=messages)
assert mock_post.call_count == 1
assert "deployment" not in mock_post.call_args.kwargs["url"]
@pytest.mark.parametrize(
"model",
[
"watsonx/deployment/<xxxx.xxx.xxx.xxxx>",
"watsonx_text/deployment/<xxxx.xxx.xxx.xxxx>",
],
)
def test_watsonx_deployment_space_id(monkeypatch, watsonx_chat_completion_call, model):
my_fake_space_id = "xxx-xxx-xxx-xxx-xxx"
monkeypatch.setenv("WATSONX_SPACE_ID", my_fake_space_id)
mock_post, _ = watsonx_chat_completion_call(
model=model,
messages=[{"content": "Hello, how are you?", "role": "user"}],
)
assert mock_post.call_count == 1
json_data = json.loads(mock_post.call_args.kwargs["data"])
assert my_fake_space_id == json_data["space_id"]
def test_watsonx_deployment_space_id_embedding(monkeypatch, watsonx_embedding_call):
my_fake_space_id = "xxx-xxx-xxx-xxx-xxx"
monkeypatch.setenv("WATSONX_SPACE_ID", my_fake_space_id)
mock_post, _ = watsonx_embedding_call(model="watsonx/deployment/my-test-model")
assert mock_post.call_count == 1
json_data = json.loads(mock_post.call_args.kwargs["data"])
assert my_fake_space_id == json_data["space_id"]