Litellm perf improvements 3 (#6573)

* perf: move writing key to cache, to background task

* perf(litellm_pre_call_utils.py): add otel tracing for pre-call utils

adds 200ms on calls with pgdb connected

* fix(litellm_pre_call_utils.py'): rename call_type to actual call used

* perf(proxy_server.py): remove db logic from _get_config_from_file

was causing db calls to occur on every llm request, if team_id was set on key

* fix(auth_checks.py): add check for reducing db calls if user/team id does not exist in db

reduces latency/call by ~100ms

* fix(proxy_server.py): minor fix on existing_settings not incl alerting

* fix(exception_mapping_utils.py): map databricks exception string

* fix(auth_checks.py): fix auth check logic

* test: correctly mark flaky test

* fix(utils.py): handle auth token error for tokenizers.from_pretrained
This commit is contained in:
Krish Dholakia 2024-11-05 03:51:26 +05:30 committed by GitHub
parent 7525b6bbaa
commit 3a6ba0b955
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
14 changed files with 137 additions and 86 deletions

View file

@ -392,7 +392,7 @@ jobs:
pip install click
pip install "boto3==1.34.34"
pip install jinja2
pip install tokenizers
pip install tokenizers=="0.20.0"
pip install jsonschema
- run:
name: Run tests

View file

@ -70,7 +70,7 @@ class DualCache(BaseCache):
self.redis_batch_cache_expiry = (
default_redis_batch_cache_expiry
or litellm.default_redis_batch_cache_expiry
or 5
or 10
)
self.default_in_memory_ttl = (
default_in_memory_ttl or litellm.default_in_memory_ttl

View file

@ -281,21 +281,6 @@ class OpenTelemetry(CustomLogger):
# End Parent OTEL Sspan
parent_otel_span.end(end_time=self._to_ns(datetime.now()))
async def async_post_call_success_hook(
self,
data: dict,
user_api_key_dict: UserAPIKeyAuth,
response: Union[Any, ModelResponse, EmbeddingResponse, ImageResponse],
):
from opentelemetry import trace
from opentelemetry.trace import Status, StatusCode
parent_otel_span = user_api_key_dict.parent_otel_span
if parent_otel_span is not None:
parent_otel_span.set_status(Status(StatusCode.OK))
# End Parent OTEL Sspan
parent_otel_span.end(end_time=self._to_ns(datetime.now()))
def _handle_sucess(self, kwargs, response_obj, start_time, end_time):
from opentelemetry import trace
from opentelemetry.trace import Status, StatusCode

View file

@ -646,6 +646,16 @@ def exception_type( # type: ignore # noqa: PLR0915
response=original_exception.response,
litellm_debug_info=extra_information,
)
elif (
"The server received an invalid response from an upstream server."
in error_str
):
exception_mapping_worked = True
raise litellm.InternalServerError(
message=f"{custom_llm_provider}Exception - {original_exception.message}",
llm_provider=custom_llm_provider,
model=model,
)
elif hasattr(original_exception, "status_code"):
if original_exception.status_code == 500:
exception_mapping_worked = True

View file

@ -35,13 +35,7 @@ litellm_settings:
# see https://docs.litellm.ai/docs/proxy/caching#turn-on-batch_redis_requests
# see https://docs.litellm.ai/docs/proxy/prometheus
callbacks: ['prometheus', 'otel']
# # see https://docs.litellm.ai/docs/proxy/logging#logging-proxy-inputoutput---sentry
failure_callback: ['sentry']
service_callback: ['prometheus_system']
# redact_user_api_key_info: true
callbacks: ['otel']
router_settings:

View file

@ -18,6 +18,7 @@ from pydantic import BaseModel
import litellm
from litellm._logging import verbose_proxy_logger
from litellm.caching.caching import DualCache
from litellm.caching.dual_cache import LimitedSizeOrderedDict
from litellm.proxy._types import (
LiteLLM_EndUserTable,
LiteLLM_JWTAuth,
@ -42,6 +43,10 @@ if TYPE_CHECKING:
else:
Span = Any
last_db_access_time = LimitedSizeOrderedDict(max_size=100)
db_cache_expiry = 5 # refresh every 5s
all_routes = LiteLLMRoutes.openai_routes.value + LiteLLMRoutes.management_routes.value
@ -383,6 +388,32 @@ def model_in_access_group(model: str, team_models: Optional[List[str]]) -> bool:
return False
def _should_check_db(
key: str, last_db_access_time: LimitedSizeOrderedDict, db_cache_expiry: int
) -> bool:
"""
Prevent calling db repeatedly for items that don't exist in the db.
"""
current_time = time.time()
# if key doesn't exist in last_db_access_time -> check db
if key not in last_db_access_time:
return True
elif (
last_db_access_time[key][0] is not None
): # check db for non-null values (for refresh operations)
return True
elif last_db_access_time[key][0] is None:
if current_time - last_db_access_time[key] >= db_cache_expiry:
return True
return False
def _update_last_db_access_time(
key: str, value: Optional[Any], last_db_access_time: LimitedSizeOrderedDict
):
last_db_access_time[key] = (value, time.time())
@log_to_opentelemetry
async def get_user_object(
user_id: str,
@ -412,11 +443,20 @@ async def get_user_object(
if prisma_client is None:
raise Exception("No db connected")
try:
response = await prisma_client.db.litellm_usertable.find_unique(
where={"user_id": user_id}, include={"organization_memberships": True}
db_access_time_key = "user_id:{}".format(user_id)
should_check_db = _should_check_db(
key=db_access_time_key,
last_db_access_time=last_db_access_time,
db_cache_expiry=db_cache_expiry,
)
if should_check_db:
response = await prisma_client.db.litellm_usertable.find_unique(
where={"user_id": user_id}, include={"organization_memberships": True}
)
else:
response = None
if response is None:
if user_id_upsert:
response = await prisma_client.db.litellm_usertable.create(
@ -444,6 +484,13 @@ async def get_user_object(
# save the user object to cache
await user_api_key_cache.async_set_cache(key=user_id, value=response_dict)
# save to db access time
_update_last_db_access_time(
key=db_access_time_key,
value=response_dict,
last_db_access_time=last_db_access_time,
)
return _response
except Exception as e: # if user not in db
raise ValueError(
@ -515,6 +562,12 @@ async def _delete_cache_key_object(
@log_to_opentelemetry
async def _get_team_db_check(team_id: str, prisma_client: PrismaClient):
return await prisma_client.db.litellm_teamtable.find_unique(
where={"team_id": team_id}
)
async def get_team_object(
team_id: str,
prisma_client: Optional[PrismaClient],
@ -544,7 +597,7 @@ async def get_team_object(
):
cached_team_obj = (
await proxy_logging_obj.internal_usage_cache.dual_cache.async_get_cache(
key=key
key=key, parent_otel_span=parent_otel_span
)
)
@ -564,9 +617,18 @@ async def get_team_object(
# else, check db
try:
response = await prisma_client.db.litellm_teamtable.find_unique(
where={"team_id": team_id}
db_access_time_key = "team_id:{}".format(team_id)
should_check_db = _should_check_db(
key=db_access_time_key,
last_db_access_time=last_db_access_time,
db_cache_expiry=db_cache_expiry,
)
if should_check_db:
response = await _get_team_db_check(
team_id=team_id, prisma_client=prisma_client
)
else:
response = None
if response is None:
raise Exception
@ -580,6 +642,14 @@ async def get_team_object(
proxy_logging_obj=proxy_logging_obj,
)
# save to db access time
# save to db access time
_update_last_db_access_time(
key=db_access_time_key,
value=_response,
last_db_access_time=last_db_access_time,
)
return _response
except Exception:
raise Exception(
@ -608,16 +678,16 @@ async def get_key_object(
# check if in cache
key = hashed_token
cached_team_obj: Optional[UserAPIKeyAuth] = None
if cached_team_obj is None:
cached_team_obj = await user_api_key_cache.async_get_cache(key=key)
cached_key_obj: Optional[UserAPIKeyAuth] = await user_api_key_cache.async_get_cache(
key=key
)
if cached_team_obj is not None:
if isinstance(cached_team_obj, dict):
return UserAPIKeyAuth(**cached_team_obj)
elif isinstance(cached_team_obj, UserAPIKeyAuth):
return cached_team_obj
if cached_key_obj is not None:
if isinstance(cached_key_obj, dict):
return UserAPIKeyAuth(**cached_key_obj)
elif isinstance(cached_key_obj, UserAPIKeyAuth):
return cached_key_obj
if check_cache_only:
raise Exception(

View file

@ -1127,11 +1127,13 @@ async def user_api_key_auth( # noqa: PLR0915
api_key = valid_token.token
# Add hashed token to cache
await _cache_key_object(
hashed_token=api_key,
user_api_key_obj=valid_token,
user_api_key_cache=user_api_key_cache,
proxy_logging_obj=proxy_logging_obj,
asyncio.create_task(
_cache_key_object(
hashed_token=api_key,
user_api_key_obj=valid_token,
user_api_key_cache=user_api_key_cache,
proxy_logging_obj=proxy_logging_obj,
)
)
valid_token_dict = valid_token.model_dump(exclude_none=True)

View file

@ -1,4 +1,5 @@
import copy
import time
from typing import TYPE_CHECKING, Any, Dict, Optional, Union
from fastapi import Request
@ -6,6 +7,7 @@ from starlette.datastructures import Headers
import litellm
from litellm._logging import verbose_logger, verbose_proxy_logger
from litellm._service_logger import ServiceLogging
from litellm.proxy._types import (
AddTeamCallback,
CommonProxyErrors,
@ -16,11 +18,15 @@ from litellm.proxy._types import (
UserAPIKeyAuth,
)
from litellm.proxy.auth.auth_utils import get_request_route
from litellm.types.services import ServiceTypes
from litellm.types.utils import (
StandardLoggingUserAPIKeyMetadata,
SupportedCacheControls,
)
service_logger_obj = ServiceLogging() # used for tracking latency on OTEL
if TYPE_CHECKING:
from litellm.proxy.proxy_server import ProxyConfig as _ProxyConfig
@ -471,7 +477,7 @@ async def add_litellm_data_to_request( # noqa: PLR0915
### END-USER SPECIFIC PARAMS ###
if user_api_key_dict.allowed_model_region is not None:
data["allowed_model_region"] = user_api_key_dict.allowed_model_region
start_time = time.time()
## [Enterprise Only]
# Add User-IP Address
requester_ip_address = ""
@ -539,6 +545,16 @@ async def add_litellm_data_to_request( # noqa: PLR0915
verbose_proxy_logger.debug(
f"[PROXY]returned data from litellm_pre_call_utils: {data}"
)
end_time = time.time()
await service_logger_obj.async_service_success_hook(
service=ServiceTypes.PROXY_PRE_CALL,
duration=end_time - start_time,
call_type="add_litellm_data_to_request",
start_time=start_time,
end_time=end_time,
parent_otel_span=user_api_key_dict.parent_otel_span,
)
return data

View file

@ -1373,9 +1373,6 @@ class ProxyConfig:
) -> dict:
"""
Given a config file path, load the config from the file.
If `store_model_in_db` is True, then read the DB and update the config with the DB values.
Args:
config_file_path (str): path to the config file
Returns:
@ -1401,40 +1398,6 @@ class ProxyConfig:
"litellm_settings": {},
}
## DB
if prisma_client is not None and (
general_settings.get("store_model_in_db", False) is True
or store_model_in_db is True
):
_tasks = []
keys = [
"general_settings",
"router_settings",
"litellm_settings",
"environment_variables",
]
for k in keys:
response = prisma_client.get_generic_data(
key="param_name", value=k, table_name="config"
)
_tasks.append(response)
responses = await asyncio.gather(*_tasks)
for response in responses:
if response is not None:
param_name = getattr(response, "param_name", None)
param_value = getattr(response, "param_value", None)
if param_name is not None and param_value is not None:
# check if param_name is already in the config
if param_name in config:
if isinstance(config[param_name], dict):
config[param_name].update(param_value)
else:
config[param_name] = param_value
else:
# if it's not in the config - then add it
config[param_name] = param_value
return config
async def save_config(self, new_config: dict):
@ -1500,8 +1463,10 @@ class ProxyConfig:
- for a given team id
- return the relevant completion() call params
"""
# load existing config
config = await self.get_config()
## LITELLM MODULE SETTINGS (e.g. litellm.drop_params=True,..)
litellm_settings = config.get("litellm_settings", {})
all_teams_config = litellm_settings.get("default_team_settings", None)
@ -8824,7 +8789,7 @@ async def update_config(config_info: ConfigYAML): # noqa: PLR0915
if k == "alert_to_webhook_url":
# check if slack is already enabled. if not, enable it
if "alerting" not in _existing_settings:
_existing_settings["alerting"].append("slack")
_existing_settings = {"alerting": ["slack"]}
elif isinstance(_existing_settings["alerting"], list):
if "slack" not in _existing_settings["alerting"]:
_existing_settings["alerting"].append("slack")

View file

@ -1400,6 +1400,7 @@ class PrismaClient:
return
@log_to_opentelemetry
@backoff.on_exception(
backoff.expo,
Exception, # base exception to catch for the backoff

View file

@ -16,6 +16,7 @@ class ServiceTypes(str, enum.Enum):
LITELLM = "self"
ROUTER = "router"
AUTH = "auth"
PROXY_PRE_CALL = "proxy_pre_call"
class ServiceLoggerPayload(BaseModel):

View file

@ -1539,9 +1539,15 @@ def create_pretrained_tokenizer(
dict: A dictionary with the tokenizer and its type.
"""
tokenizer = Tokenizer.from_pretrained(
identifier, revision=revision, auth_token=auth_token
)
try:
tokenizer = Tokenizer.from_pretrained(
identifier, revision=revision, auth_token=auth_token
)
except Exception as e:
verbose_logger.error(
f"Error creating pretrained tokenizer: {e}. Defaulting to version without 'auth_token'."
)
tokenizer = Tokenizer.from_pretrained(identifier, revision=revision)
return {"type": "huggingface_tokenizer", "tokenizer": tokenizer}

View file

@ -2717,7 +2717,7 @@ async def test_update_user_role(prisma_client):
)
)
await asyncio.sleep(2)
# await asyncio.sleep(3)
# use generated key to auth in
print("\n\nMAKING NEW REQUEST WITH UPDATED USER ROLE\n\n")

View file

@ -2486,6 +2486,7 @@ async def test_aaarouter_dynamic_cooldown_message_retry_time(sync_mode):
@pytest.mark.parametrize("sync_mode", [True, False])
@pytest.mark.asyncio()
@pytest.mark.flaky(retries=6, delay=1)
async def test_router_weighted_pick(sync_mode):
router = Router(
model_list=[