forked from phoenix/litellm-mirror
fix(proxy_server.py): enable jwt-auth for users
allow a user to auth into the proxy via jwt's and call allowed routes
This commit is contained in:
parent
9bf086386e
commit
d06b9a5a47
4 changed files with 86 additions and 12 deletions
|
@ -603,6 +603,7 @@ general_settings:
|
||||||
"disable_spend_logs": "boolean", # turn off writing each transaction to the db
|
"disable_spend_logs": "boolean", # turn off writing each transaction to the db
|
||||||
"disable_reset_budget": "boolean", # turn off reset budget scheduled task
|
"disable_reset_budget": "boolean", # turn off reset budget scheduled task
|
||||||
"enable_jwt_auth": "boolean", # allow proxy admin to auth in via jwt tokens with 'litellm_proxy_admin' in claims
|
"enable_jwt_auth": "boolean", # allow proxy admin to auth in via jwt tokens with 'litellm_proxy_admin' in claims
|
||||||
|
"allowed_routes": "list", # list of allowed proxy API routes - a user can access. (currently JWT-Auth only)
|
||||||
"key_management_system": "google_kms", # either google_kms or azure_kms
|
"key_management_system": "google_kms", # either google_kms or azure_kms
|
||||||
"master_key": "string",
|
"master_key": "string",
|
||||||
"database_url": "string",
|
"database_url": "string",
|
||||||
|
|
|
@ -531,6 +531,9 @@ class ConfigGeneralSettings(LiteLLMBase):
|
||||||
ui_access_mode: Optional[Literal["admin_only", "all"]] = Field(
|
ui_access_mode: Optional[Literal["admin_only", "all"]] = Field(
|
||||||
"all", description="Control access to the Proxy UI"
|
"all", description="Control access to the Proxy UI"
|
||||||
)
|
)
|
||||||
|
allowed_routes: Optional[List] = Field(
|
||||||
|
None, description="Proxy API Endpoints you want users to be able to access"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class ConfigYAML(LiteLLMBase):
|
class ConfigYAML(LiteLLMBase):
|
||||||
|
|
|
@ -9,7 +9,7 @@ Run checks for:
|
||||||
3. If end_user ('user' passed to /chat/completions, /embeddings endpoint) is in budget
|
3. If end_user ('user' passed to /chat/completions, /embeddings endpoint) is in budget
|
||||||
"""
|
"""
|
||||||
from litellm.proxy._types import LiteLLM_UserTable, LiteLLM_EndUserTable
|
from litellm.proxy._types import LiteLLM_UserTable, LiteLLM_EndUserTable
|
||||||
from typing import Optional
|
from typing import Optional, Literal
|
||||||
from litellm.proxy.utils import PrismaClient
|
from litellm.proxy.utils import PrismaClient
|
||||||
from litellm.caching import DualCache
|
from litellm.caching import DualCache
|
||||||
|
|
||||||
|
@ -19,6 +19,13 @@ def common_checks(
|
||||||
user_object: LiteLLM_UserTable,
|
user_object: LiteLLM_UserTable,
|
||||||
end_user_object: Optional[LiteLLM_EndUserTable],
|
end_user_object: Optional[LiteLLM_EndUserTable],
|
||||||
) -> bool:
|
) -> bool:
|
||||||
|
"""
|
||||||
|
Common checks across jwt + key-based auth.
|
||||||
|
|
||||||
|
1. If user can call model
|
||||||
|
2. If user is in budget
|
||||||
|
3. If end_user ('user' passed to /chat/completions, /embeddings endpoint) is in budget
|
||||||
|
"""
|
||||||
_model = request_body.get("model", None)
|
_model = request_body.get("model", None)
|
||||||
# 1. If user can call model
|
# 1. If user can call model
|
||||||
if (
|
if (
|
||||||
|
@ -47,6 +54,52 @@ def common_checks(
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
def allowed_routes_check(
|
||||||
|
user_role: Literal["proxy_admin", "app_owner"],
|
||||||
|
route: str,
|
||||||
|
allowed_routes: Optional[list] = None,
|
||||||
|
) -> bool:
|
||||||
|
"""
|
||||||
|
Check if user -> not admin - allowed to access these routes
|
||||||
|
"""
|
||||||
|
openai_routes = [
|
||||||
|
# chat completions
|
||||||
|
"/openai/deployments/{model}/chat/completions",
|
||||||
|
"/chat/completions",
|
||||||
|
"/v1/chat/completions",
|
||||||
|
# completions
|
||||||
|
# embeddings
|
||||||
|
"/openai/deployments/{model}/embeddings",
|
||||||
|
"/embeddings",
|
||||||
|
"/v1/embeddings",
|
||||||
|
# image generation
|
||||||
|
"/images/generations",
|
||||||
|
"/v1/images/generations",
|
||||||
|
# audio transcription
|
||||||
|
"/audio/transcriptions",
|
||||||
|
"/v1/audio/transcriptions",
|
||||||
|
# moderations
|
||||||
|
"/moderations",
|
||||||
|
"/v1/moderations",
|
||||||
|
# models
|
||||||
|
"/models",
|
||||||
|
"/v1/models",
|
||||||
|
]
|
||||||
|
info_routes = ["/key/info", "/team/info", "/user/info", "/model/info"]
|
||||||
|
default_routes = openai_routes + info_routes
|
||||||
|
if user_role == "proxy_admin":
|
||||||
|
return True
|
||||||
|
elif user_role == "app_owner":
|
||||||
|
if allowed_routes is None:
|
||||||
|
if route in default_routes: # check default routes
|
||||||
|
return True
|
||||||
|
elif route in allowed_routes:
|
||||||
|
return True
|
||||||
|
else:
|
||||||
|
return False
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
async def get_end_user_object(
|
async def get_end_user_object(
|
||||||
end_user_id: Optional[str],
|
end_user_id: Optional[str],
|
||||||
prisma_client: Optional[PrismaClient],
|
prisma_client: Optional[PrismaClient],
|
||||||
|
|
|
@ -110,7 +110,11 @@ from litellm.proxy.auth.handle_jwt import JWTHandler
|
||||||
from litellm.proxy.hooks.prompt_injection_detection import (
|
from litellm.proxy.hooks.prompt_injection_detection import (
|
||||||
_OPTIONAL_PromptInjectionDetection,
|
_OPTIONAL_PromptInjectionDetection,
|
||||||
)
|
)
|
||||||
from litellm.proxy.auth.auth_checks import common_checks, get_end_user_object
|
from litellm.proxy.auth.auth_checks import (
|
||||||
|
common_checks,
|
||||||
|
get_end_user_object,
|
||||||
|
allowed_routes_check,
|
||||||
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from litellm._version import version
|
from litellm._version import version
|
||||||
|
@ -332,7 +336,7 @@ def _get_pydantic_json_dict(pydantic_obj: BaseModel) -> dict:
|
||||||
async def user_api_key_auth(
|
async def user_api_key_auth(
|
||||||
request: Request, api_key: str = fastapi.Security(api_key_header)
|
request: Request, api_key: str = fastapi.Security(api_key_header)
|
||||||
) -> UserAPIKeyAuth:
|
) -> UserAPIKeyAuth:
|
||||||
global master_key, prisma_client, llm_model_list, user_custom_auth, custom_db_client
|
global master_key, prisma_client, llm_model_list, user_custom_auth, custom_db_client, general_settings
|
||||||
try:
|
try:
|
||||||
if isinstance(api_key, str):
|
if isinstance(api_key, str):
|
||||||
passed_in_key = api_key
|
passed_in_key = api_key
|
||||||
|
@ -354,6 +358,7 @@ async def user_api_key_auth(
|
||||||
enable_jwt_auth: true
|
enable_jwt_auth: true
|
||||||
```
|
```
|
||||||
"""
|
"""
|
||||||
|
route: str = request.url.path
|
||||||
if general_settings.get("enable_jwt_auth", False) == True:
|
if general_settings.get("enable_jwt_auth", False) == True:
|
||||||
is_jwt = jwt_handler.is_jwt(token=api_key)
|
is_jwt = jwt_handler.is_jwt(token=api_key)
|
||||||
verbose_proxy_logger.debug(f"is_jwt: {is_jwt}")
|
verbose_proxy_logger.debug(f"is_jwt: {is_jwt}")
|
||||||
|
@ -407,6 +412,12 @@ async def user_api_key_auth(
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
|
is_allowed = allowed_routes_check(
|
||||||
|
user_role="app_owner",
|
||||||
|
route=route,
|
||||||
|
allowed_routes=general_settings.get("allowed_routes", None),
|
||||||
|
)
|
||||||
|
if is_allowed:
|
||||||
# return UserAPIKeyAuth object
|
# return UserAPIKeyAuth object
|
||||||
return UserAPIKeyAuth(
|
return UserAPIKeyAuth(
|
||||||
api_key=None,
|
api_key=None,
|
||||||
|
@ -416,6 +427,13 @@ async def user_api_key_auth(
|
||||||
models=user_object.models,
|
models=user_object.models,
|
||||||
user_role="app_owner",
|
user_role="app_owner",
|
||||||
)
|
)
|
||||||
|
else:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=401,
|
||||||
|
detail={
|
||||||
|
"error": f"User={user_object.user_id} not allowed to access this route={route}."
|
||||||
|
},
|
||||||
|
)
|
||||||
#### ELSE ####
|
#### ELSE ####
|
||||||
if master_key is None:
|
if master_key is None:
|
||||||
if isinstance(api_key, str):
|
if isinstance(api_key, str):
|
||||||
|
@ -423,7 +441,6 @@ async def user_api_key_auth(
|
||||||
else:
|
else:
|
||||||
return UserAPIKeyAuth()
|
return UserAPIKeyAuth()
|
||||||
|
|
||||||
route: str = request.url.path
|
|
||||||
if route == "/user/auth":
|
if route == "/user/auth":
|
||||||
if general_settings.get("allow_user_auth", False) == True:
|
if general_settings.get("allow_user_auth", False) == True:
|
||||||
return UserAPIKeyAuth()
|
return UserAPIKeyAuth()
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue