Merge branch 'main' into main

This commit is contained in:
Vince Lwt 2023-08-23 18:39:09 +02:00 committed by GitHub
commit dcbf5e2444
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
48 changed files with 9611 additions and 239 deletions

View file

@ -1,7 +1,20 @@
import aiohttp
import subprocess
import importlib
from typing import List, Dict, Union, Optional
import sys
import dotenv, json, traceback, threading
import subprocess, os
import litellm, openai
import random, uuid, requests
import datetime, time
import tiktoken
encoding = tiktoken.get_encoding("cl100k_base")
import importlib.metadata
from .integrations.helicone import HeliconeLogger
from .integrations.aispend import AISpendLogger
from .integrations.berrispend import BerriSpendLogger
from .integrations.supabase import Supabase
from .integrations.litedebugger import LiteDebugger
from openai.error import OpenAIError as OriginalError
from openai.openai_object import OpenAIObject
from .exceptions import (
AuthenticationError,
InvalidRequestError,
@ -9,34 +22,10 @@ from .exceptions import (
ServiceUnavailableError,
OpenAIError,
)
from openai.openai_object import OpenAIObject
from openai.error import OpenAIError as OriginalError
from .integrations.llmonitor import LLMonitorLogger
from .integrations.litedebugger import LiteDebugger
from .integrations.supabase import Supabase
from .integrations.berrispend import BerriSpendLogger
from .integrations.aispend import AISpendLogger
from .integrations.helicone import HeliconeLogger
import pkg_resources
import sys
import dotenv
import json
import traceback
import threading
import subprocess
import os
import litellm
import openai
import random
import uuid
import requests
import datetime
import time
import tiktoken
from typing import List, Dict, Union, Optional
encoding = tiktoken.get_encoding("cl100k_base")
####### ENVIRONMENT VARIABLES ###################
####### ENVIRONMENT VARIABLES ####################
dotenv.load_dotenv() # Loading env variables using dotenv
sentry_sdk_instance = None
capture_exception = None
@ -49,12 +38,11 @@ aispendLogger = None
berrispendLogger = None
supabaseClient = None
liteDebuggerClient = None
llmonitorLogger = None
callback_list: Optional[List[str]] = []
user_logger_fn = None
additional_details: Optional[Dict[str, str]] = {}
local_cache: Optional[Dict[str, str]] = {}
last_fetched_at = None
######## Model Response #########################
# All liteLLM Model responses will be in this format, Follows the OpenAI Format
# https://docs.litellm.ai/docs/completion/output
@ -174,7 +162,7 @@ class Logging:
def pre_call(self, input, api_key, additional_args={}):
try:
print(f"logging pre call for model: {self.model}")
print_verbose(f"logging pre call for model: {self.model}")
self.model_call_details["input"] = input
self.model_call_details["api_key"] = api_key
self.model_call_details["additional_args"] = additional_args
@ -214,7 +202,7 @@ class Logging:
print_verbose("reaches litedebugger for logging!")
model = self.model
messages = self.messages
print(f"liteDebuggerClient: {liteDebuggerClient}")
print_verbose(f"liteDebuggerClient: {liteDebuggerClient}")
liteDebuggerClient.input_log_event(
model=model,
messages=messages,
@ -271,7 +259,6 @@ class Logging:
# Add more methods as needed
def exception_logging(
additional_args={},
logger_fn=None,
@ -305,20 +292,34 @@ def exception_logging(
####### CLIENT ###################
# make it easy to log if completion/embedding runs succeeded or failed + see what happened | Non-Blocking
def client(original_function):
global liteDebuggerClient, get_all_keys
def function_setup(
def function_setup(
*args, **kwargs
): # just run once to check if user wants to send their data anywhere - PostHog/Sentry/Slack/etc.
try:
global callback_list, add_breadcrumb, user_logger_fn
if (len(litellm.input_callback) > 0
or len(litellm.success_callback) > 0
or len(litellm.failure_callback)
> 0) and len(callback_list) == 0:
if litellm.email is not None or os.getenv("LITELLM_EMAIL", None) is not None: # add to input, success and failure callbacks if user is using hosted product
get_all_keys()
if "lite_debugger" not in callback_list:
litellm.input_callback.append("lite_debugger")
litellm.success_callback.append("lite_debugger")
litellm.failure_callback.append("lite_debugger")
if (
len(litellm.input_callback) > 0
or len(litellm.success_callback) > 0
or len(litellm.failure_callback) > 0
) and len(callback_list) == 0:
callback_list = list(
set(litellm.input_callback + litellm.success_callback +
litellm.failure_callback))
set_callbacks(callback_list=callback_list, )
set(
litellm.input_callback
+ litellm.success_callback
+ litellm.failure_callback
)
)
set_callbacks(
callback_list=callback_list,
)
if add_breadcrumb:
add_breadcrumb(
category="litellm.llm_call",
@ -432,6 +433,11 @@ def client(original_function):
kwargs),
) # don't interrupt execution of main thread
my_thread.start()
if hasattr(e, "message"):
if (
liteDebuggerClient and liteDebuggerClient.dashboard_url != None
): # make it easy to get to the debugger logs if you've initialized it
e.message += f"\n Check the log in your dashboard - {liteDebuggerClient.dashboard_url}"
raise e
return wrapper
@ -515,7 +521,7 @@ def get_litellm_params(
"verbose": verbose,
"custom_llm_provider": custom_llm_provider,
"custom_api_base": custom_api_base,
"litellm_call_id": litellm_call_id
"litellm_call_id": litellm_call_id,
}
return litellm_params
@ -738,7 +744,7 @@ def set_callbacks(callback_list):
supabaseClient = Supabase()
elif callback == "lite_debugger":
print(f"instantiating lite_debugger")
liteDebuggerClient = LiteDebugger()
liteDebuggerClient = LiteDebugger(email=litellm.email)
except Exception as e:
raise e
@ -1036,7 +1042,7 @@ def handle_success(args, kwargs, result, start_time, end_time):
print_verbose("reaches lite_debugger for logging!")
model = args[0] if len(args) > 0 else kwargs["model"]
messages = args[1] if len(args) > 1 else kwargs["messages"]
print(f"liteDebuggerClient: {liteDebuggerClient}")
print_verbose(f"liteDebuggerClient: {liteDebuggerClient}")
liteDebuggerClient.log_event(
model=model,
messages=messages,
@ -1066,6 +1072,9 @@ def handle_success(args, kwargs, result, start_time, end_time):
)
pass
def acreate(*args, **kwargs): ## Thin client to handle the acreate langchain call
return litellm.acompletion(*args, **kwargs)
def prompt_token_calculator(model, messages):
# use tiktoken or anthropic's tokenizer depending on the model
@ -1082,6 +1091,21 @@ def prompt_token_calculator(model, messages):
return num_tokens
def valid_model(model):
try:
# for a given model name, check if the user has the right permissions to access the model
if (
model in litellm.open_ai_chat_completion_models
or model in litellm.open_ai_text_completion_models
):
openai.Model.retrieve(model)
else:
messages = [{"role": "user", "content": "Hello World"}]
litellm.completion(model=model, messages=messages)
except:
raise InvalidRequestError(message="", model=model, llm_provider="")
# integration helper function
def modify_integration(integration_name, integration_params):
global supabaseClient
@ -1089,9 +1113,65 @@ def modify_integration(integration_name, integration_params):
if "table_name" in integration_params:
Supabase.supabase_table_name = integration_params["table_name"]
####### [BETA] HOSTED PRODUCT ################ - https://docs.litellm.ai/docs/debugging/hosted_debugging
def get_all_keys(llm_provider=None):
try:
global last_fetched_at
# if user is using hosted product -> instantiate their env with their hosted api keys - refresh every 5 minutes
user_email = os.getenv("LITELLM_EMAIL") or litellm.email
if user_email:
time_delta = 0
if last_fetched_at != None:
current_time = time.time()
time_delta = current_time - last_fetched_at
if time_delta > 300 or last_fetched_at == None or llm_provider: # if the llm provider is passed in , assume this happening due to an AuthError for that provider
# make the api call
last_fetched_at = time.time()
print(f"last_fetched_at: {last_fetched_at}")
response = requests.post(url="http://api.litellm.ai/get_all_keys", headers={"content-type": "application/json"}, data=json.dumps({"user_email": user_email}))
print_verbose(f"get model key response: {response.text}")
data = response.json()
# update model list
for key, value in data["model_keys"].items(): # follows the LITELLM API KEY format - <UPPERCASE_PROVIDER_NAME>_API_KEY - e.g. HUGGINGFACE_API_KEY
os.environ[key] = value
return "it worked!"
return None
# return None by default
return None
except:
print_verbose(f"[Non-Blocking Error] get_all_keys error - {traceback.format_exc()}")
pass
def get_model_list():
global last_fetched_at
try:
# if user is using hosted product -> get their updated model list - refresh every 5 minutes
user_email = os.getenv("LITELLM_EMAIL") or litellm.email
if user_email:
time_delta = 0
if last_fetched_at != None:
current_time = time.time()
time_delta = current_time - last_fetched_at
if time_delta > 300 or last_fetched_at == None:
# make the api call
last_fetched_at = time.time()
print(f"last_fetched_at: {last_fetched_at}")
response = requests.post(url="http://api.litellm.ai/get_model_list", headers={"content-type": "application/json"}, data=json.dumps({"user_email": user_email}))
print_verbose(f"get_model_list response: {response.text}")
data = response.json()
# update model list
model_list = data["model_list"]
return model_list
return None
return None # return None by default
except:
print_verbose(f"[Non-Blocking Error] get_all_keys error - {traceback.format_exc()}")
####### EXCEPTION MAPPING ################
def exception_type(model, original_exception, custom_llm_provider):
global user_logger_fn
global user_logger_fn, liteDebuggerClient
exception_mapping_worked = False
try:
if isinstance(original_exception, OriginalError):
@ -1232,12 +1312,16 @@ def exception_type(model, original_exception, custom_llm_provider):
},
exception=e,
)
## AUTH ERROR
if isinstance(e, AuthenticationError) and (litellm.email or "LITELLM_EMAIL" in os.environ):
threading.Thread(target=get_all_keys, args=(e.llm_provider,)).start()
if exception_mapping_worked:
raise e
else: # don't let an error with mapping interrupt the user from receiving an error from the llm api calls
raise original_exception
####### CRASH REPORTING ################
def safe_crash_reporting(model=None, exception=None, custom_llm_provider=None):
data = {
"model": model,
@ -1273,7 +1357,7 @@ def litellm_telemetry(data):
payload = {
"uuid": uuid_value,
"data": data,
"version": pkg_resources.get_distribution("litellm").version,
"version:": importlib.metadata.version("litellm"),
}
# Make the POST request to litellm logging api
response = requests.post(
@ -1443,7 +1527,7 @@ async def stream_to_string(generator):
return response
########## Together AI streaming #############################
########## Together AI streaming ############################# [TODO] move together ai to it's own llm class
async def together_ai_completion_streaming(json_data, headers):
session = aiohttp.ClientSession()
url = "https://api.together.xyz/inference"
@ -1480,3 +1564,49 @@ async def together_ai_completion_streaming(json_data, headers):
pass
finally:
await session.close()
def completion_with_fallbacks(**kwargs):
response = None
rate_limited_models = set()
model_expiration_times = {}
start_time = time.time()
fallbacks = [kwargs["model"]] + kwargs["fallbacks"]
del kwargs["fallbacks"] # remove fallbacks so it's not recursive
while response == None and time.time() - start_time < 45:
for model in fallbacks:
# loop thru all models
try:
if (
model in rate_limited_models
): # check if model is currently cooling down
if (
model_expiration_times.get(model)
and time.time() >= model_expiration_times[model]
):
rate_limited_models.remove(
model
) # check if it's been 60s of cool down and remove model
else:
continue # skip model
# delete model from kwargs if it exists
if kwargs.get("model"):
del kwargs["model"]
print("making completion call", model)
response = litellm.completion(**kwargs, model=model)
if response != None:
return response
except Exception as e:
print(f"got exception {e} for model {model}")
rate_limited_models.add(model)
model_expiration_times[model] = (
time.time() + 60
) # cool down this selected model
# print(f"rate_limited_models {rate_limited_models}")
pass
return response