forked from phoenix/litellm-mirror
Merge pull request #2704 from BerriAI/litellm_jwt_auth_improvements_3
fix(handle_jwt.py): enable team-based jwt-auth access
This commit is contained in:
commit
0ab708e6f1
7 changed files with 358 additions and 145 deletions
|
@ -124,7 +124,7 @@ general_settings:
|
|||
### Allowed LiteLLM scopes
|
||||
|
||||
```python
|
||||
class LiteLLMProxyRoles(LiteLLMBase):
|
||||
class LiteLLM_JWTAuth(LiteLLMBase):
|
||||
proxy_admin: str = "litellm_proxy_admin"
|
||||
proxy_user: str = "litellm_user" # 👈 Not implemented yet, for JWT-Auth.
|
||||
```
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
from pydantic import BaseModel, Extra, Field, root_validator, Json, validator
|
||||
from dataclasses import fields
|
||||
import enum
|
||||
from typing import Optional, List, Union, Dict, Literal, Any
|
||||
from datetime import datetime
|
||||
|
@ -37,9 +38,97 @@ class LiteLLMBase(BaseModel):
|
|||
protected_namespaces = ()
|
||||
|
||||
|
||||
class LiteLLMProxyRoles(LiteLLMBase):
|
||||
proxy_admin: str = "litellm_proxy_admin"
|
||||
proxy_user: str = "litellm_user"
|
||||
class LiteLLMRoutes(enum.Enum):
|
||||
openai_routes: List = [ # chat completions
|
||||
"/openai/deployments/{model}/chat/completions",
|
||||
"/chat/completions",
|
||||
"/v1/chat/completions",
|
||||
# completions
|
||||
"/openai/deployments/{model}/completions",
|
||||
"/completions",
|
||||
"/v1/completions",
|
||||
# embeddings
|
||||
"/openai/deployments/{model}/embeddings",
|
||||
"/embeddings",
|
||||
"/v1/embeddings",
|
||||
# image generation
|
||||
"/images/generations",
|
||||
"/v1/images/generations",
|
||||
# audio transcription
|
||||
"/audio/transcriptions",
|
||||
"/v1/audio/transcriptions",
|
||||
# moderations
|
||||
"/moderations",
|
||||
"/v1/moderations",
|
||||
# models
|
||||
"/models",
|
||||
"/v1/models",
|
||||
]
|
||||
|
||||
info_routes: List = ["/key/info", "/team/info", "/user/info", "/model/info"]
|
||||
|
||||
management_routes: List = [ # key
|
||||
"/key/generate",
|
||||
"/key/update",
|
||||
"/key/delete",
|
||||
"/key/info",
|
||||
# user
|
||||
"/user/new",
|
||||
"/user/update",
|
||||
"/user/delete",
|
||||
"/user/info",
|
||||
# team
|
||||
"/team/new",
|
||||
"/team/update",
|
||||
"/team/delete",
|
||||
"/team/info",
|
||||
# model
|
||||
"/model/new",
|
||||
"/model/update",
|
||||
"/model/delete",
|
||||
"/model/info",
|
||||
]
|
||||
|
||||
|
||||
class LiteLLM_JWTAuth(LiteLLMBase):
|
||||
"""
|
||||
A class to define the roles and permissions for a LiteLLM Proxy w/ JWT Auth.
|
||||
|
||||
Attributes:
|
||||
- admin_jwt_scope: The JWT scope required for proxy admin roles.
|
||||
- admin_allowed_routes: list of allowed routes for proxy admin roles.
|
||||
- team_jwt_scope: The JWT scope required for proxy team roles.
|
||||
- team_id_jwt_field: The field in the JWT token that stores the team ID. Default - `client_id`.
|
||||
- team_allowed_routes: list of allowed routes for proxy team roles.
|
||||
- end_user_id_jwt_field: Default - `sub`. The field in the JWT token that stores the end-user ID. Turn this off by setting to `None`. Enables end-user cost tracking.
|
||||
|
||||
See `auth_checks.py` for the specific routes
|
||||
"""
|
||||
|
||||
admin_jwt_scope: str = "litellm_proxy_admin"
|
||||
admin_allowed_routes: List[
|
||||
Literal["openai_routes", "info_routes", "management_routes"]
|
||||
] = ["management_routes"]
|
||||
team_jwt_scope: str = "litellm_team"
|
||||
team_id_jwt_field: str = "client_id"
|
||||
team_allowed_routes: List[
|
||||
Literal["openai_routes", "info_routes", "management_routes"]
|
||||
] = ["openai_routes", "info_routes"]
|
||||
end_user_id_jwt_field: Optional[str] = "sub"
|
||||
public_key_ttl: float = 600
|
||||
|
||||
def __init__(self, **kwargs: Any) -> None:
|
||||
# get the attribute names for this Pydantic model
|
||||
allowed_keys = self.__annotations__.keys()
|
||||
|
||||
invalid_keys = set(kwargs.keys()) - allowed_keys
|
||||
|
||||
if invalid_keys:
|
||||
raise ValueError(
|
||||
f"Invalid arguments provided: {', '.join(invalid_keys)}. Allowed arguments are: {', '.join(allowed_keys)}."
|
||||
)
|
||||
|
||||
super().__init__(**kwargs)
|
||||
|
||||
|
||||
class LiteLLMPromptInjectionParams(LiteLLMBase):
|
||||
|
|
|
@ -8,15 +8,23 @@ Run checks for:
|
|||
2. If user is in budget
|
||||
3. If end_user ('user' passed to /chat/completions, /embeddings endpoint) is in budget
|
||||
"""
|
||||
from litellm.proxy._types import LiteLLM_UserTable, LiteLLM_EndUserTable
|
||||
from litellm.proxy._types import (
|
||||
LiteLLM_UserTable,
|
||||
LiteLLM_EndUserTable,
|
||||
LiteLLM_JWTAuth,
|
||||
LiteLLM_TeamTable,
|
||||
LiteLLMRoutes,
|
||||
)
|
||||
from typing import Optional, Literal
|
||||
from litellm.proxy.utils import PrismaClient
|
||||
from litellm.caching import DualCache
|
||||
|
||||
all_routes = LiteLLMRoutes.openai_routes.value + LiteLLMRoutes.management_routes.value
|
||||
|
||||
|
||||
def common_checks(
|
||||
request_body: dict,
|
||||
user_object: LiteLLM_UserTable,
|
||||
team_object: LiteLLM_TeamTable,
|
||||
end_user_object: Optional[LiteLLM_EndUserTable],
|
||||
) -> bool:
|
||||
"""
|
||||
|
@ -30,19 +38,20 @@ def common_checks(
|
|||
# 1. If user can call model
|
||||
if (
|
||||
_model is not None
|
||||
and len(user_object.models) > 0
|
||||
and _model not in user_object.models
|
||||
and len(team_object.models) > 0
|
||||
and _model not in team_object.models
|
||||
):
|
||||
raise Exception(
|
||||
f"User={user_object.user_id} not allowed to call model={_model}. Allowed user models = {user_object.models}"
|
||||
f"Team={team_object.team_id} not allowed to call model={_model}. Allowed team models = {team_object.models}"
|
||||
)
|
||||
# 2. If user is in budget
|
||||
# 2. If team is in budget
|
||||
if (
|
||||
user_object.max_budget is not None
|
||||
and user_object.spend > user_object.max_budget
|
||||
team_object.max_budget is not None
|
||||
and team_object.spend is not None
|
||||
and team_object.spend > team_object.max_budget
|
||||
):
|
||||
raise Exception(
|
||||
f"User={user_object.user_id} over budget. Spend={user_object.spend}, Budget={user_object.max_budget}"
|
||||
f"Team={team_object.team_id} over budget. Spend={team_object.spend}, Budget={team_object.max_budget}"
|
||||
)
|
||||
# 3. If end_user ('user' passed to /chat/completions, /embeddings endpoint) is in budget
|
||||
if end_user_object is not None and end_user_object.litellm_budget_table is not None:
|
||||
|
@ -54,52 +63,79 @@ def common_checks(
|
|||
return True
|
||||
|
||||
|
||||
def _allowed_routes_check(user_route: str, allowed_routes: list) -> bool:
|
||||
for allowed_route in allowed_routes:
|
||||
if (
|
||||
allowed_route == LiteLLMRoutes.openai_routes.name
|
||||
and user_route in LiteLLMRoutes.openai_routes.value
|
||||
):
|
||||
return True
|
||||
elif (
|
||||
allowed_route == LiteLLMRoutes.info_routes.name
|
||||
and user_route in LiteLLMRoutes.info_routes.value
|
||||
):
|
||||
return True
|
||||
elif (
|
||||
allowed_route == LiteLLMRoutes.management_routes.name
|
||||
and user_route in LiteLLMRoutes.management_routes.value
|
||||
):
|
||||
return True
|
||||
elif allowed_route == user_route:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def allowed_routes_check(
|
||||
user_role: Literal["proxy_admin", "app_owner"],
|
||||
route: str,
|
||||
allowed_routes: Optional[list] = None,
|
||||
user_role: Literal["proxy_admin", "team"],
|
||||
user_route: str,
|
||||
litellm_proxy_roles: LiteLLM_JWTAuth,
|
||||
) -> bool:
|
||||
"""
|
||||
Check if user -> not admin - allowed to access these routes
|
||||
"""
|
||||
openai_routes = [
|
||||
# chat completions
|
||||
"/openai/deployments/{model}/chat/completions",
|
||||
"/chat/completions",
|
||||
"/v1/chat/completions",
|
||||
# completions
|
||||
# embeddings
|
||||
"/openai/deployments/{model}/embeddings",
|
||||
"/embeddings",
|
||||
"/v1/embeddings",
|
||||
# image generation
|
||||
"/images/generations",
|
||||
"/v1/images/generations",
|
||||
# audio transcription
|
||||
"/audio/transcriptions",
|
||||
"/v1/audio/transcriptions",
|
||||
# moderations
|
||||
"/moderations",
|
||||
"/v1/moderations",
|
||||
# models
|
||||
"/models",
|
||||
"/v1/models",
|
||||
]
|
||||
info_routes = ["/key/info", "/team/info", "/user/info", "/model/info"]
|
||||
default_routes = openai_routes + info_routes
|
||||
|
||||
if user_role == "proxy_admin":
|
||||
return True
|
||||
elif user_role == "app_owner":
|
||||
if allowed_routes is None:
|
||||
if route in default_routes: # check default routes
|
||||
return True
|
||||
elif route in allowed_routes:
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
if litellm_proxy_roles.admin_allowed_routes is None:
|
||||
is_allowed = _allowed_routes_check(
|
||||
user_route=user_route, allowed_routes=["management_routes"]
|
||||
)
|
||||
return is_allowed
|
||||
elif litellm_proxy_roles.admin_allowed_routes is not None:
|
||||
is_allowed = _allowed_routes_check(
|
||||
user_route=user_route,
|
||||
allowed_routes=litellm_proxy_roles.admin_allowed_routes,
|
||||
)
|
||||
return is_allowed
|
||||
|
||||
elif user_role == "team":
|
||||
if litellm_proxy_roles.team_allowed_routes is None:
|
||||
"""
|
||||
By default allow a team to call openai + info routes
|
||||
"""
|
||||
is_allowed = _allowed_routes_check(
|
||||
user_route=user_route, allowed_routes=["openai_routes", "info_routes"]
|
||||
)
|
||||
return is_allowed
|
||||
elif litellm_proxy_roles.team_allowed_routes is not None:
|
||||
is_allowed = _allowed_routes_check(
|
||||
user_route=user_route,
|
||||
allowed_routes=litellm_proxy_roles.team_allowed_routes,
|
||||
)
|
||||
return is_allowed
|
||||
return False
|
||||
|
||||
|
||||
def get_actual_routes(allowed_routes: list) -> list:
|
||||
actual_routes: list = []
|
||||
for route_name in allowed_routes:
|
||||
try:
|
||||
route_value = LiteLLMRoutes[route_name].value
|
||||
actual_routes = actual_routes + route_value
|
||||
except KeyError:
|
||||
actual_routes.append(route_name)
|
||||
return actual_routes
|
||||
|
||||
|
||||
async def get_end_user_object(
|
||||
end_user_id: Optional[str],
|
||||
prisma_client: Optional[PrismaClient],
|
||||
|
@ -135,3 +171,75 @@ async def get_end_user_object(
|
|||
return LiteLLM_EndUserTable(**response.dict())
|
||||
except Exception as e: # if end-user not in db
|
||||
return None
|
||||
|
||||
|
||||
async def get_user_object(self, user_id: str) -> LiteLLM_UserTable:
|
||||
"""
|
||||
- Check if user id in proxy User Table
|
||||
- if valid, return LiteLLM_UserTable object with defined limits
|
||||
- if not, then raise an error
|
||||
"""
|
||||
if self.prisma_client is None:
|
||||
raise Exception(
|
||||
"No DB Connected. See - https://docs.litellm.ai/docs/proxy/virtual_keys"
|
||||
)
|
||||
|
||||
# check if in cache
|
||||
cached_user_obj = self.user_api_key_cache.async_get_cache(key=user_id)
|
||||
if cached_user_obj is not None:
|
||||
if isinstance(cached_user_obj, dict):
|
||||
return LiteLLM_UserTable(**cached_user_obj)
|
||||
elif isinstance(cached_user_obj, LiteLLM_UserTable):
|
||||
return cached_user_obj
|
||||
# else, check db
|
||||
try:
|
||||
response = await self.prisma_client.db.litellm_usertable.find_unique(
|
||||
where={"user_id": user_id}
|
||||
)
|
||||
|
||||
if response is None:
|
||||
raise Exception
|
||||
|
||||
return LiteLLM_UserTable(**response.dict())
|
||||
except Exception as e:
|
||||
raise Exception(
|
||||
f"User doesn't exist in db. User={user_id}. Create user via `/user/new` call."
|
||||
)
|
||||
|
||||
|
||||
async def get_team_object(
|
||||
team_id: str,
|
||||
prisma_client: Optional[PrismaClient],
|
||||
user_api_key_cache: DualCache,
|
||||
) -> LiteLLM_TeamTable:
|
||||
"""
|
||||
- Check if team id in proxy Team Table
|
||||
- if valid, return LiteLLM_TeamTable object with defined limits
|
||||
- if not, then raise an error
|
||||
"""
|
||||
if prisma_client is None:
|
||||
raise Exception(
|
||||
"No DB Connected. See - https://docs.litellm.ai/docs/proxy/virtual_keys"
|
||||
)
|
||||
|
||||
# check if in cache
|
||||
cached_team_obj = user_api_key_cache.async_get_cache(key=team_id)
|
||||
if cached_team_obj is not None:
|
||||
if isinstance(cached_team_obj, dict):
|
||||
return LiteLLM_TeamTable(**cached_team_obj)
|
||||
elif isinstance(cached_team_obj, LiteLLM_TeamTable):
|
||||
return cached_team_obj
|
||||
# else, check db
|
||||
try:
|
||||
response = await prisma_client.db.litellm_teamtable.find_unique(
|
||||
where={"team_id": team_id}
|
||||
)
|
||||
|
||||
if response is None:
|
||||
raise Exception
|
||||
|
||||
return LiteLLM_TeamTable(**response.dict())
|
||||
except Exception as e:
|
||||
raise Exception(
|
||||
f"Team doesn't exist in db. Team={team_id}. Create team via `/team/new` call."
|
||||
)
|
||||
|
|
|
@ -12,7 +12,7 @@ import json
|
|||
import os
|
||||
from litellm.caching import DualCache
|
||||
from litellm._logging import verbose_proxy_logger
|
||||
from litellm.proxy._types import LiteLLMProxyRoles, LiteLLM_UserTable
|
||||
from litellm.proxy._types import LiteLLM_JWTAuth, LiteLLM_UserTable
|
||||
from litellm.proxy.utils import PrismaClient
|
||||
from typing import Optional
|
||||
|
||||
|
@ -70,68 +70,43 @@ class JWTHandler:
|
|||
self,
|
||||
prisma_client: Optional[PrismaClient],
|
||||
user_api_key_cache: DualCache,
|
||||
litellm_proxy_roles: LiteLLMProxyRoles,
|
||||
litellm_jwtauth: LiteLLM_JWTAuth,
|
||||
) -> None:
|
||||
self.prisma_client = prisma_client
|
||||
self.user_api_key_cache = user_api_key_cache
|
||||
self.litellm_proxy_roles = litellm_proxy_roles
|
||||
self.litellm_jwtauth = litellm_jwtauth
|
||||
|
||||
def is_jwt(self, token: str):
|
||||
parts = token.split(".")
|
||||
return len(parts) == 3
|
||||
|
||||
def is_admin(self, scopes: list) -> bool:
|
||||
if self.litellm_proxy_roles.proxy_admin in scopes:
|
||||
if self.litellm_jwtauth.admin_jwt_scope in scopes:
|
||||
return True
|
||||
return False
|
||||
|
||||
def get_user_id(self, token: dict, default_value: str) -> str:
|
||||
def is_team(self, scopes: list) -> bool:
|
||||
if self.litellm_jwtauth.team_jwt_scope in scopes:
|
||||
return True
|
||||
return False
|
||||
|
||||
def get_end_user_id(self, token: dict, default_value: Optional[str]) -> str:
|
||||
try:
|
||||
user_id = token["sub"]
|
||||
if self.litellm_jwtauth.end_user_id_jwt_field is not None:
|
||||
user_id = token[self.litellm_jwtauth.end_user_id_jwt_field]
|
||||
else:
|
||||
user_id = None
|
||||
except KeyError:
|
||||
user_id = default_value
|
||||
return user_id
|
||||
|
||||
def get_team_id(self, token: dict, default_value: Optional[str]) -> Optional[str]:
|
||||
try:
|
||||
team_id = token["client_id"]
|
||||
team_id = token[self.litellm_jwtauth.team_id_jwt_field]
|
||||
except KeyError:
|
||||
team_id = default_value
|
||||
return team_id
|
||||
|
||||
async def get_user_object(self, user_id: str) -> LiteLLM_UserTable:
|
||||
"""
|
||||
- Check if user id in proxy User Table
|
||||
- if valid, return LiteLLM_UserTable object with defined limits
|
||||
- if not, then raise an error
|
||||
"""
|
||||
if self.prisma_client is None:
|
||||
raise Exception(
|
||||
"No DB Connected. See - https://docs.litellm.ai/docs/proxy/virtual_keys"
|
||||
)
|
||||
|
||||
# check if in cache
|
||||
cached_user_obj = self.user_api_key_cache.async_get_cache(key=user_id)
|
||||
if cached_user_obj is not None:
|
||||
if isinstance(cached_user_obj, dict):
|
||||
return LiteLLM_UserTable(**cached_user_obj)
|
||||
elif isinstance(cached_user_obj, LiteLLM_UserTable):
|
||||
return cached_user_obj
|
||||
# else, check db
|
||||
try:
|
||||
response = await self.prisma_client.db.litellm_usertable.find_unique(
|
||||
where={"user_id": user_id}
|
||||
)
|
||||
|
||||
if response is None:
|
||||
raise Exception
|
||||
|
||||
return LiteLLM_UserTable(**response.dict())
|
||||
except Exception as e:
|
||||
raise Exception(
|
||||
f"User doesn't exist in db. User={user_id}. Create user via `/user/new` call."
|
||||
)
|
||||
|
||||
def get_scopes(self, token: dict) -> list:
|
||||
try:
|
||||
if isinstance(token["scope"], str):
|
||||
|
@ -162,7 +137,9 @@ class JWTHandler:
|
|||
keys = response.json()["keys"]
|
||||
|
||||
await self.user_api_key_cache.async_set_cache(
|
||||
key="litellm_jwt_auth_keys", value=keys, ttl=600 # cache for 10 mins
|
||||
key="litellm_jwt_auth_keys",
|
||||
value=keys,
|
||||
ttl=self.litellm_jwtauth.public_key_ttl, # cache for 10 mins
|
||||
)
|
||||
else:
|
||||
keys = cached_keys
|
||||
|
|
|
@ -113,7 +113,10 @@ from litellm.proxy.hooks.prompt_injection_detection import (
|
|||
from litellm.proxy.auth.auth_checks import (
|
||||
common_checks,
|
||||
get_end_user_object,
|
||||
get_team_object,
|
||||
get_user_object,
|
||||
allowed_routes_check,
|
||||
get_actual_routes,
|
||||
)
|
||||
|
||||
try:
|
||||
|
@ -359,71 +362,99 @@ async def user_api_key_auth(
|
|||
scopes = jwt_handler.get_scopes(token=valid_token)
|
||||
# check if admin
|
||||
is_admin = jwt_handler.is_admin(scopes=scopes)
|
||||
# get user id
|
||||
user_id = jwt_handler.get_user_id(
|
||||
token=valid_token, default_value=litellm_proxy_admin_name
|
||||
# if admin return
|
||||
if is_admin:
|
||||
# check allowed admin routes
|
||||
is_allowed = allowed_routes_check(
|
||||
user_role="proxy_admin",
|
||||
user_route=route,
|
||||
litellm_proxy_roles=jwt_handler.litellm_jwtauth,
|
||||
)
|
||||
if is_allowed:
|
||||
return UserAPIKeyAuth()
|
||||
else:
|
||||
allowed_routes = (
|
||||
jwt_handler.litellm_jwtauth.admin_allowed_routes
|
||||
)
|
||||
actual_routes = get_actual_routes(allowed_routes=allowed_routes)
|
||||
raise Exception(
|
||||
f"Admin not allowed to access this route. Route={route}, Allowed Routes={actual_routes}"
|
||||
)
|
||||
# check if team in scopes
|
||||
is_team = jwt_handler.is_team(scopes=scopes)
|
||||
if is_team == False:
|
||||
raise Exception(
|
||||
f"Missing both Admin and Team scopes from token. Either is required. Admin Scope={jwt_handler.litellm_jwtauth.admin_jwt_scope}, Team Scope={jwt_handler.litellm_jwtauth.team_jwt_scope}"
|
||||
)
|
||||
# get team id
|
||||
team_id = jwt_handler.get_team_id(token=valid_token, default_value=None)
|
||||
|
||||
if team_id is None:
|
||||
raise Exception(
|
||||
f"No team id passed in. Field checked in jwt token - '{jwt_handler.litellm_jwtauth.team_id_jwt_field}'"
|
||||
)
|
||||
# check allowed team routes
|
||||
is_allowed = allowed_routes_check(
|
||||
user_role="team",
|
||||
user_route=route,
|
||||
litellm_proxy_roles=jwt_handler.litellm_jwtauth,
|
||||
)
|
||||
if is_allowed == False:
|
||||
allowed_routes = jwt_handler.litellm_jwtauth.team_allowed_routes
|
||||
actual_routes = get_actual_routes(allowed_routes=allowed_routes)
|
||||
raise Exception(
|
||||
f"Team not allowed to access this route. Route={route}, Allowed Routes={actual_routes}"
|
||||
)
|
||||
|
||||
# check if team in db
|
||||
team_object = await get_team_object(
|
||||
team_id=team_id,
|
||||
prisma_client=prisma_client,
|
||||
user_api_key_cache=user_api_key_cache,
|
||||
)
|
||||
|
||||
end_user_object = None
|
||||
# common checks
|
||||
# allow request
|
||||
|
||||
# get the request body
|
||||
request_data = await _read_request_body(request=request)
|
||||
# get user obj from cache/db -> run for admin too. Ensures, admin client id in db.
|
||||
user_object = await jwt_handler.get_user_object(user_id=user_id)
|
||||
if (
|
||||
request_data.get("user", None)
|
||||
and request_data["user"] != user_object.user_id
|
||||
):
|
||||
|
||||
end_user_object = None
|
||||
end_user_id = jwt_handler.get_end_user_id(
|
||||
token=valid_token, default_value=None
|
||||
)
|
||||
if end_user_id is not None:
|
||||
# get the end-user object
|
||||
end_user_object = await get_end_user_object(
|
||||
end_user_id=request_data["user"],
|
||||
end_user_id=end_user_id,
|
||||
prisma_client=prisma_client,
|
||||
user_api_key_cache=user_api_key_cache,
|
||||
)
|
||||
# save the end-user object to cache
|
||||
await user_api_key_cache.async_set_cache(
|
||||
key=request_data["user"], value=end_user_object
|
||||
key=end_user_id, value=end_user_object
|
||||
)
|
||||
|
||||
# run through common checks
|
||||
_ = common_checks(
|
||||
request_body=request_data,
|
||||
user_object=user_object,
|
||||
team_object=team_object,
|
||||
end_user_object=end_user_object,
|
||||
)
|
||||
# save user object in cache
|
||||
await user_api_key_cache.async_set_cache(
|
||||
key=user_object.user_id, value=user_object
|
||||
key=team_object.team_id, value=team_object
|
||||
)
|
||||
|
||||
# return UserAPIKeyAuth object
|
||||
return UserAPIKeyAuth(
|
||||
api_key=None,
|
||||
team_id=team_object.team_id,
|
||||
tpm_limit=team_object.tpm_limit,
|
||||
rpm_limit=team_object.rpm_limit,
|
||||
models=team_object.models,
|
||||
user_role="app_owner",
|
||||
)
|
||||
# if admin return
|
||||
if is_admin:
|
||||
return UserAPIKeyAuth(
|
||||
api_key=api_key,
|
||||
user_role="proxy_admin",
|
||||
user_id=user_id,
|
||||
)
|
||||
else:
|
||||
is_allowed = allowed_routes_check(
|
||||
user_role="app_owner",
|
||||
route=route,
|
||||
allowed_routes=general_settings.get("allowed_routes", None),
|
||||
)
|
||||
if is_allowed:
|
||||
# return UserAPIKeyAuth object
|
||||
return UserAPIKeyAuth(
|
||||
api_key=None,
|
||||
user_id=user_object.user_id,
|
||||
tpm_limit=user_object.tpm_limit,
|
||||
rpm_limit=user_object.rpm_limit,
|
||||
models=user_object.models,
|
||||
user_role="app_owner",
|
||||
)
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=401,
|
||||
detail={
|
||||
"error": f"User={user_object.user_id} not allowed to access this route={route}."
|
||||
},
|
||||
)
|
||||
#### ELSE ####
|
||||
if master_key is None:
|
||||
if isinstance(api_key, str):
|
||||
|
@ -2698,12 +2729,14 @@ async def startup_event():
|
|||
proxy_logging_obj._init_litellm_callbacks() # INITIALIZE LITELLM CALLBACKS ON SERVER STARTUP <- do this to catch any logging errors on startup, not when calls are being made
|
||||
|
||||
## JWT AUTH ##
|
||||
if general_settings.get("litellm_jwtauth", None) is not None:
|
||||
litellm_jwtauth = LiteLLM_JWTAuth(**general_settings["litellm_jwtauth"])
|
||||
else:
|
||||
litellm_jwtauth = LiteLLM_JWTAuth()
|
||||
jwt_handler.update_environment(
|
||||
prisma_client=prisma_client,
|
||||
user_api_key_cache=user_api_key_cache,
|
||||
litellm_proxy_roles=LiteLLMProxyRoles(
|
||||
**general_settings.get("litellm_proxy_roles", {})
|
||||
),
|
||||
litellm_jwtauth=litellm_jwtauth,
|
||||
)
|
||||
|
||||
if use_background_health_checks:
|
||||
|
|
|
@ -12,7 +12,7 @@ sys.path.insert(
|
|||
0, os.path.abspath("../..")
|
||||
) # Adds the parent directory to the system path
|
||||
import pytest
|
||||
from litellm.proxy._types import LiteLLMProxyRoles
|
||||
from litellm.proxy._types import LiteLLM_JWTAuth
|
||||
from litellm.proxy.auth.handle_jwt import JWTHandler
|
||||
from litellm.caching import DualCache
|
||||
from datetime import datetime, timedelta
|
||||
|
@ -28,17 +28,17 @@ public_key = {
|
|||
def test_load_config_with_custom_role_names():
|
||||
config = {
|
||||
"general_settings": {
|
||||
"litellm_proxy_roles": {"proxy_admin": "litellm-proxy-admin"}
|
||||
"litellm_proxy_roles": {"admin_jwt_scope": "litellm-proxy-admin"}
|
||||
}
|
||||
}
|
||||
|
||||
proxy_roles = LiteLLMProxyRoles(
|
||||
proxy_roles = LiteLLM_JWTAuth(
|
||||
**config.get("general_settings", {}).get("litellm_proxy_roles", {})
|
||||
)
|
||||
|
||||
print(f"proxy_roles: {proxy_roles}")
|
||||
|
||||
assert proxy_roles.proxy_admin == "litellm-proxy-admin"
|
||||
assert proxy_roles.admin_jwt_scope == "litellm-proxy-admin"
|
||||
|
||||
|
||||
# test_load_config_with_custom_role_names()
|
||||
|
|
|
@ -500,7 +500,10 @@ class ModelResponse(OpenAIObject):
|
|||
if choices is not None and isinstance(choices, list):
|
||||
new_choices = []
|
||||
for choice in choices:
|
||||
_new_choice = StreamingChoices(**choice)
|
||||
if isinstance(choice, StreamingChoices):
|
||||
_new_choice = choice
|
||||
elif isinstance(choice, dict):
|
||||
_new_choice = StreamingChoices(**choice)
|
||||
new_choices.append(_new_choice)
|
||||
choices = new_choices
|
||||
else:
|
||||
|
@ -513,7 +516,10 @@ class ModelResponse(OpenAIObject):
|
|||
if choices is not None and isinstance(choices, list):
|
||||
new_choices = []
|
||||
for choice in choices:
|
||||
_new_choice = Choices(**choice)
|
||||
if isinstance(choice, Choices):
|
||||
_new_choice = choice
|
||||
elif isinstance(choice, dict):
|
||||
_new_choice = Choices(**choice)
|
||||
new_choices.append(_new_choice)
|
||||
choices = new_choices
|
||||
else:
|
||||
|
@ -7231,7 +7237,7 @@ def exception_type(
|
|||
exception_mapping_worked = True
|
||||
raise APIError(
|
||||
status_code=original_exception.status_code,
|
||||
message=f"AnthropicException - {original_exception.message}",
|
||||
message=f"AnthropicException - {original_exception.message}. Handle with `litellm.APIError`.",
|
||||
llm_provider="anthropic",
|
||||
model=model,
|
||||
request=original_exception.request,
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue