Merge branch 'main' into stevefarthing/bing-search-pass-thru

This commit is contained in:
Steve Farthing 2025-03-11 08:06:56 -04:00 committed by GitHub
commit dbfb7ebdaf
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
741 changed files with 66437 additions and 15378 deletions

View file

@ -8,6 +8,7 @@ Returns a UserAPIKeyAuth object if the API key is valid
"""
import asyncio
import re
import secrets
from datetime import datetime, timezone
from typing import Optional, cast
@ -20,19 +21,17 @@ import litellm
from litellm._logging import verbose_logger, verbose_proxy_logger
from litellm._service_logger import ServiceLogging
from litellm.caching import DualCache
from litellm.litellm_core_utils.dd_tracing import tracer
from litellm.proxy._types import *
from litellm.proxy.auth.auth_checks import (
_cache_key_object,
_handle_failed_db_connection_for_get_key_object,
_virtual_key_max_budget_check,
_virtual_key_soft_budget_check,
allowed_routes_check,
can_key_call_model,
common_checks,
get_actual_routes,
get_end_user_object,
get_key_object,
get_org_object,
get_team_object,
get_user_object,
is_valid_fallback_model,
@ -46,7 +45,7 @@ from litellm.proxy.auth.auth_utils import (
route_in_additonal_public_routes,
should_run_auth_on_pass_through_provider_route,
)
from litellm.proxy.auth.handle_jwt import JWTHandler
from litellm.proxy.auth.handle_jwt import JWTAuthManager, JWTHandler
from litellm.proxy.auth.oauth2_check import check_oauth2_token
from litellm.proxy.auth.oauth2_proxy_hook import handle_oauth2_proxy_request
from litellm.proxy.auth.route_checks import RouteChecks
@ -286,168 +285,19 @@ def get_rbac_role(jwt_handler: JWTHandler, scopes: List[str]) -> str:
return LitellmUserRoles.TEAM
async def _jwt_auth_user_api_key_auth_builder(
api_key: str,
jwt_handler: JWTHandler,
route: str,
prisma_client: Optional[PrismaClient],
user_api_key_cache: DualCache,
parent_otel_span: Optional[Span],
proxy_logging_obj: ProxyLogging,
) -> JWTAuthBuilderResult:
def get_model_from_request(request_data: dict, route: str) -> Optional[str]:
# check if valid token
jwt_valid_token: dict = await jwt_handler.auth_jwt(token=api_key)
# First try to get model from request_data
model = request_data.get("model")
# check if unmatched token and enforce_rbac is true
if (
jwt_handler.litellm_jwtauth.enforce_rbac is True
and jwt_handler.get_rbac_role(token=jwt_valid_token) is None
):
raise HTTPException(
status_code=403,
detail="Unmatched token passed in. enforce_rbac is set to True. Token must belong to a proxy admin, team, or user. See how to set roles in config here: https://docs.litellm.ai/docs/proxy/token_auth#advanced---spend-tracking-end-users--internal-users--team--org",
)
# get scopes
scopes = jwt_handler.get_scopes(token=jwt_valid_token)
# If model not in request_data, try to extract from route
if model is None:
# Parse model from route that follows the pattern /openai/deployments/{model}/*
match = re.match(r"/openai/deployments/([^/]+)", route)
if match:
model = match.group(1)
# [OPTIONAL] allowed user email domains
valid_user_email: Optional[bool] = None
user_email: Optional[str] = None
if jwt_handler.is_enforced_email_domain():
"""
if 'allowed_email_subdomains' is set,
- checks if token contains 'email' field
- checks if 'email' is from an allowed domain
"""
user_email = jwt_handler.get_user_email(
token=jwt_valid_token, default_value=None
)
if user_email is None:
valid_user_email = False
else:
valid_user_email = jwt_handler.is_allowed_domain(user_email=user_email)
# [OPTIONAL] track spend against an internal employee - `LiteLLM_UserTable`
user_object = None
user_id = jwt_handler.get_user_id(token=jwt_valid_token, default_value=user_email)
# get org id
org_id = jwt_handler.get_org_id(token=jwt_valid_token, default_value=None)
# get team id
team_id = jwt_handler.get_team_id(token=jwt_valid_token, default_value=None)
# get end user id
end_user_id = jwt_handler.get_end_user_id(token=jwt_valid_token, default_value=None)
# check if admin
is_admin = jwt_handler.is_admin(scopes=scopes)
# if admin return
if is_admin:
# check allowed admin routes
is_allowed = allowed_routes_check(
user_role=LitellmUserRoles.PROXY_ADMIN,
user_route=route,
litellm_proxy_roles=jwt_handler.litellm_jwtauth,
)
if is_allowed:
return JWTAuthBuilderResult(
is_proxy_admin=True,
team_object=None,
user_object=None,
end_user_object=None,
org_object=None,
token=api_key,
team_id=team_id,
user_id=user_id,
end_user_id=end_user_id,
org_id=org_id,
)
else:
allowed_routes: List[Any] = jwt_handler.litellm_jwtauth.admin_allowed_routes
actual_routes = get_actual_routes(allowed_routes=allowed_routes)
raise Exception(
f"Admin not allowed to access this route. Route={route}, Allowed Routes={actual_routes}"
)
if team_id is None and jwt_handler.is_required_team_id() is True:
raise Exception(
f"No team id passed in. Field checked in jwt token - '{jwt_handler.litellm_jwtauth.team_id_jwt_field}'"
)
team_object: Optional[LiteLLM_TeamTable] = None
if team_id is not None:
# check allowed team routes
is_allowed = allowed_routes_check(
user_role=LitellmUserRoles.TEAM,
user_route=route,
litellm_proxy_roles=jwt_handler.litellm_jwtauth,
)
if is_allowed is False:
allowed_routes = jwt_handler.litellm_jwtauth.team_allowed_routes # type: ignore
actual_routes = get_actual_routes(allowed_routes=allowed_routes)
raise Exception(
f"Team not allowed to access this route. Route={route}, Allowed Routes={actual_routes}"
)
# check if team in db
team_object = await get_team_object(
team_id=team_id,
prisma_client=prisma_client,
user_api_key_cache=user_api_key_cache,
parent_otel_span=parent_otel_span,
proxy_logging_obj=proxy_logging_obj,
)
# [OPTIONAL] track spend for an org id - `LiteLLM_OrganizationTable`
org_object: Optional[LiteLLM_OrganizationTable] = None
if org_id is not None:
org_object = await get_org_object(
org_id=org_id,
prisma_client=prisma_client,
user_api_key_cache=user_api_key_cache,
parent_otel_span=parent_otel_span,
proxy_logging_obj=proxy_logging_obj,
)
if user_id is not None:
# get the user object
user_object = await get_user_object(
user_id=user_id,
prisma_client=prisma_client,
user_api_key_cache=user_api_key_cache,
user_id_upsert=jwt_handler.is_upsert_user_id(
valid_user_email=valid_user_email
),
parent_otel_span=parent_otel_span,
proxy_logging_obj=proxy_logging_obj,
)
# [OPTIONAL] track spend against an external user - `LiteLLM_EndUserTable`
end_user_object = None
if end_user_id is not None:
# get the end-user object
end_user_object = await get_end_user_object(
end_user_id=end_user_id,
prisma_client=prisma_client,
user_api_key_cache=user_api_key_cache,
parent_otel_span=parent_otel_span,
proxy_logging_obj=proxy_logging_obj,
)
return {
"is_proxy_admin": False,
"team_id": team_id,
"team_object": team_object,
"user_id": user_id,
"user_object": user_object,
"org_id": org_id,
"org_object": org_object,
"end_user_id": end_user_id,
"end_user_object": end_user_object,
"token": api_key,
}
return model
async def _user_api_key_auth_builder( # noqa: PLR0915
@ -588,7 +438,9 @@ async def _user_api_key_auth_builder( # noqa: PLR0915
is_jwt = jwt_handler.is_jwt(token=api_key)
verbose_proxy_logger.debug("is_jwt: %s", is_jwt)
if is_jwt:
result = await _jwt_auth_user_api_key_auth_builder(
result = await JWTAuthManager.auth_builder(
request_data=request_data,
general_settings=general_settings,
api_key=api_key,
jwt_handler=jwt_handler,
route=route,
@ -942,6 +794,13 @@ async def _user_api_key_auth_builder( # noqa: PLR0915
)
valid_token = None
if valid_token is None:
raise Exception(
"Invalid proxy server token passed. Received API Key (hashed)={}. Unable to find token in cache or `LiteLLM_VerificationTokenTable`".format(
api_key
)
)
user_obj: Optional[LiteLLM_UserTable] = None
valid_token_dict: dict = {}
if valid_token is not None:
@ -963,21 +822,6 @@ async def _user_api_key_auth_builder( # noqa: PLR0915
raise Exception(
"Key is blocked. Update via `/key/unblock` if you're admin."
)
# Check 1. If token can call model
_model_alias_map = {}
model: Optional[str] = None
if (
hasattr(valid_token, "team_model_aliases")
and valid_token.team_model_aliases is not None
):
_model_alias_map = {
**valid_token.aliases,
**valid_token.team_model_aliases,
}
else:
_model_alias_map = {**valid_token.aliases}
litellm.model_alias_map = _model_alias_map
config = valid_token.config
if config != {}:
@ -994,7 +838,7 @@ async def _user_api_key_auth_builder( # noqa: PLR0915
# the validation will occur when checking the team has access to this model
pass
else:
model = request_data.get("model", None)
model = get_model_from_request(request_data, route)
fallback_models = cast(
Optional[List[ALL_FALLBACK_MODEL_VALUES]],
request_data.get("fallbacks", None),
@ -1085,7 +929,10 @@ async def _user_api_key_auth_builder( # noqa: PLR0915
# Check 3. If token is expired
if valid_token.expires is not None:
current_time = datetime.now(timezone.utc)
expiry_time = datetime.fromisoformat(valid_token.expires)
if isinstance(valid_token.expires, datetime):
expiry_time = valid_token.expires
else:
expiry_time = datetime.fromisoformat(valid_token.expires)
if (
expiry_time.tzinfo is None
or expiry_time.tzinfo.utcoffset(expiry_time) is None
@ -1315,6 +1162,7 @@ async def _user_api_key_auth_builder( # noqa: PLR0915
)
@tracer.wrap()
async def user_api_key_auth(
request: Request,
api_key: str = fastapi.Security(api_key_header),
@ -1386,6 +1234,7 @@ async def _return_user_api_key_auth_obj(
user_api_key_kwargs.update(
user_tpm_limit=user_obj.tpm_limit,
user_rpm_limit=user_obj.rpm_limit,
user_email=user_obj.user_email,
)
if user_obj is not None and _is_user_proxy_admin(user_obj=user_obj):
user_api_key_kwargs.update(