mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 10:44:24 +00:00
Merge branch 'main' into main
This commit is contained in:
commit
dcbf5e2444
48 changed files with 9611 additions and 239 deletions
228
litellm/utils.py
228
litellm/utils.py
|
@ -1,7 +1,20 @@
|
|||
import aiohttp
|
||||
import subprocess
|
||||
import importlib
|
||||
from typing import List, Dict, Union, Optional
|
||||
import sys
|
||||
import dotenv, json, traceback, threading
|
||||
import subprocess, os
|
||||
import litellm, openai
|
||||
import random, uuid, requests
|
||||
import datetime, time
|
||||
import tiktoken
|
||||
|
||||
encoding = tiktoken.get_encoding("cl100k_base")
|
||||
import importlib.metadata
|
||||
from .integrations.helicone import HeliconeLogger
|
||||
from .integrations.aispend import AISpendLogger
|
||||
from .integrations.berrispend import BerriSpendLogger
|
||||
from .integrations.supabase import Supabase
|
||||
from .integrations.litedebugger import LiteDebugger
|
||||
from openai.error import OpenAIError as OriginalError
|
||||
from openai.openai_object import OpenAIObject
|
||||
from .exceptions import (
|
||||
AuthenticationError,
|
||||
InvalidRequestError,
|
||||
|
@ -9,34 +22,10 @@ from .exceptions import (
|
|||
ServiceUnavailableError,
|
||||
OpenAIError,
|
||||
)
|
||||
from openai.openai_object import OpenAIObject
|
||||
from openai.error import OpenAIError as OriginalError
|
||||
from .integrations.llmonitor import LLMonitorLogger
|
||||
from .integrations.litedebugger import LiteDebugger
|
||||
from .integrations.supabase import Supabase
|
||||
from .integrations.berrispend import BerriSpendLogger
|
||||
from .integrations.aispend import AISpendLogger
|
||||
from .integrations.helicone import HeliconeLogger
|
||||
import pkg_resources
|
||||
import sys
|
||||
import dotenv
|
||||
import json
|
||||
import traceback
|
||||
import threading
|
||||
import subprocess
|
||||
import os
|
||||
import litellm
|
||||
import openai
|
||||
import random
|
||||
import uuid
|
||||
import requests
|
||||
import datetime
|
||||
import time
|
||||
import tiktoken
|
||||
from typing import List, Dict, Union, Optional
|
||||
|
||||
encoding = tiktoken.get_encoding("cl100k_base")
|
||||
|
||||
####### ENVIRONMENT VARIABLES ###################
|
||||
####### ENVIRONMENT VARIABLES ####################
|
||||
dotenv.load_dotenv() # Loading env variables using dotenv
|
||||
sentry_sdk_instance = None
|
||||
capture_exception = None
|
||||
|
@ -49,12 +38,11 @@ aispendLogger = None
|
|||
berrispendLogger = None
|
||||
supabaseClient = None
|
||||
liteDebuggerClient = None
|
||||
llmonitorLogger = None
|
||||
callback_list: Optional[List[str]] = []
|
||||
user_logger_fn = None
|
||||
additional_details: Optional[Dict[str, str]] = {}
|
||||
local_cache: Optional[Dict[str, str]] = {}
|
||||
|
||||
last_fetched_at = None
|
||||
######## Model Response #########################
|
||||
# All liteLLM Model responses will be in this format, Follows the OpenAI Format
|
||||
# https://docs.litellm.ai/docs/completion/output
|
||||
|
@ -174,7 +162,7 @@ class Logging:
|
|||
|
||||
def pre_call(self, input, api_key, additional_args={}):
|
||||
try:
|
||||
print(f"logging pre call for model: {self.model}")
|
||||
print_verbose(f"logging pre call for model: {self.model}")
|
||||
self.model_call_details["input"] = input
|
||||
self.model_call_details["api_key"] = api_key
|
||||
self.model_call_details["additional_args"] = additional_args
|
||||
|
@ -214,7 +202,7 @@ class Logging:
|
|||
print_verbose("reaches litedebugger for logging!")
|
||||
model = self.model
|
||||
messages = self.messages
|
||||
print(f"liteDebuggerClient: {liteDebuggerClient}")
|
||||
print_verbose(f"liteDebuggerClient: {liteDebuggerClient}")
|
||||
liteDebuggerClient.input_log_event(
|
||||
model=model,
|
||||
messages=messages,
|
||||
|
@ -271,7 +259,6 @@ class Logging:
|
|||
|
||||
# Add more methods as needed
|
||||
|
||||
|
||||
def exception_logging(
|
||||
additional_args={},
|
||||
logger_fn=None,
|
||||
|
@ -305,20 +292,34 @@ def exception_logging(
|
|||
####### CLIENT ###################
|
||||
# make it easy to log if completion/embedding runs succeeded or failed + see what happened | Non-Blocking
|
||||
def client(original_function):
|
||||
global liteDebuggerClient, get_all_keys
|
||||
|
||||
def function_setup(
|
||||
def function_setup(
|
||||
*args, **kwargs
|
||||
): # just run once to check if user wants to send their data anywhere - PostHog/Sentry/Slack/etc.
|
||||
try:
|
||||
global callback_list, add_breadcrumb, user_logger_fn
|
||||
if (len(litellm.input_callback) > 0
|
||||
or len(litellm.success_callback) > 0
|
||||
or len(litellm.failure_callback)
|
||||
> 0) and len(callback_list) == 0:
|
||||
if litellm.email is not None or os.getenv("LITELLM_EMAIL", None) is not None: # add to input, success and failure callbacks if user is using hosted product
|
||||
get_all_keys()
|
||||
if "lite_debugger" not in callback_list:
|
||||
litellm.input_callback.append("lite_debugger")
|
||||
litellm.success_callback.append("lite_debugger")
|
||||
litellm.failure_callback.append("lite_debugger")
|
||||
if (
|
||||
len(litellm.input_callback) > 0
|
||||
or len(litellm.success_callback) > 0
|
||||
or len(litellm.failure_callback) > 0
|
||||
) and len(callback_list) == 0:
|
||||
callback_list = list(
|
||||
set(litellm.input_callback + litellm.success_callback +
|
||||
litellm.failure_callback))
|
||||
set_callbacks(callback_list=callback_list, )
|
||||
set(
|
||||
litellm.input_callback
|
||||
+ litellm.success_callback
|
||||
+ litellm.failure_callback
|
||||
)
|
||||
)
|
||||
set_callbacks(
|
||||
callback_list=callback_list,
|
||||
)
|
||||
if add_breadcrumb:
|
||||
add_breadcrumb(
|
||||
category="litellm.llm_call",
|
||||
|
@ -432,6 +433,11 @@ def client(original_function):
|
|||
kwargs),
|
||||
) # don't interrupt execution of main thread
|
||||
my_thread.start()
|
||||
if hasattr(e, "message"):
|
||||
if (
|
||||
liteDebuggerClient and liteDebuggerClient.dashboard_url != None
|
||||
): # make it easy to get to the debugger logs if you've initialized it
|
||||
e.message += f"\n Check the log in your dashboard - {liteDebuggerClient.dashboard_url}"
|
||||
raise e
|
||||
|
||||
return wrapper
|
||||
|
@ -515,7 +521,7 @@ def get_litellm_params(
|
|||
"verbose": verbose,
|
||||
"custom_llm_provider": custom_llm_provider,
|
||||
"custom_api_base": custom_api_base,
|
||||
"litellm_call_id": litellm_call_id
|
||||
"litellm_call_id": litellm_call_id,
|
||||
}
|
||||
|
||||
return litellm_params
|
||||
|
@ -738,7 +744,7 @@ def set_callbacks(callback_list):
|
|||
supabaseClient = Supabase()
|
||||
elif callback == "lite_debugger":
|
||||
print(f"instantiating lite_debugger")
|
||||
liteDebuggerClient = LiteDebugger()
|
||||
liteDebuggerClient = LiteDebugger(email=litellm.email)
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
|
@ -1036,7 +1042,7 @@ def handle_success(args, kwargs, result, start_time, end_time):
|
|||
print_verbose("reaches lite_debugger for logging!")
|
||||
model = args[0] if len(args) > 0 else kwargs["model"]
|
||||
messages = args[1] if len(args) > 1 else kwargs["messages"]
|
||||
print(f"liteDebuggerClient: {liteDebuggerClient}")
|
||||
print_verbose(f"liteDebuggerClient: {liteDebuggerClient}")
|
||||
liteDebuggerClient.log_event(
|
||||
model=model,
|
||||
messages=messages,
|
||||
|
@ -1066,6 +1072,9 @@ def handle_success(args, kwargs, result, start_time, end_time):
|
|||
)
|
||||
pass
|
||||
|
||||
def acreate(*args, **kwargs): ## Thin client to handle the acreate langchain call
|
||||
return litellm.acompletion(*args, **kwargs)
|
||||
|
||||
|
||||
def prompt_token_calculator(model, messages):
|
||||
# use tiktoken or anthropic's tokenizer depending on the model
|
||||
|
@ -1082,6 +1091,21 @@ def prompt_token_calculator(model, messages):
|
|||
return num_tokens
|
||||
|
||||
|
||||
def valid_model(model):
|
||||
try:
|
||||
# for a given model name, check if the user has the right permissions to access the model
|
||||
if (
|
||||
model in litellm.open_ai_chat_completion_models
|
||||
or model in litellm.open_ai_text_completion_models
|
||||
):
|
||||
openai.Model.retrieve(model)
|
||||
else:
|
||||
messages = [{"role": "user", "content": "Hello World"}]
|
||||
litellm.completion(model=model, messages=messages)
|
||||
except:
|
||||
raise InvalidRequestError(message="", model=model, llm_provider="")
|
||||
|
||||
|
||||
# integration helper function
|
||||
def modify_integration(integration_name, integration_params):
|
||||
global supabaseClient
|
||||
|
@ -1089,9 +1113,65 @@ def modify_integration(integration_name, integration_params):
|
|||
if "table_name" in integration_params:
|
||||
Supabase.supabase_table_name = integration_params["table_name"]
|
||||
|
||||
####### [BETA] HOSTED PRODUCT ################ - https://docs.litellm.ai/docs/debugging/hosted_debugging
|
||||
|
||||
def get_all_keys(llm_provider=None):
|
||||
try:
|
||||
global last_fetched_at
|
||||
# if user is using hosted product -> instantiate their env with their hosted api keys - refresh every 5 minutes
|
||||
user_email = os.getenv("LITELLM_EMAIL") or litellm.email
|
||||
if user_email:
|
||||
time_delta = 0
|
||||
if last_fetched_at != None:
|
||||
current_time = time.time()
|
||||
time_delta = current_time - last_fetched_at
|
||||
if time_delta > 300 or last_fetched_at == None or llm_provider: # if the llm provider is passed in , assume this happening due to an AuthError for that provider
|
||||
# make the api call
|
||||
last_fetched_at = time.time()
|
||||
print(f"last_fetched_at: {last_fetched_at}")
|
||||
response = requests.post(url="http://api.litellm.ai/get_all_keys", headers={"content-type": "application/json"}, data=json.dumps({"user_email": user_email}))
|
||||
print_verbose(f"get model key response: {response.text}")
|
||||
data = response.json()
|
||||
# update model list
|
||||
for key, value in data["model_keys"].items(): # follows the LITELLM API KEY format - <UPPERCASE_PROVIDER_NAME>_API_KEY - e.g. HUGGINGFACE_API_KEY
|
||||
os.environ[key] = value
|
||||
return "it worked!"
|
||||
return None
|
||||
# return None by default
|
||||
return None
|
||||
except:
|
||||
print_verbose(f"[Non-Blocking Error] get_all_keys error - {traceback.format_exc()}")
|
||||
pass
|
||||
|
||||
def get_model_list():
|
||||
global last_fetched_at
|
||||
try:
|
||||
# if user is using hosted product -> get their updated model list - refresh every 5 minutes
|
||||
user_email = os.getenv("LITELLM_EMAIL") or litellm.email
|
||||
if user_email:
|
||||
time_delta = 0
|
||||
if last_fetched_at != None:
|
||||
current_time = time.time()
|
||||
time_delta = current_time - last_fetched_at
|
||||
if time_delta > 300 or last_fetched_at == None:
|
||||
# make the api call
|
||||
last_fetched_at = time.time()
|
||||
print(f"last_fetched_at: {last_fetched_at}")
|
||||
response = requests.post(url="http://api.litellm.ai/get_model_list", headers={"content-type": "application/json"}, data=json.dumps({"user_email": user_email}))
|
||||
print_verbose(f"get_model_list response: {response.text}")
|
||||
data = response.json()
|
||||
# update model list
|
||||
model_list = data["model_list"]
|
||||
return model_list
|
||||
return None
|
||||
return None # return None by default
|
||||
except:
|
||||
print_verbose(f"[Non-Blocking Error] get_all_keys error - {traceback.format_exc()}")
|
||||
|
||||
|
||||
####### EXCEPTION MAPPING ################
|
||||
def exception_type(model, original_exception, custom_llm_provider):
|
||||
global user_logger_fn
|
||||
global user_logger_fn, liteDebuggerClient
|
||||
exception_mapping_worked = False
|
||||
try:
|
||||
if isinstance(original_exception, OriginalError):
|
||||
|
@ -1232,12 +1312,16 @@ def exception_type(model, original_exception, custom_llm_provider):
|
|||
},
|
||||
exception=e,
|
||||
)
|
||||
## AUTH ERROR
|
||||
if isinstance(e, AuthenticationError) and (litellm.email or "LITELLM_EMAIL" in os.environ):
|
||||
threading.Thread(target=get_all_keys, args=(e.llm_provider,)).start()
|
||||
if exception_mapping_worked:
|
||||
raise e
|
||||
else: # don't let an error with mapping interrupt the user from receiving an error from the llm api calls
|
||||
raise original_exception
|
||||
|
||||
|
||||
####### CRASH REPORTING ################
|
||||
def safe_crash_reporting(model=None, exception=None, custom_llm_provider=None):
|
||||
data = {
|
||||
"model": model,
|
||||
|
@ -1273,7 +1357,7 @@ def litellm_telemetry(data):
|
|||
payload = {
|
||||
"uuid": uuid_value,
|
||||
"data": data,
|
||||
"version": pkg_resources.get_distribution("litellm").version,
|
||||
"version:": importlib.metadata.version("litellm"),
|
||||
}
|
||||
# Make the POST request to litellm logging api
|
||||
response = requests.post(
|
||||
|
@ -1443,7 +1527,7 @@ async def stream_to_string(generator):
|
|||
return response
|
||||
|
||||
|
||||
########## Together AI streaming #############################
|
||||
########## Together AI streaming ############################# [TODO] move together ai to it's own llm class
|
||||
async def together_ai_completion_streaming(json_data, headers):
|
||||
session = aiohttp.ClientSession()
|
||||
url = "https://api.together.xyz/inference"
|
||||
|
@ -1480,3 +1564,49 @@ async def together_ai_completion_streaming(json_data, headers):
|
|||
pass
|
||||
finally:
|
||||
await session.close()
|
||||
|
||||
|
||||
def completion_with_fallbacks(**kwargs):
|
||||
response = None
|
||||
rate_limited_models = set()
|
||||
model_expiration_times = {}
|
||||
start_time = time.time()
|
||||
fallbacks = [kwargs["model"]] + kwargs["fallbacks"]
|
||||
del kwargs["fallbacks"] # remove fallbacks so it's not recursive
|
||||
|
||||
while response == None and time.time() - start_time < 45:
|
||||
for model in fallbacks:
|
||||
# loop thru all models
|
||||
try:
|
||||
if (
|
||||
model in rate_limited_models
|
||||
): # check if model is currently cooling down
|
||||
if (
|
||||
model_expiration_times.get(model)
|
||||
and time.time() >= model_expiration_times[model]
|
||||
):
|
||||
rate_limited_models.remove(
|
||||
model
|
||||
) # check if it's been 60s of cool down and remove model
|
||||
else:
|
||||
continue # skip model
|
||||
|
||||
# delete model from kwargs if it exists
|
||||
if kwargs.get("model"):
|
||||
del kwargs["model"]
|
||||
|
||||
print("making completion call", model)
|
||||
response = litellm.completion(**kwargs, model=model)
|
||||
|
||||
if response != None:
|
||||
return response
|
||||
|
||||
except Exception as e:
|
||||
print(f"got exception {e} for model {model}")
|
||||
rate_limited_models.add(model)
|
||||
model_expiration_times[model] = (
|
||||
time.time() + 60
|
||||
) # cool down this selected model
|
||||
# print(f"rate_limited_models {rate_limited_models}")
|
||||
pass
|
||||
return response
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue