fix(team_endpoints.py): ensure 404 raised when team not found (#9038)

* fix(team_endpoints.py): ensure 404 raised when team not found

* fix(key_management_endpoints.py): fix adding tags to key when metadata is empty

* fix(key_management_endpoints.py): refactor set metadata field to use common function across keys + teams

reduces scope for errors + easier testing

* fix: fix linting error
This commit is contained in:
Krish Dholakia 2025-03-06 22:04:36 -08:00 committed by GitHub
parent 6dc83135ab
commit 274147bc5e
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 72 additions and 27 deletions

View file

@ -1,4 +1,12 @@
from litellm.proxy._types import LiteLLM_TeamTable, UserAPIKeyAuth from typing import Any, Union
from litellm.proxy._types import (
GenerateKeyRequest,
LiteLLM_ManagementEndpoint_MetadataFields_Premium,
LiteLLM_TeamTable,
UserAPIKeyAuth,
)
from litellm.proxy.utils import _premium_user_check
def _is_user_team_admin( def _is_user_team_admin(
@ -12,3 +20,22 @@ def _is_user_team_admin(
return True return True
return False return False
def _set_object_metadata_field(
object_data: Union[LiteLLM_TeamTable, GenerateKeyRequest],
field_name: str,
value: Any,
) -> None:
"""
Helper function to set metadata fields that require premium user checks
Args:
object_data: The team data object to modify
field_name: Name of the metadata field to set
value: Value to set for the field
"""
if field_name in LiteLLM_ManagementEndpoint_MetadataFields_Premium:
_premium_user_check()
object_data.metadata = object_data.metadata or {}
object_data.metadata[field_name] = value

View file

@ -35,7 +35,10 @@ from litellm.proxy.auth.auth_checks import (
) )
from litellm.proxy.auth.user_api_key_auth import user_api_key_auth from litellm.proxy.auth.user_api_key_auth import user_api_key_auth
from litellm.proxy.hooks.key_management_event_hooks import KeyManagementEventHooks from litellm.proxy.hooks.key_management_event_hooks import KeyManagementEventHooks
from litellm.proxy.management_endpoints.common_utils import _is_user_team_admin from litellm.proxy.management_endpoints.common_utils import (
_is_user_team_admin,
_set_object_metadata_field,
)
from litellm.proxy.management_helpers.utils import management_endpoint_wrapper from litellm.proxy.management_helpers.utils import management_endpoint_wrapper
from litellm.proxy.utils import ( from litellm.proxy.utils import (
PrismaClient, PrismaClient,
@ -507,6 +510,17 @@ async def generate_key_fn( # noqa: PLR0915
} }
) )
_budget_id = getattr(_budget, "budget_id", None) _budget_id = getattr(_budget, "budget_id", None)
# ADD METADATA FIELDS
# Set Management Endpoint Metadata Fields
for field in LiteLLM_ManagementEndpoint_MetadataFields_Premium:
if getattr(data, field) is not None:
_set_object_metadata_field(
object_data=data,
field_name=field,
value=getattr(data, field),
)
data_json = data.model_dump(exclude_unset=True, exclude_none=True) # type: ignore data_json = data.model_dump(exclude_unset=True, exclude_none=True) # type: ignore
# if we get max_budget passed to /key/generate, then use it as key_max_budget. Since generate_key_helper_fn is used to make new users # if we get max_budget passed to /key/generate, then use it as key_max_budget. Since generate_key_helper_fn is used to make new users
@ -531,7 +545,8 @@ async def generate_key_fn( # noqa: PLR0915
f"Only premium users can add tags to keys. {CommonProxyErrors.not_premium_user.value}" f"Only premium users can add tags to keys. {CommonProxyErrors.not_premium_user.value}"
) )
if data_json["metadata"] is None: _metadata = data_json.get("metadata")
if not _metadata:
data_json["metadata"] = {"tags": data_json["tags"]} data_json["metadata"] = {"tags": data_json["tags"]}
else: else:
data_json["metadata"]["tags"] = data_json["tags"] data_json["metadata"]["tags"] = data_json["tags"]

View file

@ -14,7 +14,7 @@ import json
import traceback import traceback
import uuid import uuid
from datetime import datetime, timedelta, timezone from datetime import datetime, timedelta, timezone
from typing import Any, List, Optional, Tuple, Union, cast from typing import List, Optional, Tuple, Union, cast
import fastapi import fastapi
from fastapi import APIRouter, Depends, Header, HTTPException, Request, status from fastapi import APIRouter, Depends, Header, HTTPException, Request, status
@ -57,7 +57,10 @@ from litellm.proxy.auth.auth_checks import (
get_team_object, get_team_object,
) )
from litellm.proxy.auth.user_api_key_auth import user_api_key_auth from litellm.proxy.auth.user_api_key_auth import user_api_key_auth
from litellm.proxy.management_endpoints.common_utils import _is_user_team_admin from litellm.proxy.management_endpoints.common_utils import (
_is_user_team_admin,
_set_object_metadata_field,
)
from litellm.proxy.management_helpers.utils import ( from litellm.proxy.management_helpers.utils import (
add_new_member, add_new_member,
management_endpoint_wrapper, management_endpoint_wrapper,
@ -283,8 +286,8 @@ async def new_team( # noqa: PLR0915
# Set Management Endpoint Metadata Fields # Set Management Endpoint Metadata Fields
for field in LiteLLM_ManagementEndpoint_MetadataFields_Premium: for field in LiteLLM_ManagementEndpoint_MetadataFields_Premium:
if getattr(data, field) is not None: if getattr(data, field) is not None:
_set_team_metadata_field( _set_object_metadata_field(
team_data=complete_team_data, object_data=complete_team_data,
field_name=field, field_name=field,
value=getattr(data, field), value=getattr(data, field),
) )
@ -1274,9 +1277,13 @@ async def team_info(
) )
try: try:
team_info: BaseModel = await prisma_client.db.litellm_teamtable.find_unique( team_info: Optional[BaseModel] = (
where={"team_id": team_id} await prisma_client.db.litellm_teamtable.find_unique(
where={"team_id": team_id}
)
) )
if team_info is None:
raise Exception
except Exception: except Exception:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, status_code=status.HTTP_404_NOT_FOUND,
@ -1673,23 +1680,6 @@ def _update_team_metadata_field(updated_kv: dict, field_name: str) -> None:
updated_kv["metadata"] = {field_name: _value} updated_kv["metadata"] = {field_name: _value}
def _set_team_metadata_field(
team_data: LiteLLM_TeamTable, field_name: str, value: Any
) -> None:
"""
Helper function to set metadata fields that require premium user checks
Args:
team_data: The team data object to modify
field_name: Name of the metadata field to set
value: Value to set for the field
"""
if field_name in LiteLLM_ManagementEndpoint_MetadataFields_Premium:
_premium_user_check()
team_data.metadata = team_data.metadata or {}
team_data.metadata[field_name] = value
@router.get( @router.get(
"/team/filter/ui", "/team/filter/ui",
tags=["team management"], tags=["team management"],

View file

@ -6,6 +6,8 @@ import aiohttp
import time, uuid import time, uuid
from openai import AsyncOpenAI from openai import AsyncOpenAI
from typing import Optional from typing import Optional
import openai
from unittest.mock import MagicMock, patch
async def get_user_info(session, get_user, call_user, view_all: Optional[bool] = None): async def get_user_info(session, get_user, call_user, view_all: Optional[bool] = None):
@ -358,6 +360,11 @@ async def get_team_info(session, get_team, call_key):
print(response_text) print(response_text)
print() print()
if status == 404:
raise openai.NotFoundError(
message="404 received", response=MagicMock(), body=None
)
if status != 200: if status != 200:
raise Exception(f"Request did not return a 200 status code: {status}") raise Exception(f"Request did not return a 200 status code: {status}")
return await response.json() return await response.json()
@ -549,7 +556,7 @@ async def test_team_delete():
key_gen = await generate_key(session=session, i=0, team_id=team_data["team_id"]) key_gen = await generate_key(session=session, i=0, team_id=team_data["team_id"])
key = key_gen["key"] key = key_gen["key"]
## Test key ## Test key
response = await chat_completion(session=session, key=key) # response = await chat_completion(session=session, key=key)
## Delete team ## Delete team
await delete_team(session=session, i=0, team_id=team_data["team_id"]) await delete_team(session=session, i=0, team_id=team_data["team_id"])
@ -559,6 +566,12 @@ async def test_team_delete():
) )
assert len(user_info["teams"]) == 0 assert len(user_info["teams"]) == 0
## ASSERT TEAM INFO NOW RETURNS A 404
with pytest.raises(openai.NotFoundError):
await get_team_info(
session=session, get_team=team_data["team_id"], call_key="sk-1234"
)
@pytest.mark.parametrize("dimension", ["user_id", "user_email"]) @pytest.mark.parametrize("dimension", ["user_id", "user_email"])
@pytest.mark.asyncio @pytest.mark.asyncio