forked from phoenix/litellm-mirror
Merge pull request #2970 from BerriAI/litellm_keys
fix(handle_jwt.py): User cost tracking via JWT Auth
This commit is contained in:
commit
d89644d46c
7 changed files with 263 additions and 35 deletions
|
@ -40,9 +40,10 @@ general_settings:
|
||||||
allow_user_auth: true
|
allow_user_auth: true
|
||||||
alerting: ["slack"]
|
alerting: ["slack"]
|
||||||
# store_model_in_db: True // set via environment variable - os.environ["STORE_MODEL_IN_DB"] = "True"
|
# store_model_in_db: True // set via environment variable - os.environ["STORE_MODEL_IN_DB"] = "True"
|
||||||
# proxy_batch_write_at: 60 # 👈 Frequency of batch writing logs to server (in seconds)
|
proxy_batch_write_at: 5 # 👈 Frequency of batch writing logs to server (in seconds)
|
||||||
enable_jwt_auth: True
|
enable_jwt_auth: True
|
||||||
alerting: ["slack"]
|
alerting: ["slack"]
|
||||||
litellm_jwtauth:
|
litellm_jwtauth:
|
||||||
admin_jwt_scope: "litellm_proxy_admin"
|
admin_jwt_scope: "litellm_proxy_admin"
|
||||||
public_key_ttl: os.environ/LITELLM_PUBLIC_KEY_TTL
|
public_key_ttl: os.environ/LITELLM_PUBLIC_KEY_TTL
|
||||||
|
user_id_jwt_field: "sub"
|
|
@ -124,7 +124,8 @@ class LiteLLM_JWTAuth(LiteLLMBase):
|
||||||
- team_jwt_scope: The JWT scope required for proxy team roles.
|
- team_jwt_scope: The JWT scope required for proxy team roles.
|
||||||
- team_id_jwt_field: The field in the JWT token that stores the team ID. Default - `client_id`.
|
- team_id_jwt_field: The field in the JWT token that stores the team ID. Default - `client_id`.
|
||||||
- team_allowed_routes: list of allowed routes for proxy team roles.
|
- team_allowed_routes: list of allowed routes for proxy team roles.
|
||||||
- end_user_id_jwt_field: Default - `sub`. The field in the JWT token that stores the end-user ID. Turn this off by setting to `None`. Enables end-user cost tracking.
|
- user_id_jwt_field: The field in the JWT token that stores the user id (maps to `LiteLLMUserTable`). Use this for internal employees.
|
||||||
|
- end_user_id_jwt_field: The field in the JWT token that stores the end-user ID (maps to `LiteLLMEndUserTable`). Turn this off by setting to `None`. Enables end-user cost tracking. Use this for external customers.
|
||||||
- public_key_ttl: Default - 600s. TTL for caching public JWT keys.
|
- public_key_ttl: Default - 600s. TTL for caching public JWT keys.
|
||||||
|
|
||||||
See `auth_checks.py` for the specific routes
|
See `auth_checks.py` for the specific routes
|
||||||
|
@ -139,7 +140,8 @@ class LiteLLM_JWTAuth(LiteLLMBase):
|
||||||
team_allowed_routes: List[
|
team_allowed_routes: List[
|
||||||
Literal["openai_routes", "info_routes", "management_routes"]
|
Literal["openai_routes", "info_routes", "management_routes"]
|
||||||
] = ["openai_routes", "info_routes"]
|
] = ["openai_routes", "info_routes"]
|
||||||
end_user_id_jwt_field: Optional[str] = "sub"
|
user_id_jwt_field: Optional[str] = None
|
||||||
|
end_user_id_jwt_field: Optional[str] = None
|
||||||
public_key_ttl: float = 600
|
public_key_ttl: float = 600
|
||||||
|
|
||||||
def __init__(self, **kwargs: Any) -> None:
|
def __init__(self, **kwargs: Any) -> None:
|
||||||
|
|
|
@ -26,6 +26,7 @@ all_routes = LiteLLMRoutes.openai_routes.value + LiteLLMRoutes.management_routes
|
||||||
def common_checks(
|
def common_checks(
|
||||||
request_body: dict,
|
request_body: dict,
|
||||||
team_object: LiteLLM_TeamTable,
|
team_object: LiteLLM_TeamTable,
|
||||||
|
user_object: Optional[LiteLLM_UserTable],
|
||||||
end_user_object: Optional[LiteLLM_EndUserTable],
|
end_user_object: Optional[LiteLLM_EndUserTable],
|
||||||
global_proxy_spend: Optional[float],
|
global_proxy_spend: Optional[float],
|
||||||
general_settings: dict,
|
general_settings: dict,
|
||||||
|
@ -37,7 +38,8 @@ def common_checks(
|
||||||
1. If team is blocked
|
1. If team is blocked
|
||||||
2. If team can call model
|
2. If team can call model
|
||||||
3. If team is in budget
|
3. If team is in budget
|
||||||
4. If end_user ('user' passed to /chat/completions, /embeddings endpoint) is in budget
|
5. If user passed in (JWT or key.user_id) - is in budget
|
||||||
|
4. If end_user (either via JWT or 'user' passed to /chat/completions, /embeddings endpoint) is in budget
|
||||||
5. [OPTIONAL] If 'enforce_end_user' enabled - did developer pass in 'user' param for openai endpoints
|
5. [OPTIONAL] If 'enforce_end_user' enabled - did developer pass in 'user' param for openai endpoints
|
||||||
6. [OPTIONAL] If 'litellm.max_budget' is set (>0), is proxy under budget
|
6. [OPTIONAL] If 'litellm.max_budget' is set (>0), is proxy under budget
|
||||||
"""
|
"""
|
||||||
|
@ -69,14 +71,20 @@ def common_checks(
|
||||||
raise Exception(
|
raise Exception(
|
||||||
f"Team={team_object.team_id} over budget. Spend={team_object.spend}, Budget={team_object.max_budget}"
|
f"Team={team_object.team_id} over budget. Spend={team_object.spend}, Budget={team_object.max_budget}"
|
||||||
)
|
)
|
||||||
# 4. If end_user ('user' passed to /chat/completions, /embeddings endpoint) is in budget
|
if user_object is not None and user_object.max_budget is not None:
|
||||||
|
user_budget = user_object.max_budget
|
||||||
|
if user_budget > user_object.spend:
|
||||||
|
raise Exception(
|
||||||
|
f"ExceededBudget: User={user_object.user_id} over budget. Spend={user_object.spend}, Budget={user_budget}"
|
||||||
|
)
|
||||||
|
# 5. If end_user ('user' passed to /chat/completions, /embeddings endpoint) is in budget
|
||||||
if end_user_object is not None and end_user_object.litellm_budget_table is not None:
|
if end_user_object is not None and end_user_object.litellm_budget_table is not None:
|
||||||
end_user_budget = end_user_object.litellm_budget_table.max_budget
|
end_user_budget = end_user_object.litellm_budget_table.max_budget
|
||||||
if end_user_budget is not None and end_user_object.spend > end_user_budget:
|
if end_user_budget is not None and end_user_object.spend > end_user_budget:
|
||||||
raise Exception(
|
raise Exception(
|
||||||
f"ExceededBudget: End User={end_user_object.user_id} over budget. Spend={end_user_object.spend}, Budget={end_user_budget}"
|
f"ExceededBudget: End User={end_user_object.user_id} over budget. Spend={end_user_object.spend}, Budget={end_user_budget}"
|
||||||
)
|
)
|
||||||
# 5. [OPTIONAL] If 'enforce_user_param' enabled - did developer pass in 'user' param for openai endpoints
|
# 6. [OPTIONAL] If 'enforce_user_param' enabled - did developer pass in 'user' param for openai endpoints
|
||||||
if (
|
if (
|
||||||
general_settings.get("enforce_user_param", None) is not None
|
general_settings.get("enforce_user_param", None) is not None
|
||||||
and general_settings["enforce_user_param"] == True
|
and general_settings["enforce_user_param"] == True
|
||||||
|
@ -85,7 +93,7 @@ def common_checks(
|
||||||
raise Exception(
|
raise Exception(
|
||||||
f"'user' param not passed in. 'enforce_user_param'={general_settings['enforce_user_param']}"
|
f"'user' param not passed in. 'enforce_user_param'={general_settings['enforce_user_param']}"
|
||||||
)
|
)
|
||||||
# 6. [OPTIONAL] If 'litellm.max_budget' is set (>0), is proxy under budget
|
# 7. [OPTIONAL] If 'litellm.max_budget' is set (>0), is proxy under budget
|
||||||
if litellm.max_budget > 0 and global_proxy_spend is not None:
|
if litellm.max_budget > 0 and global_proxy_spend is not None:
|
||||||
if global_proxy_spend > litellm.max_budget:
|
if global_proxy_spend > litellm.max_budget:
|
||||||
raise Exception(
|
raise Exception(
|
||||||
|
@ -204,19 +212,24 @@ async def get_end_user_object(
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
async def get_user_object(self, user_id: str) -> LiteLLM_UserTable:
|
async def get_user_object(
|
||||||
|
user_id: str,
|
||||||
|
prisma_client: Optional[PrismaClient],
|
||||||
|
user_api_key_cache: DualCache,
|
||||||
|
) -> Optional[LiteLLM_UserTable]:
|
||||||
"""
|
"""
|
||||||
- Check if user id in proxy User Table
|
- Check if user id in proxy User Table
|
||||||
- if valid, return LiteLLM_UserTable object with defined limits
|
- if valid, return LiteLLM_UserTable object with defined limits
|
||||||
- if not, then raise an error
|
- if not, then raise an error
|
||||||
"""
|
"""
|
||||||
if self.prisma_client is None:
|
if prisma_client is None:
|
||||||
raise Exception(
|
raise Exception("No db connected")
|
||||||
"No DB Connected. See - https://docs.litellm.ai/docs/proxy/virtual_keys"
|
|
||||||
)
|
if user_id is None:
|
||||||
|
return None
|
||||||
|
|
||||||
# check if in cache
|
# check if in cache
|
||||||
cached_user_obj = self.user_api_key_cache.async_get_cache(key=user_id)
|
cached_user_obj = user_api_key_cache.async_get_cache(key=user_id)
|
||||||
if cached_user_obj is not None:
|
if cached_user_obj is not None:
|
||||||
if isinstance(cached_user_obj, dict):
|
if isinstance(cached_user_obj, dict):
|
||||||
return LiteLLM_UserTable(**cached_user_obj)
|
return LiteLLM_UserTable(**cached_user_obj)
|
||||||
|
@ -224,7 +237,7 @@ async def get_user_object(self, user_id: str) -> LiteLLM_UserTable:
|
||||||
return cached_user_obj
|
return cached_user_obj
|
||||||
# else, check db
|
# else, check db
|
||||||
try:
|
try:
|
||||||
response = await self.prisma_client.db.litellm_usertable.find_unique(
|
response = await prisma_client.db.litellm_usertable.find_unique(
|
||||||
where={"user_id": user_id}
|
where={"user_id": user_id}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -232,9 +245,9 @@ async def get_user_object(self, user_id: str) -> LiteLLM_UserTable:
|
||||||
raise Exception
|
raise Exception
|
||||||
|
|
||||||
return LiteLLM_UserTable(**response.dict())
|
return LiteLLM_UserTable(**response.dict())
|
||||||
except Exception as e:
|
except Exception as e: # if end-user not in db
|
||||||
raise Exception(
|
raise Exception(
|
||||||
f"User doesn't exist in db. User={user_id}. Create user via `/user/new` call."
|
f"User doesn't exist in db. 'user_id'={user_id}. Create user via `/user/new` call."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -74,6 +74,16 @@ class JWTHandler:
|
||||||
team_id = default_value
|
team_id = default_value
|
||||||
return team_id
|
return team_id
|
||||||
|
|
||||||
|
def get_user_id(self, token: dict, default_value: Optional[str]) -> Optional[str]:
|
||||||
|
try:
|
||||||
|
if self.litellm_jwtauth.user_id_jwt_field is not None:
|
||||||
|
user_id = token[self.litellm_jwtauth.user_id_jwt_field]
|
||||||
|
else:
|
||||||
|
user_id = None
|
||||||
|
except KeyError:
|
||||||
|
user_id = default_value
|
||||||
|
return user_id
|
||||||
|
|
||||||
def get_scopes(self, token: dict) -> list:
|
def get_scopes(self, token: dict) -> list:
|
||||||
try:
|
try:
|
||||||
if isinstance(token["scope"], str):
|
if isinstance(token["scope"], str):
|
||||||
|
@ -101,7 +111,11 @@ class JWTHandler:
|
||||||
if cached_keys is None:
|
if cached_keys is None:
|
||||||
response = await self.http_handler.get(keys_url)
|
response = await self.http_handler.get(keys_url)
|
||||||
|
|
||||||
|
response_json = response.json()
|
||||||
|
if "keys" in response_json:
|
||||||
keys = response.json()["keys"]
|
keys = response.json()["keys"]
|
||||||
|
else:
|
||||||
|
keys = response_json
|
||||||
|
|
||||||
await self.user_api_key_cache.async_set_cache(
|
await self.user_api_key_cache.async_set_cache(
|
||||||
key="litellm_jwt_auth_keys",
|
key="litellm_jwt_auth_keys",
|
||||||
|
|
|
@ -422,12 +422,21 @@ async def user_api_key_auth(
|
||||||
user_api_key_cache=user_api_key_cache,
|
user_api_key_cache=user_api_key_cache,
|
||||||
)
|
)
|
||||||
|
|
||||||
# common checks
|
# [OPTIONAL] track spend against an internal employee - `LiteLLM_UserTable`
|
||||||
# allow request
|
user_object = None
|
||||||
|
user_id = jwt_handler.get_user_id(token=valid_token, default_value=None)
|
||||||
# get the request body
|
if user_id is not None:
|
||||||
request_data = await _read_request_body(request=request)
|
# get the user object
|
||||||
|
user_object = await get_user_object(
|
||||||
|
user_id=user_id,
|
||||||
|
prisma_client=prisma_client,
|
||||||
|
user_api_key_cache=user_api_key_cache,
|
||||||
|
)
|
||||||
|
# save the user object to cache
|
||||||
|
await user_api_key_cache.async_set_cache(
|
||||||
|
key=user_id, value=user_object
|
||||||
|
)
|
||||||
|
# [OPTIONAL] track spend against an external user - `LiteLLM_EndUserTable`
|
||||||
end_user_object = None
|
end_user_object = None
|
||||||
end_user_id = jwt_handler.get_end_user_id(
|
end_user_id = jwt_handler.get_end_user_id(
|
||||||
token=valid_token, default_value=None
|
token=valid_token, default_value=None
|
||||||
|
@ -445,7 +454,6 @@ async def user_api_key_auth(
|
||||||
)
|
)
|
||||||
|
|
||||||
global_proxy_spend = None
|
global_proxy_spend = None
|
||||||
|
|
||||||
if litellm.max_budget > 0: # user set proxy max budget
|
if litellm.max_budget > 0: # user set proxy max budget
|
||||||
# check cache
|
# check cache
|
||||||
global_proxy_spend = await user_api_key_cache.async_get_cache(
|
global_proxy_spend = await user_api_key_cache.async_get_cache(
|
||||||
|
@ -480,16 +488,20 @@ async def user_api_key_auth(
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# get the request body
|
||||||
|
request_data = await _read_request_body(request=request)
|
||||||
|
|
||||||
# run through common checks
|
# run through common checks
|
||||||
_ = common_checks(
|
_ = common_checks(
|
||||||
request_body=request_data,
|
request_body=request_data,
|
||||||
team_object=team_object,
|
team_object=team_object,
|
||||||
|
user_object=user_object,
|
||||||
end_user_object=end_user_object,
|
end_user_object=end_user_object,
|
||||||
general_settings=general_settings,
|
general_settings=general_settings,
|
||||||
global_proxy_spend=global_proxy_spend,
|
global_proxy_spend=global_proxy_spend,
|
||||||
route=route,
|
route=route,
|
||||||
)
|
)
|
||||||
# save user object in cache
|
# save team object in cache
|
||||||
await user_api_key_cache.async_set_cache(
|
await user_api_key_cache.async_set_cache(
|
||||||
key=team_object.team_id, value=team_object
|
key=team_object.team_id, value=team_object
|
||||||
)
|
)
|
||||||
|
@ -502,6 +514,7 @@ async def user_api_key_auth(
|
||||||
team_rpm_limit=team_object.rpm_limit,
|
team_rpm_limit=team_object.rpm_limit,
|
||||||
team_models=team_object.models,
|
team_models=team_object.models,
|
||||||
user_role="app_owner",
|
user_role="app_owner",
|
||||||
|
user_id=user_id,
|
||||||
)
|
)
|
||||||
#### ELSE ####
|
#### ELSE ####
|
||||||
if master_key is None:
|
if master_key is None:
|
||||||
|
@ -954,6 +967,7 @@ async def user_api_key_auth(
|
||||||
_ = common_checks(
|
_ = common_checks(
|
||||||
request_body=request_data,
|
request_body=request_data,
|
||||||
team_object=_team_obj,
|
team_object=_team_obj,
|
||||||
|
user_object=None,
|
||||||
end_user_object=_end_user_object,
|
end_user_object=_end_user_object,
|
||||||
general_settings=general_settings,
|
general_settings=general_settings,
|
||||||
global_proxy_spend=global_proxy_spend,
|
global_proxy_spend=global_proxy_spend,
|
||||||
|
@ -1328,8 +1342,6 @@ async def update_database(
|
||||||
existing_token_obj = await user_api_key_cache.async_get_cache(
|
existing_token_obj = await user_api_key_cache.async_get_cache(
|
||||||
key=hashed_token
|
key=hashed_token
|
||||||
)
|
)
|
||||||
if existing_token_obj is None:
|
|
||||||
return
|
|
||||||
existing_user_obj = await user_api_key_cache.async_get_cache(key=user_id)
|
existing_user_obj = await user_api_key_cache.async_get_cache(key=user_id)
|
||||||
if existing_user_obj is not None and isinstance(existing_user_obj, dict):
|
if existing_user_obj is not None and isinstance(existing_user_obj, dict):
|
||||||
existing_user_obj = LiteLLM_UserTable(**existing_user_obj)
|
existing_user_obj = LiteLLM_UserTable(**existing_user_obj)
|
||||||
|
@ -1351,7 +1363,9 @@ async def update_database(
|
||||||
if end_user_id is not None:
|
if end_user_id is not None:
|
||||||
prisma_client.end_user_list_transactons[end_user_id] = (
|
prisma_client.end_user_list_transactons[end_user_id] = (
|
||||||
response_cost
|
response_cost
|
||||||
+ prisma_client.user_list_transactons.get(end_user_id, 0)
|
+ prisma_client.end_user_list_transactons.get(
|
||||||
|
end_user_id, 0
|
||||||
|
)
|
||||||
)
|
)
|
||||||
elif custom_db_client is not None:
|
elif custom_db_client is not None:
|
||||||
for id in user_ids:
|
for id in user_ids:
|
||||||
|
|
|
@ -345,3 +345,187 @@ async def test_team_token_output(prisma_client):
|
||||||
assert team_result.team_tpm_limit == 100
|
assert team_result.team_tpm_limit == 100
|
||||||
assert team_result.team_rpm_limit == 99
|
assert team_result.team_rpm_limit == 99
|
||||||
assert team_result.team_models == ["gpt-3.5-turbo", "gpt-4"]
|
assert team_result.team_models == ["gpt-3.5-turbo", "gpt-4"]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_user_token_output(prisma_client):
|
||||||
|
"""
|
||||||
|
- If user required, check if it exists
|
||||||
|
- fail initial request (when user doesn't exist)
|
||||||
|
- create user
|
||||||
|
- retry -> it should pass now
|
||||||
|
"""
|
||||||
|
import jwt, json
|
||||||
|
from cryptography.hazmat.primitives import serialization
|
||||||
|
from cryptography.hazmat.primitives.asymmetric import rsa
|
||||||
|
from cryptography.hazmat.backends import default_backend
|
||||||
|
from fastapi import Request
|
||||||
|
from starlette.datastructures import URL
|
||||||
|
from litellm.proxy.proxy_server import user_api_key_auth, new_team, new_user
|
||||||
|
from litellm.proxy._types import NewTeamRequest, UserAPIKeyAuth, NewUserRequest
|
||||||
|
import litellm
|
||||||
|
import uuid
|
||||||
|
|
||||||
|
setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client)
|
||||||
|
await litellm.proxy.proxy_server.prisma_client.connect()
|
||||||
|
|
||||||
|
# Generate a private / public key pair using RSA algorithm
|
||||||
|
key = rsa.generate_private_key(
|
||||||
|
public_exponent=65537, key_size=2048, backend=default_backend()
|
||||||
|
)
|
||||||
|
# Get private key in PEM format
|
||||||
|
private_key = key.private_bytes(
|
||||||
|
encoding=serialization.Encoding.PEM,
|
||||||
|
format=serialization.PrivateFormat.PKCS8,
|
||||||
|
encryption_algorithm=serialization.NoEncryption(),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Get public key in PEM format
|
||||||
|
public_key = key.public_key().public_bytes(
|
||||||
|
encoding=serialization.Encoding.PEM,
|
||||||
|
format=serialization.PublicFormat.SubjectPublicKeyInfo,
|
||||||
|
)
|
||||||
|
|
||||||
|
public_key_obj = serialization.load_pem_public_key(
|
||||||
|
public_key, backend=default_backend()
|
||||||
|
)
|
||||||
|
|
||||||
|
# Convert RSA public key object to JWK (JSON Web Key)
|
||||||
|
public_jwk = json.loads(jwt.algorithms.RSAAlgorithm.to_jwk(public_key_obj))
|
||||||
|
|
||||||
|
assert isinstance(public_jwk, dict)
|
||||||
|
|
||||||
|
# set cache
|
||||||
|
cache = DualCache()
|
||||||
|
|
||||||
|
await cache.async_set_cache(key="litellm_jwt_auth_keys", value=[public_jwk])
|
||||||
|
|
||||||
|
jwt_handler = JWTHandler()
|
||||||
|
|
||||||
|
jwt_handler.user_api_key_cache = cache
|
||||||
|
|
||||||
|
jwt_handler.litellm_jwtauth = LiteLLM_JWTAuth()
|
||||||
|
|
||||||
|
jwt_handler.litellm_jwtauth.user_id_jwt_field = "sub"
|
||||||
|
|
||||||
|
# VALID TOKEN
|
||||||
|
## GENERATE A TOKEN
|
||||||
|
# Assuming the current time is in UTC
|
||||||
|
expiration_time = int((datetime.utcnow() + timedelta(minutes=10)).timestamp())
|
||||||
|
|
||||||
|
team_id = f"team123_{uuid.uuid4()}"
|
||||||
|
user_id = f"user123_{uuid.uuid4()}"
|
||||||
|
payload = {
|
||||||
|
"sub": user_id,
|
||||||
|
"exp": expiration_time, # set the token to expire in 10 minutes
|
||||||
|
"scope": "litellm_team",
|
||||||
|
"client_id": team_id,
|
||||||
|
}
|
||||||
|
|
||||||
|
# Generate the JWT token
|
||||||
|
# But before, you should convert bytes to string
|
||||||
|
private_key_str = private_key.decode("utf-8")
|
||||||
|
|
||||||
|
## team token
|
||||||
|
token = jwt.encode(payload, private_key_str, algorithm="RS256")
|
||||||
|
|
||||||
|
## admin token
|
||||||
|
payload = {
|
||||||
|
"sub": user_id,
|
||||||
|
"exp": expiration_time, # set the token to expire in 10 minutes
|
||||||
|
"scope": "litellm_proxy_admin",
|
||||||
|
}
|
||||||
|
|
||||||
|
admin_token = jwt.encode(payload, private_key_str, algorithm="RS256")
|
||||||
|
|
||||||
|
## VERIFY IT WORKS
|
||||||
|
|
||||||
|
# verify token
|
||||||
|
|
||||||
|
response = await jwt_handler.auth_jwt(token=token)
|
||||||
|
|
||||||
|
## RUN IT THROUGH USER API KEY AUTH
|
||||||
|
|
||||||
|
"""
|
||||||
|
- 1. Initial call should fail -> team doesn't exist
|
||||||
|
- 2. Create team via admin token
|
||||||
|
- 3. 2nd call w/ same team -> call should fail -> user doesn't exist
|
||||||
|
- 4. Create user via admin token
|
||||||
|
- 5. 3rd call w/ same team, same user -> call should succeed
|
||||||
|
- 6. assert user api key auth format
|
||||||
|
"""
|
||||||
|
|
||||||
|
bearer_token = "Bearer " + token
|
||||||
|
|
||||||
|
request = Request(scope={"type": "http"})
|
||||||
|
request._url = URL(url="/chat/completions")
|
||||||
|
|
||||||
|
## 1. INITIAL TEAM CALL - should fail
|
||||||
|
# use generated key to auth in
|
||||||
|
setattr(litellm.proxy.proxy_server, "general_settings", {"enable_jwt_auth": True})
|
||||||
|
setattr(litellm.proxy.proxy_server, "jwt_handler", jwt_handler)
|
||||||
|
try:
|
||||||
|
result = await user_api_key_auth(request=request, api_key=bearer_token)
|
||||||
|
pytest.fail("Team doesn't exist. This should fail")
|
||||||
|
except Exception as e:
|
||||||
|
pass
|
||||||
|
|
||||||
|
## 2. CREATE TEAM W/ ADMIN TOKEN - should succeed
|
||||||
|
try:
|
||||||
|
bearer_token = "Bearer " + admin_token
|
||||||
|
|
||||||
|
request._url = URL(url="/team/new")
|
||||||
|
result = await user_api_key_auth(request=request, api_key=bearer_token)
|
||||||
|
await new_team(
|
||||||
|
data=NewTeamRequest(
|
||||||
|
team_id=team_id,
|
||||||
|
tpm_limit=100,
|
||||||
|
rpm_limit=99,
|
||||||
|
models=["gpt-3.5-turbo", "gpt-4"],
|
||||||
|
),
|
||||||
|
user_api_key_dict=result,
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
pytest.fail(f"This should not fail - {str(e)}")
|
||||||
|
|
||||||
|
## 3. 2nd CALL W/ TEAM TOKEN - should fail
|
||||||
|
bearer_token = "Bearer " + token
|
||||||
|
request._url = URL(url="/chat/completions")
|
||||||
|
try:
|
||||||
|
team_result: UserAPIKeyAuth = await user_api_key_auth(
|
||||||
|
request=request, api_key=bearer_token
|
||||||
|
)
|
||||||
|
pytest.fail(f"User doesn't exist. this should fail")
|
||||||
|
except Exception as e:
|
||||||
|
pass
|
||||||
|
|
||||||
|
## 4. Create user
|
||||||
|
try:
|
||||||
|
bearer_token = "Bearer " + admin_token
|
||||||
|
|
||||||
|
request._url = URL(url="/team/new")
|
||||||
|
result = await user_api_key_auth(request=request, api_key=bearer_token)
|
||||||
|
await new_user(
|
||||||
|
data=NewUserRequest(
|
||||||
|
user_id=user_id,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
pytest.fail(f"This should not fail - {str(e)}")
|
||||||
|
|
||||||
|
## 5. 3rd call w/ same team, same user -> call should succeed
|
||||||
|
bearer_token = "Bearer " + token
|
||||||
|
request._url = URL(url="/chat/completions")
|
||||||
|
try:
|
||||||
|
team_result: UserAPIKeyAuth = await user_api_key_auth(
|
||||||
|
request=request, api_key=bearer_token
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
pytest.fail(f"Team exists. This should not fail - {e}")
|
||||||
|
|
||||||
|
## 6. ASSERT USER_API_KEY_AUTH format (used for tpm/rpm limiting in parallel_request_limiter.py AND cost tracking)
|
||||||
|
|
||||||
|
assert team_result.team_tpm_limit == 100
|
||||||
|
assert team_result.team_rpm_limit == 99
|
||||||
|
assert team_result.team_models == ["gpt-3.5-turbo", "gpt-4"]
|
||||||
|
assert team_result.user_id == user_id
|
||||||
|
|
|
@ -64,12 +64,12 @@ litellm_settings:
|
||||||
telemetry: False
|
telemetry: False
|
||||||
context_window_fallbacks: [{"gpt-3.5-turbo": ["gpt-3.5-turbo-large"]}]
|
context_window_fallbacks: [{"gpt-3.5-turbo": ["gpt-3.5-turbo-large"]}]
|
||||||
|
|
||||||
router_settings:
|
# router_settings:
|
||||||
routing_strategy: usage-based-routing-v2
|
# routing_strategy: usage-based-routing-v2
|
||||||
redis_host: os.environ/REDIS_HOST
|
# redis_host: os.environ/REDIS_HOST
|
||||||
redis_password: os.environ/REDIS_PASSWORD
|
# redis_password: os.environ/REDIS_PASSWORD
|
||||||
redis_port: os.environ/REDIS_PORT
|
# redis_port: os.environ/REDIS_PORT
|
||||||
enable_pre_call_checks: true
|
# enable_pre_call_checks: true
|
||||||
|
|
||||||
general_settings:
|
general_settings:
|
||||||
master_key: sk-1234 # [OPTIONAL] Use to enforce auth on proxy. See - https://docs.litellm.ai/docs/proxy/virtual_keys
|
master_key: sk-1234 # [OPTIONAL] Use to enforce auth on proxy. See - https://docs.litellm.ai/docs/proxy/virtual_keys
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue