Litellm perf improvements 3 (#6573)

* perf: move writing key to cache, to background task

* perf(litellm_pre_call_utils.py): add otel tracing for pre-call utils

adds 200ms on calls with pgdb connected

* fix(litellm_pre_call_utils.py'): rename call_type to actual call used

* perf(proxy_server.py): remove db logic from _get_config_from_file

was causing db calls to occur on every llm request, if team_id was set on key

* fix(auth_checks.py): add check for reducing db calls if user/team id does not exist in db

reduces latency/call by ~100ms

* fix(proxy_server.py): minor fix on existing_settings not incl alerting

* fix(exception_mapping_utils.py): map databricks exception string

* fix(auth_checks.py): fix auth check logic

* test: correctly mark flaky test

* fix(utils.py): handle auth token error for tokenizers.from_pretrained
This commit is contained in:
Krish Dholakia 2024-11-05 03:51:26 +05:30 committed by GitHub
parent 7525b6bbaa
commit 3a6ba0b955
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
14 changed files with 137 additions and 86 deletions

View file

@ -18,6 +18,7 @@ from pydantic import BaseModel
import litellm
from litellm._logging import verbose_proxy_logger
from litellm.caching.caching import DualCache
from litellm.caching.dual_cache import LimitedSizeOrderedDict
from litellm.proxy._types import (
LiteLLM_EndUserTable,
LiteLLM_JWTAuth,
@ -42,6 +43,10 @@ if TYPE_CHECKING:
else:
Span = Any
last_db_access_time = LimitedSizeOrderedDict(max_size=100)
db_cache_expiry = 5 # refresh every 5s
all_routes = LiteLLMRoutes.openai_routes.value + LiteLLMRoutes.management_routes.value
@ -383,6 +388,32 @@ def model_in_access_group(model: str, team_models: Optional[List[str]]) -> bool:
return False
def _should_check_db(
key: str, last_db_access_time: LimitedSizeOrderedDict, db_cache_expiry: int
) -> bool:
"""
Prevent calling db repeatedly for items that don't exist in the db.
"""
current_time = time.time()
# if key doesn't exist in last_db_access_time -> check db
if key not in last_db_access_time:
return True
elif (
last_db_access_time[key][0] is not None
): # check db for non-null values (for refresh operations)
return True
elif last_db_access_time[key][0] is None:
if current_time - last_db_access_time[key] >= db_cache_expiry:
return True
return False
def _update_last_db_access_time(
key: str, value: Optional[Any], last_db_access_time: LimitedSizeOrderedDict
):
last_db_access_time[key] = (value, time.time())
@log_to_opentelemetry
async def get_user_object(
user_id: str,
@ -412,11 +443,20 @@ async def get_user_object(
if prisma_client is None:
raise Exception("No db connected")
try:
response = await prisma_client.db.litellm_usertable.find_unique(
where={"user_id": user_id}, include={"organization_memberships": True}
db_access_time_key = "user_id:{}".format(user_id)
should_check_db = _should_check_db(
key=db_access_time_key,
last_db_access_time=last_db_access_time,
db_cache_expiry=db_cache_expiry,
)
if should_check_db:
response = await prisma_client.db.litellm_usertable.find_unique(
where={"user_id": user_id}, include={"organization_memberships": True}
)
else:
response = None
if response is None:
if user_id_upsert:
response = await prisma_client.db.litellm_usertable.create(
@ -444,6 +484,13 @@ async def get_user_object(
# save the user object to cache
await user_api_key_cache.async_set_cache(key=user_id, value=response_dict)
# save to db access time
_update_last_db_access_time(
key=db_access_time_key,
value=response_dict,
last_db_access_time=last_db_access_time,
)
return _response
except Exception as e: # if user not in db
raise ValueError(
@ -515,6 +562,12 @@ async def _delete_cache_key_object(
@log_to_opentelemetry
async def _get_team_db_check(team_id: str, prisma_client: PrismaClient):
return await prisma_client.db.litellm_teamtable.find_unique(
where={"team_id": team_id}
)
async def get_team_object(
team_id: str,
prisma_client: Optional[PrismaClient],
@ -544,7 +597,7 @@ async def get_team_object(
):
cached_team_obj = (
await proxy_logging_obj.internal_usage_cache.dual_cache.async_get_cache(
key=key
key=key, parent_otel_span=parent_otel_span
)
)
@ -564,9 +617,18 @@ async def get_team_object(
# else, check db
try:
response = await prisma_client.db.litellm_teamtable.find_unique(
where={"team_id": team_id}
db_access_time_key = "team_id:{}".format(team_id)
should_check_db = _should_check_db(
key=db_access_time_key,
last_db_access_time=last_db_access_time,
db_cache_expiry=db_cache_expiry,
)
if should_check_db:
response = await _get_team_db_check(
team_id=team_id, prisma_client=prisma_client
)
else:
response = None
if response is None:
raise Exception
@ -580,6 +642,14 @@ async def get_team_object(
proxy_logging_obj=proxy_logging_obj,
)
# save to db access time
# save to db access time
_update_last_db_access_time(
key=db_access_time_key,
value=_response,
last_db_access_time=last_db_access_time,
)
return _response
except Exception:
raise Exception(
@ -608,16 +678,16 @@ async def get_key_object(
# check if in cache
key = hashed_token
cached_team_obj: Optional[UserAPIKeyAuth] = None
if cached_team_obj is None:
cached_team_obj = await user_api_key_cache.async_get_cache(key=key)
cached_key_obj: Optional[UserAPIKeyAuth] = await user_api_key_cache.async_get_cache(
key=key
)
if cached_team_obj is not None:
if isinstance(cached_team_obj, dict):
return UserAPIKeyAuth(**cached_team_obj)
elif isinstance(cached_team_obj, UserAPIKeyAuth):
return cached_team_obj
if cached_key_obj is not None:
if isinstance(cached_key_obj, dict):
return UserAPIKeyAuth(**cached_key_obj)
elif isinstance(cached_key_obj, UserAPIKeyAuth):
return cached_key_obj
if check_cache_only:
raise Exception(