From 72414919ddee5cb35a3edd8a4c8e84ed1d43632e Mon Sep 17 00:00:00 2001 From: ishaan-jaff Date: Tue, 22 Aug 2023 15:05:42 -0700 Subject: [PATCH] test fallbacks --- litellm/main.py | 4 ++++ litellm/tests/test_completion.py | 18 +++++++++++++++- litellm/utils.py | 37 ++++++++++++++++++++++++++++++++ 3 files changed, 58 insertions(+), 1 deletion(-) diff --git a/litellm/main.py b/litellm/main.py index 298d2ce95..e0240892d 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -17,6 +17,7 @@ from litellm.utils import ( install_and_import, CustomStreamWrapper, read_config_args, + completion_with_fallbacks ) from .llms.anthropic import AnthropicLLM from .llms.huggingface_restapi import HuggingfaceRestAPILLM @@ -90,9 +91,12 @@ def completion( # used by text-bison only top_k=40, request_timeout=0, # unused var for old version of OpenAI API + fallbacks=[], ) -> ModelResponse: args = locals() try: + if fallbacks != []: + return completion_with_fallbacks(**args) model_response = ModelResponse() if azure: # this flag is deprecated, remove once notebooks are also updated. custom_llm_provider = "azure" diff --git a/litellm/tests/test_completion.py b/litellm/tests/test_completion.py index 6fd1964b8..e6f75b8dc 100644 --- a/litellm/tests/test_completion.py +++ b/litellm/tests/test_completion.py @@ -11,6 +11,7 @@ sys.path.insert( import pytest import litellm from litellm import embedding, completion +litellm.debugger = True # from infisical import InfisicalClient @@ -38,7 +39,7 @@ def test_completion_custom_provider_model_name(): pytest.fail(f"Error occurred: {e}") -test_completion_custom_provider_model_name() +# test_completion_custom_provider_model_name() def test_completion_claude(): @@ -347,6 +348,21 @@ def test_petals(): pytest.fail(f"Error occurred: {e}") +def test_completion_with_fallbacks(): + fallbacks = ['gpt-3.5-turb', 'gpt-3.5-turbo', 'command-nightly'] + try: + response = completion( + model='bad-model', + messages=messages, + force_timeout=120, + fallbacks=fallbacks + ) + # Add any assertions here to check the response + print(response) + except Exception as e: + pytest.fail(f"Error occurred: {e}") + + # def test_baseten_falcon_7bcompletion(): # model_name = "qvv0xeq" # try: diff --git a/litellm/utils.py b/litellm/utils.py index 1684bb757..6cbeb3e1a 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -1464,3 +1464,40 @@ async def together_ai_completion_streaming(json_data, headers): pass finally: await session.close() + + +def completion_with_fallbacks(**kwargs): + response = None + rate_limited_models = set() + model_expiration_times = {} + start_time = time.time() + fallbacks = [kwargs['model']] + kwargs['fallbacks'] + del kwargs['fallbacks'] # remove fallbacks so it's not recursive + + while response == None and time.time() - start_time < 45: + for model in fallbacks: + # loop thru all models + try: + if model in rate_limited_models: # check if model is currently cooling down + if model_expiration_times.get(model) and time.time() >= model_expiration_times[model]: + rate_limited_models.remove(model) # check if it's been 60s of cool down and remove model + else: + continue # skip model + + # delete model from kwargs if it exists + if kwargs.get('model'): + del kwargs['model'] + + print("making completion call", model) + response = litellm.completion(**kwargs, model=model) + + if response != None: + return response + + except Exception as e: + print(f"got exception {e} for model {model}") + rate_limited_models.add(model) + model_expiration_times[model] = time.time() + 60 # cool down this selected model + #print(f"rate_limited_models {rate_limited_models}") + pass + return response \ No newline at end of file