forked from phoenix/litellm-mirror
fix(user_api_key_auth.py): update valid token cache with updated team object cache
This commit is contained in:
parent
fa7037e48a
commit
35e640076b
8 changed files with 94 additions and 15 deletions
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
|
@ -1,5 +1,6 @@
|
|||
model_list:
|
||||
- model_name: bad-azure-model
|
||||
- model_name: azure-chatgpt
|
||||
litellm_params:
|
||||
model: gpt-4
|
||||
request_timeout: 1
|
||||
model: azure/chatgpt-v-2
|
||||
api_key: os.environ/AZURE_API_KEY
|
||||
api_base: os.environ/AZURE_API_BASE
|
||||
|
|
|
@ -59,7 +59,7 @@ def common_checks(
|
|||
6. [OPTIONAL] If 'litellm.max_budget' is set (>0), is proxy under budget
|
||||
"""
|
||||
_model = request_body.get("model", None)
|
||||
if team_object is not None and team_object.blocked == True:
|
||||
if team_object is not None and team_object.blocked is True:
|
||||
raise Exception(
|
||||
f"Team={team_object.team_id} is blocked. Update via `/team/unblock` if your admin."
|
||||
)
|
||||
|
@ -349,6 +349,15 @@ async def get_user_object(
|
|||
)
|
||||
|
||||
|
||||
async def _cache_team_object(
|
||||
team_id: str,
|
||||
team_table: LiteLLM_TeamTable,
|
||||
user_api_key_cache: DualCache,
|
||||
):
|
||||
key = "team_id:{}".format(team_id)
|
||||
await user_api_key_cache.async_set_cache(key=key, value=team_table)
|
||||
|
||||
|
||||
@log_to_opentelemetry
|
||||
async def get_team_object(
|
||||
team_id: str,
|
||||
|
@ -386,7 +395,9 @@ async def get_team_object(
|
|||
|
||||
_response = LiteLLM_TeamTable(**response.dict())
|
||||
# save the team object to cache
|
||||
await user_api_key_cache.async_set_cache(key=key, value=_response)
|
||||
await _cache_team_object(
|
||||
team_id=team_id, team_table=_response, user_api_key_cache=user_api_key_cache
|
||||
)
|
||||
|
||||
return _response
|
||||
except Exception as e:
|
||||
|
|
|
@ -453,6 +453,27 @@ async def user_api_key_auth(
|
|||
|
||||
return valid_token
|
||||
|
||||
if (
|
||||
valid_token is not None
|
||||
and isinstance(valid_token, UserAPIKeyAuth)
|
||||
and valid_token.team_id is not None
|
||||
and user_api_key_cache.get_cache(
|
||||
key="team_id:{}".format(valid_token.team_id)
|
||||
)
|
||||
is not None
|
||||
):
|
||||
## UPDATE TEAM VALUES BASED ON CACHED TEAM OBJECT - allows `/team/update` values to work for cached token
|
||||
team_obj: LiteLLM_TeamTable = user_api_key_cache.get_cache(
|
||||
key="team_id:{}".format(valid_token.team_id)
|
||||
)
|
||||
|
||||
team_obj_dict = team_obj.__dict__
|
||||
|
||||
for k, v in team_obj_dict.items():
|
||||
field_name = f"team_{k}"
|
||||
if field_name in valid_token.__fields__:
|
||||
setattr(valid_token, field_name, v)
|
||||
|
||||
try:
|
||||
is_master_key_valid = secrets.compare_digest(api_key, master_key) # type: ignore
|
||||
except Exception as e:
|
||||
|
@ -504,7 +525,6 @@ async def user_api_key_auth(
|
|||
raise Exception("No connected db.")
|
||||
|
||||
## check for cache hit (In-Memory Cache)
|
||||
original_api_key = api_key # (Patch: For DynamoDB Backwards Compatibility)
|
||||
_user_role = None
|
||||
if api_key.startswith("sk-"):
|
||||
api_key = hash_token(token=api_key)
|
||||
|
|
|
@ -328,11 +328,13 @@ async def update_team(
|
|||
}'
|
||||
```
|
||||
"""
|
||||
from litellm.proxy.auth.auth_checks import _cache_team_object
|
||||
from litellm.proxy.proxy_server import (
|
||||
_duration_in_seconds,
|
||||
create_audit_log_for_update,
|
||||
litellm_proxy_admin_name,
|
||||
prisma_client,
|
||||
user_api_key_cache,
|
||||
)
|
||||
|
||||
if prisma_client is None:
|
||||
|
@ -361,11 +363,22 @@ async def update_team(
|
|||
# set the budget_reset_at in DB
|
||||
updated_kv["budget_reset_at"] = reset_at
|
||||
|
||||
team_row = await prisma_client.update_data(
|
||||
update_key_values=updated_kv,
|
||||
data=updated_kv,
|
||||
table_name="team",
|
||||
team_id=data.team_id,
|
||||
team_row: Optional[
|
||||
LiteLLM_TeamTable
|
||||
] = await prisma_client.db.litellm_teamtable.update(
|
||||
where={"team_id": data.team_id}, data=updated_kv # type: ignore
|
||||
)
|
||||
|
||||
if team_row is None or team_row.team_id is None:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail={"error": "Team doesn't exist. Got={}".format(team_row)},
|
||||
)
|
||||
|
||||
await _cache_team_object(
|
||||
team_id=team_row.team_id,
|
||||
team_table=team_row,
|
||||
user_api_key_cache=user_api_key_cache,
|
||||
)
|
||||
|
||||
# Enterprise Feature - Audit Logging. Enable with litellm.store_audit_logs = True
|
||||
|
@ -392,7 +405,7 @@ async def update_team(
|
|||
)
|
||||
)
|
||||
|
||||
return team_row
|
||||
return {"team_id": team_row.team_id, "data": team_row}
|
||||
|
||||
|
||||
@router.post(
|
||||
|
|
|
@ -44,3 +44,40 @@ def test_check_valid_ip(
|
|||
request = Request(client_ip)
|
||||
|
||||
assert _check_valid_ip(allowed_ips, request) == expected_result # type: ignore
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_check_blocked_team():
|
||||
"""
|
||||
cached valid_token obj has team_blocked = true
|
||||
|
||||
cached team obj has team_blocked = false
|
||||
|
||||
assert team is not blocked
|
||||
"""
|
||||
from fastapi import Request
|
||||
from starlette.datastructures import URL
|
||||
|
||||
from litellm.proxy._types import LiteLLM_TeamTable, UserAPIKeyAuth
|
||||
from litellm.proxy.auth.user_api_key_auth import user_api_key_auth
|
||||
from litellm.proxy.proxy_server import hash_token, user_api_key_cache
|
||||
|
||||
_team_id = "1234"
|
||||
team_obj = LiteLLM_TeamTable(team_id=_team_id, blocked=False)
|
||||
|
||||
user_key = "sk-12345678"
|
||||
|
||||
valid_token = UserAPIKeyAuth(
|
||||
team_id=_team_id, team_blocked=True, token=hash_token(user_key)
|
||||
)
|
||||
user_api_key_cache.set_cache(key=hash_token(user_key), value=valid_token)
|
||||
user_api_key_cache.set_cache(key="team_id:{}".format(_team_id), value=team_obj)
|
||||
|
||||
setattr(litellm.proxy.proxy_server, "user_api_key_cache", user_api_key_cache)
|
||||
setattr(litellm.proxy.proxy_server, "master_key", "sk-1234")
|
||||
setattr(litellm.proxy.proxy_server, "prisma_client", "hello-world")
|
||||
|
||||
request = Request(scope={"type": "http"})
|
||||
request._url = URL(url="/chat/completions")
|
||||
|
||||
await user_api_key_auth(request=request, api_key="Bearer " + user_key)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue