forked from phoenix/litellm-mirror
test fallbacks
This commit is contained in:
parent
ccf875f84b
commit
72414919dd
3 changed files with 58 additions and 1 deletions
|
@ -17,6 +17,7 @@ from litellm.utils import (
|
||||||
install_and_import,
|
install_and_import,
|
||||||
CustomStreamWrapper,
|
CustomStreamWrapper,
|
||||||
read_config_args,
|
read_config_args,
|
||||||
|
completion_with_fallbacks
|
||||||
)
|
)
|
||||||
from .llms.anthropic import AnthropicLLM
|
from .llms.anthropic import AnthropicLLM
|
||||||
from .llms.huggingface_restapi import HuggingfaceRestAPILLM
|
from .llms.huggingface_restapi import HuggingfaceRestAPILLM
|
||||||
|
@ -90,9 +91,12 @@ def completion(
|
||||||
# used by text-bison only
|
# used by text-bison only
|
||||||
top_k=40,
|
top_k=40,
|
||||||
request_timeout=0, # unused var for old version of OpenAI API
|
request_timeout=0, # unused var for old version of OpenAI API
|
||||||
|
fallbacks=[],
|
||||||
) -> ModelResponse:
|
) -> ModelResponse:
|
||||||
args = locals()
|
args = locals()
|
||||||
try:
|
try:
|
||||||
|
if fallbacks != []:
|
||||||
|
return completion_with_fallbacks(**args)
|
||||||
model_response = ModelResponse()
|
model_response = ModelResponse()
|
||||||
if azure: # this flag is deprecated, remove once notebooks are also updated.
|
if azure: # this flag is deprecated, remove once notebooks are also updated.
|
||||||
custom_llm_provider = "azure"
|
custom_llm_provider = "azure"
|
||||||
|
|
|
@ -11,6 +11,7 @@ sys.path.insert(
|
||||||
import pytest
|
import pytest
|
||||||
import litellm
|
import litellm
|
||||||
from litellm import embedding, completion
|
from litellm import embedding, completion
|
||||||
|
litellm.debugger = True
|
||||||
|
|
||||||
# from infisical import InfisicalClient
|
# from infisical import InfisicalClient
|
||||||
|
|
||||||
|
@ -38,7 +39,7 @@ def test_completion_custom_provider_model_name():
|
||||||
pytest.fail(f"Error occurred: {e}")
|
pytest.fail(f"Error occurred: {e}")
|
||||||
|
|
||||||
|
|
||||||
test_completion_custom_provider_model_name()
|
# test_completion_custom_provider_model_name()
|
||||||
|
|
||||||
|
|
||||||
def test_completion_claude():
|
def test_completion_claude():
|
||||||
|
@ -347,6 +348,21 @@ def test_petals():
|
||||||
pytest.fail(f"Error occurred: {e}")
|
pytest.fail(f"Error occurred: {e}")
|
||||||
|
|
||||||
|
|
||||||
|
def test_completion_with_fallbacks():
|
||||||
|
fallbacks = ['gpt-3.5-turb', 'gpt-3.5-turbo', 'command-nightly']
|
||||||
|
try:
|
||||||
|
response = completion(
|
||||||
|
model='bad-model',
|
||||||
|
messages=messages,
|
||||||
|
force_timeout=120,
|
||||||
|
fallbacks=fallbacks
|
||||||
|
)
|
||||||
|
# Add any assertions here to check the response
|
||||||
|
print(response)
|
||||||
|
except Exception as e:
|
||||||
|
pytest.fail(f"Error occurred: {e}")
|
||||||
|
|
||||||
|
|
||||||
# def test_baseten_falcon_7bcompletion():
|
# def test_baseten_falcon_7bcompletion():
|
||||||
# model_name = "qvv0xeq"
|
# model_name = "qvv0xeq"
|
||||||
# try:
|
# try:
|
||||||
|
|
|
@ -1464,3 +1464,40 @@ async def together_ai_completion_streaming(json_data, headers):
|
||||||
pass
|
pass
|
||||||
finally:
|
finally:
|
||||||
await session.close()
|
await session.close()
|
||||||
|
|
||||||
|
|
||||||
|
def completion_with_fallbacks(**kwargs):
|
||||||
|
response = None
|
||||||
|
rate_limited_models = set()
|
||||||
|
model_expiration_times = {}
|
||||||
|
start_time = time.time()
|
||||||
|
fallbacks = [kwargs['model']] + kwargs['fallbacks']
|
||||||
|
del kwargs['fallbacks'] # remove fallbacks so it's not recursive
|
||||||
|
|
||||||
|
while response == None and time.time() - start_time < 45:
|
||||||
|
for model in fallbacks:
|
||||||
|
# loop thru all models
|
||||||
|
try:
|
||||||
|
if model in rate_limited_models: # check if model is currently cooling down
|
||||||
|
if model_expiration_times.get(model) and time.time() >= model_expiration_times[model]:
|
||||||
|
rate_limited_models.remove(model) # check if it's been 60s of cool down and remove model
|
||||||
|
else:
|
||||||
|
continue # skip model
|
||||||
|
|
||||||
|
# delete model from kwargs if it exists
|
||||||
|
if kwargs.get('model'):
|
||||||
|
del kwargs['model']
|
||||||
|
|
||||||
|
print("making completion call", model)
|
||||||
|
response = litellm.completion(**kwargs, model=model)
|
||||||
|
|
||||||
|
if response != None:
|
||||||
|
return response
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"got exception {e} for model {model}")
|
||||||
|
rate_limited_models.add(model)
|
||||||
|
model_expiration_times[model] = time.time() + 60 # cool down this selected model
|
||||||
|
#print(f"rate_limited_models {rate_limited_models}")
|
||||||
|
pass
|
||||||
|
return response
|
Loading…
Add table
Add a link
Reference in a new issue