forked from phoenix/litellm-mirror
fix: fix linting errors
This commit is contained in:
parent
ddf56b8935
commit
84f3ac7d25
2 changed files with 63 additions and 55 deletions
|
@ -43,7 +43,7 @@ from litellm.proxy.utils import handle_exception_on_proxy
|
||||||
router = APIRouter()
|
router = APIRouter()
|
||||||
|
|
||||||
|
|
||||||
def _update_internal_user_params(data_json: dict, data: NewUserRequest) -> dict:
|
def _update_internal_new_user_params(data_json: dict, data: NewUserRequest) -> dict:
|
||||||
if "user_id" in data_json and data_json["user_id"] is None:
|
if "user_id" in data_json and data_json["user_id"] is None:
|
||||||
data_json["user_id"] = str(uuid.uuid4())
|
data_json["user_id"] = str(uuid.uuid4())
|
||||||
auto_create_key = data_json.pop("auto_create_key", True)
|
auto_create_key = data_json.pop("auto_create_key", True)
|
||||||
|
@ -146,7 +146,7 @@ async def new_user(
|
||||||
from litellm.proxy.proxy_server import general_settings, proxy_logging_obj
|
from litellm.proxy.proxy_server import general_settings, proxy_logging_obj
|
||||||
|
|
||||||
data_json = data.json() # type: ignore
|
data_json = data.json() # type: ignore
|
||||||
data_json = _update_internal_user_params(data_json, data)
|
data_json = _update_internal_new_user_params(data_json, data)
|
||||||
response = await generate_key_helper_fn(request_type="user", **data_json)
|
response = await generate_key_helper_fn(request_type="user", **data_json)
|
||||||
|
|
||||||
# Admin UI Logic
|
# Admin UI Logic
|
||||||
|
@ -439,6 +439,52 @@ async def user_info( # noqa: PLR0915
|
||||||
raise handle_exception_on_proxy(e)
|
raise handle_exception_on_proxy(e)
|
||||||
|
|
||||||
|
|
||||||
|
def _update_internal_user_params(data_json: dict, data: UpdateUserRequest) -> dict:
|
||||||
|
non_default_values = {}
|
||||||
|
for k, v in data_json.items():
|
||||||
|
if (
|
||||||
|
v is not None
|
||||||
|
and v
|
||||||
|
not in (
|
||||||
|
[],
|
||||||
|
{},
|
||||||
|
0,
|
||||||
|
)
|
||||||
|
and k not in LiteLLM_ManagementEndpoint_MetadataFields
|
||||||
|
): # models default to [], spend defaults to 0, we should not reset these values
|
||||||
|
non_default_values[k] = v
|
||||||
|
|
||||||
|
is_internal_user = False
|
||||||
|
if data.user_role == LitellmUserRoles.INTERNAL_USER:
|
||||||
|
is_internal_user = True
|
||||||
|
|
||||||
|
if "budget_duration" in non_default_values:
|
||||||
|
duration_s = duration_in_seconds(duration=non_default_values["budget_duration"])
|
||||||
|
user_reset_at = datetime.now(timezone.utc) + timedelta(seconds=duration_s)
|
||||||
|
non_default_values["budget_reset_at"] = user_reset_at
|
||||||
|
|
||||||
|
if "max_budget" not in non_default_values:
|
||||||
|
if (
|
||||||
|
is_internal_user and litellm.max_internal_user_budget is not None
|
||||||
|
): # applies internal user limits, if user role updated
|
||||||
|
non_default_values["max_budget"] = litellm.max_internal_user_budget
|
||||||
|
|
||||||
|
if (
|
||||||
|
"budget_duration" not in non_default_values
|
||||||
|
): # applies internal user limits, if user role updated
|
||||||
|
if is_internal_user and litellm.internal_user_budget_duration is not None:
|
||||||
|
non_default_values["budget_duration"] = (
|
||||||
|
litellm.internal_user_budget_duration
|
||||||
|
)
|
||||||
|
duration_s = duration_in_seconds(
|
||||||
|
duration=non_default_values["budget_duration"]
|
||||||
|
)
|
||||||
|
user_reset_at = datetime.now(timezone.utc) + timedelta(seconds=duration_s)
|
||||||
|
non_default_values["budget_reset_at"] = user_reset_at
|
||||||
|
|
||||||
|
return non_default_values
|
||||||
|
|
||||||
|
|
||||||
@router.post(
|
@router.post(
|
||||||
"/user/update",
|
"/user/update",
|
||||||
tags=["Internal User management"],
|
tags=["Internal User management"],
|
||||||
|
@ -504,51 +550,9 @@ async def user_update(
|
||||||
raise Exception("Not connected to DB!")
|
raise Exception("Not connected to DB!")
|
||||||
|
|
||||||
# get non default values for key
|
# get non default values for key
|
||||||
non_default_values = {}
|
non_default_values = _update_internal_user_params(
|
||||||
for k, v in data_json.items():
|
data_json=data_json, data=data
|
||||||
if (
|
|
||||||
v is not None
|
|
||||||
and v
|
|
||||||
not in (
|
|
||||||
[],
|
|
||||||
{},
|
|
||||||
0,
|
|
||||||
)
|
)
|
||||||
and k not in LiteLLM_ManagementEndpoint_MetadataFields
|
|
||||||
): # models default to [], spend defaults to 0, we should not reset these values
|
|
||||||
non_default_values[k] = v
|
|
||||||
|
|
||||||
is_internal_user = False
|
|
||||||
if data.user_role == LitellmUserRoles.INTERNAL_USER:
|
|
||||||
is_internal_user = True
|
|
||||||
|
|
||||||
if "budget_duration" in non_default_values:
|
|
||||||
duration_s = duration_in_seconds(
|
|
||||||
duration=non_default_values["budget_duration"]
|
|
||||||
)
|
|
||||||
user_reset_at = datetime.now(timezone.utc) + timedelta(seconds=duration_s)
|
|
||||||
non_default_values["budget_reset_at"] = user_reset_at
|
|
||||||
|
|
||||||
if "max_budget" not in non_default_values:
|
|
||||||
if (
|
|
||||||
is_internal_user and litellm.max_internal_user_budget is not None
|
|
||||||
): # applies internal user limits, if user role updated
|
|
||||||
non_default_values["max_budget"] = litellm.max_internal_user_budget
|
|
||||||
|
|
||||||
if (
|
|
||||||
"budget_duration" not in non_default_values
|
|
||||||
): # applies internal user limits, if user role updated
|
|
||||||
if is_internal_user and litellm.internal_user_budget_duration is not None:
|
|
||||||
non_default_values["budget_duration"] = (
|
|
||||||
litellm.internal_user_budget_duration
|
|
||||||
)
|
|
||||||
duration_s = duration_in_seconds(
|
|
||||||
duration=non_default_values["budget_duration"]
|
|
||||||
)
|
|
||||||
user_reset_at = datetime.now(timezone.utc) + timedelta(
|
|
||||||
seconds=duration_s
|
|
||||||
)
|
|
||||||
non_default_values["budget_reset_at"] = user_reset_at
|
|
||||||
|
|
||||||
existing_user_row = await prisma_client.get_data(
|
existing_user_row = await prisma_client.get_data(
|
||||||
user_id=data.user_id, table_name="user", query_type="find_unique"
|
user_id=data.user_id, table_name="user", query_type="find_unique"
|
||||||
|
|
|
@ -17,7 +17,7 @@ import secrets
|
||||||
import traceback
|
import traceback
|
||||||
import uuid
|
import uuid
|
||||||
from datetime import datetime, timedelta, timezone
|
from datetime import datetime, timedelta, timezone
|
||||||
from typing import List, Optional, Tuple
|
from typing import List, Optional, Tuple, cast
|
||||||
|
|
||||||
import fastapi
|
import fastapi
|
||||||
from fastapi import APIRouter, Depends, Header, HTTPException, Query, Request, status
|
from fastapi import APIRouter, Depends, Header, HTTPException, Query, Request, status
|
||||||
|
@ -469,20 +469,22 @@ def prepare_metadata_fields(
|
||||||
non_default_values["metadata"] = non_default_values["metadata"].copy()
|
non_default_values["metadata"] = non_default_values["metadata"].copy()
|
||||||
non_default_values["metadata"].update(existing_metadata)
|
non_default_values["metadata"].update(existing_metadata)
|
||||||
|
|
||||||
|
casted_metadata = cast(dict, non_default_values["metadata"])
|
||||||
|
|
||||||
data_json = data.model_dump(exclude_unset=True)
|
data_json = data.model_dump(exclude_unset=True)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
for k, v in data_json.items():
|
for k, v in data_json.items():
|
||||||
if k == "model_tpm_limit" or k == "model_rpm_limit":
|
if k == "model_tpm_limit" or k == "model_rpm_limit":
|
||||||
if k not in non_default_values["metadata"]:
|
if k not in casted_metadata:
|
||||||
non_default_values["metadata"][k] = {}
|
casted_metadata[k] = {}
|
||||||
non_default_values["metadata"][k].update(v)
|
casted_metadata[k].update(v)
|
||||||
|
|
||||||
if k == "tags" or k == "guardrails":
|
if k == "tags" or k == "guardrails":
|
||||||
if k not in non_default_values["metadata"]:
|
if k not in casted_metadata:
|
||||||
non_default_values["metadata"][k] = []
|
casted_metadata[k] = []
|
||||||
seen = set(non_default_values["metadata"][k])
|
seen = set(casted_metadata[k])
|
||||||
non_default_values["metadata"][k].extend(
|
casted_metadata[k].extend(
|
||||||
x for x in v if x not in seen and not seen.add(x) # type: ignore
|
x for x in v if x not in seen and not seen.add(x) # type: ignore
|
||||||
) # prevent duplicates from being added + maintain initial order
|
) # prevent duplicates from being added + maintain initial order
|
||||||
|
|
||||||
|
@ -492,6 +494,8 @@ def prepare_metadata_fields(
|
||||||
str(e)
|
str(e)
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
non_default_values["metadata"] = casted_metadata
|
||||||
return non_default_values
|
return non_default_values
|
||||||
|
|
||||||
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue