From 3a6ba0b9558ca0a754cc558fbfddd2bb7b11fda5 Mon Sep 17 00:00:00 2001 From: Krish Dholakia Date: Tue, 5 Nov 2024 03:51:26 +0530 Subject: [PATCH] Litellm perf improvements 3 (#6573) * perf: move writing key to cache, to background task * perf(litellm_pre_call_utils.py): add otel tracing for pre-call utils adds 200ms on calls with pgdb connected * fix(litellm_pre_call_utils.py'): rename call_type to actual call used * perf(proxy_server.py): remove db logic from _get_config_from_file was causing db calls to occur on every llm request, if team_id was set on key * fix(auth_checks.py): add check for reducing db calls if user/team id does not exist in db reduces latency/call by ~100ms * fix(proxy_server.py): minor fix on existing_settings not incl alerting * fix(exception_mapping_utils.py): map databricks exception string * fix(auth_checks.py): fix auth check logic * test: correctly mark flaky test * fix(utils.py): handle auth token error for tokenizers.from_pretrained --- .circleci/config.yml | 2 +- litellm/caching/dual_cache.py | 2 +- litellm/integrations/opentelemetry.py | 15 --- .../exception_mapping_utils.py | 10 ++ litellm/proxy/_new_secret_config.yaml | 8 +- litellm/proxy/auth/auth_checks.py | 98 ++++++++++++++++--- litellm/proxy/auth/user_api_key_auth.py | 12 ++- litellm/proxy/litellm_pre_call_utils.py | 18 +++- litellm/proxy/proxy_server.py | 41 +------- litellm/proxy/utils.py | 1 + litellm/types/services.py | 1 + litellm/utils.py | 12 ++- .../local_testing/test_key_generate_prisma.py | 2 +- tests/local_testing/test_router.py | 1 + 14 files changed, 137 insertions(+), 86 deletions(-) diff --git a/.circleci/config.yml b/.circleci/config.yml index 4734ee2a7..7083be6bd 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -392,7 +392,7 @@ jobs: pip install click pip install "boto3==1.34.34" pip install jinja2 - pip install tokenizers + pip install tokenizers=="0.20.0" pip install jsonschema - run: name: Run tests diff --git a/litellm/caching/dual_cache.py b/litellm/caching/dual_cache.py index a55a1a577..ddcd02abe 100644 --- a/litellm/caching/dual_cache.py +++ b/litellm/caching/dual_cache.py @@ -70,7 +70,7 @@ class DualCache(BaseCache): self.redis_batch_cache_expiry = ( default_redis_batch_cache_expiry or litellm.default_redis_batch_cache_expiry - or 5 + or 10 ) self.default_in_memory_ttl = ( default_in_memory_ttl or litellm.default_in_memory_ttl diff --git a/litellm/integrations/opentelemetry.py b/litellm/integrations/opentelemetry.py index a3bbb244e..a1d4b781a 100644 --- a/litellm/integrations/opentelemetry.py +++ b/litellm/integrations/opentelemetry.py @@ -281,21 +281,6 @@ class OpenTelemetry(CustomLogger): # End Parent OTEL Sspan parent_otel_span.end(end_time=self._to_ns(datetime.now())) - async def async_post_call_success_hook( - self, - data: dict, - user_api_key_dict: UserAPIKeyAuth, - response: Union[Any, ModelResponse, EmbeddingResponse, ImageResponse], - ): - from opentelemetry import trace - from opentelemetry.trace import Status, StatusCode - - parent_otel_span = user_api_key_dict.parent_otel_span - if parent_otel_span is not None: - parent_otel_span.set_status(Status(StatusCode.OK)) - # End Parent OTEL Sspan - parent_otel_span.end(end_time=self._to_ns(datetime.now())) - def _handle_sucess(self, kwargs, response_obj, start_time, end_time): from opentelemetry import trace from opentelemetry.trace import Status, StatusCode diff --git a/litellm/litellm_core_utils/exception_mapping_utils.py b/litellm/litellm_core_utils/exception_mapping_utils.py index 94eb5c623..14d5bffdb 100644 --- a/litellm/litellm_core_utils/exception_mapping_utils.py +++ b/litellm/litellm_core_utils/exception_mapping_utils.py @@ -646,6 +646,16 @@ def exception_type( # type: ignore # noqa: PLR0915 response=original_exception.response, litellm_debug_info=extra_information, ) + elif ( + "The server received an invalid response from an upstream server." + in error_str + ): + exception_mapping_worked = True + raise litellm.InternalServerError( + message=f"{custom_llm_provider}Exception - {original_exception.message}", + llm_provider=custom_llm_provider, + model=model, + ) elif hasattr(original_exception, "status_code"): if original_exception.status_code == 500: exception_mapping_worked = True diff --git a/litellm/proxy/_new_secret_config.yaml b/litellm/proxy/_new_secret_config.yaml index b9315670a..45a379748 100644 --- a/litellm/proxy/_new_secret_config.yaml +++ b/litellm/proxy/_new_secret_config.yaml @@ -35,13 +35,7 @@ litellm_settings: # see https://docs.litellm.ai/docs/proxy/caching#turn-on-batch_redis_requests # see https://docs.litellm.ai/docs/proxy/prometheus - callbacks: ['prometheus', 'otel'] - - # # see https://docs.litellm.ai/docs/proxy/logging#logging-proxy-inputoutput---sentry - failure_callback: ['sentry'] - service_callback: ['prometheus_system'] - - # redact_user_api_key_info: true + callbacks: ['otel'] router_settings: diff --git a/litellm/proxy/auth/auth_checks.py b/litellm/proxy/auth/auth_checks.py index b3f249d6f..e00d494d9 100644 --- a/litellm/proxy/auth/auth_checks.py +++ b/litellm/proxy/auth/auth_checks.py @@ -18,6 +18,7 @@ from pydantic import BaseModel import litellm from litellm._logging import verbose_proxy_logger from litellm.caching.caching import DualCache +from litellm.caching.dual_cache import LimitedSizeOrderedDict from litellm.proxy._types import ( LiteLLM_EndUserTable, LiteLLM_JWTAuth, @@ -42,6 +43,10 @@ if TYPE_CHECKING: else: Span = Any + +last_db_access_time = LimitedSizeOrderedDict(max_size=100) +db_cache_expiry = 5 # refresh every 5s + all_routes = LiteLLMRoutes.openai_routes.value + LiteLLMRoutes.management_routes.value @@ -383,6 +388,32 @@ def model_in_access_group(model: str, team_models: Optional[List[str]]) -> bool: return False +def _should_check_db( + key: str, last_db_access_time: LimitedSizeOrderedDict, db_cache_expiry: int +) -> bool: + """ + Prevent calling db repeatedly for items that don't exist in the db. + """ + current_time = time.time() + # if key doesn't exist in last_db_access_time -> check db + if key not in last_db_access_time: + return True + elif ( + last_db_access_time[key][0] is not None + ): # check db for non-null values (for refresh operations) + return True + elif last_db_access_time[key][0] is None: + if current_time - last_db_access_time[key] >= db_cache_expiry: + return True + return False + + +def _update_last_db_access_time( + key: str, value: Optional[Any], last_db_access_time: LimitedSizeOrderedDict +): + last_db_access_time[key] = (value, time.time()) + + @log_to_opentelemetry async def get_user_object( user_id: str, @@ -412,11 +443,20 @@ async def get_user_object( if prisma_client is None: raise Exception("No db connected") try: - - response = await prisma_client.db.litellm_usertable.find_unique( - where={"user_id": user_id}, include={"organization_memberships": True} + db_access_time_key = "user_id:{}".format(user_id) + should_check_db = _should_check_db( + key=db_access_time_key, + last_db_access_time=last_db_access_time, + db_cache_expiry=db_cache_expiry, ) + if should_check_db: + response = await prisma_client.db.litellm_usertable.find_unique( + where={"user_id": user_id}, include={"organization_memberships": True} + ) + else: + response = None + if response is None: if user_id_upsert: response = await prisma_client.db.litellm_usertable.create( @@ -444,6 +484,13 @@ async def get_user_object( # save the user object to cache await user_api_key_cache.async_set_cache(key=user_id, value=response_dict) + # save to db access time + _update_last_db_access_time( + key=db_access_time_key, + value=response_dict, + last_db_access_time=last_db_access_time, + ) + return _response except Exception as e: # if user not in db raise ValueError( @@ -515,6 +562,12 @@ async def _delete_cache_key_object( @log_to_opentelemetry +async def _get_team_db_check(team_id: str, prisma_client: PrismaClient): + return await prisma_client.db.litellm_teamtable.find_unique( + where={"team_id": team_id} + ) + + async def get_team_object( team_id: str, prisma_client: Optional[PrismaClient], @@ -544,7 +597,7 @@ async def get_team_object( ): cached_team_obj = ( await proxy_logging_obj.internal_usage_cache.dual_cache.async_get_cache( - key=key + key=key, parent_otel_span=parent_otel_span ) ) @@ -564,9 +617,18 @@ async def get_team_object( # else, check db try: - response = await prisma_client.db.litellm_teamtable.find_unique( - where={"team_id": team_id} + db_access_time_key = "team_id:{}".format(team_id) + should_check_db = _should_check_db( + key=db_access_time_key, + last_db_access_time=last_db_access_time, + db_cache_expiry=db_cache_expiry, ) + if should_check_db: + response = await _get_team_db_check( + team_id=team_id, prisma_client=prisma_client + ) + else: + response = None if response is None: raise Exception @@ -580,6 +642,14 @@ async def get_team_object( proxy_logging_obj=proxy_logging_obj, ) + # save to db access time + # save to db access time + _update_last_db_access_time( + key=db_access_time_key, + value=_response, + last_db_access_time=last_db_access_time, + ) + return _response except Exception: raise Exception( @@ -608,16 +678,16 @@ async def get_key_object( # check if in cache key = hashed_token - cached_team_obj: Optional[UserAPIKeyAuth] = None - if cached_team_obj is None: - cached_team_obj = await user_api_key_cache.async_get_cache(key=key) + cached_key_obj: Optional[UserAPIKeyAuth] = await user_api_key_cache.async_get_cache( + key=key + ) - if cached_team_obj is not None: - if isinstance(cached_team_obj, dict): - return UserAPIKeyAuth(**cached_team_obj) - elif isinstance(cached_team_obj, UserAPIKeyAuth): - return cached_team_obj + if cached_key_obj is not None: + if isinstance(cached_key_obj, dict): + return UserAPIKeyAuth(**cached_key_obj) + elif isinstance(cached_key_obj, UserAPIKeyAuth): + return cached_key_obj if check_cache_only: raise Exception( diff --git a/litellm/proxy/auth/user_api_key_auth.py b/litellm/proxy/auth/user_api_key_auth.py index 995a95f79..d25b6f620 100644 --- a/litellm/proxy/auth/user_api_key_auth.py +++ b/litellm/proxy/auth/user_api_key_auth.py @@ -1127,11 +1127,13 @@ async def user_api_key_auth( # noqa: PLR0915 api_key = valid_token.token # Add hashed token to cache - await _cache_key_object( - hashed_token=api_key, - user_api_key_obj=valid_token, - user_api_key_cache=user_api_key_cache, - proxy_logging_obj=proxy_logging_obj, + asyncio.create_task( + _cache_key_object( + hashed_token=api_key, + user_api_key_obj=valid_token, + user_api_key_cache=user_api_key_cache, + proxy_logging_obj=proxy_logging_obj, + ) ) valid_token_dict = valid_token.model_dump(exclude_none=True) diff --git a/litellm/proxy/litellm_pre_call_utils.py b/litellm/proxy/litellm_pre_call_utils.py index a34dffccd..789e79f37 100644 --- a/litellm/proxy/litellm_pre_call_utils.py +++ b/litellm/proxy/litellm_pre_call_utils.py @@ -1,4 +1,5 @@ import copy +import time from typing import TYPE_CHECKING, Any, Dict, Optional, Union from fastapi import Request @@ -6,6 +7,7 @@ from starlette.datastructures import Headers import litellm from litellm._logging import verbose_logger, verbose_proxy_logger +from litellm._service_logger import ServiceLogging from litellm.proxy._types import ( AddTeamCallback, CommonProxyErrors, @@ -16,11 +18,15 @@ from litellm.proxy._types import ( UserAPIKeyAuth, ) from litellm.proxy.auth.auth_utils import get_request_route +from litellm.types.services import ServiceTypes from litellm.types.utils import ( StandardLoggingUserAPIKeyMetadata, SupportedCacheControls, ) +service_logger_obj = ServiceLogging() # used for tracking latency on OTEL + + if TYPE_CHECKING: from litellm.proxy.proxy_server import ProxyConfig as _ProxyConfig @@ -471,7 +477,7 @@ async def add_litellm_data_to_request( # noqa: PLR0915 ### END-USER SPECIFIC PARAMS ### if user_api_key_dict.allowed_model_region is not None: data["allowed_model_region"] = user_api_key_dict.allowed_model_region - + start_time = time.time() ## [Enterprise Only] # Add User-IP Address requester_ip_address = "" @@ -539,6 +545,16 @@ async def add_litellm_data_to_request( # noqa: PLR0915 verbose_proxy_logger.debug( f"[PROXY]returned data from litellm_pre_call_utils: {data}" ) + + end_time = time.time() + await service_logger_obj.async_service_success_hook( + service=ServiceTypes.PROXY_PRE_CALL, + duration=end_time - start_time, + call_type="add_litellm_data_to_request", + start_time=start_time, + end_time=end_time, + parent_otel_span=user_api_key_dict.parent_otel_span, + ) return data diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index 363ab4efd..37cbd2b82 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -1373,9 +1373,6 @@ class ProxyConfig: ) -> dict: """ Given a config file path, load the config from the file. - - If `store_model_in_db` is True, then read the DB and update the config with the DB values. - Args: config_file_path (str): path to the config file Returns: @@ -1401,40 +1398,6 @@ class ProxyConfig: "litellm_settings": {}, } - ## DB - if prisma_client is not None and ( - general_settings.get("store_model_in_db", False) is True - or store_model_in_db is True - ): - _tasks = [] - keys = [ - "general_settings", - "router_settings", - "litellm_settings", - "environment_variables", - ] - for k in keys: - response = prisma_client.get_generic_data( - key="param_name", value=k, table_name="config" - ) - _tasks.append(response) - - responses = await asyncio.gather(*_tasks) - for response in responses: - if response is not None: - param_name = getattr(response, "param_name", None) - param_value = getattr(response, "param_value", None) - if param_name is not None and param_value is not None: - # check if param_name is already in the config - if param_name in config: - if isinstance(config[param_name], dict): - config[param_name].update(param_value) - else: - config[param_name] = param_value - else: - # if it's not in the config - then add it - config[param_name] = param_value - return config async def save_config(self, new_config: dict): @@ -1500,8 +1463,10 @@ class ProxyConfig: - for a given team id - return the relevant completion() call params """ + # load existing config config = await self.get_config() + ## LITELLM MODULE SETTINGS (e.g. litellm.drop_params=True,..) litellm_settings = config.get("litellm_settings", {}) all_teams_config = litellm_settings.get("default_team_settings", None) @@ -8824,7 +8789,7 @@ async def update_config(config_info: ConfigYAML): # noqa: PLR0915 if k == "alert_to_webhook_url": # check if slack is already enabled. if not, enable it if "alerting" not in _existing_settings: - _existing_settings["alerting"].append("slack") + _existing_settings = {"alerting": ["slack"]} elif isinstance(_existing_settings["alerting"], list): if "slack" not in _existing_settings["alerting"]: _existing_settings["alerting"].append("slack") diff --git a/litellm/proxy/utils.py b/litellm/proxy/utils.py index 82831b3b2..44243cab0 100644 --- a/litellm/proxy/utils.py +++ b/litellm/proxy/utils.py @@ -1400,6 +1400,7 @@ class PrismaClient: return + @log_to_opentelemetry @backoff.on_exception( backoff.expo, Exception, # base exception to catch for the backoff diff --git a/litellm/types/services.py b/litellm/types/services.py index 5f690f328..cfa427ebc 100644 --- a/litellm/types/services.py +++ b/litellm/types/services.py @@ -16,6 +16,7 @@ class ServiceTypes(str, enum.Enum): LITELLM = "self" ROUTER = "router" AUTH = "auth" + PROXY_PRE_CALL = "proxy_pre_call" class ServiceLoggerPayload(BaseModel): diff --git a/litellm/utils.py b/litellm/utils.py index 70f43e512..8bd001def 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -1539,9 +1539,15 @@ def create_pretrained_tokenizer( dict: A dictionary with the tokenizer and its type. """ - tokenizer = Tokenizer.from_pretrained( - identifier, revision=revision, auth_token=auth_token - ) + try: + tokenizer = Tokenizer.from_pretrained( + identifier, revision=revision, auth_token=auth_token + ) + except Exception as e: + verbose_logger.error( + f"Error creating pretrained tokenizer: {e}. Defaulting to version without 'auth_token'." + ) + tokenizer = Tokenizer.from_pretrained(identifier, revision=revision) return {"type": "huggingface_tokenizer", "tokenizer": tokenizer} diff --git a/tests/local_testing/test_key_generate_prisma.py b/tests/local_testing/test_key_generate_prisma.py index 74182c09f..e009e214c 100644 --- a/tests/local_testing/test_key_generate_prisma.py +++ b/tests/local_testing/test_key_generate_prisma.py @@ -2717,7 +2717,7 @@ async def test_update_user_role(prisma_client): ) ) - await asyncio.sleep(2) + # await asyncio.sleep(3) # use generated key to auth in print("\n\nMAKING NEW REQUEST WITH UPDATED USER ROLE\n\n") diff --git a/tests/local_testing/test_router.py b/tests/local_testing/test_router.py index 7bf0b0bba..5ffdbc7ac 100644 --- a/tests/local_testing/test_router.py +++ b/tests/local_testing/test_router.py @@ -2486,6 +2486,7 @@ async def test_aaarouter_dynamic_cooldown_message_retry_time(sync_mode): @pytest.mark.parametrize("sync_mode", [True, False]) @pytest.mark.asyncio() +@pytest.mark.flaky(retries=6, delay=1) async def test_router_weighted_pick(sync_mode): router = Router( model_list=[