fix: fix linting errors

This commit is contained in:
Krrish Dholakia 2024-11-30 17:18:00 -08:00
parent ddf56b8935
commit 84f3ac7d25
2 changed files with 63 additions and 55 deletions

View file

@ -43,7 +43,7 @@ from litellm.proxy.utils import handle_exception_on_proxy
router = APIRouter()
def _update_internal_user_params(data_json: dict, data: NewUserRequest) -> dict:
def _update_internal_new_user_params(data_json: dict, data: NewUserRequest) -> dict:
if "user_id" in data_json and data_json["user_id"] is None:
data_json["user_id"] = str(uuid.uuid4())
auto_create_key = data_json.pop("auto_create_key", True)
@ -146,7 +146,7 @@ async def new_user(
from litellm.proxy.proxy_server import general_settings, proxy_logging_obj
data_json = data.json() # type: ignore
data_json = _update_internal_user_params(data_json, data)
data_json = _update_internal_new_user_params(data_json, data)
response = await generate_key_helper_fn(request_type="user", **data_json)
# Admin UI Logic
@ -439,6 +439,52 @@ async def user_info( # noqa: PLR0915
raise handle_exception_on_proxy(e)
def _update_internal_user_params(data_json: dict, data: UpdateUserRequest) -> dict:
non_default_values = {}
for k, v in data_json.items():
if (
v is not None
and v
not in (
[],
{},
0,
)
and k not in LiteLLM_ManagementEndpoint_MetadataFields
): # models default to [], spend defaults to 0, we should not reset these values
non_default_values[k] = v
is_internal_user = False
if data.user_role == LitellmUserRoles.INTERNAL_USER:
is_internal_user = True
if "budget_duration" in non_default_values:
duration_s = duration_in_seconds(duration=non_default_values["budget_duration"])
user_reset_at = datetime.now(timezone.utc) + timedelta(seconds=duration_s)
non_default_values["budget_reset_at"] = user_reset_at
if "max_budget" not in non_default_values:
if (
is_internal_user and litellm.max_internal_user_budget is not None
): # applies internal user limits, if user role updated
non_default_values["max_budget"] = litellm.max_internal_user_budget
if (
"budget_duration" not in non_default_values
): # applies internal user limits, if user role updated
if is_internal_user and litellm.internal_user_budget_duration is not None:
non_default_values["budget_duration"] = (
litellm.internal_user_budget_duration
)
duration_s = duration_in_seconds(
duration=non_default_values["budget_duration"]
)
user_reset_at = datetime.now(timezone.utc) + timedelta(seconds=duration_s)
non_default_values["budget_reset_at"] = user_reset_at
return non_default_values
@router.post(
"/user/update",
tags=["Internal User management"],
@ -504,51 +550,9 @@ async def user_update(
raise Exception("Not connected to DB!")
# get non default values for key
non_default_values = {}
for k, v in data_json.items():
if (
v is not None
and v
not in (
[],
{},
0,
)
and k not in LiteLLM_ManagementEndpoint_MetadataFields
): # models default to [], spend defaults to 0, we should not reset these values
non_default_values[k] = v
is_internal_user = False
if data.user_role == LitellmUserRoles.INTERNAL_USER:
is_internal_user = True
if "budget_duration" in non_default_values:
duration_s = duration_in_seconds(
duration=non_default_values["budget_duration"]
)
user_reset_at = datetime.now(timezone.utc) + timedelta(seconds=duration_s)
non_default_values["budget_reset_at"] = user_reset_at
if "max_budget" not in non_default_values:
if (
is_internal_user and litellm.max_internal_user_budget is not None
): # applies internal user limits, if user role updated
non_default_values["max_budget"] = litellm.max_internal_user_budget
if (
"budget_duration" not in non_default_values
): # applies internal user limits, if user role updated
if is_internal_user and litellm.internal_user_budget_duration is not None:
non_default_values["budget_duration"] = (
litellm.internal_user_budget_duration
)
duration_s = duration_in_seconds(
duration=non_default_values["budget_duration"]
)
user_reset_at = datetime.now(timezone.utc) + timedelta(
seconds=duration_s
)
non_default_values["budget_reset_at"] = user_reset_at
non_default_values = _update_internal_user_params(
data_json=data_json, data=data
)
existing_user_row = await prisma_client.get_data(
user_id=data.user_id, table_name="user", query_type="find_unique"

View file

@ -17,7 +17,7 @@ import secrets
import traceback
import uuid
from datetime import datetime, timedelta, timezone
from typing import List, Optional, Tuple
from typing import List, Optional, Tuple, cast
import fastapi
from fastapi import APIRouter, Depends, Header, HTTPException, Query, Request, status
@ -469,20 +469,22 @@ def prepare_metadata_fields(
non_default_values["metadata"] = non_default_values["metadata"].copy()
non_default_values["metadata"].update(existing_metadata)
casted_metadata = cast(dict, non_default_values["metadata"])
data_json = data.model_dump(exclude_unset=True)
try:
for k, v in data_json.items():
if k == "model_tpm_limit" or k == "model_rpm_limit":
if k not in non_default_values["metadata"]:
non_default_values["metadata"][k] = {}
non_default_values["metadata"][k].update(v)
if k not in casted_metadata:
casted_metadata[k] = {}
casted_metadata[k].update(v)
if k == "tags" or k == "guardrails":
if k not in non_default_values["metadata"]:
non_default_values["metadata"][k] = []
seen = set(non_default_values["metadata"][k])
non_default_values["metadata"][k].extend(
if k not in casted_metadata:
casted_metadata[k] = []
seen = set(casted_metadata[k])
casted_metadata[k].extend(
x for x in v if x not in seen and not seen.add(x) # type: ignore
) # prevent duplicates from being added + maintain initial order
@ -492,6 +494,8 @@ def prepare_metadata_fields(
str(e)
)
)
non_default_values["metadata"] = casted_metadata
return non_default_values