diff --git a/litellm/tests/test_proxy_server.py b/litellm/tests/test_proxy_server.py index 51574c7ab..43a070556 100644 --- a/litellm/tests/test_proxy_server.py +++ b/litellm/tests/test_proxy_server.py @@ -2,7 +2,6 @@ import sys, os import traceback from unittest import mock from dotenv import load_dotenv -import contextlib load_dotenv() import os, io @@ -47,15 +46,66 @@ example_completion_result = { } ], } +example_embedding_result = { + "object": "list", + "data": [ + { + "object": "embedding", + "index": 0, + "embedding": [ + -0.006929283495992422, + -0.005336422007530928, + -4.547132266452536e-05, + -0.024047505110502243, + -0.006929283495992422, + -0.005336422007530928, + -4.547132266452536e-05, + -0.024047505110502243, + -0.006929283495992422, + -0.005336422007530928, + -4.547132266452536e-05, + -0.024047505110502243, + ], + } + ], + "model": "text-embedding-3-small", + "usage": { + "prompt_tokens": 5, + "total_tokens": 5 + } +} +example_image_generation_result = { + "created": 1589478378, + "data": [ + { + "url": "https://..." + }, + { + "url": "https://..." + } + ] +} -@contextlib.contextmanager def mock_patch_acompletion(): - with mock.patch( + return mock.patch( "litellm.proxy.proxy_server.llm_router.acompletion", return_value=example_completion_result, - ): - yield + ) + + +def mock_patch_aembedding(): + return mock.patch( + "litellm.proxy.proxy_server.llm_router.aembedding", + return_value=example_embedding_result, + ) + + +def mock_patch_aimage_generation(): + return mock.patch( + "litellm.proxy.proxy_server.llm_router.aimage_generation", + return_value=example_image_generation_result, + ) @pytest.fixture(scope="function") @@ -74,7 +124,8 @@ def client_no_auth(): return TestClient(app) -def test_chat_completion(client_no_auth): +@mock_patch_acompletion() +def test_chat_completion(mock_acompletion, client_no_auth): global headers try: # Your test data @@ -88,6 +139,19 @@ def test_chat_completion(client_no_auth): print("testing proxy server with chat completions") response = client_no_auth.post("/v1/chat/completions", json=test_data) + mock_acompletion.assert_called_once_with( + model="gpt-3.5-turbo", + messages=[ + {"role": "user", "content": "hi"}, + ], + max_tokens=10, + litellm_call_id=mock.ANY, + litellm_logging_obj=mock.ANY, + request_timeout=mock.ANY, + specific_deployment=True, + metadata=mock.ANY, + proxy_server_request=mock.ANY, + ) print(f"response - {response.text}") assert response.status_code == 200 result = response.json() @@ -99,7 +163,8 @@ def test_chat_completion(client_no_auth): # Run the test -def test_chat_completion_azure(client_no_auth): +@mock_patch_acompletion() +def test_chat_completion_azure(mock_acompletion, client_no_auth): global headers try: # Your test data @@ -114,6 +179,19 @@ def test_chat_completion_azure(client_no_auth): print("testing proxy server with Azure Request /chat/completions") response = client_no_auth.post("/v1/chat/completions", json=test_data) + mock_acompletion.assert_called_once_with( + model="azure/chatgpt-v-2", + messages=[ + {"role": "user", "content": "write 1 sentence poem"}, + ], + max_tokens=10, + litellm_call_id=mock.ANY, + litellm_logging_obj=mock.ANY, + request_timeout=mock.ANY, + specific_deployment=True, + metadata=mock.ANY, + proxy_server_request=mock.ANY, + ) assert response.status_code == 200 result = response.json() print(f"Received response: {result}") @@ -127,7 +205,7 @@ def test_chat_completion_azure(client_no_auth): @mock_patch_acompletion() -def test_openai_deployments_model_chat_completions_azure(client_no_auth): +def test_openai_deployments_model_chat_completions_azure(mock_acompletion, client_no_auth): global headers try: # Your test data @@ -143,6 +221,19 @@ def test_openai_deployments_model_chat_completions_azure(client_no_auth): print(f"testing proxy server with Azure Request {url}") response = client_no_auth.post(url, json=test_data) + mock_acompletion.assert_called_once_with( + model="azure/chatgpt-v-2", + messages=[ + {"role": "user", "content": "write 1 sentence poem"}, + ], + max_tokens=10, + litellm_call_id=mock.ANY, + litellm_logging_obj=mock.ANY, + request_timeout=mock.ANY, + specific_deployment=True, + metadata=mock.ANY, + proxy_server_request=mock.ANY, + ) assert response.status_code == 200 result = response.json() print(f"Received response: {result}") @@ -156,7 +247,8 @@ def test_openai_deployments_model_chat_completions_azure(client_no_auth): ### EMBEDDING -def test_embedding(client_no_auth): +@mock_patch_aembedding() +def test_embedding(mock_aembedding, client_no_auth): global headers from litellm.proxy.proxy_server import user_custom_auth @@ -168,6 +260,13 @@ def test_embedding(client_no_auth): response = client_no_auth.post("/v1/embeddings", json=test_data) + mock_aembedding.assert_called_once_with( + model="azure/azure-embedding-model", + input=["good morning from litellm"], + specific_deployment=True, + metadata=mock.ANY, + proxy_server_request=mock.ANY, + ) assert response.status_code == 200 result = response.json() print(len(result["data"][0]["embedding"])) @@ -176,7 +275,8 @@ def test_embedding(client_no_auth): pytest.fail(f"LiteLLM Proxy test failed. Exception - {str(e)}") -def test_bedrock_embedding(client_no_auth): +@mock_patch_aembedding() +def test_bedrock_embedding(mock_aembedding, client_no_auth): global headers from litellm.proxy.proxy_server import user_custom_auth @@ -188,6 +288,12 @@ def test_bedrock_embedding(client_no_auth): response = client_no_auth.post("/v1/embeddings", json=test_data) + mock_aembedding.assert_called_once_with( + model="amazon-embeddings", + input=["good morning from litellm"], + metadata=mock.ANY, + proxy_server_request=mock.ANY, + ) assert response.status_code == 200 result = response.json() print(len(result["data"][0]["embedding"])) @@ -222,7 +328,8 @@ def test_sagemaker_embedding(client_no_auth): #### IMAGE GENERATION -def test_img_gen(client_no_auth): +@mock_patch_aimage_generation() +def test_img_gen(mock_aimage_generation, client_no_auth): global headers from litellm.proxy.proxy_server import user_custom_auth @@ -236,6 +343,14 @@ def test_img_gen(client_no_auth): response = client_no_auth.post("/v1/images/generations", json=test_data) + mock_aimage_generation.assert_called_once_with( + model='dall-e-3', + prompt='A cute baby sea otter', + n=1, + size='1024x1024', + metadata=mock.ANY, + proxy_server_request=mock.ANY, + ) assert response.status_code == 200 result = response.json() print(len(result["data"][0]["url"])) @@ -300,7 +415,8 @@ class MyCustomHandler(CustomLogger): customHandler = MyCustomHandler() -def test_chat_completion_optional_params(client_no_auth): +@mock_patch_acompletion() +def test_chat_completion_optional_params(mock_acompletion, client_no_auth): # [PROXY: PROD TEST] - DO NOT DELETE # This tests if all the /chat/completion params are passed to litellm try: @@ -318,6 +434,20 @@ def test_chat_completion_optional_params(client_no_auth): litellm.callbacks = [customHandler] print("testing proxy server: optional params") response = client_no_auth.post("/v1/chat/completions", json=test_data) + mock_acompletion.assert_called_once_with( + model="gpt-3.5-turbo", + messages=[ + {"role": "user", "content": "hi"}, + ], + max_tokens=10, + user="proxy-user", + litellm_call_id=mock.ANY, + litellm_logging_obj=mock.ANY, + request_timeout=mock.ANY, + specific_deployment=True, + metadata=mock.ANY, + proxy_server_request=mock.ANY, + ) assert response.status_code == 200 result = response.json() print(f"Received response: {result}")