test fallbacks

This commit is contained in:
ishaan-jaff 2023-08-22 15:05:42 -07:00
parent c67e59cd10
commit 9a04a3b95b
3 changed files with 58 additions and 1 deletions

View file

@ -17,6 +17,7 @@ from litellm.utils import (
install_and_import, install_and_import,
CustomStreamWrapper, CustomStreamWrapper,
read_config_args, read_config_args,
completion_with_fallbacks
) )
from .llms.anthropic import AnthropicLLM from .llms.anthropic import AnthropicLLM
from .llms.huggingface_restapi import HuggingfaceRestAPILLM from .llms.huggingface_restapi import HuggingfaceRestAPILLM
@ -90,9 +91,12 @@ def completion(
# used by text-bison only # used by text-bison only
top_k=40, top_k=40,
request_timeout=0, # unused var for old version of OpenAI API request_timeout=0, # unused var for old version of OpenAI API
fallbacks=[],
) -> ModelResponse: ) -> ModelResponse:
args = locals() args = locals()
try: try:
if fallbacks != []:
return completion_with_fallbacks(**args)
model_response = ModelResponse() model_response = ModelResponse()
if azure: # this flag is deprecated, remove once notebooks are also updated. if azure: # this flag is deprecated, remove once notebooks are also updated.
custom_llm_provider = "azure" custom_llm_provider = "azure"

View file

@ -11,6 +11,7 @@ sys.path.insert(
import pytest import pytest
import litellm import litellm
from litellm import embedding, completion from litellm import embedding, completion
litellm.debugger = True
# from infisical import InfisicalClient # from infisical import InfisicalClient
@ -38,7 +39,7 @@ def test_completion_custom_provider_model_name():
pytest.fail(f"Error occurred: {e}") pytest.fail(f"Error occurred: {e}")
test_completion_custom_provider_model_name() # test_completion_custom_provider_model_name()
def test_completion_claude(): def test_completion_claude():
@ -347,6 +348,21 @@ def test_petals():
pytest.fail(f"Error occurred: {e}") pytest.fail(f"Error occurred: {e}")
def test_completion_with_fallbacks():
fallbacks = ['gpt-3.5-turb', 'gpt-3.5-turbo', 'command-nightly']
try:
response = completion(
model='bad-model',
messages=messages,
force_timeout=120,
fallbacks=fallbacks
)
# Add any assertions here to check the response
print(response)
except Exception as e:
pytest.fail(f"Error occurred: {e}")
# def test_baseten_falcon_7bcompletion(): # def test_baseten_falcon_7bcompletion():
# model_name = "qvv0xeq" # model_name = "qvv0xeq"
# try: # try:

View file

@ -1464,3 +1464,40 @@ async def together_ai_completion_streaming(json_data, headers):
pass pass
finally: finally:
await session.close() await session.close()
def completion_with_fallbacks(**kwargs):
response = None
rate_limited_models = set()
model_expiration_times = {}
start_time = time.time()
fallbacks = [kwargs['model']] + kwargs['fallbacks']
del kwargs['fallbacks'] # remove fallbacks so it's not recursive
while response == None and time.time() - start_time < 45:
for model in fallbacks:
# loop thru all models
try:
if model in rate_limited_models: # check if model is currently cooling down
if model_expiration_times.get(model) and time.time() >= model_expiration_times[model]:
rate_limited_models.remove(model) # check if it's been 60s of cool down and remove model
else:
continue # skip model
# delete model from kwargs if it exists
if kwargs.get('model'):
del kwargs['model']
print("making completion call", model)
response = litellm.completion(**kwargs, model=model)
if response != None:
return response
except Exception as e:
print(f"got exception {e} for model {model}")
rate_limited_models.add(model)
model_expiration_times[model] = time.time() + 60 # cool down this selected model
#print(f"rate_limited_models {rate_limited_models}")
pass
return response