Improve mocking in test_proxy_server

Mock the calls to the backend and assert that the correct parameters are passed
to the backend.
This commit is contained in:
Marc Abramowitz 2024-05-02 13:36:23 -07:00
parent 762a1fbd50
commit 14e7c9b01c

View file

@ -2,7 +2,6 @@ import sys, os
import traceback import traceback
from unittest import mock from unittest import mock
from dotenv import load_dotenv from dotenv import load_dotenv
import contextlib
load_dotenv() load_dotenv()
import os, io import os, io
@ -47,15 +46,66 @@ example_completion_result = {
} }
], ],
} }
example_embedding_result = {
"object": "list",
"data": [
{
"object": "embedding",
"index": 0,
"embedding": [
-0.006929283495992422,
-0.005336422007530928,
-4.547132266452536e-05,
-0.024047505110502243,
-0.006929283495992422,
-0.005336422007530928,
-4.547132266452536e-05,
-0.024047505110502243,
-0.006929283495992422,
-0.005336422007530928,
-4.547132266452536e-05,
-0.024047505110502243,
],
}
],
"model": "text-embedding-3-small",
"usage": {
"prompt_tokens": 5,
"total_tokens": 5
}
}
example_image_generation_result = {
"created": 1589478378,
"data": [
{
"url": "https://..."
},
{
"url": "https://..."
}
]
}
@contextlib.contextmanager
def mock_patch_acompletion(): def mock_patch_acompletion():
with mock.patch( return mock.patch(
"litellm.proxy.proxy_server.llm_router.acompletion", "litellm.proxy.proxy_server.llm_router.acompletion",
return_value=example_completion_result, return_value=example_completion_result,
): )
yield
def mock_patch_aembedding():
return mock.patch(
"litellm.proxy.proxy_server.llm_router.aembedding",
return_value=example_embedding_result,
)
def mock_patch_aimage_generation():
return mock.patch(
"litellm.proxy.proxy_server.llm_router.aimage_generation",
return_value=example_image_generation_result,
)
@pytest.fixture(scope="function") @pytest.fixture(scope="function")
@ -74,7 +124,8 @@ def client_no_auth():
return TestClient(app) return TestClient(app)
def test_chat_completion(client_no_auth): @mock_patch_acompletion()
def test_chat_completion(mock_acompletion, client_no_auth):
global headers global headers
try: try:
# Your test data # Your test data
@ -88,6 +139,19 @@ def test_chat_completion(client_no_auth):
print("testing proxy server with chat completions") print("testing proxy server with chat completions")
response = client_no_auth.post("/v1/chat/completions", json=test_data) response = client_no_auth.post("/v1/chat/completions", json=test_data)
mock_acompletion.assert_called_once_with(
model="gpt-3.5-turbo",
messages=[
{"role": "user", "content": "hi"},
],
max_tokens=10,
litellm_call_id=mock.ANY,
litellm_logging_obj=mock.ANY,
request_timeout=mock.ANY,
specific_deployment=True,
metadata=mock.ANY,
proxy_server_request=mock.ANY,
)
print(f"response - {response.text}") print(f"response - {response.text}")
assert response.status_code == 200 assert response.status_code == 200
result = response.json() result = response.json()
@ -99,7 +163,8 @@ def test_chat_completion(client_no_auth):
# Run the test # Run the test
def test_chat_completion_azure(client_no_auth): @mock_patch_acompletion()
def test_chat_completion_azure(mock_acompletion, client_no_auth):
global headers global headers
try: try:
# Your test data # Your test data
@ -114,6 +179,19 @@ def test_chat_completion_azure(client_no_auth):
print("testing proxy server with Azure Request /chat/completions") print("testing proxy server with Azure Request /chat/completions")
response = client_no_auth.post("/v1/chat/completions", json=test_data) response = client_no_auth.post("/v1/chat/completions", json=test_data)
mock_acompletion.assert_called_once_with(
model="azure/chatgpt-v-2",
messages=[
{"role": "user", "content": "write 1 sentence poem"},
],
max_tokens=10,
litellm_call_id=mock.ANY,
litellm_logging_obj=mock.ANY,
request_timeout=mock.ANY,
specific_deployment=True,
metadata=mock.ANY,
proxy_server_request=mock.ANY,
)
assert response.status_code == 200 assert response.status_code == 200
result = response.json() result = response.json()
print(f"Received response: {result}") print(f"Received response: {result}")
@ -127,7 +205,7 @@ def test_chat_completion_azure(client_no_auth):
@mock_patch_acompletion() @mock_patch_acompletion()
def test_openai_deployments_model_chat_completions_azure(client_no_auth): def test_openai_deployments_model_chat_completions_azure(mock_acompletion, client_no_auth):
global headers global headers
try: try:
# Your test data # Your test data
@ -143,6 +221,19 @@ def test_openai_deployments_model_chat_completions_azure(client_no_auth):
print(f"testing proxy server with Azure Request {url}") print(f"testing proxy server with Azure Request {url}")
response = client_no_auth.post(url, json=test_data) response = client_no_auth.post(url, json=test_data)
mock_acompletion.assert_called_once_with(
model="azure/chatgpt-v-2",
messages=[
{"role": "user", "content": "write 1 sentence poem"},
],
max_tokens=10,
litellm_call_id=mock.ANY,
litellm_logging_obj=mock.ANY,
request_timeout=mock.ANY,
specific_deployment=True,
metadata=mock.ANY,
proxy_server_request=mock.ANY,
)
assert response.status_code == 200 assert response.status_code == 200
result = response.json() result = response.json()
print(f"Received response: {result}") print(f"Received response: {result}")
@ -156,7 +247,8 @@ def test_openai_deployments_model_chat_completions_azure(client_no_auth):
### EMBEDDING ### EMBEDDING
def test_embedding(client_no_auth): @mock_patch_aembedding()
def test_embedding(mock_aembedding, client_no_auth):
global headers global headers
from litellm.proxy.proxy_server import user_custom_auth from litellm.proxy.proxy_server import user_custom_auth
@ -168,6 +260,13 @@ def test_embedding(client_no_auth):
response = client_no_auth.post("/v1/embeddings", json=test_data) response = client_no_auth.post("/v1/embeddings", json=test_data)
mock_aembedding.assert_called_once_with(
model="azure/azure-embedding-model",
input=["good morning from litellm"],
specific_deployment=True,
metadata=mock.ANY,
proxy_server_request=mock.ANY,
)
assert response.status_code == 200 assert response.status_code == 200
result = response.json() result = response.json()
print(len(result["data"][0]["embedding"])) print(len(result["data"][0]["embedding"]))
@ -176,7 +275,8 @@ def test_embedding(client_no_auth):
pytest.fail(f"LiteLLM Proxy test failed. Exception - {str(e)}") pytest.fail(f"LiteLLM Proxy test failed. Exception - {str(e)}")
def test_bedrock_embedding(client_no_auth): @mock_patch_aembedding()
def test_bedrock_embedding(mock_aembedding, client_no_auth):
global headers global headers
from litellm.proxy.proxy_server import user_custom_auth from litellm.proxy.proxy_server import user_custom_auth
@ -188,6 +288,12 @@ def test_bedrock_embedding(client_no_auth):
response = client_no_auth.post("/v1/embeddings", json=test_data) response = client_no_auth.post("/v1/embeddings", json=test_data)
mock_aembedding.assert_called_once_with(
model="amazon-embeddings",
input=["good morning from litellm"],
metadata=mock.ANY,
proxy_server_request=mock.ANY,
)
assert response.status_code == 200 assert response.status_code == 200
result = response.json() result = response.json()
print(len(result["data"][0]["embedding"])) print(len(result["data"][0]["embedding"]))
@ -222,7 +328,8 @@ def test_sagemaker_embedding(client_no_auth):
#### IMAGE GENERATION #### IMAGE GENERATION
def test_img_gen(client_no_auth): @mock_patch_aimage_generation()
def test_img_gen(mock_aimage_generation, client_no_auth):
global headers global headers
from litellm.proxy.proxy_server import user_custom_auth from litellm.proxy.proxy_server import user_custom_auth
@ -236,6 +343,14 @@ def test_img_gen(client_no_auth):
response = client_no_auth.post("/v1/images/generations", json=test_data) response = client_no_auth.post("/v1/images/generations", json=test_data)
mock_aimage_generation.assert_called_once_with(
model='dall-e-3',
prompt='A cute baby sea otter',
n=1,
size='1024x1024',
metadata=mock.ANY,
proxy_server_request=mock.ANY,
)
assert response.status_code == 200 assert response.status_code == 200
result = response.json() result = response.json()
print(len(result["data"][0]["url"])) print(len(result["data"][0]["url"]))
@ -300,7 +415,8 @@ class MyCustomHandler(CustomLogger):
customHandler = MyCustomHandler() customHandler = MyCustomHandler()
def test_chat_completion_optional_params(client_no_auth): @mock_patch_acompletion()
def test_chat_completion_optional_params(mock_acompletion, client_no_auth):
# [PROXY: PROD TEST] - DO NOT DELETE # [PROXY: PROD TEST] - DO NOT DELETE
# This tests if all the /chat/completion params are passed to litellm # This tests if all the /chat/completion params are passed to litellm
try: try:
@ -318,6 +434,20 @@ def test_chat_completion_optional_params(client_no_auth):
litellm.callbacks = [customHandler] litellm.callbacks = [customHandler]
print("testing proxy server: optional params") print("testing proxy server: optional params")
response = client_no_auth.post("/v1/chat/completions", json=test_data) response = client_no_auth.post("/v1/chat/completions", json=test_data)
mock_acompletion.assert_called_once_with(
model="gpt-3.5-turbo",
messages=[
{"role": "user", "content": "hi"},
],
max_tokens=10,
user="proxy-user",
litellm_call_id=mock.ANY,
litellm_logging_obj=mock.ANY,
request_timeout=mock.ANY,
specific_deployment=True,
metadata=mock.ANY,
proxy_server_request=mock.ANY,
)
assert response.status_code == 200 assert response.status_code == 200
result = response.json() result = response.json()
print(f"Received response: {result}") print(f"Received response: {result}")