fix cost tracking by tags

This commit is contained in:
Ishaan Jaff 2024-06-21 16:49:57 -07:00
parent 9fbc30d4f1
commit aa3f2b3cf9
4 changed files with 139 additions and 129 deletions

View file

@ -165,9 +165,10 @@ from litellm.proxy.secret_managers.aws_secret_manager import (
load_aws_secret_manager,
)
from litellm.proxy.secret_managers.google_kms import load_google_kms
from litellm.proxy.spend_reporting_endpoints.spend_management_endpoints import (
from litellm.proxy.spend_tracking.spend_management_endpoints import (
router as spend_management_router,
)
from litellm.proxy.spend_tracking.spend_tracking_utils import get_logging_payload
from litellm.proxy.utils import (
DBClient,
PrismaClient,
@ -180,7 +181,6 @@ from litellm.proxy.utils import (
encrypt_value,
get_error_message_str,
get_instance_fn,
get_logging_payload,
hash_token,
html_form,
missing_keys_html_form,

View file

@ -1,13 +1,14 @@
#### SPEND MANAGEMENT #####
from typing import Optional, List
from datetime import datetime, timedelta, timezone
from typing import List, Optional
import fastapi
from fastapi import APIRouter, Depends, Header, HTTPException, Request, status
import litellm
from litellm._logging import verbose_proxy_logger
from datetime import datetime, timedelta, timezone
from litellm.proxy.auth.user_api_key_auth import user_api_key_auth
import fastapi
from fastapi import Depends, Request, APIRouter, Header, status
from fastapi import HTTPException
from litellm.proxy._types import *
from litellm.proxy.auth.user_api_key_auth import user_api_key_auth
router = APIRouter()
@ -227,7 +228,7 @@ async def get_global_activity(
start_date_obj = datetime.strptime(start_date, "%Y-%m-%d")
end_date_obj = datetime.strptime(end_date, "%Y-%m-%d")
from litellm.proxy.proxy_server import prisma_client, llm_router
from litellm.proxy.proxy_server import llm_router, prisma_client
try:
if prisma_client is None:
@ -355,7 +356,7 @@ async def get_global_activity_model(
start_date_obj = datetime.strptime(start_date, "%Y-%m-%d")
end_date_obj = datetime.strptime(end_date, "%Y-%m-%d")
from litellm.proxy.proxy_server import prisma_client, llm_router, premium_user
from litellm.proxy.proxy_server import llm_router, premium_user, prisma_client
try:
if prisma_client is None:
@ -500,7 +501,7 @@ async def get_global_activity_exceptions_per_deployment(
start_date_obj = datetime.strptime(start_date, "%Y-%m-%d")
end_date_obj = datetime.strptime(end_date, "%Y-%m-%d")
from litellm.proxy.proxy_server import prisma_client, llm_router, premium_user
from litellm.proxy.proxy_server import llm_router, premium_user, prisma_client
try:
if prisma_client is None:
@ -634,7 +635,7 @@ async def get_global_activity_exceptions(
start_date_obj = datetime.strptime(start_date, "%Y-%m-%d")
end_date_obj = datetime.strptime(end_date, "%Y-%m-%d")
from litellm.proxy.proxy_server import prisma_client, llm_router
from litellm.proxy.proxy_server import llm_router, prisma_client
try:
if prisma_client is None:
@ -739,7 +740,7 @@ async def get_global_spend_provider(
start_date_obj = datetime.strptime(start_date, "%Y-%m-%d")
end_date_obj = datetime.strptime(end_date, "%Y-%m-%d")
from litellm.proxy.proxy_server import prisma_client, llm_router
from litellm.proxy.proxy_server import llm_router, prisma_client
try:
if prisma_client is None:
@ -1091,7 +1092,6 @@ async def global_view_spend_tags(
"""
from enterprise.utils import ui_get_spend_by_tags
from litellm.proxy.proxy_server import prisma_client
try:

View file

@ -0,0 +1,125 @@
import json
import traceback
from typing import Optional
import litellm
from litellm._logging import verbose_proxy_logger
from litellm.proxy._types import SpendLogsMetadata, SpendLogsPayload
from litellm.proxy.utils import hash_token
def get_logging_payload(
kwargs, response_obj, start_time, end_time, end_user_id: Optional[str]
) -> SpendLogsPayload:
from pydantic import Json
from litellm.proxy._types import LiteLLM_SpendLogs
verbose_proxy_logger.debug(
f"SpendTable: get_logging_payload - kwargs: {kwargs}\n\n"
)
if kwargs is None:
kwargs = {}
# standardize this function to be used across, s3, dynamoDB, langfuse logging
litellm_params = kwargs.get("litellm_params", {})
metadata = (
litellm_params.get("metadata", {}) or {}
) # if litellm_params['metadata'] == None
completion_start_time = kwargs.get("completion_start_time", end_time)
call_type = kwargs.get("call_type")
cache_hit = kwargs.get("cache_hit", False)
usage = response_obj["usage"]
if type(usage) == litellm.Usage:
usage = dict(usage)
id = response_obj.get("id", kwargs.get("litellm_call_id"))
api_key = metadata.get("user_api_key", "")
if api_key is not None and isinstance(api_key, str) and api_key.startswith("sk-"):
# hash the api_key
api_key = hash_token(api_key)
_model_id = metadata.get("model_info", {}).get("id", "")
_model_group = metadata.get("model_group", "")
request_tags = (
json.dumps(metadata.get("tags", []))
if isinstance(metadata.get("tags", []), dict)
else "[]"
)
# clean up litellm metadata
clean_metadata = SpendLogsMetadata(
user_api_key=None,
user_api_key_alias=None,
user_api_key_team_id=None,
user_api_key_user_id=None,
user_api_key_team_alias=None,
spend_logs_metadata=None,
)
if isinstance(metadata, dict):
verbose_proxy_logger.debug(
"getting payload for SpendLogs, available keys in metadata: "
+ str(list(metadata.keys()))
)
# Filter the metadata dictionary to include only the specified keys
clean_metadata = SpendLogsMetadata(
**{ # type: ignore
key: metadata[key]
for key in SpendLogsMetadata.__annotations__.keys()
if key in metadata
}
)
if litellm.cache is not None:
cache_key = litellm.cache.get_cache_key(**kwargs)
else:
cache_key = "Cache OFF"
if cache_hit is True:
import time
id = f"{id}_cache_hit{time.time()}" # SpendLogs does not allow duplicate request_id
try:
payload: SpendLogsPayload = SpendLogsPayload(
request_id=str(id),
call_type=call_type or "",
api_key=str(api_key),
cache_hit=str(cache_hit),
startTime=start_time,
endTime=end_time,
completionStartTime=completion_start_time,
model=kwargs.get("model", "") or "",
user=kwargs.get("litellm_params", {})
.get("metadata", {})
.get("user_api_key_user_id", "")
or "",
team_id=kwargs.get("litellm_params", {})
.get("metadata", {})
.get("user_api_key_team_id", "")
or "",
metadata=json.dumps(clean_metadata),
cache_key=cache_key,
spend=kwargs.get("response_cost", 0),
total_tokens=usage.get("total_tokens", 0),
prompt_tokens=usage.get("prompt_tokens", 0),
completion_tokens=usage.get("completion_tokens", 0),
request_tags=request_tags,
end_user=end_user_id or "",
api_base=litellm_params.get("api_base", ""),
model_group=_model_group,
model_id=_model_id,
)
verbose_proxy_logger.debug(
"SpendTable: created payload - payload: %s\n\n", payload
)
return payload
except Exception as e:
verbose_proxy_logger.error(
"Error creating spendlogs object - {}\n{}".format(
str(e), traceback.format_exc()
)
)
raise e

View file

@ -2005,121 +2005,6 @@ def hash_token(token: str):
return hashed_token
def get_logging_payload(
kwargs, response_obj, start_time, end_time, end_user_id: Optional[str]
) -> SpendLogsPayload:
from pydantic import Json
from litellm.proxy._types import LiteLLM_SpendLogs
verbose_proxy_logger.debug(
f"SpendTable: get_logging_payload - kwargs: {kwargs}\n\n"
)
if kwargs is None:
kwargs = {}
# standardize this function to be used across, s3, dynamoDB, langfuse logging
litellm_params = kwargs.get("litellm_params", {})
metadata = (
litellm_params.get("metadata", {}) or {}
) # if litellm_params['metadata'] == None
completion_start_time = kwargs.get("completion_start_time", end_time)
call_type = kwargs.get("call_type")
cache_hit = kwargs.get("cache_hit", False)
usage = response_obj["usage"]
if type(usage) == litellm.Usage:
usage = dict(usage)
id = response_obj.get("id", kwargs.get("litellm_call_id"))
api_key = metadata.get("user_api_key", "")
if api_key is not None and isinstance(api_key, str) and api_key.startswith("sk-"):
# hash the api_key
api_key = hash_token(api_key)
_model_id = metadata.get("model_info", {}).get("id", "")
_model_group = metadata.get("model_group", "")
# clean up litellm metadata
clean_metadata = SpendLogsMetadata(
user_api_key=None,
user_api_key_alias=None,
user_api_key_team_id=None,
user_api_key_user_id=None,
user_api_key_team_alias=None,
spend_logs_metadata=None,
)
if isinstance(metadata, dict):
verbose_proxy_logger.debug(
"getting payload for SpendLogs, available keys in metadata: "
+ str(list(metadata.keys()))
)
# Filter the metadata dictionary to include only the specified keys
clean_metadata = SpendLogsMetadata(
**{ # type: ignore
key: metadata[key]
for key in SpendLogsMetadata.__annotations__.keys()
if key in metadata
}
)
if litellm.cache is not None:
cache_key = litellm.cache.get_cache_key(**kwargs)
else:
cache_key = "Cache OFF"
if cache_hit is True:
import time
id = f"{id}_cache_hit{time.time()}" # SpendLogs does not allow duplicate request_id
try:
payload: SpendLogsPayload = SpendLogsPayload(
request_id=str(id),
call_type=call_type or "",
api_key=str(api_key),
cache_hit=str(cache_hit),
startTime=start_time,
endTime=end_time,
completionStartTime=completion_start_time,
model=kwargs.get("model", "") or "",
user=kwargs.get("litellm_params", {})
.get("metadata", {})
.get("user_api_key_user_id", "")
or "",
team_id=kwargs.get("litellm_params", {})
.get("metadata", {})
.get("user_api_key_team_id", "")
or "",
metadata=json.dumps(clean_metadata),
cache_key=cache_key,
spend=kwargs.get("response_cost", 0),
total_tokens=usage.get("total_tokens", 0),
prompt_tokens=usage.get("prompt_tokens", 0),
completion_tokens=usage.get("completion_tokens", 0),
request_tags=(
json.dumps(metadata.get("tags", []))
if isinstance(metadata.get("tags", []), dict)
else "[]"
),
end_user=end_user_id or "",
api_base=litellm_params.get("api_base", ""),
model_group=_model_group,
model_id=_model_id,
)
verbose_proxy_logger.debug(
"SpendTable: created payload - payload: %s\n\n", payload
)
return payload
except Exception as e:
verbose_proxy_logger.error(
"Error creating spendlogs object - {}\n{}".format(
str(e), traceback.format_exc()
)
)
raise e
def _extract_from_regex(duration: str) -> Tuple[int, str]:
match = re.match(r"(\d+)(mo|[smhd]?)", duration)