mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 10:44:24 +00:00
refactor: add black formatting
This commit is contained in:
parent
b87d630b0a
commit
4905929de3
156 changed files with 19723 additions and 10869 deletions
|
@ -8,16 +8,19 @@ from litellm.proxy.hooks.max_budget_limiter import MaxBudgetLimiter
|
|||
from litellm.integrations.custom_logger import CustomLogger
|
||||
from fastapi import HTTPException, status
|
||||
|
||||
|
||||
def print_verbose(print_statement):
|
||||
if litellm.set_verbose:
|
||||
print(f"LiteLLM Proxy: {print_statement}") # noqa
|
||||
### LOGGING ###
|
||||
class ProxyLogging:
|
||||
print(f"LiteLLM Proxy: {print_statement}") # noqa
|
||||
|
||||
|
||||
### LOGGING ###
|
||||
class ProxyLogging:
|
||||
"""
|
||||
Logging/Custom Handlers for proxy.
|
||||
Logging/Custom Handlers for proxy.
|
||||
|
||||
Implemented mainly to:
|
||||
- log successful/failed db read/writes
|
||||
- log successful/failed db read/writes
|
||||
- support the max parallel request integration
|
||||
"""
|
||||
|
||||
|
@ -25,15 +28,15 @@ class ProxyLogging:
|
|||
## INITIALIZE LITELLM CALLBACKS ##
|
||||
self.call_details: dict = {}
|
||||
self.call_details["user_api_key_cache"] = user_api_key_cache
|
||||
self.max_parallel_request_limiter = MaxParallelRequestsHandler()
|
||||
self.max_budget_limiter = MaxBudgetLimiter()
|
||||
self.max_parallel_request_limiter = MaxParallelRequestsHandler()
|
||||
self.max_budget_limiter = MaxBudgetLimiter()
|
||||
pass
|
||||
|
||||
def _init_litellm_callbacks(self):
|
||||
print_verbose(f"INITIALIZING LITELLM CALLBACKS!")
|
||||
litellm.callbacks.append(self.max_parallel_request_limiter)
|
||||
litellm.callbacks.append(self.max_budget_limiter)
|
||||
for callback in litellm.callbacks:
|
||||
for callback in litellm.callbacks:
|
||||
if callback not in litellm.input_callback:
|
||||
litellm.input_callback.append(callback)
|
||||
if callback not in litellm.success_callback:
|
||||
|
@ -44,7 +47,7 @@ class ProxyLogging:
|
|||
litellm._async_success_callback.append(callback)
|
||||
if callback not in litellm._async_failure_callback:
|
||||
litellm._async_failure_callback.append(callback)
|
||||
|
||||
|
||||
if (
|
||||
len(litellm.input_callback) > 0
|
||||
or len(litellm.success_callback) > 0
|
||||
|
@ -57,31 +60,41 @@ class ProxyLogging:
|
|||
+ litellm.failure_callback
|
||||
)
|
||||
)
|
||||
litellm.utils.set_callbacks(
|
||||
callback_list=callback_list
|
||||
)
|
||||
litellm.utils.set_callbacks(callback_list=callback_list)
|
||||
|
||||
async def pre_call_hook(self, user_api_key_dict: UserAPIKeyAuth, data: dict, call_type: Literal["completion", "embeddings"]):
|
||||
async def pre_call_hook(
|
||||
self,
|
||||
user_api_key_dict: UserAPIKeyAuth,
|
||||
data: dict,
|
||||
call_type: Literal["completion", "embeddings"],
|
||||
):
|
||||
"""
|
||||
Allows users to modify/reject the incoming request to the proxy, without having to deal with parsing Request body.
|
||||
|
||||
Covers:
|
||||
Covers:
|
||||
1. /chat/completions
|
||||
2. /embeddings
|
||||
2. /embeddings
|
||||
"""
|
||||
try:
|
||||
for callback in litellm.callbacks:
|
||||
if isinstance(callback, CustomLogger) and 'async_pre_call_hook' in vars(callback.__class__):
|
||||
response = await callback.async_pre_call_hook(user_api_key_dict=user_api_key_dict, cache=self.call_details["user_api_key_cache"], data=data, call_type=call_type)
|
||||
if response is not None:
|
||||
for callback in litellm.callbacks:
|
||||
if isinstance(callback, CustomLogger) and "async_pre_call_hook" in vars(
|
||||
callback.__class__
|
||||
):
|
||||
response = await callback.async_pre_call_hook(
|
||||
user_api_key_dict=user_api_key_dict,
|
||||
cache=self.call_details["user_api_key_cache"],
|
||||
data=data,
|
||||
call_type=call_type,
|
||||
)
|
||||
if response is not None:
|
||||
data = response
|
||||
|
||||
print_verbose(f'final data being sent to {call_type} call: {data}')
|
||||
print_verbose(f"final data being sent to {call_type} call: {data}")
|
||||
return data
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
async def success_handler(self, *args, **kwargs):
|
||||
|
||||
async def success_handler(self, *args, **kwargs):
|
||||
"""
|
||||
Log successful db read/writes
|
||||
"""
|
||||
|
@ -93,26 +106,31 @@ class ProxyLogging:
|
|||
|
||||
Currently only logs exceptions to sentry
|
||||
"""
|
||||
if litellm.utils.capture_exception:
|
||||
if litellm.utils.capture_exception:
|
||||
litellm.utils.capture_exception(error=original_exception)
|
||||
|
||||
async def post_call_failure_hook(self, original_exception: Exception, user_api_key_dict: UserAPIKeyAuth):
|
||||
async def post_call_failure_hook(
|
||||
self, original_exception: Exception, user_api_key_dict: UserAPIKeyAuth
|
||||
):
|
||||
"""
|
||||
Allows users to raise custom exceptions/log when a call fails, without having to deal with parsing Request body.
|
||||
|
||||
Covers:
|
||||
Covers:
|
||||
1. /chat/completions
|
||||
2. /embeddings
|
||||
2. /embeddings
|
||||
"""
|
||||
|
||||
for callback in litellm.callbacks:
|
||||
try:
|
||||
try:
|
||||
if isinstance(callback, CustomLogger):
|
||||
await callback.async_post_call_failure_hook(user_api_key_dict=user_api_key_dict, original_exception=original_exception)
|
||||
except Exception as e:
|
||||
await callback.async_post_call_failure_hook(
|
||||
user_api_key_dict=user_api_key_dict,
|
||||
original_exception=original_exception,
|
||||
)
|
||||
except Exception as e:
|
||||
raise e
|
||||
return
|
||||
|
||||
|
||||
|
||||
### DB CONNECTOR ###
|
||||
# Define the retry decorator with backoff strategy
|
||||
|
@ -121,9 +139,12 @@ def on_backoff(details):
|
|||
# The 'tries' key in the details dictionary contains the number of completed tries
|
||||
print_verbose(f"Backing off... this was attempt #{details['tries']}")
|
||||
|
||||
|
||||
class PrismaClient:
|
||||
def __init__(self, database_url: str, proxy_logging_obj: ProxyLogging):
|
||||
print_verbose("LiteLLM: DATABASE_URL Set in config, trying to 'pip install prisma'")
|
||||
print_verbose(
|
||||
"LiteLLM: DATABASE_URL Set in config, trying to 'pip install prisma'"
|
||||
)
|
||||
## init logging object
|
||||
self.proxy_logging_obj = proxy_logging_obj
|
||||
|
||||
|
@ -136,23 +157,24 @@ class PrismaClient:
|
|||
os.chdir(dname)
|
||||
|
||||
try:
|
||||
subprocess.run(['prisma', 'generate'])
|
||||
subprocess.run(['prisma', 'db', 'push', '--accept-data-loss']) # this looks like a weird edge case when prisma just wont start on render. we need to have the --accept-data-loss
|
||||
subprocess.run(["prisma", "generate"])
|
||||
subprocess.run(
|
||||
["prisma", "db", "push", "--accept-data-loss"]
|
||||
) # this looks like a weird edge case when prisma just wont start on render. we need to have the --accept-data-loss
|
||||
finally:
|
||||
os.chdir(original_dir)
|
||||
# Now you can import the Prisma Client
|
||||
from prisma import Client # type: ignore
|
||||
self.db = Client() #Client to connect to Prisma db
|
||||
from prisma import Client # type: ignore
|
||||
|
||||
|
||||
self.db = Client() # Client to connect to Prisma db
|
||||
|
||||
def hash_token(self, token: str):
|
||||
# Hash the string using SHA-256
|
||||
hashed_token = hashlib.sha256(token.encode()).hexdigest()
|
||||
|
||||
|
||||
return hashed_token
|
||||
|
||||
def jsonify_object(self, data: dict) -> dict:
|
||||
def jsonify_object(self, data: dict) -> dict:
|
||||
db_data = copy.deepcopy(data)
|
||||
|
||||
for k, v in db_data.items():
|
||||
|
@ -162,233 +184,258 @@ class PrismaClient:
|
|||
|
||||
@backoff.on_exception(
|
||||
backoff.expo,
|
||||
Exception, # base exception to catch for the backoff
|
||||
max_tries=3, # maximum number of retries
|
||||
max_time=10, # maximum total time to retry for
|
||||
Exception, # base exception to catch for the backoff
|
||||
max_tries=3, # maximum number of retries
|
||||
max_time=10, # maximum total time to retry for
|
||||
on_backoff=on_backoff, # specifying the function to call on backoff
|
||||
)
|
||||
async def get_data(self, token: Optional[str]=None, expires: Optional[Any]=None, user_id: Optional[str]=None):
|
||||
try:
|
||||
async def get_data(
|
||||
self,
|
||||
token: Optional[str] = None,
|
||||
expires: Optional[Any] = None,
|
||||
user_id: Optional[str] = None,
|
||||
):
|
||||
try:
|
||||
response = None
|
||||
if token is not None:
|
||||
if token is not None:
|
||||
# check if plain text or hash
|
||||
hashed_token = token
|
||||
if token.startswith("sk-"):
|
||||
if token.startswith("sk-"):
|
||||
hashed_token = self.hash_token(token=token)
|
||||
response = await self.db.litellm_verificationtoken.find_unique(
|
||||
where={
|
||||
"token": hashed_token
|
||||
}
|
||||
)
|
||||
where={"token": hashed_token}
|
||||
)
|
||||
if response:
|
||||
# Token exists, now check expiration.
|
||||
if response.expires is not None and expires is not None:
|
||||
if response.expires is not None and expires is not None:
|
||||
if response.expires >= expires:
|
||||
# Token exists and is not expired.
|
||||
return response
|
||||
else:
|
||||
# Token exists but is expired.
|
||||
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="expired user key")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="expired user key",
|
||||
)
|
||||
return response
|
||||
else:
|
||||
# Token does not exist.
|
||||
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="invalid user key")
|
||||
elif user_id is not None:
|
||||
response = await self.db.litellm_usertable.find_unique( # type: ignore
|
||||
where={
|
||||
"user_id": user_id,
|
||||
}
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="invalid user key",
|
||||
)
|
||||
elif user_id is not None:
|
||||
response = await self.db.litellm_usertable.find_unique( # type: ignore
|
||||
where={
|
||||
"user_id": user_id,
|
||||
}
|
||||
)
|
||||
return response
|
||||
except Exception as e:
|
||||
asyncio.create_task(self.proxy_logging_obj.failure_handler(original_exception=e))
|
||||
except Exception as e:
|
||||
asyncio.create_task(
|
||||
self.proxy_logging_obj.failure_handler(original_exception=e)
|
||||
)
|
||||
raise e
|
||||
|
||||
# Define a retrying strategy with exponential backoff
|
||||
@backoff.on_exception(
|
||||
backoff.expo,
|
||||
Exception, # base exception to catch for the backoff
|
||||
max_tries=3, # maximum number of retries
|
||||
max_time=10, # maximum total time to retry for
|
||||
Exception, # base exception to catch for the backoff
|
||||
max_tries=3, # maximum number of retries
|
||||
max_time=10, # maximum total time to retry for
|
||||
on_backoff=on_backoff, # specifying the function to call on backoff
|
||||
)
|
||||
async def insert_data(self, data: dict):
|
||||
"""
|
||||
Add a key to the database. If it already exists, do nothing.
|
||||
Add a key to the database. If it already exists, do nothing.
|
||||
"""
|
||||
try:
|
||||
try:
|
||||
token = data["token"]
|
||||
hashed_token = self.hash_token(token=token)
|
||||
db_data = self.jsonify_object(data=data)
|
||||
db_data["token"] = hashed_token
|
||||
max_budget = db_data.pop("max_budget", None)
|
||||
new_verification_token = await self.db.litellm_verificationtoken.upsert( # type: ignore
|
||||
new_verification_token = await self.db.litellm_verificationtoken.upsert( # type: ignore
|
||||
where={
|
||||
'token': hashed_token,
|
||||
"token": hashed_token,
|
||||
},
|
||||
data={
|
||||
"create": {**db_data}, #type: ignore
|
||||
"update": {} # don't do anything if it already exists
|
||||
}
|
||||
"create": {**db_data}, # type: ignore
|
||||
"update": {}, # don't do anything if it already exists
|
||||
},
|
||||
)
|
||||
|
||||
new_user_row = await self.db.litellm_usertable.upsert(
|
||||
where={
|
||||
'user_id': data['user_id']
|
||||
},
|
||||
where={"user_id": data["user_id"]},
|
||||
data={
|
||||
"create": {"user_id": data['user_id'], "max_budget": max_budget},
|
||||
"update": {} # don't do anything if it already exists
|
||||
}
|
||||
"create": {"user_id": data["user_id"], "max_budget": max_budget},
|
||||
"update": {}, # don't do anything if it already exists
|
||||
},
|
||||
)
|
||||
return new_verification_token
|
||||
except Exception as e:
|
||||
asyncio.create_task(self.proxy_logging_obj.failure_handler(original_exception=e))
|
||||
asyncio.create_task(
|
||||
self.proxy_logging_obj.failure_handler(original_exception=e)
|
||||
)
|
||||
raise e
|
||||
|
||||
# Define a retrying strategy with exponential backoff
|
||||
@backoff.on_exception(
|
||||
backoff.expo,
|
||||
Exception, # base exception to catch for the backoff
|
||||
max_tries=3, # maximum number of retries
|
||||
max_time=10, # maximum total time to retry for
|
||||
Exception, # base exception to catch for the backoff
|
||||
max_tries=3, # maximum number of retries
|
||||
max_time=10, # maximum total time to retry for
|
||||
on_backoff=on_backoff, # specifying the function to call on backoff
|
||||
)
|
||||
async def update_data(self, token: Optional[str]=None, data: dict={}, user_id: Optional[str]=None):
|
||||
async def update_data(
|
||||
self,
|
||||
token: Optional[str] = None,
|
||||
data: dict = {},
|
||||
user_id: Optional[str] = None,
|
||||
):
|
||||
"""
|
||||
Update existing data
|
||||
"""
|
||||
try:
|
||||
try:
|
||||
db_data = self.jsonify_object(data=data)
|
||||
if token is not None:
|
||||
if token is not None:
|
||||
print_verbose(f"token: {token}")
|
||||
# check if plain text or hash
|
||||
if token.startswith("sk-"):
|
||||
if token.startswith("sk-"):
|
||||
token = self.hash_token(token=token)
|
||||
db_data["token"] = token
|
||||
db_data["token"] = token
|
||||
response = await self.db.litellm_verificationtoken.update(
|
||||
where={
|
||||
"token": token # type: ignore
|
||||
},
|
||||
data={**db_data} # type: ignore
|
||||
where={"token": token}, # type: ignore
|
||||
data={**db_data}, # type: ignore
|
||||
)
|
||||
print_verbose("\033[91m" + f"DB write succeeded {response}" + "\033[0m")
|
||||
return {"token": token, "data": db_data}
|
||||
elif user_id is not None:
|
||||
elif user_id is not None:
|
||||
"""
|
||||
If data['spend'] + data['user'], update the user table with spend info as well
|
||||
"""
|
||||
update_user_row = await self.db.litellm_usertable.update(
|
||||
where={
|
||||
'user_id': user_id # type: ignore
|
||||
},
|
||||
data={**db_data} # type: ignore
|
||||
where={"user_id": user_id}, # type: ignore
|
||||
data={**db_data}, # type: ignore
|
||||
)
|
||||
return {"user_id": user_id, "data": db_data}
|
||||
except Exception as e:
|
||||
asyncio.create_task(self.proxy_logging_obj.failure_handler(original_exception=e))
|
||||
except Exception as e:
|
||||
asyncio.create_task(
|
||||
self.proxy_logging_obj.failure_handler(original_exception=e)
|
||||
)
|
||||
print_verbose("\033[91m" + f"DB write failed: {e}" + "\033[0m")
|
||||
raise e
|
||||
|
||||
|
||||
# Define a retrying strategy with exponential backoff
|
||||
@backoff.on_exception(
|
||||
backoff.expo,
|
||||
Exception, # base exception to catch for the backoff
|
||||
max_tries=3, # maximum number of retries
|
||||
max_time=10, # maximum total time to retry for
|
||||
Exception, # base exception to catch for the backoff
|
||||
max_tries=3, # maximum number of retries
|
||||
max_time=10, # maximum total time to retry for
|
||||
on_backoff=on_backoff, # specifying the function to call on backoff
|
||||
)
|
||||
async def delete_data(self, tokens: List):
|
||||
"""
|
||||
Allow user to delete a key(s)
|
||||
"""
|
||||
try:
|
||||
try:
|
||||
hashed_tokens = [self.hash_token(token=token) for token in tokens]
|
||||
await self.db.litellm_verificationtoken.delete_many(
|
||||
where={"token": {"in": hashed_tokens}}
|
||||
)
|
||||
where={"token": {"in": hashed_tokens}}
|
||||
)
|
||||
return {"deleted_keys": tokens}
|
||||
except Exception as e:
|
||||
asyncio.create_task(self.proxy_logging_obj.failure_handler(original_exception=e))
|
||||
raise e
|
||||
|
||||
# Define a retrying strategy with exponential backoff
|
||||
@backoff.on_exception(
|
||||
backoff.expo,
|
||||
Exception, # base exception to catch for the backoff
|
||||
max_tries=3, # maximum number of retries
|
||||
max_time=10, # maximum total time to retry for
|
||||
on_backoff=on_backoff, # specifying the function to call on backoff
|
||||
)
|
||||
async def connect(self):
|
||||
try:
|
||||
await self.db.connect()
|
||||
except Exception as e:
|
||||
asyncio.create_task(self.proxy_logging_obj.failure_handler(original_exception=e))
|
||||
except Exception as e:
|
||||
asyncio.create_task(
|
||||
self.proxy_logging_obj.failure_handler(original_exception=e)
|
||||
)
|
||||
raise e
|
||||
|
||||
# Define a retrying strategy with exponential backoff
|
||||
@backoff.on_exception(
|
||||
backoff.expo,
|
||||
Exception, # base exception to catch for the backoff
|
||||
max_tries=3, # maximum number of retries
|
||||
max_time=10, # maximum total time to retry for
|
||||
Exception, # base exception to catch for the backoff
|
||||
max_tries=3, # maximum number of retries
|
||||
max_time=10, # maximum total time to retry for
|
||||
on_backoff=on_backoff, # specifying the function to call on backoff
|
||||
)
|
||||
async def disconnect(self):
|
||||
async def connect(self):
|
||||
try:
|
||||
await self.db.connect()
|
||||
except Exception as e:
|
||||
asyncio.create_task(
|
||||
self.proxy_logging_obj.failure_handler(original_exception=e)
|
||||
)
|
||||
raise e
|
||||
|
||||
# Define a retrying strategy with exponential backoff
|
||||
@backoff.on_exception(
|
||||
backoff.expo,
|
||||
Exception, # base exception to catch for the backoff
|
||||
max_tries=3, # maximum number of retries
|
||||
max_time=10, # maximum total time to retry for
|
||||
on_backoff=on_backoff, # specifying the function to call on backoff
|
||||
)
|
||||
async def disconnect(self):
|
||||
try:
|
||||
await self.db.disconnect()
|
||||
except Exception as e:
|
||||
asyncio.create_task(self.proxy_logging_obj.failure_handler(original_exception=e))
|
||||
except Exception as e:
|
||||
asyncio.create_task(
|
||||
self.proxy_logging_obj.failure_handler(original_exception=e)
|
||||
)
|
||||
raise e
|
||||
|
||||
|
||||
### CUSTOM FILE ###
|
||||
def get_instance_fn(value: str, config_file_path: Optional[str] = None) -> Any:
|
||||
try:
|
||||
print_verbose(f"value: {value}")
|
||||
# Split the path by dots to separate module from instance
|
||||
parts = value.split(".")
|
||||
|
||||
|
||||
# The module path is all but the last part, and the instance_name is the last part
|
||||
module_name = ".".join(parts[:-1])
|
||||
instance_name = parts[-1]
|
||||
|
||||
|
||||
# If config_file_path is provided, use it to determine the module spec and load the module
|
||||
if config_file_path is not None:
|
||||
directory = os.path.dirname(config_file_path)
|
||||
module_file_path = os.path.join(directory, *module_name.split('.'))
|
||||
module_file_path += '.py'
|
||||
module_file_path = os.path.join(directory, *module_name.split("."))
|
||||
module_file_path += ".py"
|
||||
|
||||
spec = importlib.util.spec_from_file_location(module_name, module_file_path)
|
||||
if spec is None:
|
||||
raise ImportError(f"Could not find a module specification for {module_file_path}")
|
||||
raise ImportError(
|
||||
f"Could not find a module specification for {module_file_path}"
|
||||
)
|
||||
module = importlib.util.module_from_spec(spec)
|
||||
spec.loader.exec_module(module) # type: ignore
|
||||
spec.loader.exec_module(module) # type: ignore
|
||||
else:
|
||||
# Dynamically import the module
|
||||
module = importlib.import_module(module_name)
|
||||
|
||||
|
||||
# Get the instance from the module
|
||||
instance = getattr(module, instance_name)
|
||||
|
||||
|
||||
return instance
|
||||
except ImportError as e:
|
||||
# Re-raise the exception with a user-friendly message
|
||||
raise ImportError(f"Could not import {instance_name} from {module_name}") from e
|
||||
except Exception as e:
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
|
||||
### HELPER FUNCTIONS ###
|
||||
async def _cache_user_row(user_id: str, cache: DualCache, db: PrismaClient):
|
||||
async def _cache_user_row(user_id: str, cache: DualCache, db: PrismaClient):
|
||||
"""
|
||||
Check if a user_id exists in cache,
|
||||
if not retrieve it.
|
||||
Check if a user_id exists in cache,
|
||||
if not retrieve it.
|
||||
"""
|
||||
cache_key = f"{user_id}_user_api_key_user_id"
|
||||
response = cache.get_cache(key=cache_key)
|
||||
if response is None: # Cache miss
|
||||
if response is None: # Cache miss
|
||||
user_row = await db.get_data(user_id=user_id)
|
||||
cache_value = user_row.model_dump_json()
|
||||
cache.set_cache(key=cache_key, value=cache_value, ttl=600) # store for 10 minutes
|
||||
return
|
||||
cache.set_cache(
|
||||
key=cache_key, value=cache_value, ttl=600
|
||||
) # store for 10 minutes
|
||||
return
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue