mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 03:04:13 +00:00
Merge branch 'litellm_dynamo_db_keys'
This commit is contained in:
commit
8a7a745549
12 changed files with 547 additions and 58 deletions
|
@ -67,6 +67,7 @@ def generate_feedback_box():
|
|||
import litellm
|
||||
from litellm.proxy.utils import (
|
||||
PrismaClient,
|
||||
DBClient,
|
||||
get_instance_fn,
|
||||
ProxyLogging,
|
||||
_cache_user_row,
|
||||
|
@ -141,6 +142,7 @@ worker_config = None
|
|||
master_key = None
|
||||
otel_logging = False
|
||||
prisma_client: Optional[PrismaClient] = None
|
||||
custom_db_client: Optional[DBClient] = None
|
||||
user_api_key_cache = DualCache()
|
||||
user_custom_auth = None
|
||||
use_background_health_checks = None
|
||||
|
@ -184,12 +186,12 @@ def _get_pydantic_json_dict(pydantic_obj: BaseModel) -> dict:
|
|||
async def user_api_key_auth(
|
||||
request: Request, api_key: str = fastapi.Security(api_key_header)
|
||||
) -> UserAPIKeyAuth:
|
||||
global master_key, prisma_client, llm_model_list, user_custom_auth
|
||||
global master_key, prisma_client, llm_model_list, user_custom_auth, custom_db_client
|
||||
try:
|
||||
if isinstance(api_key, str):
|
||||
api_key = _get_bearer_token(api_key=api_key)
|
||||
### USER-DEFINED AUTH FUNCTION ###
|
||||
if user_custom_auth:
|
||||
if user_custom_auth is not None:
|
||||
response = await user_custom_auth(request=request, api_key=api_key)
|
||||
return UserAPIKeyAuth.model_validate(response)
|
||||
### LITELLM-DEFINED AUTH FUNCTION ###
|
||||
|
@ -231,7 +233,7 @@ async def user_api_key_auth(
|
|||
)
|
||||
|
||||
if (
|
||||
prisma_client is None
|
||||
prisma_client is None and custom_db_client is None
|
||||
): # if both master key + user key submitted, and user key != master key, and no db connected, raise an error
|
||||
raise Exception("No connected db.")
|
||||
|
||||
|
@ -241,15 +243,22 @@ async def user_api_key_auth(
|
|||
if valid_token is None:
|
||||
## check db
|
||||
verbose_proxy_logger.debug(f"api key: {api_key}")
|
||||
valid_token = await prisma_client.get_data(
|
||||
token=api_key,
|
||||
)
|
||||
expires = datetime.utcnow().replace(tzinfo=timezone.utc)
|
||||
if prisma_client is not None:
|
||||
valid_token = await prisma_client.get_data(
|
||||
token=api_key,
|
||||
)
|
||||
|
||||
expires = datetime.utcnow().replace(tzinfo=timezone.utc)
|
||||
elif custom_db_client is not None:
|
||||
valid_token = await custom_db_client.get_data(
|
||||
key=api_key, table_name="key"
|
||||
)
|
||||
# Token exists, now check expiration.
|
||||
if valid_token.expires is not None and expires is not None:
|
||||
if valid_token.expires >= expires:
|
||||
if valid_token.expires is not None:
|
||||
expiry_time = datetime.fromisoformat(valid_token.expires)
|
||||
if expiry_time >= datetime.utcnow():
|
||||
# Token exists and is not expired.
|
||||
return valid_token
|
||||
return response
|
||||
else:
|
||||
# Token exists but is expired.
|
||||
raise HTTPException(
|
||||
|
@ -291,13 +300,22 @@ async def user_api_key_auth(
|
|||
|
||||
This makes the user row data accessible to pre-api call hooks.
|
||||
"""
|
||||
asyncio.create_task(
|
||||
_cache_user_row(
|
||||
user_id=valid_token.user_id,
|
||||
cache=user_api_key_cache,
|
||||
db=prisma_client,
|
||||
if prisma_client is not None:
|
||||
asyncio.create_task(
|
||||
_cache_user_row(
|
||||
user_id=valid_token.user_id,
|
||||
cache=user_api_key_cache,
|
||||
db=prisma_client,
|
||||
)
|
||||
)
|
||||
elif custom_db_client is not None:
|
||||
asyncio.create_task(
|
||||
_cache_user_row(
|
||||
user_id=valid_token.user_id,
|
||||
cache=user_api_key_cache,
|
||||
db=custom_db_client,
|
||||
)
|
||||
)
|
||||
)
|
||||
return UserAPIKeyAuth(api_key=api_key, **valid_token_dict)
|
||||
else:
|
||||
raise Exception(f"Invalid token")
|
||||
|
@ -371,8 +389,8 @@ def load_from_azure_key_vault(use_azure_key_vault: bool = False):
|
|||
|
||||
|
||||
def cost_tracking():
|
||||
global prisma_client
|
||||
if prisma_client is not None:
|
||||
global prisma_client, custom_db_client
|
||||
if prisma_client is not None or custom_db_client is not None:
|
||||
if isinstance(litellm.success_callback, list):
|
||||
verbose_proxy_logger.debug("setting litellm success callback to track cost")
|
||||
if (track_cost_callback) not in litellm.success_callback: # type: ignore
|
||||
|
@ -385,7 +403,7 @@ async def track_cost_callback(
|
|||
start_time=None,
|
||||
end_time=None, # start/end time for completion
|
||||
):
|
||||
global prisma_client
|
||||
global prisma_client, custom_db_client
|
||||
try:
|
||||
# check if it has collected an entire stream response
|
||||
verbose_proxy_logger.debug(
|
||||
|
@ -404,10 +422,10 @@ async def track_cost_callback(
|
|||
user_id = kwargs["litellm_params"]["metadata"].get(
|
||||
"user_api_key_user_id", None
|
||||
)
|
||||
if user_api_key and prisma_client:
|
||||
await update_prisma_database(
|
||||
token=user_api_key, response_cost=response_cost
|
||||
)
|
||||
if user_api_key and (
|
||||
prisma_client is not None or custom_db_client is not None
|
||||
):
|
||||
await update_database(token=user_api_key, response_cost=response_cost)
|
||||
elif kwargs["stream"] == False: # for non streaming responses
|
||||
response_cost = litellm.completion_cost(
|
||||
completion_response=completion_response
|
||||
|
@ -418,15 +436,17 @@ async def track_cost_callback(
|
|||
user_id = kwargs["litellm_params"]["metadata"].get(
|
||||
"user_api_key_user_id", None
|
||||
)
|
||||
if user_api_key and prisma_client:
|
||||
await update_prisma_database(
|
||||
if user_api_key and (
|
||||
prisma_client is not None or custom_db_client is not None
|
||||
):
|
||||
await update_database(
|
||||
token=user_api_key, response_cost=response_cost, user_id=user_id
|
||||
)
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.debug(f"error in tracking cost callback - {str(e)}")
|
||||
|
||||
|
||||
async def update_prisma_database(token, response_cost, user_id=None):
|
||||
async def update_database(token, response_cost, user_id=None):
|
||||
try:
|
||||
verbose_proxy_logger.debug(
|
||||
f"Enters prisma db call, token: {token}; user_id: {user_id}"
|
||||
|
@ -436,7 +456,12 @@ async def update_prisma_database(token, response_cost, user_id=None):
|
|||
async def _update_user_db():
|
||||
if user_id is None:
|
||||
return
|
||||
existing_spend_obj = await prisma_client.get_data(user_id=user_id)
|
||||
if prisma_client is not None:
|
||||
existing_spend_obj = await prisma_client.get_data(user_id=user_id)
|
||||
elif custom_db_client is not None:
|
||||
existing_spend_obj = await custom_db_client.get_data(
|
||||
key=user_id, table_name="user"
|
||||
)
|
||||
if existing_spend_obj is None:
|
||||
existing_spend = 0
|
||||
else:
|
||||
|
@ -447,23 +472,49 @@ async def update_prisma_database(token, response_cost, user_id=None):
|
|||
|
||||
verbose_proxy_logger.debug(f"new cost: {new_spend}")
|
||||
# Update the cost column for the given user id
|
||||
await prisma_client.update_data(user_id=user_id, data={"spend": new_spend})
|
||||
if prisma_client is not None:
|
||||
await prisma_client.update_data(
|
||||
user_id=user_id, data={"spend": new_spend}
|
||||
)
|
||||
elif custom_db_client is not None:
|
||||
await custom_db_client.update_data(
|
||||
key=user_id, value={"spend": new_spend}, table_name="user"
|
||||
)
|
||||
|
||||
### UPDATE KEY SPEND ###
|
||||
async def _update_key_db():
|
||||
# Fetch the existing cost for the given token
|
||||
existing_spend_obj = await prisma_client.get_data(token=token)
|
||||
verbose_proxy_logger.debug(f"existing spend: {existing_spend_obj}")
|
||||
if existing_spend_obj is None:
|
||||
existing_spend = 0
|
||||
else:
|
||||
existing_spend = existing_spend_obj.spend
|
||||
# Calculate the new cost by adding the existing cost and response_cost
|
||||
new_spend = existing_spend + response_cost
|
||||
if prisma_client is not None:
|
||||
# Fetch the existing cost for the given token
|
||||
existing_spend_obj = await prisma_client.get_data(token=token)
|
||||
verbose_proxy_logger.debug(f"existing spend: {existing_spend_obj}")
|
||||
if existing_spend_obj is None:
|
||||
existing_spend = 0
|
||||
else:
|
||||
existing_spend = existing_spend_obj.spend
|
||||
# Calculate the new cost by adding the existing cost and response_cost
|
||||
new_spend = existing_spend + response_cost
|
||||
|
||||
verbose_proxy_logger.debug(f"new cost: {new_spend}")
|
||||
# Update the cost column for the given token
|
||||
await prisma_client.update_data(token=token, data={"spend": new_spend})
|
||||
verbose_proxy_logger.debug(f"new cost: {new_spend}")
|
||||
# Update the cost column for the given token
|
||||
await prisma_client.update_data(token=token, data={"spend": new_spend})
|
||||
elif custom_db_client is not None:
|
||||
# Fetch the existing cost for the given token
|
||||
existing_spend_obj = await custom_db_client.get_data(
|
||||
key=token, table_name="key"
|
||||
)
|
||||
verbose_proxy_logger.debug(f"existing spend: {existing_spend_obj}")
|
||||
if existing_spend_obj is None:
|
||||
existing_spend = 0
|
||||
else:
|
||||
existing_spend = existing_spend_obj.spend
|
||||
# Calculate the new cost by adding the existing cost and response_cost
|
||||
new_spend = existing_spend + response_cost
|
||||
|
||||
verbose_proxy_logger.debug(f"new cost: {new_spend}")
|
||||
# Update the cost column for the given token
|
||||
await custom_db_client.update_data(
|
||||
key=token, value={"spend": new_spend}, table_name="key"
|
||||
)
|
||||
|
||||
tasks = []
|
||||
tasks.append(_update_user_db())
|
||||
|
@ -622,7 +673,7 @@ class ProxyConfig:
|
|||
"""
|
||||
Load config values into proxy global state
|
||||
"""
|
||||
global master_key, user_config_file_path, otel_logging, user_custom_auth, user_custom_auth_path, use_background_health_checks, health_check_interval, use_queue
|
||||
global master_key, user_config_file_path, otel_logging, user_custom_auth, user_custom_auth_path, use_background_health_checks, health_check_interval, use_queue, custom_db_client
|
||||
|
||||
# Load existing config
|
||||
config = await self.get_config(config_file_path=config_file_path)
|
||||
|
@ -787,8 +838,6 @@ class ProxyConfig:
|
|||
verbose_proxy_logger.debug(f"GOING INTO LITELLM.GET_SECRET!")
|
||||
database_url = litellm.get_secret(database_url)
|
||||
verbose_proxy_logger.debug(f"RETRIEVED DB URL: {database_url}")
|
||||
## COST TRACKING ##
|
||||
cost_tracking()
|
||||
### MASTER KEY ###
|
||||
master_key = general_settings.get(
|
||||
"master_key", litellm.get_secret("LITELLM_MASTER_KEY", None)
|
||||
|
@ -796,11 +845,23 @@ class ProxyConfig:
|
|||
if master_key and master_key.startswith("os.environ/"):
|
||||
master_key = litellm.get_secret(master_key)
|
||||
### CUSTOM API KEY AUTH ###
|
||||
## pass filepath
|
||||
custom_auth = general_settings.get("custom_auth", None)
|
||||
if custom_auth:
|
||||
if custom_auth is not None:
|
||||
user_custom_auth = get_instance_fn(
|
||||
value=custom_auth, config_file_path=config_file_path
|
||||
)
|
||||
## dynamodb
|
||||
database_type = general_settings.get("database_type", None)
|
||||
if database_type is not None and (
|
||||
database_type == "dynamo_db" or database_type == "dynamodb"
|
||||
):
|
||||
database_args = general_settings.get("database_args", None)
|
||||
custom_db_client = DBClient(
|
||||
custom_db_args=database_args, custom_db_type=database_type
|
||||
)
|
||||
## COST TRACKING ##
|
||||
cost_tracking()
|
||||
### BACKGROUND HEALTH CHECKS ###
|
||||
# Enable background health checks
|
||||
use_background_health_checks = general_settings.get(
|
||||
|
@ -867,9 +928,9 @@ async def generate_key_helper_fn(
|
|||
max_parallel_requests: Optional[int] = None,
|
||||
metadata: Optional[dict] = {},
|
||||
):
|
||||
global prisma_client
|
||||
global prisma_client, custom_db_client
|
||||
|
||||
if prisma_client is None:
|
||||
if prisma_client is None and custom_db_client is None:
|
||||
raise Exception(
|
||||
f"Connect Proxy to database to generate keys - https://docs.litellm.ai/docs/proxy/virtual_keys "
|
||||
)
|
||||
|
@ -908,7 +969,13 @@ async def generate_key_helper_fn(
|
|||
user_id = user_id or str(uuid.uuid4())
|
||||
try:
|
||||
# Create a new verification token (you may want to enhance this logic based on your needs)
|
||||
verification_token_data = {
|
||||
user_data = {
|
||||
"max_budget": max_budget,
|
||||
"user_email": user_email,
|
||||
"user_id": user_id,
|
||||
"spend": spend,
|
||||
}
|
||||
key_data = {
|
||||
"token": token,
|
||||
"expires": expires,
|
||||
"models": models,
|
||||
|
@ -918,19 +985,23 @@ async def generate_key_helper_fn(
|
|||
"user_id": user_id,
|
||||
"max_parallel_requests": max_parallel_requests,
|
||||
"metadata": metadata_json,
|
||||
"max_budget": max_budget,
|
||||
"user_email": user_email,
|
||||
}
|
||||
verbose_proxy_logger.debug("PrismaClient: Before Insert Data")
|
||||
new_verification_token = await prisma_client.insert_data(
|
||||
data=verification_token_data
|
||||
)
|
||||
if prisma_client is not None:
|
||||
verification_token_data = dict(key_data)
|
||||
verification_token_data.update(user_data)
|
||||
verbose_proxy_logger.debug("PrismaClient: Before Insert Data")
|
||||
await prisma_client.insert_data(data=verification_token_data)
|
||||
elif custom_db_client is not None:
|
||||
## CREATE USER (If necessary)
|
||||
await custom_db_client.insert_data(value=user_data, table_name="user")
|
||||
## CREATE KEY
|
||||
await custom_db_client.insert_data(value=key_data, table_name="key")
|
||||
except Exception as e:
|
||||
traceback.print_exc()
|
||||
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR)
|
||||
return {
|
||||
"token": token,
|
||||
"expires": new_verification_token.expires,
|
||||
"expires": expires,
|
||||
"user_id": user_id,
|
||||
"max_budget": max_budget,
|
||||
}
|
||||
|
@ -1174,12 +1245,21 @@ async def startup_event():
|
|||
if prisma_client is not None:
|
||||
await prisma_client.connect()
|
||||
|
||||
if custom_db_client is not None:
|
||||
await custom_db_client.connect()
|
||||
|
||||
if prisma_client is not None and master_key is not None:
|
||||
# add master key to db
|
||||
await generate_key_helper_fn(
|
||||
duration=None, models=[], aliases={}, config={}, spend=0, token=master_key
|
||||
)
|
||||
|
||||
if custom_db_client is not None and master_key is not None:
|
||||
# add master key to db
|
||||
await generate_key_helper_fn(
|
||||
duration=None, models=[], aliases={}, config={}, spend=0, token=master_key
|
||||
)
|
||||
|
||||
|
||||
#### API ENDPOINTS ####
|
||||
@router.get(
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue