forked from phoenix/litellm-mirror
Merge pull request #1498 from BerriAI/litellm_spend_tracking_logs
[Feat] Proxy - Add Spend tracking logs
This commit is contained in:
commit
a26267851f
8 changed files with 231 additions and 14 deletions
|
@ -1,8 +1,8 @@
|
|||
from pydantic import BaseModel, Extra, Field, root_validator
|
||||
from pydantic import BaseModel, Extra, Field, root_validator, Json
|
||||
import enum
|
||||
from typing import Optional, List, Union, Dict, Literal
|
||||
from typing import Optional, List, Union, Dict, Literal, Any
|
||||
from datetime import datetime
|
||||
import uuid, json
|
||||
import uuid, json, sys, os
|
||||
|
||||
|
||||
class LiteLLMBase(BaseModel):
|
||||
|
@ -196,6 +196,7 @@ class DynamoDBArgs(LiteLLMBase):
|
|||
user_table_name: str = "LiteLLM_UserTable"
|
||||
key_table_name: str = "LiteLLM_VerificationToken"
|
||||
config_table_name: str = "LiteLLM_Config"
|
||||
spend_table_name: str = "LiteLLM_SpendLogs"
|
||||
|
||||
|
||||
class ConfigGeneralSettings(LiteLLMBase):
|
||||
|
@ -314,3 +315,20 @@ class LiteLLM_UserTable(LiteLLMBase):
|
|||
if values.get("models") is None:
|
||||
values.update({"models", []})
|
||||
return values
|
||||
|
||||
|
||||
class LiteLLM_SpendLogs(LiteLLMBase):
|
||||
request_id: str
|
||||
api_key: str
|
||||
model: Optional[str] = ""
|
||||
call_type: str
|
||||
spend: Optional[float] = 0.0
|
||||
startTime: Union[str, datetime, None]
|
||||
endTime: Union[str, datetime, None]
|
||||
user: Optional[str] = ""
|
||||
modelParameters: Optional[Json] = {}
|
||||
messages: Optional[Json] = []
|
||||
response: Optional[Json] = {}
|
||||
usage: Optional[Json] = {}
|
||||
metadata: Optional[Json] = {}
|
||||
cache_hit: Optional[str] = "False"
|
||||
|
|
|
@ -131,10 +131,27 @@ class DynamoDBWrapper(CustomDB):
|
|||
raise Exception(
|
||||
f"Failed to create table - {self.database_arguments.config_table_name}.\nPlease create a new table called {self.database_arguments.config_table_name}\nAND set `hash_key` as 'param_name'"
|
||||
)
|
||||
|
||||
## Spend
|
||||
try:
|
||||
verbose_proxy_logger.debug("DynamoDB Wrapper - Creating Spend Table")
|
||||
error_occurred = False
|
||||
table = client.table(self.database_arguments.spend_table_name)
|
||||
if not await table.exists():
|
||||
await table.create(
|
||||
self.throughput_type,
|
||||
KeySchema(hash_key=KeySpec("request_id", KeyType.string)),
|
||||
)
|
||||
except Exception as e:
|
||||
error_occurred = True
|
||||
if error_occurred == True:
|
||||
raise Exception(
|
||||
f"Failed to create table - {self.database_arguments.key_table_name}.\nPlease create a new table called {self.database_arguments.key_table_name}\nAND set `hash_key` as 'token'"
|
||||
)
|
||||
verbose_proxy_logger.debug("DynamoDB Wrapper - Done connecting()")
|
||||
|
||||
async def insert_data(
|
||||
self, value: Any, table_name: Literal["user", "key", "config"]
|
||||
self, value: Any, table_name: Literal["user", "key", "config", "spend"]
|
||||
):
|
||||
from aiodynamo.client import Client
|
||||
from aiodynamo.credentials import Credentials, StaticCredentials
|
||||
|
@ -166,6 +183,8 @@ class DynamoDBWrapper(CustomDB):
|
|||
table = client.table(self.database_arguments.key_table_name)
|
||||
elif table_name == "config":
|
||||
table = client.table(self.database_arguments.config_table_name)
|
||||
elif table_name == "spend":
|
||||
table = client.table(self.database_arguments.spend_table_name)
|
||||
|
||||
for k, v in value.items():
|
||||
if isinstance(v, datetime):
|
||||
|
|
|
@ -61,8 +61,8 @@ litellm_settings:
|
|||
# setting callback class
|
||||
# callbacks: custom_callbacks.proxy_handler_instance # sets litellm.callbacks = [proxy_handler_instance]
|
||||
|
||||
# general_settings:
|
||||
# master_key: sk-1234
|
||||
general_settings:
|
||||
master_key: sk-1234
|
||||
# database_type: "dynamo_db"
|
||||
# database_args: { # 👈 all args - https://github.com/BerriAI/litellm/blob/befbcbb7ac8f59835ce47415c128decf37aac328/litellm/proxy/_types.py#L190
|
||||
# "billing_mode": "PAY_PER_REQUEST",
|
||||
|
|
|
@ -72,6 +72,7 @@ from litellm.proxy.utils import (
|
|||
ProxyLogging,
|
||||
_cache_user_row,
|
||||
send_email,
|
||||
get_logging_payload,
|
||||
)
|
||||
from litellm.proxy.secret_managers.google_kms import load_google_kms
|
||||
import pydantic
|
||||
|
@ -518,6 +519,7 @@ async def track_cost_callback(
|
|||
global prisma_client, custom_db_client
|
||||
try:
|
||||
# check if it has collected an entire stream response
|
||||
verbose_proxy_logger.debug(f"Proxy: In track_cost_callback for {kwargs}")
|
||||
verbose_proxy_logger.debug(
|
||||
f"kwargs stream: {kwargs.get('stream', None)} + complete streaming response: {kwargs.get('complete_streaming_response', None)}"
|
||||
)
|
||||
|
@ -538,7 +540,13 @@ async def track_cost_callback(
|
|||
prisma_client is not None or custom_db_client is not None
|
||||
):
|
||||
await update_database(
|
||||
token=user_api_key, response_cost=response_cost, user_id=user_id
|
||||
token=user_api_key,
|
||||
response_cost=response_cost,
|
||||
user_id=user_id,
|
||||
kwargs=kwargs,
|
||||
completion_response=completion_response,
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
)
|
||||
elif kwargs["stream"] == False: # for non streaming responses
|
||||
response_cost = litellm.completion_cost(
|
||||
|
@ -554,13 +562,27 @@ async def track_cost_callback(
|
|||
prisma_client is not None or custom_db_client is not None
|
||||
):
|
||||
await update_database(
|
||||
token=user_api_key, response_cost=response_cost, user_id=user_id
|
||||
token=user_api_key,
|
||||
response_cost=response_cost,
|
||||
user_id=user_id,
|
||||
kwargs=kwargs,
|
||||
completion_response=completion_response,
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
)
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.debug(f"error in tracking cost callback - {str(e)}")
|
||||
|
||||
|
||||
async def update_database(token, response_cost, user_id=None):
|
||||
async def update_database(
|
||||
token,
|
||||
response_cost,
|
||||
user_id=None,
|
||||
kwargs=None,
|
||||
completion_response=None,
|
||||
start_time=None,
|
||||
end_time=None,
|
||||
):
|
||||
try:
|
||||
verbose_proxy_logger.debug(
|
||||
f"Enters prisma db call, token: {token}; user_id: {user_id}"
|
||||
|
@ -630,9 +652,28 @@ async def update_database(token, response_cost, user_id=None):
|
|||
key=token, value={"spend": new_spend}, table_name="key"
|
||||
)
|
||||
|
||||
async def _insert_spend_log_to_db():
|
||||
# Helper to generate payload to log
|
||||
verbose_proxy_logger.debug("inserting spend log to db")
|
||||
payload = get_logging_payload(
|
||||
kwargs=kwargs,
|
||||
response_obj=completion_response,
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
)
|
||||
|
||||
payload["spend"] = response_cost
|
||||
|
||||
if prisma_client is not None:
|
||||
await prisma_client.insert_data(data=payload, table_name="spend")
|
||||
|
||||
elif custom_db_client is not None:
|
||||
await custom_db_client.insert_data(payload, table_name="spend")
|
||||
|
||||
tasks = []
|
||||
tasks.append(_update_user_db())
|
||||
tasks.append(_update_key_db())
|
||||
tasks.append(_insert_spend_log_to_db())
|
||||
await asyncio.gather(*tasks)
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.debug(
|
||||
|
|
|
@ -31,4 +31,21 @@ model LiteLLM_VerificationToken {
|
|||
model LiteLLM_Config {
|
||||
param_name String @id
|
||||
param_value Json?
|
||||
}
|
||||
|
||||
model LiteLLM_SpendLogs {
|
||||
request_id String @unique
|
||||
call_type String
|
||||
api_key String @default ("")
|
||||
spend Float @default(0.0)
|
||||
startTime DateTime // Assuming start_time is a DateTime field
|
||||
endTime DateTime // Assuming end_time is a DateTime field
|
||||
model String @default("")
|
||||
user String @default("")
|
||||
modelParameters Json @default("{}")// Assuming optional_params is a JSON field
|
||||
messages Json @default("[]")
|
||||
response Json @default("{}")
|
||||
usage Json @default("{}")
|
||||
metadata Json @default("{}")
|
||||
cache_hit String @default("")
|
||||
}
|
|
@ -1,7 +1,12 @@
|
|||
from typing import Optional, List, Any, Literal, Union
|
||||
import os, subprocess, hashlib, importlib, asyncio, copy, json, aiohttp, httpx
|
||||
import litellm, backoff
|
||||
from litellm.proxy._types import UserAPIKeyAuth, DynamoDBArgs, LiteLLM_VerificationToken
|
||||
from litellm.proxy._types import (
|
||||
UserAPIKeyAuth,
|
||||
DynamoDBArgs,
|
||||
LiteLLM_VerificationToken,
|
||||
LiteLLM_SpendLogs,
|
||||
)
|
||||
from litellm.caching import DualCache
|
||||
from litellm.proxy.hooks.parallel_request_limiter import MaxParallelRequestsHandler
|
||||
from litellm.proxy.hooks.max_budget_limiter import MaxBudgetLimiter
|
||||
|
@ -316,7 +321,7 @@ class PrismaClient:
|
|||
self,
|
||||
key: str,
|
||||
value: Any,
|
||||
table_name: Literal["users", "keys", "config"],
|
||||
table_name: Literal["users", "keys", "config", "spend"],
|
||||
):
|
||||
"""
|
||||
Generic implementation of get data
|
||||
|
@ -334,6 +339,10 @@ class PrismaClient:
|
|||
response = await self.db.litellm_config.find_first( # type: ignore
|
||||
where={key: value} # type: ignore
|
||||
)
|
||||
elif table_name == "spend":
|
||||
response = await self.db.l.find_first( # type: ignore
|
||||
where={key: value} # type: ignore
|
||||
)
|
||||
return response
|
||||
except Exception as e:
|
||||
asyncio.create_task(
|
||||
|
@ -417,7 +426,7 @@ class PrismaClient:
|
|||
on_backoff=on_backoff, # specifying the function to call on backoff
|
||||
)
|
||||
async def insert_data(
|
||||
self, data: dict, table_name: Literal["user", "key", "config"]
|
||||
self, data: dict, table_name: Literal["user", "key", "config", "spend"]
|
||||
):
|
||||
"""
|
||||
Add a key to the database. If it already exists, do nothing.
|
||||
|
@ -473,8 +482,18 @@ class PrismaClient:
|
|||
)
|
||||
|
||||
tasks.append(updated_table_row)
|
||||
|
||||
await asyncio.gather(*tasks)
|
||||
elif table_name == "spend":
|
||||
db_data = self.jsonify_object(data=data)
|
||||
new_spend_row = await self.db.litellm_spendlogs.upsert(
|
||||
where={"request_id": data["request_id"]},
|
||||
data={
|
||||
"create": {**db_data}, # type: ignore
|
||||
"update": {}, # don't do anything if it already exists
|
||||
},
|
||||
)
|
||||
return new_spend_row
|
||||
|
||||
except Exception as e:
|
||||
print_verbose(f"LiteLLM Prisma Client Exception: {e}")
|
||||
asyncio.create_task(
|
||||
|
@ -760,3 +779,85 @@ async def send_email(sender_name, sender_email, receiver_email, subject, html):
|
|||
|
||||
except Exception as e:
|
||||
print_verbose("An error occurred while sending the email:", str(e))
|
||||
|
||||
|
||||
def hash_token(token: str):
|
||||
import hashlib
|
||||
|
||||
# Hash the string using SHA-256
|
||||
hashed_token = hashlib.sha256(token.encode()).hexdigest()
|
||||
|
||||
return hashed_token
|
||||
|
||||
|
||||
def get_logging_payload(kwargs, response_obj, start_time, end_time):
|
||||
from litellm.proxy._types import LiteLLM_SpendLogs
|
||||
from pydantic import Json
|
||||
import uuid
|
||||
|
||||
if kwargs == None:
|
||||
kwargs = {}
|
||||
# standardize this function to be used across, s3, dynamoDB, langfuse logging
|
||||
litellm_params = kwargs.get("litellm_params", {})
|
||||
metadata = (
|
||||
litellm_params.get("metadata", {}) or {}
|
||||
) # if litellm_params['metadata'] == None
|
||||
messages = kwargs.get("messages")
|
||||
optional_params = kwargs.get("optional_params", {})
|
||||
call_type = kwargs.get("call_type", "litellm.completion")
|
||||
cache_hit = kwargs.get("cache_hit", False)
|
||||
usage = response_obj["usage"]
|
||||
id = response_obj.get("id", str(uuid.uuid4()))
|
||||
api_key = metadata.get("user_api_key", "")
|
||||
if api_key is not None and type(api_key) == str:
|
||||
# hash the api_key
|
||||
api_key = hash_token(api_key)
|
||||
|
||||
payload = {
|
||||
"request_id": id,
|
||||
"call_type": call_type,
|
||||
"api_key": api_key,
|
||||
"cache_hit": cache_hit,
|
||||
"startTime": start_time,
|
||||
"endTime": end_time,
|
||||
"model": kwargs.get("model", ""),
|
||||
"user": kwargs.get("user", ""),
|
||||
"modelParameters": optional_params,
|
||||
"messages": messages,
|
||||
"response": response_obj,
|
||||
"usage": usage,
|
||||
"metadata": metadata,
|
||||
}
|
||||
|
||||
json_fields = [
|
||||
field
|
||||
for field, field_type in LiteLLM_SpendLogs.__annotations__.items()
|
||||
if field_type == Json or field_type == Optional[Json]
|
||||
]
|
||||
str_fields = [
|
||||
field
|
||||
for field, field_type in LiteLLM_SpendLogs.__annotations__.items()
|
||||
if field_type == str or field_type == Optional[str]
|
||||
]
|
||||
datetime_fields = [
|
||||
field
|
||||
for field, field_type in LiteLLM_SpendLogs.__annotations__.items()
|
||||
if field_type == datetime
|
||||
]
|
||||
|
||||
for param in json_fields:
|
||||
if param in payload and type(payload[param]) != Json:
|
||||
if type(payload[param]) == litellm.ModelResponse:
|
||||
payload[param] = payload[param].model_dump_json()
|
||||
if type(payload[param]) == litellm.EmbeddingResponse:
|
||||
payload[param] = payload[param].model_dump_json()
|
||||
elif type(payload[param]) == litellm.Usage:
|
||||
payload[param] = payload[param].model_dump_json()
|
||||
else:
|
||||
payload[param] = json.dumps(payload[param])
|
||||
|
||||
for param in str_fields:
|
||||
if param in payload and type(payload[param]) != str:
|
||||
payload[param] = str(payload[param])
|
||||
|
||||
return payload
|
||||
|
|
|
@ -179,6 +179,10 @@ def test_call_with_key_over_budget(custom_db_client):
|
|||
# 5. Make a call with a key over budget, expect to fail
|
||||
setattr(litellm.proxy.proxy_server, "custom_db_client", custom_db_client)
|
||||
setattr(litellm.proxy.proxy_server, "master_key", "sk-1234")
|
||||
from litellm._logging import verbose_proxy_logger
|
||||
import logging
|
||||
|
||||
verbose_proxy_logger.setLevel(logging.DEBUG)
|
||||
try:
|
||||
|
||||
async def test():
|
||||
|
|
|
@ -31,4 +31,21 @@ model LiteLLM_VerificationToken {
|
|||
model LiteLLM_Config {
|
||||
param_name String @id
|
||||
param_value Json?
|
||||
}
|
||||
}
|
||||
|
||||
model LiteLLM_SpendLogs {
|
||||
request_id String @unique
|
||||
api_key String @default ("")
|
||||
call_type String
|
||||
spend Float @default(0.0)
|
||||
startTime DateTime // Assuming start_time is a DateTime field
|
||||
endTime DateTime // Assuming end_time is a DateTime field
|
||||
model String @default("")
|
||||
user String @default("")
|
||||
modelParameters Json @default("{}")// Assuming optional_params is a JSON field
|
||||
messages Json @default("[]")
|
||||
response Json @default("{}")
|
||||
usage Json @default("{}")
|
||||
metadata Json @default("{}")
|
||||
cache_hit String @default("")
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue