forked from phoenix/litellm-mirror
(v0 ) working - writing /chat/completion spend tracking
This commit is contained in:
parent
4a5f987512
commit
d14d36af9a
6 changed files with 114 additions and 17 deletions
|
@ -1,8 +1,8 @@
|
||||||
from pydantic import BaseModel, Extra, Field, root_validator
|
from pydantic import BaseModel, Extra, Field, root_validator, Json
|
||||||
import enum
|
import enum
|
||||||
from typing import Optional, List, Union, Dict, Literal
|
from typing import Optional, List, Union, Dict, Literal, Any
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
import uuid, json
|
import uuid, json, sys, os
|
||||||
|
|
||||||
|
|
||||||
class LiteLLMBase(BaseModel):
|
class LiteLLMBase(BaseModel):
|
||||||
|
@ -318,13 +318,14 @@ class LiteLLM_UserTable(LiteLLMBase):
|
||||||
class LiteLLM_SpendLogs(LiteLLMBase):
|
class LiteLLM_SpendLogs(LiteLLMBase):
|
||||||
request_id: str
|
request_id: str
|
||||||
call_type: str
|
call_type: str
|
||||||
startTime: Union[str, None]
|
startTime: Union[str, datetime, None]
|
||||||
endTime: Union[str, None]
|
endTime: Union[str, datetime, None]
|
||||||
model: str = ""
|
model: Optional[str] = ""
|
||||||
user: str = ""
|
user: Optional[str] = ""
|
||||||
modelParameters: Dict = {}
|
modelParameters: Optional[Json] = {}
|
||||||
messages: List[str] = []
|
messages: Optional[Json] = []
|
||||||
call_cost: float = 0.0
|
spend: Optional[float] = 0.0
|
||||||
response: Dict = {}
|
response: Optional[Json] = {}
|
||||||
usage: Dict = {}
|
usage: Optional[Json] = {}
|
||||||
metadata: Dict = {}
|
metadata: Optional[Json] = {}
|
||||||
|
cache_hit: Optional[str] = "False"
|
||||||
|
|
|
@ -61,8 +61,8 @@ litellm_settings:
|
||||||
# setting callback class
|
# setting callback class
|
||||||
# callbacks: custom_callbacks.proxy_handler_instance # sets litellm.callbacks = [proxy_handler_instance]
|
# callbacks: custom_callbacks.proxy_handler_instance # sets litellm.callbacks = [proxy_handler_instance]
|
||||||
|
|
||||||
# general_settings:
|
general_settings:
|
||||||
# master_key: sk-1234
|
master_key: sk-1234
|
||||||
# database_type: "dynamo_db"
|
# database_type: "dynamo_db"
|
||||||
# database_args: { # 👈 all args - https://github.com/BerriAI/litellm/blob/befbcbb7ac8f59835ce47415c128decf37aac328/litellm/proxy/_types.py#L190
|
# database_args: { # 👈 all args - https://github.com/BerriAI/litellm/blob/befbcbb7ac8f59835ce47415c128decf37aac328/litellm/proxy/_types.py#L190
|
||||||
# "billing_mode": "PAY_PER_REQUEST",
|
# "billing_mode": "PAY_PER_REQUEST",
|
||||||
|
|
|
@ -510,6 +510,7 @@ async def track_cost_callback(
|
||||||
global prisma_client, custom_db_client
|
global prisma_client, custom_db_client
|
||||||
try:
|
try:
|
||||||
# check if it has collected an entire stream response
|
# check if it has collected an entire stream response
|
||||||
|
verbose_proxy_logger.debug(f"Proxy: In track_cost_callback for {kwargs}")
|
||||||
verbose_proxy_logger.debug(
|
verbose_proxy_logger.debug(
|
||||||
f"kwargs stream: {kwargs.get('stream', None)} + complete streaming response: {kwargs.get('complete_streaming_response', None)}"
|
f"kwargs stream: {kwargs.get('stream', None)} + complete streaming response: {kwargs.get('complete_streaming_response', None)}"
|
||||||
)
|
)
|
||||||
|
@ -546,13 +547,27 @@ async def track_cost_callback(
|
||||||
prisma_client is not None or custom_db_client is not None
|
prisma_client is not None or custom_db_client is not None
|
||||||
):
|
):
|
||||||
await update_database(
|
await update_database(
|
||||||
token=user_api_key, response_cost=response_cost, user_id=user_id
|
token=user_api_key,
|
||||||
|
response_cost=response_cost,
|
||||||
|
user_id=user_id,
|
||||||
|
kwargs=kwargs,
|
||||||
|
completion_response=completion_response,
|
||||||
|
start_time=start_time,
|
||||||
|
end_time=end_time,
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
verbose_proxy_logger.debug(f"error in tracking cost callback - {str(e)}")
|
verbose_proxy_logger.debug(f"error in tracking cost callback - {str(e)}")
|
||||||
|
|
||||||
|
|
||||||
async def update_database(token, response_cost, user_id=None):
|
async def update_database(
|
||||||
|
token,
|
||||||
|
response_cost,
|
||||||
|
user_id=None,
|
||||||
|
kwargs=None,
|
||||||
|
completion_response=None,
|
||||||
|
start_time=None,
|
||||||
|
end_time=None,
|
||||||
|
):
|
||||||
try:
|
try:
|
||||||
verbose_proxy_logger.debug(
|
verbose_proxy_logger.debug(
|
||||||
f"Enters prisma db call, token: {token}; user_id: {user_id}"
|
f"Enters prisma db call, token: {token}; user_id: {user_id}"
|
||||||
|
@ -622,9 +637,25 @@ async def update_database(token, response_cost, user_id=None):
|
||||||
key=token, value={"spend": new_spend}, table_name="key"
|
key=token, value={"spend": new_spend}, table_name="key"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
async def _insert_spend_log_to_db():
|
||||||
|
# Helper to generate payload to log
|
||||||
|
verbose_proxy_logger.debug("inserting spend log to db")
|
||||||
|
payload = litellm.utils.get_logging_payload(
|
||||||
|
kwargs=kwargs,
|
||||||
|
response_obj=completion_response,
|
||||||
|
start_time=start_time,
|
||||||
|
end_time=end_time,
|
||||||
|
)
|
||||||
|
|
||||||
|
payload["spend"] = response_cost
|
||||||
|
|
||||||
|
if prisma_client is not None:
|
||||||
|
await prisma_client.insert_data(data=payload, table_name="spend")
|
||||||
|
|
||||||
tasks = []
|
tasks = []
|
||||||
tasks.append(_update_user_db())
|
tasks.append(_update_user_db())
|
||||||
tasks.append(_update_key_db())
|
tasks.append(_update_key_db())
|
||||||
|
tasks.append(_insert_spend_log_to_db())
|
||||||
await asyncio.gather(*tasks)
|
await asyncio.gather(*tasks)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
verbose_proxy_logger.debug(
|
verbose_proxy_logger.debug(
|
||||||
|
|
|
@ -46,4 +46,5 @@ model LiteLLM_SpendLogs {
|
||||||
response Json @default("{}")
|
response Json @default("{}")
|
||||||
usage Json @default("{}")
|
usage Json @default("{}")
|
||||||
metadata Json @default("{}")
|
metadata Json @default("{}")
|
||||||
|
cache_hit String @default("")
|
||||||
}
|
}
|
|
@ -8423,3 +8423,66 @@ def print_args_passed_to_litellm(original_function, args, kwargs):
|
||||||
except:
|
except:
|
||||||
# This should always be non blocking
|
# This should always be non blocking
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def get_logging_payload(kwargs, response_obj, start_time, end_time):
|
||||||
|
from litellm.proxy._types import LiteLLM_SpendLogs
|
||||||
|
from pydantic import Json
|
||||||
|
|
||||||
|
# standardize this function to be used across, s3, dynamoDB, langfuse logging
|
||||||
|
litellm_params = kwargs.get("litellm_params", {})
|
||||||
|
metadata = (
|
||||||
|
litellm_params.get("metadata", {}) or {}
|
||||||
|
) # if litellm_params['metadata'] == None
|
||||||
|
messages = kwargs.get("messages")
|
||||||
|
optional_params = kwargs.get("optional_params", {})
|
||||||
|
call_type = kwargs.get("call_type", "litellm.completion")
|
||||||
|
cache_hit = kwargs.get("cache_hit", False)
|
||||||
|
usage = response_obj["usage"]
|
||||||
|
id = response_obj.get("id", str(uuid.uuid4()))
|
||||||
|
|
||||||
|
payload = {
|
||||||
|
"request_id": id,
|
||||||
|
"call_type": call_type,
|
||||||
|
"cache_hit": cache_hit,
|
||||||
|
"startTime": start_time,
|
||||||
|
"endTime": end_time,
|
||||||
|
"model": kwargs.get("model", ""),
|
||||||
|
"user": kwargs.get("user", ""),
|
||||||
|
"modelParameters": optional_params,
|
||||||
|
"messages": messages,
|
||||||
|
"response": response_obj,
|
||||||
|
"usage": usage,
|
||||||
|
"metadata": metadata,
|
||||||
|
}
|
||||||
|
|
||||||
|
json_fields = [
|
||||||
|
field
|
||||||
|
for field, field_type in LiteLLM_SpendLogs.__annotations__.items()
|
||||||
|
if field_type == Json or field_type == Optional[Json]
|
||||||
|
]
|
||||||
|
str_fields = [
|
||||||
|
field
|
||||||
|
for field, field_type in LiteLLM_SpendLogs.__annotations__.items()
|
||||||
|
if field_type == str or field_type == Optional[str]
|
||||||
|
]
|
||||||
|
datetime_fields = [
|
||||||
|
field
|
||||||
|
for field, field_type in LiteLLM_SpendLogs.__annotations__.items()
|
||||||
|
if field_type == datetime
|
||||||
|
]
|
||||||
|
|
||||||
|
for param in json_fields:
|
||||||
|
if param in payload and type(payload[param]) != Json:
|
||||||
|
if type(payload[param]) == ModelResponse:
|
||||||
|
payload[param] = payload[param].model_dump_json()
|
||||||
|
elif type(payload[param]) == Usage:
|
||||||
|
payload[param] = payload[param].model_dump_json()
|
||||||
|
else:
|
||||||
|
payload[param] = json.dumps(payload[param])
|
||||||
|
|
||||||
|
for param in str_fields:
|
||||||
|
if param in payload and type(payload[param]) != str:
|
||||||
|
payload[param] = str(payload[param])
|
||||||
|
|
||||||
|
return payload
|
||||||
|
|
|
@ -46,4 +46,5 @@ model LiteLLM_SpendLogs {
|
||||||
response Json @default("{}")
|
response Json @default("{}")
|
||||||
usage Json @default("{}")
|
usage Json @default("{}")
|
||||||
metadata Json @default("{}")
|
metadata Json @default("{}")
|
||||||
|
cache_hit String @default("")
|
||||||
}
|
}
|
Loading…
Add table
Add a link
Reference in a new issue