fix(utils.py): allow text completion input to be either model or engine

This commit is contained in:
Krrish Dholakia 2023-12-27 17:24:02 +05:30
parent ed615e7df4
commit e516cfe9f5
3 changed files with 113 additions and 3 deletions

View file

@ -1925,7 +1925,10 @@ def client(original_function):
except:
model = None
call_type = original_function.__name__
if call_type != CallTypes.image_generation.value:
if (
call_type != CallTypes.image_generation.value
and call_type != CallTypes.text_completion.value
):
raise ValueError("model param not passed in.")
try:
@ -1945,6 +1948,16 @@ def client(original_function):
max_budget=litellm.max_budget,
)
# [OPTIONAL] CHECK MAX RETRIES / REQUEST
if litellm.num_retries_per_request is not None:
# check if previous_models passed in as ['litellm_params']['metadata]['previous_models']
previous_models = kwargs.get("metadata", {}).get(
"previous_models", None
)
if previous_models is not None:
if litellm.num_retries_per_request <= len(previous_models):
raise Exception(f"Max retries per request hit!")
# [OPTIONAL] CHECK CACHE
print_verbose(
f"kwargs[caching]: {kwargs.get('caching', False)}; litellm.cache: {litellm.cache}"
@ -2096,7 +2109,11 @@ def client(original_function):
try:
model = args[0] if len(args) > 0 else kwargs["model"]
except:
raise ValueError("model param not passed in.")
if (
call_type != CallTypes.aimage_generation.value # model optional
and call_type != CallTypes.atext_completion.value # can also be engine
):
raise ValueError("model param not passed in.")
try:
if logging_obj is None: