(v0 ) working - writing /chat/completion spend tracking

This commit is contained in:
ishaan-jaff 2024-01-18 11:54:15 -08:00
parent 4a5f987512
commit d14d36af9a
6 changed files with 114 additions and 17 deletions

View file

@ -1,8 +1,8 @@
from pydantic import BaseModel, Extra, Field, root_validator
from pydantic import BaseModel, Extra, Field, root_validator, Json
import enum
from typing import Optional, List, Union, Dict, Literal
from typing import Optional, List, Union, Dict, Literal, Any
from datetime import datetime
import uuid, json
import uuid, json, sys, os
class LiteLLMBase(BaseModel):
@ -318,13 +318,14 @@ class LiteLLM_UserTable(LiteLLMBase):
class LiteLLM_SpendLogs(LiteLLMBase):
request_id: str
call_type: str
startTime: Union[str, None]
endTime: Union[str, None]
model: str = ""
user: str = ""
modelParameters: Dict = {}
messages: List[str] = []
call_cost: float = 0.0
response: Dict = {}
usage: Dict = {}
metadata: Dict = {}
startTime: Union[str, datetime, None]
endTime: Union[str, datetime, None]
model: Optional[str] = ""
user: Optional[str] = ""
modelParameters: Optional[Json] = {}
messages: Optional[Json] = []
spend: Optional[float] = 0.0
response: Optional[Json] = {}
usage: Optional[Json] = {}
metadata: Optional[Json] = {}
cache_hit: Optional[str] = "False"

View file

@ -61,8 +61,8 @@ litellm_settings:
# setting callback class
# callbacks: custom_callbacks.proxy_handler_instance # sets litellm.callbacks = [proxy_handler_instance]
# general_settings:
# master_key: sk-1234
general_settings:
master_key: sk-1234
# database_type: "dynamo_db"
# database_args: { # 👈 all args - https://github.com/BerriAI/litellm/blob/befbcbb7ac8f59835ce47415c128decf37aac328/litellm/proxy/_types.py#L190
# "billing_mode": "PAY_PER_REQUEST",

View file

@ -510,6 +510,7 @@ async def track_cost_callback(
global prisma_client, custom_db_client
try:
# check if it has collected an entire stream response
verbose_proxy_logger.debug(f"Proxy: In track_cost_callback for {kwargs}")
verbose_proxy_logger.debug(
f"kwargs stream: {kwargs.get('stream', None)} + complete streaming response: {kwargs.get('complete_streaming_response', None)}"
)
@ -546,13 +547,27 @@ async def track_cost_callback(
prisma_client is not None or custom_db_client is not None
):
await update_database(
token=user_api_key, response_cost=response_cost, user_id=user_id
token=user_api_key,
response_cost=response_cost,
user_id=user_id,
kwargs=kwargs,
completion_response=completion_response,
start_time=start_time,
end_time=end_time,
)
except Exception as e:
verbose_proxy_logger.debug(f"error in tracking cost callback - {str(e)}")
async def update_database(token, response_cost, user_id=None):
async def update_database(
token,
response_cost,
user_id=None,
kwargs=None,
completion_response=None,
start_time=None,
end_time=None,
):
try:
verbose_proxy_logger.debug(
f"Enters prisma db call, token: {token}; user_id: {user_id}"
@ -622,9 +637,25 @@ async def update_database(token, response_cost, user_id=None):
key=token, value={"spend": new_spend}, table_name="key"
)
async def _insert_spend_log_to_db():
# Helper to generate payload to log
verbose_proxy_logger.debug("inserting spend log to db")
payload = litellm.utils.get_logging_payload(
kwargs=kwargs,
response_obj=completion_response,
start_time=start_time,
end_time=end_time,
)
payload["spend"] = response_cost
if prisma_client is not None:
await prisma_client.insert_data(data=payload, table_name="spend")
tasks = []
tasks.append(_update_user_db())
tasks.append(_update_key_db())
tasks.append(_insert_spend_log_to_db())
await asyncio.gather(*tasks)
except Exception as e:
verbose_proxy_logger.debug(

View file

@ -46,4 +46,5 @@ model LiteLLM_SpendLogs {
response Json @default("{}")
usage Json @default("{}")
metadata Json @default("{}")
cache_hit String @default("")
}

View file

@ -8423,3 +8423,66 @@ def print_args_passed_to_litellm(original_function, args, kwargs):
except:
# This should always be non blocking
pass
def get_logging_payload(kwargs, response_obj, start_time, end_time):
from litellm.proxy._types import LiteLLM_SpendLogs
from pydantic import Json
# standardize this function to be used across, s3, dynamoDB, langfuse logging
litellm_params = kwargs.get("litellm_params", {})
metadata = (
litellm_params.get("metadata", {}) or {}
) # if litellm_params['metadata'] == None
messages = kwargs.get("messages")
optional_params = kwargs.get("optional_params", {})
call_type = kwargs.get("call_type", "litellm.completion")
cache_hit = kwargs.get("cache_hit", False)
usage = response_obj["usage"]
id = response_obj.get("id", str(uuid.uuid4()))
payload = {
"request_id": id,
"call_type": call_type,
"cache_hit": cache_hit,
"startTime": start_time,
"endTime": end_time,
"model": kwargs.get("model", ""),
"user": kwargs.get("user", ""),
"modelParameters": optional_params,
"messages": messages,
"response": response_obj,
"usage": usage,
"metadata": metadata,
}
json_fields = [
field
for field, field_type in LiteLLM_SpendLogs.__annotations__.items()
if field_type == Json or field_type == Optional[Json]
]
str_fields = [
field
for field, field_type in LiteLLM_SpendLogs.__annotations__.items()
if field_type == str or field_type == Optional[str]
]
datetime_fields = [
field
for field, field_type in LiteLLM_SpendLogs.__annotations__.items()
if field_type == datetime
]
for param in json_fields:
if param in payload and type(payload[param]) != Json:
if type(payload[param]) == ModelResponse:
payload[param] = payload[param].model_dump_json()
elif type(payload[param]) == Usage:
payload[param] = payload[param].model_dump_json()
else:
payload[param] = json.dumps(payload[param])
for param in str_fields:
if param in payload and type(payload[param]) != str:
payload[param] = str(payload[param])
return payload

View file

@ -46,4 +46,5 @@ model LiteLLM_SpendLogs {
response Json @default("{}")
usage Json @default("{}")
metadata Json @default("{}")
cache_hit String @default("")
}