Merge remote-tracking branch 'upstream/main' into fix/reset-end-user-budget-by-duration

This commit is contained in:
Laurien Lummer 2025-03-19 15:05:00 +01:00
commit 7dc0d29fe9
693 changed files with 49813 additions and 14285 deletions

View file

@ -1,7 +1,6 @@
import asyncio
import copy
import hashlib
import importlib
import json
import os
import smtplib
@ -20,6 +19,7 @@ from litellm.proxy._types import (
ProxyErrorTypes,
ProxyException,
)
from litellm.types.guardrails import GuardrailEventHooks
try:
import backoff
@ -33,7 +33,13 @@ from fastapi import HTTPException, status
import litellm
import litellm.litellm_core_utils
import litellm.litellm_core_utils.litellm_logging
from litellm import EmbeddingResponse, ImageResponse, ModelResponse, Router
from litellm import (
EmbeddingResponse,
ImageResponse,
ModelResponse,
ModelResponseStream,
Router,
)
from litellm._logging import verbose_proxy_logger
from litellm._service_logger import ServiceLogging, ServiceTypes
from litellm.caching.caching import DualCache, RedisCache
@ -49,7 +55,6 @@ from litellm.proxy._types import (
CallInfo,
LiteLLM_VerificationTokenView,
Member,
ResetTeamBudgetRequest,
UserAPIKeyAuth,
)
from litellm.proxy.db.create_views import (
@ -539,6 +544,7 @@ class ProxyLogging:
user_api_key_dict: UserAPIKeyAuth,
call_type: Literal[
"completion",
"responses",
"embeddings",
"image_generation",
"moderation",
@ -786,13 +792,17 @@ class ProxyLogging:
else:
_callback = callback # type: ignore
if _callback is not None and isinstance(_callback, CustomLogger):
await _callback.async_post_call_failure_hook(
request_data=request_data,
user_api_key_dict=user_api_key_dict,
original_exception=original_exception,
asyncio.create_task(
_callback.async_post_call_failure_hook(
request_data=request_data,
user_api_key_dict=user_api_key_dict,
original_exception=original_exception,
)
)
except Exception as e:
raise e
verbose_proxy_logger.exception(
f"[Non-Blocking] Error in post_call_failure_hook: {e}"
)
return
def _is_proxy_only_error(
@ -959,7 +969,9 @@ class ProxyLogging:
async def async_post_call_streaming_hook(
self,
response: Union[ModelResponse, EmbeddingResponse, ImageResponse],
response: Union[
ModelResponse, EmbeddingResponse, ImageResponse, ModelResponseStream
],
user_api_key_dict: UserAPIKeyAuth,
):
"""
@ -969,7 +981,7 @@ class ProxyLogging:
1. /chat/completions
"""
response_str: Optional[str] = None
if isinstance(response, ModelResponse):
if isinstance(response, (ModelResponse, ModelResponseStream)):
response_str = litellm.get_response_string(response_obj=response)
if response_str is not None:
for callback in litellm.callbacks:
@ -989,6 +1001,40 @@ class ProxyLogging:
raise e
return response
def async_post_call_streaming_iterator_hook(
self,
response,
user_api_key_dict: UserAPIKeyAuth,
request_data: dict,
):
"""
Allow user to modify outgoing streaming data -> Given a whole response iterator.
This hook is best used when you need to modify multiple chunks of the response at once.
Covers:
1. /chat/completions
"""
for callback in litellm.callbacks:
_callback: Optional[CustomLogger] = None
if isinstance(callback, str):
_callback = litellm.litellm_core_utils.litellm_logging.get_custom_logger_compatible_class(
callback
)
else:
_callback = callback # type: ignore
if _callback is not None and isinstance(_callback, CustomLogger):
if not isinstance(
_callback, CustomGuardrail
) or _callback.should_run_guardrail(
data=request_data, event_type=GuardrailEventHooks.post_call
):
response = _callback.async_post_call_streaming_iterator_hook(
user_api_key_dict=user_api_key_dict,
response=response,
request_data=request_data,
)
return response
async def post_call_streaming_hook(
self,
response: str,
@ -1726,13 +1772,7 @@ class PrismaClient:
verbose_proxy_logger.info("Data Inserted into User Table")
return new_user_row
elif table_name == "team":
db_data = self.jsonify_object(data=data)
if db_data.get("members_with_roles", None) is not None and isinstance(
db_data["members_with_roles"], list
):
db_data["members_with_roles"] = json.dumps(
db_data["members_with_roles"]
)
db_data = self.jsonify_team_object(db_data=data)
new_team_row = await self.db.litellm_teamtable.upsert(
where={"team_id": data["team_id"]},
data={
@ -2067,8 +2107,8 @@ class PrismaClient:
batcher = self.db.batch_()
for idx, team in enumerate(data_list):
try:
data_json = self.jsonify_object(
data=team.model_dump(exclude_none=True)
data_json = self.jsonify_team_object(
db_data=team.model_dump(exclude_none=True)
)
except Exception:
data_json = self.jsonify_object(
@ -2313,51 +2353,6 @@ class PrismaClient:
)
### CUSTOM FILE ###
def get_instance_fn(value: str, config_file_path: Optional[str] = None) -> Any:
module_name = value
instance_name = None
try:
# Split the path by dots to separate module from instance
parts = value.split(".")
# The module path is all but the last part, and the instance_name is the last part
module_name = ".".join(parts[:-1])
instance_name = parts[-1]
# If config_file_path is provided, use it to determine the module spec and load the module
if config_file_path is not None:
directory = os.path.dirname(config_file_path)
module_file_path = os.path.join(directory, *module_name.split("."))
module_file_path += ".py"
spec = importlib.util.spec_from_file_location(module_name, module_file_path) # type: ignore
if spec is None:
raise ImportError(
f"Could not find a module specification for {module_file_path}"
)
module = importlib.util.module_from_spec(spec) # type: ignore
spec.loader.exec_module(module) # type: ignore
else:
# Dynamically import the module
module = importlib.import_module(module_name)
# Get the instance from the module
instance = getattr(module, instance_name)
return instance
except ImportError as e:
# Re-raise the exception with a user-friendly message
if instance_name and module_name:
raise ImportError(
f"Could not import {instance_name} from {module_name}"
) from e
else:
raise e
except Exception as e:
raise e
### HELPER FUNCTIONS ###
async def _cache_user_row(user_id: str, cache: DualCache, db: PrismaClient):
"""
@ -2411,12 +2406,6 @@ async def send_email(receiver_email, subject, html):
if smtp_host is None:
raise ValueError("Trying to use SMTP, but SMTP_HOST is not set")
if smtp_username is None:
raise ValueError("Trying to use SMTP, but SMTP_USERNAME is not set")
if smtp_password is None:
raise ValueError("Trying to use SMTP, but SMTP_PASSWORD is not set")
# Attach the body to the email
email_message.attach(MIMEText(html, "html"))
@ -2426,8 +2415,9 @@ async def send_email(receiver_email, subject, html):
if os.getenv("SMTP_TLS", "True") != "False":
server.starttls()
# Login to your email account
server.login(smtp_username, smtp_password) # type: ignore
# Login to your email account only if smtp_username and smtp_password are provided
if smtp_username and smtp_password:
server.login(smtp_username, smtp_password) # type: ignore
# Send the email
server.send_message(email_message)
@ -2687,6 +2677,18 @@ class ProxyUpdateSpend:
e=e, start_time=start_time, proxy_logging_obj=proxy_logging_obj
)
@staticmethod
def disable_spend_updates() -> bool:
"""
returns True if should not update spend in db
Skips writing spend logs and updates to key, team, user spend to DB
"""
from litellm.proxy.proxy_server import general_settings
if general_settings.get("disable_spend_updates") is True:
return True
return False
async def update_spend( # noqa: PLR0915
prisma_client: PrismaClient,