feat(auth_check.py): support using redis cache for team objects

Allows team update / check logic to work across instances instantly
This commit is contained in:
Krrish Dholakia 2024-07-24 18:14:49 -07:00
parent 254b6dc630
commit 487035c970
5 changed files with 92 additions and 5 deletions

View file

@ -4,5 +4,6 @@ model_list:
model: "openai/*" # passes our validation check that a real provider is given model: "openai/*" # passes our validation check that a real provider is given
api_key: "" api_key: ""
general_settings: litellm_settings:
completion_model: "gpt-3.5-turbo" cache: True

View file

@ -370,10 +370,17 @@ async def _cache_team_object(
team_id: str, team_id: str,
team_table: LiteLLM_TeamTable, team_table: LiteLLM_TeamTable,
user_api_key_cache: DualCache, user_api_key_cache: DualCache,
proxy_logging_obj: Optional[ProxyLogging],
): ):
key = "team_id:{}".format(team_id) key = "team_id:{}".format(team_id)
await user_api_key_cache.async_set_cache(key=key, value=team_table) await user_api_key_cache.async_set_cache(key=key, value=team_table)
## UPDATE REDIS CACHE ##
if proxy_logging_obj is not None:
await proxy_logging_obj.internal_usage_cache.async_set_cache(
key=key, value=team_table
)
@log_to_opentelemetry @log_to_opentelemetry
async def get_team_object( async def get_team_object(
@ -395,7 +402,17 @@ async def get_team_object(
# check if in cache # check if in cache
key = "team_id:{}".format(team_id) key = "team_id:{}".format(team_id)
cached_team_obj: Optional[LiteLLM_TeamTable] = None
## CHECK REDIS CACHE ##
if proxy_logging_obj is not None:
cached_team_obj = await proxy_logging_obj.internal_usage_cache.async_get_cache(
key=key
)
if cached_team_obj is None:
cached_team_obj = await user_api_key_cache.async_get_cache(key=key) cached_team_obj = await user_api_key_cache.async_get_cache(key=key)
if cached_team_obj is not None: if cached_team_obj is not None:
if isinstance(cached_team_obj, dict): if isinstance(cached_team_obj, dict):
return LiteLLM_TeamTable(**cached_team_obj) return LiteLLM_TeamTable(**cached_team_obj)
@ -413,7 +430,10 @@ async def get_team_object(
_response = LiteLLM_TeamTable(**response.dict()) _response = LiteLLM_TeamTable(**response.dict())
# save the team object to cache # save the team object to cache
await _cache_team_object( await _cache_team_object(
team_id=team_id, team_table=_response, user_api_key_cache=user_api_key_cache team_id=team_id,
team_table=_response,
user_api_key_cache=user_api_key_cache,
proxy_logging_obj=proxy_logging_obj,
) )
return _response return _response

View file

@ -334,6 +334,7 @@ async def update_team(
create_audit_log_for_update, create_audit_log_for_update,
litellm_proxy_admin_name, litellm_proxy_admin_name,
prisma_client, prisma_client,
proxy_logging_obj,
user_api_key_cache, user_api_key_cache,
) )
@ -380,6 +381,7 @@ async def update_team(
team_id=team_row.team_id, team_id=team_row.team_id,
team_table=team_row, team_table=team_row,
user_api_key_cache=user_api_key_cache, user_api_key_cache=user_api_key_cache,
proxy_logging_obj=proxy_logging_obj,
) )
# Enterprise Feature - Audit Logging. Enable with litellm.store_audit_logs = True # Enterprise Feature - Audit Logging. Enable with litellm.store_audit_logs = True

View file

@ -862,7 +862,7 @@ class PrismaClient:
) )
""" """
) )
if ret[0]['sum'] == 6: if ret[0]["sum"] == 6:
print("All necessary views exist!") # noqa print("All necessary views exist!") # noqa
return return
except Exception: except Exception:

View file

@ -731,3 +731,67 @@ def test_load_router_config(mock_cache, fake_env_vars):
# test_load_router_config() # test_load_router_config()
@pytest.mark.asyncio
async def test_team_update_redis():
"""
Tests if team update, updates the redis cache if set
"""
from litellm.caching import DualCache, RedisCache
from litellm.proxy._types import LiteLLM_TeamTable
from litellm.proxy.auth.auth_checks import _cache_team_object
proxy_logging_obj: ProxyLogging = getattr(
litellm.proxy.proxy_server, "proxy_logging_obj"
)
proxy_logging_obj.internal_usage_cache.redis_cache = RedisCache()
with patch.object(
proxy_logging_obj.internal_usage_cache.redis_cache,
"async_set_cache",
new=MagicMock(),
) as mock_client:
await _cache_team_object(
team_id="1234",
team_table=LiteLLM_TeamTable(),
user_api_key_cache=DualCache(),
proxy_logging_obj=proxy_logging_obj,
)
mock_client.assert_called_once()
@pytest.mark.asyncio
async def test_get_team_redis(client_no_auth):
"""
Tests if get_team_object gets value from redis cache, if set
"""
from litellm.caching import DualCache, RedisCache
from litellm.proxy._types import LiteLLM_TeamTable
from litellm.proxy.auth.auth_checks import _cache_team_object, get_team_object
proxy_logging_obj: ProxyLogging = getattr(
litellm.proxy.proxy_server, "proxy_logging_obj"
)
proxy_logging_obj.internal_usage_cache.redis_cache = RedisCache()
with patch.object(
proxy_logging_obj.internal_usage_cache.redis_cache,
"async_get_cache",
new=AsyncMock(),
) as mock_client:
try:
await get_team_object(
team_id="1234",
user_api_key_cache=DualCache(),
parent_otel_span=None,
proxy_logging_obj=proxy_logging_obj,
prisma_client=MagicMock(),
)
except Exception as e:
pass
mock_client.assert_called_once()