litellm-mirror/tests/llm_translation/test_huggingface.py
Krish Dholakia 9237357bcc
Litellm dev 12 25 2024 p1 (#7411)
* test(test_watsonx.py): e2e unit test for watsonx custom header

covers https://github.com/BerriAI/litellm/issues/7408

* fix(common_utils.py): handle auth token already present in headers (watsonx + openai-like base handler)

Fixes https://github.com/BerriAI/litellm/issues/7408

* fix(watsonx/chat): fix chat route

Fixes https://github.com/BerriAI/litellm/issues/7408

* fix(huggingface/chat/handler.py): fix huggingface async completion calls

* Correct handling of max_retries=0 to disable AzureOpenAI retries (#7379)

* test: fix test

---------

Co-authored-by: Minh Duc <phamminhduc0711@gmail.com>
2024-12-25 17:36:30 -08:00

169 lines
5 KiB
Python

"""
Unit Tests Huggingface route
"""
import json
import os
import sys
from datetime import datetime
from unittest.mock import AsyncMock
sys.path.insert(
0, os.path.abspath("../..")
) # Adds the parent directory to the system path
import litellm
from litellm import completion, acompletion
from litellm.llms.custom_httpx.http_handler import HTTPHandler, AsyncHTTPHandler
from unittest.mock import patch, MagicMock, AsyncMock, Mock
import pytest
def tgi_mock_post(url, **kwargs):
mock_response = MagicMock()
mock_response.status_code = 200
mock_response.headers = {"Content-Type": "application/json"}
mock_response.json.return_value = [
{
"generated_text": "<|assistant|>\nI'm",
"details": {
"finish_reason": "length",
"generated_tokens": 10,
"seed": None,
"prefill": [],
"tokens": [
{
"id": 28789,
"text": "<",
"logprob": -0.025222778,
"special": False,
},
{
"id": 28766,
"text": "|",
"logprob": -0.000003695488,
"special": False,
},
{
"id": 489,
"text": "ass",
"logprob": -0.0000019073486,
"special": False,
},
{
"id": 11143,
"text": "istant",
"logprob": -0.000002026558,
"special": False,
},
{
"id": 28766,
"text": "|",
"logprob": -0.0000015497208,
"special": False,
},
{
"id": 28767,
"text": ">",
"logprob": -0.0000011920929,
"special": False,
},
{
"id": 13,
"text": "\n",
"logprob": -0.00009703636,
"special": False,
},
{"id": 28737, "text": "I", "logprob": -0.1953125, "special": False},
{
"id": 28742,
"text": "'",
"logprob": -0.88183594,
"special": False,
},
{
"id": 28719,
"text": "m",
"logprob": -0.00032639503,
"special": False,
},
],
},
}
]
return mock_response
@pytest.fixture
def huggingface_chat_completion_call():
def _call(
model="huggingface/my-test-model",
messages=None,
api_key="test_api_key",
headers=None,
client=None,
):
if messages is None:
messages = [{"role": "user", "content": "Hello, how are you?"}]
if client is None:
client = HTTPHandler()
mock_response = Mock()
with patch.object(client, "post", side_effect=tgi_mock_post) as mock_post:
completion(
model=model,
messages=messages,
api_key=api_key,
headers=headers or {},
client=client,
)
return mock_post
return _call
@pytest.fixture
def async_huggingface_chat_completion_call():
async def _call(
model="huggingface/my-test-model",
messages=None,
api_key="test_api_key",
headers=None,
client=None,
):
if messages is None:
messages = [{"role": "user", "content": "Hello, how are you?"}]
if client is None:
client = AsyncHTTPHandler()
with patch.object(client, "post", side_effect=tgi_mock_post) as mock_post:
await acompletion(
model=model,
messages=messages,
api_key=api_key,
headers=headers or {},
client=client,
)
return mock_post
return _call
@pytest.mark.parametrize("sync_mode", [True, False])
@pytest.mark.asyncio
async def test_huggingface_chat_completions_endpoint(
sync_mode, huggingface_chat_completion_call, async_huggingface_chat_completion_call
):
model = "huggingface/another-model"
messages = [{"role": "user", "content": "Test message"}]
if sync_mode:
mock_post = huggingface_chat_completion_call(model=model, messages=messages)
else:
mock_post = await async_huggingface_chat_completion_call(
model=model, messages=messages
)
assert mock_post.call_count == 1