mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-27 03:34:10 +00:00
feat(auth_check.py): support using redis cache for team objects
Allows team update / check logic to work across instances instantly
This commit is contained in:
parent
254b6dc630
commit
487035c970
5 changed files with 92 additions and 5 deletions
|
@ -4,5 +4,6 @@ model_list:
|
|||
model: "openai/*" # passes our validation check that a real provider is given
|
||||
api_key: ""
|
||||
|
||||
general_settings:
|
||||
completion_model: "gpt-3.5-turbo"
|
||||
litellm_settings:
|
||||
cache: True
|
||||
|
|
@ -370,10 +370,17 @@ async def _cache_team_object(
|
|||
team_id: str,
|
||||
team_table: LiteLLM_TeamTable,
|
||||
user_api_key_cache: DualCache,
|
||||
proxy_logging_obj: Optional[ProxyLogging],
|
||||
):
|
||||
key = "team_id:{}".format(team_id)
|
||||
await user_api_key_cache.async_set_cache(key=key, value=team_table)
|
||||
|
||||
## UPDATE REDIS CACHE ##
|
||||
if proxy_logging_obj is not None:
|
||||
await proxy_logging_obj.internal_usage_cache.async_set_cache(
|
||||
key=key, value=team_table
|
||||
)
|
||||
|
||||
|
||||
@log_to_opentelemetry
|
||||
async def get_team_object(
|
||||
|
@ -395,7 +402,17 @@ async def get_team_object(
|
|||
|
||||
# check if in cache
|
||||
key = "team_id:{}".format(team_id)
|
||||
cached_team_obj = await user_api_key_cache.async_get_cache(key=key)
|
||||
|
||||
cached_team_obj: Optional[LiteLLM_TeamTable] = None
|
||||
## CHECK REDIS CACHE ##
|
||||
if proxy_logging_obj is not None:
|
||||
cached_team_obj = await proxy_logging_obj.internal_usage_cache.async_get_cache(
|
||||
key=key
|
||||
)
|
||||
|
||||
if cached_team_obj is None:
|
||||
cached_team_obj = await user_api_key_cache.async_get_cache(key=key)
|
||||
|
||||
if cached_team_obj is not None:
|
||||
if isinstance(cached_team_obj, dict):
|
||||
return LiteLLM_TeamTable(**cached_team_obj)
|
||||
|
@ -413,7 +430,10 @@ async def get_team_object(
|
|||
_response = LiteLLM_TeamTable(**response.dict())
|
||||
# save the team object to cache
|
||||
await _cache_team_object(
|
||||
team_id=team_id, team_table=_response, user_api_key_cache=user_api_key_cache
|
||||
team_id=team_id,
|
||||
team_table=_response,
|
||||
user_api_key_cache=user_api_key_cache,
|
||||
proxy_logging_obj=proxy_logging_obj,
|
||||
)
|
||||
|
||||
return _response
|
||||
|
|
|
@ -334,6 +334,7 @@ async def update_team(
|
|||
create_audit_log_for_update,
|
||||
litellm_proxy_admin_name,
|
||||
prisma_client,
|
||||
proxy_logging_obj,
|
||||
user_api_key_cache,
|
||||
)
|
||||
|
||||
|
@ -380,6 +381,7 @@ async def update_team(
|
|||
team_id=team_row.team_id,
|
||||
team_table=team_row,
|
||||
user_api_key_cache=user_api_key_cache,
|
||||
proxy_logging_obj=proxy_logging_obj,
|
||||
)
|
||||
|
||||
# Enterprise Feature - Audit Logging. Enable with litellm.store_audit_logs = True
|
||||
|
|
|
@ -862,7 +862,7 @@ class PrismaClient:
|
|||
)
|
||||
"""
|
||||
)
|
||||
if ret[0]['sum'] == 6:
|
||||
if ret[0]["sum"] == 6:
|
||||
print("All necessary views exist!") # noqa
|
||||
return
|
||||
except Exception:
|
||||
|
|
|
@ -731,3 +731,67 @@ def test_load_router_config(mock_cache, fake_env_vars):
|
|||
|
||||
|
||||
# test_load_router_config()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_team_update_redis():
|
||||
"""
|
||||
Tests if team update, updates the redis cache if set
|
||||
"""
|
||||
from litellm.caching import DualCache, RedisCache
|
||||
from litellm.proxy._types import LiteLLM_TeamTable
|
||||
from litellm.proxy.auth.auth_checks import _cache_team_object
|
||||
|
||||
proxy_logging_obj: ProxyLogging = getattr(
|
||||
litellm.proxy.proxy_server, "proxy_logging_obj"
|
||||
)
|
||||
|
||||
proxy_logging_obj.internal_usage_cache.redis_cache = RedisCache()
|
||||
|
||||
with patch.object(
|
||||
proxy_logging_obj.internal_usage_cache.redis_cache,
|
||||
"async_set_cache",
|
||||
new=MagicMock(),
|
||||
) as mock_client:
|
||||
await _cache_team_object(
|
||||
team_id="1234",
|
||||
team_table=LiteLLM_TeamTable(),
|
||||
user_api_key_cache=DualCache(),
|
||||
proxy_logging_obj=proxy_logging_obj,
|
||||
)
|
||||
|
||||
mock_client.assert_called_once()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_team_redis(client_no_auth):
|
||||
"""
|
||||
Tests if get_team_object gets value from redis cache, if set
|
||||
"""
|
||||
from litellm.caching import DualCache, RedisCache
|
||||
from litellm.proxy._types import LiteLLM_TeamTable
|
||||
from litellm.proxy.auth.auth_checks import _cache_team_object, get_team_object
|
||||
|
||||
proxy_logging_obj: ProxyLogging = getattr(
|
||||
litellm.proxy.proxy_server, "proxy_logging_obj"
|
||||
)
|
||||
|
||||
proxy_logging_obj.internal_usage_cache.redis_cache = RedisCache()
|
||||
|
||||
with patch.object(
|
||||
proxy_logging_obj.internal_usage_cache.redis_cache,
|
||||
"async_get_cache",
|
||||
new=AsyncMock(),
|
||||
) as mock_client:
|
||||
try:
|
||||
await get_team_object(
|
||||
team_id="1234",
|
||||
user_api_key_cache=DualCache(),
|
||||
parent_otel_span=None,
|
||||
proxy_logging_obj=proxy_logging_obj,
|
||||
prisma_client=MagicMock(),
|
||||
)
|
||||
except Exception as e:
|
||||
pass
|
||||
|
||||
mock_client.assert_called_once()
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue