Merge branch 'BerriAI:main' into main

This commit is contained in:
Zakhar Kogan 2023-08-22 08:52:55 +03:00 committed by GitHub
commit 1c93c9c945
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
38 changed files with 8936 additions and 114 deletions

View file

@ -12,6 +12,7 @@ from .integrations.helicone import HeliconeLogger
from .integrations.aispend import AISpendLogger
from .integrations.berrispend import BerriSpendLogger
from .integrations.supabase import Supabase
from .integrations.litedebugger import LiteDebugger
from openai.error import OpenAIError as OriginalError
from openai.openai_object import OpenAIObject
from .exceptions import (
@ -35,6 +36,7 @@ heliconeLogger = None
aispendLogger = None
berrispendLogger = None
supabaseClient = None
liteDebuggerClient = None
callback_list: Optional[List[str]] = []
user_logger_fn = None
additional_details: Optional[Dict[str, str]] = {}
@ -136,6 +138,7 @@ def install_and_import(package: str):
####### LOGGING ###################
# Logging function -> log the exact model details + what's being sent | Non-Blocking
class Logging:
global supabaseClient, liteDebuggerClient
def __init__(self, model, messages, optional_params, litellm_params):
self.model = model
self.messages = messages
@ -151,7 +154,7 @@ class Logging:
def pre_call(self, input, api_key, additional_args={}):
try:
print(f"logging pre call for model: {self.model}")
print_verbose(f"logging pre call for model: {self.model}")
self.model_call_details["input"] = input
self.model_call_details["api_key"] = api_key
self.model_call_details["additional_args"] = additional_args
@ -177,7 +180,7 @@ class Logging:
print_verbose("reaches supabase for logging!")
model = self.model
messages = self.messages
print(f"litellm._thread_context: {litellm._thread_context}")
print(f"supabaseClient: {supabaseClient}")
supabaseClient.input_log_event(
model=model,
messages=messages,
@ -185,14 +188,34 @@ class Logging:
litellm_call_id=self.litellm_params["litellm_call_id"],
print_verbose=print_verbose,
)
pass
except:
pass
elif callback == "lite_debugger":
print_verbose("reaches litedebugger for logging!")
model = self.model
messages = self.messages
print_verbose(f"liteDebuggerClient: {liteDebuggerClient}")
liteDebuggerClient.input_log_event(
model=model,
messages=messages,
end_user=litellm._thread_context.user,
litellm_call_id=self.litellm_params["litellm_call_id"],
print_verbose=print_verbose,
)
except Exception as e:
print_verbose(f"LiteLLM.LoggingError: [Non-Blocking] Exception occurred while input logging with integrations {traceback.format_exc()}")
print_verbose(
f"LiteLLM.Logging: is sentry capture exception initialized {capture_exception}"
)
if capture_exception: # log this error to sentry for debugging
capture_exception(e)
except:
print_verbose(
f"LiteLLM.LoggingError: [Non-Blocking] Exception occurred while logging {traceback.format_exc()}"
)
pass
print_verbose(
f"LiteLLM.Logging: is sentry capture exception initialized {capture_exception}"
)
if capture_exception: # log this error to sentry for debugging
capture_exception(e)
def post_call(self, input, api_key, original_response, additional_args={}):
# Do something here
@ -220,9 +243,6 @@ class Logging:
f"LiteLLM.LoggingError: [Non-Blocking] Exception occurred while logging {traceback.format_exc()}"
)
pass
# Add more methods as needed
def exception_logging(
additional_args={},
@ -257,11 +277,16 @@ def exception_logging(
####### CLIENT ###################
# make it easy to log if completion/embedding runs succeeded or failed + see what happened | Non-Blocking
def client(original_function):
global liteDebuggerClient
def function_setup(
*args, **kwargs
): # just run once to check if user wants to send their data anywhere - PostHog/Sentry/Slack/etc.
try:
global callback_list, add_breadcrumb, user_logger_fn
if litellm.debugger or os.getenv("LITELLM_EMAIL", None) != None: # add to input, success and failure callbacks if user sets debugging to true
litellm.input_callback.append("lite_debugger")
litellm.success_callback.append("lite_debugger")
litellm.failure_callback.append("lite_debugger")
if (
len(litellm.input_callback) > 0 or len(litellm.success_callback) > 0 or len(litellm.failure_callback) > 0
) and len(callback_list) == 0:
@ -387,6 +412,9 @@ def client(original_function):
args=(e, traceback_exception, start_time, end_time, args, kwargs),
) # don't interrupt execution of main thread
my_thread.start()
if hasattr(e, "message"):
if liteDebuggerClient and liteDebuggerClient.dashboard_url != None: # make it easy to get to the debugger logs if you've initialized it
e.message += f"\n Check the log in your dashboard - {liteDebuggerClient.dashboard_url}"
raise e
return wrapper
@ -626,7 +654,7 @@ def load_test_model(
def set_callbacks(callback_list):
global sentry_sdk_instance, capture_exception, add_breadcrumb, posthog, slack_app, alerts_channel, heliconeLogger, aispendLogger, berrispendLogger, supabaseClient
global sentry_sdk_instance, capture_exception, add_breadcrumb, posthog, slack_app, alerts_channel, heliconeLogger, aispendLogger, berrispendLogger, supabaseClient, liteDebuggerClient
try:
for callback in callback_list:
print(f"callback: {callback}")
@ -688,12 +716,15 @@ def set_callbacks(callback_list):
elif callback == "supabase":
print(f"instantiating supabase")
supabaseClient = Supabase()
elif callback == "lite_debugger":
print(f"instantiating lite_debugger")
liteDebuggerClient = LiteDebugger(email=litellm.email)
except Exception as e:
raise e
def handle_failure(exception, traceback_exception, start_time, end_time, args, kwargs):
global sentry_sdk_instance, capture_exception, add_breadcrumb, posthog, slack_app, alerts_channel, aispendLogger, berrispendLogger
global sentry_sdk_instance, capture_exception, add_breadcrumb, posthog, slack_app, alerts_channel, aispendLogger, berrispendLogger, supabaseClient, liteDebuggerClient
try:
# print_verbose(f"handle_failure args: {args}")
# print_verbose(f"handle_failure kwargs: {kwargs}")
@ -794,6 +825,7 @@ def handle_failure(exception, traceback_exception, start_time, end_time, args, k
)
elif callback == "supabase":
print_verbose("reaches supabase for logging!")
print_verbose(f"supabaseClient: {supabaseClient}")
model = args[0] if len(args) > 0 else kwargs["model"]
messages = args[1] if len(args) > 1 else kwargs["messages"]
result = {
@ -817,6 +849,32 @@ def handle_failure(exception, traceback_exception, start_time, end_time, args, k
litellm_call_id=kwargs["litellm_call_id"],
print_verbose=print_verbose,
)
elif callback == "lite_debugger":
print_verbose("reaches lite_debugger for logging!")
print_verbose(f"liteDebuggerClient: {liteDebuggerClient}")
model = args[0] if len(args) > 0 else kwargs["model"]
messages = args[1] if len(args) > 1 else kwargs["messages"]
result = {
"model": model,
"created": time.time(),
"error": traceback_exception,
"usage": {
"prompt_tokens": prompt_token_calculator(
model, messages=messages
),
"completion_tokens": 0,
},
}
liteDebuggerClient.log_event(
model=model,
messages=messages,
end_user=litellm._thread_context.user,
response_obj=result,
start_time=start_time,
end_time=end_time,
litellm_call_id=kwargs["litellm_call_id"],
print_verbose=print_verbose,
)
except:
print_verbose(
f"Error Occurred while logging failure: {traceback.format_exc()}"
@ -837,7 +895,7 @@ def handle_failure(exception, traceback_exception, start_time, end_time, args, k
def handle_success(args, kwargs, result, start_time, end_time):
global heliconeLogger, aispendLogger
global heliconeLogger, aispendLogger, supabaseClient, liteDebuggerClient
try:
success_handler = additional_details.pop("success_handler", None)
failure_handler = additional_details.pop("failure_handler", None)
@ -904,7 +962,7 @@ def handle_success(args, kwargs, result, start_time, end_time):
print_verbose("reaches supabase for logging!")
model = args[0] if len(args) > 0 else kwargs["model"]
messages = args[1] if len(args) > 1 else kwargs["messages"]
print(f"litellm._thread_context: {litellm._thread_context}")
print(f"supabaseClient: {supabaseClient}")
supabaseClient.log_event(
model=model,
messages=messages,
@ -915,6 +973,21 @@ def handle_success(args, kwargs, result, start_time, end_time):
litellm_call_id=kwargs["litellm_call_id"],
print_verbose=print_verbose,
)
elif callback == "lite_debugger":
print_verbose("reaches lite_debugger for logging!")
model = args[0] if len(args) > 0 else kwargs["model"]
messages = args[1] if len(args) > 1 else kwargs["messages"]
print_verbose(f"liteDebuggerClient: {liteDebuggerClient}")
liteDebuggerClient.log_event(
model=model,
messages=messages,
end_user=litellm._thread_context.user,
response_obj=result,
start_time=start_time,
end_time=end_time,
litellm_call_id=kwargs["litellm_call_id"],
print_verbose=print_verbose,
)
except Exception as e:
## LOGGING
exception_logging(logger_fn=user_logger_fn, exception=e)
@ -935,6 +1008,9 @@ def handle_success(args, kwargs, result, start_time, end_time):
pass
def acreate(*args, **kwargs): ## Thin client to handle the acreate langchain call
return litellm.acompletion(*args, **kwargs)
def prompt_token_calculator(model, messages):
# use tiktoken or anthropic's tokenizer depending on the model
text = " ".join(message["content"] for message in messages)
@ -949,6 +1025,16 @@ def prompt_token_calculator(model, messages):
num_tokens = len(encoding.encode(text))
return num_tokens
def valid_model(model):
try:
# for a given model name, check if the user has the right permissions to access the model
if model in litellm.open_ai_chat_completion_models or model in litellm.open_ai_text_completion_models:
openai.Model.retrieve(model)
else:
messages = [{"role": "user", "content": "Hello World"}]
litellm.completion(model=model, messages=messages)
except:
raise InvalidRequestError(message="", model=model, llm_provider="")
# integration helper function
def modify_integration(integration_name, integration_params):
@ -958,8 +1044,9 @@ def modify_integration(integration_name, integration_params):
Supabase.supabase_table_name = integration_params["table_name"]
####### EXCEPTION MAPPING ################
def exception_type(model, original_exception, custom_llm_provider):
global user_logger_fn
global user_logger_fn, liteDebuggerClient
exception_mapping_worked = False
try:
if isinstance(original_exception, OriginalError):
@ -1099,6 +1186,7 @@ def exception_type(model, original_exception, custom_llm_provider):
raise original_exception
####### CRASH REPORTING ################
def safe_crash_reporting(model=None, exception=None, custom_llm_provider=None):
data = {
"model": model,
@ -1297,7 +1385,7 @@ async def stream_to_string(generator):
return response
########## Together AI streaming #############################
########## Together AI streaming ############################# [TODO] move together ai to it's own llm class
async def together_ai_completion_streaming(json_data, headers):
session = aiohttp.ClientSession()
url = "https://api.together.xyz/inference"