mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 03:04:13 +00:00
test fallbacks
This commit is contained in:
parent
ccf875f84b
commit
72414919dd
3 changed files with 58 additions and 1 deletions
|
@ -17,6 +17,7 @@ from litellm.utils import (
|
|||
install_and_import,
|
||||
CustomStreamWrapper,
|
||||
read_config_args,
|
||||
completion_with_fallbacks
|
||||
)
|
||||
from .llms.anthropic import AnthropicLLM
|
||||
from .llms.huggingface_restapi import HuggingfaceRestAPILLM
|
||||
|
@ -90,9 +91,12 @@ def completion(
|
|||
# used by text-bison only
|
||||
top_k=40,
|
||||
request_timeout=0, # unused var for old version of OpenAI API
|
||||
fallbacks=[],
|
||||
) -> ModelResponse:
|
||||
args = locals()
|
||||
try:
|
||||
if fallbacks != []:
|
||||
return completion_with_fallbacks(**args)
|
||||
model_response = ModelResponse()
|
||||
if azure: # this flag is deprecated, remove once notebooks are also updated.
|
||||
custom_llm_provider = "azure"
|
||||
|
|
|
@ -11,6 +11,7 @@ sys.path.insert(
|
|||
import pytest
|
||||
import litellm
|
||||
from litellm import embedding, completion
|
||||
litellm.debugger = True
|
||||
|
||||
# from infisical import InfisicalClient
|
||||
|
||||
|
@ -38,7 +39,7 @@ def test_completion_custom_provider_model_name():
|
|||
pytest.fail(f"Error occurred: {e}")
|
||||
|
||||
|
||||
test_completion_custom_provider_model_name()
|
||||
# test_completion_custom_provider_model_name()
|
||||
|
||||
|
||||
def test_completion_claude():
|
||||
|
@ -347,6 +348,21 @@ def test_petals():
|
|||
pytest.fail(f"Error occurred: {e}")
|
||||
|
||||
|
||||
def test_completion_with_fallbacks():
|
||||
fallbacks = ['gpt-3.5-turb', 'gpt-3.5-turbo', 'command-nightly']
|
||||
try:
|
||||
response = completion(
|
||||
model='bad-model',
|
||||
messages=messages,
|
||||
force_timeout=120,
|
||||
fallbacks=fallbacks
|
||||
)
|
||||
# Add any assertions here to check the response
|
||||
print(response)
|
||||
except Exception as e:
|
||||
pytest.fail(f"Error occurred: {e}")
|
||||
|
||||
|
||||
# def test_baseten_falcon_7bcompletion():
|
||||
# model_name = "qvv0xeq"
|
||||
# try:
|
||||
|
|
|
@ -1464,3 +1464,40 @@ async def together_ai_completion_streaming(json_data, headers):
|
|||
pass
|
||||
finally:
|
||||
await session.close()
|
||||
|
||||
|
||||
def completion_with_fallbacks(**kwargs):
|
||||
response = None
|
||||
rate_limited_models = set()
|
||||
model_expiration_times = {}
|
||||
start_time = time.time()
|
||||
fallbacks = [kwargs['model']] + kwargs['fallbacks']
|
||||
del kwargs['fallbacks'] # remove fallbacks so it's not recursive
|
||||
|
||||
while response == None and time.time() - start_time < 45:
|
||||
for model in fallbacks:
|
||||
# loop thru all models
|
||||
try:
|
||||
if model in rate_limited_models: # check if model is currently cooling down
|
||||
if model_expiration_times.get(model) and time.time() >= model_expiration_times[model]:
|
||||
rate_limited_models.remove(model) # check if it's been 60s of cool down and remove model
|
||||
else:
|
||||
continue # skip model
|
||||
|
||||
# delete model from kwargs if it exists
|
||||
if kwargs.get('model'):
|
||||
del kwargs['model']
|
||||
|
||||
print("making completion call", model)
|
||||
response = litellm.completion(**kwargs, model=model)
|
||||
|
||||
if response != None:
|
||||
return response
|
||||
|
||||
except Exception as e:
|
||||
print(f"got exception {e} for model {model}")
|
||||
rate_limited_models.add(model)
|
||||
model_expiration_times[model] = time.time() + 60 # cool down this selected model
|
||||
#print(f"rate_limited_models {rate_limited_models}")
|
||||
pass
|
||||
return response
|
Loading…
Add table
Add a link
Reference in a new issue