Merge branch 'main' into litellm_tpm_rpm_rate_limits

This commit is contained in:
Krrish Dholakia 2024-01-18 19:10:07 -08:00
commit f7694bc193
13 changed files with 378 additions and 31 deletions

View file

@ -72,6 +72,7 @@ from litellm.proxy.utils import (
ProxyLogging,
_cache_user_row,
send_email,
get_logging_payload,
)
from litellm.proxy.secret_managers.google_kms import load_google_kms
import pydantic
@ -518,6 +519,7 @@ async def track_cost_callback(
global prisma_client, custom_db_client
try:
# check if it has collected an entire stream response
verbose_proxy_logger.debug(f"Proxy: In track_cost_callback for {kwargs}")
verbose_proxy_logger.debug(
f"kwargs stream: {kwargs.get('stream', None)} + complete streaming response: {kwargs.get('complete_streaming_response', None)}"
)
@ -538,7 +540,13 @@ async def track_cost_callback(
prisma_client is not None or custom_db_client is not None
):
await update_database(
token=user_api_key, response_cost=response_cost, user_id=user_id
token=user_api_key,
response_cost=response_cost,
user_id=user_id,
kwargs=kwargs,
completion_response=completion_response,
start_time=start_time,
end_time=end_time,
)
elif kwargs["stream"] == False: # for non streaming responses
response_cost = litellm.completion_cost(
@ -554,13 +562,27 @@ async def track_cost_callback(
prisma_client is not None or custom_db_client is not None
):
await update_database(
token=user_api_key, response_cost=response_cost, user_id=user_id
token=user_api_key,
response_cost=response_cost,
user_id=user_id,
kwargs=kwargs,
completion_response=completion_response,
start_time=start_time,
end_time=end_time,
)
except Exception as e:
verbose_proxy_logger.debug(f"error in tracking cost callback - {str(e)}")
async def update_database(token, response_cost, user_id=None):
async def update_database(
token,
response_cost,
user_id=None,
kwargs=None,
completion_response=None,
start_time=None,
end_time=None,
):
try:
verbose_proxy_logger.debug(
f"Enters prisma db call, token: {token}; user_id: {user_id}"
@ -630,9 +652,28 @@ async def update_database(token, response_cost, user_id=None):
key=token, value={"spend": new_spend}, table_name="key"
)
async def _insert_spend_log_to_db():
# Helper to generate payload to log
verbose_proxy_logger.debug("inserting spend log to db")
payload = get_logging_payload(
kwargs=kwargs,
response_obj=completion_response,
start_time=start_time,
end_time=end_time,
)
payload["spend"] = response_cost
if prisma_client is not None:
await prisma_client.insert_data(data=payload, table_name="spend")
elif custom_db_client is not None:
await custom_db_client.insert_data(payload, table_name="spend")
tasks = []
tasks.append(_update_user_db())
tasks.append(_update_key_db())
tasks.append(_insert_spend_log_to_db())
await asyncio.gather(*tasks)
except Exception as e:
verbose_proxy_logger.debug(
@ -1037,6 +1078,7 @@ async def generate_key_helper_fn(
max_budget: Optional[float] = None,
token: Optional[str] = None,
user_id: Optional[str] = None,
team_id: Optional[str] = None,
user_email: Optional[str] = None,
max_parallel_requests: Optional[int] = None,
metadata: Optional[dict] = {},
@ -1084,12 +1126,15 @@ async def generate_key_helper_fn(
user_id = user_id or str(uuid.uuid4())
tpm_limit = tpm_limit or sys.maxsize
rpm_limit = rpm_limit or sys.maxsize
if type(team_id) is not str:
team_id = str(team_id)
try:
# Create a new verification token (you may want to enhance this logic based on your needs)
user_data = {
"max_budget": max_budget,
"user_email": user_email,
"user_id": user_id,
"team_id": team_id,
"spend": spend,
"models": models,
"max_parallel_requests": max_parallel_requests,
@ -1104,6 +1149,7 @@ async def generate_key_helper_fn(
"config": config_json,
"spend": spend,
"user_id": user_id,
"team_id": team_id,
"max_parallel_requests": max_parallel_requests,
"metadata": metadata_json,
"tpm_limit": tpm_limit,
@ -2051,6 +2097,7 @@ async def generate_key_fn(
Parameters:
- duration: Optional[str] - Specify the length of time the token is valid for. You can set duration as seconds ("30s"), minutes ("30m"), hours ("30h"), days ("30d"). **(Default is set to 1 hour.)**
- team_id: Optional[str] - The team id of the user
- models: Optional[list] - Model_name's a user is allowed to call. (if empty, key is allowed to call all models)
- aliases: Optional[dict] - Any alias mappings, on top of anything in the config.yaml model list. - https://docs.litellm.ai/docs/proxy/virtual_keys#managing-auth---upgradedowngrade-models
- config: Optional[dict] - any key-specific configs, overrides config in config.yaml