Merge branch 'main' into litellm_security_fix

This commit is contained in:
Ishaan Jaff 2024-06-07 16:52:25 -07:00 committed by GitHub
commit 92841dfe1b
31 changed files with 2394 additions and 5332 deletions

View file

@ -1,4 +1,4 @@
from typing import Optional, List, Any, Literal, Union
from typing import Optional, List, Any, Literal, Union, TYPE_CHECKING
import os
import subprocess
import hashlib
@ -46,6 +46,15 @@ from email.mime.text import MIMEText
from email.mime.multipart import MIMEMultipart
from datetime import datetime, timedelta
from litellm.integrations.slack_alerting import SlackAlerting
from typing_extensions import overload
from functools import wraps
if TYPE_CHECKING:
from opentelemetry.trace import Span as _Span
Span = _Span
else:
Span = Any
def print_verbose(print_statement):
@ -63,6 +72,58 @@ def print_verbose(print_statement):
print(f"LiteLLM Proxy: {print_statement}") # noqa
def safe_deep_copy(data):
"""
Safe Deep Copy
The LiteLLM Request has some object that can-not be pickled / deep copied
Use this function to safely deep copy the LiteLLM Request
"""
# Step 1: Remove the litellm_parent_otel_span
if isinstance(data, dict):
# remove litellm_parent_otel_span since this is not picklable
if "metadata" in data and "litellm_parent_otel_span" in data["metadata"]:
litellm_parent_otel_span = data["metadata"].pop("litellm_parent_otel_span")
new_data = copy.deepcopy(data)
# Step 2: re-add the litellm_parent_otel_span after doing a deep copy
if isinstance(data, dict):
if "metadata" in data:
data["metadata"]["litellm_parent_otel_span"] = litellm_parent_otel_span
return new_data
def log_to_opentelemetry(func):
@wraps(func)
async def wrapper(*args, **kwargs):
start_time = datetime.now()
result = await func(*args, **kwargs)
end_time = datetime.now()
# Log to OTEL only if "parent_otel_span" is in kwargs and is not None
if (
"parent_otel_span" in kwargs
and kwargs["parent_otel_span"] is not None
and "proxy_logging_obj" in kwargs
and kwargs["proxy_logging_obj"] is not None
):
proxy_logging_obj = kwargs["proxy_logging_obj"]
await proxy_logging_obj.service_logging_obj.async_service_success_hook(
service=ServiceTypes.DB,
call_type=func.__name__,
parent_otel_span=kwargs["parent_otel_span"],
duration=0.0,
start_time=start_time,
end_time=end_time,
)
# end of logging to otel
return result
return wrapper
### LOGGING ###
class ProxyLogging:
"""
@ -282,7 +343,7 @@ class ProxyLogging:
"""
Runs the CustomLogger's async_moderation_hook()
"""
new_data = copy.deepcopy(data)
new_data = safe_deep_copy(data)
for callback in litellm.callbacks:
try:
if isinstance(callback, CustomLogger):
@ -832,6 +893,7 @@ class PrismaClient:
max_time=10, # maximum total time to retry for
on_backoff=on_backoff, # specifying the function to call on backoff
)
@log_to_opentelemetry
async def get_data(
self,
token: Optional[Union[str, list]] = None,
@ -858,6 +920,8 @@ class PrismaClient:
limit: Optional[
int
] = None, # pagination, number of rows to getch when find_all==True
parent_otel_span: Optional[Span] = None,
proxy_logging_obj: Optional[ProxyLogging] = None,
):
args_passed_in = locals()
start_time = time.time()
@ -2829,6 +2893,10 @@ missing_keys_html_form = """
"""
def _to_ns(dt):
return int(dt.timestamp() * 1e9)
def get_error_message_str(e: Exception) -> str:
error_message = ""
if isinstance(e, HTTPException):