forked from phoenix/litellm-mirror
Litellm perf improvements 3 (#6573)
* perf: move writing key to cache, to background task * perf(litellm_pre_call_utils.py): add otel tracing for pre-call utils adds 200ms on calls with pgdb connected * fix(litellm_pre_call_utils.py'): rename call_type to actual call used * perf(proxy_server.py): remove db logic from _get_config_from_file was causing db calls to occur on every llm request, if team_id was set on key * fix(auth_checks.py): add check for reducing db calls if user/team id does not exist in db reduces latency/call by ~100ms * fix(proxy_server.py): minor fix on existing_settings not incl alerting * fix(exception_mapping_utils.py): map databricks exception string * fix(auth_checks.py): fix auth check logic * test: correctly mark flaky test * fix(utils.py): handle auth token error for tokenizers.from_pretrained
This commit is contained in:
parent
7525b6bbaa
commit
3a6ba0b955
14 changed files with 137 additions and 86 deletions
|
@ -392,7 +392,7 @@ jobs:
|
||||||
pip install click
|
pip install click
|
||||||
pip install "boto3==1.34.34"
|
pip install "boto3==1.34.34"
|
||||||
pip install jinja2
|
pip install jinja2
|
||||||
pip install tokenizers
|
pip install tokenizers=="0.20.0"
|
||||||
pip install jsonschema
|
pip install jsonschema
|
||||||
- run:
|
- run:
|
||||||
name: Run tests
|
name: Run tests
|
||||||
|
|
|
@ -70,7 +70,7 @@ class DualCache(BaseCache):
|
||||||
self.redis_batch_cache_expiry = (
|
self.redis_batch_cache_expiry = (
|
||||||
default_redis_batch_cache_expiry
|
default_redis_batch_cache_expiry
|
||||||
or litellm.default_redis_batch_cache_expiry
|
or litellm.default_redis_batch_cache_expiry
|
||||||
or 5
|
or 10
|
||||||
)
|
)
|
||||||
self.default_in_memory_ttl = (
|
self.default_in_memory_ttl = (
|
||||||
default_in_memory_ttl or litellm.default_in_memory_ttl
|
default_in_memory_ttl or litellm.default_in_memory_ttl
|
||||||
|
|
|
@ -281,21 +281,6 @@ class OpenTelemetry(CustomLogger):
|
||||||
# End Parent OTEL Sspan
|
# End Parent OTEL Sspan
|
||||||
parent_otel_span.end(end_time=self._to_ns(datetime.now()))
|
parent_otel_span.end(end_time=self._to_ns(datetime.now()))
|
||||||
|
|
||||||
async def async_post_call_success_hook(
|
|
||||||
self,
|
|
||||||
data: dict,
|
|
||||||
user_api_key_dict: UserAPIKeyAuth,
|
|
||||||
response: Union[Any, ModelResponse, EmbeddingResponse, ImageResponse],
|
|
||||||
):
|
|
||||||
from opentelemetry import trace
|
|
||||||
from opentelemetry.trace import Status, StatusCode
|
|
||||||
|
|
||||||
parent_otel_span = user_api_key_dict.parent_otel_span
|
|
||||||
if parent_otel_span is not None:
|
|
||||||
parent_otel_span.set_status(Status(StatusCode.OK))
|
|
||||||
# End Parent OTEL Sspan
|
|
||||||
parent_otel_span.end(end_time=self._to_ns(datetime.now()))
|
|
||||||
|
|
||||||
def _handle_sucess(self, kwargs, response_obj, start_time, end_time):
|
def _handle_sucess(self, kwargs, response_obj, start_time, end_time):
|
||||||
from opentelemetry import trace
|
from opentelemetry import trace
|
||||||
from opentelemetry.trace import Status, StatusCode
|
from opentelemetry.trace import Status, StatusCode
|
||||||
|
|
|
@ -646,6 +646,16 @@ def exception_type( # type: ignore # noqa: PLR0915
|
||||||
response=original_exception.response,
|
response=original_exception.response,
|
||||||
litellm_debug_info=extra_information,
|
litellm_debug_info=extra_information,
|
||||||
)
|
)
|
||||||
|
elif (
|
||||||
|
"The server received an invalid response from an upstream server."
|
||||||
|
in error_str
|
||||||
|
):
|
||||||
|
exception_mapping_worked = True
|
||||||
|
raise litellm.InternalServerError(
|
||||||
|
message=f"{custom_llm_provider}Exception - {original_exception.message}",
|
||||||
|
llm_provider=custom_llm_provider,
|
||||||
|
model=model,
|
||||||
|
)
|
||||||
elif hasattr(original_exception, "status_code"):
|
elif hasattr(original_exception, "status_code"):
|
||||||
if original_exception.status_code == 500:
|
if original_exception.status_code == 500:
|
||||||
exception_mapping_worked = True
|
exception_mapping_worked = True
|
||||||
|
|
|
@ -35,13 +35,7 @@ litellm_settings:
|
||||||
|
|
||||||
# see https://docs.litellm.ai/docs/proxy/caching#turn-on-batch_redis_requests
|
# see https://docs.litellm.ai/docs/proxy/caching#turn-on-batch_redis_requests
|
||||||
# see https://docs.litellm.ai/docs/proxy/prometheus
|
# see https://docs.litellm.ai/docs/proxy/prometheus
|
||||||
callbacks: ['prometheus', 'otel']
|
callbacks: ['otel']
|
||||||
|
|
||||||
# # see https://docs.litellm.ai/docs/proxy/logging#logging-proxy-inputoutput---sentry
|
|
||||||
failure_callback: ['sentry']
|
|
||||||
service_callback: ['prometheus_system']
|
|
||||||
|
|
||||||
# redact_user_api_key_info: true
|
|
||||||
|
|
||||||
|
|
||||||
router_settings:
|
router_settings:
|
||||||
|
|
|
@ -18,6 +18,7 @@ from pydantic import BaseModel
|
||||||
import litellm
|
import litellm
|
||||||
from litellm._logging import verbose_proxy_logger
|
from litellm._logging import verbose_proxy_logger
|
||||||
from litellm.caching.caching import DualCache
|
from litellm.caching.caching import DualCache
|
||||||
|
from litellm.caching.dual_cache import LimitedSizeOrderedDict
|
||||||
from litellm.proxy._types import (
|
from litellm.proxy._types import (
|
||||||
LiteLLM_EndUserTable,
|
LiteLLM_EndUserTable,
|
||||||
LiteLLM_JWTAuth,
|
LiteLLM_JWTAuth,
|
||||||
|
@ -42,6 +43,10 @@ if TYPE_CHECKING:
|
||||||
else:
|
else:
|
||||||
Span = Any
|
Span = Any
|
||||||
|
|
||||||
|
|
||||||
|
last_db_access_time = LimitedSizeOrderedDict(max_size=100)
|
||||||
|
db_cache_expiry = 5 # refresh every 5s
|
||||||
|
|
||||||
all_routes = LiteLLMRoutes.openai_routes.value + LiteLLMRoutes.management_routes.value
|
all_routes = LiteLLMRoutes.openai_routes.value + LiteLLMRoutes.management_routes.value
|
||||||
|
|
||||||
|
|
||||||
|
@ -383,6 +388,32 @@ def model_in_access_group(model: str, team_models: Optional[List[str]]) -> bool:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def _should_check_db(
|
||||||
|
key: str, last_db_access_time: LimitedSizeOrderedDict, db_cache_expiry: int
|
||||||
|
) -> bool:
|
||||||
|
"""
|
||||||
|
Prevent calling db repeatedly for items that don't exist in the db.
|
||||||
|
"""
|
||||||
|
current_time = time.time()
|
||||||
|
# if key doesn't exist in last_db_access_time -> check db
|
||||||
|
if key not in last_db_access_time:
|
||||||
|
return True
|
||||||
|
elif (
|
||||||
|
last_db_access_time[key][0] is not None
|
||||||
|
): # check db for non-null values (for refresh operations)
|
||||||
|
return True
|
||||||
|
elif last_db_access_time[key][0] is None:
|
||||||
|
if current_time - last_db_access_time[key] >= db_cache_expiry:
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def _update_last_db_access_time(
|
||||||
|
key: str, value: Optional[Any], last_db_access_time: LimitedSizeOrderedDict
|
||||||
|
):
|
||||||
|
last_db_access_time[key] = (value, time.time())
|
||||||
|
|
||||||
|
|
||||||
@log_to_opentelemetry
|
@log_to_opentelemetry
|
||||||
async def get_user_object(
|
async def get_user_object(
|
||||||
user_id: str,
|
user_id: str,
|
||||||
|
@ -412,11 +443,20 @@ async def get_user_object(
|
||||||
if prisma_client is None:
|
if prisma_client is None:
|
||||||
raise Exception("No db connected")
|
raise Exception("No db connected")
|
||||||
try:
|
try:
|
||||||
|
db_access_time_key = "user_id:{}".format(user_id)
|
||||||
response = await prisma_client.db.litellm_usertable.find_unique(
|
should_check_db = _should_check_db(
|
||||||
where={"user_id": user_id}, include={"organization_memberships": True}
|
key=db_access_time_key,
|
||||||
|
last_db_access_time=last_db_access_time,
|
||||||
|
db_cache_expiry=db_cache_expiry,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if should_check_db:
|
||||||
|
response = await prisma_client.db.litellm_usertable.find_unique(
|
||||||
|
where={"user_id": user_id}, include={"organization_memberships": True}
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
response = None
|
||||||
|
|
||||||
if response is None:
|
if response is None:
|
||||||
if user_id_upsert:
|
if user_id_upsert:
|
||||||
response = await prisma_client.db.litellm_usertable.create(
|
response = await prisma_client.db.litellm_usertable.create(
|
||||||
|
@ -444,6 +484,13 @@ async def get_user_object(
|
||||||
# save the user object to cache
|
# save the user object to cache
|
||||||
await user_api_key_cache.async_set_cache(key=user_id, value=response_dict)
|
await user_api_key_cache.async_set_cache(key=user_id, value=response_dict)
|
||||||
|
|
||||||
|
# save to db access time
|
||||||
|
_update_last_db_access_time(
|
||||||
|
key=db_access_time_key,
|
||||||
|
value=response_dict,
|
||||||
|
last_db_access_time=last_db_access_time,
|
||||||
|
)
|
||||||
|
|
||||||
return _response
|
return _response
|
||||||
except Exception as e: # if user not in db
|
except Exception as e: # if user not in db
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
|
@ -515,6 +562,12 @@ async def _delete_cache_key_object(
|
||||||
|
|
||||||
|
|
||||||
@log_to_opentelemetry
|
@log_to_opentelemetry
|
||||||
|
async def _get_team_db_check(team_id: str, prisma_client: PrismaClient):
|
||||||
|
return await prisma_client.db.litellm_teamtable.find_unique(
|
||||||
|
where={"team_id": team_id}
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
async def get_team_object(
|
async def get_team_object(
|
||||||
team_id: str,
|
team_id: str,
|
||||||
prisma_client: Optional[PrismaClient],
|
prisma_client: Optional[PrismaClient],
|
||||||
|
@ -544,7 +597,7 @@ async def get_team_object(
|
||||||
):
|
):
|
||||||
cached_team_obj = (
|
cached_team_obj = (
|
||||||
await proxy_logging_obj.internal_usage_cache.dual_cache.async_get_cache(
|
await proxy_logging_obj.internal_usage_cache.dual_cache.async_get_cache(
|
||||||
key=key
|
key=key, parent_otel_span=parent_otel_span
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -564,9 +617,18 @@ async def get_team_object(
|
||||||
|
|
||||||
# else, check db
|
# else, check db
|
||||||
try:
|
try:
|
||||||
response = await prisma_client.db.litellm_teamtable.find_unique(
|
db_access_time_key = "team_id:{}".format(team_id)
|
||||||
where={"team_id": team_id}
|
should_check_db = _should_check_db(
|
||||||
|
key=db_access_time_key,
|
||||||
|
last_db_access_time=last_db_access_time,
|
||||||
|
db_cache_expiry=db_cache_expiry,
|
||||||
)
|
)
|
||||||
|
if should_check_db:
|
||||||
|
response = await _get_team_db_check(
|
||||||
|
team_id=team_id, prisma_client=prisma_client
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
response = None
|
||||||
|
|
||||||
if response is None:
|
if response is None:
|
||||||
raise Exception
|
raise Exception
|
||||||
|
@ -580,6 +642,14 @@ async def get_team_object(
|
||||||
proxy_logging_obj=proxy_logging_obj,
|
proxy_logging_obj=proxy_logging_obj,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# save to db access time
|
||||||
|
# save to db access time
|
||||||
|
_update_last_db_access_time(
|
||||||
|
key=db_access_time_key,
|
||||||
|
value=_response,
|
||||||
|
last_db_access_time=last_db_access_time,
|
||||||
|
)
|
||||||
|
|
||||||
return _response
|
return _response
|
||||||
except Exception:
|
except Exception:
|
||||||
raise Exception(
|
raise Exception(
|
||||||
|
@ -608,16 +678,16 @@ async def get_key_object(
|
||||||
|
|
||||||
# check if in cache
|
# check if in cache
|
||||||
key = hashed_token
|
key = hashed_token
|
||||||
cached_team_obj: Optional[UserAPIKeyAuth] = None
|
|
||||||
|
|
||||||
if cached_team_obj is None:
|
cached_key_obj: Optional[UserAPIKeyAuth] = await user_api_key_cache.async_get_cache(
|
||||||
cached_team_obj = await user_api_key_cache.async_get_cache(key=key)
|
key=key
|
||||||
|
)
|
||||||
|
|
||||||
if cached_team_obj is not None:
|
if cached_key_obj is not None:
|
||||||
if isinstance(cached_team_obj, dict):
|
if isinstance(cached_key_obj, dict):
|
||||||
return UserAPIKeyAuth(**cached_team_obj)
|
return UserAPIKeyAuth(**cached_key_obj)
|
||||||
elif isinstance(cached_team_obj, UserAPIKeyAuth):
|
elif isinstance(cached_key_obj, UserAPIKeyAuth):
|
||||||
return cached_team_obj
|
return cached_key_obj
|
||||||
|
|
||||||
if check_cache_only:
|
if check_cache_only:
|
||||||
raise Exception(
|
raise Exception(
|
||||||
|
|
|
@ -1127,11 +1127,13 @@ async def user_api_key_auth( # noqa: PLR0915
|
||||||
api_key = valid_token.token
|
api_key = valid_token.token
|
||||||
|
|
||||||
# Add hashed token to cache
|
# Add hashed token to cache
|
||||||
await _cache_key_object(
|
asyncio.create_task(
|
||||||
hashed_token=api_key,
|
_cache_key_object(
|
||||||
user_api_key_obj=valid_token,
|
hashed_token=api_key,
|
||||||
user_api_key_cache=user_api_key_cache,
|
user_api_key_obj=valid_token,
|
||||||
proxy_logging_obj=proxy_logging_obj,
|
user_api_key_cache=user_api_key_cache,
|
||||||
|
proxy_logging_obj=proxy_logging_obj,
|
||||||
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
valid_token_dict = valid_token.model_dump(exclude_none=True)
|
valid_token_dict = valid_token.model_dump(exclude_none=True)
|
||||||
|
|
|
@ -1,4 +1,5 @@
|
||||||
import copy
|
import copy
|
||||||
|
import time
|
||||||
from typing import TYPE_CHECKING, Any, Dict, Optional, Union
|
from typing import TYPE_CHECKING, Any, Dict, Optional, Union
|
||||||
|
|
||||||
from fastapi import Request
|
from fastapi import Request
|
||||||
|
@ -6,6 +7,7 @@ from starlette.datastructures import Headers
|
||||||
|
|
||||||
import litellm
|
import litellm
|
||||||
from litellm._logging import verbose_logger, verbose_proxy_logger
|
from litellm._logging import verbose_logger, verbose_proxy_logger
|
||||||
|
from litellm._service_logger import ServiceLogging
|
||||||
from litellm.proxy._types import (
|
from litellm.proxy._types import (
|
||||||
AddTeamCallback,
|
AddTeamCallback,
|
||||||
CommonProxyErrors,
|
CommonProxyErrors,
|
||||||
|
@ -16,11 +18,15 @@ from litellm.proxy._types import (
|
||||||
UserAPIKeyAuth,
|
UserAPIKeyAuth,
|
||||||
)
|
)
|
||||||
from litellm.proxy.auth.auth_utils import get_request_route
|
from litellm.proxy.auth.auth_utils import get_request_route
|
||||||
|
from litellm.types.services import ServiceTypes
|
||||||
from litellm.types.utils import (
|
from litellm.types.utils import (
|
||||||
StandardLoggingUserAPIKeyMetadata,
|
StandardLoggingUserAPIKeyMetadata,
|
||||||
SupportedCacheControls,
|
SupportedCacheControls,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
service_logger_obj = ServiceLogging() # used for tracking latency on OTEL
|
||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from litellm.proxy.proxy_server import ProxyConfig as _ProxyConfig
|
from litellm.proxy.proxy_server import ProxyConfig as _ProxyConfig
|
||||||
|
|
||||||
|
@ -471,7 +477,7 @@ async def add_litellm_data_to_request( # noqa: PLR0915
|
||||||
### END-USER SPECIFIC PARAMS ###
|
### END-USER SPECIFIC PARAMS ###
|
||||||
if user_api_key_dict.allowed_model_region is not None:
|
if user_api_key_dict.allowed_model_region is not None:
|
||||||
data["allowed_model_region"] = user_api_key_dict.allowed_model_region
|
data["allowed_model_region"] = user_api_key_dict.allowed_model_region
|
||||||
|
start_time = time.time()
|
||||||
## [Enterprise Only]
|
## [Enterprise Only]
|
||||||
# Add User-IP Address
|
# Add User-IP Address
|
||||||
requester_ip_address = ""
|
requester_ip_address = ""
|
||||||
|
@ -539,6 +545,16 @@ async def add_litellm_data_to_request( # noqa: PLR0915
|
||||||
verbose_proxy_logger.debug(
|
verbose_proxy_logger.debug(
|
||||||
f"[PROXY]returned data from litellm_pre_call_utils: {data}"
|
f"[PROXY]returned data from litellm_pre_call_utils: {data}"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
end_time = time.time()
|
||||||
|
await service_logger_obj.async_service_success_hook(
|
||||||
|
service=ServiceTypes.PROXY_PRE_CALL,
|
||||||
|
duration=end_time - start_time,
|
||||||
|
call_type="add_litellm_data_to_request",
|
||||||
|
start_time=start_time,
|
||||||
|
end_time=end_time,
|
||||||
|
parent_otel_span=user_api_key_dict.parent_otel_span,
|
||||||
|
)
|
||||||
return data
|
return data
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -1373,9 +1373,6 @@ class ProxyConfig:
|
||||||
) -> dict:
|
) -> dict:
|
||||||
"""
|
"""
|
||||||
Given a config file path, load the config from the file.
|
Given a config file path, load the config from the file.
|
||||||
|
|
||||||
If `store_model_in_db` is True, then read the DB and update the config with the DB values.
|
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
config_file_path (str): path to the config file
|
config_file_path (str): path to the config file
|
||||||
Returns:
|
Returns:
|
||||||
|
@ -1401,40 +1398,6 @@ class ProxyConfig:
|
||||||
"litellm_settings": {},
|
"litellm_settings": {},
|
||||||
}
|
}
|
||||||
|
|
||||||
## DB
|
|
||||||
if prisma_client is not None and (
|
|
||||||
general_settings.get("store_model_in_db", False) is True
|
|
||||||
or store_model_in_db is True
|
|
||||||
):
|
|
||||||
_tasks = []
|
|
||||||
keys = [
|
|
||||||
"general_settings",
|
|
||||||
"router_settings",
|
|
||||||
"litellm_settings",
|
|
||||||
"environment_variables",
|
|
||||||
]
|
|
||||||
for k in keys:
|
|
||||||
response = prisma_client.get_generic_data(
|
|
||||||
key="param_name", value=k, table_name="config"
|
|
||||||
)
|
|
||||||
_tasks.append(response)
|
|
||||||
|
|
||||||
responses = await asyncio.gather(*_tasks)
|
|
||||||
for response in responses:
|
|
||||||
if response is not None:
|
|
||||||
param_name = getattr(response, "param_name", None)
|
|
||||||
param_value = getattr(response, "param_value", None)
|
|
||||||
if param_name is not None and param_value is not None:
|
|
||||||
# check if param_name is already in the config
|
|
||||||
if param_name in config:
|
|
||||||
if isinstance(config[param_name], dict):
|
|
||||||
config[param_name].update(param_value)
|
|
||||||
else:
|
|
||||||
config[param_name] = param_value
|
|
||||||
else:
|
|
||||||
# if it's not in the config - then add it
|
|
||||||
config[param_name] = param_value
|
|
||||||
|
|
||||||
return config
|
return config
|
||||||
|
|
||||||
async def save_config(self, new_config: dict):
|
async def save_config(self, new_config: dict):
|
||||||
|
@ -1500,8 +1463,10 @@ class ProxyConfig:
|
||||||
- for a given team id
|
- for a given team id
|
||||||
- return the relevant completion() call params
|
- return the relevant completion() call params
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# load existing config
|
# load existing config
|
||||||
config = await self.get_config()
|
config = await self.get_config()
|
||||||
|
|
||||||
## LITELLM MODULE SETTINGS (e.g. litellm.drop_params=True,..)
|
## LITELLM MODULE SETTINGS (e.g. litellm.drop_params=True,..)
|
||||||
litellm_settings = config.get("litellm_settings", {})
|
litellm_settings = config.get("litellm_settings", {})
|
||||||
all_teams_config = litellm_settings.get("default_team_settings", None)
|
all_teams_config = litellm_settings.get("default_team_settings", None)
|
||||||
|
@ -8824,7 +8789,7 @@ async def update_config(config_info: ConfigYAML): # noqa: PLR0915
|
||||||
if k == "alert_to_webhook_url":
|
if k == "alert_to_webhook_url":
|
||||||
# check if slack is already enabled. if not, enable it
|
# check if slack is already enabled. if not, enable it
|
||||||
if "alerting" not in _existing_settings:
|
if "alerting" not in _existing_settings:
|
||||||
_existing_settings["alerting"].append("slack")
|
_existing_settings = {"alerting": ["slack"]}
|
||||||
elif isinstance(_existing_settings["alerting"], list):
|
elif isinstance(_existing_settings["alerting"], list):
|
||||||
if "slack" not in _existing_settings["alerting"]:
|
if "slack" not in _existing_settings["alerting"]:
|
||||||
_existing_settings["alerting"].append("slack")
|
_existing_settings["alerting"].append("slack")
|
||||||
|
|
|
@ -1400,6 +1400,7 @@ class PrismaClient:
|
||||||
|
|
||||||
return
|
return
|
||||||
|
|
||||||
|
@log_to_opentelemetry
|
||||||
@backoff.on_exception(
|
@backoff.on_exception(
|
||||||
backoff.expo,
|
backoff.expo,
|
||||||
Exception, # base exception to catch for the backoff
|
Exception, # base exception to catch for the backoff
|
||||||
|
|
|
@ -16,6 +16,7 @@ class ServiceTypes(str, enum.Enum):
|
||||||
LITELLM = "self"
|
LITELLM = "self"
|
||||||
ROUTER = "router"
|
ROUTER = "router"
|
||||||
AUTH = "auth"
|
AUTH = "auth"
|
||||||
|
PROXY_PRE_CALL = "proxy_pre_call"
|
||||||
|
|
||||||
|
|
||||||
class ServiceLoggerPayload(BaseModel):
|
class ServiceLoggerPayload(BaseModel):
|
||||||
|
|
|
@ -1539,9 +1539,15 @@ def create_pretrained_tokenizer(
|
||||||
dict: A dictionary with the tokenizer and its type.
|
dict: A dictionary with the tokenizer and its type.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
tokenizer = Tokenizer.from_pretrained(
|
try:
|
||||||
identifier, revision=revision, auth_token=auth_token
|
tokenizer = Tokenizer.from_pretrained(
|
||||||
)
|
identifier, revision=revision, auth_token=auth_token
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
verbose_logger.error(
|
||||||
|
f"Error creating pretrained tokenizer: {e}. Defaulting to version without 'auth_token'."
|
||||||
|
)
|
||||||
|
tokenizer = Tokenizer.from_pretrained(identifier, revision=revision)
|
||||||
return {"type": "huggingface_tokenizer", "tokenizer": tokenizer}
|
return {"type": "huggingface_tokenizer", "tokenizer": tokenizer}
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -2717,7 +2717,7 @@ async def test_update_user_role(prisma_client):
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
await asyncio.sleep(2)
|
# await asyncio.sleep(3)
|
||||||
|
|
||||||
# use generated key to auth in
|
# use generated key to auth in
|
||||||
print("\n\nMAKING NEW REQUEST WITH UPDATED USER ROLE\n\n")
|
print("\n\nMAKING NEW REQUEST WITH UPDATED USER ROLE\n\n")
|
||||||
|
|
|
@ -2486,6 +2486,7 @@ async def test_aaarouter_dynamic_cooldown_message_retry_time(sync_mode):
|
||||||
|
|
||||||
@pytest.mark.parametrize("sync_mode", [True, False])
|
@pytest.mark.parametrize("sync_mode", [True, False])
|
||||||
@pytest.mark.asyncio()
|
@pytest.mark.asyncio()
|
||||||
|
@pytest.mark.flaky(retries=6, delay=1)
|
||||||
async def test_router_weighted_pick(sync_mode):
|
async def test_router_weighted_pick(sync_mode):
|
||||||
router = Router(
|
router = Router(
|
||||||
model_list=[
|
model_list=[
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue