mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-24 18:24:20 +00:00
fix(team_endpoints.py): ensure 404 raised when team not found (#9038)
* fix(team_endpoints.py): ensure 404 raised when team not found * fix(key_management_endpoints.py): fix adding tags to key when metadata is empty * fix(key_management_endpoints.py): refactor set metadata field to use common function across keys + teams reduces scope for errors + easier testing * fix: fix linting error
This commit is contained in:
parent
6dc83135ab
commit
274147bc5e
4 changed files with 72 additions and 27 deletions
|
@ -1,4 +1,12 @@
|
|||
from litellm.proxy._types import LiteLLM_TeamTable, UserAPIKeyAuth
|
||||
from typing import Any, Union
|
||||
|
||||
from litellm.proxy._types import (
|
||||
GenerateKeyRequest,
|
||||
LiteLLM_ManagementEndpoint_MetadataFields_Premium,
|
||||
LiteLLM_TeamTable,
|
||||
UserAPIKeyAuth,
|
||||
)
|
||||
from litellm.proxy.utils import _premium_user_check
|
||||
|
||||
|
||||
def _is_user_team_admin(
|
||||
|
@ -12,3 +20,22 @@ def _is_user_team_admin(
|
|||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def _set_object_metadata_field(
|
||||
object_data: Union[LiteLLM_TeamTable, GenerateKeyRequest],
|
||||
field_name: str,
|
||||
value: Any,
|
||||
) -> None:
|
||||
"""
|
||||
Helper function to set metadata fields that require premium user checks
|
||||
|
||||
Args:
|
||||
object_data: The team data object to modify
|
||||
field_name: Name of the metadata field to set
|
||||
value: Value to set for the field
|
||||
"""
|
||||
if field_name in LiteLLM_ManagementEndpoint_MetadataFields_Premium:
|
||||
_premium_user_check()
|
||||
object_data.metadata = object_data.metadata or {}
|
||||
object_data.metadata[field_name] = value
|
||||
|
|
|
@ -35,7 +35,10 @@ from litellm.proxy.auth.auth_checks import (
|
|||
)
|
||||
from litellm.proxy.auth.user_api_key_auth import user_api_key_auth
|
||||
from litellm.proxy.hooks.key_management_event_hooks import KeyManagementEventHooks
|
||||
from litellm.proxy.management_endpoints.common_utils import _is_user_team_admin
|
||||
from litellm.proxy.management_endpoints.common_utils import (
|
||||
_is_user_team_admin,
|
||||
_set_object_metadata_field,
|
||||
)
|
||||
from litellm.proxy.management_helpers.utils import management_endpoint_wrapper
|
||||
from litellm.proxy.utils import (
|
||||
PrismaClient,
|
||||
|
@ -507,6 +510,17 @@ async def generate_key_fn( # noqa: PLR0915
|
|||
}
|
||||
)
|
||||
_budget_id = getattr(_budget, "budget_id", None)
|
||||
|
||||
# ADD METADATA FIELDS
|
||||
# Set Management Endpoint Metadata Fields
|
||||
for field in LiteLLM_ManagementEndpoint_MetadataFields_Premium:
|
||||
if getattr(data, field) is not None:
|
||||
_set_object_metadata_field(
|
||||
object_data=data,
|
||||
field_name=field,
|
||||
value=getattr(data, field),
|
||||
)
|
||||
|
||||
data_json = data.model_dump(exclude_unset=True, exclude_none=True) # type: ignore
|
||||
|
||||
# if we get max_budget passed to /key/generate, then use it as key_max_budget. Since generate_key_helper_fn is used to make new users
|
||||
|
@ -531,7 +545,8 @@ async def generate_key_fn( # noqa: PLR0915
|
|||
f"Only premium users can add tags to keys. {CommonProxyErrors.not_premium_user.value}"
|
||||
)
|
||||
|
||||
if data_json["metadata"] is None:
|
||||
_metadata = data_json.get("metadata")
|
||||
if not _metadata:
|
||||
data_json["metadata"] = {"tags": data_json["tags"]}
|
||||
else:
|
||||
data_json["metadata"]["tags"] = data_json["tags"]
|
||||
|
|
|
@ -14,7 +14,7 @@ import json
|
|||
import traceback
|
||||
import uuid
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import Any, List, Optional, Tuple, Union, cast
|
||||
from typing import List, Optional, Tuple, Union, cast
|
||||
|
||||
import fastapi
|
||||
from fastapi import APIRouter, Depends, Header, HTTPException, Request, status
|
||||
|
@ -57,7 +57,10 @@ from litellm.proxy.auth.auth_checks import (
|
|||
get_team_object,
|
||||
)
|
||||
from litellm.proxy.auth.user_api_key_auth import user_api_key_auth
|
||||
from litellm.proxy.management_endpoints.common_utils import _is_user_team_admin
|
||||
from litellm.proxy.management_endpoints.common_utils import (
|
||||
_is_user_team_admin,
|
||||
_set_object_metadata_field,
|
||||
)
|
||||
from litellm.proxy.management_helpers.utils import (
|
||||
add_new_member,
|
||||
management_endpoint_wrapper,
|
||||
|
@ -283,8 +286,8 @@ async def new_team( # noqa: PLR0915
|
|||
# Set Management Endpoint Metadata Fields
|
||||
for field in LiteLLM_ManagementEndpoint_MetadataFields_Premium:
|
||||
if getattr(data, field) is not None:
|
||||
_set_team_metadata_field(
|
||||
team_data=complete_team_data,
|
||||
_set_object_metadata_field(
|
||||
object_data=complete_team_data,
|
||||
field_name=field,
|
||||
value=getattr(data, field),
|
||||
)
|
||||
|
@ -1274,9 +1277,13 @@ async def team_info(
|
|||
)
|
||||
|
||||
try:
|
||||
team_info: BaseModel = await prisma_client.db.litellm_teamtable.find_unique(
|
||||
where={"team_id": team_id}
|
||||
team_info: Optional[BaseModel] = (
|
||||
await prisma_client.db.litellm_teamtable.find_unique(
|
||||
where={"team_id": team_id}
|
||||
)
|
||||
)
|
||||
if team_info is None:
|
||||
raise Exception
|
||||
except Exception:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
|
@ -1673,23 +1680,6 @@ def _update_team_metadata_field(updated_kv: dict, field_name: str) -> None:
|
|||
updated_kv["metadata"] = {field_name: _value}
|
||||
|
||||
|
||||
def _set_team_metadata_field(
|
||||
team_data: LiteLLM_TeamTable, field_name: str, value: Any
|
||||
) -> None:
|
||||
"""
|
||||
Helper function to set metadata fields that require premium user checks
|
||||
|
||||
Args:
|
||||
team_data: The team data object to modify
|
||||
field_name: Name of the metadata field to set
|
||||
value: Value to set for the field
|
||||
"""
|
||||
if field_name in LiteLLM_ManagementEndpoint_MetadataFields_Premium:
|
||||
_premium_user_check()
|
||||
team_data.metadata = team_data.metadata or {}
|
||||
team_data.metadata[field_name] = value
|
||||
|
||||
|
||||
@router.get(
|
||||
"/team/filter/ui",
|
||||
tags=["team management"],
|
||||
|
|
|
@ -6,6 +6,8 @@ import aiohttp
|
|||
import time, uuid
|
||||
from openai import AsyncOpenAI
|
||||
from typing import Optional
|
||||
import openai
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
|
||||
async def get_user_info(session, get_user, call_user, view_all: Optional[bool] = None):
|
||||
|
@ -358,6 +360,11 @@ async def get_team_info(session, get_team, call_key):
|
|||
print(response_text)
|
||||
print()
|
||||
|
||||
if status == 404:
|
||||
raise openai.NotFoundError(
|
||||
message="404 received", response=MagicMock(), body=None
|
||||
)
|
||||
|
||||
if status != 200:
|
||||
raise Exception(f"Request did not return a 200 status code: {status}")
|
||||
return await response.json()
|
||||
|
@ -549,7 +556,7 @@ async def test_team_delete():
|
|||
key_gen = await generate_key(session=session, i=0, team_id=team_data["team_id"])
|
||||
key = key_gen["key"]
|
||||
## Test key
|
||||
response = await chat_completion(session=session, key=key)
|
||||
# response = await chat_completion(session=session, key=key)
|
||||
## Delete team
|
||||
await delete_team(session=session, i=0, team_id=team_data["team_id"])
|
||||
|
||||
|
@ -559,6 +566,12 @@ async def test_team_delete():
|
|||
)
|
||||
assert len(user_info["teams"]) == 0
|
||||
|
||||
## ASSERT TEAM INFO NOW RETURNS A 404
|
||||
with pytest.raises(openai.NotFoundError):
|
||||
await get_team_info(
|
||||
session=session, get_team=team_data["team_id"], call_key="sk-1234"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("dimension", ["user_id", "user_email"])
|
||||
@pytest.mark.asyncio
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue