From e516cfe9f5d7d9405ff9b476ef6493e7a4cb5657 Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Wed, 27 Dec 2023 17:24:02 +0530 Subject: [PATCH] fix(utils.py): allow text completion input to be either model or engine --- litellm/__init__.py | 5 +- litellm/tests/test_router_fallbacks.py | 90 ++++++++++++++++++++++++++ litellm/utils.py | 21 +++++- 3 files changed, 113 insertions(+), 3 deletions(-) diff --git a/litellm/__init__.py b/litellm/__init__.py index f1fbbc7d3..42c0b49b3 100644 --- a/litellm/__init__.py +++ b/litellm/__init__.py @@ -136,10 +136,13 @@ suppress_debug_info = False dynamodb_table_name: Optional[str] = None #### RELIABILITY #### request_timeout: Optional[float] = 6000 -num_retries: Optional[int] = None +num_retries: Optional[int] = None # per model endpoint fallbacks: Optional[List] = None context_window_fallbacks: Optional[List] = None allowed_fails: int = 0 +num_retries_per_request: Optional[ + int +] = None # for the request overall (incl. fallbacks + model retries) ####### SECRET MANAGERS ##################### secret_manager_client: Optional[ Any diff --git a/litellm/tests/test_router_fallbacks.py b/litellm/tests/test_router_fallbacks.py index 36baae346..d7ffd446d 100644 --- a/litellm/tests/test_router_fallbacks.py +++ b/litellm/tests/test_router_fallbacks.py @@ -554,3 +554,93 @@ def test_sync_fallbacks_streaming(): router.reset() except Exception as e: print(e) + + +@pytest.mark.asyncio +async def test_async_fallbacks_max_retries_per_request(): + litellm.set_verbose = False + litellm.num_retries_per_request = 0 + model_list = [ + { # list of model deployments + "model_name": "azure/gpt-3.5-turbo", # openai model name + "litellm_params": { # params for litellm completion/embedding call + "model": "azure/chatgpt-v-2", + "api_key": "bad-key", + "api_version": os.getenv("AZURE_API_VERSION"), + "api_base": os.getenv("AZURE_API_BASE"), + }, + "tpm": 240000, + "rpm": 1800, + }, + { # list of model deployments + "model_name": "azure/gpt-3.5-turbo-context-fallback", # openai model name + "litellm_params": { # params for litellm completion/embedding call + "model": "azure/chatgpt-v-2", + "api_key": os.getenv("AZURE_API_KEY"), + "api_version": os.getenv("AZURE_API_VERSION"), + "api_base": os.getenv("AZURE_API_BASE"), + }, + "tpm": 240000, + "rpm": 1800, + }, + { + "model_name": "azure/gpt-3.5-turbo", # openai model name + "litellm_params": { # params for litellm completion/embedding call + "model": "azure/chatgpt-functioncalling", + "api_key": "bad-key", + "api_version": os.getenv("AZURE_API_VERSION"), + "api_base": os.getenv("AZURE_API_BASE"), + }, + "tpm": 240000, + "rpm": 1800, + }, + { + "model_name": "gpt-3.5-turbo", # openai model name + "litellm_params": { # params for litellm completion/embedding call + "model": "gpt-3.5-turbo", + "api_key": os.getenv("OPENAI_API_KEY"), + }, + "tpm": 1000000, + "rpm": 9000, + }, + { + "model_name": "gpt-3.5-turbo-16k", # openai model name + "litellm_params": { # params for litellm completion/embedding call + "model": "gpt-3.5-turbo-16k", + "api_key": os.getenv("OPENAI_API_KEY"), + }, + "tpm": 1000000, + "rpm": 9000, + }, + ] + + router = Router( + model_list=model_list, + fallbacks=[{"azure/gpt-3.5-turbo": ["gpt-3.5-turbo"]}], + context_window_fallbacks=[ + {"azure/gpt-3.5-turbo-context-fallback": ["gpt-3.5-turbo-16k"]}, + {"gpt-3.5-turbo": ["gpt-3.5-turbo-16k"]}, + ], + set_verbose=False, + ) + customHandler = MyCustomHandler() + litellm.callbacks = [customHandler] + user_message = "Hello, how are you?" + messages = [{"content": user_message, "role": "user"}] + try: + try: + response = await router.acompletion(**kwargs, stream=True) + except: + pass + print(f"customHandler.previous_models: {customHandler.previous_models}") + await asyncio.sleep( + 0.05 + ) # allow a delay as success_callbacks are on a separate thread + assert customHandler.previous_models == 0 # 0 retries, 0 fallback + router.reset() + except litellm.Timeout as e: + pass + except Exception as e: + pytest.fail(f"An exception occurred: {e}") + finally: + router.reset() diff --git a/litellm/utils.py b/litellm/utils.py index 1b1fc5c24..9dbd9bbc4 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -1925,7 +1925,10 @@ def client(original_function): except: model = None call_type = original_function.__name__ - if call_type != CallTypes.image_generation.value: + if ( + call_type != CallTypes.image_generation.value + and call_type != CallTypes.text_completion.value + ): raise ValueError("model param not passed in.") try: @@ -1945,6 +1948,16 @@ def client(original_function): max_budget=litellm.max_budget, ) + # [OPTIONAL] CHECK MAX RETRIES / REQUEST + if litellm.num_retries_per_request is not None: + # check if previous_models passed in as ['litellm_params']['metadata]['previous_models'] + previous_models = kwargs.get("metadata", {}).get( + "previous_models", None + ) + if previous_models is not None: + if litellm.num_retries_per_request <= len(previous_models): + raise Exception(f"Max retries per request hit!") + # [OPTIONAL] CHECK CACHE print_verbose( f"kwargs[caching]: {kwargs.get('caching', False)}; litellm.cache: {litellm.cache}" @@ -2096,7 +2109,11 @@ def client(original_function): try: model = args[0] if len(args) > 0 else kwargs["model"] except: - raise ValueError("model param not passed in.") + if ( + call_type != CallTypes.aimage_generation.value # model optional + and call_type != CallTypes.atext_completion.value # can also be engine + ): + raise ValueError("model param not passed in.") try: if logging_obj is None: