forked from phoenix/litellm-mirror
Improve mocking in test_proxy_server
Mock the calls to the backend and assert that the correct parameters are passed to the backend.
This commit is contained in:
parent
762a1fbd50
commit
14e7c9b01c
1 changed files with 142 additions and 12 deletions
|
@ -2,7 +2,6 @@ import sys, os
|
||||||
import traceback
|
import traceback
|
||||||
from unittest import mock
|
from unittest import mock
|
||||||
from dotenv import load_dotenv
|
from dotenv import load_dotenv
|
||||||
import contextlib
|
|
||||||
|
|
||||||
load_dotenv()
|
load_dotenv()
|
||||||
import os, io
|
import os, io
|
||||||
|
@ -47,15 +46,66 @@ example_completion_result = {
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
}
|
}
|
||||||
|
example_embedding_result = {
|
||||||
|
"object": "list",
|
||||||
|
"data": [
|
||||||
|
{
|
||||||
|
"object": "embedding",
|
||||||
|
"index": 0,
|
||||||
|
"embedding": [
|
||||||
|
-0.006929283495992422,
|
||||||
|
-0.005336422007530928,
|
||||||
|
-4.547132266452536e-05,
|
||||||
|
-0.024047505110502243,
|
||||||
|
-0.006929283495992422,
|
||||||
|
-0.005336422007530928,
|
||||||
|
-4.547132266452536e-05,
|
||||||
|
-0.024047505110502243,
|
||||||
|
-0.006929283495992422,
|
||||||
|
-0.005336422007530928,
|
||||||
|
-4.547132266452536e-05,
|
||||||
|
-0.024047505110502243,
|
||||||
|
],
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"model": "text-embedding-3-small",
|
||||||
|
"usage": {
|
||||||
|
"prompt_tokens": 5,
|
||||||
|
"total_tokens": 5
|
||||||
|
}
|
||||||
|
}
|
||||||
|
example_image_generation_result = {
|
||||||
|
"created": 1589478378,
|
||||||
|
"data": [
|
||||||
|
{
|
||||||
|
"url": "https://..."
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"url": "https://..."
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
@contextlib.contextmanager
|
|
||||||
def mock_patch_acompletion():
|
def mock_patch_acompletion():
|
||||||
with mock.patch(
|
return mock.patch(
|
||||||
"litellm.proxy.proxy_server.llm_router.acompletion",
|
"litellm.proxy.proxy_server.llm_router.acompletion",
|
||||||
return_value=example_completion_result,
|
return_value=example_completion_result,
|
||||||
):
|
)
|
||||||
yield
|
|
||||||
|
|
||||||
|
def mock_patch_aembedding():
|
||||||
|
return mock.patch(
|
||||||
|
"litellm.proxy.proxy_server.llm_router.aembedding",
|
||||||
|
return_value=example_embedding_result,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def mock_patch_aimage_generation():
|
||||||
|
return mock.patch(
|
||||||
|
"litellm.proxy.proxy_server.llm_router.aimage_generation",
|
||||||
|
return_value=example_image_generation_result,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="function")
|
@pytest.fixture(scope="function")
|
||||||
|
@ -74,7 +124,8 @@ def client_no_auth():
|
||||||
return TestClient(app)
|
return TestClient(app)
|
||||||
|
|
||||||
|
|
||||||
def test_chat_completion(client_no_auth):
|
@mock_patch_acompletion()
|
||||||
|
def test_chat_completion(mock_acompletion, client_no_auth):
|
||||||
global headers
|
global headers
|
||||||
try:
|
try:
|
||||||
# Your test data
|
# Your test data
|
||||||
|
@ -88,6 +139,19 @@ def test_chat_completion(client_no_auth):
|
||||||
|
|
||||||
print("testing proxy server with chat completions")
|
print("testing proxy server with chat completions")
|
||||||
response = client_no_auth.post("/v1/chat/completions", json=test_data)
|
response = client_no_auth.post("/v1/chat/completions", json=test_data)
|
||||||
|
mock_acompletion.assert_called_once_with(
|
||||||
|
model="gpt-3.5-turbo",
|
||||||
|
messages=[
|
||||||
|
{"role": "user", "content": "hi"},
|
||||||
|
],
|
||||||
|
max_tokens=10,
|
||||||
|
litellm_call_id=mock.ANY,
|
||||||
|
litellm_logging_obj=mock.ANY,
|
||||||
|
request_timeout=mock.ANY,
|
||||||
|
specific_deployment=True,
|
||||||
|
metadata=mock.ANY,
|
||||||
|
proxy_server_request=mock.ANY,
|
||||||
|
)
|
||||||
print(f"response - {response.text}")
|
print(f"response - {response.text}")
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
result = response.json()
|
result = response.json()
|
||||||
|
@ -99,7 +163,8 @@ def test_chat_completion(client_no_auth):
|
||||||
# Run the test
|
# Run the test
|
||||||
|
|
||||||
|
|
||||||
def test_chat_completion_azure(client_no_auth):
|
@mock_patch_acompletion()
|
||||||
|
def test_chat_completion_azure(mock_acompletion, client_no_auth):
|
||||||
global headers
|
global headers
|
||||||
try:
|
try:
|
||||||
# Your test data
|
# Your test data
|
||||||
|
@ -114,6 +179,19 @@ def test_chat_completion_azure(client_no_auth):
|
||||||
print("testing proxy server with Azure Request /chat/completions")
|
print("testing proxy server with Azure Request /chat/completions")
|
||||||
response = client_no_auth.post("/v1/chat/completions", json=test_data)
|
response = client_no_auth.post("/v1/chat/completions", json=test_data)
|
||||||
|
|
||||||
|
mock_acompletion.assert_called_once_with(
|
||||||
|
model="azure/chatgpt-v-2",
|
||||||
|
messages=[
|
||||||
|
{"role": "user", "content": "write 1 sentence poem"},
|
||||||
|
],
|
||||||
|
max_tokens=10,
|
||||||
|
litellm_call_id=mock.ANY,
|
||||||
|
litellm_logging_obj=mock.ANY,
|
||||||
|
request_timeout=mock.ANY,
|
||||||
|
specific_deployment=True,
|
||||||
|
metadata=mock.ANY,
|
||||||
|
proxy_server_request=mock.ANY,
|
||||||
|
)
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
result = response.json()
|
result = response.json()
|
||||||
print(f"Received response: {result}")
|
print(f"Received response: {result}")
|
||||||
|
@ -127,7 +205,7 @@ def test_chat_completion_azure(client_no_auth):
|
||||||
|
|
||||||
|
|
||||||
@mock_patch_acompletion()
|
@mock_patch_acompletion()
|
||||||
def test_openai_deployments_model_chat_completions_azure(client_no_auth):
|
def test_openai_deployments_model_chat_completions_azure(mock_acompletion, client_no_auth):
|
||||||
global headers
|
global headers
|
||||||
try:
|
try:
|
||||||
# Your test data
|
# Your test data
|
||||||
|
@ -143,6 +221,19 @@ def test_openai_deployments_model_chat_completions_azure(client_no_auth):
|
||||||
print(f"testing proxy server with Azure Request {url}")
|
print(f"testing proxy server with Azure Request {url}")
|
||||||
response = client_no_auth.post(url, json=test_data)
|
response = client_no_auth.post(url, json=test_data)
|
||||||
|
|
||||||
|
mock_acompletion.assert_called_once_with(
|
||||||
|
model="azure/chatgpt-v-2",
|
||||||
|
messages=[
|
||||||
|
{"role": "user", "content": "write 1 sentence poem"},
|
||||||
|
],
|
||||||
|
max_tokens=10,
|
||||||
|
litellm_call_id=mock.ANY,
|
||||||
|
litellm_logging_obj=mock.ANY,
|
||||||
|
request_timeout=mock.ANY,
|
||||||
|
specific_deployment=True,
|
||||||
|
metadata=mock.ANY,
|
||||||
|
proxy_server_request=mock.ANY,
|
||||||
|
)
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
result = response.json()
|
result = response.json()
|
||||||
print(f"Received response: {result}")
|
print(f"Received response: {result}")
|
||||||
|
@ -156,7 +247,8 @@ def test_openai_deployments_model_chat_completions_azure(client_no_auth):
|
||||||
|
|
||||||
|
|
||||||
### EMBEDDING
|
### EMBEDDING
|
||||||
def test_embedding(client_no_auth):
|
@mock_patch_aembedding()
|
||||||
|
def test_embedding(mock_aembedding, client_no_auth):
|
||||||
global headers
|
global headers
|
||||||
from litellm.proxy.proxy_server import user_custom_auth
|
from litellm.proxy.proxy_server import user_custom_auth
|
||||||
|
|
||||||
|
@ -168,6 +260,13 @@ def test_embedding(client_no_auth):
|
||||||
|
|
||||||
response = client_no_auth.post("/v1/embeddings", json=test_data)
|
response = client_no_auth.post("/v1/embeddings", json=test_data)
|
||||||
|
|
||||||
|
mock_aembedding.assert_called_once_with(
|
||||||
|
model="azure/azure-embedding-model",
|
||||||
|
input=["good morning from litellm"],
|
||||||
|
specific_deployment=True,
|
||||||
|
metadata=mock.ANY,
|
||||||
|
proxy_server_request=mock.ANY,
|
||||||
|
)
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
result = response.json()
|
result = response.json()
|
||||||
print(len(result["data"][0]["embedding"]))
|
print(len(result["data"][0]["embedding"]))
|
||||||
|
@ -176,7 +275,8 @@ def test_embedding(client_no_auth):
|
||||||
pytest.fail(f"LiteLLM Proxy test failed. Exception - {str(e)}")
|
pytest.fail(f"LiteLLM Proxy test failed. Exception - {str(e)}")
|
||||||
|
|
||||||
|
|
||||||
def test_bedrock_embedding(client_no_auth):
|
@mock_patch_aembedding()
|
||||||
|
def test_bedrock_embedding(mock_aembedding, client_no_auth):
|
||||||
global headers
|
global headers
|
||||||
from litellm.proxy.proxy_server import user_custom_auth
|
from litellm.proxy.proxy_server import user_custom_auth
|
||||||
|
|
||||||
|
@ -188,6 +288,12 @@ def test_bedrock_embedding(client_no_auth):
|
||||||
|
|
||||||
response = client_no_auth.post("/v1/embeddings", json=test_data)
|
response = client_no_auth.post("/v1/embeddings", json=test_data)
|
||||||
|
|
||||||
|
mock_aembedding.assert_called_once_with(
|
||||||
|
model="amazon-embeddings",
|
||||||
|
input=["good morning from litellm"],
|
||||||
|
metadata=mock.ANY,
|
||||||
|
proxy_server_request=mock.ANY,
|
||||||
|
)
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
result = response.json()
|
result = response.json()
|
||||||
print(len(result["data"][0]["embedding"]))
|
print(len(result["data"][0]["embedding"]))
|
||||||
|
@ -222,7 +328,8 @@ def test_sagemaker_embedding(client_no_auth):
|
||||||
#### IMAGE GENERATION
|
#### IMAGE GENERATION
|
||||||
|
|
||||||
|
|
||||||
def test_img_gen(client_no_auth):
|
@mock_patch_aimage_generation()
|
||||||
|
def test_img_gen(mock_aimage_generation, client_no_auth):
|
||||||
global headers
|
global headers
|
||||||
from litellm.proxy.proxy_server import user_custom_auth
|
from litellm.proxy.proxy_server import user_custom_auth
|
||||||
|
|
||||||
|
@ -236,6 +343,14 @@ def test_img_gen(client_no_auth):
|
||||||
|
|
||||||
response = client_no_auth.post("/v1/images/generations", json=test_data)
|
response = client_no_auth.post("/v1/images/generations", json=test_data)
|
||||||
|
|
||||||
|
mock_aimage_generation.assert_called_once_with(
|
||||||
|
model='dall-e-3',
|
||||||
|
prompt='A cute baby sea otter',
|
||||||
|
n=1,
|
||||||
|
size='1024x1024',
|
||||||
|
metadata=mock.ANY,
|
||||||
|
proxy_server_request=mock.ANY,
|
||||||
|
)
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
result = response.json()
|
result = response.json()
|
||||||
print(len(result["data"][0]["url"]))
|
print(len(result["data"][0]["url"]))
|
||||||
|
@ -300,7 +415,8 @@ class MyCustomHandler(CustomLogger):
|
||||||
customHandler = MyCustomHandler()
|
customHandler = MyCustomHandler()
|
||||||
|
|
||||||
|
|
||||||
def test_chat_completion_optional_params(client_no_auth):
|
@mock_patch_acompletion()
|
||||||
|
def test_chat_completion_optional_params(mock_acompletion, client_no_auth):
|
||||||
# [PROXY: PROD TEST] - DO NOT DELETE
|
# [PROXY: PROD TEST] - DO NOT DELETE
|
||||||
# This tests if all the /chat/completion params are passed to litellm
|
# This tests if all the /chat/completion params are passed to litellm
|
||||||
try:
|
try:
|
||||||
|
@ -318,6 +434,20 @@ def test_chat_completion_optional_params(client_no_auth):
|
||||||
litellm.callbacks = [customHandler]
|
litellm.callbacks = [customHandler]
|
||||||
print("testing proxy server: optional params")
|
print("testing proxy server: optional params")
|
||||||
response = client_no_auth.post("/v1/chat/completions", json=test_data)
|
response = client_no_auth.post("/v1/chat/completions", json=test_data)
|
||||||
|
mock_acompletion.assert_called_once_with(
|
||||||
|
model="gpt-3.5-turbo",
|
||||||
|
messages=[
|
||||||
|
{"role": "user", "content": "hi"},
|
||||||
|
],
|
||||||
|
max_tokens=10,
|
||||||
|
user="proxy-user",
|
||||||
|
litellm_call_id=mock.ANY,
|
||||||
|
litellm_logging_obj=mock.ANY,
|
||||||
|
request_timeout=mock.ANY,
|
||||||
|
specific_deployment=True,
|
||||||
|
metadata=mock.ANY,
|
||||||
|
proxy_server_request=mock.ANY,
|
||||||
|
)
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
result = response.json()
|
result = response.json()
|
||||||
print(f"Received response: {result}")
|
print(f"Received response: {result}")
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue