mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-24 18:24:20 +00:00
Allow team admins to add/update/delete models on UI + show api base and model id on request logs (#9572)
* feat(view_logs.tsx): show model id + api base in request logs easier debugging * fix(index.tsx): fix length of api base easier viewing * refactor(leftnav.tsx): show models tab to team admin * feat(model_dashboard.tsx): add explainer for what the 'models' page is for team admin helps them understand how they can use it * feat(model_management_endpoints.py): restrict model add by team to just team admin allow team admin to add models via non-team keys (e.g. ui token) * test(test_add_update_models.py): update unit testing for new behaviour * fix(model_dashboard.tsx): show user the models * feat(proxy_server.py): add new query param 'user_models_only' to `/v2/model/info` Allows user to retrieve just the models they've added Used in UI to show internal users just the models they've added * feat(model_dashboard.tsx): allow team admins to view their own models * fix: allow ui user to fetch model cost map * feat(add_model_tab.tsx): require team admins to specify team when onboarding models * fix(_types.py): add `/v1/model/info` to info route `/model/info` was already there * fix(model_info_view.tsx): allow user to edit a model they created * fix(model_management_endpoints.py): allow team admin to update team model * feat(model_managament_endpoints.py): allow team admin to delete team models * fix(model_management_endpoints.py): don't require team id to be set when adding a model * fix(proxy_server.py): fix linting error * fix: fix ui linting error * fix(model_management_endpoints.py): ensure consistent auth checks on all model calls * test: remove old test - function no longer exists in same form * test: add updated mock testing
This commit is contained in:
parent
3ca34a181c
commit
63c9f59373
11 changed files with 483 additions and 144 deletions
|
@ -292,6 +292,7 @@ class LiteLLMRoutes(enum.Enum):
|
|||
"/team/available",
|
||||
"/user/info",
|
||||
"/model/info",
|
||||
"/v1/model/info",
|
||||
"/v2/model/info",
|
||||
"/v2/key/info",
|
||||
"/model_group/info",
|
||||
|
@ -386,6 +387,7 @@ class LiteLLMRoutes(enum.Enum):
|
|||
"/global/predict/spend/logs",
|
||||
"/global/activity",
|
||||
"/health/services",
|
||||
"/get/litellm_model_cost_map",
|
||||
] + info_routes
|
||||
|
||||
internal_user_routes = [
|
||||
|
@ -412,6 +414,8 @@ class LiteLLMRoutes(enum.Enum):
|
|||
"/team/member_add",
|
||||
"/team/member_delete",
|
||||
"/model/new",
|
||||
"/model/update",
|
||||
"/model/delete",
|
||||
] # routes that manage their own allowed/disallowed logic
|
||||
|
||||
## Org Admin Routes ##
|
||||
|
|
|
@ -13,7 +13,7 @@ model/{model_id}/update - PATCH endpoint for model update.
|
|||
import asyncio
|
||||
import json
|
||||
import uuid
|
||||
from typing import Optional, cast
|
||||
from typing import Literal, Optional, Union, cast
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request, status
|
||||
from pydantic import BaseModel
|
||||
|
@ -23,6 +23,7 @@ from litellm.constants import LITELLM_PROXY_ADMIN_NAME
|
|||
from litellm.proxy._types import (
|
||||
CommonProxyErrors,
|
||||
LiteLLM_ProxyModelTable,
|
||||
LiteLLM_TeamTable,
|
||||
LitellmTableNames,
|
||||
LitellmUserRoles,
|
||||
ModelInfoDelete,
|
||||
|
@ -35,6 +36,7 @@ from litellm.proxy._types import (
|
|||
)
|
||||
from litellm.proxy.auth.user_api_key_auth import user_api_key_auth
|
||||
from litellm.proxy.common_utils.encrypt_decrypt_utils import encrypt_value_helper
|
||||
from litellm.proxy.management_endpoints.common_utils import _is_user_team_admin
|
||||
from litellm.proxy.management_endpoints.team_endpoints import (
|
||||
team_model_add,
|
||||
update_team,
|
||||
|
@ -317,22 +319,111 @@ async def _add_team_model_to_db(
|
|||
return model_response
|
||||
|
||||
|
||||
def check_if_team_id_matches_key(
|
||||
team_id: Optional[str], user_api_key_dict: UserAPIKeyAuth
|
||||
) -> bool:
|
||||
can_make_call = True
|
||||
if (
|
||||
user_api_key_dict.user_role
|
||||
and user_api_key_dict.user_role == LitellmUserRoles.PROXY_ADMIN
|
||||
):
|
||||
class ModelManagementAuthChecks:
|
||||
"""
|
||||
Common auth checks for model management endpoints
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def can_user_make_team_model_call(
|
||||
team_id: str,
|
||||
user_api_key_dict: UserAPIKeyAuth,
|
||||
team_obj: Optional[LiteLLM_TeamTable] = None,
|
||||
premium_user: bool = False,
|
||||
) -> Literal[True]:
|
||||
if premium_user is False:
|
||||
raise HTTPException(
|
||||
status_code=403,
|
||||
detail={"error": CommonProxyErrors.not_premium_user.value},
|
||||
)
|
||||
if (
|
||||
user_api_key_dict.user_role
|
||||
and user_api_key_dict.user_role == LitellmUserRoles.PROXY_ADMIN
|
||||
):
|
||||
return True
|
||||
elif team_obj is None or not _is_user_team_admin(
|
||||
user_api_key_dict=user_api_key_dict, team_obj=team_obj
|
||||
):
|
||||
raise HTTPException(
|
||||
status_code=403,
|
||||
detail={
|
||||
"error": "Team ID={} does not match the API key's team ID={}, OR you are not the admin for this team. Check `/user/info` to verify your team admin status.".format(
|
||||
team_id, user_api_key_dict.team_id
|
||||
)
|
||||
},
|
||||
)
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
async def allow_team_model_action(
|
||||
model_params: Union[Deployment, updateDeployment],
|
||||
user_api_key_dict: UserAPIKeyAuth,
|
||||
prisma_client: PrismaClient,
|
||||
premium_user: bool,
|
||||
) -> Literal[True]:
|
||||
if model_params.model_info is None or model_params.model_info.team_id is None:
|
||||
return True
|
||||
if model_params.model_info.team_id is not None and premium_user is not True:
|
||||
raise HTTPException(
|
||||
status_code=403,
|
||||
detail={"error": CommonProxyErrors.not_premium_user.value},
|
||||
)
|
||||
|
||||
_existing_team_row = await prisma_client.db.litellm_teamtable.find_unique(
|
||||
where={"team_id": model_params.model_info.team_id}
|
||||
)
|
||||
|
||||
if _existing_team_row is None:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail={
|
||||
"error": "Team id={} does not exist in db".format(
|
||||
model_params.model_info.team_id
|
||||
)
|
||||
},
|
||||
)
|
||||
existing_team_row = LiteLLM_TeamTable(**_existing_team_row.model_dump())
|
||||
|
||||
ModelManagementAuthChecks.can_user_make_team_model_call(
|
||||
team_id=model_params.model_info.team_id,
|
||||
user_api_key_dict=user_api_key_dict,
|
||||
team_obj=existing_team_row,
|
||||
premium_user=premium_user,
|
||||
)
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
async def can_user_make_model_call(
|
||||
model_params: Union[Deployment, updateDeployment],
|
||||
user_api_key_dict: UserAPIKeyAuth,
|
||||
prisma_client: PrismaClient,
|
||||
premium_user: bool,
|
||||
) -> Literal[True]:
|
||||
|
||||
## Check team model auth
|
||||
if (
|
||||
model_params.model_info is not None
|
||||
and model_params.model_info.team_id is not None
|
||||
):
|
||||
return ModelManagementAuthChecks.can_user_make_team_model_call(
|
||||
team_id=model_params.model_info.team_id,
|
||||
user_api_key_dict=user_api_key_dict,
|
||||
premium_user=premium_user,
|
||||
)
|
||||
## Check non-team model auth
|
||||
elif user_api_key_dict.user_role != LitellmUserRoles.PROXY_ADMIN:
|
||||
raise HTTPException(
|
||||
status_code=403,
|
||||
detail={
|
||||
"error": "User does not have permission to make this model call. Your role={}. You can only make model calls if you are a PROXY_ADMIN or if you are a team admin, by specifying a team_id in the model_info.".format(
|
||||
user_api_key_dict.user_role
|
||||
)
|
||||
},
|
||||
)
|
||||
else:
|
||||
return True
|
||||
|
||||
return True
|
||||
if team_id is None:
|
||||
if user_api_key_dict.user_role != LitellmUserRoles.PROXY_ADMIN:
|
||||
can_make_call = False
|
||||
else:
|
||||
if user_api_key_dict.team_id != team_id:
|
||||
can_make_call = False
|
||||
return can_make_call
|
||||
|
||||
|
||||
#### [BETA] - This is a beta endpoint, format might change based on user feedback. - https://github.com/BerriAI/litellm/issues/964
|
||||
|
@ -358,6 +449,7 @@ async def delete_model(
|
|||
|
||||
from litellm.proxy.proxy_server import (
|
||||
llm_router,
|
||||
premium_user,
|
||||
prisma_client,
|
||||
store_model_in_db,
|
||||
)
|
||||
|
@ -370,6 +462,23 @@ async def delete_model(
|
|||
},
|
||||
)
|
||||
|
||||
model_in_db = await prisma_client.db.litellm_proxymodeltable.find_unique(
|
||||
where={"model_id": model_info.id}
|
||||
)
|
||||
if model_in_db is None:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail={"error": f"Model with id={model_info.id} not found in db"},
|
||||
)
|
||||
|
||||
model_params = Deployment(**model_in_db.model_dump())
|
||||
await ModelManagementAuthChecks.can_user_make_model_call(
|
||||
model_params=model_params,
|
||||
user_api_key_dict=user_api_key_dict,
|
||||
prisma_client=prisma_client,
|
||||
premium_user=premium_user,
|
||||
)
|
||||
|
||||
# update DB
|
||||
if store_model_in_db is True:
|
||||
"""
|
||||
|
@ -464,19 +573,13 @@ async def add_new_model(
|
|||
},
|
||||
)
|
||||
|
||||
if model_params.model_info.team_id is not None and premium_user is not True:
|
||||
raise HTTPException(
|
||||
status_code=403,
|
||||
detail={"error": CommonProxyErrors.not_premium_user.value},
|
||||
)
|
||||
|
||||
if not check_if_team_id_matches_key(
|
||||
team_id=model_params.model_info.team_id, user_api_key_dict=user_api_key_dict
|
||||
):
|
||||
raise HTTPException(
|
||||
status_code=403,
|
||||
detail={"error": "Team ID does not match the API key's team ID"},
|
||||
)
|
||||
## Auth check
|
||||
await ModelManagementAuthChecks.can_user_make_model_call(
|
||||
model_params=model_params,
|
||||
user_api_key_dict=user_api_key_dict,
|
||||
prisma_client=prisma_client,
|
||||
premium_user=premium_user,
|
||||
)
|
||||
|
||||
model_response: Optional[LiteLLM_ProxyModelTable] = None
|
||||
# update DB
|
||||
|
@ -593,6 +696,7 @@ async def update_model(
|
|||
from litellm.proxy.proxy_server import (
|
||||
LITELLM_PROXY_ADMIN_NAME,
|
||||
llm_router,
|
||||
premium_user,
|
||||
prisma_client,
|
||||
store_model_in_db,
|
||||
)
|
||||
|
@ -606,6 +710,14 @@ async def update_model(
|
|||
"error": "No DB Connected. Here's how to do it - https://docs.litellm.ai/docs/proxy/virtual_keys"
|
||||
},
|
||||
)
|
||||
|
||||
await ModelManagementAuthChecks.can_user_make_model_call(
|
||||
model_params=model_params,
|
||||
user_api_key_dict=user_api_key_dict,
|
||||
prisma_client=prisma_client,
|
||||
premium_user=premium_user,
|
||||
)
|
||||
|
||||
# update DB
|
||||
if store_model_in_db is True:
|
||||
_model_id = None
|
||||
|
|
|
@ -215,7 +215,6 @@ from litellm.proxy.management_endpoints.key_management_endpoints import (
|
|||
from litellm.proxy.management_endpoints.model_management_endpoints import (
|
||||
_add_model_to_db,
|
||||
_add_team_model_to_db,
|
||||
check_if_team_id_matches_key,
|
||||
)
|
||||
from litellm.proxy.management_endpoints.model_management_endpoints import (
|
||||
router as model_management_router,
|
||||
|
@ -5494,9 +5493,40 @@ async def transform_request(request: TransformRequestBody):
|
|||
return return_raw_request(endpoint=request.call_type, kwargs=request.request_body)
|
||||
|
||||
|
||||
async def _check_if_model_is_user_added(
|
||||
models: List[Dict],
|
||||
user_api_key_dict: UserAPIKeyAuth,
|
||||
prisma_client: Optional[PrismaClient],
|
||||
) -> List[Dict]:
|
||||
"""
|
||||
Check if model is in db
|
||||
|
||||
Check if db model is 'created_by' == user_api_key_dict.user_id
|
||||
|
||||
Only return models that match
|
||||
"""
|
||||
if prisma_client is None:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail={"error": CommonProxyErrors.db_not_connected_error.value},
|
||||
)
|
||||
filtered_models = []
|
||||
for model in models:
|
||||
id = model.get("model_info", {}).get("id", None)
|
||||
if id is None:
|
||||
continue
|
||||
db_model = await prisma_client.db.litellm_proxymodeltable.find_unique(
|
||||
where={"model_id": id}
|
||||
)
|
||||
if db_model is not None:
|
||||
if db_model.created_by == user_api_key_dict.user_id:
|
||||
filtered_models.append(model)
|
||||
return filtered_models
|
||||
|
||||
|
||||
@router.get(
|
||||
"/v2/model/info",
|
||||
description="v2 - returns all the models set on the config.yaml, shows 'user_access' = True if the user has access to the model. Provides more info about each model in /models, including config.yaml descriptions (except api key and api base)",
|
||||
description="v2 - returns models available to the user based on their API key permissions. Shows model info from config.yaml (except api key and api base). Filter to just user-added models with ?user_models_only=true",
|
||||
tags=["model management"],
|
||||
dependencies=[Depends(user_api_key_auth)],
|
||||
include_in_schema=False,
|
||||
|
@ -5506,6 +5536,9 @@ async def model_info_v2(
|
|||
model: Optional[str] = fastapi.Query(
|
||||
None, description="Specify the model name (optional)"
|
||||
),
|
||||
user_models_only: Optional[bool] = fastapi.Query(
|
||||
False, description="Only return models added by this user"
|
||||
),
|
||||
debug: Optional[bool] = False,
|
||||
):
|
||||
"""
|
||||
|
@ -5536,6 +5569,20 @@ async def model_info_v2(
|
|||
if model is not None:
|
||||
all_models = [m for m in all_models if m["model_name"] == model]
|
||||
|
||||
if user_models_only is True:
|
||||
"""
|
||||
Check if model is in db
|
||||
|
||||
Check if db model is 'created_by' == user_api_key_dict.user_id
|
||||
|
||||
Only return models that match
|
||||
"""
|
||||
all_models = await _check_if_model_is_user_added(
|
||||
models=all_models,
|
||||
user_api_key_dict=user_api_key_dict,
|
||||
prisma_client=prisma_client,
|
||||
)
|
||||
|
||||
# fill in model info based on config.yaml and litellm model_prices_and_context_window.json
|
||||
for _model in all_models:
|
||||
# provided model_info in config.yaml
|
||||
|
|
|
@ -0,0 +1,196 @@
|
|||
import json
|
||||
import os
|
||||
import sys
|
||||
from typing import Optional
|
||||
|
||||
import pytest
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
sys.path.insert(
|
||||
0, os.path.abspath("../../../..")
|
||||
) # Adds the parent directory to the system path
|
||||
from litellm.proxy._types import (
|
||||
LiteLLM_TeamTable,
|
||||
LitellmUserRoles,
|
||||
Member,
|
||||
UserAPIKeyAuth,
|
||||
)
|
||||
from litellm.proxy.management_endpoints.model_management_endpoints import (
|
||||
ModelManagementAuthChecks,
|
||||
)
|
||||
from litellm.proxy.utils import PrismaClient
|
||||
from litellm.types.router import Deployment, LiteLLM_Params, updateDeployment
|
||||
|
||||
|
||||
class MockPrismaClient:
|
||||
def __init__(self, team_exists: bool = True):
|
||||
self.team_exists = team_exists
|
||||
self.db = self
|
||||
|
||||
async def find_unique(self, where):
|
||||
if self.team_exists:
|
||||
return LiteLLM_TeamTable(
|
||||
team_id=where["team_id"],
|
||||
team_alias="test_team",
|
||||
members_with_roles=[Member(user_id="test_user", role="admin")],
|
||||
)
|
||||
return None
|
||||
|
||||
@property
|
||||
def litellm_teamtable(self):
|
||||
return self
|
||||
|
||||
|
||||
class TestModelManagementAuthChecks:
|
||||
def setup_method(self):
|
||||
"""Setup test cases"""
|
||||
self.admin_user = UserAPIKeyAuth(
|
||||
user_id="test_admin", user_role=LitellmUserRoles.PROXY_ADMIN
|
||||
)
|
||||
|
||||
self.normal_user = UserAPIKeyAuth(
|
||||
user_id="test_user", user_role=LitellmUserRoles.INTERNAL_USER
|
||||
)
|
||||
|
||||
self.team_admin_user = UserAPIKeyAuth(
|
||||
user_id="test_user",
|
||||
team_id="test_team",
|
||||
user_role=LitellmUserRoles.INTERNAL_USER,
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_can_user_make_team_model_call_admin_success(self):
|
||||
"""Test that admin users can make team model calls"""
|
||||
result = ModelManagementAuthChecks.can_user_make_team_model_call(
|
||||
team_id="test_team", user_api_key_dict=self.admin_user, premium_user=True
|
||||
)
|
||||
assert result is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_can_user_make_team_model_call_non_premium_fails(self):
|
||||
"""Test that non-premium users cannot make team model calls"""
|
||||
with pytest.raises(Exception) as exc_info:
|
||||
ModelManagementAuthChecks.can_user_make_team_model_call(
|
||||
team_id="test_team",
|
||||
user_api_key_dict=self.admin_user,
|
||||
premium_user=False,
|
||||
)
|
||||
assert "403" in str(exc_info.value)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_can_user_make_team_model_call_team_admin_success(self):
|
||||
"""Test that team admins can make calls for their team"""
|
||||
team_obj = LiteLLM_TeamTable(
|
||||
team_id="test_team",
|
||||
team_alias="test_team",
|
||||
members_with_roles=[
|
||||
Member(user_id=self.team_admin_user.user_id, role="admin")
|
||||
],
|
||||
)
|
||||
|
||||
result = ModelManagementAuthChecks.can_user_make_team_model_call(
|
||||
team_id="test_team",
|
||||
user_api_key_dict=self.team_admin_user,
|
||||
team_obj=team_obj,
|
||||
premium_user=True,
|
||||
)
|
||||
assert result is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_allow_team_model_action_success(self):
|
||||
"""Test successful team model action"""
|
||||
model_params = Deployment(
|
||||
model_name="test_model",
|
||||
litellm_params=LiteLLM_Params(model="test_model", team_id="test_team"),
|
||||
model_info={"team_id": "test_team"},
|
||||
)
|
||||
prisma_client = MockPrismaClient(team_exists=True)
|
||||
|
||||
result = await ModelManagementAuthChecks.allow_team_model_action(
|
||||
model_params=model_params,
|
||||
user_api_key_dict=self.admin_user,
|
||||
prisma_client=prisma_client,
|
||||
premium_user=True,
|
||||
)
|
||||
assert result is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_allow_team_model_action_non_premium_fails(self):
|
||||
"""Test team model action fails for non-premium users"""
|
||||
model_params = Deployment(
|
||||
model_name="test_model",
|
||||
litellm_params=LiteLLM_Params(model="test_model", team_id="test_team"),
|
||||
model_info={"team_id": "test_team"},
|
||||
)
|
||||
prisma_client = MockPrismaClient(team_exists=True)
|
||||
|
||||
with pytest.raises(Exception) as exc_info:
|
||||
await ModelManagementAuthChecks.allow_team_model_action(
|
||||
model_params=model_params,
|
||||
user_api_key_dict=self.admin_user,
|
||||
prisma_client=prisma_client,
|
||||
premium_user=False,
|
||||
)
|
||||
assert "403" in str(exc_info.value)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_allow_team_model_action_nonexistent_team_fails(self):
|
||||
"""Test team model action fails for non-existent team"""
|
||||
model_params = Deployment(
|
||||
model_name="test_model",
|
||||
litellm_params=LiteLLM_Params(
|
||||
model="test_model",
|
||||
),
|
||||
model_info={"team_id": "nonexistent_team"},
|
||||
)
|
||||
prisma_client = MockPrismaClient(team_exists=False)
|
||||
|
||||
with pytest.raises(Exception) as exc_info:
|
||||
await ModelManagementAuthChecks.allow_team_model_action(
|
||||
model_params=model_params,
|
||||
user_api_key_dict=self.admin_user,
|
||||
prisma_client=prisma_client,
|
||||
premium_user=True,
|
||||
)
|
||||
assert "400" in str(exc_info.value)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_can_user_make_model_call_admin_success(self):
|
||||
"""Test that admin users can make any model call"""
|
||||
model_params = Deployment(
|
||||
model_name="test_model",
|
||||
litellm_params=LiteLLM_Params(
|
||||
model="test_model",
|
||||
),
|
||||
model_info={"team_id": "test_team"},
|
||||
)
|
||||
prisma_client = MockPrismaClient()
|
||||
|
||||
result = await ModelManagementAuthChecks.can_user_make_model_call(
|
||||
model_params=model_params,
|
||||
user_api_key_dict=self.admin_user,
|
||||
prisma_client=prisma_client,
|
||||
premium_user=True,
|
||||
)
|
||||
assert result is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_can_user_make_model_call_normal_user_fails(self):
|
||||
"""Test that normal users cannot make model calls"""
|
||||
model_params = Deployment(
|
||||
model_name="test_model",
|
||||
litellm_params=LiteLLM_Params(
|
||||
model="test_model",
|
||||
),
|
||||
model_info={"team_id": "test_team"},
|
||||
)
|
||||
prisma_client = MockPrismaClient()
|
||||
|
||||
with pytest.raises(Exception) as exc_info:
|
||||
await ModelManagementAuthChecks.can_user_make_model_call(
|
||||
model_params=model_params,
|
||||
user_api_key_dict=self.normal_user,
|
||||
prisma_client=prisma_client,
|
||||
premium_user=True,
|
||||
)
|
||||
assert "403" in str(exc_info.value)
|
|
@ -110,49 +110,6 @@ async def test_add_new_model(prisma_client):
|
|||
assert _new_model_in_db is not None
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"team_id, key_team_id, user_role, expected_result",
|
||||
[
|
||||
("1234", "1234", LitellmUserRoles.PROXY_ADMIN.value, True),
|
||||
(
|
||||
"1234",
|
||||
"1235",
|
||||
LitellmUserRoles.PROXY_ADMIN.value,
|
||||
True,
|
||||
), # proxy admin can add models for any team
|
||||
(None, "1234", LitellmUserRoles.PROXY_ADMIN.value, True),
|
||||
(None, None, LitellmUserRoles.PROXY_ADMIN.value, True),
|
||||
(
|
||||
"1234",
|
||||
"1234",
|
||||
LitellmUserRoles.INTERNAL_USER.value,
|
||||
True,
|
||||
), # internal users can add models for their team
|
||||
("1234", "1235", LitellmUserRoles.INTERNAL_USER.value, False),
|
||||
(None, "1234", LitellmUserRoles.INTERNAL_USER.value, False),
|
||||
(
|
||||
None,
|
||||
None,
|
||||
LitellmUserRoles.INTERNAL_USER.value,
|
||||
False,
|
||||
), # internal users cannot add models by default
|
||||
],
|
||||
)
|
||||
def test_can_add_model(team_id, key_team_id, user_role, expected_result):
|
||||
from litellm.proxy.proxy_server import check_if_team_id_matches_key
|
||||
|
||||
args = {
|
||||
"team_id": team_id,
|
||||
"user_api_key_dict": UserAPIKeyAuth(
|
||||
user_role=user_role,
|
||||
api_key="sk-1234",
|
||||
team_id=key_team_id,
|
||||
),
|
||||
}
|
||||
|
||||
assert check_if_team_id_matches_key(**args) is expected_result
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.skip(reason="new feature, tests passing locally")
|
||||
async def test_add_update_model(prisma_client):
|
||||
|
|
|
@ -13,6 +13,8 @@ import ConnectionErrorDisplay from "./model_connection_test";
|
|||
import { TEST_MODES } from "./add_model_modes";
|
||||
import { Row, Col } from "antd";
|
||||
import { Text, TextInput } from "@tremor/react";
|
||||
import TeamDropdown from "../common_components/team_dropdown";
|
||||
import { all_admin_roles } from "@/utils/roles";
|
||||
|
||||
interface AddModelTabProps {
|
||||
form: FormInstance;
|
||||
|
@ -28,6 +30,7 @@ interface AddModelTabProps {
|
|||
teams: Team[] | null;
|
||||
credentials: CredentialItem[];
|
||||
accessToken: string;
|
||||
userRole: string;
|
||||
}
|
||||
|
||||
const { Title, Link } = Typography;
|
||||
|
@ -46,6 +49,7 @@ const AddModelTab: React.FC<AddModelTabProps> = ({
|
|||
teams,
|
||||
credentials,
|
||||
accessToken,
|
||||
userRole,
|
||||
}) => {
|
||||
// State for test mode and connection testing
|
||||
const [testMode, setTestMode] = useState<string>("chat");
|
||||
|
@ -64,6 +68,8 @@ const AddModelTab: React.FC<AddModelTabProps> = ({
|
|||
setIsResultModalVisible(true);
|
||||
};
|
||||
|
||||
const isAdmin = all_admin_roles.includes(userRole);
|
||||
|
||||
return (
|
||||
<>
|
||||
<Title level={2}>Add new model</Title>
|
||||
|
@ -217,6 +223,25 @@ const AddModelTab: React.FC<AddModelTabProps> = ({
|
|||
);
|
||||
}}
|
||||
</Form.Item>
|
||||
<div className="flex items-center my-4">
|
||||
<div className="flex-grow border-t border-gray-200"></div>
|
||||
<span className="px-4 text-gray-500 text-sm">Team Settings</span>
|
||||
<div className="flex-grow border-t border-gray-200"></div>
|
||||
</div>
|
||||
<Form.Item
|
||||
label="Team"
|
||||
name="team_id"
|
||||
className="mb-4"
|
||||
tooltip="Only keys for this team, will be able to call this model."
|
||||
rules={[
|
||||
{
|
||||
required: !isAdmin, // Required if not admin
|
||||
message: 'Please select a team.'
|
||||
}
|
||||
]}
|
||||
>
|
||||
<TeamDropdown teams={teams} />
|
||||
</Form.Item>
|
||||
<AdvancedSettings
|
||||
showAdvancedSettings={showAdvancedSettings}
|
||||
setShowAdvancedSettings={setShowAdvancedSettings}
|
||||
|
|
|
@ -91,13 +91,6 @@ const AdvancedSettings: React.FC<AdvancedSettingsProps> = ({
|
|||
</AccordionHeader>
|
||||
<AccordionBody>
|
||||
<div className="bg-white rounded-lg">
|
||||
<Form.Item
|
||||
label="Team"
|
||||
name="team_id"
|
||||
className="mb-4"
|
||||
>
|
||||
<TeamDropdown teams={teams} />
|
||||
</Form.Item>
|
||||
|
||||
<Form.Item
|
||||
label="Custom Pricing"
|
||||
|
|
|
@ -47,7 +47,7 @@ interface MenuItem {
|
|||
const menuItems: MenuItem[] = [
|
||||
{ key: "1", page: "api-keys", label: "Virtual Keys", icon: <KeyOutlined /> },
|
||||
{ key: "3", page: "llm-playground", label: "Test Key", icon: <PlayCircleOutlined />, roles: rolesWithWriteAccess },
|
||||
{ key: "2", page: "models", label: "Models", icon: <BlockOutlined />, roles: all_admin_roles },
|
||||
{ key: "2", page: "models", label: "Models", icon: <BlockOutlined />, roles: rolesWithWriteAccess },
|
||||
{ key: "4", page: "usage", label: "Usage", icon: <BarChartOutlined /> },
|
||||
{ key: "6", page: "teams", label: "Teams", icon: <TeamOutlined /> },
|
||||
{ key: "17", page: "organizations", label: "Organizations", icon: <BankOutlined />, roles: all_admin_roles },
|
||||
|
|
|
@ -111,6 +111,7 @@ import ModelInfoView from "./model_info_view";
|
|||
import AddModelTab from "./add_model/add_model_tab";
|
||||
import { ModelDataTable } from "./model_dashboard/table";
|
||||
import { columns } from "./model_dashboard/columns";
|
||||
import { all_admin_roles } from "@/utils/roles";
|
||||
|
||||
interface ModelDashboardProps {
|
||||
accessToken: string | null;
|
||||
|
@ -479,9 +480,6 @@ const ModelDashboard: React.FC<ModelDashboardProps> = ({
|
|||
}
|
||||
const fetchData = async () => {
|
||||
try {
|
||||
const _providerSettings = await modelSettingsCall(accessToken);
|
||||
setProviderSettings(_providerSettings);
|
||||
|
||||
// Replace with your actual API call for model data
|
||||
const modelDataResponse = await modelInfoCall(
|
||||
accessToken,
|
||||
|
@ -490,6 +488,12 @@ const ModelDashboard: React.FC<ModelDashboardProps> = ({
|
|||
);
|
||||
console.log("Model data response:", modelDataResponse.data);
|
||||
setModelData(modelDataResponse);
|
||||
const _providerSettings = await modelSettingsCall(accessToken);
|
||||
if (_providerSettings) {
|
||||
setProviderSettings(_providerSettings);
|
||||
}
|
||||
|
||||
|
||||
|
||||
// loop through modelDataResponse and get all`model_name` values
|
||||
let all_model_groups: Set<string> = new Set();
|
||||
|
@ -1001,7 +1005,7 @@ const ModelDashboard: React.FC<ModelDashboardProps> = ({
|
|||
);
|
||||
|
||||
let dynamicProviderForm: ProviderSettings | undefined = undefined;
|
||||
if (providerKey) {
|
||||
if (providerKey && providerSettings) {
|
||||
dynamicProviderForm = providerSettings.find(
|
||||
(provider) => provider.name === provider_map[providerKey]
|
||||
);
|
||||
|
@ -1055,16 +1059,17 @@ const ModelDashboard: React.FC<ModelDashboardProps> = ({
|
|||
/>
|
||||
) : (
|
||||
<TabGroup className="gap-2 p-8 h-[75vh] w-full mt-2">
|
||||
|
||||
<TabList className="flex justify-between mt-2 w-full items-center">
|
||||
<div className="flex">
|
||||
<Tab>All Models</Tab>
|
||||
{all_admin_roles.includes(userRole) ? <Tab>All Models</Tab> : <Tab>Your Models</Tab>}
|
||||
<Tab>Add Model</Tab>
|
||||
<Tab>LLM Credentials</Tab>
|
||||
<Tab>
|
||||
{all_admin_roles.includes(userRole) && <Tab>LLM Credentials</Tab>}
|
||||
{all_admin_roles.includes(userRole) && <Tab>
|
||||
<pre>/health Models</pre>
|
||||
</Tab>
|
||||
<Tab>Model Analytics</Tab>
|
||||
<Tab>Model Retry Settings</Tab>
|
||||
</Tab>}
|
||||
{all_admin_roles.includes(userRole) && <Tab>Model Analytics</Tab>}
|
||||
{all_admin_roles.includes(userRole) && <Tab>Model Retry Settings</Tab>}
|
||||
|
||||
</div>
|
||||
|
||||
|
@ -1081,26 +1086,25 @@ const ModelDashboard: React.FC<ModelDashboardProps> = ({
|
|||
</TabList>
|
||||
<TabPanels>
|
||||
<TabPanel>
|
||||
<Grid>
|
||||
<div className="flex items-center">
|
||||
<Text>Filter by Public Model Name</Text>
|
||||
<div className="flex justify-between items-center mb-6">
|
||||
{/* Left side - Title and description */}
|
||||
<div>
|
||||
<Title>Model Management</Title>
|
||||
{!all_admin_roles.includes(userRole) && <Text className="text-tremor-content">
|
||||
Add models for teams you are an admin for.
|
||||
</Text>}
|
||||
</div>
|
||||
|
||||
{/* Right side - Filter */}
|
||||
<div className="flex items-center gap-2">
|
||||
<Text>Filter by Public Model Name:</Text>
|
||||
<Select
|
||||
className="mb-4 mt-2 ml-2 w-50"
|
||||
defaultValue={
|
||||
selectedModelGroup
|
||||
? selectedModelGroup
|
||||
: undefined
|
||||
}
|
||||
onValueChange={(value) =>
|
||||
setSelectedModelGroup(value === "all" ? "all" : value)
|
||||
}
|
||||
value={
|
||||
selectedModelGroup
|
||||
? selectedModelGroup
|
||||
: undefined
|
||||
}
|
||||
className="w-64"
|
||||
defaultValue={selectedModelGroup ?? "all"}
|
||||
onValueChange={(value) => setSelectedModelGroup(value === "all" ? "all" : value)}
|
||||
value={selectedModelGroup ?? "all"}
|
||||
>
|
||||
<SelectItem value={"all"}>All Models</SelectItem>
|
||||
<SelectItem value="all">All Models</SelectItem>
|
||||
{availableModelGroups.map((group, idx) => (
|
||||
<SelectItem
|
||||
key={idx}
|
||||
|
@ -1112,30 +1116,24 @@ const ModelDashboard: React.FC<ModelDashboardProps> = ({
|
|||
))}
|
||||
</Select>
|
||||
</div>
|
||||
<ModelDataTable
|
||||
columns={columns(
|
||||
premiumUser,
|
||||
setSelectedModelId,
|
||||
setSelectedTeamId,
|
||||
getDisplayModelName,
|
||||
handleEditClick,
|
||||
handleRefreshClick,
|
||||
setEditModel
|
||||
)}
|
||||
data={modelData.data.filter(
|
||||
(model: any) =>
|
||||
selectedModelGroup === "all" ||
|
||||
model.model_name === selectedModelGroup ||
|
||||
!selectedModelGroup
|
||||
)}
|
||||
isLoading={false} // Add loading state if needed
|
||||
/>
|
||||
</Grid>
|
||||
<EditModelModal
|
||||
visible={editModalVisible}
|
||||
onCancel={handleEditCancel}
|
||||
model={selectedModel}
|
||||
onSubmit={(data: FormData) => handleEditModelSubmit(data, accessToken, setEditModalVisible, setSelectedModel)}
|
||||
</div>
|
||||
<ModelDataTable
|
||||
columns={columns(
|
||||
premiumUser,
|
||||
setSelectedModelId,
|
||||
setSelectedTeamId,
|
||||
getDisplayModelName,
|
||||
handleEditClick,
|
||||
handleRefreshClick,
|
||||
setEditModel
|
||||
)}
|
||||
data={modelData.data.filter(
|
||||
(model: any) =>
|
||||
selectedModelGroup === "all" ||
|
||||
model.model_name === selectedModelGroup ||
|
||||
!selectedModelGroup
|
||||
)}
|
||||
isLoading={false} // Add loading state if needed
|
||||
/>
|
||||
</TabPanel>
|
||||
<TabPanel className="h-full">
|
||||
|
@ -1153,6 +1151,7 @@ const ModelDashboard: React.FC<ModelDashboardProps> = ({
|
|||
teams={teams}
|
||||
credentials={credentialsList}
|
||||
accessToken={accessToken}
|
||||
userRole={userRole}
|
||||
/>
|
||||
</TabPanel>
|
||||
<TabPanel>
|
||||
|
|
|
@ -58,7 +58,7 @@ export default function ModelInfoView({
|
|||
const [isEditing, setIsEditing] = useState(false);
|
||||
const [existingCredential, setExistingCredential] = useState<CredentialItem | null>(null);
|
||||
|
||||
const canEditModel = userRole === "Admin";
|
||||
const canEditModel = userRole === "Admin" || modelData.model_info.created_by === userID;
|
||||
const isAdmin = userRole === "Admin";
|
||||
|
||||
const usingExistingCredential = modelData.litellm_params?.litellm_credential_name != null && modelData.litellm_params?.litellm_credential_name != undefined;
|
||||
|
@ -209,8 +209,8 @@ export default function ModelInfoView({
|
|||
<Title>Public Model Name: {getDisplayModelName(modelData)}</Title>
|
||||
<Text className="text-gray-500 font-mono">{modelData.model_info.id}</Text>
|
||||
</div>
|
||||
{isAdmin && (
|
||||
<div className="flex gap-2">
|
||||
<div className="flex gap-2">
|
||||
{isAdmin && (
|
||||
<TremorButton
|
||||
icon={KeyIcon}
|
||||
variant="secondary"
|
||||
|
@ -219,6 +219,8 @@ export default function ModelInfoView({
|
|||
>
|
||||
Re-use Credentials
|
||||
</TremorButton>
|
||||
)}
|
||||
{canEditModel && (
|
||||
<TremorButton
|
||||
icon={TrashIcon}
|
||||
variant="secondary"
|
||||
|
@ -227,8 +229,8 @@ export default function ModelInfoView({
|
|||
>
|
||||
Delete Model
|
||||
</TremorButton>
|
||||
</div>
|
||||
)}
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<TabGroup>
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
/**
|
||||
* Helper file for calls being made to proxy
|
||||
*/
|
||||
import { all_admin_roles } from "@/utils/roles";
|
||||
import { message } from "antd";
|
||||
|
||||
const isLocal = process.env.NODE_ENV === "development";
|
||||
|
@ -131,9 +132,9 @@ export const modelCreateCall = async (
|
|||
});
|
||||
|
||||
if (!response.ok) {
|
||||
const errorData = await response.json();
|
||||
const errorData = await response.text();
|
||||
const errorMsg =
|
||||
errorData.error?.message?.error ||
|
||||
errorData||
|
||||
"Network response was not ok";
|
||||
message.error(errorMsg);
|
||||
throw new Error(errorMsg);
|
||||
|
@ -183,9 +184,8 @@ export const modelSettingsCall = async (accessToken: String) => {
|
|||
//message.info("Received model data");
|
||||
return data;
|
||||
// Handle success - you might want to update some state or UI based on the created key
|
||||
} catch (error) {
|
||||
console.error("Failed to get callbacks:", error);
|
||||
throw error;
|
||||
} catch (error: any) {
|
||||
console.error("Failed to get model settings:", error);
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -1213,8 +1213,12 @@ export const modelInfoCall = async (
|
|||
* Get all models on proxy
|
||||
*/
|
||||
try {
|
||||
console.log("modelInfoCall:", accessToken, userID, userRole);
|
||||
let url = proxyBaseUrl ? `${proxyBaseUrl}/v2/model/info` : `/v2/model/info`;
|
||||
|
||||
if (!all_admin_roles.includes(userRole as string)) { // only show users models they've added
|
||||
url += `?user_models_only=true`;
|
||||
}
|
||||
//message.info("Requesting model data");
|
||||
const response = await fetch(url, {
|
||||
method: "GET",
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue