Merge branch 'main' into litellm_security_fix

This commit is contained in:
Ishaan Jaff 2024-06-07 16:52:25 -07:00 committed by GitHub
commit 92841dfe1b
31 changed files with 2394 additions and 5332 deletions

View file

@ -1,11 +1,20 @@
from pydantic import BaseModel, Extra, Field, model_validator, Json, ConfigDict
from dataclasses import fields
import enum
from typing import Optional, List, Union, Dict, Literal, Any, TypedDict
from typing import Optional, List, Union, Dict, Literal, Any, TypedDict, TYPE_CHECKING
from datetime import datetime
import uuid, json, sys, os
from litellm.types.router import UpdateRouterConfig
from litellm.types.utils import ProviderField
from typing_extensions import Annotated
if TYPE_CHECKING:
from opentelemetry.trace import Span as _Span
Span = _Span
else:
Span = Any
class LitellmUserRoles(str, enum.Enum):
@ -1195,6 +1204,7 @@ class UserAPIKeyAuth(
]
] = None
allowed_model_region: Optional[Literal["eu"]] = None
parent_otel_span: Optional[Span] = None
@model_validator(mode="before")
@classmethod
@ -1207,6 +1217,9 @@ class UserAPIKeyAuth(
values.update({"api_key": hash_token(values.get("api_key"))})
return values
class Config:
arbitrary_types_allowed = True
class LiteLLM_Config(LiteLLMBase):
param_name: str

View file

@ -17,10 +17,19 @@ from litellm.proxy._types import (
LiteLLM_OrganizationTable,
LitellmUserRoles,
)
from typing import Optional, Literal, Union
from litellm.proxy.utils import PrismaClient
from typing import Optional, Literal, TYPE_CHECKING, Any
from litellm.proxy.utils import PrismaClient, ProxyLogging, log_to_opentelemetry
from litellm.caching import DualCache
import litellm
from litellm.types.services import ServiceLoggerPayload, ServiceTypes
from datetime import datetime
if TYPE_CHECKING:
from opentelemetry.trace import Span as _Span
Span = _Span
else:
Span = Any
all_routes = LiteLLMRoutes.openai_routes.value + LiteLLMRoutes.management_routes.value
@ -216,10 +225,13 @@ def get_actual_routes(allowed_routes: list) -> list:
return actual_routes
@log_to_opentelemetry
async def get_end_user_object(
end_user_id: Optional[str],
prisma_client: Optional[PrismaClient],
user_api_key_cache: DualCache,
parent_otel_span: Optional[Span] = None,
proxy_logging_obj: Optional[ProxyLogging] = None,
) -> Optional[LiteLLM_EndUserTable]:
"""
Returns end user object, if in db.
@ -279,11 +291,14 @@ async def get_end_user_object(
return None
@log_to_opentelemetry
async def get_user_object(
user_id: str,
prisma_client: Optional[PrismaClient],
user_api_key_cache: DualCache,
user_id_upsert: bool,
parent_otel_span: Optional[Span] = None,
proxy_logging_obj: Optional[ProxyLogging] = None,
) -> Optional[LiteLLM_UserTable]:
"""
- Check if user id in proxy User Table
@ -330,10 +345,13 @@ async def get_user_object(
)
@log_to_opentelemetry
async def get_team_object(
team_id: str,
prisma_client: Optional[PrismaClient],
user_api_key_cache: DualCache,
parent_otel_span: Optional[Span] = None,
proxy_logging_obj: Optional[ProxyLogging] = None,
) -> LiteLLM_TeamTable:
"""
- Check if team id in proxy Team Table
@ -372,10 +390,13 @@ async def get_team_object(
)
@log_to_opentelemetry
async def get_org_object(
org_id: str,
prisma_client: Optional[PrismaClient],
user_api_key_cache: DualCache,
parent_otel_span: Optional[Span] = None,
proxy_logging_obj: Optional[ProxyLogging] = None,
):
"""
- Check if org id in proxy Org Table

View file

@ -0,0 +1,130 @@
import copy
from fastapi import Request
from typing import Any, Dict, Optional, TYPE_CHECKING
from litellm.proxy._types import UserAPIKeyAuth
from litellm._logging import verbose_proxy_logger, verbose_logger
if TYPE_CHECKING:
from litellm.proxy.proxy_server import ProxyConfig as _ProxyConfig
ProxyConfig = _ProxyConfig
else:
ProxyConfig = Any
def parse_cache_control(cache_control):
cache_dict = {}
directives = cache_control.split(", ")
for directive in directives:
if "=" in directive:
key, value = directive.split("=")
cache_dict[key] = value
else:
cache_dict[directive] = True
return cache_dict
async def add_litellm_data_to_request(
data: dict,
request: Request,
user_api_key_dict: UserAPIKeyAuth,
proxy_config: ProxyConfig,
general_settings: Optional[Dict[str, Any]] = None,
version: Optional[str] = None,
):
"""
Adds LiteLLM-specific data to the request.
Args:
data (dict): The data dictionary to be modified.
request (Request): The incoming request.
user_api_key_dict (UserAPIKeyAuth): The user API key dictionary.
general_settings (Optional[Dict[str, Any]], optional): General settings. Defaults to None.
version (Optional[str], optional): Version. Defaults to None.
Returns:
dict: The modified data dictionary.
"""
query_params = dict(request.query_params)
if "api-version" in query_params:
data["api_version"] = query_params["api-version"]
# Include original request and headers in the data
data["proxy_server_request"] = {
"url": str(request.url),
"method": request.method,
"headers": dict(request.headers),
"body": copy.copy(data), # use copy instead of deepcopy
}
## Cache Controls
headers = request.headers
verbose_proxy_logger.debug("Request Headers: %s", headers)
cache_control_header = headers.get("Cache-Control", None)
if cache_control_header:
cache_dict = parse_cache_control(cache_control_header)
data["ttl"] = cache_dict.get("s-maxage")
verbose_proxy_logger.debug("receiving data: %s", data)
# users can pass in 'user' param to /chat/completions. Don't override it
if data.get("user", None) is None and user_api_key_dict.user_id is not None:
# if users are using user_api_key_auth, set `user` in `data`
data["user"] = user_api_key_dict.user_id
if "metadata" not in data:
data["metadata"] = {}
data["metadata"]["user_api_key"] = user_api_key_dict.api_key
data["metadata"]["user_api_key_alias"] = getattr(
user_api_key_dict, "key_alias", None
)
data["metadata"]["user_api_end_user_max_budget"] = getattr(
user_api_key_dict, "end_user_max_budget", None
)
data["metadata"]["litellm_api_version"] = version
if general_settings is not None:
data["metadata"]["global_max_parallel_requests"] = general_settings.get(
"global_max_parallel_requests", None
)
data["metadata"]["user_api_key_user_id"] = user_api_key_dict.user_id
data["metadata"]["user_api_key_org_id"] = user_api_key_dict.org_id
data["metadata"]["user_api_key_team_id"] = getattr(
user_api_key_dict, "team_id", None
)
data["metadata"]["user_api_key_team_alias"] = getattr(
user_api_key_dict, "team_alias", None
)
data["metadata"]["user_api_key_metadata"] = user_api_key_dict.metadata
_headers = dict(request.headers)
_headers.pop(
"authorization", None
) # do not store the original `sk-..` api key in the db
data["metadata"]["headers"] = _headers
data["metadata"]["endpoint"] = str(request.url)
# Add the OTEL Parent Trace before sending it LiteLLM
data["metadata"]["litellm_parent_otel_span"] = user_api_key_dict.parent_otel_span
### END-USER SPECIFIC PARAMS ###
if user_api_key_dict.allowed_model_region is not None:
data["allowed_model_region"] = user_api_key_dict.allowed_model_region
### TEAM-SPECIFIC PARAMS ###
if user_api_key_dict.team_id is not None:
team_config = await proxy_config.load_team_config(
team_id=user_api_key_dict.team_id
)
if len(team_config) == 0:
pass
else:
team_id = team_config.pop("team_id", None)
data["metadata"]["team_id"] = team_id
data = {
**team_config,
**data,
} # add the team-specific configs to the completion call
return data

View file

@ -21,10 +21,14 @@ model_list:
general_settings:
master_key: sk-1234
alerting: ["slack"]
litellm_settings:
callbacks: ["otel"]
store_audit_logs: true
redact_messages_in_exceptions: True
enforced_params:
- user
- metadata
- metadata.generation_name
litellm_settings:
store_audit_logs: true

File diff suppressed because it is too large Load diff

View file

@ -1,4 +1,4 @@
from typing import Optional, List, Any, Literal, Union
from typing import Optional, List, Any, Literal, Union, TYPE_CHECKING
import os
import subprocess
import hashlib
@ -46,6 +46,15 @@ from email.mime.text import MIMEText
from email.mime.multipart import MIMEMultipart
from datetime import datetime, timedelta
from litellm.integrations.slack_alerting import SlackAlerting
from typing_extensions import overload
from functools import wraps
if TYPE_CHECKING:
from opentelemetry.trace import Span as _Span
Span = _Span
else:
Span = Any
def print_verbose(print_statement):
@ -63,6 +72,58 @@ def print_verbose(print_statement):
print(f"LiteLLM Proxy: {print_statement}") # noqa
def safe_deep_copy(data):
"""
Safe Deep Copy
The LiteLLM Request has some object that can-not be pickled / deep copied
Use this function to safely deep copy the LiteLLM Request
"""
# Step 1: Remove the litellm_parent_otel_span
if isinstance(data, dict):
# remove litellm_parent_otel_span since this is not picklable
if "metadata" in data and "litellm_parent_otel_span" in data["metadata"]:
litellm_parent_otel_span = data["metadata"].pop("litellm_parent_otel_span")
new_data = copy.deepcopy(data)
# Step 2: re-add the litellm_parent_otel_span after doing a deep copy
if isinstance(data, dict):
if "metadata" in data:
data["metadata"]["litellm_parent_otel_span"] = litellm_parent_otel_span
return new_data
def log_to_opentelemetry(func):
@wraps(func)
async def wrapper(*args, **kwargs):
start_time = datetime.now()
result = await func(*args, **kwargs)
end_time = datetime.now()
# Log to OTEL only if "parent_otel_span" is in kwargs and is not None
if (
"parent_otel_span" in kwargs
and kwargs["parent_otel_span"] is not None
and "proxy_logging_obj" in kwargs
and kwargs["proxy_logging_obj"] is not None
):
proxy_logging_obj = kwargs["proxy_logging_obj"]
await proxy_logging_obj.service_logging_obj.async_service_success_hook(
service=ServiceTypes.DB,
call_type=func.__name__,
parent_otel_span=kwargs["parent_otel_span"],
duration=0.0,
start_time=start_time,
end_time=end_time,
)
# end of logging to otel
return result
return wrapper
### LOGGING ###
class ProxyLogging:
"""
@ -282,7 +343,7 @@ class ProxyLogging:
"""
Runs the CustomLogger's async_moderation_hook()
"""
new_data = copy.deepcopy(data)
new_data = safe_deep_copy(data)
for callback in litellm.callbacks:
try:
if isinstance(callback, CustomLogger):
@ -832,6 +893,7 @@ class PrismaClient:
max_time=10, # maximum total time to retry for
on_backoff=on_backoff, # specifying the function to call on backoff
)
@log_to_opentelemetry
async def get_data(
self,
token: Optional[Union[str, list]] = None,
@ -858,6 +920,8 @@ class PrismaClient:
limit: Optional[
int
] = None, # pagination, number of rows to getch when find_all==True
parent_otel_span: Optional[Span] = None,
proxy_logging_obj: Optional[ProxyLogging] = None,
):
args_passed_in = locals()
start_time = time.time()
@ -2829,6 +2893,10 @@ missing_keys_html_form = """
"""
def _to_ns(dt):
return int(dt.timestamp() * 1e9)
def get_error_message_str(e: Exception) -> str:
error_message = ""
if isinstance(e, HTTPException):