forked from phoenix/litellm-mirror
Litellm perf improvements 3 (#6573)
* perf: move writing key to cache, to background task * perf(litellm_pre_call_utils.py): add otel tracing for pre-call utils adds 200ms on calls with pgdb connected * fix(litellm_pre_call_utils.py'): rename call_type to actual call used * perf(proxy_server.py): remove db logic from _get_config_from_file was causing db calls to occur on every llm request, if team_id was set on key * fix(auth_checks.py): add check for reducing db calls if user/team id does not exist in db reduces latency/call by ~100ms * fix(proxy_server.py): minor fix on existing_settings not incl alerting * fix(exception_mapping_utils.py): map databricks exception string * fix(auth_checks.py): fix auth check logic * test: correctly mark flaky test * fix(utils.py): handle auth token error for tokenizers.from_pretrained
This commit is contained in:
parent
7525b6bbaa
commit
3a6ba0b955
14 changed files with 137 additions and 86 deletions
|
@ -392,7 +392,7 @@ jobs:
|
|||
pip install click
|
||||
pip install "boto3==1.34.34"
|
||||
pip install jinja2
|
||||
pip install tokenizers
|
||||
pip install tokenizers=="0.20.0"
|
||||
pip install jsonschema
|
||||
- run:
|
||||
name: Run tests
|
||||
|
|
|
@ -70,7 +70,7 @@ class DualCache(BaseCache):
|
|||
self.redis_batch_cache_expiry = (
|
||||
default_redis_batch_cache_expiry
|
||||
or litellm.default_redis_batch_cache_expiry
|
||||
or 5
|
||||
or 10
|
||||
)
|
||||
self.default_in_memory_ttl = (
|
||||
default_in_memory_ttl or litellm.default_in_memory_ttl
|
||||
|
|
|
@ -281,21 +281,6 @@ class OpenTelemetry(CustomLogger):
|
|||
# End Parent OTEL Sspan
|
||||
parent_otel_span.end(end_time=self._to_ns(datetime.now()))
|
||||
|
||||
async def async_post_call_success_hook(
|
||||
self,
|
||||
data: dict,
|
||||
user_api_key_dict: UserAPIKeyAuth,
|
||||
response: Union[Any, ModelResponse, EmbeddingResponse, ImageResponse],
|
||||
):
|
||||
from opentelemetry import trace
|
||||
from opentelemetry.trace import Status, StatusCode
|
||||
|
||||
parent_otel_span = user_api_key_dict.parent_otel_span
|
||||
if parent_otel_span is not None:
|
||||
parent_otel_span.set_status(Status(StatusCode.OK))
|
||||
# End Parent OTEL Sspan
|
||||
parent_otel_span.end(end_time=self._to_ns(datetime.now()))
|
||||
|
||||
def _handle_sucess(self, kwargs, response_obj, start_time, end_time):
|
||||
from opentelemetry import trace
|
||||
from opentelemetry.trace import Status, StatusCode
|
||||
|
|
|
@ -646,6 +646,16 @@ def exception_type( # type: ignore # noqa: PLR0915
|
|||
response=original_exception.response,
|
||||
litellm_debug_info=extra_information,
|
||||
)
|
||||
elif (
|
||||
"The server received an invalid response from an upstream server."
|
||||
in error_str
|
||||
):
|
||||
exception_mapping_worked = True
|
||||
raise litellm.InternalServerError(
|
||||
message=f"{custom_llm_provider}Exception - {original_exception.message}",
|
||||
llm_provider=custom_llm_provider,
|
||||
model=model,
|
||||
)
|
||||
elif hasattr(original_exception, "status_code"):
|
||||
if original_exception.status_code == 500:
|
||||
exception_mapping_worked = True
|
||||
|
|
|
@ -35,13 +35,7 @@ litellm_settings:
|
|||
|
||||
# see https://docs.litellm.ai/docs/proxy/caching#turn-on-batch_redis_requests
|
||||
# see https://docs.litellm.ai/docs/proxy/prometheus
|
||||
callbacks: ['prometheus', 'otel']
|
||||
|
||||
# # see https://docs.litellm.ai/docs/proxy/logging#logging-proxy-inputoutput---sentry
|
||||
failure_callback: ['sentry']
|
||||
service_callback: ['prometheus_system']
|
||||
|
||||
# redact_user_api_key_info: true
|
||||
callbacks: ['otel']
|
||||
|
||||
|
||||
router_settings:
|
||||
|
|
|
@ -18,6 +18,7 @@ from pydantic import BaseModel
|
|||
import litellm
|
||||
from litellm._logging import verbose_proxy_logger
|
||||
from litellm.caching.caching import DualCache
|
||||
from litellm.caching.dual_cache import LimitedSizeOrderedDict
|
||||
from litellm.proxy._types import (
|
||||
LiteLLM_EndUserTable,
|
||||
LiteLLM_JWTAuth,
|
||||
|
@ -42,6 +43,10 @@ if TYPE_CHECKING:
|
|||
else:
|
||||
Span = Any
|
||||
|
||||
|
||||
last_db_access_time = LimitedSizeOrderedDict(max_size=100)
|
||||
db_cache_expiry = 5 # refresh every 5s
|
||||
|
||||
all_routes = LiteLLMRoutes.openai_routes.value + LiteLLMRoutes.management_routes.value
|
||||
|
||||
|
||||
|
@ -383,6 +388,32 @@ def model_in_access_group(model: str, team_models: Optional[List[str]]) -> bool:
|
|||
return False
|
||||
|
||||
|
||||
def _should_check_db(
|
||||
key: str, last_db_access_time: LimitedSizeOrderedDict, db_cache_expiry: int
|
||||
) -> bool:
|
||||
"""
|
||||
Prevent calling db repeatedly for items that don't exist in the db.
|
||||
"""
|
||||
current_time = time.time()
|
||||
# if key doesn't exist in last_db_access_time -> check db
|
||||
if key not in last_db_access_time:
|
||||
return True
|
||||
elif (
|
||||
last_db_access_time[key][0] is not None
|
||||
): # check db for non-null values (for refresh operations)
|
||||
return True
|
||||
elif last_db_access_time[key][0] is None:
|
||||
if current_time - last_db_access_time[key] >= db_cache_expiry:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def _update_last_db_access_time(
|
||||
key: str, value: Optional[Any], last_db_access_time: LimitedSizeOrderedDict
|
||||
):
|
||||
last_db_access_time[key] = (value, time.time())
|
||||
|
||||
|
||||
@log_to_opentelemetry
|
||||
async def get_user_object(
|
||||
user_id: str,
|
||||
|
@ -412,11 +443,20 @@ async def get_user_object(
|
|||
if prisma_client is None:
|
||||
raise Exception("No db connected")
|
||||
try:
|
||||
|
||||
response = await prisma_client.db.litellm_usertable.find_unique(
|
||||
where={"user_id": user_id}, include={"organization_memberships": True}
|
||||
db_access_time_key = "user_id:{}".format(user_id)
|
||||
should_check_db = _should_check_db(
|
||||
key=db_access_time_key,
|
||||
last_db_access_time=last_db_access_time,
|
||||
db_cache_expiry=db_cache_expiry,
|
||||
)
|
||||
|
||||
if should_check_db:
|
||||
response = await prisma_client.db.litellm_usertable.find_unique(
|
||||
where={"user_id": user_id}, include={"organization_memberships": True}
|
||||
)
|
||||
else:
|
||||
response = None
|
||||
|
||||
if response is None:
|
||||
if user_id_upsert:
|
||||
response = await prisma_client.db.litellm_usertable.create(
|
||||
|
@ -444,6 +484,13 @@ async def get_user_object(
|
|||
# save the user object to cache
|
||||
await user_api_key_cache.async_set_cache(key=user_id, value=response_dict)
|
||||
|
||||
# save to db access time
|
||||
_update_last_db_access_time(
|
||||
key=db_access_time_key,
|
||||
value=response_dict,
|
||||
last_db_access_time=last_db_access_time,
|
||||
)
|
||||
|
||||
return _response
|
||||
except Exception as e: # if user not in db
|
||||
raise ValueError(
|
||||
|
@ -515,6 +562,12 @@ async def _delete_cache_key_object(
|
|||
|
||||
|
||||
@log_to_opentelemetry
|
||||
async def _get_team_db_check(team_id: str, prisma_client: PrismaClient):
|
||||
return await prisma_client.db.litellm_teamtable.find_unique(
|
||||
where={"team_id": team_id}
|
||||
)
|
||||
|
||||
|
||||
async def get_team_object(
|
||||
team_id: str,
|
||||
prisma_client: Optional[PrismaClient],
|
||||
|
@ -544,7 +597,7 @@ async def get_team_object(
|
|||
):
|
||||
cached_team_obj = (
|
||||
await proxy_logging_obj.internal_usage_cache.dual_cache.async_get_cache(
|
||||
key=key
|
||||
key=key, parent_otel_span=parent_otel_span
|
||||
)
|
||||
)
|
||||
|
||||
|
@ -564,9 +617,18 @@ async def get_team_object(
|
|||
|
||||
# else, check db
|
||||
try:
|
||||
response = await prisma_client.db.litellm_teamtable.find_unique(
|
||||
where={"team_id": team_id}
|
||||
db_access_time_key = "team_id:{}".format(team_id)
|
||||
should_check_db = _should_check_db(
|
||||
key=db_access_time_key,
|
||||
last_db_access_time=last_db_access_time,
|
||||
db_cache_expiry=db_cache_expiry,
|
||||
)
|
||||
if should_check_db:
|
||||
response = await _get_team_db_check(
|
||||
team_id=team_id, prisma_client=prisma_client
|
||||
)
|
||||
else:
|
||||
response = None
|
||||
|
||||
if response is None:
|
||||
raise Exception
|
||||
|
@ -580,6 +642,14 @@ async def get_team_object(
|
|||
proxy_logging_obj=proxy_logging_obj,
|
||||
)
|
||||
|
||||
# save to db access time
|
||||
# save to db access time
|
||||
_update_last_db_access_time(
|
||||
key=db_access_time_key,
|
||||
value=_response,
|
||||
last_db_access_time=last_db_access_time,
|
||||
)
|
||||
|
||||
return _response
|
||||
except Exception:
|
||||
raise Exception(
|
||||
|
@ -608,16 +678,16 @@ async def get_key_object(
|
|||
|
||||
# check if in cache
|
||||
key = hashed_token
|
||||
cached_team_obj: Optional[UserAPIKeyAuth] = None
|
||||
|
||||
if cached_team_obj is None:
|
||||
cached_team_obj = await user_api_key_cache.async_get_cache(key=key)
|
||||
cached_key_obj: Optional[UserAPIKeyAuth] = await user_api_key_cache.async_get_cache(
|
||||
key=key
|
||||
)
|
||||
|
||||
if cached_team_obj is not None:
|
||||
if isinstance(cached_team_obj, dict):
|
||||
return UserAPIKeyAuth(**cached_team_obj)
|
||||
elif isinstance(cached_team_obj, UserAPIKeyAuth):
|
||||
return cached_team_obj
|
||||
if cached_key_obj is not None:
|
||||
if isinstance(cached_key_obj, dict):
|
||||
return UserAPIKeyAuth(**cached_key_obj)
|
||||
elif isinstance(cached_key_obj, UserAPIKeyAuth):
|
||||
return cached_key_obj
|
||||
|
||||
if check_cache_only:
|
||||
raise Exception(
|
||||
|
|
|
@ -1127,11 +1127,13 @@ async def user_api_key_auth( # noqa: PLR0915
|
|||
api_key = valid_token.token
|
||||
|
||||
# Add hashed token to cache
|
||||
await _cache_key_object(
|
||||
hashed_token=api_key,
|
||||
user_api_key_obj=valid_token,
|
||||
user_api_key_cache=user_api_key_cache,
|
||||
proxy_logging_obj=proxy_logging_obj,
|
||||
asyncio.create_task(
|
||||
_cache_key_object(
|
||||
hashed_token=api_key,
|
||||
user_api_key_obj=valid_token,
|
||||
user_api_key_cache=user_api_key_cache,
|
||||
proxy_logging_obj=proxy_logging_obj,
|
||||
)
|
||||
)
|
||||
|
||||
valid_token_dict = valid_token.model_dump(exclude_none=True)
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
import copy
|
||||
import time
|
||||
from typing import TYPE_CHECKING, Any, Dict, Optional, Union
|
||||
|
||||
from fastapi import Request
|
||||
|
@ -6,6 +7,7 @@ from starlette.datastructures import Headers
|
|||
|
||||
import litellm
|
||||
from litellm._logging import verbose_logger, verbose_proxy_logger
|
||||
from litellm._service_logger import ServiceLogging
|
||||
from litellm.proxy._types import (
|
||||
AddTeamCallback,
|
||||
CommonProxyErrors,
|
||||
|
@ -16,11 +18,15 @@ from litellm.proxy._types import (
|
|||
UserAPIKeyAuth,
|
||||
)
|
||||
from litellm.proxy.auth.auth_utils import get_request_route
|
||||
from litellm.types.services import ServiceTypes
|
||||
from litellm.types.utils import (
|
||||
StandardLoggingUserAPIKeyMetadata,
|
||||
SupportedCacheControls,
|
||||
)
|
||||
|
||||
service_logger_obj = ServiceLogging() # used for tracking latency on OTEL
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from litellm.proxy.proxy_server import ProxyConfig as _ProxyConfig
|
||||
|
||||
|
@ -471,7 +477,7 @@ async def add_litellm_data_to_request( # noqa: PLR0915
|
|||
### END-USER SPECIFIC PARAMS ###
|
||||
if user_api_key_dict.allowed_model_region is not None:
|
||||
data["allowed_model_region"] = user_api_key_dict.allowed_model_region
|
||||
|
||||
start_time = time.time()
|
||||
## [Enterprise Only]
|
||||
# Add User-IP Address
|
||||
requester_ip_address = ""
|
||||
|
@ -539,6 +545,16 @@ async def add_litellm_data_to_request( # noqa: PLR0915
|
|||
verbose_proxy_logger.debug(
|
||||
f"[PROXY]returned data from litellm_pre_call_utils: {data}"
|
||||
)
|
||||
|
||||
end_time = time.time()
|
||||
await service_logger_obj.async_service_success_hook(
|
||||
service=ServiceTypes.PROXY_PRE_CALL,
|
||||
duration=end_time - start_time,
|
||||
call_type="add_litellm_data_to_request",
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
parent_otel_span=user_api_key_dict.parent_otel_span,
|
||||
)
|
||||
return data
|
||||
|
||||
|
||||
|
|
|
@ -1373,9 +1373,6 @@ class ProxyConfig:
|
|||
) -> dict:
|
||||
"""
|
||||
Given a config file path, load the config from the file.
|
||||
|
||||
If `store_model_in_db` is True, then read the DB and update the config with the DB values.
|
||||
|
||||
Args:
|
||||
config_file_path (str): path to the config file
|
||||
Returns:
|
||||
|
@ -1401,40 +1398,6 @@ class ProxyConfig:
|
|||
"litellm_settings": {},
|
||||
}
|
||||
|
||||
## DB
|
||||
if prisma_client is not None and (
|
||||
general_settings.get("store_model_in_db", False) is True
|
||||
or store_model_in_db is True
|
||||
):
|
||||
_tasks = []
|
||||
keys = [
|
||||
"general_settings",
|
||||
"router_settings",
|
||||
"litellm_settings",
|
||||
"environment_variables",
|
||||
]
|
||||
for k in keys:
|
||||
response = prisma_client.get_generic_data(
|
||||
key="param_name", value=k, table_name="config"
|
||||
)
|
||||
_tasks.append(response)
|
||||
|
||||
responses = await asyncio.gather(*_tasks)
|
||||
for response in responses:
|
||||
if response is not None:
|
||||
param_name = getattr(response, "param_name", None)
|
||||
param_value = getattr(response, "param_value", None)
|
||||
if param_name is not None and param_value is not None:
|
||||
# check if param_name is already in the config
|
||||
if param_name in config:
|
||||
if isinstance(config[param_name], dict):
|
||||
config[param_name].update(param_value)
|
||||
else:
|
||||
config[param_name] = param_value
|
||||
else:
|
||||
# if it's not in the config - then add it
|
||||
config[param_name] = param_value
|
||||
|
||||
return config
|
||||
|
||||
async def save_config(self, new_config: dict):
|
||||
|
@ -1500,8 +1463,10 @@ class ProxyConfig:
|
|||
- for a given team id
|
||||
- return the relevant completion() call params
|
||||
"""
|
||||
|
||||
# load existing config
|
||||
config = await self.get_config()
|
||||
|
||||
## LITELLM MODULE SETTINGS (e.g. litellm.drop_params=True,..)
|
||||
litellm_settings = config.get("litellm_settings", {})
|
||||
all_teams_config = litellm_settings.get("default_team_settings", None)
|
||||
|
@ -8824,7 +8789,7 @@ async def update_config(config_info: ConfigYAML): # noqa: PLR0915
|
|||
if k == "alert_to_webhook_url":
|
||||
# check if slack is already enabled. if not, enable it
|
||||
if "alerting" not in _existing_settings:
|
||||
_existing_settings["alerting"].append("slack")
|
||||
_existing_settings = {"alerting": ["slack"]}
|
||||
elif isinstance(_existing_settings["alerting"], list):
|
||||
if "slack" not in _existing_settings["alerting"]:
|
||||
_existing_settings["alerting"].append("slack")
|
||||
|
|
|
@ -1400,6 +1400,7 @@ class PrismaClient:
|
|||
|
||||
return
|
||||
|
||||
@log_to_opentelemetry
|
||||
@backoff.on_exception(
|
||||
backoff.expo,
|
||||
Exception, # base exception to catch for the backoff
|
||||
|
|
|
@ -16,6 +16,7 @@ class ServiceTypes(str, enum.Enum):
|
|||
LITELLM = "self"
|
||||
ROUTER = "router"
|
||||
AUTH = "auth"
|
||||
PROXY_PRE_CALL = "proxy_pre_call"
|
||||
|
||||
|
||||
class ServiceLoggerPayload(BaseModel):
|
||||
|
|
|
@ -1539,9 +1539,15 @@ def create_pretrained_tokenizer(
|
|||
dict: A dictionary with the tokenizer and its type.
|
||||
"""
|
||||
|
||||
tokenizer = Tokenizer.from_pretrained(
|
||||
identifier, revision=revision, auth_token=auth_token
|
||||
)
|
||||
try:
|
||||
tokenizer = Tokenizer.from_pretrained(
|
||||
identifier, revision=revision, auth_token=auth_token
|
||||
)
|
||||
except Exception as e:
|
||||
verbose_logger.error(
|
||||
f"Error creating pretrained tokenizer: {e}. Defaulting to version without 'auth_token'."
|
||||
)
|
||||
tokenizer = Tokenizer.from_pretrained(identifier, revision=revision)
|
||||
return {"type": "huggingface_tokenizer", "tokenizer": tokenizer}
|
||||
|
||||
|
||||
|
|
|
@ -2717,7 +2717,7 @@ async def test_update_user_role(prisma_client):
|
|||
)
|
||||
)
|
||||
|
||||
await asyncio.sleep(2)
|
||||
# await asyncio.sleep(3)
|
||||
|
||||
# use generated key to auth in
|
||||
print("\n\nMAKING NEW REQUEST WITH UPDATED USER ROLE\n\n")
|
||||
|
|
|
@ -2486,6 +2486,7 @@ async def test_aaarouter_dynamic_cooldown_message_retry_time(sync_mode):
|
|||
|
||||
@pytest.mark.parametrize("sync_mode", [True, False])
|
||||
@pytest.mark.asyncio()
|
||||
@pytest.mark.flaky(retries=6, delay=1)
|
||||
async def test_router_weighted_pick(sync_mode):
|
||||
router = Router(
|
||||
model_list=[
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue