Litellm dev 11 02 2024 (#6561)

* fix(dual_cache.py): update in-memory check for redis batch get cache

Fixes latency delay for async_batch_redis_cache

* fix(service_logger.py): fix race condition causing otel service logging to be overwritten if service_callbacks set

* feat(user_api_key_auth.py): add parent otel component for auth

allows us to isolate how much latency is added by auth checks

* perf(parallel_request_limiter.py): move async_set_cache_pipeline (from max parallel request limiter) out of execution path (background task)

reduces latency by 200ms

* feat(user_api_key_auth.py): have user api key auth object return user tpm/rpm limits - reduces redis calls in downstream task (parallel_request_limiter)

Reduces latency by 400-800ms

* fix(parallel_request_limiter.py): use batch get cache to reduce user/key/team usage object calls

reduces latency by 50-100ms

* fix: fix linting error

* fix(_service_logger.py): fix import

* fix(user_api_key_auth.py): fix service logging

* fix(dual_cache.py): don't pass 'self'

* fix: fix python3.8 error

* fix: fix init]
This commit is contained in:
Krish Dholakia 2024-11-04 07:48:20 +05:30 committed by GitHub
parent 587d5fe277
commit d88e8922d4
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
17 changed files with 303 additions and 157 deletions

View file

@ -59,12 +59,15 @@ async def test_dual_cache_async_batch_get_cache():
redis_cache = RedisCache() # get credentials from environment
dual_cache = DualCache(in_memory_cache=in_memory_cache, redis_cache=redis_cache)
in_memory_cache.set_cache(key="test_value", value="hello world")
with patch.object(
dual_cache.redis_cache, "async_batch_get_cache", new=AsyncMock()
) as mock_redis_cache:
mock_redis_cache.return_value = {"test_value_2": None, "test_value": "hello"}
result = await dual_cache.async_batch_get_cache(keys=["test_value", "test_value_2"])
await dual_cache.async_batch_get_cache(keys=["test_value", "test_value_2"])
await dual_cache.async_batch_get_cache(keys=["test_value", "test_value_2"])
assert result[0] == "hello world"
assert result[1] == None
assert mock_redis_cache.call_count == 1
def test_dual_cache_batch_get_cache():

View file

@ -96,6 +96,7 @@ async def test_pre_call_hook():
key=request_count_api_key
)
)
await asyncio.sleep(1)
assert (
parallel_request_handler.internal_usage_cache.get_cache(
key=request_count_api_key
@ -110,6 +111,7 @@ async def test_pre_call_hook_rpm_limits():
Test if error raised on hitting rpm limits
"""
_api_key = "sk-12345"
_api_key = hash_token(_api_key)
user_api_key_dict = UserAPIKeyAuth(
api_key=_api_key, max_parallel_requests=1, tpm_limit=9, rpm_limit=1
)
@ -152,6 +154,7 @@ async def test_pre_call_hook_rpm_limits_retry_after():
Test if rate limit error, returns 'retry_after'
"""
_api_key = "sk-12345"
_api_key = hash_token(_api_key)
user_api_key_dict = UserAPIKeyAuth(
api_key=_api_key, max_parallel_requests=1, tpm_limit=9, rpm_limit=1
)
@ -251,6 +254,7 @@ async def test_pre_call_hook_tpm_limits():
Test if error raised on hitting tpm limits
"""
_api_key = "sk-12345"
_api_key = hash_token(_api_key)
user_api_key_dict = UserAPIKeyAuth(
api_key=_api_key, max_parallel_requests=1, tpm_limit=9, rpm_limit=10
)
@ -306,9 +310,9 @@ async def test_pre_call_hook_user_tpm_limits():
local_cache.set_cache(key=user_id, value=user_obj)
_api_key = "sk-12345"
_api_key = hash_token(_api_key)
user_api_key_dict = UserAPIKeyAuth(
api_key=_api_key,
user_id=user_id,
api_key=_api_key, user_id=user_id, user_rpm_limit=10, user_tpm_limit=9
)
res = dict(user_api_key_dict)
print("dict user", res)
@ -372,7 +376,7 @@ async def test_success_call_hook():
current_minute = datetime.now().strftime("%M")
precise_minute = f"{current_date}-{current_hour}-{current_minute}"
request_count_api_key = f"{_api_key}::{precise_minute}::request_count"
await asyncio.sleep(1)
assert (
parallel_request_handler.internal_usage_cache.get_cache(
key=request_count_api_key
@ -416,7 +420,7 @@ async def test_failure_call_hook():
current_minute = datetime.now().strftime("%M")
precise_minute = f"{current_date}-{current_hour}-{current_minute}"
request_count_api_key = f"{_api_key}::{precise_minute}::request_count"
await asyncio.sleep(1)
assert (
parallel_request_handler.internal_usage_cache.get_cache(
key=request_count_api_key
@ -497,7 +501,7 @@ async def test_normal_router_call():
current_minute = datetime.now().strftime("%M")
precise_minute = f"{current_date}-{current_hour}-{current_minute}"
request_count_api_key = f"{_api_key}::{precise_minute}::request_count"
await asyncio.sleep(1)
assert (
parallel_request_handler.internal_usage_cache.get_cache(
key=request_count_api_key
@ -579,7 +583,7 @@ async def test_normal_router_tpm_limit():
precise_minute = f"{current_date}-{current_hour}-{current_minute}"
request_count_api_key = f"{_api_key}::{precise_minute}::request_count"
print("Test: Checking current_requests for precise_minute=", precise_minute)
await asyncio.sleep(1)
assert (
parallel_request_handler.internal_usage_cache.get_cache(
key=request_count_api_key
@ -658,7 +662,7 @@ async def test_streaming_router_call():
current_minute = datetime.now().strftime("%M")
precise_minute = f"{current_date}-{current_hour}-{current_minute}"
request_count_api_key = f"{_api_key}::{precise_minute}::request_count"
await asyncio.sleep(1)
assert (
parallel_request_handler.internal_usage_cache.get_cache(
key=request_count_api_key
@ -736,7 +740,7 @@ async def test_streaming_router_tpm_limit():
current_minute = datetime.now().strftime("%M")
precise_minute = f"{current_date}-{current_hour}-{current_minute}"
request_count_api_key = f"{_api_key}::{precise_minute}::request_count"
await asyncio.sleep(1)
assert (
parallel_request_handler.internal_usage_cache.get_cache(
key=request_count_api_key
@ -814,7 +818,7 @@ async def test_bad_router_call():
current_minute = datetime.now().strftime("%M")
precise_minute = f"{current_date}-{current_hour}-{current_minute}"
request_count_api_key = f"{_api_key}::{precise_minute}::request_count"
await asyncio.sleep(1)
assert (
parallel_request_handler.internal_usage_cache.get_cache( # type: ignore
key=request_count_api_key
@ -890,7 +894,7 @@ async def test_bad_router_tpm_limit():
current_minute = datetime.now().strftime("%M")
precise_minute = f"{current_date}-{current_hour}-{current_minute}"
request_count_api_key = f"{_api_key}::{precise_minute}::request_count"
await asyncio.sleep(1)
assert (
parallel_request_handler.internal_usage_cache.get_cache(
key=request_count_api_key
@ -979,7 +983,7 @@ async def test_bad_router_tpm_limit_per_model():
current_minute = datetime.now().strftime("%M")
precise_minute = f"{current_date}-{current_hour}-{current_minute}"
request_count_api_key = f"{_api_key}::{model}::{precise_minute}::request_count"
await asyncio.sleep(1)
print(
"internal usage cache: ",
parallel_request_handler.internal_usage_cache.dual_cache.in_memory_cache.cache_dict,

View file

@ -139,6 +139,7 @@ async def test_check_blocked_team():
def test_returned_user_api_key_auth(user_role, expected_role):
from litellm.proxy._types import LiteLLM_UserTable, LitellmUserRoles
from litellm.proxy.auth.user_api_key_auth import _return_user_api_key_auth_obj
from datetime import datetime
new_obj = _return_user_api_key_auth_obj(
user_obj=LiteLLM_UserTable(
@ -148,6 +149,7 @@ def test_returned_user_api_key_auth(user_role, expected_role):
parent_otel_span=None,
valid_token_dict={},
route="/chat/completion",
start_time=datetime.now(),
)
assert new_obj.user_role == expected_role