From d14d36af9add19434f96a6ef49482c18acc610b3 Mon Sep 17 00:00:00 2001 From: ishaan-jaff Date: Thu, 18 Jan 2024 11:54:15 -0800 Subject: [PATCH] (v0 ) working - writing /chat/completion spend tracking --- litellm/proxy/_types.py | 27 +++++++------- litellm/proxy/proxy_config.yaml | 4 +-- litellm/proxy/proxy_server.py | 35 ++++++++++++++++-- litellm/proxy/schema.prisma | 1 + litellm/utils.py | 63 +++++++++++++++++++++++++++++++++ schema.prisma | 1 + 6 files changed, 114 insertions(+), 17 deletions(-) diff --git a/litellm/proxy/_types.py b/litellm/proxy/_types.py index cf8350022..9bc6b09b1 100644 --- a/litellm/proxy/_types.py +++ b/litellm/proxy/_types.py @@ -1,8 +1,8 @@ -from pydantic import BaseModel, Extra, Field, root_validator +from pydantic import BaseModel, Extra, Field, root_validator, Json import enum -from typing import Optional, List, Union, Dict, Literal +from typing import Optional, List, Union, Dict, Literal, Any from datetime import datetime -import uuid, json +import uuid, json, sys, os class LiteLLMBase(BaseModel): @@ -318,13 +318,14 @@ class LiteLLM_UserTable(LiteLLMBase): class LiteLLM_SpendLogs(LiteLLMBase): request_id: str call_type: str - startTime: Union[str, None] - endTime: Union[str, None] - model: str = "" - user: str = "" - modelParameters: Dict = {} - messages: List[str] = [] - call_cost: float = 0.0 - response: Dict = {} - usage: Dict = {} - metadata: Dict = {} + startTime: Union[str, datetime, None] + endTime: Union[str, datetime, None] + model: Optional[str] = "" + user: Optional[str] = "" + modelParameters: Optional[Json] = {} + messages: Optional[Json] = [] + spend: Optional[float] = 0.0 + response: Optional[Json] = {} + usage: Optional[Json] = {} + metadata: Optional[Json] = {} + cache_hit: Optional[str] = "False" diff --git a/litellm/proxy/proxy_config.yaml b/litellm/proxy/proxy_config.yaml index 5b87ab775..8cd2fcec8 100644 --- a/litellm/proxy/proxy_config.yaml +++ b/litellm/proxy/proxy_config.yaml @@ -61,8 +61,8 @@ litellm_settings: # setting callback class # callbacks: custom_callbacks.proxy_handler_instance # sets litellm.callbacks = [proxy_handler_instance] -# general_settings: - # master_key: sk-1234 +general_settings: + master_key: sk-1234 # database_type: "dynamo_db" # database_args: { # 👈 all args - https://github.com/BerriAI/litellm/blob/befbcbb7ac8f59835ce47415c128decf37aac328/litellm/proxy/_types.py#L190 # "billing_mode": "PAY_PER_REQUEST", diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index d3667892b..fdc81f88e 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -510,6 +510,7 @@ async def track_cost_callback( global prisma_client, custom_db_client try: # check if it has collected an entire stream response + verbose_proxy_logger.debug(f"Proxy: In track_cost_callback for {kwargs}") verbose_proxy_logger.debug( f"kwargs stream: {kwargs.get('stream', None)} + complete streaming response: {kwargs.get('complete_streaming_response', None)}" ) @@ -546,13 +547,27 @@ async def track_cost_callback( prisma_client is not None or custom_db_client is not None ): await update_database( - token=user_api_key, response_cost=response_cost, user_id=user_id + token=user_api_key, + response_cost=response_cost, + user_id=user_id, + kwargs=kwargs, + completion_response=completion_response, + start_time=start_time, + end_time=end_time, ) except Exception as e: verbose_proxy_logger.debug(f"error in tracking cost callback - {str(e)}") -async def update_database(token, response_cost, user_id=None): +async def update_database( + token, + response_cost, + user_id=None, + kwargs=None, + completion_response=None, + start_time=None, + end_time=None, +): try: verbose_proxy_logger.debug( f"Enters prisma db call, token: {token}; user_id: {user_id}" @@ -622,9 +637,25 @@ async def update_database(token, response_cost, user_id=None): key=token, value={"spend": new_spend}, table_name="key" ) + async def _insert_spend_log_to_db(): + # Helper to generate payload to log + verbose_proxy_logger.debug("inserting spend log to db") + payload = litellm.utils.get_logging_payload( + kwargs=kwargs, + response_obj=completion_response, + start_time=start_time, + end_time=end_time, + ) + + payload["spend"] = response_cost + + if prisma_client is not None: + await prisma_client.insert_data(data=payload, table_name="spend") + tasks = [] tasks.append(_update_user_db()) tasks.append(_update_key_db()) + tasks.append(_insert_spend_log_to_db()) await asyncio.gather(*tasks) except Exception as e: verbose_proxy_logger.debug( diff --git a/litellm/proxy/schema.prisma b/litellm/proxy/schema.prisma index d2e338bd4..9049f953d 100644 --- a/litellm/proxy/schema.prisma +++ b/litellm/proxy/schema.prisma @@ -46,4 +46,5 @@ model LiteLLM_SpendLogs { response Json @default("{}") usage Json @default("{}") metadata Json @default("{}") + cache_hit String @default("") } \ No newline at end of file diff --git a/litellm/utils.py b/litellm/utils.py index f7cc5d2a5..b22a053ff 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -8423,3 +8423,66 @@ def print_args_passed_to_litellm(original_function, args, kwargs): except: # This should always be non blocking pass + + +def get_logging_payload(kwargs, response_obj, start_time, end_time): + from litellm.proxy._types import LiteLLM_SpendLogs + from pydantic import Json + + # standardize this function to be used across, s3, dynamoDB, langfuse logging + litellm_params = kwargs.get("litellm_params", {}) + metadata = ( + litellm_params.get("metadata", {}) or {} + ) # if litellm_params['metadata'] == None + messages = kwargs.get("messages") + optional_params = kwargs.get("optional_params", {}) + call_type = kwargs.get("call_type", "litellm.completion") + cache_hit = kwargs.get("cache_hit", False) + usage = response_obj["usage"] + id = response_obj.get("id", str(uuid.uuid4())) + + payload = { + "request_id": id, + "call_type": call_type, + "cache_hit": cache_hit, + "startTime": start_time, + "endTime": end_time, + "model": kwargs.get("model", ""), + "user": kwargs.get("user", ""), + "modelParameters": optional_params, + "messages": messages, + "response": response_obj, + "usage": usage, + "metadata": metadata, + } + + json_fields = [ + field + for field, field_type in LiteLLM_SpendLogs.__annotations__.items() + if field_type == Json or field_type == Optional[Json] + ] + str_fields = [ + field + for field, field_type in LiteLLM_SpendLogs.__annotations__.items() + if field_type == str or field_type == Optional[str] + ] + datetime_fields = [ + field + for field, field_type in LiteLLM_SpendLogs.__annotations__.items() + if field_type == datetime + ] + + for param in json_fields: + if param in payload and type(payload[param]) != Json: + if type(payload[param]) == ModelResponse: + payload[param] = payload[param].model_dump_json() + elif type(payload[param]) == Usage: + payload[param] = payload[param].model_dump_json() + else: + payload[param] = json.dumps(payload[param]) + + for param in str_fields: + if param in payload and type(payload[param]) != str: + payload[param] = str(payload[param]) + + return payload diff --git a/schema.prisma b/schema.prisma index ed69f67a7..a07dcad08 100644 --- a/schema.prisma +++ b/schema.prisma @@ -46,4 +46,5 @@ model LiteLLM_SpendLogs { response Json @default("{}") usage Json @default("{}") metadata Json @default("{}") + cache_hit String @default("") } \ No newline at end of file