Allow team admins to add/update/delete models on UI + show api base and model id on request logs (#9572)

* feat(view_logs.tsx): show model id + api base in request logs

easier debugging

* fix(index.tsx): fix length of api base

easier viewing

* refactor(leftnav.tsx): show models tab to team admin

* feat(model_dashboard.tsx): add explainer for what the 'models' page is for team admin

helps them understand how they can use it

* feat(model_management_endpoints.py): restrict model add by team to just team admin

allow team admin to add models via non-team keys (e.g. ui token)

* test(test_add_update_models.py): update unit testing for new behaviour

* fix(model_dashboard.tsx): show user the models

* feat(proxy_server.py): add new query param 'user_models_only' to `/v2/model/info`

Allows user to retrieve just the models they've added

Used in UI to show internal users just the models they've added

* feat(model_dashboard.tsx): allow team admins to view their own models

* fix: allow ui user to fetch model cost map

* feat(add_model_tab.tsx): require team admins to specify team when onboarding models

* fix(_types.py): add `/v1/model/info` to info route

`/model/info` was already there

* fix(model_info_view.tsx): allow user to edit a model they created

* fix(model_management_endpoints.py): allow team admin to update team model

* feat(model_managament_endpoints.py): allow team admin to delete team models

* fix(model_management_endpoints.py): don't require team id to be set when adding a model

* fix(proxy_server.py): fix linting error

* fix: fix ui linting error

* fix(model_management_endpoints.py): ensure consistent auth checks on all model calls

* test: remove old test - function no longer exists in same form

* test: add updated mock testing
This commit is contained in:
Krish Dholakia 2025-03-27 12:06:31 -07:00 committed by GitHub
parent a5fbe50f04
commit ed8c63b51e
11 changed files with 483 additions and 144 deletions

View file

@ -292,6 +292,7 @@ class LiteLLMRoutes(enum.Enum):
"/team/available", "/team/available",
"/user/info", "/user/info",
"/model/info", "/model/info",
"/v1/model/info",
"/v2/model/info", "/v2/model/info",
"/v2/key/info", "/v2/key/info",
"/model_group/info", "/model_group/info",
@ -386,6 +387,7 @@ class LiteLLMRoutes(enum.Enum):
"/global/predict/spend/logs", "/global/predict/spend/logs",
"/global/activity", "/global/activity",
"/health/services", "/health/services",
"/get/litellm_model_cost_map",
] + info_routes ] + info_routes
internal_user_routes = [ internal_user_routes = [
@ -412,6 +414,8 @@ class LiteLLMRoutes(enum.Enum):
"/team/member_add", "/team/member_add",
"/team/member_delete", "/team/member_delete",
"/model/new", "/model/new",
"/model/update",
"/model/delete",
] # routes that manage their own allowed/disallowed logic ] # routes that manage their own allowed/disallowed logic
## Org Admin Routes ## ## Org Admin Routes ##

View file

@ -13,7 +13,7 @@ model/{model_id}/update - PATCH endpoint for model update.
import asyncio import asyncio
import json import json
import uuid import uuid
from typing import Optional, cast from typing import Literal, Optional, Union, cast
from fastapi import APIRouter, Depends, HTTPException, Request, status from fastapi import APIRouter, Depends, HTTPException, Request, status
from pydantic import BaseModel from pydantic import BaseModel
@ -23,6 +23,7 @@ from litellm.constants import LITELLM_PROXY_ADMIN_NAME
from litellm.proxy._types import ( from litellm.proxy._types import (
CommonProxyErrors, CommonProxyErrors,
LiteLLM_ProxyModelTable, LiteLLM_ProxyModelTable,
LiteLLM_TeamTable,
LitellmTableNames, LitellmTableNames,
LitellmUserRoles, LitellmUserRoles,
ModelInfoDelete, ModelInfoDelete,
@ -35,6 +36,7 @@ from litellm.proxy._types import (
) )
from litellm.proxy.auth.user_api_key_auth import user_api_key_auth from litellm.proxy.auth.user_api_key_auth import user_api_key_auth
from litellm.proxy.common_utils.encrypt_decrypt_utils import encrypt_value_helper from litellm.proxy.common_utils.encrypt_decrypt_utils import encrypt_value_helper
from litellm.proxy.management_endpoints.common_utils import _is_user_team_admin
from litellm.proxy.management_endpoints.team_endpoints import ( from litellm.proxy.management_endpoints.team_endpoints import (
team_model_add, team_model_add,
update_team, update_team,
@ -317,22 +319,111 @@ async def _add_team_model_to_db(
return model_response return model_response
def check_if_team_id_matches_key( class ModelManagementAuthChecks:
team_id: Optional[str], user_api_key_dict: UserAPIKeyAuth """
) -> bool: Common auth checks for model management endpoints
can_make_call = True """
if (
user_api_key_dict.user_role @staticmethod
and user_api_key_dict.user_role == LitellmUserRoles.PROXY_ADMIN def can_user_make_team_model_call(
): team_id: str,
user_api_key_dict: UserAPIKeyAuth,
team_obj: Optional[LiteLLM_TeamTable] = None,
premium_user: bool = False,
) -> Literal[True]:
if premium_user is False:
raise HTTPException(
status_code=403,
detail={"error": CommonProxyErrors.not_premium_user.value},
)
if (
user_api_key_dict.user_role
and user_api_key_dict.user_role == LitellmUserRoles.PROXY_ADMIN
):
return True
elif team_obj is None or not _is_user_team_admin(
user_api_key_dict=user_api_key_dict, team_obj=team_obj
):
raise HTTPException(
status_code=403,
detail={
"error": "Team ID={} does not match the API key's team ID={}, OR you are not the admin for this team. Check `/user/info` to verify your team admin status.".format(
team_id, user_api_key_dict.team_id
)
},
)
return True
@staticmethod
async def allow_team_model_action(
model_params: Union[Deployment, updateDeployment],
user_api_key_dict: UserAPIKeyAuth,
prisma_client: PrismaClient,
premium_user: bool,
) -> Literal[True]:
if model_params.model_info is None or model_params.model_info.team_id is None:
return True
if model_params.model_info.team_id is not None and premium_user is not True:
raise HTTPException(
status_code=403,
detail={"error": CommonProxyErrors.not_premium_user.value},
)
_existing_team_row = await prisma_client.db.litellm_teamtable.find_unique(
where={"team_id": model_params.model_info.team_id}
)
if _existing_team_row is None:
raise HTTPException(
status_code=400,
detail={
"error": "Team id={} does not exist in db".format(
model_params.model_info.team_id
)
},
)
existing_team_row = LiteLLM_TeamTable(**_existing_team_row.model_dump())
ModelManagementAuthChecks.can_user_make_team_model_call(
team_id=model_params.model_info.team_id,
user_api_key_dict=user_api_key_dict,
team_obj=existing_team_row,
premium_user=premium_user,
)
return True
@staticmethod
async def can_user_make_model_call(
model_params: Union[Deployment, updateDeployment],
user_api_key_dict: UserAPIKeyAuth,
prisma_client: PrismaClient,
premium_user: bool,
) -> Literal[True]:
## Check team model auth
if (
model_params.model_info is not None
and model_params.model_info.team_id is not None
):
return ModelManagementAuthChecks.can_user_make_team_model_call(
team_id=model_params.model_info.team_id,
user_api_key_dict=user_api_key_dict,
premium_user=premium_user,
)
## Check non-team model auth
elif user_api_key_dict.user_role != LitellmUserRoles.PROXY_ADMIN:
raise HTTPException(
status_code=403,
detail={
"error": "User does not have permission to make this model call. Your role={}. You can only make model calls if you are a PROXY_ADMIN or if you are a team admin, by specifying a team_id in the model_info.".format(
user_api_key_dict.user_role
)
},
)
else:
return True
return True return True
if team_id is None:
if user_api_key_dict.user_role != LitellmUserRoles.PROXY_ADMIN:
can_make_call = False
else:
if user_api_key_dict.team_id != team_id:
can_make_call = False
return can_make_call
#### [BETA] - This is a beta endpoint, format might change based on user feedback. - https://github.com/BerriAI/litellm/issues/964 #### [BETA] - This is a beta endpoint, format might change based on user feedback. - https://github.com/BerriAI/litellm/issues/964
@ -358,6 +449,7 @@ async def delete_model(
from litellm.proxy.proxy_server import ( from litellm.proxy.proxy_server import (
llm_router, llm_router,
premium_user,
prisma_client, prisma_client,
store_model_in_db, store_model_in_db,
) )
@ -370,6 +462,23 @@ async def delete_model(
}, },
) )
model_in_db = await prisma_client.db.litellm_proxymodeltable.find_unique(
where={"model_id": model_info.id}
)
if model_in_db is None:
raise HTTPException(
status_code=400,
detail={"error": f"Model with id={model_info.id} not found in db"},
)
model_params = Deployment(**model_in_db.model_dump())
await ModelManagementAuthChecks.can_user_make_model_call(
model_params=model_params,
user_api_key_dict=user_api_key_dict,
prisma_client=prisma_client,
premium_user=premium_user,
)
# update DB # update DB
if store_model_in_db is True: if store_model_in_db is True:
""" """
@ -464,19 +573,13 @@ async def add_new_model(
}, },
) )
if model_params.model_info.team_id is not None and premium_user is not True: ## Auth check
raise HTTPException( await ModelManagementAuthChecks.can_user_make_model_call(
status_code=403, model_params=model_params,
detail={"error": CommonProxyErrors.not_premium_user.value}, user_api_key_dict=user_api_key_dict,
) prisma_client=prisma_client,
premium_user=premium_user,
if not check_if_team_id_matches_key( )
team_id=model_params.model_info.team_id, user_api_key_dict=user_api_key_dict
):
raise HTTPException(
status_code=403,
detail={"error": "Team ID does not match the API key's team ID"},
)
model_response: Optional[LiteLLM_ProxyModelTable] = None model_response: Optional[LiteLLM_ProxyModelTable] = None
# update DB # update DB
@ -593,6 +696,7 @@ async def update_model(
from litellm.proxy.proxy_server import ( from litellm.proxy.proxy_server import (
LITELLM_PROXY_ADMIN_NAME, LITELLM_PROXY_ADMIN_NAME,
llm_router, llm_router,
premium_user,
prisma_client, prisma_client,
store_model_in_db, store_model_in_db,
) )
@ -606,6 +710,14 @@ async def update_model(
"error": "No DB Connected. Here's how to do it - https://docs.litellm.ai/docs/proxy/virtual_keys" "error": "No DB Connected. Here's how to do it - https://docs.litellm.ai/docs/proxy/virtual_keys"
}, },
) )
await ModelManagementAuthChecks.can_user_make_model_call(
model_params=model_params,
user_api_key_dict=user_api_key_dict,
prisma_client=prisma_client,
premium_user=premium_user,
)
# update DB # update DB
if store_model_in_db is True: if store_model_in_db is True:
_model_id = None _model_id = None

View file

@ -215,7 +215,6 @@ from litellm.proxy.management_endpoints.key_management_endpoints import (
from litellm.proxy.management_endpoints.model_management_endpoints import ( from litellm.proxy.management_endpoints.model_management_endpoints import (
_add_model_to_db, _add_model_to_db,
_add_team_model_to_db, _add_team_model_to_db,
check_if_team_id_matches_key,
) )
from litellm.proxy.management_endpoints.model_management_endpoints import ( from litellm.proxy.management_endpoints.model_management_endpoints import (
router as model_management_router, router as model_management_router,
@ -5494,9 +5493,40 @@ async def transform_request(request: TransformRequestBody):
return return_raw_request(endpoint=request.call_type, kwargs=request.request_body) return return_raw_request(endpoint=request.call_type, kwargs=request.request_body)
async def _check_if_model_is_user_added(
models: List[Dict],
user_api_key_dict: UserAPIKeyAuth,
prisma_client: Optional[PrismaClient],
) -> List[Dict]:
"""
Check if model is in db
Check if db model is 'created_by' == user_api_key_dict.user_id
Only return models that match
"""
if prisma_client is None:
raise HTTPException(
status_code=500,
detail={"error": CommonProxyErrors.db_not_connected_error.value},
)
filtered_models = []
for model in models:
id = model.get("model_info", {}).get("id", None)
if id is None:
continue
db_model = await prisma_client.db.litellm_proxymodeltable.find_unique(
where={"model_id": id}
)
if db_model is not None:
if db_model.created_by == user_api_key_dict.user_id:
filtered_models.append(model)
return filtered_models
@router.get( @router.get(
"/v2/model/info", "/v2/model/info",
description="v2 - returns all the models set on the config.yaml, shows 'user_access' = True if the user has access to the model. Provides more info about each model in /models, including config.yaml descriptions (except api key and api base)", description="v2 - returns models available to the user based on their API key permissions. Shows model info from config.yaml (except api key and api base). Filter to just user-added models with ?user_models_only=true",
tags=["model management"], tags=["model management"],
dependencies=[Depends(user_api_key_auth)], dependencies=[Depends(user_api_key_auth)],
include_in_schema=False, include_in_schema=False,
@ -5506,6 +5536,9 @@ async def model_info_v2(
model: Optional[str] = fastapi.Query( model: Optional[str] = fastapi.Query(
None, description="Specify the model name (optional)" None, description="Specify the model name (optional)"
), ),
user_models_only: Optional[bool] = fastapi.Query(
False, description="Only return models added by this user"
),
debug: Optional[bool] = False, debug: Optional[bool] = False,
): ):
""" """
@ -5536,6 +5569,20 @@ async def model_info_v2(
if model is not None: if model is not None:
all_models = [m for m in all_models if m["model_name"] == model] all_models = [m for m in all_models if m["model_name"] == model]
if user_models_only is True:
"""
Check if model is in db
Check if db model is 'created_by' == user_api_key_dict.user_id
Only return models that match
"""
all_models = await _check_if_model_is_user_added(
models=all_models,
user_api_key_dict=user_api_key_dict,
prisma_client=prisma_client,
)
# fill in model info based on config.yaml and litellm model_prices_and_context_window.json # fill in model info based on config.yaml and litellm model_prices_and_context_window.json
for _model in all_models: for _model in all_models:
# provided model_info in config.yaml # provided model_info in config.yaml

View file

@ -0,0 +1,196 @@
import json
import os
import sys
from typing import Optional
import pytest
from fastapi.testclient import TestClient
sys.path.insert(
0, os.path.abspath("../../../..")
) # Adds the parent directory to the system path
from litellm.proxy._types import (
LiteLLM_TeamTable,
LitellmUserRoles,
Member,
UserAPIKeyAuth,
)
from litellm.proxy.management_endpoints.model_management_endpoints import (
ModelManagementAuthChecks,
)
from litellm.proxy.utils import PrismaClient
from litellm.types.router import Deployment, LiteLLM_Params, updateDeployment
class MockPrismaClient:
def __init__(self, team_exists: bool = True):
self.team_exists = team_exists
self.db = self
async def find_unique(self, where):
if self.team_exists:
return LiteLLM_TeamTable(
team_id=where["team_id"],
team_alias="test_team",
members_with_roles=[Member(user_id="test_user", role="admin")],
)
return None
@property
def litellm_teamtable(self):
return self
class TestModelManagementAuthChecks:
def setup_method(self):
"""Setup test cases"""
self.admin_user = UserAPIKeyAuth(
user_id="test_admin", user_role=LitellmUserRoles.PROXY_ADMIN
)
self.normal_user = UserAPIKeyAuth(
user_id="test_user", user_role=LitellmUserRoles.INTERNAL_USER
)
self.team_admin_user = UserAPIKeyAuth(
user_id="test_user",
team_id="test_team",
user_role=LitellmUserRoles.INTERNAL_USER,
)
@pytest.mark.asyncio
async def test_can_user_make_team_model_call_admin_success(self):
"""Test that admin users can make team model calls"""
result = ModelManagementAuthChecks.can_user_make_team_model_call(
team_id="test_team", user_api_key_dict=self.admin_user, premium_user=True
)
assert result is True
@pytest.mark.asyncio
async def test_can_user_make_team_model_call_non_premium_fails(self):
"""Test that non-premium users cannot make team model calls"""
with pytest.raises(Exception) as exc_info:
ModelManagementAuthChecks.can_user_make_team_model_call(
team_id="test_team",
user_api_key_dict=self.admin_user,
premium_user=False,
)
assert "403" in str(exc_info.value)
@pytest.mark.asyncio
async def test_can_user_make_team_model_call_team_admin_success(self):
"""Test that team admins can make calls for their team"""
team_obj = LiteLLM_TeamTable(
team_id="test_team",
team_alias="test_team",
members_with_roles=[
Member(user_id=self.team_admin_user.user_id, role="admin")
],
)
result = ModelManagementAuthChecks.can_user_make_team_model_call(
team_id="test_team",
user_api_key_dict=self.team_admin_user,
team_obj=team_obj,
premium_user=True,
)
assert result is True
@pytest.mark.asyncio
async def test_allow_team_model_action_success(self):
"""Test successful team model action"""
model_params = Deployment(
model_name="test_model",
litellm_params=LiteLLM_Params(model="test_model", team_id="test_team"),
model_info={"team_id": "test_team"},
)
prisma_client = MockPrismaClient(team_exists=True)
result = await ModelManagementAuthChecks.allow_team_model_action(
model_params=model_params,
user_api_key_dict=self.admin_user,
prisma_client=prisma_client,
premium_user=True,
)
assert result is True
@pytest.mark.asyncio
async def test_allow_team_model_action_non_premium_fails(self):
"""Test team model action fails for non-premium users"""
model_params = Deployment(
model_name="test_model",
litellm_params=LiteLLM_Params(model="test_model", team_id="test_team"),
model_info={"team_id": "test_team"},
)
prisma_client = MockPrismaClient(team_exists=True)
with pytest.raises(Exception) as exc_info:
await ModelManagementAuthChecks.allow_team_model_action(
model_params=model_params,
user_api_key_dict=self.admin_user,
prisma_client=prisma_client,
premium_user=False,
)
assert "403" in str(exc_info.value)
@pytest.mark.asyncio
async def test_allow_team_model_action_nonexistent_team_fails(self):
"""Test team model action fails for non-existent team"""
model_params = Deployment(
model_name="test_model",
litellm_params=LiteLLM_Params(
model="test_model",
),
model_info={"team_id": "nonexistent_team"},
)
prisma_client = MockPrismaClient(team_exists=False)
with pytest.raises(Exception) as exc_info:
await ModelManagementAuthChecks.allow_team_model_action(
model_params=model_params,
user_api_key_dict=self.admin_user,
prisma_client=prisma_client,
premium_user=True,
)
assert "400" in str(exc_info.value)
@pytest.mark.asyncio
async def test_can_user_make_model_call_admin_success(self):
"""Test that admin users can make any model call"""
model_params = Deployment(
model_name="test_model",
litellm_params=LiteLLM_Params(
model="test_model",
),
model_info={"team_id": "test_team"},
)
prisma_client = MockPrismaClient()
result = await ModelManagementAuthChecks.can_user_make_model_call(
model_params=model_params,
user_api_key_dict=self.admin_user,
prisma_client=prisma_client,
premium_user=True,
)
assert result is True
@pytest.mark.asyncio
async def test_can_user_make_model_call_normal_user_fails(self):
"""Test that normal users cannot make model calls"""
model_params = Deployment(
model_name="test_model",
litellm_params=LiteLLM_Params(
model="test_model",
),
model_info={"team_id": "test_team"},
)
prisma_client = MockPrismaClient()
with pytest.raises(Exception) as exc_info:
await ModelManagementAuthChecks.can_user_make_model_call(
model_params=model_params,
user_api_key_dict=self.normal_user,
prisma_client=prisma_client,
premium_user=True,
)
assert "403" in str(exc_info.value)

View file

@ -110,49 +110,6 @@ async def test_add_new_model(prisma_client):
assert _new_model_in_db is not None assert _new_model_in_db is not None
@pytest.mark.parametrize(
"team_id, key_team_id, user_role, expected_result",
[
("1234", "1234", LitellmUserRoles.PROXY_ADMIN.value, True),
(
"1234",
"1235",
LitellmUserRoles.PROXY_ADMIN.value,
True,
), # proxy admin can add models for any team
(None, "1234", LitellmUserRoles.PROXY_ADMIN.value, True),
(None, None, LitellmUserRoles.PROXY_ADMIN.value, True),
(
"1234",
"1234",
LitellmUserRoles.INTERNAL_USER.value,
True,
), # internal users can add models for their team
("1234", "1235", LitellmUserRoles.INTERNAL_USER.value, False),
(None, "1234", LitellmUserRoles.INTERNAL_USER.value, False),
(
None,
None,
LitellmUserRoles.INTERNAL_USER.value,
False,
), # internal users cannot add models by default
],
)
def test_can_add_model(team_id, key_team_id, user_role, expected_result):
from litellm.proxy.proxy_server import check_if_team_id_matches_key
args = {
"team_id": team_id,
"user_api_key_dict": UserAPIKeyAuth(
user_role=user_role,
api_key="sk-1234",
team_id=key_team_id,
),
}
assert check_if_team_id_matches_key(**args) is expected_result
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.skip(reason="new feature, tests passing locally") @pytest.mark.skip(reason="new feature, tests passing locally")
async def test_add_update_model(prisma_client): async def test_add_update_model(prisma_client):

View file

@ -13,6 +13,8 @@ import ConnectionErrorDisplay from "./model_connection_test";
import { TEST_MODES } from "./add_model_modes"; import { TEST_MODES } from "./add_model_modes";
import { Row, Col } from "antd"; import { Row, Col } from "antd";
import { Text, TextInput } from "@tremor/react"; import { Text, TextInput } from "@tremor/react";
import TeamDropdown from "../common_components/team_dropdown";
import { all_admin_roles } from "@/utils/roles";
interface AddModelTabProps { interface AddModelTabProps {
form: FormInstance; form: FormInstance;
@ -28,6 +30,7 @@ interface AddModelTabProps {
teams: Team[] | null; teams: Team[] | null;
credentials: CredentialItem[]; credentials: CredentialItem[];
accessToken: string; accessToken: string;
userRole: string;
} }
const { Title, Link } = Typography; const { Title, Link } = Typography;
@ -46,6 +49,7 @@ const AddModelTab: React.FC<AddModelTabProps> = ({
teams, teams,
credentials, credentials,
accessToken, accessToken,
userRole,
}) => { }) => {
// State for test mode and connection testing // State for test mode and connection testing
const [testMode, setTestMode] = useState<string>("chat"); const [testMode, setTestMode] = useState<string>("chat");
@ -64,6 +68,8 @@ const AddModelTab: React.FC<AddModelTabProps> = ({
setIsResultModalVisible(true); setIsResultModalVisible(true);
}; };
const isAdmin = all_admin_roles.includes(userRole);
return ( return (
<> <>
<Title level={2}>Add new model</Title> <Title level={2}>Add new model</Title>
@ -217,6 +223,25 @@ const AddModelTab: React.FC<AddModelTabProps> = ({
); );
}} }}
</Form.Item> </Form.Item>
<div className="flex items-center my-4">
<div className="flex-grow border-t border-gray-200"></div>
<span className="px-4 text-gray-500 text-sm">Team Settings</span>
<div className="flex-grow border-t border-gray-200"></div>
</div>
<Form.Item
label="Team"
name="team_id"
className="mb-4"
tooltip="Only keys for this team, will be able to call this model."
rules={[
{
required: !isAdmin, // Required if not admin
message: 'Please select a team.'
}
]}
>
<TeamDropdown teams={teams} />
</Form.Item>
<AdvancedSettings <AdvancedSettings
showAdvancedSettings={showAdvancedSettings} showAdvancedSettings={showAdvancedSettings}
setShowAdvancedSettings={setShowAdvancedSettings} setShowAdvancedSettings={setShowAdvancedSettings}

View file

@ -91,13 +91,6 @@ const AdvancedSettings: React.FC<AdvancedSettingsProps> = ({
</AccordionHeader> </AccordionHeader>
<AccordionBody> <AccordionBody>
<div className="bg-white rounded-lg"> <div className="bg-white rounded-lg">
<Form.Item
label="Team"
name="team_id"
className="mb-4"
>
<TeamDropdown teams={teams} />
</Form.Item>
<Form.Item <Form.Item
label="Custom Pricing" label="Custom Pricing"

View file

@ -47,7 +47,7 @@ interface MenuItem {
const menuItems: MenuItem[] = [ const menuItems: MenuItem[] = [
{ key: "1", page: "api-keys", label: "Virtual Keys", icon: <KeyOutlined /> }, { key: "1", page: "api-keys", label: "Virtual Keys", icon: <KeyOutlined /> },
{ key: "3", page: "llm-playground", label: "Test Key", icon: <PlayCircleOutlined />, roles: rolesWithWriteAccess }, { key: "3", page: "llm-playground", label: "Test Key", icon: <PlayCircleOutlined />, roles: rolesWithWriteAccess },
{ key: "2", page: "models", label: "Models", icon: <BlockOutlined />, roles: all_admin_roles }, { key: "2", page: "models", label: "Models", icon: <BlockOutlined />, roles: rolesWithWriteAccess },
{ key: "4", page: "usage", label: "Usage", icon: <BarChartOutlined /> }, { key: "4", page: "usage", label: "Usage", icon: <BarChartOutlined /> },
{ key: "6", page: "teams", label: "Teams", icon: <TeamOutlined /> }, { key: "6", page: "teams", label: "Teams", icon: <TeamOutlined /> },
{ key: "17", page: "organizations", label: "Organizations", icon: <BankOutlined />, roles: all_admin_roles }, { key: "17", page: "organizations", label: "Organizations", icon: <BankOutlined />, roles: all_admin_roles },

View file

@ -111,6 +111,7 @@ import ModelInfoView from "./model_info_view";
import AddModelTab from "./add_model/add_model_tab"; import AddModelTab from "./add_model/add_model_tab";
import { ModelDataTable } from "./model_dashboard/table"; import { ModelDataTable } from "./model_dashboard/table";
import { columns } from "./model_dashboard/columns"; import { columns } from "./model_dashboard/columns";
import { all_admin_roles } from "@/utils/roles";
interface ModelDashboardProps { interface ModelDashboardProps {
accessToken: string | null; accessToken: string | null;
@ -479,9 +480,6 @@ const ModelDashboard: React.FC<ModelDashboardProps> = ({
} }
const fetchData = async () => { const fetchData = async () => {
try { try {
const _providerSettings = await modelSettingsCall(accessToken);
setProviderSettings(_providerSettings);
// Replace with your actual API call for model data // Replace with your actual API call for model data
const modelDataResponse = await modelInfoCall( const modelDataResponse = await modelInfoCall(
accessToken, accessToken,
@ -490,6 +488,12 @@ const ModelDashboard: React.FC<ModelDashboardProps> = ({
); );
console.log("Model data response:", modelDataResponse.data); console.log("Model data response:", modelDataResponse.data);
setModelData(modelDataResponse); setModelData(modelDataResponse);
const _providerSettings = await modelSettingsCall(accessToken);
if (_providerSettings) {
setProviderSettings(_providerSettings);
}
// loop through modelDataResponse and get all`model_name` values // loop through modelDataResponse and get all`model_name` values
let all_model_groups: Set<string> = new Set(); let all_model_groups: Set<string> = new Set();
@ -1001,7 +1005,7 @@ const ModelDashboard: React.FC<ModelDashboardProps> = ({
); );
let dynamicProviderForm: ProviderSettings | undefined = undefined; let dynamicProviderForm: ProviderSettings | undefined = undefined;
if (providerKey) { if (providerKey && providerSettings) {
dynamicProviderForm = providerSettings.find( dynamicProviderForm = providerSettings.find(
(provider) => provider.name === provider_map[providerKey] (provider) => provider.name === provider_map[providerKey]
); );
@ -1055,16 +1059,17 @@ const ModelDashboard: React.FC<ModelDashboardProps> = ({
/> />
) : ( ) : (
<TabGroup className="gap-2 p-8 h-[75vh] w-full mt-2"> <TabGroup className="gap-2 p-8 h-[75vh] w-full mt-2">
<TabList className="flex justify-between mt-2 w-full items-center"> <TabList className="flex justify-between mt-2 w-full items-center">
<div className="flex"> <div className="flex">
<Tab>All Models</Tab> {all_admin_roles.includes(userRole) ? <Tab>All Models</Tab> : <Tab>Your Models</Tab>}
<Tab>Add Model</Tab> <Tab>Add Model</Tab>
<Tab>LLM Credentials</Tab> {all_admin_roles.includes(userRole) && <Tab>LLM Credentials</Tab>}
<Tab> {all_admin_roles.includes(userRole) && <Tab>
<pre>/health Models</pre> <pre>/health Models</pre>
</Tab> </Tab>}
<Tab>Model Analytics</Tab> {all_admin_roles.includes(userRole) && <Tab>Model Analytics</Tab>}
<Tab>Model Retry Settings</Tab> {all_admin_roles.includes(userRole) && <Tab>Model Retry Settings</Tab>}
</div> </div>
@ -1081,26 +1086,25 @@ const ModelDashboard: React.FC<ModelDashboardProps> = ({
</TabList> </TabList>
<TabPanels> <TabPanels>
<TabPanel> <TabPanel>
<Grid> <div className="flex justify-between items-center mb-6">
<div className="flex items-center"> {/* Left side - Title and description */}
<Text>Filter by Public Model Name</Text> <div>
<Title>Model Management</Title>
{!all_admin_roles.includes(userRole) && <Text className="text-tremor-content">
Add models for teams you are an admin for.
</Text>}
</div>
{/* Right side - Filter */}
<div className="flex items-center gap-2">
<Text>Filter by Public Model Name:</Text>
<Select <Select
className="mb-4 mt-2 ml-2 w-50" className="w-64"
defaultValue={ defaultValue={selectedModelGroup ?? "all"}
selectedModelGroup onValueChange={(value) => setSelectedModelGroup(value === "all" ? "all" : value)}
? selectedModelGroup value={selectedModelGroup ?? "all"}
: undefined
}
onValueChange={(value) =>
setSelectedModelGroup(value === "all" ? "all" : value)
}
value={
selectedModelGroup
? selectedModelGroup
: undefined
}
> >
<SelectItem value={"all"}>All Models</SelectItem> <SelectItem value="all">All Models</SelectItem>
{availableModelGroups.map((group, idx) => ( {availableModelGroups.map((group, idx) => (
<SelectItem <SelectItem
key={idx} key={idx}
@ -1112,30 +1116,24 @@ const ModelDashboard: React.FC<ModelDashboardProps> = ({
))} ))}
</Select> </Select>
</div> </div>
<ModelDataTable </div>
columns={columns( <ModelDataTable
premiumUser, columns={columns(
setSelectedModelId, premiumUser,
setSelectedTeamId, setSelectedModelId,
getDisplayModelName, setSelectedTeamId,
handleEditClick, getDisplayModelName,
handleRefreshClick, handleEditClick,
setEditModel handleRefreshClick,
)} setEditModel
data={modelData.data.filter( )}
(model: any) => data={modelData.data.filter(
selectedModelGroup === "all" || (model: any) =>
model.model_name === selectedModelGroup || selectedModelGroup === "all" ||
!selectedModelGroup model.model_name === selectedModelGroup ||
)} !selectedModelGroup
isLoading={false} // Add loading state if needed )}
/> isLoading={false} // Add loading state if needed
</Grid>
<EditModelModal
visible={editModalVisible}
onCancel={handleEditCancel}
model={selectedModel}
onSubmit={(data: FormData) => handleEditModelSubmit(data, accessToken, setEditModalVisible, setSelectedModel)}
/> />
</TabPanel> </TabPanel>
<TabPanel className="h-full"> <TabPanel className="h-full">
@ -1153,6 +1151,7 @@ const ModelDashboard: React.FC<ModelDashboardProps> = ({
teams={teams} teams={teams}
credentials={credentialsList} credentials={credentialsList}
accessToken={accessToken} accessToken={accessToken}
userRole={userRole}
/> />
</TabPanel> </TabPanel>
<TabPanel> <TabPanel>

View file

@ -58,7 +58,7 @@ export default function ModelInfoView({
const [isEditing, setIsEditing] = useState(false); const [isEditing, setIsEditing] = useState(false);
const [existingCredential, setExistingCredential] = useState<CredentialItem | null>(null); const [existingCredential, setExistingCredential] = useState<CredentialItem | null>(null);
const canEditModel = userRole === "Admin"; const canEditModel = userRole === "Admin" || modelData.model_info.created_by === userID;
const isAdmin = userRole === "Admin"; const isAdmin = userRole === "Admin";
const usingExistingCredential = modelData.litellm_params?.litellm_credential_name != null && modelData.litellm_params?.litellm_credential_name != undefined; const usingExistingCredential = modelData.litellm_params?.litellm_credential_name != null && modelData.litellm_params?.litellm_credential_name != undefined;
@ -209,8 +209,8 @@ export default function ModelInfoView({
<Title>Public Model Name: {getDisplayModelName(modelData)}</Title> <Title>Public Model Name: {getDisplayModelName(modelData)}</Title>
<Text className="text-gray-500 font-mono">{modelData.model_info.id}</Text> <Text className="text-gray-500 font-mono">{modelData.model_info.id}</Text>
</div> </div>
{isAdmin && ( <div className="flex gap-2">
<div className="flex gap-2"> {isAdmin && (
<TremorButton <TremorButton
icon={KeyIcon} icon={KeyIcon}
variant="secondary" variant="secondary"
@ -219,6 +219,8 @@ export default function ModelInfoView({
> >
Re-use Credentials Re-use Credentials
</TremorButton> </TremorButton>
)}
{canEditModel && (
<TremorButton <TremorButton
icon={TrashIcon} icon={TrashIcon}
variant="secondary" variant="secondary"
@ -227,8 +229,8 @@ export default function ModelInfoView({
> >
Delete Model Delete Model
</TremorButton> </TremorButton>
</div> )}
)} </div>
</div> </div>
<TabGroup> <TabGroup>

View file

@ -1,6 +1,7 @@
/** /**
* Helper file for calls being made to proxy * Helper file for calls being made to proxy
*/ */
import { all_admin_roles } from "@/utils/roles";
import { message } from "antd"; import { message } from "antd";
const isLocal = process.env.NODE_ENV === "development"; const isLocal = process.env.NODE_ENV === "development";
@ -131,9 +132,9 @@ export const modelCreateCall = async (
}); });
if (!response.ok) { if (!response.ok) {
const errorData = await response.json(); const errorData = await response.text();
const errorMsg = const errorMsg =
errorData.error?.message?.error || errorData||
"Network response was not ok"; "Network response was not ok";
message.error(errorMsg); message.error(errorMsg);
throw new Error(errorMsg); throw new Error(errorMsg);
@ -183,9 +184,8 @@ export const modelSettingsCall = async (accessToken: String) => {
//message.info("Received model data"); //message.info("Received model data");
return data; return data;
// Handle success - you might want to update some state or UI based on the created key // Handle success - you might want to update some state or UI based on the created key
} catch (error) { } catch (error: any) {
console.error("Failed to get callbacks:", error); console.error("Failed to get model settings:", error);
throw error;
} }
}; };
@ -1213,8 +1213,12 @@ export const modelInfoCall = async (
* Get all models on proxy * Get all models on proxy
*/ */
try { try {
console.log("modelInfoCall:", accessToken, userID, userRole);
let url = proxyBaseUrl ? `${proxyBaseUrl}/v2/model/info` : `/v2/model/info`; let url = proxyBaseUrl ? `${proxyBaseUrl}/v2/model/info` : `/v2/model/info`;
if (!all_admin_roles.includes(userRole as string)) { // only show users models they've added
url += `?user_models_only=true`;
}
//message.info("Requesting model data"); //message.info("Requesting model data");
const response = await fetch(url, { const response = await fetch(url, {
method: "GET", method: "GET",