mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 03:04:13 +00:00
fix(team_endpoints.py): ensure 404 raised when team not found (#9038)
* fix(team_endpoints.py): ensure 404 raised when team not found * fix(key_management_endpoints.py): fix adding tags to key when metadata is empty * fix(key_management_endpoints.py): refactor set metadata field to use common function across keys + teams reduces scope for errors + easier testing * fix: fix linting error
This commit is contained in:
parent
6dc83135ab
commit
274147bc5e
4 changed files with 72 additions and 27 deletions
|
@ -1,4 +1,12 @@
|
||||||
from litellm.proxy._types import LiteLLM_TeamTable, UserAPIKeyAuth
|
from typing import Any, Union
|
||||||
|
|
||||||
|
from litellm.proxy._types import (
|
||||||
|
GenerateKeyRequest,
|
||||||
|
LiteLLM_ManagementEndpoint_MetadataFields_Premium,
|
||||||
|
LiteLLM_TeamTable,
|
||||||
|
UserAPIKeyAuth,
|
||||||
|
)
|
||||||
|
from litellm.proxy.utils import _premium_user_check
|
||||||
|
|
||||||
|
|
||||||
def _is_user_team_admin(
|
def _is_user_team_admin(
|
||||||
|
@ -12,3 +20,22 @@ def _is_user_team_admin(
|
||||||
return True
|
return True
|
||||||
|
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def _set_object_metadata_field(
|
||||||
|
object_data: Union[LiteLLM_TeamTable, GenerateKeyRequest],
|
||||||
|
field_name: str,
|
||||||
|
value: Any,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Helper function to set metadata fields that require premium user checks
|
||||||
|
|
||||||
|
Args:
|
||||||
|
object_data: The team data object to modify
|
||||||
|
field_name: Name of the metadata field to set
|
||||||
|
value: Value to set for the field
|
||||||
|
"""
|
||||||
|
if field_name in LiteLLM_ManagementEndpoint_MetadataFields_Premium:
|
||||||
|
_premium_user_check()
|
||||||
|
object_data.metadata = object_data.metadata or {}
|
||||||
|
object_data.metadata[field_name] = value
|
||||||
|
|
|
@ -35,7 +35,10 @@ from litellm.proxy.auth.auth_checks import (
|
||||||
)
|
)
|
||||||
from litellm.proxy.auth.user_api_key_auth import user_api_key_auth
|
from litellm.proxy.auth.user_api_key_auth import user_api_key_auth
|
||||||
from litellm.proxy.hooks.key_management_event_hooks import KeyManagementEventHooks
|
from litellm.proxy.hooks.key_management_event_hooks import KeyManagementEventHooks
|
||||||
from litellm.proxy.management_endpoints.common_utils import _is_user_team_admin
|
from litellm.proxy.management_endpoints.common_utils import (
|
||||||
|
_is_user_team_admin,
|
||||||
|
_set_object_metadata_field,
|
||||||
|
)
|
||||||
from litellm.proxy.management_helpers.utils import management_endpoint_wrapper
|
from litellm.proxy.management_helpers.utils import management_endpoint_wrapper
|
||||||
from litellm.proxy.utils import (
|
from litellm.proxy.utils import (
|
||||||
PrismaClient,
|
PrismaClient,
|
||||||
|
@ -507,6 +510,17 @@ async def generate_key_fn( # noqa: PLR0915
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
_budget_id = getattr(_budget, "budget_id", None)
|
_budget_id = getattr(_budget, "budget_id", None)
|
||||||
|
|
||||||
|
# ADD METADATA FIELDS
|
||||||
|
# Set Management Endpoint Metadata Fields
|
||||||
|
for field in LiteLLM_ManagementEndpoint_MetadataFields_Premium:
|
||||||
|
if getattr(data, field) is not None:
|
||||||
|
_set_object_metadata_field(
|
||||||
|
object_data=data,
|
||||||
|
field_name=field,
|
||||||
|
value=getattr(data, field),
|
||||||
|
)
|
||||||
|
|
||||||
data_json = data.model_dump(exclude_unset=True, exclude_none=True) # type: ignore
|
data_json = data.model_dump(exclude_unset=True, exclude_none=True) # type: ignore
|
||||||
|
|
||||||
# if we get max_budget passed to /key/generate, then use it as key_max_budget. Since generate_key_helper_fn is used to make new users
|
# if we get max_budget passed to /key/generate, then use it as key_max_budget. Since generate_key_helper_fn is used to make new users
|
||||||
|
@ -531,7 +545,8 @@ async def generate_key_fn( # noqa: PLR0915
|
||||||
f"Only premium users can add tags to keys. {CommonProxyErrors.not_premium_user.value}"
|
f"Only premium users can add tags to keys. {CommonProxyErrors.not_premium_user.value}"
|
||||||
)
|
)
|
||||||
|
|
||||||
if data_json["metadata"] is None:
|
_metadata = data_json.get("metadata")
|
||||||
|
if not _metadata:
|
||||||
data_json["metadata"] = {"tags": data_json["tags"]}
|
data_json["metadata"] = {"tags": data_json["tags"]}
|
||||||
else:
|
else:
|
||||||
data_json["metadata"]["tags"] = data_json["tags"]
|
data_json["metadata"]["tags"] = data_json["tags"]
|
||||||
|
|
|
@ -14,7 +14,7 @@ import json
|
||||||
import traceback
|
import traceback
|
||||||
import uuid
|
import uuid
|
||||||
from datetime import datetime, timedelta, timezone
|
from datetime import datetime, timedelta, timezone
|
||||||
from typing import Any, List, Optional, Tuple, Union, cast
|
from typing import List, Optional, Tuple, Union, cast
|
||||||
|
|
||||||
import fastapi
|
import fastapi
|
||||||
from fastapi import APIRouter, Depends, Header, HTTPException, Request, status
|
from fastapi import APIRouter, Depends, Header, HTTPException, Request, status
|
||||||
|
@ -57,7 +57,10 @@ from litellm.proxy.auth.auth_checks import (
|
||||||
get_team_object,
|
get_team_object,
|
||||||
)
|
)
|
||||||
from litellm.proxy.auth.user_api_key_auth import user_api_key_auth
|
from litellm.proxy.auth.user_api_key_auth import user_api_key_auth
|
||||||
from litellm.proxy.management_endpoints.common_utils import _is_user_team_admin
|
from litellm.proxy.management_endpoints.common_utils import (
|
||||||
|
_is_user_team_admin,
|
||||||
|
_set_object_metadata_field,
|
||||||
|
)
|
||||||
from litellm.proxy.management_helpers.utils import (
|
from litellm.proxy.management_helpers.utils import (
|
||||||
add_new_member,
|
add_new_member,
|
||||||
management_endpoint_wrapper,
|
management_endpoint_wrapper,
|
||||||
|
@ -283,8 +286,8 @@ async def new_team( # noqa: PLR0915
|
||||||
# Set Management Endpoint Metadata Fields
|
# Set Management Endpoint Metadata Fields
|
||||||
for field in LiteLLM_ManagementEndpoint_MetadataFields_Premium:
|
for field in LiteLLM_ManagementEndpoint_MetadataFields_Premium:
|
||||||
if getattr(data, field) is not None:
|
if getattr(data, field) is not None:
|
||||||
_set_team_metadata_field(
|
_set_object_metadata_field(
|
||||||
team_data=complete_team_data,
|
object_data=complete_team_data,
|
||||||
field_name=field,
|
field_name=field,
|
||||||
value=getattr(data, field),
|
value=getattr(data, field),
|
||||||
)
|
)
|
||||||
|
@ -1274,9 +1277,13 @@ async def team_info(
|
||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
team_info: BaseModel = await prisma_client.db.litellm_teamtable.find_unique(
|
team_info: Optional[BaseModel] = (
|
||||||
|
await prisma_client.db.litellm_teamtable.find_unique(
|
||||||
where={"team_id": team_id}
|
where={"team_id": team_id}
|
||||||
)
|
)
|
||||||
|
)
|
||||||
|
if team_info is None:
|
||||||
|
raise Exception
|
||||||
except Exception:
|
except Exception:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_404_NOT_FOUND,
|
status_code=status.HTTP_404_NOT_FOUND,
|
||||||
|
@ -1673,23 +1680,6 @@ def _update_team_metadata_field(updated_kv: dict, field_name: str) -> None:
|
||||||
updated_kv["metadata"] = {field_name: _value}
|
updated_kv["metadata"] = {field_name: _value}
|
||||||
|
|
||||||
|
|
||||||
def _set_team_metadata_field(
|
|
||||||
team_data: LiteLLM_TeamTable, field_name: str, value: Any
|
|
||||||
) -> None:
|
|
||||||
"""
|
|
||||||
Helper function to set metadata fields that require premium user checks
|
|
||||||
|
|
||||||
Args:
|
|
||||||
team_data: The team data object to modify
|
|
||||||
field_name: Name of the metadata field to set
|
|
||||||
value: Value to set for the field
|
|
||||||
"""
|
|
||||||
if field_name in LiteLLM_ManagementEndpoint_MetadataFields_Premium:
|
|
||||||
_premium_user_check()
|
|
||||||
team_data.metadata = team_data.metadata or {}
|
|
||||||
team_data.metadata[field_name] = value
|
|
||||||
|
|
||||||
|
|
||||||
@router.get(
|
@router.get(
|
||||||
"/team/filter/ui",
|
"/team/filter/ui",
|
||||||
tags=["team management"],
|
tags=["team management"],
|
||||||
|
|
|
@ -6,6 +6,8 @@ import aiohttp
|
||||||
import time, uuid
|
import time, uuid
|
||||||
from openai import AsyncOpenAI
|
from openai import AsyncOpenAI
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
import openai
|
||||||
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
|
|
||||||
async def get_user_info(session, get_user, call_user, view_all: Optional[bool] = None):
|
async def get_user_info(session, get_user, call_user, view_all: Optional[bool] = None):
|
||||||
|
@ -358,6 +360,11 @@ async def get_team_info(session, get_team, call_key):
|
||||||
print(response_text)
|
print(response_text)
|
||||||
print()
|
print()
|
||||||
|
|
||||||
|
if status == 404:
|
||||||
|
raise openai.NotFoundError(
|
||||||
|
message="404 received", response=MagicMock(), body=None
|
||||||
|
)
|
||||||
|
|
||||||
if status != 200:
|
if status != 200:
|
||||||
raise Exception(f"Request did not return a 200 status code: {status}")
|
raise Exception(f"Request did not return a 200 status code: {status}")
|
||||||
return await response.json()
|
return await response.json()
|
||||||
|
@ -549,7 +556,7 @@ async def test_team_delete():
|
||||||
key_gen = await generate_key(session=session, i=0, team_id=team_data["team_id"])
|
key_gen = await generate_key(session=session, i=0, team_id=team_data["team_id"])
|
||||||
key = key_gen["key"]
|
key = key_gen["key"]
|
||||||
## Test key
|
## Test key
|
||||||
response = await chat_completion(session=session, key=key)
|
# response = await chat_completion(session=session, key=key)
|
||||||
## Delete team
|
## Delete team
|
||||||
await delete_team(session=session, i=0, team_id=team_data["team_id"])
|
await delete_team(session=session, i=0, team_id=team_data["team_id"])
|
||||||
|
|
||||||
|
@ -559,6 +566,12 @@ async def test_team_delete():
|
||||||
)
|
)
|
||||||
assert len(user_info["teams"]) == 0
|
assert len(user_info["teams"]) == 0
|
||||||
|
|
||||||
|
## ASSERT TEAM INFO NOW RETURNS A 404
|
||||||
|
with pytest.raises(openai.NotFoundError):
|
||||||
|
await get_team_info(
|
||||||
|
session=session, get_team=team_data["team_id"], call_key="sk-1234"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("dimension", ["user_id", "user_email"])
|
@pytest.mark.parametrize("dimension", ["user_id", "user_email"])
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue