diff --git a/.circleci/config.yml b/.circleci/config.yml index 39925ab3e..dbe4ac5fb 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -199,7 +199,7 @@ jobs: command: | pwd ls - python -m pytest tests/local_testing --cov=litellm --cov-report=xml -vv -k "router" -x -s -v --junitxml=test-results/junit.xml --durations=5 + python -m pytest tests/local_testing tests/router_unit_tests --cov=litellm --cov-report=xml -vv -k "router" -x -s -v --junitxml=test-results/junit.xml --durations=5 no_output_timeout: 120m - run: name: Rename the coverage files @@ -380,6 +380,7 @@ jobs: - run: python -c "from litellm import *" || (echo '🚨 import failed, this means you introduced unprotected imports! 🚨'; exit 1) - run: ruff check ./litellm - run: python ./tests/documentation_tests/test_general_setting_keys.py + - run: python ./tests/code_coverage_tests/router_code_coverage.py - run: python ./tests/documentation_tests/test_env_keys.py db_migration_disable_update_check: diff --git a/litellm/litellm_core_utils/mock_functions.py b/litellm/litellm_core_utils/mock_functions.py new file mode 100644 index 000000000..76425651a --- /dev/null +++ b/litellm/litellm_core_utils/mock_functions.py @@ -0,0 +1,28 @@ +from typing import List, Optional + +from ..types.utils import ( + Categories, + CategoryAppliedInputTypes, + CategoryScores, + Embedding, + EmbeddingResponse, + ImageObject, + ImageResponse, + Moderation, + ModerationCreateResponse, +) + + +def mock_embedding(model: str, mock_response: Optional[List[float]]): + if mock_response is None: + mock_response = [0.0] * 1536 + return EmbeddingResponse( + model=model, + data=[Embedding(embedding=mock_response, index=0, object="embedding")], + ) + + +def mock_image_generation(model: str, mock_response: str): + return ImageResponse( + data=[ImageObject(url=mock_response)], + ) diff --git a/litellm/main.py b/litellm/main.py index fe1453836..c1cf3ae23 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -42,6 +42,10 @@ from litellm import ( # type: ignore ) from litellm.integrations.custom_logger import CustomLogger from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj +from litellm.litellm_core_utils.mock_functions import ( + mock_embedding, + mock_image_generation, +) from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler from litellm.secret_managers.main import get_secret_str from litellm.utils import ( @@ -3163,6 +3167,7 @@ def embedding( tpm = kwargs.pop("tpm", None) litellm_logging_obj: LiteLLMLoggingObj = kwargs.get("litellm_logging_obj") # type: ignore cooldown_time = kwargs.get("cooldown_time", None) + mock_response: Optional[List[float]] = kwargs.get("mock_response", None) # type: ignore max_parallel_requests = kwargs.pop("max_parallel_requests", None) model_info = kwargs.get("model_info", None) metadata = kwargs.get("metadata", None) @@ -3268,6 +3273,9 @@ def embedding( custom_llm_provider=custom_llm_provider, **non_default_params, ) + + if mock_response is not None: + return mock_embedding(model=model, mock_response=mock_response) ### REGISTER CUSTOM MODEL PRICING -- IF GIVEN ### if input_cost_per_token is not None and output_cost_per_token is not None: litellm.register_model( @@ -4377,6 +4385,7 @@ def image_generation( aimg_generation = kwargs.get("aimg_generation", False) litellm_call_id = kwargs.get("litellm_call_id", None) logger_fn = kwargs.get("logger_fn", None) + mock_response: Optional[str] = kwargs.get("mock_response", None) # type: ignore proxy_server_request = kwargs.get("proxy_server_request", None) model_info = kwargs.get("model_info", None) metadata = kwargs.get("metadata", {}) @@ -4486,6 +4495,8 @@ def image_generation( }, custom_llm_provider=custom_llm_provider, ) + if mock_response is not None: + return mock_image_generation(model=model, mock_response=mock_response) if custom_llm_provider == "azure": # azure configs diff --git a/litellm/router.py b/litellm/router.py index c31536bd6..26730402f 100644 --- a/litellm/router.py +++ b/litellm/router.py @@ -111,6 +111,7 @@ from litellm.types.router import ( RouterModelGroupAliasItem, RouterRateLimitError, RouterRateLimitErrorBasic, + RoutingStrategy, updateDeployment, updateLiteLLMParams, ) @@ -519,6 +520,9 @@ class Router: self._initialize_alerting() def validate_fallbacks(self, fallback_param: Optional[List]): + """ + Validate the fallbacks parameter. + """ if fallback_param is None: return @@ -530,8 +534,13 @@ class Router: f"Dictionary '{fallback_dict}' must have exactly one key, but has {len(fallback_dict)} keys." ) - def routing_strategy_init(self, routing_strategy: str, routing_strategy_args: dict): - if routing_strategy == "least-busy": + def routing_strategy_init( + self, routing_strategy: Union[RoutingStrategy, str], routing_strategy_args: dict + ): + if ( + routing_strategy == RoutingStrategy.LEAST_BUSY.value + or routing_strategy == RoutingStrategy.LEAST_BUSY + ): self.leastbusy_logger = LeastBusyLoggingHandler( router_cache=self.cache, model_list=self.model_list ) @@ -542,7 +551,10 @@ class Router: litellm.input_callback = [self.leastbusy_logger] # type: ignore if isinstance(litellm.callbacks, list): litellm.callbacks.append(self.leastbusy_logger) # type: ignore - elif routing_strategy == "usage-based-routing": + elif ( + routing_strategy == RoutingStrategy.USAGE_BASED_ROUTING.value + or routing_strategy == RoutingStrategy.USAGE_BASED_ROUTING + ): self.lowesttpm_logger = LowestTPMLoggingHandler( router_cache=self.cache, model_list=self.model_list, @@ -550,7 +562,10 @@ class Router: ) if isinstance(litellm.callbacks, list): litellm.callbacks.append(self.lowesttpm_logger) # type: ignore - elif routing_strategy == "usage-based-routing-v2": + elif ( + routing_strategy == RoutingStrategy.USAGE_BASED_ROUTING_V2.value + or routing_strategy == RoutingStrategy.USAGE_BASED_ROUTING_V2 + ): self.lowesttpm_logger_v2 = LowestTPMLoggingHandler_v2( router_cache=self.cache, model_list=self.model_list, @@ -558,7 +573,10 @@ class Router: ) if isinstance(litellm.callbacks, list): litellm.callbacks.append(self.lowesttpm_logger_v2) # type: ignore - elif routing_strategy == "latency-based-routing": + elif ( + routing_strategy == RoutingStrategy.LATENCY_BASED.value + or routing_strategy == RoutingStrategy.LATENCY_BASED + ): self.lowestlatency_logger = LowestLatencyLoggingHandler( router_cache=self.cache, model_list=self.model_list, @@ -566,7 +584,10 @@ class Router: ) if isinstance(litellm.callbacks, list): litellm.callbacks.append(self.lowestlatency_logger) # type: ignore - elif routing_strategy == "cost-based-routing": + elif ( + routing_strategy == RoutingStrategy.COST_BASED.value + or routing_strategy == RoutingStrategy.COST_BASED + ): self.lowestcost_logger = LowestCostLoggingHandler( router_cache=self.cache, model_list=self.model_list, @@ -574,10 +595,14 @@ class Router: ) if isinstance(litellm.callbacks, list): litellm.callbacks.append(self.lowestcost_logger) # type: ignore + else: + pass def print_deployment(self, deployment: dict): """ returns a copy of the deployment with the api key masked + + Only returns 2 characters of the api key and masks the rest with * (10 *). """ try: _deployment_copy = copy.deepcopy(deployment) @@ -1746,7 +1771,6 @@ class Router: try: kwargs["model"] = model kwargs["prompt"] = prompt - kwargs["original_function"] = self.text_completion kwargs["num_retries"] = kwargs.get("num_retries", self.num_retries) kwargs.get("request_timeout", self.timeout) kwargs.setdefault("metadata", {}).update({"model_group": model}) @@ -1770,13 +1794,7 @@ class Router: # call via litellm.completion() return litellm.text_completion(**{**data, "prompt": prompt, "caching": self.cache_responses, **kwargs}) # type: ignore except Exception as e: - if self.num_retries > 0: - kwargs["model"] = model - kwargs["messages"] = messages - kwargs["original_function"] = self.text_completion - return self.function_with_retries(**kwargs) - else: - raise e + raise e async def atext_completion( self, @@ -3005,7 +3023,7 @@ class Router: async def make_call(self, original_function: Any, *args, **kwargs): """ - Handler for making a call to the .completion()/.embeddings() functions. + Handler for making a call to the .completion()/.embeddings()/etc. functions. """ model_group = kwargs.get("model") response = await original_function(*args, **kwargs) diff --git a/litellm/types/router.py b/litellm/types/router.py index 1f2c7224f..f0737b3ef 100644 --- a/litellm/types/router.py +++ b/litellm/types/router.py @@ -608,3 +608,11 @@ VALID_LITELLM_ENVIRONMENTS = [ "staging", "production", ] + + +class RoutingStrategy(enum.Enum): + LEAST_BUSY = "least-busy" + LATENCY_BASED = "latency-based-routing" + COST_BASED = "cost-based-routing" + USAGE_BASED_ROUTING_V2 = "usage-based-routing-v2" + USAGE_BASED_ROUTING = "usage-based-routing" diff --git a/litellm/types/utils.py b/litellm/types/utils.py index c3118b453..cbc0f0274 100644 --- a/litellm/types/utils.py +++ b/litellm/types/utils.py @@ -11,6 +11,12 @@ from openai.types.completion_usage import ( CompletionUsage, PromptTokensDetails, ) +from openai.types.moderation import ( + Categories, + CategoryAppliedInputTypes, + CategoryScores, +) +from openai.types.moderation_create_response import Moderation, ModerationCreateResponse from pydantic import BaseModel, ConfigDict, PrivateAttr from typing_extensions import Callable, Dict, Required, TypedDict, override @@ -20,6 +26,7 @@ from .llms.openai import ( ChatCompletionUsageBlock, OpenAIChatCompletionChunk, ) +from .rerank import RerankResponse def _generate_id(): # private helper function @@ -811,7 +818,7 @@ class EmbeddingResponse(OpenAIObject): model: Optional[str] = None, usage: Optional[Usage] = None, response_ms=None, - data: Optional[List] = None, + data: Optional[Union[List, List[Embedding]]] = None, hidden_params=None, _response_headers=None, **params, diff --git a/tests/code_coverage_tests/router_code_coverage.py b/tests/code_coverage_tests/router_code_coverage.py new file mode 100644 index 000000000..fb88c3504 --- /dev/null +++ b/tests/code_coverage_tests/router_code_coverage.py @@ -0,0 +1,117 @@ +import ast +import os + + +def get_function_names_from_file(file_path): + """ + Extracts all function names from a given Python file. + """ + with open(file_path, "r") as file: + tree = ast.parse(file.read()) + + function_names = [] + + for node in ast.walk(tree): + if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)): + function_names.append(node.name) + + return function_names + + +def get_all_functions_called_in_tests(base_dir): + """ + Returns a set of function names that are called in test functions + inside 'local_testing' and 'router_unit_test' directories, + specifically in files containing the word 'router'. + """ + called_functions = set() + test_dirs = ["local_testing", "router_unit_tests"] + + for test_dir in test_dirs: + dir_path = os.path.join(base_dir, test_dir) + if not os.path.exists(dir_path): + print(f"Warning: Directory {dir_path} does not exist.") + continue + + print("dir_path: ", dir_path) + for root, _, files in os.walk(dir_path): + for file in files: + if file.endswith(".py") and "router" in file.lower(): + print("file: ", file) + file_path = os.path.join(root, file) + with open(file_path, "r") as f: + try: + tree = ast.parse(f.read()) + except SyntaxError: + print(f"Warning: Syntax error in file {file_path}") + continue + if file == "test_router_validate_fallbacks.py": + print(f"tree: {tree}") + for node in ast.walk(tree): + if isinstance(node, ast.Call) and isinstance( + node.func, ast.Name + ): + called_functions.add(node.func.id) + elif isinstance(node, ast.Call) and isinstance( + node.func, ast.Attribute + ): + called_functions.add(node.func.attr) + + return called_functions + + +def get_functions_from_router(file_path): + """ + Extracts all functions defined in router.py. + """ + return get_function_names_from_file(file_path) + + +ignored_function_names = [ + "__init__", + "_acreate_file", + "_acreate_batch", + "acreate_assistants", + "adelete_assistant", + "aget_assistants", + "acreate_thread", + "aget_thread", + "a_add_message", + "aget_messages", + "arun_thread", +] + + +def main(): + router_file = "./litellm/router.py" # Update this path if it's located elsewhere + # router_file = "../../litellm/router.py" ## LOCAL TESTING + tests_dir = ( + "./tests/" # Update this path if your tests directory is located elsewhere + ) + # tests_dir = "../../tests/" # LOCAL TESTING + + router_functions = get_functions_from_router(router_file) + print("router_functions: ", router_functions) + called_functions_in_tests = get_all_functions_called_in_tests(tests_dir) + untested_functions = [ + fn for fn in router_functions if fn not in called_functions_in_tests + ] + + if untested_functions: + all_untested_functions = [] + for func in untested_functions: + if func not in ignored_function_names: + all_untested_functions.append(func) + untested_perc = (len(all_untested_functions)) / len(router_functions) + print("perc_covered: ", untested_perc) + if untested_perc < 0.3: + print("The following functions in router.py are not tested:") + raise Exception( + f"{untested_perc * 100:.2f}% of functions in router.py are not tested: {all_untested_functions}" + ) + else: + print("All functions in router.py are covered by tests.") + + +if __name__ == "__main__": + main() diff --git a/tests/llm_translation/test_anthropic_completion.py b/tests/llm_translation/test_anthropic_completion.py index 2d5dd570a..ca402903b 100644 --- a/tests/llm_translation/test_anthropic_completion.py +++ b/tests/llm_translation/test_anthropic_completion.py @@ -366,41 +366,6 @@ def test_anthropic_tool_streaming(): assert tool_use["index"] == correct_tool_index -@pytest.mark.asyncio -async def test_anthropic_router_completion_e2e(): - litellm.set_verbose = True - - litellm.adapters = [{"id": "anthropic", "adapter": anthropic_adapter}] - - router = Router( - model_list=[ - { - "model_name": "claude-3-5-sonnet-20240620", - "litellm_params": { - "model": "gpt-3.5-turbo", - "mock_response": "hi this is macintosh.", - }, - } - ] - ) - messages = [{"role": "user", "content": "Hey, how's it going?"}] - - response = await router.aadapter_completion( - model="claude-3-5-sonnet-20240620", - messages=messages, - adapter_id="anthropic", - mock_response="This is a fake call", - ) - - print("Response: {}".format(response)) - - assert response is not None - - assert isinstance(response, AnthropicResponse) - - assert response.model == "gpt-3.5-turbo" - - def test_anthropic_tool_calling_translation(): kwargs = { "model": "claude-3-5-sonnet-20240620", diff --git a/tests/local_testing/test_audio_speech.py b/tests/local_testing/test_audio_speech.py index 4e45b9953..ac37b1b0c 100644 --- a/tests/local_testing/test_audio_speech.py +++ b/tests/local_testing/test_audio_speech.py @@ -83,43 +83,6 @@ async def test_audio_speech_litellm(sync_mode, model, api_base, api_key): assert isinstance(response, HttpxBinaryResponseContent) -@pytest.mark.parametrize("mode", ["iterator"]) # "file", -@pytest.mark.asyncio -async def test_audio_speech_router(mode): - speech_file_path = Path(__file__).parent / "speech.mp3" - - from litellm import Router - - client = Router( - model_list=[ - { - "model_name": "tts", - "litellm_params": { - "model": "openai/tts-1", - }, - }, - ] - ) - - response = await client.aspeech( - model="tts", - voice="alloy", - input="the quick brown fox jumped over the lazy dogs", - api_base=None, - api_key=None, - organization=None, - project=None, - max_retries=1, - timeout=600, - client=None, - optional_params={}, - ) - - from litellm.llms.OpenAI.openai import HttpxBinaryResponseContent - - assert isinstance(response, HttpxBinaryResponseContent) - - @pytest.mark.parametrize( "sync_mode", [False, True], diff --git a/tests/local_testing/test_router.py b/tests/local_testing/test_router.py index 42148d9ab..36ebdb382 100644 --- a/tests/local_testing/test_router.py +++ b/tests/local_testing/test_router.py @@ -1310,19 +1310,38 @@ def test_aembedding_on_router(): router = Router(model_list=model_list) async def embedding_call(): + ## Test 1: user facing function response = await router.aembedding( model="text-embedding-ada-002", input=["good morning from litellm", "this is another item"], ) print(response) + ## Test 2: underlying function + response = await router._aembedding( + model="text-embedding-ada-002", + input=["good morning from litellm 2"], + ) + print(response) + router.reset() + asyncio.run(embedding_call()) print("\n Making sync Embedding call\n") + ## Test 1: user facing function response = router.embedding( model="text-embedding-ada-002", input=["good morning from litellm 2"], ) + print(response) + router.reset() + + ## Test 2: underlying function + response = router._embedding( + model="text-embedding-ada-002", + input=["good morning from litellm 2"], + ) + print(response) router.reset() except Exception as e: if "Your task failed as a result of our safety system." in str(e): @@ -1843,10 +1862,16 @@ async def test_router_amoderation(): ] router = Router(model_list=model_list) + ## Test 1: user facing function result = await router.amoderation( model="openai-moderations", input="this is valid good text" ) + ## Test 2: underlying function + result = await router._amoderation( + model="openai-moderations", input="this is valid good text" + ) + print("moderation result", result) diff --git a/tests/local_testing/test_whisper.py b/tests/local_testing/test_whisper.py index 087028928..f66ad8b13 100644 --- a/tests/local_testing/test_whisper.py +++ b/tests/local_testing/test_whisper.py @@ -79,81 +79,6 @@ async def test_transcription(model, api_key, api_base, response_format, sync_mod assert transcript.text is not None -# This file includes the custom callbacks for LiteLLM Proxy -# Once defined, these can be passed in proxy_config.yaml -class MyCustomHandler(CustomLogger): - def __init__(self): - self.openai_client = None - - async def async_log_success_event(self, kwargs, response_obj, start_time, end_time): - try: - # init logging config - print("logging a transcript kwargs: ", kwargs) - print("openai client=", kwargs.get("client")) - self.openai_client = kwargs.get("client") - - except Exception: - pass - - -proxy_handler_instance = MyCustomHandler() - - -# Set litellm.callbacks = [proxy_handler_instance] on the proxy -# need to set litellm.callbacks = [proxy_handler_instance] # on the proxy -@pytest.mark.asyncio -async def test_transcription_on_router(): - litellm.set_verbose = True - litellm.callbacks = [proxy_handler_instance] - print("\n Testing async transcription on router\n") - try: - model_list = [ - { - "model_name": "whisper", - "litellm_params": { - "model": "whisper-1", - }, - }, - { - "model_name": "whisper", - "litellm_params": { - "model": "azure/azure-whisper", - "api_base": "https://my-endpoint-europe-berri-992.openai.azure.com/", - "api_key": os.getenv("AZURE_EUROPE_API_KEY"), - "api_version": "2024-02-15-preview", - }, - }, - ] - - router = Router(model_list=model_list) - - router_level_clients = [] - for deployment in router.model_list: - _deployment_openai_client = router._get_client( - deployment=deployment, - kwargs={"model": "whisper-1"}, - client_type="async", - ) - - router_level_clients.append(str(_deployment_openai_client)) - - response = await router.atranscription( - model="whisper", - file=audio_file, - ) - print(response) - - # PROD Test - # Ensure we ONLY use OpenAI/Azure client initialized on the router level - await asyncio.sleep(5) - print("OpenAI Client used= ", proxy_handler_instance.openai_client) - print("all router level clients= ", router_level_clients) - assert proxy_handler_instance.openai_client in router_level_clients - except Exception as e: - traceback.print_exc() - pytest.fail(f"Error occurred: {e}") - - @pytest.mark.asyncio() async def test_transcription_caching(): import litellm diff --git a/tests/router_unit_tests/README.md b/tests/router_unit_tests/README.md new file mode 100644 index 000000000..0206bcbb8 --- /dev/null +++ b/tests/router_unit_tests/README.md @@ -0,0 +1,5 @@ +## Router component unit tests. + +Please name all files with the word 'router' in them. + +This is used to ensure all functions in the router are tested. \ No newline at end of file diff --git a/tests/router_unit_tests/gettysburg.wav b/tests/router_unit_tests/gettysburg.wav new file mode 100644 index 000000000..9690f521e Binary files /dev/null and b/tests/router_unit_tests/gettysburg.wav differ diff --git a/tests/router_unit_tests/test_router_endpoints.py b/tests/router_unit_tests/test_router_endpoints.py new file mode 100644 index 000000000..accd5ea40 --- /dev/null +++ b/tests/router_unit_tests/test_router_endpoints.py @@ -0,0 +1,279 @@ +import sys +import os +import traceback +from dotenv import load_dotenv +from fastapi import Request +from datetime import datetime + +sys.path.insert( + 0, os.path.abspath("../..") +) # Adds the parent directory to the system path +from litellm import Router, CustomLogger + +# Get the current directory of the file being run +pwd = os.path.dirname(os.path.realpath(__file__)) +print(pwd) + +file_path = os.path.join(pwd, "gettysburg.wav") + +audio_file = open(file_path, "rb") +from pathlib import Path +import litellm +import pytest +import asyncio + + +@pytest.fixture +def model_list(): + return [ + { + "model_name": "gpt-3.5-turbo", + "litellm_params": { + "model": "gpt-3.5-turbo", + "api_key": os.getenv("OPENAI_API_KEY"), + }, + }, + { + "model_name": "gpt-4o", + "litellm_params": { + "model": "gpt-4o", + "api_key": os.getenv("OPENAI_API_KEY"), + }, + }, + { + "model_name": "dall-e-3", + "litellm_params": { + "model": "dall-e-3", + "api_key": os.getenv("OPENAI_API_KEY"), + }, + }, + { + "model_name": "cohere-rerank", + "litellm_params": { + "model": "cohere/rerank-english-v3.0", + "api_key": os.getenv("COHERE_API_KEY"), + }, + }, + { + "model_name": "claude-3-5-sonnet-20240620", + "litellm_params": { + "model": "gpt-3.5-turbo", + "mock_response": "hi this is macintosh.", + }, + }, + ] + + +# This file includes the custom callbacks for LiteLLM Proxy +# Once defined, these can be passed in proxy_config.yaml +class MyCustomHandler(CustomLogger): + def __init__(self): + self.openai_client = None + + async def async_log_success_event(self, kwargs, response_obj, start_time, end_time): + try: + # init logging config + print("logging a transcript kwargs: ", kwargs) + print("openai client=", kwargs.get("client")) + self.openai_client = kwargs.get("client") + + except Exception: + pass + + +proxy_handler_instance = MyCustomHandler() + + +# Set litellm.callbacks = [proxy_handler_instance] on the proxy +# need to set litellm.callbacks = [proxy_handler_instance] # on the proxy +@pytest.mark.asyncio +async def test_transcription_on_router(): + litellm.set_verbose = True + litellm.callbacks = [proxy_handler_instance] + print("\n Testing async transcription on router\n") + try: + model_list = [ + { + "model_name": "whisper", + "litellm_params": { + "model": "whisper-1", + }, + }, + { + "model_name": "whisper", + "litellm_params": { + "model": "azure/azure-whisper", + "api_base": "https://my-endpoint-europe-berri-992.openai.azure.com/", + "api_key": os.getenv("AZURE_EUROPE_API_KEY"), + "api_version": "2024-02-15-preview", + }, + }, + ] + + router = Router(model_list=model_list) + + router_level_clients = [] + for deployment in router.model_list: + _deployment_openai_client = router._get_client( + deployment=deployment, + kwargs={"model": "whisper-1"}, + client_type="async", + ) + + router_level_clients.append(str(_deployment_openai_client)) + + ## test 1: user facing function + response = await router.atranscription( + model="whisper", + file=audio_file, + ) + + ## test 2: underlying function + response = await router._atranscription( + model="whisper", + file=audio_file, + ) + print(response) + + # PROD Test + # Ensure we ONLY use OpenAI/Azure client initialized on the router level + await asyncio.sleep(5) + print("OpenAI Client used= ", proxy_handler_instance.openai_client) + print("all router level clients= ", router_level_clients) + assert proxy_handler_instance.openai_client in router_level_clients + except Exception as e: + traceback.print_exc() + pytest.fail(f"Error occurred: {e}") + + +@pytest.mark.parametrize("mode", ["iterator"]) # "file", +@pytest.mark.asyncio +async def test_audio_speech_router(mode): + + from litellm import Router + + client = Router( + model_list=[ + { + "model_name": "tts", + "litellm_params": { + "model": "openai/tts-1", + }, + }, + ] + ) + + response = await client.aspeech( + model="tts", + voice="alloy", + input="the quick brown fox jumped over the lazy dogs", + api_base=None, + api_key=None, + organization=None, + project=None, + max_retries=1, + timeout=600, + client=None, + optional_params={}, + ) + + from litellm.llms.OpenAI.openai import HttpxBinaryResponseContent + + assert isinstance(response, HttpxBinaryResponseContent) + + +@pytest.mark.asyncio() +async def test_rerank_endpoint(model_list): + from litellm.types.utils import RerankResponse + + router = Router(model_list=model_list) + + ## Test 1: user facing function + response = await router.arerank( + model="cohere-rerank", + query="hello", + documents=["hello", "world"], + top_n=3, + ) + + ## Test 2: underlying function + response = await router._arerank( + model="cohere-rerank", + query="hello", + documents=["hello", "world"], + top_n=3, + ) + + print("async re rank response: ", response) + + assert response.id is not None + assert response.results is not None + + RerankResponse.model_validate(response) + + +@pytest.mark.parametrize("sync_mode", [True, False]) +@pytest.mark.asyncio +async def test_text_completion_endpoint(model_list, sync_mode): + router = Router(model_list=model_list) + + if sync_mode: + response = router.text_completion( + model="gpt-3.5-turbo", + prompt="Hello, how are you?", + mock_response="I'm fine, thank you!", + ) + else: + ## Test 1: user facing function + response = await router.atext_completion( + model="gpt-3.5-turbo", + prompt="Hello, how are you?", + mock_response="I'm fine, thank you!", + ) + + ## Test 2: underlying function + response_2 = await router._atext_completion( + model="gpt-3.5-turbo", + prompt="Hello, how are you?", + mock_response="I'm fine, thank you!", + ) + assert response_2.choices[0].text == "I'm fine, thank you!" + + assert response.choices[0].text == "I'm fine, thank you!" + + +@pytest.mark.asyncio +async def test_anthropic_router_completion_e2e(model_list): + from litellm.adapters.anthropic_adapter import anthropic_adapter + from litellm.types.llms.anthropic import AnthropicResponse + + litellm.set_verbose = True + + litellm.adapters = [{"id": "anthropic", "adapter": anthropic_adapter}] + + router = Router(model_list=model_list) + messages = [{"role": "user", "content": "Hey, how's it going?"}] + + ## Test 1: user facing function + response = await router.aadapter_completion( + model="claude-3-5-sonnet-20240620", + messages=messages, + adapter_id="anthropic", + mock_response="This is a fake call", + ) + + ## Test 2: underlying function + await router._aadapter_completion( + model="claude-3-5-sonnet-20240620", + messages=messages, + adapter_id="anthropic", + mock_response="This is a fake call", + ) + + print("Response: {}".format(response)) + + assert response is not None + + AnthropicResponse.model_validate(response) + + assert response.model == "gpt-3.5-turbo" diff --git a/tests/router_unit_tests/test_router_helper_utils.py b/tests/router_unit_tests/test_router_helper_utils.py new file mode 100644 index 000000000..97660471b --- /dev/null +++ b/tests/router_unit_tests/test_router_helper_utils.py @@ -0,0 +1,252 @@ +import sys +import os +import traceback +from dotenv import load_dotenv +from fastapi import Request +from datetime import datetime + +sys.path.insert( + 0, os.path.abspath("../..") +) # Adds the parent directory to the system path +from litellm import Router +import pytest +from unittest.mock import patch, MagicMock, AsyncMock + + +@pytest.fixture +def model_list(): + return [ + { + "model_name": "gpt-3.5-turbo", + "litellm_params": { + "model": "gpt-3.5-turbo", + "api_key": os.getenv("OPENAI_API_KEY"), + }, + }, + { + "model_name": "gpt-4o", + "litellm_params": { + "model": "gpt-4o", + "api_key": os.getenv("OPENAI_API_KEY"), + }, + }, + { + "model_name": "dall-e-3", + "litellm_params": { + "model": "dall-e-3", + "api_key": os.getenv("OPENAI_API_KEY"), + }, + }, + ] + + +def test_validate_fallbacks(model_list): + router = Router(model_list=model_list, fallbacks=[{"gpt-4o": "gpt-3.5-turbo"}]) + router.validate_fallbacks(fallback_param=[{"gpt-4o": "gpt-3.5-turbo"}]) + + +def test_routing_strategy_init(model_list): + """Test if all routing strategies are initialized correctly""" + from litellm.types.router import RoutingStrategy + + router = Router(model_list=model_list) + for strategy in RoutingStrategy._member_names_: + router.routing_strategy_init( + routing_strategy=strategy, routing_strategy_args={} + ) + + +def test_print_deployment(model_list): + """Test if the api key is masked correctly""" + + router = Router(model_list=model_list) + deployment = { + "model_name": "gpt-3.5-turbo", + "litellm_params": { + "model": "gpt-3.5-turbo", + "api_key": os.getenv("OPENAI_API_KEY"), + }, + } + printed_deployment = router.print_deployment(deployment) + assert 10 * "*" in printed_deployment["litellm_params"]["api_key"] + + +def test_completion(model_list): + """Test if the completion function is working correctly""" + router = Router(model_list=model_list) + response = router._completion( + model="gpt-3.5-turbo", + messages=[{"role": "user", "content": "Hello, how are you?"}], + mock_response="I'm fine, thank you!", + ) + assert response["choices"][0]["message"]["content"] == "I'm fine, thank you!" + + +@pytest.mark.parametrize("sync_mode", [True, False]) +@pytest.mark.flaky(retries=6, delay=1) +@pytest.mark.asyncio +async def test_image_generation(model_list, sync_mode): + """Test if the underlying '_image_generation' function is working correctly""" + from litellm.types.utils import ImageResponse + + router = Router(model_list=model_list) + if sync_mode: + response = router._image_generation( + model="dall-e-3", + prompt="A cute baby sea otter", + ) + else: + response = await router._aimage_generation( + model="dall-e-3", + prompt="A cute baby sea otter", + ) + + ImageResponse.model_validate(response) + + +@pytest.mark.asyncio +async def test_router_acompletion_util(model_list): + """Test if the underlying '_acompletion' function is working correctly""" + router = Router(model_list=model_list) + response = await router._acompletion( + model="gpt-3.5-turbo", + messages=[{"role": "user", "content": "Hello, how are you?"}], + mock_response="I'm fine, thank you!", + ) + assert response["choices"][0]["message"]["content"] == "I'm fine, thank you!" + + +@pytest.mark.asyncio +async def test_router_abatch_completion_one_model_multiple_requests_util(model_list): + """Test if the 'abatch_completion_one_model_multiple_requests' function is working correctly""" + router = Router(model_list=model_list) + response = await router.abatch_completion_one_model_multiple_requests( + model="gpt-3.5-turbo", + messages=[ + [{"role": "user", "content": "Hello, how are you?"}], + [{"role": "user", "content": "Hello, how are you?"}], + ], + mock_response="I'm fine, thank you!", + ) + print(response) + assert response[0]["choices"][0]["message"]["content"] == "I'm fine, thank you!" + assert response[1]["choices"][0]["message"]["content"] == "I'm fine, thank you!" + + +@pytest.mark.asyncio +async def test_router_schedule_acompletion(model_list): + """Test if the 'schedule_acompletion' function is working correctly""" + router = Router(model_list=model_list) + response = await router.schedule_acompletion( + model="gpt-3.5-turbo", + messages=[{"role": "user", "content": "Hello, how are you?"}], + mock_response="I'm fine, thank you!", + priority=1, + ) + assert response["choices"][0]["message"]["content"] == "I'm fine, thank you!" + + +@pytest.mark.asyncio +async def test_router_arealtime(model_list): + """Test if the '_arealtime' function is working correctly""" + import litellm + + router = Router(model_list=model_list) + with patch.object(litellm, "_arealtime", AsyncMock()) as mock_arealtime: + mock_arealtime.return_value = "I'm fine, thank you!" + await router._arealtime( + model="gpt-3.5-turbo", + messages=[{"role": "user", "content": "Hello, how are you?"}], + ) + + mock_arealtime.assert_awaited_once() + + +@pytest.mark.parametrize("sync_mode", [True, False]) +@pytest.mark.asyncio +async def test_router_function_with_fallbacks(model_list, sync_mode): + """Test if the router 'async_function_with_fallbacks' + 'function_with_fallbacks' are working correctly""" + router = Router(model_list=model_list) + data = { + "model": "gpt-3.5-turbo", + "messages": [{"role": "user", "content": "Hello, how are you?"}], + "mock_response": "I'm fine, thank you!", + "num_retries": 0, + } + if sync_mode: + response = router.function_with_fallbacks( + original_function=router._completion, + **data, + ) + else: + response = await router.async_function_with_fallbacks( + original_function=router._acompletion, + **data, + ) + assert response.choices[0].message.content == "I'm fine, thank you!" + + +@pytest.mark.parametrize("sync_mode", [True, False]) +@pytest.mark.asyncio +async def test_router_function_with_retries(model_list, sync_mode): + """Test if the router 'async_function_with_retries' + 'function_with_retries' are working correctly""" + router = Router(model_list=model_list) + data = { + "model": "gpt-3.5-turbo", + "messages": [{"role": "user", "content": "Hello, how are you?"}], + "mock_response": "I'm fine, thank you!", + "num_retries": 0, + } + if sync_mode: + response = router.function_with_retries( + original_function=router._completion, + **data, + ) + else: + response = await router.async_function_with_retries( + original_function=router._acompletion, + **data, + ) + assert response.choices[0].message.content == "I'm fine, thank you!" + + +@pytest.mark.asyncio +async def test_router_make_call(model_list): + """Test if the router 'make_call' function is working correctly""" + + ## ACOMPLETION + router = Router(model_list=model_list) + response = await router.make_call( + original_function=router._acompletion, + model="gpt-3.5-turbo", + messages=[{"role": "user", "content": "Hello, how are you?"}], + mock_response="I'm fine, thank you!", + ) + assert response.choices[0].message.content == "I'm fine, thank you!" + + ## ATEXT_COMPLETION + response = await router.make_call( + original_function=router._atext_completion, + model="gpt-3.5-turbo", + prompt="Hello, how are you?", + mock_response="I'm fine, thank you!", + ) + assert response.choices[0].text == "I'm fine, thank you!" + + ## AEMBEDDING + response = await router.make_call( + original_function=router._aembedding, + model="gpt-3.5-turbo", + input="Hello, how are you?", + mock_response=[0.1, 0.2, 0.3], + ) + assert response.data[0].embedding == [0.1, 0.2, 0.3] + + ## AIMAGE_GENERATION + response = await router.make_call( + original_function=router._aimage_generation, + model="dall-e-3", + prompt="A cute baby sea otter", + mock_response="https://example.com/image.png", + ) + assert response.data[0].url == "https://example.com/image.png"