fix(team_endpoints.py): ensure 404 raised when team not found (#9038)

* fix(team_endpoints.py): ensure 404 raised when team not found

* fix(key_management_endpoints.py): fix adding tags to key when metadata is empty

* fix(key_management_endpoints.py): refactor set metadata field to use common function across keys + teams

reduces scope for errors + easier testing

* fix: fix linting error
This commit is contained in:
Krish Dholakia 2025-03-06 22:04:36 -08:00 committed by GitHub
parent 6dc83135ab
commit 274147bc5e
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 72 additions and 27 deletions

View file

@ -1,4 +1,12 @@
from litellm.proxy._types import LiteLLM_TeamTable, UserAPIKeyAuth
from typing import Any, Union
from litellm.proxy._types import (
GenerateKeyRequest,
LiteLLM_ManagementEndpoint_MetadataFields_Premium,
LiteLLM_TeamTable,
UserAPIKeyAuth,
)
from litellm.proxy.utils import _premium_user_check
def _is_user_team_admin(
@ -12,3 +20,22 @@ def _is_user_team_admin(
return True
return False
def _set_object_metadata_field(
object_data: Union[LiteLLM_TeamTable, GenerateKeyRequest],
field_name: str,
value: Any,
) -> None:
"""
Helper function to set metadata fields that require premium user checks
Args:
object_data: The team data object to modify
field_name: Name of the metadata field to set
value: Value to set for the field
"""
if field_name in LiteLLM_ManagementEndpoint_MetadataFields_Premium:
_premium_user_check()
object_data.metadata = object_data.metadata or {}
object_data.metadata[field_name] = value

View file

@ -35,7 +35,10 @@ from litellm.proxy.auth.auth_checks import (
)
from litellm.proxy.auth.user_api_key_auth import user_api_key_auth
from litellm.proxy.hooks.key_management_event_hooks import KeyManagementEventHooks
from litellm.proxy.management_endpoints.common_utils import _is_user_team_admin
from litellm.proxy.management_endpoints.common_utils import (
_is_user_team_admin,
_set_object_metadata_field,
)
from litellm.proxy.management_helpers.utils import management_endpoint_wrapper
from litellm.proxy.utils import (
PrismaClient,
@ -507,6 +510,17 @@ async def generate_key_fn( # noqa: PLR0915
}
)
_budget_id = getattr(_budget, "budget_id", None)
# ADD METADATA FIELDS
# Set Management Endpoint Metadata Fields
for field in LiteLLM_ManagementEndpoint_MetadataFields_Premium:
if getattr(data, field) is not None:
_set_object_metadata_field(
object_data=data,
field_name=field,
value=getattr(data, field),
)
data_json = data.model_dump(exclude_unset=True, exclude_none=True) # type: ignore
# if we get max_budget passed to /key/generate, then use it as key_max_budget. Since generate_key_helper_fn is used to make new users
@ -531,7 +545,8 @@ async def generate_key_fn( # noqa: PLR0915
f"Only premium users can add tags to keys. {CommonProxyErrors.not_premium_user.value}"
)
if data_json["metadata"] is None:
_metadata = data_json.get("metadata")
if not _metadata:
data_json["metadata"] = {"tags": data_json["tags"]}
else:
data_json["metadata"]["tags"] = data_json["tags"]

View file

@ -14,7 +14,7 @@ import json
import traceback
import uuid
from datetime import datetime, timedelta, timezone
from typing import Any, List, Optional, Tuple, Union, cast
from typing import List, Optional, Tuple, Union, cast
import fastapi
from fastapi import APIRouter, Depends, Header, HTTPException, Request, status
@ -57,7 +57,10 @@ from litellm.proxy.auth.auth_checks import (
get_team_object,
)
from litellm.proxy.auth.user_api_key_auth import user_api_key_auth
from litellm.proxy.management_endpoints.common_utils import _is_user_team_admin
from litellm.proxy.management_endpoints.common_utils import (
_is_user_team_admin,
_set_object_metadata_field,
)
from litellm.proxy.management_helpers.utils import (
add_new_member,
management_endpoint_wrapper,
@ -283,8 +286,8 @@ async def new_team( # noqa: PLR0915
# Set Management Endpoint Metadata Fields
for field in LiteLLM_ManagementEndpoint_MetadataFields_Premium:
if getattr(data, field) is not None:
_set_team_metadata_field(
team_data=complete_team_data,
_set_object_metadata_field(
object_data=complete_team_data,
field_name=field,
value=getattr(data, field),
)
@ -1274,9 +1277,13 @@ async def team_info(
)
try:
team_info: BaseModel = await prisma_client.db.litellm_teamtable.find_unique(
where={"team_id": team_id}
team_info: Optional[BaseModel] = (
await prisma_client.db.litellm_teamtable.find_unique(
where={"team_id": team_id}
)
)
if team_info is None:
raise Exception
except Exception:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
@ -1673,23 +1680,6 @@ def _update_team_metadata_field(updated_kv: dict, field_name: str) -> None:
updated_kv["metadata"] = {field_name: _value}
def _set_team_metadata_field(
team_data: LiteLLM_TeamTable, field_name: str, value: Any
) -> None:
"""
Helper function to set metadata fields that require premium user checks
Args:
team_data: The team data object to modify
field_name: Name of the metadata field to set
value: Value to set for the field
"""
if field_name in LiteLLM_ManagementEndpoint_MetadataFields_Premium:
_premium_user_check()
team_data.metadata = team_data.metadata or {}
team_data.metadata[field_name] = value
@router.get(
"/team/filter/ui",
tags=["team management"],

View file

@ -6,6 +6,8 @@ import aiohttp
import time, uuid
from openai import AsyncOpenAI
from typing import Optional
import openai
from unittest.mock import MagicMock, patch
async def get_user_info(session, get_user, call_user, view_all: Optional[bool] = None):
@ -358,6 +360,11 @@ async def get_team_info(session, get_team, call_key):
print(response_text)
print()
if status == 404:
raise openai.NotFoundError(
message="404 received", response=MagicMock(), body=None
)
if status != 200:
raise Exception(f"Request did not return a 200 status code: {status}")
return await response.json()
@ -549,7 +556,7 @@ async def test_team_delete():
key_gen = await generate_key(session=session, i=0, team_id=team_data["team_id"])
key = key_gen["key"]
## Test key
response = await chat_completion(session=session, key=key)
# response = await chat_completion(session=session, key=key)
## Delete team
await delete_team(session=session, i=0, team_id=team_data["team_id"])
@ -559,6 +566,12 @@ async def test_team_delete():
)
assert len(user_info["teams"]) == 0
## ASSERT TEAM INFO NOW RETURNS A 404
with pytest.raises(openai.NotFoundError):
await get_team_info(
session=session, get_team=team_data["team_id"], call_key="sk-1234"
)
@pytest.mark.parametrize("dimension", ["user_id", "user_email"])
@pytest.mark.asyncio