formatting improvements

This commit is contained in:
ishaan-jaff 2023-08-28 09:20:50 -07:00
parent 3e0a16acf4
commit a69b7ffcfa
17 changed files with 464 additions and 323 deletions

View file

@ -69,7 +69,6 @@ last_fetched_at_keys = None
class Message(OpenAIObject):
def __init__(self, content="default", role="assistant", **params):
super(Message, self).__init__(**params)
self.content = content
@ -77,12 +76,7 @@ class Message(OpenAIObject):
class Choices(OpenAIObject):
def __init__(self,
finish_reason="stop",
index=0,
message=Message(),
**params):
def __init__(self, finish_reason="stop", index=0, message=Message(), **params):
super(Choices, self).__init__(**params)
self.finish_reason = finish_reason
self.index = index
@ -90,22 +84,20 @@ class Choices(OpenAIObject):
class ModelResponse(OpenAIObject):
def __init__(self,
choices=None,
created=None,
model=None,
usage=None,
**params):
def __init__(self, choices=None, created=None, model=None, usage=None, **params):
super(ModelResponse, self).__init__(**params)
self.choices = choices if choices else [Choices()]
self.created = created
self.model = model
self.usage = (usage if usage else {
"prompt_tokens": None,
"completion_tokens": None,
"total_tokens": None,
})
self.usage = (
usage
if usage
else {
"prompt_tokens": None,
"completion_tokens": None,
"total_tokens": None,
}
)
def to_dict_recursive(self):
d = super().to_dict_recursive()
@ -173,7 +165,9 @@ class Logging:
self.model_call_details["api_key"] = api_key
self.model_call_details["additional_args"] = additional_args
if model: # if model name was changes pre-call, overwrite the initial model call name with the new one
if (
model
): # if model name was changes pre-call, overwrite the initial model call name with the new one
self.model_call_details["model"] = model
# User Logging -> if you pass in a custom logging function
@ -203,8 +197,7 @@ class Logging:
model=model,
messages=messages,
end_user=litellm._thread_context.user,
litellm_call_id=self.
litellm_params["litellm_call_id"],
litellm_call_id=self.litellm_params["litellm_call_id"],
print_verbose=print_verbose,
)
@ -217,8 +210,7 @@ class Logging:
model=model,
messages=messages,
end_user=litellm._thread_context.user,
litellm_call_id=self.
litellm_params["litellm_call_id"],
litellm_call_id=self.litellm_params["litellm_call_id"],
litellm_params=self.model_call_details["litellm_params"],
optional_params=self.model_call_details["optional_params"],
print_verbose=print_verbose,
@ -263,7 +255,7 @@ class Logging:
print_verbose(
f"LiteLLM.LoggingError: [Non-Blocking] Exception occurred while logging {traceback.format_exc()}"
)
# Input Integration Logging -> If you want to log the fact that an attempt to call the model was made
for callback in litellm.input_callback:
try:
@ -274,8 +266,7 @@ class Logging:
print_verbose(f"liteDebuggerClient: {liteDebuggerClient}")
liteDebuggerClient.post_call_log_event(
original_response=original_response,
litellm_call_id=self.
litellm_params["litellm_call_id"],
litellm_call_id=self.litellm_params["litellm_call_id"],
print_verbose=print_verbose,
)
except:
@ -295,6 +286,7 @@ class Logging:
# Add more methods as needed
def exception_logging(
additional_args={},
logger_fn=None,
@ -329,13 +321,18 @@ def exception_logging(
# make it easy to log if completion/embedding runs succeeded or failed + see what happened | Non-Blocking
def client(original_function):
global liteDebuggerClient, get_all_keys
def function_setup(
*args, **kwargs
): # just run once to check if user wants to send their data anywhere - PostHog/Sentry/Slack/etc.
try:
global callback_list, add_breadcrumb, user_logger_fn
if litellm.email is not None or os.getenv("LITELLM_EMAIL", None) is not None or litellm.token is not None or os.getenv("LITELLM_TOKEN", None): # add to input, success and failure callbacks if user is using hosted product
if (
litellm.email is not None
or os.getenv("LITELLM_EMAIL", None) is not None
or litellm.token is not None
or os.getenv("LITELLM_TOKEN", None)
): # add to input, success and failure callbacks if user is using hosted product
get_all_keys()
if "lite_debugger" not in callback_list and litellm.logging:
litellm.input_callback.append("lite_debugger")
@ -381,11 +378,12 @@ def client(original_function):
if litellm.telemetry:
try:
model = args[0] if len(args) > 0 else kwargs["model"]
exception = kwargs[
"exception"] if "exception" in kwargs else None
custom_llm_provider = (kwargs["custom_llm_provider"]
if "custom_llm_provider" in kwargs else
None)
exception = kwargs["exception"] if "exception" in kwargs else None
custom_llm_provider = (
kwargs["custom_llm_provider"]
if "custom_llm_provider" in kwargs
else None
)
safe_crash_reporting(
model=model,
exception=exception,
@ -410,10 +408,10 @@ def client(original_function):
def check_cache(*args, **kwargs):
try: # never block execution
prompt = get_prompt(*args, **kwargs)
if (prompt != None): # check if messages / prompt exists
if prompt != None: # check if messages / prompt exists
if litellm.caching_with_models:
# if caching with model names is enabled, key is prompt + model name
if ("model" in kwargs):
if "model" in kwargs:
cache_key = prompt + kwargs["model"]
if cache_key in local_cache:
return local_cache[cache_key]
@ -423,7 +421,7 @@ def client(original_function):
return result
else:
return None
return None # default to return None
return None # default to return None
except:
return None
@ -431,7 +429,7 @@ def client(original_function):
try: # never block execution
prompt = get_prompt(*args, **kwargs)
if litellm.caching_with_models: # caching with model + prompt
if ("model" in kwargs):
if "model" in kwargs:
cache_key = prompt + kwargs["model"]
local_cache[cache_key] = result
else: # caching based only on prompts
@ -449,7 +447,8 @@ def client(original_function):
start_time = datetime.datetime.now()
# [OPTIONAL] CHECK CACHE
if (litellm.caching or litellm.caching_with_models) and (
cached_result := check_cache(*args, **kwargs)) is not None:
cached_result := check_cache(*args, **kwargs)
) is not None:
result = cached_result
return result
# MODEL CALL
@ -458,25 +457,22 @@ def client(original_function):
return result
end_time = datetime.datetime.now()
# [OPTIONAL] ADD TO CACHE
if (litellm.caching or litellm.caching_with_models):
if litellm.caching or litellm.caching_with_models:
add_cache(result, *args, **kwargs)
# LOG SUCCESS
my_thread = threading.Thread(
target=handle_success,
args=(args, kwargs, result, start_time,
end_time)) # don't interrupt execution of main thread
target=handle_success, args=(args, kwargs, result, start_time, end_time)
) # don't interrupt execution of main thread
my_thread.start()
# RETURN RESULT
return result
except Exception as e:
traceback_exception = traceback.format_exc()
crash_reporting(*args, **kwargs, exception=traceback_exception)
end_time = datetime.datetime.now()
my_thread = threading.Thread(
target=handle_failure,
args=(e, traceback_exception, start_time, end_time, args,
kwargs),
args=(e, traceback_exception, start_time, end_time, args, kwargs),
) # don't interrupt execution of main thread
my_thread.start()
if hasattr(e, "message"):
@ -506,18 +502,18 @@ def token_counter(model, text):
return num_tokens
def cost_per_token(model="gpt-3.5-turbo",
prompt_tokens=0,
completion_tokens=0):
def cost_per_token(model="gpt-3.5-turbo", prompt_tokens=0, completion_tokens=0):
# given
prompt_tokens_cost_usd_dollar = 0
completion_tokens_cost_usd_dollar = 0
model_cost_ref = litellm.model_cost
if model in model_cost_ref:
prompt_tokens_cost_usd_dollar = (
model_cost_ref[model]["input_cost_per_token"] * prompt_tokens)
model_cost_ref[model]["input_cost_per_token"] * prompt_tokens
)
completion_tokens_cost_usd_dollar = (
model_cost_ref[model]["output_cost_per_token"] * completion_tokens)
model_cost_ref[model]["output_cost_per_token"] * completion_tokens
)
return prompt_tokens_cost_usd_dollar, completion_tokens_cost_usd_dollar
else:
# calculate average input cost
@ -538,9 +534,8 @@ def completion_cost(model="gpt-3.5-turbo", prompt="", completion=""):
prompt_tokens = token_counter(model=model, text=prompt)
completion_tokens = token_counter(model=model, text=completion)
prompt_tokens_cost_usd_dollar, completion_tokens_cost_usd_dollar = cost_per_token(
model=model,
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens)
model=model, prompt_tokens=prompt_tokens, completion_tokens=completion_tokens
)
return prompt_tokens_cost_usd_dollar + completion_tokens_cost_usd_dollar
@ -558,7 +553,7 @@ def get_litellm_params(
custom_llm_provider=None,
custom_api_base=None,
litellm_call_id=None,
model_alias_map=None
model_alias_map=None,
):
litellm_params = {
"return_async": return_async,
@ -569,13 +564,13 @@ def get_litellm_params(
"custom_llm_provider": custom_llm_provider,
"custom_api_base": custom_api_base,
"litellm_call_id": litellm_call_id,
"model_alias_map": model_alias_map
"model_alias_map": model_alias_map,
}
return litellm_params
def get_optional_params( # use the openai defaults
def get_optional_params( # use the openai defaults
# 12 optional params
functions=[],
function_call="",
@ -588,7 +583,7 @@ def get_optional_params( # use the openai defaults
presence_penalty=0,
frequency_penalty=0,
logit_bias={},
num_beams=1,
num_beams=1,
user="",
deployment_id=None,
model=None,
@ -635,8 +630,9 @@ def get_optional_params( # use the openai defaults
optional_params["max_tokens"] = max_tokens
if frequency_penalty != 0:
optional_params["frequency_penalty"] = frequency_penalty
elif (model == "chat-bison"
): # chat-bison has diff args from chat-bison@001 ty Google
elif (
model == "chat-bison"
): # chat-bison has diff args from chat-bison@001 ty Google
if temperature != 1:
optional_params["temperature"] = temperature
if top_p != 1:
@ -702,10 +698,7 @@ def load_test_model(
test_prompt = prompt
if num_calls:
test_calls = num_calls
messages = [[{
"role": "user",
"content": test_prompt
}] for _ in range(test_calls)]
messages = [[{"role": "user", "content": test_prompt}] for _ in range(test_calls)]
start_time = time.time()
try:
litellm.batch_completion(
@ -743,15 +736,17 @@ def set_callbacks(callback_list):
try:
import sentry_sdk
except ImportError:
print_verbose(
"Package 'sentry_sdk' is missing. Installing it...")
print_verbose("Package 'sentry_sdk' is missing. Installing it...")
subprocess.check_call(
[sys.executable, "-m", "pip", "install", "sentry_sdk"])
[sys.executable, "-m", "pip", "install", "sentry_sdk"]
)
import sentry_sdk
sentry_sdk_instance = sentry_sdk
sentry_trace_rate = (os.environ.get("SENTRY_API_TRACE_RATE")
if "SENTRY_API_TRACE_RATE" in os.environ
else "1.0")
sentry_trace_rate = (
os.environ.get("SENTRY_API_TRACE_RATE")
if "SENTRY_API_TRACE_RATE" in os.environ
else "1.0"
)
sentry_sdk_instance.init(
dsn=os.environ.get("SENTRY_API_URL"),
traces_sample_rate=float(sentry_trace_rate),
@ -762,10 +757,10 @@ def set_callbacks(callback_list):
try:
from posthog import Posthog
except ImportError:
print_verbose(
"Package 'posthog' is missing. Installing it...")
print_verbose("Package 'posthog' is missing. Installing it...")
subprocess.check_call(
[sys.executable, "-m", "pip", "install", "posthog"])
[sys.executable, "-m", "pip", "install", "posthog"]
)
from posthog import Posthog
posthog = Posthog(
project_api_key=os.environ.get("POSTHOG_API_KEY"),
@ -775,10 +770,10 @@ def set_callbacks(callback_list):
try:
from slack_bolt import App
except ImportError:
print_verbose(
"Package 'slack_bolt' is missing. Installing it...")
print_verbose("Package 'slack_bolt' is missing. Installing it...")
subprocess.check_call(
[sys.executable, "-m", "pip", "install", "slack_bolt"])
[sys.executable, "-m", "pip", "install", "slack_bolt"]
)
from slack_bolt import App
slack_app = App(
token=os.environ.get("SLACK_API_TOKEN"),
@ -809,8 +804,7 @@ def set_callbacks(callback_list):
raise e
def handle_failure(exception, traceback_exception, start_time, end_time, args,
kwargs):
def handle_failure(exception, traceback_exception, start_time, end_time, args, kwargs):
global sentry_sdk_instance, capture_exception, add_breadcrumb, posthog, slack_app, alerts_channel, aispendLogger, berrispendLogger, supabaseClient, liteDebuggerClient
try:
# print_verbose(f"handle_failure args: {args}")
@ -820,7 +814,8 @@ def handle_failure(exception, traceback_exception, start_time, end_time, args,
failure_handler = additional_details.pop("failure_handler", None)
additional_details["Event_Name"] = additional_details.pop(
"failed_event_name", "litellm.failed_query")
"failed_event_name", "litellm.failed_query"
)
print_verbose(f"self.failure_callback: {litellm.failure_callback}")
for callback in litellm.failure_callback:
try:
@ -835,8 +830,9 @@ def handle_failure(exception, traceback_exception, start_time, end_time, args,
for detail in additional_details:
slack_msg += f"{detail}: {additional_details[detail]}\n"
slack_msg += f"Traceback: {traceback_exception}"
slack_app.client.chat_postMessage(channel=alerts_channel,
text=slack_msg)
slack_app.client.chat_postMessage(
channel=alerts_channel, text=slack_msg
)
elif callback == "sentry":
capture_exception(exception)
elif callback == "posthog":
@ -855,8 +851,9 @@ def handle_failure(exception, traceback_exception, start_time, end_time, args,
print_verbose(f"ph_obj: {ph_obj}")
print_verbose(f"PostHog Event Name: {event_name}")
if "user_id" in additional_details:
posthog.capture(additional_details["user_id"],
event_name, ph_obj)
posthog.capture(
additional_details["user_id"], event_name, ph_obj
)
else: # PostHog calls require a unique id to identify a user - https://posthog.com/docs/libraries/python
unique_id = str(uuid.uuid4())
posthog.capture(unique_id, event_name)
@ -870,10 +867,10 @@ def handle_failure(exception, traceback_exception, start_time, end_time, args,
"created": time.time(),
"error": traceback_exception,
"usage": {
"prompt_tokens":
prompt_token_calculator(model, messages=messages),
"completion_tokens":
0,
"prompt_tokens": prompt_token_calculator(
model, messages=messages
),
"completion_tokens": 0,
},
}
berrispendLogger.log_event(
@ -892,10 +889,10 @@ def handle_failure(exception, traceback_exception, start_time, end_time, args,
"model": model,
"created": time.time(),
"usage": {
"prompt_tokens":
prompt_token_calculator(model, messages=messages),
"completion_tokens":
0,
"prompt_tokens": prompt_token_calculator(
model, messages=messages
),
"completion_tokens": 0,
},
}
aispendLogger.log_event(
@ -910,10 +907,13 @@ def handle_failure(exception, traceback_exception, start_time, end_time, args,
model = args[0] if len(args) > 0 else kwargs["model"]
input = args[1] if len(args) > 1 else kwargs.get(
"messages", kwargs.get("input", None))
input = (
args[1]
if len(args) > 1
else kwargs.get("messages", kwargs.get("input", None))
)
type = 'embed' if 'input' in kwargs else 'llm'
type = "embed" if "input" in kwargs else "llm"
llmonitorLogger.log_event(
type=type,
@ -937,10 +937,10 @@ def handle_failure(exception, traceback_exception, start_time, end_time, args,
"created": time.time(),
"error": traceback_exception,
"usage": {
"prompt_tokens":
prompt_token_calculator(model, messages=messages),
"completion_tokens":
0,
"prompt_tokens": prompt_token_calculator(
model, messages=messages
),
"completion_tokens": 0,
},
}
supabaseClient.log_event(
@ -957,16 +957,28 @@ def handle_failure(exception, traceback_exception, start_time, end_time, args,
print_verbose("reaches lite_debugger for logging!")
print_verbose(f"liteDebuggerClient: {liteDebuggerClient}")
model = args[0] if len(args) > 0 else kwargs["model"]
messages = args[1] if len(args) > 1 else kwargs.get("messages", [{"role": "user", "content": ' '.join(kwargs.get("input", ""))}])
messages = (
args[1]
if len(args) > 1
else kwargs.get(
"messages",
[
{
"role": "user",
"content": " ".join(kwargs.get("input", "")),
}
],
)
)
result = {
"model": model,
"created": time.time(),
"error": traceback_exception,
"usage": {
"prompt_tokens":
prompt_token_calculator(model, messages=messages),
"completion_tokens":
0,
"prompt_tokens": prompt_token_calculator(
model, messages=messages
),
"completion_tokens": 0,
},
}
liteDebuggerClient.log_event(
@ -1002,11 +1014,16 @@ def handle_success(args, kwargs, result, start_time, end_time):
global heliconeLogger, aispendLogger, supabaseClient, liteDebuggerClient, llmonitorLogger
try:
model = args[0] if len(args) > 0 else kwargs["model"]
input = args[1] if len(args) > 1 else kwargs.get("messages", kwargs.get("input", None))
input = (
args[1]
if len(args) > 1
else kwargs.get("messages", kwargs.get("input", None))
)
success_handler = additional_details.pop("success_handler", None)
failure_handler = additional_details.pop("failure_handler", None)
additional_details["Event_Name"] = additional_details.pop(
"successful_event_name", "litellm.succes_query")
"successful_event_name", "litellm.succes_query"
)
for callback in litellm.success_callback:
try:
if callback == "posthog":
@ -1015,8 +1032,9 @@ def handle_success(args, kwargs, result, start_time, end_time):
ph_obj[detail] = additional_details[detail]
event_name = additional_details["Event_Name"]
if "user_id" in additional_details:
posthog.capture(additional_details["user_id"],
event_name, ph_obj)
posthog.capture(
additional_details["user_id"], event_name, ph_obj
)
else: # PostHog calls require a unique id to identify a user - https://posthog.com/docs/libraries/python
unique_id = str(uuid.uuid4())
posthog.capture(unique_id, event_name, ph_obj)
@ -1025,8 +1043,9 @@ def handle_success(args, kwargs, result, start_time, end_time):
slack_msg = ""
for detail in additional_details:
slack_msg += f"{detail}: {additional_details[detail]}\n"
slack_app.client.chat_postMessage(channel=alerts_channel,
text=slack_msg)
slack_app.client.chat_postMessage(
channel=alerts_channel, text=slack_msg
)
elif callback == "helicone":
print_verbose("reaches helicone for logging!")
model = args[0] if len(args) > 0 else kwargs["model"]
@ -1043,11 +1062,14 @@ def handle_success(args, kwargs, result, start_time, end_time):
print_verbose("reaches llmonitor for logging!")
model = args[0] if len(args) > 0 else kwargs["model"]
input = args[1] if len(args) > 1 else kwargs.get(
"messages", kwargs.get("input", None))
input = (
args[1]
if len(args) > 1
else kwargs.get("messages", kwargs.get("input", None))
)
#if contains input, it's 'embedding', otherwise 'llm'
type = 'embed' if 'input' in kwargs else 'llm'
# if contains input, it's 'embedding', otherwise 'llm'
type = "embed" if "input" in kwargs else "llm"
llmonitorLogger.log_event(
type=type,
@ -1069,7 +1091,6 @@ def handle_success(args, kwargs, result, start_time, end_time):
start_time=start_time,
end_time=end_time,
print_verbose=print_verbose,
)
elif callback == "aispend":
print_verbose("reaches aispend for logging!")
@ -1084,7 +1105,11 @@ def handle_success(args, kwargs, result, start_time, end_time):
elif callback == "supabase":
print_verbose("reaches supabase for logging!")
model = args[0] if len(args) > 0 else kwargs["model"]
messages = args[1] if len(args) > 1 else kwargs.get("messages", {"role": "user", "content": ""})
messages = (
args[1]
if len(args) > 1
else kwargs.get("messages", {"role": "user", "content": ""})
)
print(f"supabaseClient: {supabaseClient}")
supabaseClient.log_event(
model=model,
@ -1099,7 +1124,19 @@ def handle_success(args, kwargs, result, start_time, end_time):
elif callback == "lite_debugger":
print_verbose("reaches lite_debugger for logging!")
print_verbose(f"liteDebuggerClient: {liteDebuggerClient}")
messages = args[1] if len(args) > 1 else kwargs.get("messages", [{"role": "user", "content": ' '.join(kwargs.get("input", ""))}])
messages = (
args[1]
if len(args) > 1
else kwargs.get(
"messages",
[
{
"role": "user",
"content": " ".join(kwargs.get("input", "")),
}
],
)
)
liteDebuggerClient.log_event(
model=model,
messages=messages,
@ -1129,6 +1166,7 @@ def handle_success(args, kwargs, result, start_time, end_time):
)
pass
def acreate(*args, **kwargs): ## Thin client to handle the acreate langchain call
return litellm.acompletion(*args, **kwargs)
@ -1170,28 +1208,43 @@ def modify_integration(integration_name, integration_params):
if "table_name" in integration_params:
Supabase.supabase_table_name = integration_params["table_name"]
####### [BETA] HOSTED PRODUCT ################ - https://docs.litellm.ai/docs/debugging/hosted_debugging
def get_all_keys(llm_provider=None):
try:
global last_fetched_at_keys
# if user is using hosted product -> instantiate their env with their hosted api keys - refresh every 5 minutes
print_verbose(f"Reaches get all keys, llm_provider: {llm_provider}")
user_email = os.getenv("LITELLM_EMAIL") or litellm.email or litellm.token or os.getenv("LITELLM_TOKEN")
user_email = (
os.getenv("LITELLM_EMAIL")
or litellm.email
or litellm.token
or os.getenv("LITELLM_TOKEN")
)
if user_email:
time_delta = 0
if last_fetched_at_keys != None:
current_time = time.time()
time_delta = current_time - last_fetched_at_keys
if time_delta > 300 or last_fetched_at_keys == None or llm_provider: # if the llm provider is passed in , assume this happening due to an AuthError for that provider
if (
time_delta > 300 or last_fetched_at_keys == None or llm_provider
): # if the llm provider is passed in , assume this happening due to an AuthError for that provider
# make the api call
last_fetched_at = time.time()
print_verbose(f"last_fetched_at: {last_fetched_at}")
response = requests.post(url="http://api.litellm.ai/get_all_keys", headers={"content-type": "application/json"}, data=json.dumps({"user_email": user_email}))
response = requests.post(
url="http://api.litellm.ai/get_all_keys",
headers={"content-type": "application/json"},
data=json.dumps({"user_email": user_email}),
)
print_verbose(f"get model key response: {response.text}")
data = response.json()
# update model list
for key, value in data["model_keys"].items(): # follows the LITELLM API KEY format - <UPPERCASE_PROVIDER_NAME>_API_KEY - e.g. HUGGINGFACE_API_KEY
for key, value in data[
"model_keys"
].items(): # follows the LITELLM API KEY format - <UPPERCASE_PROVIDER_NAME>_API_KEY - e.g. HUGGINGFACE_API_KEY
os.environ[key] = value
# set model alias map
for model_alias, value in data["model_alias_map"].items():
@ -1200,19 +1253,31 @@ def get_all_keys(llm_provider=None):
return None
return None
except:
print_verbose(f"[Non-Blocking Error] get_all_keys error - {traceback.format_exc()}")
print_verbose(
f"[Non-Blocking Error] get_all_keys error - {traceback.format_exc()}"
)
pass
def get_model_list():
global last_fetched_at
try:
# if user is using hosted product -> get their updated model list
user_email = os.getenv("LITELLM_EMAIL") or litellm.email or litellm.token or os.getenv("LITELLM_TOKEN")
user_email = (
os.getenv("LITELLM_EMAIL")
or litellm.email
or litellm.token
or os.getenv("LITELLM_TOKEN")
)
if user_email:
# make the api call
last_fetched_at = time.time()
print(f"last_fetched_at: {last_fetched_at}")
response = requests.post(url="http://api.litellm.ai/get_model_list", headers={"content-type": "application/json"}, data=json.dumps({"user_email": user_email}))
response = requests.post(
url="http://api.litellm.ai/get_model_list",
headers={"content-type": "application/json"},
data=json.dumps({"user_email": user_email}),
)
print_verbose(f"get_model_list response: {response.text}")
data = response.json()
# update model list
@ -1224,12 +1289,14 @@ def get_model_list():
if f"{item.upper()}_API_KEY" not in os.environ:
missing_llm_provider = item
break
# update environment - if required
# update environment - if required
threading.Thread(target=get_all_keys, args=(missing_llm_provider)).start()
return model_list
return [] # return empty list by default
return [] # return empty list by default
except:
print_verbose(f"[Non-Blocking Error] get_all_keys error - {traceback.format_exc()}")
print_verbose(
f"[Non-Blocking Error] get_all_keys error - {traceback.format_exc()}"
)
####### EXCEPTION MAPPING ################
@ -1253,36 +1320,33 @@ def exception_type(model, original_exception, custom_llm_provider):
exception_type = ""
if "claude" in model: # one of the anthropics
if hasattr(original_exception, "status_code"):
print_verbose(
f"status_code: {original_exception.status_code}")
print_verbose(f"status_code: {original_exception.status_code}")
if original_exception.status_code == 401:
exception_mapping_worked = True
raise AuthenticationError(
message=
f"AnthropicException - {original_exception.message}",
message=f"AnthropicException - {original_exception.message}",
llm_provider="anthropic",
)
elif original_exception.status_code == 400:
exception_mapping_worked = True
raise InvalidRequestError(
message=
f"AnthropicException - {original_exception.message}",
message=f"AnthropicException - {original_exception.message}",
model=model,
llm_provider="anthropic",
)
elif original_exception.status_code == 429:
exception_mapping_worked = True
raise RateLimitError(
message=
f"AnthropicException - {original_exception.message}",
message=f"AnthropicException - {original_exception.message}",
llm_provider="anthropic",
)
elif ("Could not resolve authentication method. Expected either api_key or auth_token to be set."
in error_str):
elif (
"Could not resolve authentication method. Expected either api_key or auth_token to be set."
in error_str
):
exception_mapping_worked = True
raise AuthenticationError(
message=
f"AnthropicException - {original_exception.message}",
message=f"AnthropicException - {original_exception.message}",
llm_provider="anthropic",
)
elif "replicate" in model:
@ -1306,36 +1370,35 @@ def exception_type(model, original_exception, custom_llm_provider):
llm_provider="replicate",
)
elif (
exception_type == "ReplicateError"
exception_type == "ReplicateError"
): # ReplicateError implies an error on Replicate server side, not user side
raise ServiceUnavailableError(
message=f"ReplicateException - {error_str}",
llm_provider="replicate",
)
elif model == "command-nightly": # Cohere
if ("invalid api token" in error_str
or "No API key provided." in error_str):
if (
"invalid api token" in error_str
or "No API key provided." in error_str
):
exception_mapping_worked = True
raise AuthenticationError(
message=
f"CohereException - {original_exception.message}",
message=f"CohereException - {original_exception.message}",
llm_provider="cohere",
)
elif "too many tokens" in error_str:
exception_mapping_worked = True
raise InvalidRequestError(
message=
f"CohereException - {original_exception.message}",
message=f"CohereException - {original_exception.message}",
model=model,
llm_provider="cohere",
)
elif (
"CohereConnectionError" in exception_type
"CohereConnectionError" in exception_type
): # cohere seems to fire these errors when we load test it (1k+ messages / min)
exception_mapping_worked = True
raise RateLimitError(
message=
f"CohereException - {original_exception.message}",
message=f"CohereException - {original_exception.message}",
llm_provider="cohere",
)
elif custom_llm_provider == "huggingface":
@ -1343,23 +1406,20 @@ def exception_type(model, original_exception, custom_llm_provider):
if original_exception.status_code == 401:
exception_mapping_worked = True
raise AuthenticationError(
message=
f"HuggingfaceException - {original_exception.message}",
message=f"HuggingfaceException - {original_exception.message}",
llm_provider="huggingface",
)
elif original_exception.status_code == 400:
exception_mapping_worked = True
raise InvalidRequestError(
message=
f"HuggingfaceException - {original_exception.message}",
message=f"HuggingfaceException - {original_exception.message}",
model=model,
llm_provider="huggingface",
)
elif original_exception.status_code == 429:
exception_mapping_worked = True
raise RateLimitError(
message=
f"HuggingfaceException - {original_exception.message}",
message=f"HuggingfaceException - {original_exception.message}",
llm_provider="huggingface",
)
raise original_exception # base case - return the original exception
@ -1375,8 +1435,10 @@ def exception_type(model, original_exception, custom_llm_provider):
},
exception=e,
)
## AUTH ERROR
if isinstance(e, AuthenticationError) and (litellm.email or "LITELLM_EMAIL" in os.environ):
## AUTH ERROR
if isinstance(e, AuthenticationError) and (
litellm.email or "LITELLM_EMAIL" in os.environ
):
threading.Thread(target=get_all_keys, args=(e.llm_provider,)).start()
if exception_mapping_worked:
raise e
@ -1391,7 +1453,8 @@ def safe_crash_reporting(model=None, exception=None, custom_llm_provider=None):
"exception": str(exception),
"custom_llm_provider": custom_llm_provider,
}
threading.Thread(target=litellm_telemetry, args=(data, )).start()
threading.Thread(target=litellm_telemetry, args=(data,)).start()
def get_or_generate_uuid():
uuid_file = "litellm_uuid.txt"
@ -1445,8 +1508,7 @@ def get_secret(secret_name):
# TODO: check which secret manager is being used
# currently only supports Infisical
try:
secret = litellm.secret_manager_client.get_secret(
secret_name).secret_value
secret = litellm.secret_manager_client.get_secret(secret_name).secret_value
except:
secret = None
return secret
@ -1460,7 +1522,6 @@ def get_secret(secret_name):
# wraps the completion stream to return the correct format for the model
# replicate/anthropic/cohere
class CustomStreamWrapper:
def __init__(self, completion_stream, model, custom_llm_provider=None):
self.model = model
self.custom_llm_provider = custom_llm_provider
@ -1509,8 +1570,9 @@ class CustomStreamWrapper:
elif self.model == "replicate":
chunk = next(self.completion_stream)
completion_obj["content"] = chunk
elif (self.custom_llm_provider and self.custom_llm_provider == "together_ai") or ("togethercomputer"
in self.model):
elif (
self.custom_llm_provider and self.custom_llm_provider == "together_ai"
) or ("togethercomputer" in self.model):
chunk = next(self.completion_stream)
text_data = self.handle_together_ai_chunk(chunk)
if text_data == "":
@ -1545,9 +1607,9 @@ def read_config_args(config_path):
########## ollama implementation ############################
async def get_ollama_response_stream(api_base="http://localhost:11434",
model="llama2",
prompt="Why is the sky blue?"):
async def get_ollama_response_stream(
api_base="http://localhost:11434", model="llama2", prompt="Why is the sky blue?"
):
session = aiohttp.ClientSession()
url = f"{api_base}/api/generate"
data = {
@ -1570,11 +1632,7 @@ async def get_ollama_response_stream(api_base="http://localhost:11434",
"content": "",
}
completion_obj["content"] = j["response"]
yield {
"choices": [{
"delta": completion_obj
}]
}
yield {"choices": [{"delta": completion_obj}]}
# self.responses.append(j["response"])
# yield "blank"
except Exception as e: