LiteLLM Minor Fixes & Improvements (09/21/2024) (#5819)

* fix(router.py): fix error message

* Litellm disable keys (#5814)

* build(schema.prisma): allow blocking/unblocking keys

Fixes https://github.com/BerriAI/litellm/issues/5328

* fix(key_management_endpoints.py): fix pop

* feat(auth_checks.py): allow admin to enable/disable virtual keys

Closes https://github.com/BerriAI/litellm/issues/5328

* docs(vertex.md): add auth section for vertex ai

Addresses - https://github.com/BerriAI/litellm/issues/5768#issuecomment-2365284223

* build(model_prices_and_context_window.json): show which models support prompt_caching

Closes https://github.com/BerriAI/litellm/issues/5776

* fix(router.py): allow setting default priority for requests

* fix(router.py): add 'retry-after' header for concurrent request limit errors

Fixes https://github.com/BerriAI/litellm/issues/5783

* fix(router.py): correctly raise and use retry-after header from azure+openai

Fixes https://github.com/BerriAI/litellm/issues/5783

* fix(user_api_key_auth.py): fix valid token being none

* fix(auth_checks.py): fix model dump for cache management object

* fix(user_api_key_auth.py): pass prisma_client to obj

* test(test_otel.py): update test for new key check

* test: fix test
This commit is contained in:
Krish Dholakia 2024-09-21 18:51:53 -07:00 committed by GitHub
parent f0543a6f9d
commit f3fa2160a0
25 changed files with 1006 additions and 182 deletions

View file

@ -25,6 +25,11 @@ from fastapi import APIRouter, Depends, Header, HTTPException, Query, Request, s
import litellm
from litellm._logging import verbose_proxy_logger
from litellm.proxy._types import *
from litellm.proxy.auth.auth_checks import (
_cache_key_object,
_delete_cache_key_object,
get_key_object,
)
from litellm.proxy.auth.user_api_key_auth import user_api_key_auth
from litellm.proxy.management_helpers.utils import management_endpoint_wrapper
from litellm.proxy.utils import _duration_in_seconds
@ -302,15 +307,18 @@ async def prepare_key_update_data(
data: Union[UpdateKeyRequest, RegenerateKeyRequest], existing_key_row
):
data_json: dict = data.dict(exclude_unset=True)
key = data_json.pop("key", None)
data_json.pop("key", None)
_metadata_fields = ["model_rpm_limit", "model_tpm_limit", "guardrails"]
non_default_values = {}
for k, v in data_json.items():
if k in _metadata_fields:
continue
if v is not None and v not in ([], {}, 0):
non_default_values[k] = v
if v is not None:
if not isinstance(v, bool) and v in ([], {}, 0):
pass
else:
non_default_values[k] = v
if "duration" in non_default_values:
duration = non_default_values.pop("duration")
if duration and (isinstance(duration, str)) and len(duration) > 0:
@ -364,12 +372,10 @@ async def update_key_fn(
"""
from litellm.proxy.proxy_server import (
create_audit_log_for_update,
general_settings,
litellm_proxy_admin_name,
prisma_client,
proxy_logging_obj,
user_api_key_cache,
user_custom_key_generate,
)
try:
@ -399,9 +405,11 @@ async def update_key_fn(
# Delete - key from cache, since it's been updated!
# key updated - a new model could have been added to this key. it should not block requests after this is done
user_api_key_cache.delete_cache(key)
hashed_token = hash_token(key)
user_api_key_cache.delete_cache(hashed_token)
await _delete_cache_key_object(
hashed_token=hash_token(key),
user_api_key_cache=user_api_key_cache,
proxy_logging_obj=proxy_logging_obj,
)
# Enterprise Feature - Audit Logging. Enable with litellm.store_audit_logs = True
if litellm.store_audit_logs is True:
@ -434,6 +442,11 @@ async def update_key_fn(
return {"key": key, **response["data"]}
# update based on remaining passed in values
except Exception as e:
verbose_proxy_logger.exception(
"litellm.proxy.proxy_server.update_key_fn(): Exception occured - {}".format(
str(e)
)
)
if isinstance(e, HTTPException):
raise ProxyException(
message=getattr(e, "detail", f"Authentication Error({str(e)})"),
@ -771,6 +784,7 @@ async def generate_key_helper_fn(
float
] = None, # soft_budget is used to set soft Budgets Per user
max_budget: Optional[float] = None, # max_budget is used to Budget Per user
blocked: Optional[bool] = None,
budget_duration: Optional[str] = None, # max_budget is used to Budget Per user
token: Optional[str] = None,
key: Optional[
@ -899,6 +913,7 @@ async def generate_key_helper_fn(
"permissions": permissions_json,
"model_max_budget": model_max_budget_json,
"budget_id": budget_id,
"blocked": blocked,
}
if (
@ -1047,6 +1062,7 @@ async def regenerate_key_fn(
hash_token,
premium_user,
prisma_client,
proxy_logging_obj,
user_api_key_cache,
)
@ -1124,10 +1140,18 @@ async def regenerate_key_fn(
### 3. remove existing key entry from cache
######################################################################
if key:
user_api_key_cache.delete_cache(key)
await _delete_cache_key_object(
hashed_token=hash_token(key),
user_api_key_cache=user_api_key_cache,
proxy_logging_obj=proxy_logging_obj,
)
if hashed_api_key:
user_api_key_cache.delete_cache(hashed_api_key)
await _delete_cache_key_object(
hashed_token=hash_token(key),
user_api_key_cache=user_api_key_cache,
proxy_logging_obj=proxy_logging_obj,
)
return GenerateKeyResponse(
**updated_token_dict,
@ -1240,3 +1264,187 @@ async def list_keys(
param=getattr(e, "param", "None"),
code=status.HTTP_500_INTERNAL_SERVER_ERROR,
)
@router.post(
"/key/block", tags=["key management"], dependencies=[Depends(user_api_key_auth)]
)
@management_endpoint_wrapper
async def block_key(
data: BlockKeyRequest,
http_request: Request,
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
litellm_changed_by: Optional[str] = Header(
None,
description="The litellm-changed-by header enables tracking of actions performed by authorized users on behalf of other users, providing an audit trail for accountability",
),
):
"""
Blocks all calls from keys with this team id.
"""
from litellm.proxy.proxy_server import (
create_audit_log_for_update,
hash_token,
litellm_proxy_admin_name,
prisma_client,
proxy_logging_obj,
user_api_key_cache,
)
if prisma_client is None:
raise Exception("{}".format(CommonProxyErrors.db_not_connected_error.value))
if data.key.startswith("sk-"):
hashed_token = hash_token(token=data.key)
else:
hashed_token = data.key
if litellm.store_audit_logs is True:
# make an audit log for key update
record = await prisma_client.db.litellm_verificationtoken.find_unique(
where={"token": hashed_token}
)
if record is None:
raise ProxyException(
message=f"Key {data.key} not found",
type=ProxyErrorTypes.bad_request_error,
param="key",
code=status.HTTP_404_NOT_FOUND,
)
asyncio.create_task(
create_audit_log_for_update(
request_data=LiteLLM_AuditLogs(
id=str(uuid.uuid4()),
updated_at=datetime.now(timezone.utc),
changed_by=litellm_changed_by
or user_api_key_dict.user_id
or litellm_proxy_admin_name,
changed_by_api_key=user_api_key_dict.api_key,
table_name=LitellmTableNames.KEY_TABLE_NAME,
object_id=hashed_token,
action="blocked",
updated_values="{}",
before_value=record.model_dump_json(),
)
)
)
record = await prisma_client.db.litellm_verificationtoken.update(
where={"token": hashed_token}, data={"blocked": True} # type: ignore
)
## UPDATE KEY CACHE
### get cached object ###
key_object = await get_key_object(
hashed_token=hashed_token,
prisma_client=prisma_client,
user_api_key_cache=user_api_key_cache,
parent_otel_span=None,
proxy_logging_obj=proxy_logging_obj,
)
### update cached object ###
key_object.blocked = True
### store cached object ###
await _cache_key_object(
hashed_token=hashed_token,
user_api_key_obj=key_object,
user_api_key_cache=user_api_key_cache,
proxy_logging_obj=proxy_logging_obj,
)
return record
@router.post(
"/key/unblock", tags=["key management"], dependencies=[Depends(user_api_key_auth)]
)
@management_endpoint_wrapper
async def unblock_key(
data: BlockKeyRequest,
http_request: Request,
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
litellm_changed_by: Optional[str] = Header(
None,
description="The litellm-changed-by header enables tracking of actions performed by authorized users on behalf of other users, providing an audit trail for accountability",
),
):
"""
Unblocks all calls from this key.
"""
from litellm.proxy.proxy_server import (
create_audit_log_for_update,
hash_token,
litellm_proxy_admin_name,
prisma_client,
proxy_logging_obj,
user_api_key_cache,
)
if prisma_client is None:
raise Exception("{}".format(CommonProxyErrors.db_not_connected_error.value))
if data.key.startswith("sk-"):
hashed_token = hash_token(token=data.key)
else:
hashed_token = data.key
if litellm.store_audit_logs is True:
# make an audit log for key update
record = await prisma_client.db.litellm_verificationtoken.find_unique(
where={"token": hashed_token}
)
if record is None:
raise ProxyException(
message=f"Key {data.key} not found",
type=ProxyErrorTypes.bad_request_error,
param="key",
code=status.HTTP_404_NOT_FOUND,
)
asyncio.create_task(
create_audit_log_for_update(
request_data=LiteLLM_AuditLogs(
id=str(uuid.uuid4()),
updated_at=datetime.now(timezone.utc),
changed_by=litellm_changed_by
or user_api_key_dict.user_id
or litellm_proxy_admin_name,
changed_by_api_key=user_api_key_dict.api_key,
table_name=LitellmTableNames.KEY_TABLE_NAME,
object_id=hashed_token,
action="blocked",
updated_values="{}",
before_value=record.model_dump_json(),
)
)
)
record = await prisma_client.db.litellm_verificationtoken.update(
where={"token": hashed_token}, data={"blocked": False} # type: ignore
)
## UPDATE KEY CACHE
### get cached object ###
key_object = await get_key_object(
hashed_token=hashed_token,
prisma_client=prisma_client,
user_api_key_cache=user_api_key_cache,
parent_otel_span=None,
proxy_logging_obj=proxy_logging_obj,
)
### update cached object ###
key_object.blocked = False
### store cached object ###
await _cache_key_object(
hashed_token=hashed_token,
user_api_key_obj=key_object,
user_api_key_cache=user_api_key_cache,
proxy_logging_obj=proxy_logging_obj,
)
return record