fix(proxy/utils.py): security fix

use typed dict for spendlogs payload. assert no sensitive information logged.
This commit is contained in:
Krrish Dholakia 2024-06-07 13:44:11 -07:00
parent 78474b1ce7
commit f7f8bcb21b
4 changed files with 314 additions and 98 deletions

View file

@ -1,7 +1,7 @@
from pydantic import BaseModel, Extra, Field, model_validator, Json, ConfigDict
from dataclasses import fields
import enum
from typing import Optional, List, Union, Dict, Literal, Any
from typing import Optional, List, Union, Dict, Literal, Any, TypedDict
from datetime import datetime
import uuid, json, sys, os
from litellm.types.router import UpdateRouterConfig
@ -1268,7 +1268,7 @@ class LiteLLM_SpendLogs(LiteLLMBase):
startTime: Union[str, datetime, None]
endTime: Union[str, datetime, None]
user: Optional[str] = ""
metadata: Optional[dict] = {}
metadata: Optional[Json] = {}
cache_hit: Optional[str] = "False"
cache_key: Optional[str] = None
request_tags: Optional[Json] = None
@ -1446,3 +1446,39 @@ class AllCallbacks(LiteLLMBase):
litellm_callback_params=["DD_API_KEY", "DD_SITE"],
ui_callback_name="Datadog",
)
class SpendLogsMetadata(TypedDict):
"""
Specific metadata k,v pairs logged to spendlogs for easier cost tracking
"""
user_api_key: Optional[str]
user_api_key_alias: Optional[str]
user_api_key_team_id: Optional[str]
user_api_key_user_id: Optional[str]
user_api_key_team_alias: Optional[str]
class SpendLogsPayload(TypedDict):
request_id: str
call_type: str
api_key: str
spend: float
total_tokens: int
prompt_tokens: int
completion_tokens: int
startTime: datetime
endTime: datetime
completionStartTime: Optional[datetime]
model: str
model_id: Optional[str]
model_group: Optional[str]
api_base: str
user: str
metadata: str # json str
cache_hit: str
cache_key: str
request_tags: str # json str
team_id: Optional[str]
end_user: Optional[str]

View file

@ -168,11 +168,11 @@ model LiteLLM_Config {
param_value Json?
}
// View spend, model, api_key per request
// View spend, model, hashed api_key per request
model LiteLLM_SpendLogs {
request_id String @id
call_type String
api_key String @default ("")
api_key String @default ("") // Hashed API Token. Not the actual Virtual Key. Equivalent to 'token' column in LiteLLM_VerificationToken
spend Float @default(0.0)
total_tokens Int @default(0)
prompt_tokens Int @default(0)

View file

@ -1,24 +1,29 @@
from typing import Optional, List, Any, Literal, Union
import os, subprocess, hashlib, importlib, asyncio, copy, json, aiohttp, httpx, time
import litellm, backoff, traceback
import os
import subprocess
import hashlib
import importlib
import asyncio
import copy
import json
import httpx
import time
import litellm
import backoff
import traceback
from pydantic import BaseModel
from litellm.proxy._types import (
UserAPIKeyAuth,
DynamoDBArgs,
LiteLLM_VerificationToken,
LiteLLM_VerificationTokenView,
LiteLLM_SpendLogs,
LiteLLM_UserTable,
LiteLLM_EndUserTable,
LiteLLM_TeamTable,
Member,
CallInfo,
WebhookEvent,
AlertType,
ResetTeamBudgetRequest,
LitellmUserRoles,
SpendLogsMetadata,
SpendLogsPayload,
)
from litellm.caching import DualCache, RedisCache
from litellm.router import Deployment, ModelInfo, LiteLLM_Params
from litellm.llms.custom_httpx.httpx_handler import HTTPHandler
from litellm.proxy.hooks.parallel_request_limiter import (
_PROXY_MaxParallelRequestsHandler,
@ -29,24 +34,18 @@ from litellm import (
ModelResponse,
EmbeddingResponse,
ImageResponse,
TranscriptionResponse,
TextCompletionResponse,
CustomStreamWrapper,
TextCompletionStreamWrapper,
)
from litellm.utils import ModelResponseIterator
from litellm.proxy.hooks.max_budget_limiter import _PROXY_MaxBudgetLimiter
from litellm.proxy.hooks.cache_control_check import _PROXY_CacheControlCheck
from litellm.integrations.custom_logger import CustomLogger
from litellm.proxy.db.base_client import CustomDB
from litellm._logging import verbose_proxy_logger
from fastapi import HTTPException, status
import smtplib, re
import smtplib
import re
from email.mime.text import MIMEText
from email.mime.multipart import MIMEMultipart
from datetime import datetime, timedelta
from litellm.integrations.slack_alerting import SlackAlerting
from typing_extensions import overload
def print_verbose(print_statement):
@ -1895,16 +1894,15 @@ def hash_token(token: str):
def get_logging_payload(
kwargs, response_obj, start_time, end_time, end_user_id: Optional[str]
):
) -> SpendLogsPayload:
from litellm.proxy._types import LiteLLM_SpendLogs
from pydantic import Json
import uuid
verbose_proxy_logger.debug(
f"SpendTable: get_logging_payload - kwargs: {kwargs}\n\n"
)
if kwargs == None:
if kwargs is None:
kwargs = {}
# standardize this function to be used across, s3, dynamoDB, langfuse logging
litellm_params = kwargs.get("litellm_params", {})
@ -1927,94 +1925,82 @@ def get_logging_payload(
_model_group = metadata.get("model_group", "")
# clean up litellm metadata
clean_metadata = SpendLogsMetadata(
user_api_key=None,
user_api_key_alias=None,
user_api_key_team_id=None,
user_api_key_user_id=None,
user_api_key_team_alias=None,
)
if isinstance(metadata, dict):
clean_metadata = {}
verbose_proxy_logger.debug(
f"getting payload for SpendLogs, available keys in metadata: "
"getting payload for SpendLogs, available keys in metadata: "
+ str(list(metadata.keys()))
)
for key in metadata:
if key in [
"headers",
"endpoint",
"model_group",
"deployment",
"model_info",
"caching_groups",
"previous_models",
]:
continue
else:
clean_metadata[key] = metadata[key]
# Filter the metadata dictionary to include only the specified keys
clean_metadata = SpendLogsMetadata(
**{ # type: ignore
key: metadata[key]
for key in SpendLogsMetadata.__annotations__.keys()
if key in metadata
}
)
if litellm.cache is not None:
cache_key = litellm.cache.get_cache_key(**kwargs)
else:
cache_key = "Cache OFF"
if cache_hit == True:
if cache_hit is True:
import time
id = f"{id}_cache_hit{time.time()}" # SpendLogs does not allow duplicate request_id
payload = {
"request_id": id,
"call_type": call_type,
"api_key": api_key,
"cache_hit": cache_hit,
"startTime": start_time,
"endTime": end_time,
"completionStartTime": completion_start_time,
"model": kwargs.get("model", ""),
"user": kwargs.get("litellm_params", {})
try:
payload: SpendLogsPayload = SpendLogsPayload(
request_id=str(id),
call_type=call_type or "",
api_key=str(api_key),
cache_hit=str(cache_hit),
startTime=start_time,
endTime=end_time,
completionStartTime=completion_start_time,
model=kwargs.get("model", ""),
user=kwargs.get("litellm_params", {})
.get("metadata", {})
.get("user_api_key_user_id", ""),
"team_id": kwargs.get("litellm_params", {})
team_id=kwargs.get("litellm_params", {})
.get("metadata", {})
.get("user_api_key_team_id", ""),
"metadata": clean_metadata,
"cache_key": cache_key,
"spend": kwargs.get("response_cost", 0),
"total_tokens": usage.get("total_tokens", 0),
"prompt_tokens": usage.get("prompt_tokens", 0),
"completion_tokens": usage.get("completion_tokens", 0),
"request_tags": metadata.get("tags", []),
"end_user": end_user_id or "",
"api_base": litellm_params.get("api_base", ""),
"model_group": _model_group,
"model_id": _model_id,
}
metadata=json.dumps(clean_metadata),
cache_key=cache_key,
spend=kwargs.get("response_cost", 0),
total_tokens=usage.get("total_tokens", 0),
prompt_tokens=usage.get("prompt_tokens", 0),
completion_tokens=usage.get("completion_tokens", 0),
request_tags=(
json.dumps(metadata.get("tags", []))
if isinstance(metadata.get("tags", []), dict)
else "[]"
),
end_user=end_user_id or "",
api_base=litellm_params.get("api_base", ""),
model_group=_model_group,
model_id=_model_id,
)
verbose_proxy_logger.debug("SpendTable: created payload - payload: %s\n\n", payload)
json_fields = [
field
for field, field_type in LiteLLM_SpendLogs.__annotations__.items()
if field_type == Json or field_type == Optional[Json]
]
str_fields = [
field
for field, field_type in LiteLLM_SpendLogs.__annotations__.items()
if field_type == str or field_type == Optional[str]
]
datetime_fields = [
field
for field, field_type in LiteLLM_SpendLogs.__annotations__.items()
if field_type == datetime
]
for param in json_fields:
if param in payload and type(payload[param]) != Json:
if type(payload[param]) == litellm.ModelResponse:
payload[param] = payload[param].model_dump_json()
if type(payload[param]) == litellm.EmbeddingResponse:
payload[param] = payload[param].model_dump_json()
else:
payload[param] = json.dumps(payload[param])
for param in str_fields:
if param in payload and type(payload[param]) != str:
payload[param] = str(payload[param])
verbose_proxy_logger.debug(
"SpendTable: created payload - payload: %s\n\n", payload
)
return payload
except Exception as e:
verbose_proxy_logger.error(
"Error creating spendlogs object - {}\n{}".format(
str(e), traceback.format_exc()
)
)
raise e
def _duration_in_seconds(duration: str):

View file

@ -0,0 +1,194 @@
import sys, os
import traceback, uuid
from dotenv import load_dotenv
from fastapi import Request
from fastapi.routing import APIRoute
load_dotenv()
import os, io, time
# this file is to test litellm/proxy
sys.path.insert(
0, os.path.abspath("../..")
) # Adds the parent directory to the system path
import pytest, logging, asyncio
import litellm, asyncio
import json
import datetime
from litellm.proxy.utils import (
get_logging_payload,
SpendLogsPayload,
SpendLogsMetadata,
) # noqa: E402
def test_spend_logs_payload():
"""
Ensure only expected values are logged in spend logs payload.
"""
input_args: dict = {
"kwargs": {
"model": "chatgpt-v-2",
"messages": [
{"role": "system", "content": "you are a helpful assistant.\n"},
{"role": "user", "content": "bom dia"},
],
"optional_params": {
"stream": False,
"max_tokens": 10,
"user": "116544810872468347480",
"extra_body": {},
},
"litellm_params": {
"acompletion": True,
"api_key": "23c217a5b59f41b6b7a198017f4792f2",
"force_timeout": 600,
"logger_fn": None,
"verbose": False,
"custom_llm_provider": "azure",
"api_base": "https://openai-gpt-4-test-v-1.openai.azure.com//openai/",
"litellm_call_id": "b9929bf6-7b80-4c8c-b486-034e6ac0c8b7",
"model_alias_map": {},
"completion_call_id": None,
"metadata": {
"user_api_key": "88dc28d0f030c55ed4ab77ed8faf098196cb1c05df778539800c9f1243fe6b4b",
"user_api_key_alias": None,
"user_api_end_user_max_budget": None,
"litellm_api_version": "0.0.0",
"global_max_parallel_requests": None,
"user_api_key_user_id": "116544810872468347480",
"user_api_key_org_id": None,
"user_api_key_team_id": None,
"user_api_key_team_alias": None,
"user_api_key_metadata": {},
"headers": {
"content-type": "application/json",
"user-agent": "PostmanRuntime/7.32.3",
"accept": "*/*",
"postman-token": "92300061-eeaa-423b-a420-0b44896ecdc4",
"host": "localhost:4000",
"accept-encoding": "gzip, deflate, br",
"connection": "keep-alive",
"content-length": "163",
},
"endpoint": "http://localhost:4000/chat/completions",
"model_group": "gpt-3.5-turbo",
"deployment": "azure/chatgpt-v-2",
"model_info": {
"id": "4bad40a1eb6bebd1682800f16f44b9f06c52a6703444c99c7f9f32e9de3693b4",
"db_model": False,
},
"api_base": "https://openai-gpt-4-test-v-1.openai.azure.com/",
"caching_groups": None,
"raw_request": "\n\nPOST Request Sent from LiteLLM:\ncurl -X POST \\\nhttps://openai-gpt-4-test-v-1.openai.azure.com//openai/ \\\n-H 'Authorization: *****' \\\n-d '{'model': 'chatgpt-v-2', 'messages': [{'role': 'system', 'content': 'you are a helpful assistant.\\n'}, {'role': 'user', 'content': 'bom dia'}], 'stream': False, 'max_tokens': 10, 'user': '116544810872468347480', 'extra_body': {}}'\n",
},
"model_info": {
"id": "4bad40a1eb6bebd1682800f16f44b9f06c52a6703444c99c7f9f32e9de3693b4",
"db_model": False,
},
"proxy_server_request": {
"url": "http://localhost:4000/chat/completions",
"method": "POST",
"headers": {
"content-type": "application/json",
"authorization": "Bearer sk-1234",
"user-agent": "PostmanRuntime/7.32.3",
"accept": "*/*",
"postman-token": "92300061-eeaa-423b-a420-0b44896ecdc4",
"host": "localhost:4000",
"accept-encoding": "gzip, deflate, br",
"connection": "keep-alive",
"content-length": "163",
},
"body": {
"messages": [
{
"role": "system",
"content": "you are a helpful assistant.\n",
},
{"role": "user", "content": "bom dia"},
],
"model": "gpt-3.5-turbo",
"max_tokens": 10,
},
},
"preset_cache_key": None,
"no-log": False,
"stream_response": {},
"input_cost_per_token": None,
"input_cost_per_second": None,
"output_cost_per_token": None,
"output_cost_per_second": None,
},
"start_time": datetime.datetime(2024, 6, 7, 12, 43, 30, 307665),
"stream": False,
"user": "116544810872468347480",
"call_type": "acompletion",
"litellm_call_id": "b9929bf6-7b80-4c8c-b486-034e6ac0c8b7",
"completion_start_time": datetime.datetime(2024, 6, 7, 12, 43, 30, 954146),
"max_tokens": 10,
"extra_body": {},
"custom_llm_provider": "azure",
"input": [
{"role": "system", "content": "you are a helpful assistant.\n"},
{"role": "user", "content": "bom dia"},
],
"api_key": "1234",
"original_response": "",
"additional_args": {
"headers": {"Authorization": "Bearer 1234"},
"api_base": "openai-gpt-4-test-v-1.openai.azure.com",
"acompletion": True,
"complete_input_dict": {
"model": "chatgpt-v-2",
"messages": [
{"role": "system", "content": "you are a helpful assistant.\n"},
{"role": "user", "content": "bom dia"},
],
"stream": False,
"max_tokens": 10,
"user": "116544810872468347480",
"extra_body": {},
},
},
"log_event_type": "post_api_call",
"end_time": datetime.datetime(2024, 6, 7, 12, 43, 30, 954146),
"cache_hit": None,
"response_cost": 2.4999999999999998e-05,
},
"response_obj": litellm.ModelResponse(
id="chatcmpl-9XZmkzS1uPhRCoVdGQvBqqIbSgECt",
choices=[
litellm.Choices(
finish_reason="length",
index=0,
message=litellm.Message(
content="Bom dia! Como posso ajudar você", role="assistant"
),
)
],
created=1717789410,
model="gpt-35-turbo",
object="chat.completion",
system_fingerprint=None,
usage=litellm.Usage(
completion_tokens=10, prompt_tokens=20, total_tokens=30
),
),
"start_time": datetime.datetime(2024, 6, 7, 12, 43, 30, 308604),
"end_time": datetime.datetime(2024, 6, 7, 12, 43, 30, 954146),
"end_user_id": None,
}
payload: SpendLogsPayload = get_logging_payload(**input_args)
# Define the expected metadata keys
expected_metadata_keys = SpendLogsMetadata.__annotations__.keys()
# Validate only specified metadata keys are logged
assert "metadata" in payload
assert isinstance(payload["metadata"], str)
payload["metadata"] = json.loads(payload["metadata"])
assert set(payload["metadata"].keys()) == set(expected_metadata_keys)