test: update tests

This commit is contained in:
Krrish Dholakia 2024-11-30 12:43:45 -08:00
parent d72407515c
commit 7bdc940588
8 changed files with 199 additions and 248 deletions

View file

@ -1 +1,3 @@
More tests under `litellm/litellm/tests/*`. Unit tests for individual LLM providers.
Name of the test file is the name of the LLM provider - e.g. `test_openai.py` is for OpenAI.

View file

@ -42,7 +42,6 @@ def return_mocked_response(model: str):
"bedrock/mistral.mistral-large-2407-v1:0", "bedrock/mistral.mistral-large-2407-v1:0",
], ],
) )
@pytest.mark.respx
@pytest.mark.asyncio() @pytest.mark.asyncio()
async def test_bedrock_max_completion_tokens(model: str): async def test_bedrock_max_completion_tokens(model: str):
""" """
@ -87,7 +86,6 @@ async def test_bedrock_max_completion_tokens(model: str):
"model", "model",
["anthropic/claude-3-sonnet-20240229", "anthropic/claude-3-opus-20240229"], ["anthropic/claude-3-sonnet-20240229", "anthropic/claude-3-opus-20240229"],
) )
@pytest.mark.respx
@pytest.mark.asyncio() @pytest.mark.asyncio()
async def test_anthropic_api_max_completion_tokens(model: str): async def test_anthropic_api_max_completion_tokens(model: str):
""" """

View file

@ -19,7 +19,6 @@ from litellm import Choices, Message, ModelResponse, EmbeddingResponse, Usage
from litellm import completion from litellm import completion
@pytest.mark.respx
def test_completion_nvidia_nim(): def test_completion_nvidia_nim():
from openai import OpenAI from openai import OpenAI

View file

@ -2,7 +2,7 @@ import json
import os import os
import sys import sys
from datetime import datetime from datetime import datetime
from unittest.mock import AsyncMock from unittest.mock import AsyncMock, patch
sys.path.insert( sys.path.insert(
0, os.path.abspath("../..") 0, os.path.abspath("../..")
@ -63,8 +63,7 @@ def test_openai_prediction_param():
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.respx async def test_openai_prediction_param_mock():
async def test_openai_prediction_param_mock(respx_mock: MockRouter):
""" """
Tests that prediction parameter is correctly passed to the API Tests that prediction parameter is correctly passed to the API
""" """
@ -92,38 +91,15 @@ async def test_openai_prediction_param_mock(respx_mock: MockRouter):
public string Username { get; set; } public string Username { get; set; }
} }
""" """
from openai import AsyncOpenAI
mock_response = ModelResponse( client = AsyncOpenAI(api_key="fake-api-key")
id="chatcmpl-AQ5RmV8GvVSRxEcDxnuXlQnsibiY9",
choices=[
Choices(
message=Message(
content=code.replace("Username", "Email").replace(
"username", "email"
),
role="assistant",
)
)
],
created=int(datetime.now().timestamp()),
model="gpt-4o-mini-2024-07-18",
usage={
"completion_tokens": 207,
"prompt_tokens": 175,
"total_tokens": 382,
"completion_tokens_details": {
"accepted_prediction_tokens": 0,
"reasoning_tokens": 0,
"rejected_prediction_tokens": 80,
},
},
)
mock_request = respx_mock.post("https://api.openai.com/v1/chat/completions").mock( with patch.object(
return_value=httpx.Response(200, json=mock_response.dict()) client.chat.completions.with_raw_response, "create"
) ) as mock_client:
try:
completion = await litellm.acompletion( await litellm.acompletion(
model="gpt-4o-mini", model="gpt-4o-mini",
messages=[ messages=[
{ {
@ -133,20 +109,19 @@ async def test_openai_prediction_param_mock(respx_mock: MockRouter):
{"role": "user", "content": code}, {"role": "user", "content": code},
], ],
prediction={"type": "content", "content": code}, prediction={"type": "content", "content": code},
client=client,
) )
except Exception as e:
print(f"Error: {e}")
assert mock_request.called mock_client.assert_called_once()
request_body = json.loads(mock_request.calls[0].request.content) request_body = mock_client.call_args.kwargs
# Verify the request contains the prediction parameter # Verify the request contains the prediction parameter
assert "prediction" in request_body assert "prediction" in request_body
# verify prediction is correctly sent to the API # verify prediction is correctly sent to the API
assert request_body["prediction"] == {"type": "content", "content": code} assert request_body["prediction"] == {"type": "content", "content": code}
# Verify the completion tokens details
assert completion.usage.completion_tokens_details.accepted_prediction_tokens == 0
assert completion.usage.completion_tokens_details.rejected_prediction_tokens == 80
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_openai_prediction_param_with_caching(): async def test_openai_prediction_param_with_caching():
@ -223,3 +198,80 @@ async def test_openai_prediction_param_with_caching():
) )
assert completion_response_3.id != completion_response_1.id assert completion_response_3.id != completion_response_1.id
@pytest.mark.asyncio()
@pytest.mark.respx
async def test_vision_with_custom_model(respx_mock: MockRouter):
"""
Tests that an OpenAI compatible endpoint when sent an image will receive the image in the request
"""
import base64
import requests
litellm.set_verbose = True
api_base = "https://my-custom.api.openai.com"
# Fetch and encode a test image
url = "https://dummyimage.com/100/100/fff&text=Test+image"
response = requests.get(url)
file_data = response.content
encoded_file = base64.b64encode(file_data).decode("utf-8")
base64_image = f"data:image/png;base64,{encoded_file}"
mock_response = ModelResponse(
id="cmpl-mock",
choices=[Choices(message=Message(content="Mocked response", role="assistant"))],
created=int(datetime.now().timestamp()),
model="my-custom-model",
)
mock_request = respx_mock.post(f"{api_base}/chat/completions").mock(
return_value=httpx.Response(200, json=mock_response.dict())
)
response = await litellm.acompletion(
model="openai/my-custom-model",
max_tokens=10,
api_base=api_base, # use the mock api
messages=[
{
"role": "user",
"content": [
{"type": "text", "text": "What's in this image?"},
{
"type": "image_url",
"image_url": {"url": base64_image},
},
],
}
],
)
assert mock_request.called
request_body = json.loads(mock_request.calls[0].request.content)
print("request_body: ", request_body)
assert request_body == {
"messages": [
{
"role": "user",
"content": [
{"type": "text", "text": "What's in this image?"},
{
"type": "image_url",
"image_url": {
"url": "data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAGQAAABkBAMAAACCzIhnAAAAG1BMVEURAAD///+ln5/h39/Dv79qX18uHx+If39MPz9oMSdmAAAACXBIWXMAAA7EAAAOxAGVKw4bAAABB0lEQVRYhe2SzWrEIBCAh2A0jxEs4j6GLDS9hqWmV5Flt0cJS+lRwv742DXpEjY1kOZW6HwHFZnPmVEBEARBEARB/jd0KYA/bcUYbPrRLh6amXHJ/K+ypMoyUaGthILzw0l+xI0jsO7ZcmCcm4ILd+QuVYgpHOmDmz6jBeJImdcUCmeBqQpuqRIbVmQsLCrAalrGpfoEqEogqbLTWuXCPCo+Ki1XGqgQ+jVVuhB8bOaHkvmYuzm/b0KYLWwoK58oFqi6XfxQ4Uz7d6WeKpna6ytUs5e8betMcqAv5YPC5EZB2Lm9FIn0/VP6R58+/GEY1X1egVoZ/3bt/EqF6malgSAIgiDIH+QL41409QMY0LMAAAAASUVORK5CYII="
},
},
],
}
],
"model": "my-custom-model",
"max_tokens": 10,
}
print(f"response: {response}")
assert isinstance(response, ModelResponse)

View file

@ -2,7 +2,7 @@ import json
import os import os
import sys import sys
from datetime import datetime from datetime import datetime
from unittest.mock import AsyncMock from unittest.mock import AsyncMock, patch, MagicMock
sys.path.insert( sys.path.insert(
0, os.path.abspath("../..") 0, os.path.abspath("../..")
@ -18,87 +18,75 @@ from litellm import Choices, Message, ModelResponse
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.respx async def test_o1_handle_system_role():
async def test_o1_handle_system_role(respx_mock: MockRouter):
""" """
Tests that: Tests that:
- max_tokens is translated to 'max_completion_tokens' - max_tokens is translated to 'max_completion_tokens'
- role 'system' is translated to 'user' - role 'system' is translated to 'user'
""" """
from openai import AsyncOpenAI
litellm.set_verbose = True litellm.set_verbose = True
mock_response = ModelResponse( client = AsyncOpenAI(api_key="fake-api-key")
id="cmpl-mock",
choices=[Choices(message=Message(content="Mocked response", role="assistant"))],
created=int(datetime.now().timestamp()),
model="o1-preview",
)
mock_request = respx_mock.post("https://api.openai.com/v1/chat/completions").mock( with patch.object(
return_value=httpx.Response(200, json=mock_response.dict()) client.chat.completions.with_raw_response, "create"
) ) as mock_client:
try:
response = await litellm.acompletion( await litellm.acompletion(
model="o1-preview", model="o1-preview",
max_tokens=10, max_tokens=10,
messages=[{"role": "system", "content": "Hello!"}], messages=[{"role": "system", "content": "Hello!"}],
client=client,
) )
except Exception as e:
print(f"Error: {e}")
assert mock_request.called mock_client.assert_called_once()
request_body = json.loads(mock_request.calls[0].request.content) request_body = mock_client.call_args.kwargs
print("request_body: ", request_body) print("request_body: ", request_body)
assert request_body == { assert request_body["model"] == "o1-preview"
"model": "o1-preview", assert request_body["max_completion_tokens"] == 10
"max_completion_tokens": 10, assert request_body["messages"] == [{"role": "user", "content": "Hello!"}]
"messages": [{"role": "user", "content": "Hello!"}],
}
print(f"response: {response}")
assert isinstance(response, ModelResponse)
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.respx
@pytest.mark.parametrize("model", ["gpt-4", "gpt-4-0314", "gpt-4-32k", "o1-preview"]) @pytest.mark.parametrize("model", ["gpt-4", "gpt-4-0314", "gpt-4-32k", "o1-preview"])
async def test_o1_max_completion_tokens(respx_mock: MockRouter, model: str): async def test_o1_max_completion_tokens(model: str):
""" """
Tests that: Tests that:
- max_completion_tokens is passed directly to OpenAI chat completion models - max_completion_tokens is passed directly to OpenAI chat completion models
""" """
from openai import AsyncOpenAI
litellm.set_verbose = True litellm.set_verbose = True
mock_response = ModelResponse( client = AsyncOpenAI(api_key="fake-api-key")
id="cmpl-mock",
choices=[Choices(message=Message(content="Mocked response", role="assistant"))],
created=int(datetime.now().timestamp()),
model=model,
)
mock_request = respx_mock.post("https://api.openai.com/v1/chat/completions").mock( with patch.object(
return_value=httpx.Response(200, json=mock_response.dict()) client.chat.completions.with_raw_response, "create"
) ) as mock_client:
try:
response = await litellm.acompletion( await litellm.acompletion(
model=model, model=model,
max_completion_tokens=10, max_completion_tokens=10,
messages=[{"role": "user", "content": "Hello!"}], messages=[{"role": "user", "content": "Hello!"}],
client=client,
) )
except Exception as e:
print(f"Error: {e}")
assert mock_request.called mock_client.assert_called_once()
request_body = json.loads(mock_request.calls[0].request.content) request_body = mock_client.call_args.kwargs
print("request_body: ", request_body) print("request_body: ", request_body)
assert request_body == { assert request_body["model"] == model
"model": model, assert request_body["max_completion_tokens"] == 10
"max_completion_tokens": 10, assert request_body["messages"] == [{"role": "user", "content": "Hello!"}]
"messages": [{"role": "user", "content": "Hello!"}],
}
print(f"response: {response}")
assert isinstance(response, ModelResponse)
def test_litellm_responses(): def test_litellm_responses():

View file

@ -1,94 +0,0 @@
import json
import os
import sys
from datetime import datetime
from unittest.mock import AsyncMock
sys.path.insert(
0, os.path.abspath("../..")
) # Adds the parent directory to the system path
import httpx
import pytest
from respx import MockRouter
import litellm
from litellm import Choices, Message, ModelResponse
@pytest.mark.asyncio()
@pytest.mark.respx
async def test_vision_with_custom_model(respx_mock: MockRouter):
"""
Tests that an OpenAI compatible endpoint when sent an image will receive the image in the request
"""
import base64
import requests
litellm.set_verbose = True
api_base = "https://my-custom.api.openai.com"
# Fetch and encode a test image
url = "https://dummyimage.com/100/100/fff&text=Test+image"
response = requests.get(url)
file_data = response.content
encoded_file = base64.b64encode(file_data).decode("utf-8")
base64_image = f"data:image/png;base64,{encoded_file}"
mock_response = ModelResponse(
id="cmpl-mock",
choices=[Choices(message=Message(content="Mocked response", role="assistant"))],
created=int(datetime.now().timestamp()),
model="my-custom-model",
)
mock_request = respx_mock.post(f"{api_base}/chat/completions").mock(
return_value=httpx.Response(200, json=mock_response.dict())
)
response = await litellm.acompletion(
model="openai/my-custom-model",
max_tokens=10,
api_base=api_base, # use the mock api
messages=[
{
"role": "user",
"content": [
{"type": "text", "text": "What's in this image?"},
{
"type": "image_url",
"image_url": {"url": base64_image},
},
],
}
],
)
assert mock_request.called
request_body = json.loads(mock_request.calls[0].request.content)
print("request_body: ", request_body)
assert request_body == {
"messages": [
{
"role": "user",
"content": [
{"type": "text", "text": "What's in this image?"},
{
"type": "image_url",
"image_url": {
"url": "data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAGQAAABkBAMAAACCzIhnAAAAG1BMVEURAAD///+ln5/h39/Dv79qX18uHx+If39MPz9oMSdmAAAACXBIWXMAAA7EAAAOxAGVKw4bAAABB0lEQVRYhe2SzWrEIBCAh2A0jxEs4j6GLDS9hqWmV5Flt0cJS+lRwv742DXpEjY1kOZW6HwHFZnPmVEBEARBEARB/jd0KYA/bcUYbPrRLh6amXHJ/K+ypMoyUaGthILzw0l+xI0jsO7ZcmCcm4ILd+QuVYgpHOmDmz6jBeJImdcUCmeBqQpuqRIbVmQsLCrAalrGpfoEqEogqbLTWuXCPCo+Ki1XGqgQ+jVVuhB8bOaHkvmYuzm/b0KYLWwoK58oFqi6XfxQ4Uz7d6WeKpna6ytUs5e8betMcqAv5YPC5EZB2Lm9FIn0/VP6R58+/GEY1X1egVoZ/3bt/EqF6malgSAIgiDIH+QL41409QMY0LMAAAAASUVORK5CYII="
},
},
],
}
],
"model": "my-custom-model",
"max_tokens": 10,
}
print(f"response: {response}")
assert isinstance(response, ModelResponse)

View file

@ -6,6 +6,7 @@ from unittest.mock import AsyncMock
import pytest import pytest
import httpx import httpx
from respx import MockRouter from respx import MockRouter
from unittest.mock import patch, MagicMock, AsyncMock
sys.path.insert( sys.path.insert(
0, os.path.abspath("../..") 0, os.path.abspath("../..")
@ -68,13 +69,16 @@ def test_convert_dict_to_text_completion_response():
assert response.choices[0].logprobs.top_logprobs == [None, {",": -2.1568563}] assert response.choices[0].logprobs.top_logprobs == [None, {",": -2.1568563}]
@pytest.mark.skip(
reason="need to migrate huggingface to support httpx client being passed in"
)
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.respx @pytest.mark.respx
async def test_huggingface_text_completion_logprobs(respx_mock: MockRouter): async def test_huggingface_text_completion_logprobs():
"""Test text completion with Hugging Face, focusing on logprobs structure""" """Test text completion with Hugging Face, focusing on logprobs structure"""
litellm.set_verbose = True litellm.set_verbose = True
from litellm.llms.custom_httpx.http_handler import HTTPHandler, AsyncHTTPHandler
# Mock the raw response from Hugging Face
mock_response = [ mock_response = [
{ {
"generated_text": ",\n\nI have a question...", # truncated for brevity "generated_text": ",\n\nI have a question...", # truncated for brevity
@ -91,19 +95,21 @@ async def test_huggingface_text_completion_logprobs(respx_mock: MockRouter):
} }
] ]
# Mock the API request return_val = AsyncMock()
mock_request = respx_mock.post(
"https://api-inference.huggingface.co/models/mistralai/Mistral-7B-v0.1"
).mock(return_value=httpx.Response(200, json=mock_response))
return_val.json.return_value = mock_response
client = AsyncHTTPHandler()
with patch.object(client, "post", return_value=return_val) as mock_post:
response = await litellm.atext_completion( response = await litellm.atext_completion(
model="huggingface/mistralai/Mistral-7B-v0.1", model="huggingface/mistralai/Mistral-7B-v0.1",
prompt="good morning", prompt="good morning",
client=client,
) )
# Verify the request # Verify the request
assert mock_request.called mock_post.assert_called_once()
request_body = json.loads(mock_request.calls[0].request.content) request_body = json.loads(mock_post.call_args.kwargs["data"])
assert request_body == { assert request_body == {
"inputs": "good morning", "inputs": "good morning",
"parameters": {"details": True, "return_full_text": False}, "parameters": {"details": True, "return_full_text": False},

View file

@ -33,7 +33,7 @@ from litellm.router import Router
@pytest.mark.asyncio() @pytest.mark.asyncio()
@pytest.mark.respx() @pytest.mark.respx()
async def test_azure_tenant_id_auth(respx_mock: MockRouter): async def test_aaaaazure_tenant_id_auth(respx_mock: MockRouter):
""" """
Tests when we set tenant_id, client_id, client_secret they don't get sent with the request Tests when we set tenant_id, client_id, client_secret they don't get sent with the request