fix: fix linting errors

This commit is contained in:
Krrish Dholakia 2024-11-30 17:18:00 -08:00
parent ddf56b8935
commit 84f3ac7d25
2 changed files with 63 additions and 55 deletions

View file

@ -43,7 +43,7 @@ from litellm.proxy.utils import handle_exception_on_proxy
router = APIRouter() router = APIRouter()
def _update_internal_user_params(data_json: dict, data: NewUserRequest) -> dict: def _update_internal_new_user_params(data_json: dict, data: NewUserRequest) -> dict:
if "user_id" in data_json and data_json["user_id"] is None: if "user_id" in data_json and data_json["user_id"] is None:
data_json["user_id"] = str(uuid.uuid4()) data_json["user_id"] = str(uuid.uuid4())
auto_create_key = data_json.pop("auto_create_key", True) auto_create_key = data_json.pop("auto_create_key", True)
@ -146,7 +146,7 @@ async def new_user(
from litellm.proxy.proxy_server import general_settings, proxy_logging_obj from litellm.proxy.proxy_server import general_settings, proxy_logging_obj
data_json = data.json() # type: ignore data_json = data.json() # type: ignore
data_json = _update_internal_user_params(data_json, data) data_json = _update_internal_new_user_params(data_json, data)
response = await generate_key_helper_fn(request_type="user", **data_json) response = await generate_key_helper_fn(request_type="user", **data_json)
# Admin UI Logic # Admin UI Logic
@ -439,6 +439,52 @@ async def user_info( # noqa: PLR0915
raise handle_exception_on_proxy(e) raise handle_exception_on_proxy(e)
def _update_internal_user_params(data_json: dict, data: UpdateUserRequest) -> dict:
non_default_values = {}
for k, v in data_json.items():
if (
v is not None
and v
not in (
[],
{},
0,
)
and k not in LiteLLM_ManagementEndpoint_MetadataFields
): # models default to [], spend defaults to 0, we should not reset these values
non_default_values[k] = v
is_internal_user = False
if data.user_role == LitellmUserRoles.INTERNAL_USER:
is_internal_user = True
if "budget_duration" in non_default_values:
duration_s = duration_in_seconds(duration=non_default_values["budget_duration"])
user_reset_at = datetime.now(timezone.utc) + timedelta(seconds=duration_s)
non_default_values["budget_reset_at"] = user_reset_at
if "max_budget" not in non_default_values:
if (
is_internal_user and litellm.max_internal_user_budget is not None
): # applies internal user limits, if user role updated
non_default_values["max_budget"] = litellm.max_internal_user_budget
if (
"budget_duration" not in non_default_values
): # applies internal user limits, if user role updated
if is_internal_user and litellm.internal_user_budget_duration is not None:
non_default_values["budget_duration"] = (
litellm.internal_user_budget_duration
)
duration_s = duration_in_seconds(
duration=non_default_values["budget_duration"]
)
user_reset_at = datetime.now(timezone.utc) + timedelta(seconds=duration_s)
non_default_values["budget_reset_at"] = user_reset_at
return non_default_values
@router.post( @router.post(
"/user/update", "/user/update",
tags=["Internal User management"], tags=["Internal User management"],
@ -504,51 +550,9 @@ async def user_update(
raise Exception("Not connected to DB!") raise Exception("Not connected to DB!")
# get non default values for key # get non default values for key
non_default_values = {} non_default_values = _update_internal_user_params(
for k, v in data_json.items(): data_json=data_json, data=data
if (
v is not None
and v
not in (
[],
{},
0,
) )
and k not in LiteLLM_ManagementEndpoint_MetadataFields
): # models default to [], spend defaults to 0, we should not reset these values
non_default_values[k] = v
is_internal_user = False
if data.user_role == LitellmUserRoles.INTERNAL_USER:
is_internal_user = True
if "budget_duration" in non_default_values:
duration_s = duration_in_seconds(
duration=non_default_values["budget_duration"]
)
user_reset_at = datetime.now(timezone.utc) + timedelta(seconds=duration_s)
non_default_values["budget_reset_at"] = user_reset_at
if "max_budget" not in non_default_values:
if (
is_internal_user and litellm.max_internal_user_budget is not None
): # applies internal user limits, if user role updated
non_default_values["max_budget"] = litellm.max_internal_user_budget
if (
"budget_duration" not in non_default_values
): # applies internal user limits, if user role updated
if is_internal_user and litellm.internal_user_budget_duration is not None:
non_default_values["budget_duration"] = (
litellm.internal_user_budget_duration
)
duration_s = duration_in_seconds(
duration=non_default_values["budget_duration"]
)
user_reset_at = datetime.now(timezone.utc) + timedelta(
seconds=duration_s
)
non_default_values["budget_reset_at"] = user_reset_at
existing_user_row = await prisma_client.get_data( existing_user_row = await prisma_client.get_data(
user_id=data.user_id, table_name="user", query_type="find_unique" user_id=data.user_id, table_name="user", query_type="find_unique"

View file

@ -17,7 +17,7 @@ import secrets
import traceback import traceback
import uuid import uuid
from datetime import datetime, timedelta, timezone from datetime import datetime, timedelta, timezone
from typing import List, Optional, Tuple from typing import List, Optional, Tuple, cast
import fastapi import fastapi
from fastapi import APIRouter, Depends, Header, HTTPException, Query, Request, status from fastapi import APIRouter, Depends, Header, HTTPException, Query, Request, status
@ -469,20 +469,22 @@ def prepare_metadata_fields(
non_default_values["metadata"] = non_default_values["metadata"].copy() non_default_values["metadata"] = non_default_values["metadata"].copy()
non_default_values["metadata"].update(existing_metadata) non_default_values["metadata"].update(existing_metadata)
casted_metadata = cast(dict, non_default_values["metadata"])
data_json = data.model_dump(exclude_unset=True) data_json = data.model_dump(exclude_unset=True)
try: try:
for k, v in data_json.items(): for k, v in data_json.items():
if k == "model_tpm_limit" or k == "model_rpm_limit": if k == "model_tpm_limit" or k == "model_rpm_limit":
if k not in non_default_values["metadata"]: if k not in casted_metadata:
non_default_values["metadata"][k] = {} casted_metadata[k] = {}
non_default_values["metadata"][k].update(v) casted_metadata[k].update(v)
if k == "tags" or k == "guardrails": if k == "tags" or k == "guardrails":
if k not in non_default_values["metadata"]: if k not in casted_metadata:
non_default_values["metadata"][k] = [] casted_metadata[k] = []
seen = set(non_default_values["metadata"][k]) seen = set(casted_metadata[k])
non_default_values["metadata"][k].extend( casted_metadata[k].extend(
x for x in v if x not in seen and not seen.add(x) # type: ignore x for x in v if x not in seen and not seen.add(x) # type: ignore
) # prevent duplicates from being added + maintain initial order ) # prevent duplicates from being added + maintain initial order
@ -492,6 +494,8 @@ def prepare_metadata_fields(
str(e) str(e)
) )
) )
non_default_values["metadata"] = casted_metadata
return non_default_values return non_default_values