fix(user_api_key_auth.py): update valid token cache with updated team object cache

This commit is contained in:
Krrish Dholakia 2024-07-19 17:06:49 -07:00
parent fa7037e48a
commit 35e640076b
8 changed files with 94 additions and 15 deletions

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

View file

@ -1,5 +1,6 @@
model_list:
- model_name: bad-azure-model
- model_name: azure-chatgpt
litellm_params:
model: gpt-4
request_timeout: 1
model: azure/chatgpt-v-2
api_key: os.environ/AZURE_API_KEY
api_base: os.environ/AZURE_API_BASE

View file

@ -59,7 +59,7 @@ def common_checks(
6. [OPTIONAL] If 'litellm.max_budget' is set (>0), is proxy under budget
"""
_model = request_body.get("model", None)
if team_object is not None and team_object.blocked == True:
if team_object is not None and team_object.blocked is True:
raise Exception(
f"Team={team_object.team_id} is blocked. Update via `/team/unblock` if your admin."
)
@ -349,6 +349,15 @@ async def get_user_object(
)
async def _cache_team_object(
team_id: str,
team_table: LiteLLM_TeamTable,
user_api_key_cache: DualCache,
):
key = "team_id:{}".format(team_id)
await user_api_key_cache.async_set_cache(key=key, value=team_table)
@log_to_opentelemetry
async def get_team_object(
team_id: str,
@ -386,7 +395,9 @@ async def get_team_object(
_response = LiteLLM_TeamTable(**response.dict())
# save the team object to cache
await user_api_key_cache.async_set_cache(key=key, value=_response)
await _cache_team_object(
team_id=team_id, team_table=_response, user_api_key_cache=user_api_key_cache
)
return _response
except Exception as e:

View file

@ -453,6 +453,27 @@ async def user_api_key_auth(
return valid_token
if (
valid_token is not None
and isinstance(valid_token, UserAPIKeyAuth)
and valid_token.team_id is not None
and user_api_key_cache.get_cache(
key="team_id:{}".format(valid_token.team_id)
)
is not None
):
## UPDATE TEAM VALUES BASED ON CACHED TEAM OBJECT - allows `/team/update` values to work for cached token
team_obj: LiteLLM_TeamTable = user_api_key_cache.get_cache(
key="team_id:{}".format(valid_token.team_id)
)
team_obj_dict = team_obj.__dict__
for k, v in team_obj_dict.items():
field_name = f"team_{k}"
if field_name in valid_token.__fields__:
setattr(valid_token, field_name, v)
try:
is_master_key_valid = secrets.compare_digest(api_key, master_key) # type: ignore
except Exception as e:
@ -504,7 +525,6 @@ async def user_api_key_auth(
raise Exception("No connected db.")
## check for cache hit (In-Memory Cache)
original_api_key = api_key # (Patch: For DynamoDB Backwards Compatibility)
_user_role = None
if api_key.startswith("sk-"):
api_key = hash_token(token=api_key)

View file

@ -328,11 +328,13 @@ async def update_team(
}'
```
"""
from litellm.proxy.auth.auth_checks import _cache_team_object
from litellm.proxy.proxy_server import (
_duration_in_seconds,
create_audit_log_for_update,
litellm_proxy_admin_name,
prisma_client,
user_api_key_cache,
)
if prisma_client is None:
@ -361,11 +363,22 @@ async def update_team(
# set the budget_reset_at in DB
updated_kv["budget_reset_at"] = reset_at
team_row = await prisma_client.update_data(
update_key_values=updated_kv,
data=updated_kv,
table_name="team",
team_id=data.team_id,
team_row: Optional[
LiteLLM_TeamTable
] = await prisma_client.db.litellm_teamtable.update(
where={"team_id": data.team_id}, data=updated_kv # type: ignore
)
if team_row is None or team_row.team_id is None:
raise HTTPException(
status_code=400,
detail={"error": "Team doesn't exist. Got={}".format(team_row)},
)
await _cache_team_object(
team_id=team_row.team_id,
team_table=team_row,
user_api_key_cache=user_api_key_cache,
)
# Enterprise Feature - Audit Logging. Enable with litellm.store_audit_logs = True
@ -392,7 +405,7 @@ async def update_team(
)
)
return team_row
return {"team_id": team_row.team_id, "data": team_row}
@router.post(

View file

@ -44,3 +44,40 @@ def test_check_valid_ip(
request = Request(client_ip)
assert _check_valid_ip(allowed_ips, request) == expected_result # type: ignore
@pytest.mark.asyncio
async def test_check_blocked_team():
"""
cached valid_token obj has team_blocked = true
cached team obj has team_blocked = false
assert team is not blocked
"""
from fastapi import Request
from starlette.datastructures import URL
from litellm.proxy._types import LiteLLM_TeamTable, UserAPIKeyAuth
from litellm.proxy.auth.user_api_key_auth import user_api_key_auth
from litellm.proxy.proxy_server import hash_token, user_api_key_cache
_team_id = "1234"
team_obj = LiteLLM_TeamTable(team_id=_team_id, blocked=False)
user_key = "sk-12345678"
valid_token = UserAPIKeyAuth(
team_id=_team_id, team_blocked=True, token=hash_token(user_key)
)
user_api_key_cache.set_cache(key=hash_token(user_key), value=valid_token)
user_api_key_cache.set_cache(key="team_id:{}".format(_team_id), value=team_obj)
setattr(litellm.proxy.proxy_server, "user_api_key_cache", user_api_key_cache)
setattr(litellm.proxy.proxy_server, "master_key", "sk-1234")
setattr(litellm.proxy.proxy_server, "prisma_client", "hello-world")
request = Request(scope={"type": "http"})
request._url = URL(url="/chat/completions")
await user_api_key_auth(request=request, api_key="Bearer " + user_key)