mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-28 04:04:31 +00:00
Merge branch 'main' into litellm_security_fix
This commit is contained in:
commit
92841dfe1b
31 changed files with 2394 additions and 5332 deletions
|
@ -1,11 +1,20 @@
|
|||
from pydantic import BaseModel, Extra, Field, model_validator, Json, ConfigDict
|
||||
from dataclasses import fields
|
||||
import enum
|
||||
from typing import Optional, List, Union, Dict, Literal, Any, TypedDict
|
||||
from typing import Optional, List, Union, Dict, Literal, Any, TypedDict, TYPE_CHECKING
|
||||
from datetime import datetime
|
||||
import uuid, json, sys, os
|
||||
from litellm.types.router import UpdateRouterConfig
|
||||
from litellm.types.utils import ProviderField
|
||||
from typing_extensions import Annotated
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from opentelemetry.trace import Span as _Span
|
||||
|
||||
Span = _Span
|
||||
else:
|
||||
Span = Any
|
||||
|
||||
|
||||
class LitellmUserRoles(str, enum.Enum):
|
||||
|
@ -1195,6 +1204,7 @@ class UserAPIKeyAuth(
|
|||
]
|
||||
] = None
|
||||
allowed_model_region: Optional[Literal["eu"]] = None
|
||||
parent_otel_span: Optional[Span] = None
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
|
@ -1207,6 +1217,9 @@ class UserAPIKeyAuth(
|
|||
values.update({"api_key": hash_token(values.get("api_key"))})
|
||||
return values
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
|
||||
class LiteLLM_Config(LiteLLMBase):
|
||||
param_name: str
|
||||
|
|
|
@ -17,10 +17,19 @@ from litellm.proxy._types import (
|
|||
LiteLLM_OrganizationTable,
|
||||
LitellmUserRoles,
|
||||
)
|
||||
from typing import Optional, Literal, Union
|
||||
from litellm.proxy.utils import PrismaClient
|
||||
from typing import Optional, Literal, TYPE_CHECKING, Any
|
||||
from litellm.proxy.utils import PrismaClient, ProxyLogging, log_to_opentelemetry
|
||||
from litellm.caching import DualCache
|
||||
import litellm
|
||||
from litellm.types.services import ServiceLoggerPayload, ServiceTypes
|
||||
from datetime import datetime
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from opentelemetry.trace import Span as _Span
|
||||
|
||||
Span = _Span
|
||||
else:
|
||||
Span = Any
|
||||
|
||||
all_routes = LiteLLMRoutes.openai_routes.value + LiteLLMRoutes.management_routes.value
|
||||
|
||||
|
@ -216,10 +225,13 @@ def get_actual_routes(allowed_routes: list) -> list:
|
|||
return actual_routes
|
||||
|
||||
|
||||
@log_to_opentelemetry
|
||||
async def get_end_user_object(
|
||||
end_user_id: Optional[str],
|
||||
prisma_client: Optional[PrismaClient],
|
||||
user_api_key_cache: DualCache,
|
||||
parent_otel_span: Optional[Span] = None,
|
||||
proxy_logging_obj: Optional[ProxyLogging] = None,
|
||||
) -> Optional[LiteLLM_EndUserTable]:
|
||||
"""
|
||||
Returns end user object, if in db.
|
||||
|
@ -279,11 +291,14 @@ async def get_end_user_object(
|
|||
return None
|
||||
|
||||
|
||||
@log_to_opentelemetry
|
||||
async def get_user_object(
|
||||
user_id: str,
|
||||
prisma_client: Optional[PrismaClient],
|
||||
user_api_key_cache: DualCache,
|
||||
user_id_upsert: bool,
|
||||
parent_otel_span: Optional[Span] = None,
|
||||
proxy_logging_obj: Optional[ProxyLogging] = None,
|
||||
) -> Optional[LiteLLM_UserTable]:
|
||||
"""
|
||||
- Check if user id in proxy User Table
|
||||
|
@ -330,10 +345,13 @@ async def get_user_object(
|
|||
)
|
||||
|
||||
|
||||
@log_to_opentelemetry
|
||||
async def get_team_object(
|
||||
team_id: str,
|
||||
prisma_client: Optional[PrismaClient],
|
||||
user_api_key_cache: DualCache,
|
||||
parent_otel_span: Optional[Span] = None,
|
||||
proxy_logging_obj: Optional[ProxyLogging] = None,
|
||||
) -> LiteLLM_TeamTable:
|
||||
"""
|
||||
- Check if team id in proxy Team Table
|
||||
|
@ -372,10 +390,13 @@ async def get_team_object(
|
|||
)
|
||||
|
||||
|
||||
@log_to_opentelemetry
|
||||
async def get_org_object(
|
||||
org_id: str,
|
||||
prisma_client: Optional[PrismaClient],
|
||||
user_api_key_cache: DualCache,
|
||||
parent_otel_span: Optional[Span] = None,
|
||||
proxy_logging_obj: Optional[ProxyLogging] = None,
|
||||
):
|
||||
"""
|
||||
- Check if org id in proxy Org Table
|
||||
|
|
130
litellm/proxy/litellm_pre_call_utils.py
Normal file
130
litellm/proxy/litellm_pre_call_utils.py
Normal file
|
@ -0,0 +1,130 @@
|
|||
import copy
|
||||
from fastapi import Request
|
||||
from typing import Any, Dict, Optional, TYPE_CHECKING
|
||||
from litellm.proxy._types import UserAPIKeyAuth
|
||||
from litellm._logging import verbose_proxy_logger, verbose_logger
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from litellm.proxy.proxy_server import ProxyConfig as _ProxyConfig
|
||||
|
||||
ProxyConfig = _ProxyConfig
|
||||
else:
|
||||
ProxyConfig = Any
|
||||
|
||||
|
||||
def parse_cache_control(cache_control):
|
||||
cache_dict = {}
|
||||
directives = cache_control.split(", ")
|
||||
|
||||
for directive in directives:
|
||||
if "=" in directive:
|
||||
key, value = directive.split("=")
|
||||
cache_dict[key] = value
|
||||
else:
|
||||
cache_dict[directive] = True
|
||||
|
||||
return cache_dict
|
||||
|
||||
|
||||
async def add_litellm_data_to_request(
|
||||
data: dict,
|
||||
request: Request,
|
||||
user_api_key_dict: UserAPIKeyAuth,
|
||||
proxy_config: ProxyConfig,
|
||||
general_settings: Optional[Dict[str, Any]] = None,
|
||||
version: Optional[str] = None,
|
||||
):
|
||||
"""
|
||||
Adds LiteLLM-specific data to the request.
|
||||
|
||||
Args:
|
||||
data (dict): The data dictionary to be modified.
|
||||
request (Request): The incoming request.
|
||||
user_api_key_dict (UserAPIKeyAuth): The user API key dictionary.
|
||||
general_settings (Optional[Dict[str, Any]], optional): General settings. Defaults to None.
|
||||
version (Optional[str], optional): Version. Defaults to None.
|
||||
|
||||
Returns:
|
||||
dict: The modified data dictionary.
|
||||
|
||||
"""
|
||||
query_params = dict(request.query_params)
|
||||
if "api-version" in query_params:
|
||||
data["api_version"] = query_params["api-version"]
|
||||
|
||||
# Include original request and headers in the data
|
||||
data["proxy_server_request"] = {
|
||||
"url": str(request.url),
|
||||
"method": request.method,
|
||||
"headers": dict(request.headers),
|
||||
"body": copy.copy(data), # use copy instead of deepcopy
|
||||
}
|
||||
|
||||
## Cache Controls
|
||||
headers = request.headers
|
||||
verbose_proxy_logger.debug("Request Headers: %s", headers)
|
||||
cache_control_header = headers.get("Cache-Control", None)
|
||||
if cache_control_header:
|
||||
cache_dict = parse_cache_control(cache_control_header)
|
||||
data["ttl"] = cache_dict.get("s-maxage")
|
||||
|
||||
verbose_proxy_logger.debug("receiving data: %s", data)
|
||||
# users can pass in 'user' param to /chat/completions. Don't override it
|
||||
if data.get("user", None) is None and user_api_key_dict.user_id is not None:
|
||||
# if users are using user_api_key_auth, set `user` in `data`
|
||||
data["user"] = user_api_key_dict.user_id
|
||||
|
||||
if "metadata" not in data:
|
||||
data["metadata"] = {}
|
||||
data["metadata"]["user_api_key"] = user_api_key_dict.api_key
|
||||
data["metadata"]["user_api_key_alias"] = getattr(
|
||||
user_api_key_dict, "key_alias", None
|
||||
)
|
||||
data["metadata"]["user_api_end_user_max_budget"] = getattr(
|
||||
user_api_key_dict, "end_user_max_budget", None
|
||||
)
|
||||
data["metadata"]["litellm_api_version"] = version
|
||||
|
||||
if general_settings is not None:
|
||||
data["metadata"]["global_max_parallel_requests"] = general_settings.get(
|
||||
"global_max_parallel_requests", None
|
||||
)
|
||||
|
||||
data["metadata"]["user_api_key_user_id"] = user_api_key_dict.user_id
|
||||
data["metadata"]["user_api_key_org_id"] = user_api_key_dict.org_id
|
||||
data["metadata"]["user_api_key_team_id"] = getattr(
|
||||
user_api_key_dict, "team_id", None
|
||||
)
|
||||
data["metadata"]["user_api_key_team_alias"] = getattr(
|
||||
user_api_key_dict, "team_alias", None
|
||||
)
|
||||
data["metadata"]["user_api_key_metadata"] = user_api_key_dict.metadata
|
||||
_headers = dict(request.headers)
|
||||
_headers.pop(
|
||||
"authorization", None
|
||||
) # do not store the original `sk-..` api key in the db
|
||||
data["metadata"]["headers"] = _headers
|
||||
data["metadata"]["endpoint"] = str(request.url)
|
||||
# Add the OTEL Parent Trace before sending it LiteLLM
|
||||
data["metadata"]["litellm_parent_otel_span"] = user_api_key_dict.parent_otel_span
|
||||
|
||||
### END-USER SPECIFIC PARAMS ###
|
||||
if user_api_key_dict.allowed_model_region is not None:
|
||||
data["allowed_model_region"] = user_api_key_dict.allowed_model_region
|
||||
|
||||
### TEAM-SPECIFIC PARAMS ###
|
||||
if user_api_key_dict.team_id is not None:
|
||||
team_config = await proxy_config.load_team_config(
|
||||
team_id=user_api_key_dict.team_id
|
||||
)
|
||||
if len(team_config) == 0:
|
||||
pass
|
||||
else:
|
||||
team_id = team_config.pop("team_id", None)
|
||||
data["metadata"]["team_id"] = team_id
|
||||
data = {
|
||||
**team_config,
|
||||
**data,
|
||||
} # add the team-specific configs to the completion call
|
||||
|
||||
return data
|
|
@ -21,10 +21,14 @@ model_list:
|
|||
|
||||
general_settings:
|
||||
master_key: sk-1234
|
||||
alerting: ["slack"]
|
||||
|
||||
litellm_settings:
|
||||
callbacks: ["otel"]
|
||||
store_audit_logs: true
|
||||
redact_messages_in_exceptions: True
|
||||
enforced_params:
|
||||
- user
|
||||
- metadata
|
||||
- metadata.generation_name
|
||||
|
||||
litellm_settings:
|
||||
store_audit_logs: true
|
File diff suppressed because it is too large
Load diff
|
@ -1,4 +1,4 @@
|
|||
from typing import Optional, List, Any, Literal, Union
|
||||
from typing import Optional, List, Any, Literal, Union, TYPE_CHECKING
|
||||
import os
|
||||
import subprocess
|
||||
import hashlib
|
||||
|
@ -46,6 +46,15 @@ from email.mime.text import MIMEText
|
|||
from email.mime.multipart import MIMEMultipart
|
||||
from datetime import datetime, timedelta
|
||||
from litellm.integrations.slack_alerting import SlackAlerting
|
||||
from typing_extensions import overload
|
||||
from functools import wraps
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from opentelemetry.trace import Span as _Span
|
||||
|
||||
Span = _Span
|
||||
else:
|
||||
Span = Any
|
||||
|
||||
|
||||
def print_verbose(print_statement):
|
||||
|
@ -63,6 +72,58 @@ def print_verbose(print_statement):
|
|||
print(f"LiteLLM Proxy: {print_statement}") # noqa
|
||||
|
||||
|
||||
def safe_deep_copy(data):
|
||||
"""
|
||||
Safe Deep Copy
|
||||
|
||||
The LiteLLM Request has some object that can-not be pickled / deep copied
|
||||
|
||||
Use this function to safely deep copy the LiteLLM Request
|
||||
"""
|
||||
|
||||
# Step 1: Remove the litellm_parent_otel_span
|
||||
if isinstance(data, dict):
|
||||
# remove litellm_parent_otel_span since this is not picklable
|
||||
if "metadata" in data and "litellm_parent_otel_span" in data["metadata"]:
|
||||
litellm_parent_otel_span = data["metadata"].pop("litellm_parent_otel_span")
|
||||
new_data = copy.deepcopy(data)
|
||||
|
||||
# Step 2: re-add the litellm_parent_otel_span after doing a deep copy
|
||||
if isinstance(data, dict):
|
||||
if "metadata" in data:
|
||||
data["metadata"]["litellm_parent_otel_span"] = litellm_parent_otel_span
|
||||
return new_data
|
||||
|
||||
|
||||
def log_to_opentelemetry(func):
|
||||
@wraps(func)
|
||||
async def wrapper(*args, **kwargs):
|
||||
start_time = datetime.now()
|
||||
result = await func(*args, **kwargs)
|
||||
end_time = datetime.now()
|
||||
|
||||
# Log to OTEL only if "parent_otel_span" is in kwargs and is not None
|
||||
if (
|
||||
"parent_otel_span" in kwargs
|
||||
and kwargs["parent_otel_span"] is not None
|
||||
and "proxy_logging_obj" in kwargs
|
||||
and kwargs["proxy_logging_obj"] is not None
|
||||
):
|
||||
proxy_logging_obj = kwargs["proxy_logging_obj"]
|
||||
await proxy_logging_obj.service_logging_obj.async_service_success_hook(
|
||||
service=ServiceTypes.DB,
|
||||
call_type=func.__name__,
|
||||
parent_otel_span=kwargs["parent_otel_span"],
|
||||
duration=0.0,
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
)
|
||||
# end of logging to otel
|
||||
return result
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
### LOGGING ###
|
||||
class ProxyLogging:
|
||||
"""
|
||||
|
@ -282,7 +343,7 @@ class ProxyLogging:
|
|||
"""
|
||||
Runs the CustomLogger's async_moderation_hook()
|
||||
"""
|
||||
new_data = copy.deepcopy(data)
|
||||
new_data = safe_deep_copy(data)
|
||||
for callback in litellm.callbacks:
|
||||
try:
|
||||
if isinstance(callback, CustomLogger):
|
||||
|
@ -832,6 +893,7 @@ class PrismaClient:
|
|||
max_time=10, # maximum total time to retry for
|
||||
on_backoff=on_backoff, # specifying the function to call on backoff
|
||||
)
|
||||
@log_to_opentelemetry
|
||||
async def get_data(
|
||||
self,
|
||||
token: Optional[Union[str, list]] = None,
|
||||
|
@ -858,6 +920,8 @@ class PrismaClient:
|
|||
limit: Optional[
|
||||
int
|
||||
] = None, # pagination, number of rows to getch when find_all==True
|
||||
parent_otel_span: Optional[Span] = None,
|
||||
proxy_logging_obj: Optional[ProxyLogging] = None,
|
||||
):
|
||||
args_passed_in = locals()
|
||||
start_time = time.time()
|
||||
|
@ -2829,6 +2893,10 @@ missing_keys_html_form = """
|
|||
"""
|
||||
|
||||
|
||||
def _to_ns(dt):
|
||||
return int(dt.timestamp() * 1e9)
|
||||
|
||||
|
||||
def get_error_message_str(e: Exception) -> str:
|
||||
error_message = ""
|
||||
if isinstance(e, HTTPException):
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue