fix(proxy/utils.py): security fix

use typed dict for spendlogs payload. assert no sensitive information logged.
This commit is contained in:
Krrish Dholakia 2024-06-07 13:44:11 -07:00
parent 78474b1ce7
commit f7f8bcb21b
4 changed files with 314 additions and 98 deletions

View file

@ -1,7 +1,7 @@
from pydantic import BaseModel, Extra, Field, model_validator, Json, ConfigDict from pydantic import BaseModel, Extra, Field, model_validator, Json, ConfigDict
from dataclasses import fields from dataclasses import fields
import enum import enum
from typing import Optional, List, Union, Dict, Literal, Any from typing import Optional, List, Union, Dict, Literal, Any, TypedDict
from datetime import datetime from datetime import datetime
import uuid, json, sys, os import uuid, json, sys, os
from litellm.types.router import UpdateRouterConfig from litellm.types.router import UpdateRouterConfig
@ -1268,7 +1268,7 @@ class LiteLLM_SpendLogs(LiteLLMBase):
startTime: Union[str, datetime, None] startTime: Union[str, datetime, None]
endTime: Union[str, datetime, None] endTime: Union[str, datetime, None]
user: Optional[str] = "" user: Optional[str] = ""
metadata: Optional[dict] = {} metadata: Optional[Json] = {}
cache_hit: Optional[str] = "False" cache_hit: Optional[str] = "False"
cache_key: Optional[str] = None cache_key: Optional[str] = None
request_tags: Optional[Json] = None request_tags: Optional[Json] = None
@ -1446,3 +1446,39 @@ class AllCallbacks(LiteLLMBase):
litellm_callback_params=["DD_API_KEY", "DD_SITE"], litellm_callback_params=["DD_API_KEY", "DD_SITE"],
ui_callback_name="Datadog", ui_callback_name="Datadog",
) )
class SpendLogsMetadata(TypedDict):
"""
Specific metadata k,v pairs logged to spendlogs for easier cost tracking
"""
user_api_key: Optional[str]
user_api_key_alias: Optional[str]
user_api_key_team_id: Optional[str]
user_api_key_user_id: Optional[str]
user_api_key_team_alias: Optional[str]
class SpendLogsPayload(TypedDict):
request_id: str
call_type: str
api_key: str
spend: float
total_tokens: int
prompt_tokens: int
completion_tokens: int
startTime: datetime
endTime: datetime
completionStartTime: Optional[datetime]
model: str
model_id: Optional[str]
model_group: Optional[str]
api_base: str
user: str
metadata: str # json str
cache_hit: str
cache_key: str
request_tags: str # json str
team_id: Optional[str]
end_user: Optional[str]

View file

@ -168,11 +168,11 @@ model LiteLLM_Config {
param_value Json? param_value Json?
} }
// View spend, model, api_key per request // View spend, model, hashed api_key per request
model LiteLLM_SpendLogs { model LiteLLM_SpendLogs {
request_id String @id request_id String @id
call_type String call_type String
api_key String @default ("") api_key String @default ("") // Hashed API Token. Not the actual Virtual Key. Equivalent to 'token' column in LiteLLM_VerificationToken
spend Float @default(0.0) spend Float @default(0.0)
total_tokens Int @default(0) total_tokens Int @default(0)
prompt_tokens Int @default(0) prompt_tokens Int @default(0)

View file

@ -1,24 +1,29 @@
from typing import Optional, List, Any, Literal, Union from typing import Optional, List, Any, Literal, Union
import os, subprocess, hashlib, importlib, asyncio, copy, json, aiohttp, httpx, time import os
import litellm, backoff, traceback import subprocess
import hashlib
import importlib
import asyncio
import copy
import json
import httpx
import time
import litellm
import backoff
import traceback
from pydantic import BaseModel
from litellm.proxy._types import ( from litellm.proxy._types import (
UserAPIKeyAuth, UserAPIKeyAuth,
DynamoDBArgs, DynamoDBArgs,
LiteLLM_VerificationToken,
LiteLLM_VerificationTokenView, LiteLLM_VerificationTokenView,
LiteLLM_SpendLogs,
LiteLLM_UserTable,
LiteLLM_EndUserTable,
LiteLLM_TeamTable,
Member,
CallInfo, CallInfo,
WebhookEvent,
AlertType, AlertType,
ResetTeamBudgetRequest, ResetTeamBudgetRequest,
LitellmUserRoles, LitellmUserRoles,
SpendLogsMetadata,
SpendLogsPayload,
) )
from litellm.caching import DualCache, RedisCache from litellm.caching import DualCache, RedisCache
from litellm.router import Deployment, ModelInfo, LiteLLM_Params
from litellm.llms.custom_httpx.httpx_handler import HTTPHandler from litellm.llms.custom_httpx.httpx_handler import HTTPHandler
from litellm.proxy.hooks.parallel_request_limiter import ( from litellm.proxy.hooks.parallel_request_limiter import (
_PROXY_MaxParallelRequestsHandler, _PROXY_MaxParallelRequestsHandler,
@ -29,24 +34,18 @@ from litellm import (
ModelResponse, ModelResponse,
EmbeddingResponse, EmbeddingResponse,
ImageResponse, ImageResponse,
TranscriptionResponse,
TextCompletionResponse,
CustomStreamWrapper,
TextCompletionStreamWrapper,
) )
from litellm.utils import ModelResponseIterator
from litellm.proxy.hooks.max_budget_limiter import _PROXY_MaxBudgetLimiter from litellm.proxy.hooks.max_budget_limiter import _PROXY_MaxBudgetLimiter
from litellm.proxy.hooks.cache_control_check import _PROXY_CacheControlCheck from litellm.proxy.hooks.cache_control_check import _PROXY_CacheControlCheck
from litellm.integrations.custom_logger import CustomLogger from litellm.integrations.custom_logger import CustomLogger
from litellm.proxy.db.base_client import CustomDB
from litellm._logging import verbose_proxy_logger from litellm._logging import verbose_proxy_logger
from fastapi import HTTPException, status from fastapi import HTTPException, status
import smtplib, re import smtplib
import re
from email.mime.text import MIMEText from email.mime.text import MIMEText
from email.mime.multipart import MIMEMultipart from email.mime.multipart import MIMEMultipart
from datetime import datetime, timedelta from datetime import datetime, timedelta
from litellm.integrations.slack_alerting import SlackAlerting from litellm.integrations.slack_alerting import SlackAlerting
from typing_extensions import overload
def print_verbose(print_statement): def print_verbose(print_statement):
@ -1895,16 +1894,15 @@ def hash_token(token: str):
def get_logging_payload( def get_logging_payload(
kwargs, response_obj, start_time, end_time, end_user_id: Optional[str] kwargs, response_obj, start_time, end_time, end_user_id: Optional[str]
): ) -> SpendLogsPayload:
from litellm.proxy._types import LiteLLM_SpendLogs from litellm.proxy._types import LiteLLM_SpendLogs
from pydantic import Json from pydantic import Json
import uuid
verbose_proxy_logger.debug( verbose_proxy_logger.debug(
f"SpendTable: get_logging_payload - kwargs: {kwargs}\n\n" f"SpendTable: get_logging_payload - kwargs: {kwargs}\n\n"
) )
if kwargs == None: if kwargs is None:
kwargs = {} kwargs = {}
# standardize this function to be used across, s3, dynamoDB, langfuse logging # standardize this function to be used across, s3, dynamoDB, langfuse logging
litellm_params = kwargs.get("litellm_params", {}) litellm_params = kwargs.get("litellm_params", {})
@ -1927,94 +1925,82 @@ def get_logging_payload(
_model_group = metadata.get("model_group", "") _model_group = metadata.get("model_group", "")
# clean up litellm metadata # clean up litellm metadata
clean_metadata = SpendLogsMetadata(
user_api_key=None,
user_api_key_alias=None,
user_api_key_team_id=None,
user_api_key_user_id=None,
user_api_key_team_alias=None,
)
if isinstance(metadata, dict): if isinstance(metadata, dict):
clean_metadata = {}
verbose_proxy_logger.debug( verbose_proxy_logger.debug(
f"getting payload for SpendLogs, available keys in metadata: " "getting payload for SpendLogs, available keys in metadata: "
+ str(list(metadata.keys())) + str(list(metadata.keys()))
) )
for key in metadata:
if key in [ # Filter the metadata dictionary to include only the specified keys
"headers", clean_metadata = SpendLogsMetadata(
"endpoint", **{ # type: ignore
"model_group", key: metadata[key]
"deployment", for key in SpendLogsMetadata.__annotations__.keys()
"model_info", if key in metadata
"caching_groups", }
"previous_models", )
]:
continue
else:
clean_metadata[key] = metadata[key]
if litellm.cache is not None: if litellm.cache is not None:
cache_key = litellm.cache.get_cache_key(**kwargs) cache_key = litellm.cache.get_cache_key(**kwargs)
else: else:
cache_key = "Cache OFF" cache_key = "Cache OFF"
if cache_hit == True: if cache_hit is True:
import time import time
id = f"{id}_cache_hit{time.time()}" # SpendLogs does not allow duplicate request_id id = f"{id}_cache_hit{time.time()}" # SpendLogs does not allow duplicate request_id
payload = { try:
"request_id": id, payload: SpendLogsPayload = SpendLogsPayload(
"call_type": call_type, request_id=str(id),
"api_key": api_key, call_type=call_type or "",
"cache_hit": cache_hit, api_key=str(api_key),
"startTime": start_time, cache_hit=str(cache_hit),
"endTime": end_time, startTime=start_time,
"completionStartTime": completion_start_time, endTime=end_time,
"model": kwargs.get("model", ""), completionStartTime=completion_start_time,
"user": kwargs.get("litellm_params", {}) model=kwargs.get("model", ""),
.get("metadata", {}) user=kwargs.get("litellm_params", {})
.get("user_api_key_user_id", ""), .get("metadata", {})
"team_id": kwargs.get("litellm_params", {}) .get("user_api_key_user_id", ""),
.get("metadata", {}) team_id=kwargs.get("litellm_params", {})
.get("user_api_key_team_id", ""), .get("metadata", {})
"metadata": clean_metadata, .get("user_api_key_team_id", ""),
"cache_key": cache_key, metadata=json.dumps(clean_metadata),
"spend": kwargs.get("response_cost", 0), cache_key=cache_key,
"total_tokens": usage.get("total_tokens", 0), spend=kwargs.get("response_cost", 0),
"prompt_tokens": usage.get("prompt_tokens", 0), total_tokens=usage.get("total_tokens", 0),
"completion_tokens": usage.get("completion_tokens", 0), prompt_tokens=usage.get("prompt_tokens", 0),
"request_tags": metadata.get("tags", []), completion_tokens=usage.get("completion_tokens", 0),
"end_user": end_user_id or "", request_tags=(
"api_base": litellm_params.get("api_base", ""), json.dumps(metadata.get("tags", []))
"model_group": _model_group, if isinstance(metadata.get("tags", []), dict)
"model_id": _model_id, else "[]"
} ),
end_user=end_user_id or "",
api_base=litellm_params.get("api_base", ""),
model_group=_model_group,
model_id=_model_id,
)
verbose_proxy_logger.debug("SpendTable: created payload - payload: %s\n\n", payload) verbose_proxy_logger.debug(
json_fields = [ "SpendTable: created payload - payload: %s\n\n", payload
field )
for field, field_type in LiteLLM_SpendLogs.__annotations__.items()
if field_type == Json or field_type == Optional[Json]
]
str_fields = [
field
for field, field_type in LiteLLM_SpendLogs.__annotations__.items()
if field_type == str or field_type == Optional[str]
]
datetime_fields = [
field
for field, field_type in LiteLLM_SpendLogs.__annotations__.items()
if field_type == datetime
]
for param in json_fields: return payload
if param in payload and type(payload[param]) != Json: except Exception as e:
if type(payload[param]) == litellm.ModelResponse: verbose_proxy_logger.error(
payload[param] = payload[param].model_dump_json() "Error creating spendlogs object - {}\n{}".format(
if type(payload[param]) == litellm.EmbeddingResponse: str(e), traceback.format_exc()
payload[param] = payload[param].model_dump_json() )
else: )
payload[param] = json.dumps(payload[param]) raise e
for param in str_fields:
if param in payload and type(payload[param]) != str:
payload[param] = str(payload[param])
return payload
def _duration_in_seconds(duration: str): def _duration_in_seconds(duration: str):

View file

@ -0,0 +1,194 @@
import sys, os
import traceback, uuid
from dotenv import load_dotenv
from fastapi import Request
from fastapi.routing import APIRoute
load_dotenv()
import os, io, time
# this file is to test litellm/proxy
sys.path.insert(
0, os.path.abspath("../..")
) # Adds the parent directory to the system path
import pytest, logging, asyncio
import litellm, asyncio
import json
import datetime
from litellm.proxy.utils import (
get_logging_payload,
SpendLogsPayload,
SpendLogsMetadata,
) # noqa: E402
def test_spend_logs_payload():
"""
Ensure only expected values are logged in spend logs payload.
"""
input_args: dict = {
"kwargs": {
"model": "chatgpt-v-2",
"messages": [
{"role": "system", "content": "you are a helpful assistant.\n"},
{"role": "user", "content": "bom dia"},
],
"optional_params": {
"stream": False,
"max_tokens": 10,
"user": "116544810872468347480",
"extra_body": {},
},
"litellm_params": {
"acompletion": True,
"api_key": "23c217a5b59f41b6b7a198017f4792f2",
"force_timeout": 600,
"logger_fn": None,
"verbose": False,
"custom_llm_provider": "azure",
"api_base": "https://openai-gpt-4-test-v-1.openai.azure.com//openai/",
"litellm_call_id": "b9929bf6-7b80-4c8c-b486-034e6ac0c8b7",
"model_alias_map": {},
"completion_call_id": None,
"metadata": {
"user_api_key": "88dc28d0f030c55ed4ab77ed8faf098196cb1c05df778539800c9f1243fe6b4b",
"user_api_key_alias": None,
"user_api_end_user_max_budget": None,
"litellm_api_version": "0.0.0",
"global_max_parallel_requests": None,
"user_api_key_user_id": "116544810872468347480",
"user_api_key_org_id": None,
"user_api_key_team_id": None,
"user_api_key_team_alias": None,
"user_api_key_metadata": {},
"headers": {
"content-type": "application/json",
"user-agent": "PostmanRuntime/7.32.3",
"accept": "*/*",
"postman-token": "92300061-eeaa-423b-a420-0b44896ecdc4",
"host": "localhost:4000",
"accept-encoding": "gzip, deflate, br",
"connection": "keep-alive",
"content-length": "163",
},
"endpoint": "http://localhost:4000/chat/completions",
"model_group": "gpt-3.5-turbo",
"deployment": "azure/chatgpt-v-2",
"model_info": {
"id": "4bad40a1eb6bebd1682800f16f44b9f06c52a6703444c99c7f9f32e9de3693b4",
"db_model": False,
},
"api_base": "https://openai-gpt-4-test-v-1.openai.azure.com/",
"caching_groups": None,
"raw_request": "\n\nPOST Request Sent from LiteLLM:\ncurl -X POST \\\nhttps://openai-gpt-4-test-v-1.openai.azure.com//openai/ \\\n-H 'Authorization: *****' \\\n-d '{'model': 'chatgpt-v-2', 'messages': [{'role': 'system', 'content': 'you are a helpful assistant.\\n'}, {'role': 'user', 'content': 'bom dia'}], 'stream': False, 'max_tokens': 10, 'user': '116544810872468347480', 'extra_body': {}}'\n",
},
"model_info": {
"id": "4bad40a1eb6bebd1682800f16f44b9f06c52a6703444c99c7f9f32e9de3693b4",
"db_model": False,
},
"proxy_server_request": {
"url": "http://localhost:4000/chat/completions",
"method": "POST",
"headers": {
"content-type": "application/json",
"authorization": "Bearer sk-1234",
"user-agent": "PostmanRuntime/7.32.3",
"accept": "*/*",
"postman-token": "92300061-eeaa-423b-a420-0b44896ecdc4",
"host": "localhost:4000",
"accept-encoding": "gzip, deflate, br",
"connection": "keep-alive",
"content-length": "163",
},
"body": {
"messages": [
{
"role": "system",
"content": "you are a helpful assistant.\n",
},
{"role": "user", "content": "bom dia"},
],
"model": "gpt-3.5-turbo",
"max_tokens": 10,
},
},
"preset_cache_key": None,
"no-log": False,
"stream_response": {},
"input_cost_per_token": None,
"input_cost_per_second": None,
"output_cost_per_token": None,
"output_cost_per_second": None,
},
"start_time": datetime.datetime(2024, 6, 7, 12, 43, 30, 307665),
"stream": False,
"user": "116544810872468347480",
"call_type": "acompletion",
"litellm_call_id": "b9929bf6-7b80-4c8c-b486-034e6ac0c8b7",
"completion_start_time": datetime.datetime(2024, 6, 7, 12, 43, 30, 954146),
"max_tokens": 10,
"extra_body": {},
"custom_llm_provider": "azure",
"input": [
{"role": "system", "content": "you are a helpful assistant.\n"},
{"role": "user", "content": "bom dia"},
],
"api_key": "1234",
"original_response": "",
"additional_args": {
"headers": {"Authorization": "Bearer 1234"},
"api_base": "openai-gpt-4-test-v-1.openai.azure.com",
"acompletion": True,
"complete_input_dict": {
"model": "chatgpt-v-2",
"messages": [
{"role": "system", "content": "you are a helpful assistant.\n"},
{"role": "user", "content": "bom dia"},
],
"stream": False,
"max_tokens": 10,
"user": "116544810872468347480",
"extra_body": {},
},
},
"log_event_type": "post_api_call",
"end_time": datetime.datetime(2024, 6, 7, 12, 43, 30, 954146),
"cache_hit": None,
"response_cost": 2.4999999999999998e-05,
},
"response_obj": litellm.ModelResponse(
id="chatcmpl-9XZmkzS1uPhRCoVdGQvBqqIbSgECt",
choices=[
litellm.Choices(
finish_reason="length",
index=0,
message=litellm.Message(
content="Bom dia! Como posso ajudar você", role="assistant"
),
)
],
created=1717789410,
model="gpt-35-turbo",
object="chat.completion",
system_fingerprint=None,
usage=litellm.Usage(
completion_tokens=10, prompt_tokens=20, total_tokens=30
),
),
"start_time": datetime.datetime(2024, 6, 7, 12, 43, 30, 308604),
"end_time": datetime.datetime(2024, 6, 7, 12, 43, 30, 954146),
"end_user_id": None,
}
payload: SpendLogsPayload = get_logging_payload(**input_args)
# Define the expected metadata keys
expected_metadata_keys = SpendLogsMetadata.__annotations__.keys()
# Validate only specified metadata keys are logged
assert "metadata" in payload
assert isinstance(payload["metadata"], str)
payload["metadata"] = json.loads(payload["metadata"])
assert set(payload["metadata"].keys()) == set(expected_metadata_keys)