forked from phoenix/litellm-mirror
fix: fix linting errors
This commit is contained in:
parent
ddf56b8935
commit
84f3ac7d25
2 changed files with 63 additions and 55 deletions
|
@ -43,7 +43,7 @@ from litellm.proxy.utils import handle_exception_on_proxy
|
|||
router = APIRouter()
|
||||
|
||||
|
||||
def _update_internal_user_params(data_json: dict, data: NewUserRequest) -> dict:
|
||||
def _update_internal_new_user_params(data_json: dict, data: NewUserRequest) -> dict:
|
||||
if "user_id" in data_json and data_json["user_id"] is None:
|
||||
data_json["user_id"] = str(uuid.uuid4())
|
||||
auto_create_key = data_json.pop("auto_create_key", True)
|
||||
|
@ -146,7 +146,7 @@ async def new_user(
|
|||
from litellm.proxy.proxy_server import general_settings, proxy_logging_obj
|
||||
|
||||
data_json = data.json() # type: ignore
|
||||
data_json = _update_internal_user_params(data_json, data)
|
||||
data_json = _update_internal_new_user_params(data_json, data)
|
||||
response = await generate_key_helper_fn(request_type="user", **data_json)
|
||||
|
||||
# Admin UI Logic
|
||||
|
@ -439,6 +439,52 @@ async def user_info( # noqa: PLR0915
|
|||
raise handle_exception_on_proxy(e)
|
||||
|
||||
|
||||
def _update_internal_user_params(data_json: dict, data: UpdateUserRequest) -> dict:
|
||||
non_default_values = {}
|
||||
for k, v in data_json.items():
|
||||
if (
|
||||
v is not None
|
||||
and v
|
||||
not in (
|
||||
[],
|
||||
{},
|
||||
0,
|
||||
)
|
||||
and k not in LiteLLM_ManagementEndpoint_MetadataFields
|
||||
): # models default to [], spend defaults to 0, we should not reset these values
|
||||
non_default_values[k] = v
|
||||
|
||||
is_internal_user = False
|
||||
if data.user_role == LitellmUserRoles.INTERNAL_USER:
|
||||
is_internal_user = True
|
||||
|
||||
if "budget_duration" in non_default_values:
|
||||
duration_s = duration_in_seconds(duration=non_default_values["budget_duration"])
|
||||
user_reset_at = datetime.now(timezone.utc) + timedelta(seconds=duration_s)
|
||||
non_default_values["budget_reset_at"] = user_reset_at
|
||||
|
||||
if "max_budget" not in non_default_values:
|
||||
if (
|
||||
is_internal_user and litellm.max_internal_user_budget is not None
|
||||
): # applies internal user limits, if user role updated
|
||||
non_default_values["max_budget"] = litellm.max_internal_user_budget
|
||||
|
||||
if (
|
||||
"budget_duration" not in non_default_values
|
||||
): # applies internal user limits, if user role updated
|
||||
if is_internal_user and litellm.internal_user_budget_duration is not None:
|
||||
non_default_values["budget_duration"] = (
|
||||
litellm.internal_user_budget_duration
|
||||
)
|
||||
duration_s = duration_in_seconds(
|
||||
duration=non_default_values["budget_duration"]
|
||||
)
|
||||
user_reset_at = datetime.now(timezone.utc) + timedelta(seconds=duration_s)
|
||||
non_default_values["budget_reset_at"] = user_reset_at
|
||||
|
||||
return non_default_values
|
||||
|
||||
|
||||
@router.post(
|
||||
"/user/update",
|
||||
tags=["Internal User management"],
|
||||
|
@ -504,51 +550,9 @@ async def user_update(
|
|||
raise Exception("Not connected to DB!")
|
||||
|
||||
# get non default values for key
|
||||
non_default_values = {}
|
||||
for k, v in data_json.items():
|
||||
if (
|
||||
v is not None
|
||||
and v
|
||||
not in (
|
||||
[],
|
||||
{},
|
||||
0,
|
||||
)
|
||||
and k not in LiteLLM_ManagementEndpoint_MetadataFields
|
||||
): # models default to [], spend defaults to 0, we should not reset these values
|
||||
non_default_values[k] = v
|
||||
|
||||
is_internal_user = False
|
||||
if data.user_role == LitellmUserRoles.INTERNAL_USER:
|
||||
is_internal_user = True
|
||||
|
||||
if "budget_duration" in non_default_values:
|
||||
duration_s = duration_in_seconds(
|
||||
duration=non_default_values["budget_duration"]
|
||||
)
|
||||
user_reset_at = datetime.now(timezone.utc) + timedelta(seconds=duration_s)
|
||||
non_default_values["budget_reset_at"] = user_reset_at
|
||||
|
||||
if "max_budget" not in non_default_values:
|
||||
if (
|
||||
is_internal_user and litellm.max_internal_user_budget is not None
|
||||
): # applies internal user limits, if user role updated
|
||||
non_default_values["max_budget"] = litellm.max_internal_user_budget
|
||||
|
||||
if (
|
||||
"budget_duration" not in non_default_values
|
||||
): # applies internal user limits, if user role updated
|
||||
if is_internal_user and litellm.internal_user_budget_duration is not None:
|
||||
non_default_values["budget_duration"] = (
|
||||
litellm.internal_user_budget_duration
|
||||
)
|
||||
duration_s = duration_in_seconds(
|
||||
duration=non_default_values["budget_duration"]
|
||||
)
|
||||
user_reset_at = datetime.now(timezone.utc) + timedelta(
|
||||
seconds=duration_s
|
||||
)
|
||||
non_default_values["budget_reset_at"] = user_reset_at
|
||||
non_default_values = _update_internal_user_params(
|
||||
data_json=data_json, data=data
|
||||
)
|
||||
|
||||
existing_user_row = await prisma_client.get_data(
|
||||
user_id=data.user_id, table_name="user", query_type="find_unique"
|
||||
|
|
|
@ -17,7 +17,7 @@ import secrets
|
|||
import traceback
|
||||
import uuid
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import List, Optional, Tuple
|
||||
from typing import List, Optional, Tuple, cast
|
||||
|
||||
import fastapi
|
||||
from fastapi import APIRouter, Depends, Header, HTTPException, Query, Request, status
|
||||
|
@ -469,20 +469,22 @@ def prepare_metadata_fields(
|
|||
non_default_values["metadata"] = non_default_values["metadata"].copy()
|
||||
non_default_values["metadata"].update(existing_metadata)
|
||||
|
||||
casted_metadata = cast(dict, non_default_values["metadata"])
|
||||
|
||||
data_json = data.model_dump(exclude_unset=True)
|
||||
|
||||
try:
|
||||
for k, v in data_json.items():
|
||||
if k == "model_tpm_limit" or k == "model_rpm_limit":
|
||||
if k not in non_default_values["metadata"]:
|
||||
non_default_values["metadata"][k] = {}
|
||||
non_default_values["metadata"][k].update(v)
|
||||
if k not in casted_metadata:
|
||||
casted_metadata[k] = {}
|
||||
casted_metadata[k].update(v)
|
||||
|
||||
if k == "tags" or k == "guardrails":
|
||||
if k not in non_default_values["metadata"]:
|
||||
non_default_values["metadata"][k] = []
|
||||
seen = set(non_default_values["metadata"][k])
|
||||
non_default_values["metadata"][k].extend(
|
||||
if k not in casted_metadata:
|
||||
casted_metadata[k] = []
|
||||
seen = set(casted_metadata[k])
|
||||
casted_metadata[k].extend(
|
||||
x for x in v if x not in seen and not seen.add(x) # type: ignore
|
||||
) # prevent duplicates from being added + maintain initial order
|
||||
|
||||
|
@ -492,6 +494,8 @@ def prepare_metadata_fields(
|
|||
str(e)
|
||||
)
|
||||
)
|
||||
|
||||
non_default_values["metadata"] = casted_metadata
|
||||
return non_default_values
|
||||
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue