fix(utils.py): allow text completion input to be either model or engine

This commit is contained in:
Krrish Dholakia 2023-12-27 17:24:02 +05:30
parent ed615e7df4
commit e516cfe9f5
3 changed files with 113 additions and 3 deletions

View file

@ -136,10 +136,13 @@ suppress_debug_info = False
dynamodb_table_name: Optional[str] = None
#### RELIABILITY ####
request_timeout: Optional[float] = 6000
num_retries: Optional[int] = None
num_retries: Optional[int] = None # per model endpoint
fallbacks: Optional[List] = None
context_window_fallbacks: Optional[List] = None
allowed_fails: int = 0
num_retries_per_request: Optional[
int
] = None # for the request overall (incl. fallbacks + model retries)
####### SECRET MANAGERS #####################
secret_manager_client: Optional[
Any

View file

@ -554,3 +554,93 @@ def test_sync_fallbacks_streaming():
router.reset()
except Exception as e:
print(e)
@pytest.mark.asyncio
async def test_async_fallbacks_max_retries_per_request():
litellm.set_verbose = False
litellm.num_retries_per_request = 0
model_list = [
{ # list of model deployments
"model_name": "azure/gpt-3.5-turbo", # openai model name
"litellm_params": { # params for litellm completion/embedding call
"model": "azure/chatgpt-v-2",
"api_key": "bad-key",
"api_version": os.getenv("AZURE_API_VERSION"),
"api_base": os.getenv("AZURE_API_BASE"),
},
"tpm": 240000,
"rpm": 1800,
},
{ # list of model deployments
"model_name": "azure/gpt-3.5-turbo-context-fallback", # openai model name
"litellm_params": { # params for litellm completion/embedding call
"model": "azure/chatgpt-v-2",
"api_key": os.getenv("AZURE_API_KEY"),
"api_version": os.getenv("AZURE_API_VERSION"),
"api_base": os.getenv("AZURE_API_BASE"),
},
"tpm": 240000,
"rpm": 1800,
},
{
"model_name": "azure/gpt-3.5-turbo", # openai model name
"litellm_params": { # params for litellm completion/embedding call
"model": "azure/chatgpt-functioncalling",
"api_key": "bad-key",
"api_version": os.getenv("AZURE_API_VERSION"),
"api_base": os.getenv("AZURE_API_BASE"),
},
"tpm": 240000,
"rpm": 1800,
},
{
"model_name": "gpt-3.5-turbo", # openai model name
"litellm_params": { # params for litellm completion/embedding call
"model": "gpt-3.5-turbo",
"api_key": os.getenv("OPENAI_API_KEY"),
},
"tpm": 1000000,
"rpm": 9000,
},
{
"model_name": "gpt-3.5-turbo-16k", # openai model name
"litellm_params": { # params for litellm completion/embedding call
"model": "gpt-3.5-turbo-16k",
"api_key": os.getenv("OPENAI_API_KEY"),
},
"tpm": 1000000,
"rpm": 9000,
},
]
router = Router(
model_list=model_list,
fallbacks=[{"azure/gpt-3.5-turbo": ["gpt-3.5-turbo"]}],
context_window_fallbacks=[
{"azure/gpt-3.5-turbo-context-fallback": ["gpt-3.5-turbo-16k"]},
{"gpt-3.5-turbo": ["gpt-3.5-turbo-16k"]},
],
set_verbose=False,
)
customHandler = MyCustomHandler()
litellm.callbacks = [customHandler]
user_message = "Hello, how are you?"
messages = [{"content": user_message, "role": "user"}]
try:
try:
response = await router.acompletion(**kwargs, stream=True)
except:
pass
print(f"customHandler.previous_models: {customHandler.previous_models}")
await asyncio.sleep(
0.05
) # allow a delay as success_callbacks are on a separate thread
assert customHandler.previous_models == 0 # 0 retries, 0 fallback
router.reset()
except litellm.Timeout as e:
pass
except Exception as e:
pytest.fail(f"An exception occurred: {e}")
finally:
router.reset()

View file

@ -1925,7 +1925,10 @@ def client(original_function):
except:
model = None
call_type = original_function.__name__
if call_type != CallTypes.image_generation.value:
if (
call_type != CallTypes.image_generation.value
and call_type != CallTypes.text_completion.value
):
raise ValueError("model param not passed in.")
try:
@ -1945,6 +1948,16 @@ def client(original_function):
max_budget=litellm.max_budget,
)
# [OPTIONAL] CHECK MAX RETRIES / REQUEST
if litellm.num_retries_per_request is not None:
# check if previous_models passed in as ['litellm_params']['metadata]['previous_models']
previous_models = kwargs.get("metadata", {}).get(
"previous_models", None
)
if previous_models is not None:
if litellm.num_retries_per_request <= len(previous_models):
raise Exception(f"Max retries per request hit!")
# [OPTIONAL] CHECK CACHE
print_verbose(
f"kwargs[caching]: {kwargs.get('caching', False)}; litellm.cache: {litellm.cache}"
@ -2096,7 +2109,11 @@ def client(original_function):
try:
model = args[0] if len(args) > 0 else kwargs["model"]
except:
raise ValueError("model param not passed in.")
if (
call_type != CallTypes.aimage_generation.value # model optional
and call_type != CallTypes.atext_completion.value # can also be engine
):
raise ValueError("model param not passed in.")
try:
if logging_obj is None: