fixing exception mapping

This commit is contained in:
Krrish Dholakia 2023-08-05 09:52:01 -07:00
parent 9b0e9bf57c
commit 92a13958ce
8 changed files with 188 additions and 115 deletions

View file

@ -6,7 +6,7 @@ from functools import partial
import dotenv
import traceback
import litellm
from litellm import client, logging, exception_type, timeout
from litellm import client, logging, exception_type, timeout, get_optional_params
import random
import asyncio
from tenacity import (
@ -20,51 +20,6 @@ dotenv.load_dotenv() # Loading env variables using dotenv
# TODO move this to utils.py
# TODO add translations
# TODO see if this worked - model_name == krrish
def get_optional_params(
# 12 optional params
functions = [],
function_call = "",
temperature = 1,
top_p = 1,
n = 1,
stream = False,
stop = None,
max_tokens = float('inf'),
presence_penalty = 0,
frequency_penalty = 0,
logit_bias = {},
user = "",
deployment_id = None
):
optional_params = {}
if functions != []:
optional_params["functions"] = functions
if function_call != "":
optional_params["function_call"] = function_call
if temperature != 1:
optional_params["temperature"] = temperature
if top_p != 1:
optional_params["top_p"] = top_p
if n != 1:
optional_params["n"] = n
if stream:
optional_params["stream"] = stream
if stop != None:
optional_params["stop"] = stop
if max_tokens != float('inf'):
optional_params["max_tokens"] = max_tokens
if presence_penalty != 0:
optional_params["presence_penalty"] = presence_penalty
if frequency_penalty != 0:
optional_params["frequency_penalty"] = frequency_penalty
if logit_bias != {}:
optional_params["logit_bias"] = logit_bias
if user != "":
optional_params["user"] = user
if deployment_id != None:
optional_params["deployment_id"] = user
return optional_params
####### COMPLETION ENDPOINTS ################
#############################################
async def acompletion(*args, **kwargs):
@ -285,12 +240,13 @@ def completion(
}
response = new_response
else:
## LOGGING
logging(model=model, input=messages, azure=azure, logger_fn=logger_fn)
args = locals()
raise ValueError(f"No valid completion model args passed in - {args}")
return response
except Exception as e:
# log the original exception
## LOGGING
logging(model=model, input=messages, azure=azure, additional_args={"max_tokens": max_tokens}, logger_fn=logger_fn, exception=e)
## Map to OpenAI Exception
raise exception_type(model=model, original_exception=e)