(v0 ) working - writing /chat/completion spend tracking

This commit is contained in:
ishaan-jaff 2024-01-18 11:54:15 -08:00
parent 4a5f987512
commit d14d36af9a
6 changed files with 114 additions and 17 deletions

View file

@ -1,8 +1,8 @@
from pydantic import BaseModel, Extra, Field, root_validator from pydantic import BaseModel, Extra, Field, root_validator, Json
import enum import enum
from typing import Optional, List, Union, Dict, Literal from typing import Optional, List, Union, Dict, Literal, Any
from datetime import datetime from datetime import datetime
import uuid, json import uuid, json, sys, os
class LiteLLMBase(BaseModel): class LiteLLMBase(BaseModel):
@ -318,13 +318,14 @@ class LiteLLM_UserTable(LiteLLMBase):
class LiteLLM_SpendLogs(LiteLLMBase): class LiteLLM_SpendLogs(LiteLLMBase):
request_id: str request_id: str
call_type: str call_type: str
startTime: Union[str, None] startTime: Union[str, datetime, None]
endTime: Union[str, None] endTime: Union[str, datetime, None]
model: str = "" model: Optional[str] = ""
user: str = "" user: Optional[str] = ""
modelParameters: Dict = {} modelParameters: Optional[Json] = {}
messages: List[str] = [] messages: Optional[Json] = []
call_cost: float = 0.0 spend: Optional[float] = 0.0
response: Dict = {} response: Optional[Json] = {}
usage: Dict = {} usage: Optional[Json] = {}
metadata: Dict = {} metadata: Optional[Json] = {}
cache_hit: Optional[str] = "False"

View file

@ -61,8 +61,8 @@ litellm_settings:
# setting callback class # setting callback class
# callbacks: custom_callbacks.proxy_handler_instance # sets litellm.callbacks = [proxy_handler_instance] # callbacks: custom_callbacks.proxy_handler_instance # sets litellm.callbacks = [proxy_handler_instance]
# general_settings: general_settings:
# master_key: sk-1234 master_key: sk-1234
# database_type: "dynamo_db" # database_type: "dynamo_db"
# database_args: { # 👈 all args - https://github.com/BerriAI/litellm/blob/befbcbb7ac8f59835ce47415c128decf37aac328/litellm/proxy/_types.py#L190 # database_args: { # 👈 all args - https://github.com/BerriAI/litellm/blob/befbcbb7ac8f59835ce47415c128decf37aac328/litellm/proxy/_types.py#L190
# "billing_mode": "PAY_PER_REQUEST", # "billing_mode": "PAY_PER_REQUEST",

View file

@ -510,6 +510,7 @@ async def track_cost_callback(
global prisma_client, custom_db_client global prisma_client, custom_db_client
try: try:
# check if it has collected an entire stream response # check if it has collected an entire stream response
verbose_proxy_logger.debug(f"Proxy: In track_cost_callback for {kwargs}")
verbose_proxy_logger.debug( verbose_proxy_logger.debug(
f"kwargs stream: {kwargs.get('stream', None)} + complete streaming response: {kwargs.get('complete_streaming_response', None)}" f"kwargs stream: {kwargs.get('stream', None)} + complete streaming response: {kwargs.get('complete_streaming_response', None)}"
) )
@ -546,13 +547,27 @@ async def track_cost_callback(
prisma_client is not None or custom_db_client is not None prisma_client is not None or custom_db_client is not None
): ):
await update_database( await update_database(
token=user_api_key, response_cost=response_cost, user_id=user_id token=user_api_key,
response_cost=response_cost,
user_id=user_id,
kwargs=kwargs,
completion_response=completion_response,
start_time=start_time,
end_time=end_time,
) )
except Exception as e: except Exception as e:
verbose_proxy_logger.debug(f"error in tracking cost callback - {str(e)}") verbose_proxy_logger.debug(f"error in tracking cost callback - {str(e)}")
async def update_database(token, response_cost, user_id=None): async def update_database(
token,
response_cost,
user_id=None,
kwargs=None,
completion_response=None,
start_time=None,
end_time=None,
):
try: try:
verbose_proxy_logger.debug( verbose_proxy_logger.debug(
f"Enters prisma db call, token: {token}; user_id: {user_id}" f"Enters prisma db call, token: {token}; user_id: {user_id}"
@ -622,9 +637,25 @@ async def update_database(token, response_cost, user_id=None):
key=token, value={"spend": new_spend}, table_name="key" key=token, value={"spend": new_spend}, table_name="key"
) )
async def _insert_spend_log_to_db():
# Helper to generate payload to log
verbose_proxy_logger.debug("inserting spend log to db")
payload = litellm.utils.get_logging_payload(
kwargs=kwargs,
response_obj=completion_response,
start_time=start_time,
end_time=end_time,
)
payload["spend"] = response_cost
if prisma_client is not None:
await prisma_client.insert_data(data=payload, table_name="spend")
tasks = [] tasks = []
tasks.append(_update_user_db()) tasks.append(_update_user_db())
tasks.append(_update_key_db()) tasks.append(_update_key_db())
tasks.append(_insert_spend_log_to_db())
await asyncio.gather(*tasks) await asyncio.gather(*tasks)
except Exception as e: except Exception as e:
verbose_proxy_logger.debug( verbose_proxy_logger.debug(

View file

@ -46,4 +46,5 @@ model LiteLLM_SpendLogs {
response Json @default("{}") response Json @default("{}")
usage Json @default("{}") usage Json @default("{}")
metadata Json @default("{}") metadata Json @default("{}")
cache_hit String @default("")
} }

View file

@ -8423,3 +8423,66 @@ def print_args_passed_to_litellm(original_function, args, kwargs):
except: except:
# This should always be non blocking # This should always be non blocking
pass pass
def get_logging_payload(kwargs, response_obj, start_time, end_time):
from litellm.proxy._types import LiteLLM_SpendLogs
from pydantic import Json
# standardize this function to be used across, s3, dynamoDB, langfuse logging
litellm_params = kwargs.get("litellm_params", {})
metadata = (
litellm_params.get("metadata", {}) or {}
) # if litellm_params['metadata'] == None
messages = kwargs.get("messages")
optional_params = kwargs.get("optional_params", {})
call_type = kwargs.get("call_type", "litellm.completion")
cache_hit = kwargs.get("cache_hit", False)
usage = response_obj["usage"]
id = response_obj.get("id", str(uuid.uuid4()))
payload = {
"request_id": id,
"call_type": call_type,
"cache_hit": cache_hit,
"startTime": start_time,
"endTime": end_time,
"model": kwargs.get("model", ""),
"user": kwargs.get("user", ""),
"modelParameters": optional_params,
"messages": messages,
"response": response_obj,
"usage": usage,
"metadata": metadata,
}
json_fields = [
field
for field, field_type in LiteLLM_SpendLogs.__annotations__.items()
if field_type == Json or field_type == Optional[Json]
]
str_fields = [
field
for field, field_type in LiteLLM_SpendLogs.__annotations__.items()
if field_type == str or field_type == Optional[str]
]
datetime_fields = [
field
for field, field_type in LiteLLM_SpendLogs.__annotations__.items()
if field_type == datetime
]
for param in json_fields:
if param in payload and type(payload[param]) != Json:
if type(payload[param]) == ModelResponse:
payload[param] = payload[param].model_dump_json()
elif type(payload[param]) == Usage:
payload[param] = payload[param].model_dump_json()
else:
payload[param] = json.dumps(payload[param])
for param in str_fields:
if param in payload and type(payload[param]) != str:
payload[param] = str(payload[param])
return payload

View file

@ -46,4 +46,5 @@ model LiteLLM_SpendLogs {
response Json @default("{}") response Json @default("{}")
usage Json @default("{}") usage Json @default("{}")
metadata Json @default("{}") metadata Json @default("{}")
cache_hit String @default("")
} }