LiteLLM Minor Fixes & Improvements (09/21/2024) (#5819)

* fix(router.py): fix error message

* Litellm disable keys (#5814)

* build(schema.prisma): allow blocking/unblocking keys

Fixes https://github.com/BerriAI/litellm/issues/5328

* fix(key_management_endpoints.py): fix pop

* feat(auth_checks.py): allow admin to enable/disable virtual keys

Closes https://github.com/BerriAI/litellm/issues/5328

* docs(vertex.md): add auth section for vertex ai

Addresses - https://github.com/BerriAI/litellm/issues/5768#issuecomment-2365284223

* build(model_prices_and_context_window.json): show which models support prompt_caching

Closes https://github.com/BerriAI/litellm/issues/5776

* fix(router.py): allow setting default priority for requests

* fix(router.py): add 'retry-after' header for concurrent request limit errors

Fixes https://github.com/BerriAI/litellm/issues/5783

* fix(router.py): correctly raise and use retry-after header from azure+openai

Fixes https://github.com/BerriAI/litellm/issues/5783

* fix(user_api_key_auth.py): fix valid token being none

* fix(auth_checks.py): fix model dump for cache management object

* fix(user_api_key_auth.py): pass prisma_client to obj

* test(test_otel.py): update test for new key check

* test: fix test
This commit is contained in:
Krish Dholakia 2024-09-21 18:51:53 -07:00 committed by GitHub
parent 1ca638973f
commit 8039b95aaf
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
25 changed files with 1006 additions and 182 deletions

View file

@ -46,11 +46,13 @@ import litellm
from litellm._logging import verbose_logger, verbose_proxy_logger
from litellm.proxy._types import *
from litellm.proxy.auth.auth_checks import (
_cache_key_object,
allowed_routes_check,
can_key_call_model,
common_checks,
get_actual_routes,
get_end_user_object,
get_key_object,
get_org_object,
get_team_object,
get_user_object,
@ -525,9 +527,19 @@ async def user_api_key_auth(
### CHECK IF ADMIN ###
# note: never string compare api keys, this is vulenerable to a time attack. Use secrets.compare_digest instead
## Check CACHE
valid_token: Optional[UserAPIKeyAuth] = user_api_key_cache.get_cache(
key=hash_token(api_key)
)
try:
valid_token = await get_key_object(
hashed_token=hash_token(api_key),
prisma_client=prisma_client,
user_api_key_cache=user_api_key_cache,
parent_otel_span=parent_otel_span,
proxy_logging_obj=proxy_logging_obj,
check_cache_only=True,
)
except Exception:
verbose_logger.debug("api key not found in cache.")
valid_token = None
if (
valid_token is not None
and isinstance(valid_token, UserAPIKeyAuth)
@ -578,7 +590,7 @@ async def user_api_key_auth(
try:
is_master_key_valid = secrets.compare_digest(api_key, master_key) # type: ignore
except Exception as e:
except Exception:
is_master_key_valid = False
## VALIDATE MASTER KEY ##
@ -602,8 +614,11 @@ async def user_api_key_auth(
parent_otel_span=parent_otel_span,
**end_user_params,
)
await user_api_key_cache.async_set_cache(
key=hash_token(master_key), value=_user_api_key_obj
await _cache_key_object(
hashed_token=hash_token(master_key),
user_api_key_obj=_user_api_key_obj,
user_api_key_cache=user_api_key_cache,
proxy_logging_obj=proxy_logging_obj,
)
return _user_api_key_obj
@ -640,38 +655,31 @@ async def user_api_key_auth(
_user_role = None
if api_key.startswith("sk-"):
api_key = hash_token(token=api_key)
valid_token: Optional[UserAPIKeyAuth] = user_api_key_cache.get_cache( # type: ignore
key=api_key
)
if valid_token is None:
## check db
verbose_proxy_logger.debug("api key: %s", api_key)
if prisma_client is not None:
_valid_token: Optional[BaseModel] = await prisma_client.get_data(
token=api_key,
table_name="combined_view",
try:
valid_token = await get_key_object(
hashed_token=api_key,
prisma_client=prisma_client,
user_api_key_cache=user_api_key_cache,
parent_otel_span=parent_otel_span,
proxy_logging_obj=proxy_logging_obj,
)
# update end-user params on valid token
# These can change per request - it's important to update them here
valid_token.end_user_id = end_user_params.get("end_user_id")
valid_token.end_user_tpm_limit = end_user_params.get(
"end_user_tpm_limit"
)
valid_token.end_user_rpm_limit = end_user_params.get(
"end_user_rpm_limit"
)
valid_token.allowed_model_region = end_user_params.get(
"allowed_model_region"
)
if _valid_token is not None:
## update cached token
valid_token = UserAPIKeyAuth(
**_valid_token.model_dump(exclude_none=True)
)
verbose_proxy_logger.debug("Token from db: %s", valid_token)
elif valid_token is not None and isinstance(valid_token, UserAPIKeyAuth):
verbose_proxy_logger.debug("API Key Cache Hit!")
# update end-user params on valid token
# These can change per request - it's important to update them here
valid_token.end_user_id = end_user_params.get("end_user_id")
valid_token.end_user_tpm_limit = end_user_params.get("end_user_tpm_limit")
valid_token.end_user_rpm_limit = end_user_params.get("end_user_rpm_limit")
valid_token.allowed_model_region = end_user_params.get(
"allowed_model_region"
)
except Exception:
valid_token = None
user_obj: Optional[LiteLLM_UserTable] = None
valid_token_dict: dict = {}
@ -689,6 +697,12 @@ async def user_api_key_auth(
# 8. If token spend is under team budget
# 9. If team spend is under team budget
## base case ## key is disabled
if valid_token.blocked is True:
raise Exception(
"Key is blocked. Update via `/key/unblock` if you're admin."
)
# Check 1. If token can call model
_model_alias_map = {}
model: Optional[str] = None
@ -1006,10 +1020,13 @@ async def user_api_key_auth(
api_key = valid_token.token
# Add hashed token to cache
await user_api_key_cache.async_set_cache(
key=api_key,
value=valid_token,
await _cache_key_object(
hashed_token=api_key,
user_api_key_obj=valid_token,
user_api_key_cache=user_api_key_cache,
proxy_logging_obj=proxy_logging_obj,
)
valid_token_dict = valid_token.model_dump(exclude_none=True)
valid_token_dict.pop("token", None)