mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-27 03:34:10 +00:00
feat(auth_check.py): support using redis cache for team objects
Allows team update / check logic to work across instances instantly
This commit is contained in:
parent
254b6dc630
commit
487035c970
5 changed files with 92 additions and 5 deletions
|
@ -4,5 +4,6 @@ model_list:
|
||||||
model: "openai/*" # passes our validation check that a real provider is given
|
model: "openai/*" # passes our validation check that a real provider is given
|
||||||
api_key: ""
|
api_key: ""
|
||||||
|
|
||||||
general_settings:
|
litellm_settings:
|
||||||
completion_model: "gpt-3.5-turbo"
|
cache: True
|
||||||
|
|
|
@ -370,10 +370,17 @@ async def _cache_team_object(
|
||||||
team_id: str,
|
team_id: str,
|
||||||
team_table: LiteLLM_TeamTable,
|
team_table: LiteLLM_TeamTable,
|
||||||
user_api_key_cache: DualCache,
|
user_api_key_cache: DualCache,
|
||||||
|
proxy_logging_obj: Optional[ProxyLogging],
|
||||||
):
|
):
|
||||||
key = "team_id:{}".format(team_id)
|
key = "team_id:{}".format(team_id)
|
||||||
await user_api_key_cache.async_set_cache(key=key, value=team_table)
|
await user_api_key_cache.async_set_cache(key=key, value=team_table)
|
||||||
|
|
||||||
|
## UPDATE REDIS CACHE ##
|
||||||
|
if proxy_logging_obj is not None:
|
||||||
|
await proxy_logging_obj.internal_usage_cache.async_set_cache(
|
||||||
|
key=key, value=team_table
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@log_to_opentelemetry
|
@log_to_opentelemetry
|
||||||
async def get_team_object(
|
async def get_team_object(
|
||||||
|
@ -395,7 +402,17 @@ async def get_team_object(
|
||||||
|
|
||||||
# check if in cache
|
# check if in cache
|
||||||
key = "team_id:{}".format(team_id)
|
key = "team_id:{}".format(team_id)
|
||||||
|
|
||||||
|
cached_team_obj: Optional[LiteLLM_TeamTable] = None
|
||||||
|
## CHECK REDIS CACHE ##
|
||||||
|
if proxy_logging_obj is not None:
|
||||||
|
cached_team_obj = await proxy_logging_obj.internal_usage_cache.async_get_cache(
|
||||||
|
key=key
|
||||||
|
)
|
||||||
|
|
||||||
|
if cached_team_obj is None:
|
||||||
cached_team_obj = await user_api_key_cache.async_get_cache(key=key)
|
cached_team_obj = await user_api_key_cache.async_get_cache(key=key)
|
||||||
|
|
||||||
if cached_team_obj is not None:
|
if cached_team_obj is not None:
|
||||||
if isinstance(cached_team_obj, dict):
|
if isinstance(cached_team_obj, dict):
|
||||||
return LiteLLM_TeamTable(**cached_team_obj)
|
return LiteLLM_TeamTable(**cached_team_obj)
|
||||||
|
@ -413,7 +430,10 @@ async def get_team_object(
|
||||||
_response = LiteLLM_TeamTable(**response.dict())
|
_response = LiteLLM_TeamTable(**response.dict())
|
||||||
# save the team object to cache
|
# save the team object to cache
|
||||||
await _cache_team_object(
|
await _cache_team_object(
|
||||||
team_id=team_id, team_table=_response, user_api_key_cache=user_api_key_cache
|
team_id=team_id,
|
||||||
|
team_table=_response,
|
||||||
|
user_api_key_cache=user_api_key_cache,
|
||||||
|
proxy_logging_obj=proxy_logging_obj,
|
||||||
)
|
)
|
||||||
|
|
||||||
return _response
|
return _response
|
||||||
|
|
|
@ -334,6 +334,7 @@ async def update_team(
|
||||||
create_audit_log_for_update,
|
create_audit_log_for_update,
|
||||||
litellm_proxy_admin_name,
|
litellm_proxy_admin_name,
|
||||||
prisma_client,
|
prisma_client,
|
||||||
|
proxy_logging_obj,
|
||||||
user_api_key_cache,
|
user_api_key_cache,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -380,6 +381,7 @@ async def update_team(
|
||||||
team_id=team_row.team_id,
|
team_id=team_row.team_id,
|
||||||
team_table=team_row,
|
team_table=team_row,
|
||||||
user_api_key_cache=user_api_key_cache,
|
user_api_key_cache=user_api_key_cache,
|
||||||
|
proxy_logging_obj=proxy_logging_obj,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Enterprise Feature - Audit Logging. Enable with litellm.store_audit_logs = True
|
# Enterprise Feature - Audit Logging. Enable with litellm.store_audit_logs = True
|
||||||
|
|
|
@ -862,7 +862,7 @@ class PrismaClient:
|
||||||
)
|
)
|
||||||
"""
|
"""
|
||||||
)
|
)
|
||||||
if ret[0]['sum'] == 6:
|
if ret[0]["sum"] == 6:
|
||||||
print("All necessary views exist!") # noqa
|
print("All necessary views exist!") # noqa
|
||||||
return
|
return
|
||||||
except Exception:
|
except Exception:
|
||||||
|
|
|
@ -731,3 +731,67 @@ def test_load_router_config(mock_cache, fake_env_vars):
|
||||||
|
|
||||||
|
|
||||||
# test_load_router_config()
|
# test_load_router_config()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_team_update_redis():
|
||||||
|
"""
|
||||||
|
Tests if team update, updates the redis cache if set
|
||||||
|
"""
|
||||||
|
from litellm.caching import DualCache, RedisCache
|
||||||
|
from litellm.proxy._types import LiteLLM_TeamTable
|
||||||
|
from litellm.proxy.auth.auth_checks import _cache_team_object
|
||||||
|
|
||||||
|
proxy_logging_obj: ProxyLogging = getattr(
|
||||||
|
litellm.proxy.proxy_server, "proxy_logging_obj"
|
||||||
|
)
|
||||||
|
|
||||||
|
proxy_logging_obj.internal_usage_cache.redis_cache = RedisCache()
|
||||||
|
|
||||||
|
with patch.object(
|
||||||
|
proxy_logging_obj.internal_usage_cache.redis_cache,
|
||||||
|
"async_set_cache",
|
||||||
|
new=MagicMock(),
|
||||||
|
) as mock_client:
|
||||||
|
await _cache_team_object(
|
||||||
|
team_id="1234",
|
||||||
|
team_table=LiteLLM_TeamTable(),
|
||||||
|
user_api_key_cache=DualCache(),
|
||||||
|
proxy_logging_obj=proxy_logging_obj,
|
||||||
|
)
|
||||||
|
|
||||||
|
mock_client.assert_called_once()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_team_redis(client_no_auth):
|
||||||
|
"""
|
||||||
|
Tests if get_team_object gets value from redis cache, if set
|
||||||
|
"""
|
||||||
|
from litellm.caching import DualCache, RedisCache
|
||||||
|
from litellm.proxy._types import LiteLLM_TeamTable
|
||||||
|
from litellm.proxy.auth.auth_checks import _cache_team_object, get_team_object
|
||||||
|
|
||||||
|
proxy_logging_obj: ProxyLogging = getattr(
|
||||||
|
litellm.proxy.proxy_server, "proxy_logging_obj"
|
||||||
|
)
|
||||||
|
|
||||||
|
proxy_logging_obj.internal_usage_cache.redis_cache = RedisCache()
|
||||||
|
|
||||||
|
with patch.object(
|
||||||
|
proxy_logging_obj.internal_usage_cache.redis_cache,
|
||||||
|
"async_get_cache",
|
||||||
|
new=AsyncMock(),
|
||||||
|
) as mock_client:
|
||||||
|
try:
|
||||||
|
await get_team_object(
|
||||||
|
team_id="1234",
|
||||||
|
user_api_key_cache=DualCache(),
|
||||||
|
parent_otel_span=None,
|
||||||
|
proxy_logging_obj=proxy_logging_obj,
|
||||||
|
prisma_client=MagicMock(),
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
pass
|
||||||
|
|
||||||
|
mock_client.assert_called_once()
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue