fix(proxy/utils.py): security fix

use typed dict for spendlogs payload. assert no sensitive information logged.
This commit is contained in:
Krrish Dholakia 2024-06-07 13:44:11 -07:00
parent 6024f9e45e
commit b41d7f5d51
4 changed files with 314 additions and 98 deletions

View file

@ -1,24 +1,29 @@
from typing import Optional, List, Any, Literal, Union
import os, subprocess, hashlib, importlib, asyncio, copy, json, aiohttp, httpx, time
import litellm, backoff, traceback
import os
import subprocess
import hashlib
import importlib
import asyncio
import copy
import json
import httpx
import time
import litellm
import backoff
import traceback
from pydantic import BaseModel
from litellm.proxy._types import (
UserAPIKeyAuth,
DynamoDBArgs,
LiteLLM_VerificationToken,
LiteLLM_VerificationTokenView,
LiteLLM_SpendLogs,
LiteLLM_UserTable,
LiteLLM_EndUserTable,
LiteLLM_TeamTable,
Member,
CallInfo,
WebhookEvent,
AlertType,
ResetTeamBudgetRequest,
LitellmUserRoles,
SpendLogsMetadata,
SpendLogsPayload,
)
from litellm.caching import DualCache, RedisCache
from litellm.router import Deployment, ModelInfo, LiteLLM_Params
from litellm.llms.custom_httpx.httpx_handler import HTTPHandler
from litellm.proxy.hooks.parallel_request_limiter import (
_PROXY_MaxParallelRequestsHandler,
@ -29,24 +34,18 @@ from litellm import (
ModelResponse,
EmbeddingResponse,
ImageResponse,
TranscriptionResponse,
TextCompletionResponse,
CustomStreamWrapper,
TextCompletionStreamWrapper,
)
from litellm.utils import ModelResponseIterator
from litellm.proxy.hooks.max_budget_limiter import _PROXY_MaxBudgetLimiter
from litellm.proxy.hooks.cache_control_check import _PROXY_CacheControlCheck
from litellm.integrations.custom_logger import CustomLogger
from litellm.proxy.db.base_client import CustomDB
from litellm._logging import verbose_proxy_logger
from fastapi import HTTPException, status
import smtplib, re
import smtplib
import re
from email.mime.text import MIMEText
from email.mime.multipart import MIMEMultipart
from datetime import datetime, timedelta
from litellm.integrations.slack_alerting import SlackAlerting
from typing_extensions import overload
def print_verbose(print_statement):
@ -1895,16 +1894,15 @@ def hash_token(token: str):
def get_logging_payload(
kwargs, response_obj, start_time, end_time, end_user_id: Optional[str]
):
) -> SpendLogsPayload:
from litellm.proxy._types import LiteLLM_SpendLogs
from pydantic import Json
import uuid
verbose_proxy_logger.debug(
f"SpendTable: get_logging_payload - kwargs: {kwargs}\n\n"
)
if kwargs == None:
if kwargs is None:
kwargs = {}
# standardize this function to be used across, s3, dynamoDB, langfuse logging
litellm_params = kwargs.get("litellm_params", {})
@ -1927,94 +1925,82 @@ def get_logging_payload(
_model_group = metadata.get("model_group", "")
# clean up litellm metadata
clean_metadata = SpendLogsMetadata(
user_api_key=None,
user_api_key_alias=None,
user_api_key_team_id=None,
user_api_key_user_id=None,
user_api_key_team_alias=None,
)
if isinstance(metadata, dict):
clean_metadata = {}
verbose_proxy_logger.debug(
f"getting payload for SpendLogs, available keys in metadata: "
"getting payload for SpendLogs, available keys in metadata: "
+ str(list(metadata.keys()))
)
for key in metadata:
if key in [
"headers",
"endpoint",
"model_group",
"deployment",
"model_info",
"caching_groups",
"previous_models",
]:
continue
else:
clean_metadata[key] = metadata[key]
# Filter the metadata dictionary to include only the specified keys
clean_metadata = SpendLogsMetadata(
**{ # type: ignore
key: metadata[key]
for key in SpendLogsMetadata.__annotations__.keys()
if key in metadata
}
)
if litellm.cache is not None:
cache_key = litellm.cache.get_cache_key(**kwargs)
else:
cache_key = "Cache OFF"
if cache_hit == True:
if cache_hit is True:
import time
id = f"{id}_cache_hit{time.time()}" # SpendLogs does not allow duplicate request_id
payload = {
"request_id": id,
"call_type": call_type,
"api_key": api_key,
"cache_hit": cache_hit,
"startTime": start_time,
"endTime": end_time,
"completionStartTime": completion_start_time,
"model": kwargs.get("model", ""),
"user": kwargs.get("litellm_params", {})
.get("metadata", {})
.get("user_api_key_user_id", ""),
"team_id": kwargs.get("litellm_params", {})
.get("metadata", {})
.get("user_api_key_team_id", ""),
"metadata": clean_metadata,
"cache_key": cache_key,
"spend": kwargs.get("response_cost", 0),
"total_tokens": usage.get("total_tokens", 0),
"prompt_tokens": usage.get("prompt_tokens", 0),
"completion_tokens": usage.get("completion_tokens", 0),
"request_tags": metadata.get("tags", []),
"end_user": end_user_id or "",
"api_base": litellm_params.get("api_base", ""),
"model_group": _model_group,
"model_id": _model_id,
}
try:
payload: SpendLogsPayload = SpendLogsPayload(
request_id=str(id),
call_type=call_type or "",
api_key=str(api_key),
cache_hit=str(cache_hit),
startTime=start_time,
endTime=end_time,
completionStartTime=completion_start_time,
model=kwargs.get("model", ""),
user=kwargs.get("litellm_params", {})
.get("metadata", {})
.get("user_api_key_user_id", ""),
team_id=kwargs.get("litellm_params", {})
.get("metadata", {})
.get("user_api_key_team_id", ""),
metadata=json.dumps(clean_metadata),
cache_key=cache_key,
spend=kwargs.get("response_cost", 0),
total_tokens=usage.get("total_tokens", 0),
prompt_tokens=usage.get("prompt_tokens", 0),
completion_tokens=usage.get("completion_tokens", 0),
request_tags=(
json.dumps(metadata.get("tags", []))
if isinstance(metadata.get("tags", []), dict)
else "[]"
),
end_user=end_user_id or "",
api_base=litellm_params.get("api_base", ""),
model_group=_model_group,
model_id=_model_id,
)
verbose_proxy_logger.debug("SpendTable: created payload - payload: %s\n\n", payload)
json_fields = [
field
for field, field_type in LiteLLM_SpendLogs.__annotations__.items()
if field_type == Json or field_type == Optional[Json]
]
str_fields = [
field
for field, field_type in LiteLLM_SpendLogs.__annotations__.items()
if field_type == str or field_type == Optional[str]
]
datetime_fields = [
field
for field, field_type in LiteLLM_SpendLogs.__annotations__.items()
if field_type == datetime
]
verbose_proxy_logger.debug(
"SpendTable: created payload - payload: %s\n\n", payload
)
for param in json_fields:
if param in payload and type(payload[param]) != Json:
if type(payload[param]) == litellm.ModelResponse:
payload[param] = payload[param].model_dump_json()
if type(payload[param]) == litellm.EmbeddingResponse:
payload[param] = payload[param].model_dump_json()
else:
payload[param] = json.dumps(payload[param])
for param in str_fields:
if param in payload and type(payload[param]) != str:
payload[param] = str(payload[param])
return payload
return payload
except Exception as e:
verbose_proxy_logger.error(
"Error creating spendlogs object - {}\n{}".format(
str(e), traceback.format_exc()
)
)
raise e
def _duration_in_seconds(duration: str):