Merge pull request #1498 from BerriAI/litellm_spend_tracking_logs

[Feat] Proxy - Add Spend tracking logs
This commit is contained in:
Ishaan Jaff 2024-01-18 14:21:51 -08:00 committed by GitHub
commit a26267851f
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
8 changed files with 231 additions and 14 deletions

View file

@ -1,8 +1,8 @@
from pydantic import BaseModel, Extra, Field, root_validator from pydantic import BaseModel, Extra, Field, root_validator, Json
import enum import enum
from typing import Optional, List, Union, Dict, Literal from typing import Optional, List, Union, Dict, Literal, Any
from datetime import datetime from datetime import datetime
import uuid, json import uuid, json, sys, os
class LiteLLMBase(BaseModel): class LiteLLMBase(BaseModel):
@ -196,6 +196,7 @@ class DynamoDBArgs(LiteLLMBase):
user_table_name: str = "LiteLLM_UserTable" user_table_name: str = "LiteLLM_UserTable"
key_table_name: str = "LiteLLM_VerificationToken" key_table_name: str = "LiteLLM_VerificationToken"
config_table_name: str = "LiteLLM_Config" config_table_name: str = "LiteLLM_Config"
spend_table_name: str = "LiteLLM_SpendLogs"
class ConfigGeneralSettings(LiteLLMBase): class ConfigGeneralSettings(LiteLLMBase):
@ -314,3 +315,20 @@ class LiteLLM_UserTable(LiteLLMBase):
if values.get("models") is None: if values.get("models") is None:
values.update({"models", []}) values.update({"models", []})
return values return values
class LiteLLM_SpendLogs(LiteLLMBase):
request_id: str
api_key: str
model: Optional[str] = ""
call_type: str
spend: Optional[float] = 0.0
startTime: Union[str, datetime, None]
endTime: Union[str, datetime, None]
user: Optional[str] = ""
modelParameters: Optional[Json] = {}
messages: Optional[Json] = []
response: Optional[Json] = {}
usage: Optional[Json] = {}
metadata: Optional[Json] = {}
cache_hit: Optional[str] = "False"

View file

@ -131,10 +131,27 @@ class DynamoDBWrapper(CustomDB):
raise Exception( raise Exception(
f"Failed to create table - {self.database_arguments.config_table_name}.\nPlease create a new table called {self.database_arguments.config_table_name}\nAND set `hash_key` as 'param_name'" f"Failed to create table - {self.database_arguments.config_table_name}.\nPlease create a new table called {self.database_arguments.config_table_name}\nAND set `hash_key` as 'param_name'"
) )
## Spend
try:
verbose_proxy_logger.debug("DynamoDB Wrapper - Creating Spend Table")
error_occurred = False
table = client.table(self.database_arguments.spend_table_name)
if not await table.exists():
await table.create(
self.throughput_type,
KeySchema(hash_key=KeySpec("request_id", KeyType.string)),
)
except Exception as e:
error_occurred = True
if error_occurred == True:
raise Exception(
f"Failed to create table - {self.database_arguments.key_table_name}.\nPlease create a new table called {self.database_arguments.key_table_name}\nAND set `hash_key` as 'token'"
)
verbose_proxy_logger.debug("DynamoDB Wrapper - Done connecting()") verbose_proxy_logger.debug("DynamoDB Wrapper - Done connecting()")
async def insert_data( async def insert_data(
self, value: Any, table_name: Literal["user", "key", "config"] self, value: Any, table_name: Literal["user", "key", "config", "spend"]
): ):
from aiodynamo.client import Client from aiodynamo.client import Client
from aiodynamo.credentials import Credentials, StaticCredentials from aiodynamo.credentials import Credentials, StaticCredentials
@ -166,6 +183,8 @@ class DynamoDBWrapper(CustomDB):
table = client.table(self.database_arguments.key_table_name) table = client.table(self.database_arguments.key_table_name)
elif table_name == "config": elif table_name == "config":
table = client.table(self.database_arguments.config_table_name) table = client.table(self.database_arguments.config_table_name)
elif table_name == "spend":
table = client.table(self.database_arguments.spend_table_name)
for k, v in value.items(): for k, v in value.items():
if isinstance(v, datetime): if isinstance(v, datetime):

View file

@ -61,8 +61,8 @@ litellm_settings:
# setting callback class # setting callback class
# callbacks: custom_callbacks.proxy_handler_instance # sets litellm.callbacks = [proxy_handler_instance] # callbacks: custom_callbacks.proxy_handler_instance # sets litellm.callbacks = [proxy_handler_instance]
# general_settings: general_settings:
# master_key: sk-1234 master_key: sk-1234
# database_type: "dynamo_db" # database_type: "dynamo_db"
# database_args: { # 👈 all args - https://github.com/BerriAI/litellm/blob/befbcbb7ac8f59835ce47415c128decf37aac328/litellm/proxy/_types.py#L190 # database_args: { # 👈 all args - https://github.com/BerriAI/litellm/blob/befbcbb7ac8f59835ce47415c128decf37aac328/litellm/proxy/_types.py#L190
# "billing_mode": "PAY_PER_REQUEST", # "billing_mode": "PAY_PER_REQUEST",

View file

@ -72,6 +72,7 @@ from litellm.proxy.utils import (
ProxyLogging, ProxyLogging,
_cache_user_row, _cache_user_row,
send_email, send_email,
get_logging_payload,
) )
from litellm.proxy.secret_managers.google_kms import load_google_kms from litellm.proxy.secret_managers.google_kms import load_google_kms
import pydantic import pydantic
@ -518,6 +519,7 @@ async def track_cost_callback(
global prisma_client, custom_db_client global prisma_client, custom_db_client
try: try:
# check if it has collected an entire stream response # check if it has collected an entire stream response
verbose_proxy_logger.debug(f"Proxy: In track_cost_callback for {kwargs}")
verbose_proxy_logger.debug( verbose_proxy_logger.debug(
f"kwargs stream: {kwargs.get('stream', None)} + complete streaming response: {kwargs.get('complete_streaming_response', None)}" f"kwargs stream: {kwargs.get('stream', None)} + complete streaming response: {kwargs.get('complete_streaming_response', None)}"
) )
@ -538,7 +540,13 @@ async def track_cost_callback(
prisma_client is not None or custom_db_client is not None prisma_client is not None or custom_db_client is not None
): ):
await update_database( await update_database(
token=user_api_key, response_cost=response_cost, user_id=user_id token=user_api_key,
response_cost=response_cost,
user_id=user_id,
kwargs=kwargs,
completion_response=completion_response,
start_time=start_time,
end_time=end_time,
) )
elif kwargs["stream"] == False: # for non streaming responses elif kwargs["stream"] == False: # for non streaming responses
response_cost = litellm.completion_cost( response_cost = litellm.completion_cost(
@ -554,13 +562,27 @@ async def track_cost_callback(
prisma_client is not None or custom_db_client is not None prisma_client is not None or custom_db_client is not None
): ):
await update_database( await update_database(
token=user_api_key, response_cost=response_cost, user_id=user_id token=user_api_key,
response_cost=response_cost,
user_id=user_id,
kwargs=kwargs,
completion_response=completion_response,
start_time=start_time,
end_time=end_time,
) )
except Exception as e: except Exception as e:
verbose_proxy_logger.debug(f"error in tracking cost callback - {str(e)}") verbose_proxy_logger.debug(f"error in tracking cost callback - {str(e)}")
async def update_database(token, response_cost, user_id=None): async def update_database(
token,
response_cost,
user_id=None,
kwargs=None,
completion_response=None,
start_time=None,
end_time=None,
):
try: try:
verbose_proxy_logger.debug( verbose_proxy_logger.debug(
f"Enters prisma db call, token: {token}; user_id: {user_id}" f"Enters prisma db call, token: {token}; user_id: {user_id}"
@ -630,9 +652,28 @@ async def update_database(token, response_cost, user_id=None):
key=token, value={"spend": new_spend}, table_name="key" key=token, value={"spend": new_spend}, table_name="key"
) )
async def _insert_spend_log_to_db():
# Helper to generate payload to log
verbose_proxy_logger.debug("inserting spend log to db")
payload = get_logging_payload(
kwargs=kwargs,
response_obj=completion_response,
start_time=start_time,
end_time=end_time,
)
payload["spend"] = response_cost
if prisma_client is not None:
await prisma_client.insert_data(data=payload, table_name="spend")
elif custom_db_client is not None:
await custom_db_client.insert_data(payload, table_name="spend")
tasks = [] tasks = []
tasks.append(_update_user_db()) tasks.append(_update_user_db())
tasks.append(_update_key_db()) tasks.append(_update_key_db())
tasks.append(_insert_spend_log_to_db())
await asyncio.gather(*tasks) await asyncio.gather(*tasks)
except Exception as e: except Exception as e:
verbose_proxy_logger.debug( verbose_proxy_logger.debug(

View file

@ -32,3 +32,20 @@ model LiteLLM_Config {
param_name String @id param_name String @id
param_value Json? param_value Json?
} }
model LiteLLM_SpendLogs {
request_id String @unique
call_type String
api_key String @default ("")
spend Float @default(0.0)
startTime DateTime // Assuming start_time is a DateTime field
endTime DateTime // Assuming end_time is a DateTime field
model String @default("")
user String @default("")
modelParameters Json @default("{}")// Assuming optional_params is a JSON field
messages Json @default("[]")
response Json @default("{}")
usage Json @default("{}")
metadata Json @default("{}")
cache_hit String @default("")
}

View file

@ -1,7 +1,12 @@
from typing import Optional, List, Any, Literal, Union from typing import Optional, List, Any, Literal, Union
import os, subprocess, hashlib, importlib, asyncio, copy, json, aiohttp, httpx import os, subprocess, hashlib, importlib, asyncio, copy, json, aiohttp, httpx
import litellm, backoff import litellm, backoff
from litellm.proxy._types import UserAPIKeyAuth, DynamoDBArgs, LiteLLM_VerificationToken from litellm.proxy._types import (
UserAPIKeyAuth,
DynamoDBArgs,
LiteLLM_VerificationToken,
LiteLLM_SpendLogs,
)
from litellm.caching import DualCache from litellm.caching import DualCache
from litellm.proxy.hooks.parallel_request_limiter import MaxParallelRequestsHandler from litellm.proxy.hooks.parallel_request_limiter import MaxParallelRequestsHandler
from litellm.proxy.hooks.max_budget_limiter import MaxBudgetLimiter from litellm.proxy.hooks.max_budget_limiter import MaxBudgetLimiter
@ -316,7 +321,7 @@ class PrismaClient:
self, self,
key: str, key: str,
value: Any, value: Any,
table_name: Literal["users", "keys", "config"], table_name: Literal["users", "keys", "config", "spend"],
): ):
""" """
Generic implementation of get data Generic implementation of get data
@ -334,6 +339,10 @@ class PrismaClient:
response = await self.db.litellm_config.find_first( # type: ignore response = await self.db.litellm_config.find_first( # type: ignore
where={key: value} # type: ignore where={key: value} # type: ignore
) )
elif table_name == "spend":
response = await self.db.l.find_first( # type: ignore
where={key: value} # type: ignore
)
return response return response
except Exception as e: except Exception as e:
asyncio.create_task( asyncio.create_task(
@ -417,7 +426,7 @@ class PrismaClient:
on_backoff=on_backoff, # specifying the function to call on backoff on_backoff=on_backoff, # specifying the function to call on backoff
) )
async def insert_data( async def insert_data(
self, data: dict, table_name: Literal["user", "key", "config"] self, data: dict, table_name: Literal["user", "key", "config", "spend"]
): ):
""" """
Add a key to the database. If it already exists, do nothing. Add a key to the database. If it already exists, do nothing.
@ -473,8 +482,18 @@ class PrismaClient:
) )
tasks.append(updated_table_row) tasks.append(updated_table_row)
await asyncio.gather(*tasks) await asyncio.gather(*tasks)
elif table_name == "spend":
db_data = self.jsonify_object(data=data)
new_spend_row = await self.db.litellm_spendlogs.upsert(
where={"request_id": data["request_id"]},
data={
"create": {**db_data}, # type: ignore
"update": {}, # don't do anything if it already exists
},
)
return new_spend_row
except Exception as e: except Exception as e:
print_verbose(f"LiteLLM Prisma Client Exception: {e}") print_verbose(f"LiteLLM Prisma Client Exception: {e}")
asyncio.create_task( asyncio.create_task(
@ -760,3 +779,85 @@ async def send_email(sender_name, sender_email, receiver_email, subject, html):
except Exception as e: except Exception as e:
print_verbose("An error occurred while sending the email:", str(e)) print_verbose("An error occurred while sending the email:", str(e))
def hash_token(token: str):
import hashlib
# Hash the string using SHA-256
hashed_token = hashlib.sha256(token.encode()).hexdigest()
return hashed_token
def get_logging_payload(kwargs, response_obj, start_time, end_time):
from litellm.proxy._types import LiteLLM_SpendLogs
from pydantic import Json
import uuid
if kwargs == None:
kwargs = {}
# standardize this function to be used across, s3, dynamoDB, langfuse logging
litellm_params = kwargs.get("litellm_params", {})
metadata = (
litellm_params.get("metadata", {}) or {}
) # if litellm_params['metadata'] == None
messages = kwargs.get("messages")
optional_params = kwargs.get("optional_params", {})
call_type = kwargs.get("call_type", "litellm.completion")
cache_hit = kwargs.get("cache_hit", False)
usage = response_obj["usage"]
id = response_obj.get("id", str(uuid.uuid4()))
api_key = metadata.get("user_api_key", "")
if api_key is not None and type(api_key) == str:
# hash the api_key
api_key = hash_token(api_key)
payload = {
"request_id": id,
"call_type": call_type,
"api_key": api_key,
"cache_hit": cache_hit,
"startTime": start_time,
"endTime": end_time,
"model": kwargs.get("model", ""),
"user": kwargs.get("user", ""),
"modelParameters": optional_params,
"messages": messages,
"response": response_obj,
"usage": usage,
"metadata": metadata,
}
json_fields = [
field
for field, field_type in LiteLLM_SpendLogs.__annotations__.items()
if field_type == Json or field_type == Optional[Json]
]
str_fields = [
field
for field, field_type in LiteLLM_SpendLogs.__annotations__.items()
if field_type == str or field_type == Optional[str]
]
datetime_fields = [
field
for field, field_type in LiteLLM_SpendLogs.__annotations__.items()
if field_type == datetime
]
for param in json_fields:
if param in payload and type(payload[param]) != Json:
if type(payload[param]) == litellm.ModelResponse:
payload[param] = payload[param].model_dump_json()
if type(payload[param]) == litellm.EmbeddingResponse:
payload[param] = payload[param].model_dump_json()
elif type(payload[param]) == litellm.Usage:
payload[param] = payload[param].model_dump_json()
else:
payload[param] = json.dumps(payload[param])
for param in str_fields:
if param in payload and type(payload[param]) != str:
payload[param] = str(payload[param])
return payload

View file

@ -179,6 +179,10 @@ def test_call_with_key_over_budget(custom_db_client):
# 5. Make a call with a key over budget, expect to fail # 5. Make a call with a key over budget, expect to fail
setattr(litellm.proxy.proxy_server, "custom_db_client", custom_db_client) setattr(litellm.proxy.proxy_server, "custom_db_client", custom_db_client)
setattr(litellm.proxy.proxy_server, "master_key", "sk-1234") setattr(litellm.proxy.proxy_server, "master_key", "sk-1234")
from litellm._logging import verbose_proxy_logger
import logging
verbose_proxy_logger.setLevel(logging.DEBUG)
try: try:
async def test(): async def test():

View file

@ -32,3 +32,20 @@ model LiteLLM_Config {
param_name String @id param_name String @id
param_value Json? param_value Json?
} }
model LiteLLM_SpendLogs {
request_id String @unique
api_key String @default ("")
call_type String
spend Float @default(0.0)
startTime DateTime // Assuming start_time is a DateTime field
endTime DateTime // Assuming end_time is a DateTime field
model String @default("")
user String @default("")
modelParameters Json @default("{}")// Assuming optional_params is a JSON field
messages Json @default("[]")
response Json @default("{}")
usage Json @default("{}")
metadata Json @default("{}")
cache_hit String @default("")
}