mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 19:24:27 +00:00
Merge pull request #5420 from BerriAI/litellm_add_oauth2_mapping
[Feat-Proxy] Add hook for oauth2 proxy headers
This commit is contained in:
commit
a27cf9960b
3 changed files with 58 additions and 0 deletions
45
litellm/proxy/auth/oauth2_proxy_hook.py
Normal file
45
litellm/proxy/auth/oauth2_proxy_hook.py
Normal file
|
@ -0,0 +1,45 @@
|
|||
from typing import Any, Dict
|
||||
|
||||
from fastapi import Request
|
||||
|
||||
from litellm._logging import verbose_proxy_logger
|
||||
from litellm.proxy._types import UserAPIKeyAuth
|
||||
|
||||
|
||||
async def handle_oauth2_proxy_request(request: Request) -> UserAPIKeyAuth:
|
||||
"""
|
||||
Handle request from oauth2 proxy.
|
||||
"""
|
||||
from litellm.proxy.proxy_server import general_settings
|
||||
|
||||
verbose_proxy_logger.debug("Handling oauth2 proxy request")
|
||||
# Define the OAuth2 config mappings
|
||||
oauth2_config_mappings: Dict[str, str] = general_settings.get(
|
||||
"oauth2_config_mappings", None
|
||||
)
|
||||
verbose_proxy_logger.debug(f"Oauth2 config mappings: {oauth2_config_mappings}")
|
||||
|
||||
if not oauth2_config_mappings:
|
||||
raise ValueError("Oauth2 config mappings not found in general_settings")
|
||||
# Initialize a dictionary to store the mapped values
|
||||
auth_data: Dict[str, Any] = {}
|
||||
|
||||
# Extract values from headers based on the mappings
|
||||
for key, header in oauth2_config_mappings.items():
|
||||
value = request.headers.get(header)
|
||||
if value:
|
||||
# Convert max_budget to float if present
|
||||
if key == "max_budget":
|
||||
auth_data[key] = float(value)
|
||||
# Convert models to list if present
|
||||
elif key == "models":
|
||||
auth_data[key] = [model.strip() for model in value.split(",")]
|
||||
else:
|
||||
auth_data[key] = value
|
||||
verbose_proxy_logger.debug(
|
||||
f"Auth data before creating UserAPIKeyAuth object: {auth_data}"
|
||||
)
|
||||
user_api_key_auth = UserAPIKeyAuth(**auth_data)
|
||||
verbose_proxy_logger.debug(f"UserAPIKeyAuth object created: {user_api_key_auth}")
|
||||
# Create and return UserAPIKeyAuth object
|
||||
return user_api_key_auth
|
|
@ -64,6 +64,7 @@ from litellm.proxy.auth.auth_utils import (
|
|||
route_in_additonal_public_routes,
|
||||
)
|
||||
from litellm.proxy.auth.oauth2_check import check_oauth2_token
|
||||
from litellm.proxy.auth.oauth2_proxy_hook import handle_oauth2_proxy_request
|
||||
from litellm.proxy.common_utils.http_parsing_utils import _read_request_body
|
||||
from litellm.proxy.utils import _to_ns
|
||||
|
||||
|
@ -217,6 +218,9 @@ async def user_api_key_auth(
|
|||
|
||||
return await check_oauth2_token(token=api_key)
|
||||
|
||||
if general_settings.get("enable_oauth2_proxy_auth", False) is True:
|
||||
return await handle_oauth2_proxy_request(request=request)
|
||||
|
||||
if general_settings.get("enable_jwt_auth", False) is True:
|
||||
is_jwt = jwt_handler.is_jwt(token=api_key)
|
||||
verbose_proxy_logger.debug("is_jwt: %s", is_jwt)
|
||||
|
|
|
@ -13,6 +13,15 @@ model_list:
|
|||
model: cohere/rerank-english-v3.0
|
||||
api_key: os.environ/COHERE_API_KEY
|
||||
|
||||
general_settings:
|
||||
enable_oauth2_proxy_auth: True
|
||||
oauth2_config_mappings:
|
||||
token: X-Auth-Token
|
||||
user_id: X-Auth-Client-ID
|
||||
team_id: X-Auth-Team-ID
|
||||
max_budget: X-Auth-Max-Budget
|
||||
models: X-Auth-Allowed-Models
|
||||
|
||||
# default off mode
|
||||
litellm_settings:
|
||||
set_verbose: True
|
Loading…
Add table
Add a link
Reference in a new issue