diff --git a/litellm/proxy/_new_secret_config.yaml b/litellm/proxy/_new_secret_config.yaml index bec92c1e96..13babaac6a 100644 --- a/litellm/proxy/_new_secret_config.yaml +++ b/litellm/proxy/_new_secret_config.yaml @@ -4,5 +4,6 @@ model_list: model: "openai/*" # passes our validation check that a real provider is given api_key: "" -general_settings: - completion_model: "gpt-3.5-turbo" \ No newline at end of file +litellm_settings: + cache: True + \ No newline at end of file diff --git a/litellm/proxy/auth/auth_checks.py b/litellm/proxy/auth/auth_checks.py index 91d4b1938a..7c5356a379 100644 --- a/litellm/proxy/auth/auth_checks.py +++ b/litellm/proxy/auth/auth_checks.py @@ -370,10 +370,17 @@ async def _cache_team_object( team_id: str, team_table: LiteLLM_TeamTable, user_api_key_cache: DualCache, + proxy_logging_obj: Optional[ProxyLogging], ): key = "team_id:{}".format(team_id) await user_api_key_cache.async_set_cache(key=key, value=team_table) + ## UPDATE REDIS CACHE ## + if proxy_logging_obj is not None: + await proxy_logging_obj.internal_usage_cache.async_set_cache( + key=key, value=team_table + ) + @log_to_opentelemetry async def get_team_object( @@ -395,7 +402,17 @@ async def get_team_object( # check if in cache key = "team_id:{}".format(team_id) - cached_team_obj = await user_api_key_cache.async_get_cache(key=key) + + cached_team_obj: Optional[LiteLLM_TeamTable] = None + ## CHECK REDIS CACHE ## + if proxy_logging_obj is not None: + cached_team_obj = await proxy_logging_obj.internal_usage_cache.async_get_cache( + key=key + ) + + if cached_team_obj is None: + cached_team_obj = await user_api_key_cache.async_get_cache(key=key) + if cached_team_obj is not None: if isinstance(cached_team_obj, dict): return LiteLLM_TeamTable(**cached_team_obj) @@ -413,7 +430,10 @@ async def get_team_object( _response = LiteLLM_TeamTable(**response.dict()) # save the team object to cache await _cache_team_object( - team_id=team_id, team_table=_response, user_api_key_cache=user_api_key_cache + team_id=team_id, + team_table=_response, + user_api_key_cache=user_api_key_cache, + proxy_logging_obj=proxy_logging_obj, ) return _response diff --git a/litellm/proxy/management_endpoints/team_endpoints.py b/litellm/proxy/management_endpoints/team_endpoints.py index 9ba76a2032..9c20836d2b 100644 --- a/litellm/proxy/management_endpoints/team_endpoints.py +++ b/litellm/proxy/management_endpoints/team_endpoints.py @@ -334,6 +334,7 @@ async def update_team( create_audit_log_for_update, litellm_proxy_admin_name, prisma_client, + proxy_logging_obj, user_api_key_cache, ) @@ -380,6 +381,7 @@ async def update_team( team_id=team_row.team_id, team_table=team_row, user_api_key_cache=user_api_key_cache, + proxy_logging_obj=proxy_logging_obj, ) # Enterprise Feature - Audit Logging. Enable with litellm.store_audit_logs = True diff --git a/litellm/proxy/utils.py b/litellm/proxy/utils.py index b08d7a30f1..fc47abf9cd 100644 --- a/litellm/proxy/utils.py +++ b/litellm/proxy/utils.py @@ -862,7 +862,7 @@ class PrismaClient: ) """ ) - if ret[0]['sum'] == 6: + if ret[0]["sum"] == 6: print("All necessary views exist!") # noqa return except Exception: diff --git a/litellm/tests/test_proxy_server.py b/litellm/tests/test_proxy_server.py index f3cb69a082..e088f2055d 100644 --- a/litellm/tests/test_proxy_server.py +++ b/litellm/tests/test_proxy_server.py @@ -731,3 +731,67 @@ def test_load_router_config(mock_cache, fake_env_vars): # test_load_router_config() + + +@pytest.mark.asyncio +async def test_team_update_redis(): + """ + Tests if team update, updates the redis cache if set + """ + from litellm.caching import DualCache, RedisCache + from litellm.proxy._types import LiteLLM_TeamTable + from litellm.proxy.auth.auth_checks import _cache_team_object + + proxy_logging_obj: ProxyLogging = getattr( + litellm.proxy.proxy_server, "proxy_logging_obj" + ) + + proxy_logging_obj.internal_usage_cache.redis_cache = RedisCache() + + with patch.object( + proxy_logging_obj.internal_usage_cache.redis_cache, + "async_set_cache", + new=MagicMock(), + ) as mock_client: + await _cache_team_object( + team_id="1234", + team_table=LiteLLM_TeamTable(), + user_api_key_cache=DualCache(), + proxy_logging_obj=proxy_logging_obj, + ) + + mock_client.assert_called_once() + + +@pytest.mark.asyncio +async def test_get_team_redis(client_no_auth): + """ + Tests if get_team_object gets value from redis cache, if set + """ + from litellm.caching import DualCache, RedisCache + from litellm.proxy._types import LiteLLM_TeamTable + from litellm.proxy.auth.auth_checks import _cache_team_object, get_team_object + + proxy_logging_obj: ProxyLogging = getattr( + litellm.proxy.proxy_server, "proxy_logging_obj" + ) + + proxy_logging_obj.internal_usage_cache.redis_cache = RedisCache() + + with patch.object( + proxy_logging_obj.internal_usage_cache.redis_cache, + "async_get_cache", + new=AsyncMock(), + ) as mock_client: + try: + await get_team_object( + team_id="1234", + user_api_key_cache=DualCache(), + parent_otel_span=None, + proxy_logging_obj=proxy_logging_obj, + prisma_client=MagicMock(), + ) + except Exception as e: + pass + + mock_client.assert_called_once()