mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 11:14:04 +00:00
LiteLLM Minor Fixes & Improvements (09/21/2024) (#5819)
* fix(router.py): fix error message * Litellm disable keys (#5814) * build(schema.prisma): allow blocking/unblocking keys Fixes https://github.com/BerriAI/litellm/issues/5328 * fix(key_management_endpoints.py): fix pop * feat(auth_checks.py): allow admin to enable/disable virtual keys Closes https://github.com/BerriAI/litellm/issues/5328 * docs(vertex.md): add auth section for vertex ai Addresses - https://github.com/BerriAI/litellm/issues/5768#issuecomment-2365284223 * build(model_prices_and_context_window.json): show which models support prompt_caching Closes https://github.com/BerriAI/litellm/issues/5776 * fix(router.py): allow setting default priority for requests * fix(router.py): add 'retry-after' header for concurrent request limit errors Fixes https://github.com/BerriAI/litellm/issues/5783 * fix(router.py): correctly raise and use retry-after header from azure+openai Fixes https://github.com/BerriAI/litellm/issues/5783 * fix(user_api_key_auth.py): fix valid token being none * fix(auth_checks.py): fix model dump for cache management object * fix(user_api_key_auth.py): pass prisma_client to obj * test(test_otel.py): update test for new key check * test: fix test
This commit is contained in:
parent
f0543a6f9d
commit
f3fa2160a0
25 changed files with 1006 additions and 182 deletions
|
@ -25,6 +25,11 @@ from fastapi import APIRouter, Depends, Header, HTTPException, Query, Request, s
|
|||
import litellm
|
||||
from litellm._logging import verbose_proxy_logger
|
||||
from litellm.proxy._types import *
|
||||
from litellm.proxy.auth.auth_checks import (
|
||||
_cache_key_object,
|
||||
_delete_cache_key_object,
|
||||
get_key_object,
|
||||
)
|
||||
from litellm.proxy.auth.user_api_key_auth import user_api_key_auth
|
||||
from litellm.proxy.management_helpers.utils import management_endpoint_wrapper
|
||||
from litellm.proxy.utils import _duration_in_seconds
|
||||
|
@ -302,15 +307,18 @@ async def prepare_key_update_data(
|
|||
data: Union[UpdateKeyRequest, RegenerateKeyRequest], existing_key_row
|
||||
):
|
||||
data_json: dict = data.dict(exclude_unset=True)
|
||||
key = data_json.pop("key", None)
|
||||
|
||||
data_json.pop("key", None)
|
||||
_metadata_fields = ["model_rpm_limit", "model_tpm_limit", "guardrails"]
|
||||
non_default_values = {}
|
||||
for k, v in data_json.items():
|
||||
if k in _metadata_fields:
|
||||
continue
|
||||
if v is not None and v not in ([], {}, 0):
|
||||
non_default_values[k] = v
|
||||
if v is not None:
|
||||
if not isinstance(v, bool) and v in ([], {}, 0):
|
||||
pass
|
||||
else:
|
||||
non_default_values[k] = v
|
||||
|
||||
if "duration" in non_default_values:
|
||||
duration = non_default_values.pop("duration")
|
||||
if duration and (isinstance(duration, str)) and len(duration) > 0:
|
||||
|
@ -364,12 +372,10 @@ async def update_key_fn(
|
|||
"""
|
||||
from litellm.proxy.proxy_server import (
|
||||
create_audit_log_for_update,
|
||||
general_settings,
|
||||
litellm_proxy_admin_name,
|
||||
prisma_client,
|
||||
proxy_logging_obj,
|
||||
user_api_key_cache,
|
||||
user_custom_key_generate,
|
||||
)
|
||||
|
||||
try:
|
||||
|
@ -399,9 +405,11 @@ async def update_key_fn(
|
|||
|
||||
# Delete - key from cache, since it's been updated!
|
||||
# key updated - a new model could have been added to this key. it should not block requests after this is done
|
||||
user_api_key_cache.delete_cache(key)
|
||||
hashed_token = hash_token(key)
|
||||
user_api_key_cache.delete_cache(hashed_token)
|
||||
await _delete_cache_key_object(
|
||||
hashed_token=hash_token(key),
|
||||
user_api_key_cache=user_api_key_cache,
|
||||
proxy_logging_obj=proxy_logging_obj,
|
||||
)
|
||||
|
||||
# Enterprise Feature - Audit Logging. Enable with litellm.store_audit_logs = True
|
||||
if litellm.store_audit_logs is True:
|
||||
|
@ -434,6 +442,11 @@ async def update_key_fn(
|
|||
return {"key": key, **response["data"]}
|
||||
# update based on remaining passed in values
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.exception(
|
||||
"litellm.proxy.proxy_server.update_key_fn(): Exception occured - {}".format(
|
||||
str(e)
|
||||
)
|
||||
)
|
||||
if isinstance(e, HTTPException):
|
||||
raise ProxyException(
|
||||
message=getattr(e, "detail", f"Authentication Error({str(e)})"),
|
||||
|
@ -771,6 +784,7 @@ async def generate_key_helper_fn(
|
|||
float
|
||||
] = None, # soft_budget is used to set soft Budgets Per user
|
||||
max_budget: Optional[float] = None, # max_budget is used to Budget Per user
|
||||
blocked: Optional[bool] = None,
|
||||
budget_duration: Optional[str] = None, # max_budget is used to Budget Per user
|
||||
token: Optional[str] = None,
|
||||
key: Optional[
|
||||
|
@ -899,6 +913,7 @@ async def generate_key_helper_fn(
|
|||
"permissions": permissions_json,
|
||||
"model_max_budget": model_max_budget_json,
|
||||
"budget_id": budget_id,
|
||||
"blocked": blocked,
|
||||
}
|
||||
|
||||
if (
|
||||
|
@ -1047,6 +1062,7 @@ async def regenerate_key_fn(
|
|||
hash_token,
|
||||
premium_user,
|
||||
prisma_client,
|
||||
proxy_logging_obj,
|
||||
user_api_key_cache,
|
||||
)
|
||||
|
||||
|
@ -1124,10 +1140,18 @@ async def regenerate_key_fn(
|
|||
### 3. remove existing key entry from cache
|
||||
######################################################################
|
||||
if key:
|
||||
user_api_key_cache.delete_cache(key)
|
||||
await _delete_cache_key_object(
|
||||
hashed_token=hash_token(key),
|
||||
user_api_key_cache=user_api_key_cache,
|
||||
proxy_logging_obj=proxy_logging_obj,
|
||||
)
|
||||
|
||||
if hashed_api_key:
|
||||
user_api_key_cache.delete_cache(hashed_api_key)
|
||||
await _delete_cache_key_object(
|
||||
hashed_token=hash_token(key),
|
||||
user_api_key_cache=user_api_key_cache,
|
||||
proxy_logging_obj=proxy_logging_obj,
|
||||
)
|
||||
|
||||
return GenerateKeyResponse(
|
||||
**updated_token_dict,
|
||||
|
@ -1240,3 +1264,187 @@ async def list_keys(
|
|||
param=getattr(e, "param", "None"),
|
||||
code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/key/block", tags=["key management"], dependencies=[Depends(user_api_key_auth)]
|
||||
)
|
||||
@management_endpoint_wrapper
|
||||
async def block_key(
|
||||
data: BlockKeyRequest,
|
||||
http_request: Request,
|
||||
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
||||
litellm_changed_by: Optional[str] = Header(
|
||||
None,
|
||||
description="The litellm-changed-by header enables tracking of actions performed by authorized users on behalf of other users, providing an audit trail for accountability",
|
||||
),
|
||||
):
|
||||
"""
|
||||
Blocks all calls from keys with this team id.
|
||||
"""
|
||||
from litellm.proxy.proxy_server import (
|
||||
create_audit_log_for_update,
|
||||
hash_token,
|
||||
litellm_proxy_admin_name,
|
||||
prisma_client,
|
||||
proxy_logging_obj,
|
||||
user_api_key_cache,
|
||||
)
|
||||
|
||||
if prisma_client is None:
|
||||
raise Exception("{}".format(CommonProxyErrors.db_not_connected_error.value))
|
||||
|
||||
if data.key.startswith("sk-"):
|
||||
hashed_token = hash_token(token=data.key)
|
||||
else:
|
||||
hashed_token = data.key
|
||||
|
||||
if litellm.store_audit_logs is True:
|
||||
# make an audit log for key update
|
||||
record = await prisma_client.db.litellm_verificationtoken.find_unique(
|
||||
where={"token": hashed_token}
|
||||
)
|
||||
if record is None:
|
||||
raise ProxyException(
|
||||
message=f"Key {data.key} not found",
|
||||
type=ProxyErrorTypes.bad_request_error,
|
||||
param="key",
|
||||
code=status.HTTP_404_NOT_FOUND,
|
||||
)
|
||||
asyncio.create_task(
|
||||
create_audit_log_for_update(
|
||||
request_data=LiteLLM_AuditLogs(
|
||||
id=str(uuid.uuid4()),
|
||||
updated_at=datetime.now(timezone.utc),
|
||||
changed_by=litellm_changed_by
|
||||
or user_api_key_dict.user_id
|
||||
or litellm_proxy_admin_name,
|
||||
changed_by_api_key=user_api_key_dict.api_key,
|
||||
table_name=LitellmTableNames.KEY_TABLE_NAME,
|
||||
object_id=hashed_token,
|
||||
action="blocked",
|
||||
updated_values="{}",
|
||||
before_value=record.model_dump_json(),
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
record = await prisma_client.db.litellm_verificationtoken.update(
|
||||
where={"token": hashed_token}, data={"blocked": True} # type: ignore
|
||||
)
|
||||
|
||||
## UPDATE KEY CACHE
|
||||
|
||||
### get cached object ###
|
||||
key_object = await get_key_object(
|
||||
hashed_token=hashed_token,
|
||||
prisma_client=prisma_client,
|
||||
user_api_key_cache=user_api_key_cache,
|
||||
parent_otel_span=None,
|
||||
proxy_logging_obj=proxy_logging_obj,
|
||||
)
|
||||
|
||||
### update cached object ###
|
||||
key_object.blocked = True
|
||||
|
||||
### store cached object ###
|
||||
await _cache_key_object(
|
||||
hashed_token=hashed_token,
|
||||
user_api_key_obj=key_object,
|
||||
user_api_key_cache=user_api_key_cache,
|
||||
proxy_logging_obj=proxy_logging_obj,
|
||||
)
|
||||
|
||||
return record
|
||||
|
||||
|
||||
@router.post(
|
||||
"/key/unblock", tags=["key management"], dependencies=[Depends(user_api_key_auth)]
|
||||
)
|
||||
@management_endpoint_wrapper
|
||||
async def unblock_key(
|
||||
data: BlockKeyRequest,
|
||||
http_request: Request,
|
||||
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
||||
litellm_changed_by: Optional[str] = Header(
|
||||
None,
|
||||
description="The litellm-changed-by header enables tracking of actions performed by authorized users on behalf of other users, providing an audit trail for accountability",
|
||||
),
|
||||
):
|
||||
"""
|
||||
Unblocks all calls from this key.
|
||||
"""
|
||||
from litellm.proxy.proxy_server import (
|
||||
create_audit_log_for_update,
|
||||
hash_token,
|
||||
litellm_proxy_admin_name,
|
||||
prisma_client,
|
||||
proxy_logging_obj,
|
||||
user_api_key_cache,
|
||||
)
|
||||
|
||||
if prisma_client is None:
|
||||
raise Exception("{}".format(CommonProxyErrors.db_not_connected_error.value))
|
||||
|
||||
if data.key.startswith("sk-"):
|
||||
hashed_token = hash_token(token=data.key)
|
||||
else:
|
||||
hashed_token = data.key
|
||||
|
||||
if litellm.store_audit_logs is True:
|
||||
# make an audit log for key update
|
||||
record = await prisma_client.db.litellm_verificationtoken.find_unique(
|
||||
where={"token": hashed_token}
|
||||
)
|
||||
if record is None:
|
||||
raise ProxyException(
|
||||
message=f"Key {data.key} not found",
|
||||
type=ProxyErrorTypes.bad_request_error,
|
||||
param="key",
|
||||
code=status.HTTP_404_NOT_FOUND,
|
||||
)
|
||||
asyncio.create_task(
|
||||
create_audit_log_for_update(
|
||||
request_data=LiteLLM_AuditLogs(
|
||||
id=str(uuid.uuid4()),
|
||||
updated_at=datetime.now(timezone.utc),
|
||||
changed_by=litellm_changed_by
|
||||
or user_api_key_dict.user_id
|
||||
or litellm_proxy_admin_name,
|
||||
changed_by_api_key=user_api_key_dict.api_key,
|
||||
table_name=LitellmTableNames.KEY_TABLE_NAME,
|
||||
object_id=hashed_token,
|
||||
action="blocked",
|
||||
updated_values="{}",
|
||||
before_value=record.model_dump_json(),
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
record = await prisma_client.db.litellm_verificationtoken.update(
|
||||
where={"token": hashed_token}, data={"blocked": False} # type: ignore
|
||||
)
|
||||
|
||||
## UPDATE KEY CACHE
|
||||
|
||||
### get cached object ###
|
||||
key_object = await get_key_object(
|
||||
hashed_token=hashed_token,
|
||||
prisma_client=prisma_client,
|
||||
user_api_key_cache=user_api_key_cache,
|
||||
parent_otel_span=None,
|
||||
proxy_logging_obj=proxy_logging_obj,
|
||||
)
|
||||
|
||||
### update cached object ###
|
||||
key_object.blocked = False
|
||||
|
||||
### store cached object ###
|
||||
await _cache_key_object(
|
||||
hashed_token=hashed_token,
|
||||
user_api_key_obj=key_object,
|
||||
user_api_key_cache=user_api_key_cache,
|
||||
proxy_logging_obj=proxy_logging_obj,
|
||||
)
|
||||
|
||||
return record
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue