[Feat] LiteLLM Tag/Policy Management (#9813)
All checks were successful
Read Version from pyproject.toml / read-version (push) Successful in 15s
Helm unit test / unit-test (push) Successful in 21s

* rendering tags on UI

* use /models for building tags

* CRUD endpoints for Tag management

* fix tag management

* working api for LIST tags

* working tag management

* refactor UI components

* fixes ui tag management

* clean up ui tag management

* fix tag management ui

* fix show allowed llms

* e2e tag controls

* stash change for rendering tags on UI

* ui working tag selector on Test Key page

* fixes for tag management

* clean up tag info

* fix code quality

* test for tag management

* ui clarify what tag routing is
This commit is contained in:
Ishaan Jaff 2025-04-07 21:54:24 -07:00 committed by GitHub
parent ac9f03beae
commit ff3a6830a4
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
16 changed files with 1595 additions and 9 deletions

View file

@ -0,0 +1,356 @@
"""
TAG MANAGEMENT
All /tag management endpoints
/tag/new
/tag/info
/tag/update
/tag/delete
/tag/list
"""
import datetime
import json
from typing import Dict
from fastapi import APIRouter, Depends, HTTPException
from litellm._logging import verbose_proxy_logger
from litellm.litellm_core_utils.safe_json_dumps import safe_dumps
from litellm.proxy._types import UserAPIKeyAuth
from litellm.proxy.auth.user_api_key_auth import user_api_key_auth
from litellm.types.tag_management import (
TagConfig,
TagDeleteRequest,
TagInfoRequest,
TagNewRequest,
TagUpdateRequest,
)
router = APIRouter()
async def _get_model_names(prisma_client, model_ids: list) -> Dict[str, str]:
"""Helper function to get model names from model IDs"""
try:
models = await prisma_client.db.litellm_proxymodeltable.find_many(
where={"model_id": {"in": model_ids}}
)
return {model.model_id: model.model_name for model in models}
except Exception as e:
verbose_proxy_logger.error(f"Error getting model names: {str(e)}")
return {}
async def _get_tags_config(prisma_client) -> Dict[str, TagConfig]:
"""Helper function to get tags config from db"""
try:
tags_config = await prisma_client.db.litellm_config.find_unique(
where={"param_name": "tags_config"}
)
if tags_config is None:
return {}
# Convert from JSON if needed
if isinstance(tags_config.param_value, str):
config_dict = json.loads(tags_config.param_value)
else:
config_dict = tags_config.param_value or {}
# For each tag, get the model names
for tag_name, tag_config in config_dict.items():
if isinstance(tag_config, dict) and tag_config.get("models"):
model_info = await _get_model_names(prisma_client, tag_config["models"])
tag_config["model_info"] = model_info
return config_dict
except Exception:
return {}
async def _save_tags_config(prisma_client, tags_config: Dict[str, TagConfig]):
"""Helper function to save tags config to db"""
try:
verbose_proxy_logger.debug(f"Saving tags config: {tags_config}")
# Convert TagConfig objects to dictionaries
tags_config_dict = {}
for name, tag in tags_config.items():
if isinstance(tag, TagConfig):
tag_dict = tag.model_dump()
# Remove model_info before saving as it will be dynamically generated
if "model_info" in tag_dict:
del tag_dict["model_info"]
tags_config_dict[name] = tag_dict
else:
# If it's already a dict, remove model_info
tag_copy = tag.copy()
if "model_info" in tag_copy:
del tag_copy["model_info"]
tags_config_dict[name] = tag_copy
json_tags_config = json.dumps(tags_config_dict, default=str)
verbose_proxy_logger.debug(f"JSON tags config: {json_tags_config}")
await prisma_client.db.litellm_config.upsert(
where={"param_name": "tags_config"},
data={
"create": {
"param_name": "tags_config",
"param_value": json_tags_config,
},
"update": {"param_value": json_tags_config},
},
)
except Exception as e:
raise HTTPException(
status_code=500, detail=f"Error saving tags config: {str(e)}"
)
@router.post(
"/tag/new",
tags=["tag management"],
dependencies=[Depends(user_api_key_auth)],
)
async def new_tag(
tag: TagNewRequest,
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
):
"""
Create a new tag.
Parameters:
- name: str - The name of the tag
- description: Optional[str] - Description of what this tag represents
- models: List[str] - List of LLM models allowed for this tag
"""
from litellm.proxy.proxy_server import prisma_client
if prisma_client is None:
raise HTTPException(status_code=500, detail="Database not connected")
try:
# Get existing tags config
tags_config = await _get_tags_config(prisma_client)
# Check if tag already exists
if tag.name in tags_config:
raise HTTPException(
status_code=400, detail=f"Tag {tag.name} already exists"
)
# Add new tag
tags_config[tag.name] = TagConfig(
name=tag.name,
description=tag.description,
models=tag.models,
created_at=str(datetime.datetime.now()),
updated_at=str(datetime.datetime.now()),
created_by=user_api_key_dict.user_id,
)
# Save updated config
await _save_tags_config(
prisma_client=prisma_client,
tags_config=tags_config,
)
# Update models with new tag
if tag.models:
for model_id in tag.models:
await _add_tag_to_deployment(
model_id=model_id,
tag=tag.name,
)
# Get model names for response
model_info = await _get_model_names(prisma_client, tag.models or [])
tags_config[tag.name].model_info = model_info
return {
"message": f"Tag {tag.name} created successfully",
"tag": tags_config[tag.name],
}
except Exception as e:
verbose_proxy_logger.exception(f"Error creating tag: {str(e)}")
raise HTTPException(status_code=500, detail=str(e))
async def _add_tag_to_deployment(model_id: str, tag: str):
"""Helper function to add tag to deployment"""
from litellm.proxy.proxy_server import prisma_client
if prisma_client is None:
raise HTTPException(status_code=500, detail="Database not connected")
deployment = await prisma_client.db.litellm_proxymodeltable.find_unique(
where={"model_id": model_id}
)
if deployment is None:
raise HTTPException(status_code=404, detail=f"Deployment {model_id} not found")
litellm_params = deployment.litellm_params
if "tags" not in litellm_params:
litellm_params["tags"] = []
litellm_params["tags"].append(tag)
await prisma_client.db.litellm_proxymodeltable.update(
where={"model_id": model_id},
data={"litellm_params": safe_dumps(litellm_params)},
)
@router.post(
"/tag/update",
tags=["tag management"],
dependencies=[Depends(user_api_key_auth)],
)
async def update_tag(
tag: TagUpdateRequest,
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
):
"""
Update an existing tag.
Parameters:
- name: str - The name of the tag to update
- description: Optional[str] - Updated description
- models: List[str] - Updated list of allowed LLM models
"""
from litellm.proxy.proxy_server import prisma_client
if prisma_client is None:
raise HTTPException(status_code=500, detail="Database not connected")
try:
# Get existing tags config
tags_config = await _get_tags_config(prisma_client)
# Check if tag exists
if tag.name not in tags_config:
raise HTTPException(status_code=404, detail=f"Tag {tag.name} not found")
# Update tag
tag_config_dict = dict(tags_config[tag.name])
tag_config_dict.update(
{
"description": tag.description,
"models": tag.models,
"updated_at": str(datetime.datetime.now()),
"updated_by": user_api_key_dict.user_id,
}
)
tags_config[tag.name] = TagConfig(**tag_config_dict)
# Save updated config
await _save_tags_config(prisma_client, tags_config)
# Get model names for response
model_info = await _get_model_names(prisma_client, tag.models or [])
tags_config[tag.name].model_info = model_info
return {
"message": f"Tag {tag.name} updated successfully",
"tag": tags_config[tag.name],
}
except Exception as e:
verbose_proxy_logger.exception(f"Error updating tag: {str(e)}")
raise HTTPException(status_code=500, detail=str(e))
@router.post(
"/tag/info",
tags=["tag management"],
dependencies=[Depends(user_api_key_auth)],
)
async def info_tag(
data: TagInfoRequest,
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
):
"""
Get information about specific tags.
Parameters:
- names: List[str] - List of tag names to get information for
"""
from litellm.proxy.proxy_server import prisma_client
if prisma_client is None:
raise HTTPException(status_code=500, detail="Database not connected")
try:
tags_config = await _get_tags_config(prisma_client)
# Filter tags based on requested names
requested_tags = {name: tags_config.get(name) for name in data.names}
# Check if any requested tags don't exist
missing_tags = [name for name in data.names if name not in tags_config]
if missing_tags:
raise HTTPException(
status_code=404, detail=f"Tags not found: {missing_tags}"
)
return requested_tags
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@router.get(
"/tag/list",
tags=["tag management"],
dependencies=[Depends(user_api_key_auth)],
)
async def list_tags(
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
):
"""
List all available tags.
"""
from litellm.proxy.proxy_server import prisma_client
if prisma_client is None:
raise HTTPException(status_code=500, detail="Database not connected")
try:
tags_config = await _get_tags_config(prisma_client)
list_of_tags = list(tags_config.values())
return list_of_tags
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@router.post(
"/tag/delete",
tags=["tag management"],
dependencies=[Depends(user_api_key_auth)],
)
async def delete_tag(
data: TagDeleteRequest,
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
):
"""
Delete a tag.
Parameters:
- name: str - The name of the tag to delete
"""
from litellm.proxy.proxy_server import prisma_client
if prisma_client is None:
raise HTTPException(status_code=500, detail="Database not connected")
try:
# Get existing tags config
tags_config = await _get_tags_config(prisma_client)
# Check if tag exists
if data.name not in tags_config:
raise HTTPException(status_code=404, detail=f"Tag {data.name} not found")
# Delete tag
del tags_config[data.name]
# Save updated config
await _save_tags_config(prisma_client, tags_config)
return {"message": f"Tag {data.name} deleted successfully"}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))

View file

@ -11,6 +11,10 @@ model_list:
litellm_settings:
require_auth_for_metrics_endpoint: true
callbacks: ["prometheus"]
service_callback: ["prometheus_system"]
service_callback: ["prometheus_system"]
router_settings:
enable_tag_filtering: True # 👈 Key Change

View file

@ -237,6 +237,9 @@ from litellm.proxy.management_endpoints.model_management_endpoints import (
from litellm.proxy.management_endpoints.organization_endpoints import (
router as organization_router,
)
from litellm.proxy.management_endpoints.tag_management_endpoints import (
router as tag_management_router,
)
from litellm.proxy.management_endpoints.team_callback_endpoints import (
router as team_callback_router,
)
@ -347,13 +350,13 @@ from fastapi import (
Request,
Response,
UploadFile,
applications,
status,
applications
)
from fastapi.encoders import jsonable_encoder
from fastapi.middleware.cors import CORSMiddleware
from fastapi.openapi.utils import get_openapi
from fastapi.openapi.docs import get_swagger_ui_html
from fastapi.openapi.utils import get_openapi
from fastapi.responses import (
FileResponse,
JSONResponse,
@ -735,7 +738,7 @@ try:
except Exception:
pass
# current_dir = os.path.dirname(os.path.abspath(__file__))
current_dir = os.path.dirname(os.path.abspath(__file__))
# ui_path = os.path.join(current_dir, "_experimental", "out")
# # Mount this test directory instead
# app.mount("/ui", StaticFiles(directory=ui_path, html=True), name="ui")
@ -753,14 +756,18 @@ app.add_middleware(PrometheusAuthMiddleware)
swagger_path = os.path.join(current_dir, "swagger")
app.mount("/swagger", StaticFiles(directory=swagger_path), name="swagger")
def swagger_monkey_patch(*args, **kwargs):
return get_swagger_ui_html(
*args,
**kwargs,
swagger_js_url="/swagger/swagger-ui-bundle.js",
swagger_css_url="/swagger/swagger-ui.css",
swagger_favicon_url="/swagger/favicon.png"
swagger_favicon_url="/swagger/favicon.png",
)
applications.get_swagger_ui_html = swagger_monkey_patch
from typing import Dict
@ -8174,3 +8181,4 @@ app.include_router(openai_files_router)
app.include_router(team_callback_router)
app.include_router(budget_management_router)
app.include_router(model_management_router)
app.include_router(tag_management_router)

View file

@ -0,0 +1,32 @@
from typing import Dict, List, Optional
from pydantic import BaseModel
class TagBase(BaseModel):
name: str
description: Optional[str] = None
models: Optional[List[str]] = None
model_info: Optional[Dict[str, str]] = None # maps model_id to model_name
class TagConfig(TagBase):
created_at: str
updated_at: str
created_by: Optional[str] = None
class TagNewRequest(TagBase):
pass
class TagUpdateRequest(TagBase):
pass
class TagDeleteRequest(BaseModel):
name: str
class TagInfoRequest(BaseModel):
names: List[str]

View file

@ -0,0 +1,160 @@
import json
import os
import sys
from typing import Any, Dict, Optional
import pytest
from fastapi.testclient import TestClient
sys.path.insert(
0, os.path.abspath("../../../..")
) # Adds the parent directory to the system path
from unittest.mock import patch
import litellm
from litellm.proxy.proxy_server import app
from litellm.types.tag_management import TagDeleteRequest, TagInfoRequest, TagNewRequest
client = TestClient(app)
@pytest.mark.asyncio
async def test_create_and_get_tag():
"""
Test creation of a new tag and retrieving its information
"""
# Mock the prisma client and _get_tags_config and _save_tags_config
with patch("litellm.proxy.proxy_server.prisma_client") as mock_prisma, patch(
"litellm.proxy.management_endpoints.tag_management_endpoints._get_tags_config"
) as mock_get_tags, patch(
"litellm.proxy.management_endpoints.tag_management_endpoints._save_tags_config"
) as mock_save_tags, patch(
"litellm.proxy.management_endpoints.tag_management_endpoints._add_tag_to_deployment"
) as mock_add_tag, patch(
"litellm.proxy.management_endpoints.tag_management_endpoints._get_model_names"
) as mock_get_models:
# Setup mocks
mock_get_tags.return_value = {}
mock_get_models.return_value = {"model-1": "gpt-3.5-turbo"}
# Create a new tag
tag_data = {
"name": "test-tag",
"description": "Test tag for unit testing",
"models": ["model-1"],
}
# Set admin access for the test
headers = {"Authorization": f"Bearer sk-1234"}
# Test tag creation
response = client.post("/tag/new", json=tag_data, headers=headers)
assert response.status_code == 200
result = response.json()
assert result["message"] == "Tag test-tag created successfully"
assert result["tag"]["name"] == "test-tag"
assert result["tag"]["description"] == "Test tag for unit testing"
# Mock updated tag config for the get request
mock_get_tags.return_value = {
"test-tag": {
"name": "test-tag",
"description": "Test tag for unit testing",
"models": ["model-1"],
"model_info": {"model-1": "gpt-3.5-turbo"},
}
}
# Test retrieving tag info
info_data = {"names": ["test-tag"]}
response = client.post("/tag/info", json=info_data, headers=headers)
assert response.status_code == 200
result = response.json()
assert "test-tag" in result
assert result["test-tag"]["description"] == "Test tag for unit testing"
@pytest.mark.asyncio
async def test_update_tag():
"""
Test updating an existing tag
"""
# Mock the prisma client and _get_tags_config and _save_tags_config
with patch("litellm.proxy.proxy_server.prisma_client") as mock_prisma, patch(
"litellm.proxy.management_endpoints.tag_management_endpoints._get_tags_config"
) as mock_get_tags, patch(
"litellm.proxy.management_endpoints.tag_management_endpoints._save_tags_config"
) as mock_save_tags, patch(
"litellm.proxy.management_endpoints.tag_management_endpoints._get_model_names"
) as mock_get_models:
# Setup mocks for existing tag
mock_get_tags.return_value = {
"test-tag": {
"name": "test-tag",
"description": "Original description",
"models": ["model-1"],
"created_at": "2023-01-01T00:00:00",
"updated_at": "2023-01-01T00:00:00",
"created_by": "user-123",
}
}
mock_get_models.return_value = {"model-1": "gpt-3.5-turbo", "model-2": "gpt-4"}
# Update tag data
update_data = {
"name": "test-tag",
"description": "Updated description",
"models": ["model-1", "model-2"],
}
# Set admin access for the test
headers = {"Authorization": f"Bearer sk-1234"}
# Test tag update
response = client.post("/tag/update", json=update_data, headers=headers)
assert response.status_code == 200
result = response.json()
assert result["message"] == "Tag test-tag updated successfully"
assert result["tag"]["description"] == "Updated description"
assert len(result["tag"]["models"]) == 2
assert "model-2" in result["tag"]["models"]
@pytest.mark.asyncio
async def test_delete_tag():
"""
Test deleting a tag
"""
# Mock the prisma client and _get_tags_config and _save_tags_config
with patch("litellm.proxy.proxy_server.prisma_client") as mock_prisma, patch(
"litellm.proxy.management_endpoints.tag_management_endpoints._get_tags_config"
) as mock_get_tags, patch(
"litellm.proxy.management_endpoints.tag_management_endpoints._save_tags_config"
) as mock_save_tags:
# Setup mocks for existing tag
mock_get_tags.return_value = {
"test-tag": {
"name": "test-tag",
"description": "Test tag for deletion",
"models": ["model-1"],
"created_at": "2023-01-01T00:00:00",
"updated_at": "2023-01-01T00:00:00",
"created_by": "user-123",
}
}
# Delete tag data
delete_data = {"name": "test-tag"}
# Set admin access for the test
headers = {"Authorization": f"Bearer sk-1234"}
# Test tag deletion
response = client.post("/tag/delete", json=delete_data, headers=headers)
assert response.status_code == 200
result = response.json()
assert result["message"] == "Tag test-tag deleted successfully"
# Verify _save_tags_config was called without the deleted tag
mock_save_tags.assert_called_once()

View file

@ -33,6 +33,7 @@ import TransformRequestPanel from "@/components/transform_request";
import { fetchUserModels } from "@/components/create_key_button";
import { fetchTeams } from "@/components/common_components/fetch_teams";
import MCPToolsViewer from "@/components/mcp_tools";
import TagManagement from "@/components/tag_management";
function getCookie(name: string) {
const cookieValue = document.cookie
@ -355,6 +356,12 @@ export default function CreateKeyPage() {
userRole={userRole}
userID={userID}
/>
) : page == "tag-management" ? (
<TagManagement
accessToken={accessToken}
userRole={userRole}
userID={userID}
/>
) : page == "new_usage" ? (
<NewUsagePage
userID={userID}

View file

@ -31,6 +31,7 @@ import { litellmModeMapping, ModelMode, EndpointType, getEndpointType } from "./
import { Prism as SyntaxHighlighter } from "react-syntax-highlighter";
import { coy } from 'react-syntax-highlighter/dist/esm/styles/prism';
import EndpointSelector from "./chat_ui/EndpointSelector";
import TagSelector from "./tag_management/TagSelector";
import { determineEndpointType } from "./chat_ui/EndpointUtils";
import {
SendOutlined,
@ -40,7 +41,8 @@ import {
RobotOutlined,
UserOutlined,
DeleteOutlined,
LoadingOutlined
LoadingOutlined,
TagsOutlined
} from "@ant-design/icons";
interface ChatUIProps {
@ -73,6 +75,7 @@ const ChatUI: React.FC<ChatUIProps> = ({
const [endpointType, setEndpointType] = useState<string>(EndpointType.CHAT);
const [isLoading, setIsLoading] = useState<boolean>(false);
const abortControllerRef = useRef<AbortController | null>(null);
const [selectedTags, setSelectedTags] = useState<string[]>([]);
const chatEndRef = useRef<HTMLDivElement>(null);
@ -202,6 +205,7 @@ const ChatUI: React.FC<ChatUIProps> = ({
(chunk, model) => updateTextUI("assistant", chunk, model),
selectedModel,
effectiveApiKey,
selectedTags,
signal
);
} else if (endpointType === EndpointType.IMAGE) {
@ -211,6 +215,7 @@ const ChatUI: React.FC<ChatUIProps> = ({
(imageUrl, model) => updateImageUI(imageUrl, model),
selectedModel,
effectiveApiKey,
selectedTags,
signal
);
}
@ -343,6 +348,18 @@ const ChatUI: React.FC<ChatUIProps> = ({
endpointType={endpointType}
onEndpointChange={handleEndpointChange}
className="mb-4"
/>
</div>
<div>
<Text className="font-medium block mb-2 text-gray-700 flex items-center">
<TagsOutlined className="mr-2" /> Tags
</Text>
<TagSelector
value={selectedTags}
onChange={setSelectedTags}
className="mb-4"
accessToken={accessToken || ""}
/>
</div>

View file

@ -7,6 +7,7 @@ export async function makeOpenAIChatCompletionRequest(
updateUI: (chunk: string, model: string) => void,
selectedModel: string,
accessToken: string,
tags?: string[],
signal?: AbortSignal
) {
// base url should be the current base_url
@ -22,6 +23,7 @@ export async function makeOpenAIChatCompletionRequest(
apiKey: accessToken, // Replace with your OpenAI API key
baseURL: proxyBaseUrl, // Replace with your OpenAI API base URL
dangerouslyAllowBrowser: true, // using a temporary litellm proxy key
defaultHeaders: tags && tags.length > 0 ? { 'x-litellm-tags': tags.join(',') } : undefined,
});
try {

View file

@ -6,6 +6,7 @@ export async function makeOpenAIImageGenerationRequest(
updateUI: (imageUrl: string, model: string) => void,
selectedModel: string,
accessToken: string,
tags?: string[],
signal?: AbortSignal
) {
// base url should be the current base_url
@ -21,6 +22,7 @@ export async function makeOpenAIImageGenerationRequest(
apiKey: accessToken,
baseURL: proxyBaseUrl,
dangerouslyAllowBrowser: true,
defaultHeaders: tags && tags.length > 0 ? { 'x-litellm-tags': tags.join(',') } : undefined,
});
try {

View file

@ -22,6 +22,7 @@ import {
ThunderboltOutlined,
LockOutlined,
ToolOutlined,
TagsOutlined,
} from '@ant-design/icons';
import { old_admin_roles, v2_admin_role_names, all_admin_roles, rolesAllowedToSeeUsage, rolesWithWriteAccess, internalUserRoles } from '../utils/roles';
@ -63,9 +64,6 @@ const Sidebar: React.FC<SidebarProps> = ({
{ key: "14", page: "api_ref", label: "API Reference", icon: <ApiOutlined /> },
{ key: "16", page: "model-hub", label: "Model Hub", icon: <AppstoreOutlined /> },
{ key: "15", page: "logs", label: "Logs", icon: <LineChartOutlined />},
{
key: "experimental",
page: "experimental",
@ -77,6 +75,7 @@ const Sidebar: React.FC<SidebarProps> = ({
{ key: "11", page: "guardrails", label: "Guardrails", icon: <SafetyOutlined />, roles: all_admin_roles },
{ key: "12", page: "new_usage", label: "New Usage", icon: <BarChartOutlined />, roles: [...all_admin_roles, ...internalUserRoles] },
{ key: "18", page: "mcp-tools", label: "MCP Tools", icon: <ToolOutlined />, roles: all_admin_roles },
{ key: "19", page: "tag-management", label: "Tag Management", icon: <TagsOutlined />, roles: all_admin_roles },
]
},
{

View file

@ -3,6 +3,7 @@
*/
import { all_admin_roles } from "@/utils/roles";
import { message } from "antd";
import { TagNewRequest, TagUpdateRequest, TagDeleteRequest, TagInfoRequest, TagListResponse, TagInfoResponse } from "./tag_management/types";
const isLocal = process.env.NODE_ENV === "development";
export const proxyBaseUrl = isLocal ? "http://localhost:4000" : null;
@ -4155,4 +4156,157 @@ export const callMCPTool = async (accessToken: string, toolName: string, toolArg
console.error("Failed to call MCP tool:", error);
throw error;
}
};
export const tagCreateCall = async (
accessToken: string,
formValues: TagNewRequest
): Promise<void> => {
try {
let url = proxyBaseUrl
? `${proxyBaseUrl}/tag/new`
: `/tag/new`;
const response = await fetch(url, {
method: "POST",
headers: {
"Content-Type": "application/json",
Authorization: `Bearer ${accessToken}`,
},
body: JSON.stringify(formValues),
});
if (!response.ok) {
const errorData = await response.text();
await handleError(errorData);
return;
}
return await response.json();
} catch (error) {
console.error("Error creating tag:", error);
throw error;
}
};
export const tagUpdateCall = async (
accessToken: string,
formValues: TagUpdateRequest
): Promise<void> => {
try {
let url = proxyBaseUrl
? `${proxyBaseUrl}/tag/update`
: `/tag/update`;
const response = await fetch(url, {
method: "POST",
headers: {
"Content-Type": "application/json",
Authorization: `Bearer ${accessToken}`,
},
body: JSON.stringify(formValues),
});
if (!response.ok) {
const errorData = await response.text();
await handleError(errorData);
return;
}
return await response.json();
} catch (error) {
console.error("Error updating tag:", error);
throw error;
}
};
export const tagInfoCall = async (
accessToken: string,
tagNames: string[]
): Promise<TagInfoResponse> => {
try {
let url = proxyBaseUrl
? `${proxyBaseUrl}/tag/info`
: `/tag/info`;
const response = await fetch(url, {
method: "POST",
headers: {
"Content-Type": "application/json",
Authorization: `Bearer ${accessToken}`,
},
body: JSON.stringify({ names: tagNames }),
});
if (!response.ok) {
const errorData = await response.text();
await handleError(errorData);
return {};
}
const data = await response.json();
return data as TagInfoResponse;
} catch (error) {
console.error("Error getting tag info:", error);
throw error;
}
};
export const tagListCall = async (accessToken: string): Promise<TagListResponse> => {
try {
let url = proxyBaseUrl
? `${proxyBaseUrl}/tag/list`
: `/tag/list`;
const response = await fetch(url, {
method: "GET",
headers: {
Authorization: `Bearer ${accessToken}`,
},
});
if (!response.ok) {
const errorData = await response.text();
await handleError(errorData);
return {};
}
const data = await response.json();
return data as TagListResponse;
} catch (error) {
console.error("Error listing tags:", error);
throw error;
}
};
export const tagDeleteCall = async (
accessToken: string,
tagName: string
): Promise<void> => {
try {
let url = proxyBaseUrl
? `${proxyBaseUrl}/tag/delete`
: `/tag/delete`;
const response = await fetch(url, {
method: "POST",
headers: {
"Content-Type": "application/json",
Authorization: `Bearer ${accessToken}`,
},
body: JSON.stringify({ name: tagName }),
});
if (!response.ok) {
const errorData = await response.text();
await handleError(errorData);
return;
}
return await response.json();
} catch (error) {
console.error("Error deleting tag:", error);
throw error;
}
};

View file

@ -0,0 +1,52 @@
import React, { useEffect, useState } from 'react';
import { Select } from 'antd';
import { Tag } from './types';
import { tagListCall } from '../networking';
interface TagSelectorProps {
onChange: (selectedTags: string[]) => void;
value?: string[];
className?: string;
accessToken: string;
}
const TagSelector: React.FC<TagSelectorProps> = ({ onChange, value, className, accessToken }) => {
const [tags, setTags] = useState<Tag[]>([]);
const [loading, setLoading] = useState(false);
useEffect(() => {
const fetchTags = async () => {
if (!accessToken) return;
try {
const response = await tagListCall(accessToken);
console.log("List tags response:", response);
setTags(Object.values(response));
} catch (error) {
console.error("Error fetching tags:", error);
}
};
fetchTags();
}, []);
return (
<Select
mode="multiple"
placeholder="Select tags"
onChange={onChange}
value={value}
loading={loading}
className={className}
options={tags.map(tag => ({
label: tag.name,
value: tag.name,
title: tag.description || tag.name,
}))}
optionFilterProp="label"
showSearch
style={{ width: '100%' }}
/>
);
};
export default TagSelector;

View file

@ -0,0 +1,244 @@
import React from "react";
import {
Table,
TableBody,
TableCell,
TableHead,
TableHeaderCell,
TableRow,
Icon,
Button,
Badge,
Text,
} from "@tremor/react";
import {
PencilAltIcon,
TrashIcon,
SwitchVerticalIcon,
ChevronUpIcon,
ChevronDownIcon,
} from "@heroicons/react/outline";
import { Tooltip } from "antd";
import {
ColumnDef,
flexRender,
getCoreRowModel,
getSortedRowModel,
SortingState,
useReactTable,
} from "@tanstack/react-table";
import { Tag } from "./types";
interface TagTableProps {
data: Tag[];
onEdit: (tag: Tag) => void;
onDelete: (tagName: string) => void;
onSelectTag: (tagName: string) => void;
}
const TagTable: React.FC<TagTableProps> = ({
data,
onEdit,
onDelete,
onSelectTag,
}) => {
const [sorting, setSorting] = React.useState<SortingState>([
{ id: "created_at", desc: true }
]);
const columns: ColumnDef<Tag>[] = [
{
header: "Tag Name",
accessorKey: "name",
cell: ({ row }) => {
const tag = row.original;
return (
<div className="overflow-hidden">
<Tooltip title={tag.name}>
<Button
size="xs"
variant="light"
className="font-mono text-blue-500 bg-blue-50 hover:bg-blue-100 text-xs font-normal px-2 py-0.5"
onClick={() => onSelectTag(tag.name)}
>
{tag.name}
</Button>
</Tooltip>
</div>
);
},
},
{
header: "Description",
accessorKey: "description",
cell: ({ row }) => {
const tag = row.original;
return (
<Tooltip title={tag.description}>
<span className="text-xs">
{tag.description || "-"}
</span>
</Tooltip>
);
},
},
{
header: "Allowed LLMs",
accessorKey: "models",
cell: ({ row }) => {
const tag = row.original;
return (
<div style={{ display: "flex", flexDirection: "column" }}>
{tag?.models?.length === 0 ? (
<Badge size="xs" className="mb-1" color="red">
All Models
</Badge>
) : (
tag?.models?.map((modelId) => (
<Badge
key={modelId}
size="xs"
className="mb-1"
color="blue"
>
<Tooltip title={`ID: ${modelId}`}>
<Text>
{tag.model_info?.[modelId] || modelId}
</Text>
</Tooltip>
</Badge>
))
)}
</div>
);
},
},
{
header: "Created",
accessorKey: "created_at",
sortingFn: "datetime",
cell: ({ row }) => {
const tag = row.original;
return (
<span className="text-xs">
{new Date(tag.created_at).toLocaleDateString()}
</span>
);
},
},
{
id: "actions",
header: "",
cell: ({ row }) => {
const tag = row.original;
return (
<div className="flex space-x-2">
<Icon
icon={PencilAltIcon}
size="sm"
onClick={() => onEdit(tag)}
className="cursor-pointer"
/>
<Icon
icon={TrashIcon}
size="sm"
onClick={() => onDelete(tag.name)}
className="cursor-pointer"
/>
</div>
);
},
},
];
const table = useReactTable({
data,
columns,
state: {
sorting,
},
onSortingChange: setSorting,
getCoreRowModel: getCoreRowModel(),
getSortedRowModel: getSortedRowModel(),
enableSorting: true,
});
return (
<div className="rounded-lg custom-border relative">
<div className="overflow-x-auto">
<Table className="[&_td]:py-0.5 [&_th]:py-1">
<TableHead>
{table.getHeaderGroups().map((headerGroup) => (
<TableRow key={headerGroup.id}>
{headerGroup.headers.map((header) => (
<TableHeaderCell
key={header.id}
className={`py-1 h-8 ${
header.id === 'actions'
? 'sticky right-0 bg-white shadow-[-4px_0_8px_-6px_rgba(0,0,0,0.1)]'
: ''
}`}
onClick={header.column.getToggleSortingHandler()}
>
<div className="flex items-center justify-between gap-2">
<div className="flex items-center">
{header.isPlaceholder ? null : (
flexRender(
header.column.columnDef.header,
header.getContext()
)
)}
</div>
{header.id !== 'actions' && (
<div className="w-4">
{header.column.getIsSorted() ? (
{
asc: <ChevronUpIcon className="h-4 w-4 text-blue-500" />,
desc: <ChevronDownIcon className="h-4 w-4 text-blue-500" />
}[header.column.getIsSorted() as string]
) : (
<SwitchVerticalIcon className="h-4 w-4 text-gray-400" />
)}
</div>
)}
</div>
</TableHeaderCell>
))}
</TableRow>
))}
</TableHead>
<TableBody>
{table.getRowModel().rows.length > 0 ? (
table.getRowModel().rows.map((row) => (
<TableRow key={row.id} className="h-8">
{row.getVisibleCells().map((cell) => (
<TableCell
key={cell.id}
className={`py-0.5 max-h-8 overflow-hidden text-ellipsis whitespace-nowrap ${
cell.column.id === 'actions'
? 'sticky right-0 bg-white shadow-[-4px_0_8px_-6px_rgba(0,0,0,0.1)]'
: ''
}`}
>
{flexRender(cell.column.columnDef.cell, cell.getContext())}
</TableCell>
))}
</TableRow>
))
) : (
<TableRow>
<TableCell colSpan={columns.length} className="h-8 text-center">
<div className="text-center text-gray-500">
<p>No tags found</p>
</div>
</TableCell>
</TableRow>
)}
</TableBody>
</Table>
</div>
</div>
);
};
export default TagTable;

View file

@ -0,0 +1,309 @@
import React, { useState, useEffect } from "react";
import {
Card,
Icon,
Button,
Col,
Text,
Grid,
TextInput,
} from "@tremor/react";
import {
InformationCircleIcon,
RefreshIcon,
} from "@heroicons/react/outline";
import {
Modal,
Form,
Select as Select2,
message,
Tooltip,
Input
} from "antd";
import { InfoCircleOutlined } from '@ant-design/icons';
import NumericalInput from "../shared/numerical_input";
import TagInfoView from "./tag_info";
import { modelInfoCall } from "../networking";
import { tagCreateCall, tagListCall, tagDeleteCall } from "../networking";
import { Tag } from "./types";
import TagTable from "./TagTable";
interface ModelInfo {
model_name: string;
litellm_params: {
model: string;
};
model_info: {
id: string;
};
}
interface TagProps {
accessToken: string | null;
userID: string | null;
userRole: string | null;
}
const TagManagement: React.FC<TagProps> = ({
accessToken,
userID,
userRole,
}) => {
const [tags, setTags] = useState<Tag[]>([]);
const [isCreateModalVisible, setIsCreateModalVisible] = useState(false);
const [selectedTagId, setSelectedTagId] = useState<string | null>(null);
const [editTag, setEditTag] = useState<boolean>(false);
const [isDeleteModalOpen, setIsDeleteModalOpen] = useState(false);
const [tagToDelete, setTagToDelete] = useState<string | null>(null);
const [lastRefreshed, setLastRefreshed] = useState("");
const [form] = Form.useForm();
const [availableModels, setAvailableModels] = useState<ModelInfo[]>([]);
const fetchTags = async () => {
if (!accessToken) return;
try {
const response = await tagListCall(accessToken);
console.log("List tags response:", response);
setTags(Object.values(response));
} catch (error) {
console.error("Error fetching tags:", error);
message.error("Error fetching tags: " + error);
}
};
const handleRefreshClick = () => {
fetchTags();
const currentDate = new Date();
setLastRefreshed(currentDate.toLocaleString());
};
const handleCreate = async (formValues: any) => {
if (!accessToken) return;
try {
await tagCreateCall(accessToken, {
name: formValues.tag_name,
description: formValues.description,
models: formValues.allowed_llms,
});
message.success("Tag created successfully");
setIsCreateModalVisible(false);
form.resetFields();
fetchTags();
} catch (error) {
console.error("Error creating tag:", error);
message.error("Error creating tag: " + error);
}
};
const handleDelete = async (tagName: string) => {
setTagToDelete(tagName);
setIsDeleteModalOpen(true);
};
const confirmDelete = async () => {
if (!accessToken || !tagToDelete) return;
try {
await tagDeleteCall(accessToken, tagToDelete);
message.success("Tag deleted successfully");
fetchTags();
} catch (error) {
console.error("Error deleting tag:", error);
message.error("Error deleting tag: " + error);
}
setIsDeleteModalOpen(false);
setTagToDelete(null);
};
useEffect(() => {
if (userID && userRole && accessToken) {
const fetchModels = async () => {
try {
const response = await modelInfoCall(accessToken, userID, userRole);
if (response && response.data) {
setAvailableModels(response.data);
}
} catch (error) {
console.error("Error fetching models:", error);
message.error("Error fetching models: " + error);
}
};
fetchModels();
}
}, [accessToken, userID, userRole]);
useEffect(() => {
fetchTags();
}, [accessToken]);
return (
<div className="w-full mx-4 h-[75vh]">
{selectedTagId ? (
<TagInfoView
tagId={selectedTagId}
onClose={() => {
setSelectedTagId(null);
setEditTag(false);
}}
accessToken={accessToken}
is_admin={userRole === "Admin"}
editTag={editTag}
/>
) : (
<div className="gap-2 p-8 h-[75vh] w-full mt-2">
<div className="flex justify-between mt-2 w-full items-center mb-4">
<h1>Tag Management</h1>
<div className="flex items-center space-x-2">
{lastRefreshed && <Text>Last Refreshed: {lastRefreshed}</Text>}
<Icon
icon={RefreshIcon}
variant="shadow"
size="xs"
className="self-center cursor-pointer"
onClick={handleRefreshClick}
/>
</div>
</div>
<Text className="mb-4">
Click on a tag name to view and edit its details.
<p>You can use tags to restrict the usage of certain LLMs based on tags passed in the request. Read more about tag routing <a href="https://docs.litellm.ai/docs/proxy/tag_routing" target="_blank" rel="noopener noreferrer">here</a>.</p>
</Text>
<Button
className="mb-4"
onClick={() => setIsCreateModalVisible(true)}
>
+ Create New Tag
</Button>
<Grid numItems={1} className="gap-2 pt-2 pb-2 h-[75vh] w-full mt-2">
<Col numColSpan={1}>
<TagTable
data={tags}
onEdit={(tag) => {
setSelectedTagId(tag.name);
setEditTag(true);
}}
onDelete={handleDelete}
onSelectTag={setSelectedTagId}
/>
</Col>
</Grid>
{/* Create Tag Modal */}
<Modal
title="Create New Tag"
visible={isCreateModalVisible}
width={800}
footer={null}
onCancel={() => {
setIsCreateModalVisible(false);
form.resetFields();
}}
>
<Form
form={form}
onFinish={handleCreate}
labelCol={{ span: 8 }}
wrapperCol={{ span: 16 }}
labelAlign="left"
>
<Form.Item
label="Tag Name"
name="tag_name"
rules={[{ required: true, message: "Please input a tag name" }]}
>
<TextInput />
</Form.Item>
<Form.Item
label="Description"
name="description"
>
<Input.TextArea rows={4} />
</Form.Item>
<Form.Item
label={
<span>
Allowed Models{' '}
<Tooltip title="Select which LLMs are allowed to process requests from this tag">
<InfoCircleOutlined style={{ marginLeft: '4px' }} />
</Tooltip>
</span>
}
name="allowed_llms"
>
<Select2
mode="multiple"
placeholder="Select LLMs"
>
{availableModels.map((model) => (
<Select2.Option key={model.model_info.id} value={model.model_info.id}>
<div>
<span>{model.model_name}</span>
<span className="text-gray-400 ml-2">({model.model_info.id})</span>
</div>
</Select2.Option>
))}
</Select2>
</Form.Item>
<div style={{ textAlign: "right", marginTop: "10px" }}>
<Button type="submit">
Create Tag
</Button>
</div>
</Form>
</Modal>
{/* Delete Confirmation Modal */}
{isDeleteModalOpen && (
<div className="fixed z-10 inset-0 overflow-y-auto">
<div className="flex items-end justify-center min-h-screen pt-4 px-4 pb-20 text-center sm:block sm:p-0">
<div className="fixed inset-0 transition-opacity" aria-hidden="true">
<div className="absolute inset-0 bg-gray-500 opacity-75"></div>
</div>
<div className="inline-block align-bottom bg-white rounded-lg text-left overflow-hidden shadow-xl transform transition-all sm:my-8 sm:align-middle sm:max-w-lg sm:w-full">
<div className="bg-white px-4 pt-5 pb-4 sm:p-6 sm:pb-4">
<div className="sm:flex sm:items-start">
<div className="mt-3 text-center sm:mt-0 sm:ml-4 sm:text-left">
<h3 className="text-lg leading-6 font-medium text-gray-900">
Delete Tag
</h3>
<div className="mt-2">
<p className="text-sm text-gray-500">
Are you sure you want to delete this tag?
</p>
</div>
</div>
</div>
</div>
<div className="bg-gray-50 px-4 py-3 sm:px-6 sm:flex sm:flex-row-reverse">
<Button
onClick={confirmDelete}
color="red"
className="ml-2"
>
Delete
</Button>
<Button onClick={() => {
setIsDeleteModalOpen(false);
setTagToDelete(null);
}}>
Cancel
</Button>
</div>
</div>
</div>
</div>
)}
</div>
)}
</div>
);
};
export default TagManagement;

View file

@ -0,0 +1,206 @@
import React, { useState, useEffect } from "react";
import {
Card,
Text,
Title,
Button,
Badge,
} from "@tremor/react";
import {
Form,
Input,
Select as Select2,
message,
Tooltip,
} from "antd";
import { InfoCircleOutlined } from '@ant-design/icons';
import { fetchUserModels } from "../create_key_button";
import { getModelDisplayName } from "../key_team_helpers/fetch_available_models_team_key";
import { tagInfoCall, tagUpdateCall } from "../networking";
import { Tag, TagInfoResponse } from "./types";
interface TagInfoViewProps {
tagId: string;
onClose: () => void;
accessToken: string | null;
is_admin: boolean;
editTag: boolean;
}
const TagInfoView: React.FC<TagInfoViewProps> = ({
tagId,
onClose,
accessToken,
is_admin,
editTag,
}) => {
const [form] = Form.useForm();
const [tagDetails, setTagDetails] = useState<Tag | null>(null);
const [isEditing, setIsEditing] = useState<boolean>(editTag);
const [userModels, setUserModels] = useState<string[]>([]);
const fetchTagDetails = async () => {
if (!accessToken) return;
try {
const response = await tagInfoCall(accessToken, [tagId]);
const tagData = response[tagId];
if (tagData) {
setTagDetails(tagData);
if (editTag) {
form.setFieldsValue({
name: tagData.name,
description: tagData.description,
models: tagData.models,
});
}
}
} catch (error) {
console.error("Error fetching tag details:", error);
message.error("Error fetching tag details: " + error);
}
};
useEffect(() => {
fetchTagDetails();
}, [tagId, accessToken]);
useEffect(() => {
if (accessToken) {
// Using dummy values for userID and userRole since they're required by the function
// TODO: Pass these as props if needed for the actual API implementation
fetchUserModels("dummy-user", "Admin", accessToken, setUserModels);
}
}, [accessToken]);
const handleSave = async (values: any) => {
if (!accessToken) return;
try {
await tagUpdateCall(accessToken, {
name: values.name,
description: values.description,
models: values.models,
});
message.success("Tag updated successfully");
setIsEditing(false);
fetchTagDetails();
} catch (error) {
console.error("Error updating tag:", error);
message.error("Error updating tag: " + error);
}
};
if (!tagDetails) {
return <div>Loading...</div>;
}
return (
<div className="p-4">
<div className="flex justify-between items-center mb-6">
<div>
<Button onClick={onClose} className="mb-4"> Back to Tags</Button>
<Title>Tag Name: {tagDetails.name}</Title>
<Text className="text-gray-500">{tagDetails.description || "No description"}</Text>
</div>
{is_admin && !isEditing && (
<Button onClick={() => setIsEditing(true)}>Edit Tag</Button>
)}
</div>
{isEditing ? (
<Card>
<Form
form={form}
onFinish={handleSave}
layout="vertical"
initialValues={tagDetails}
>
<Form.Item
label="Tag Name"
name="name"
rules={[{ required: true, message: "Please input a tag name" }]}
>
<Input />
</Form.Item>
<Form.Item
label="Description"
name="description"
>
<Input.TextArea rows={4} />
</Form.Item>
<Form.Item
label={
<span>
Allowed LLMs{' '}
<Tooltip title="Select which LLMs are allowed to process this type of data">
<InfoCircleOutlined style={{ marginLeft: '4px' }} />
</Tooltip>
</span>
}
name="models"
>
<Select2
mode="multiple"
placeholder="Select LLMs"
>
{userModels.map((modelId) => (
<Select2.Option key={modelId} value={modelId}>
{getModelDisplayName(modelId)}
</Select2.Option>
))}
</Select2>
</Form.Item>
<div className="flex justify-end space-x-2">
<Button onClick={() => setIsEditing(false)}>Cancel</Button>
<Button type="submit">Save Changes</Button>
</div>
</Form>
</Card>
) : (
<div className="space-y-6">
<Card>
<Title>Tag Details</Title>
<div className="space-y-4 mt-4">
<div>
<Text className="font-medium">Name</Text>
<Text>{tagDetails.name}</Text>
</div>
<div>
<Text className="font-medium">Description</Text>
<Text>{tagDetails.description || "-"}</Text>
</div>
<div>
<Text className="font-medium">Allowed LLMs</Text>
<div className="flex flex-wrap gap-2 mt-2">
{tagDetails.models.length === 0 ? (
<Badge color="red">All Models</Badge>
) : (
tagDetails.models.map((modelId) => (
<Badge key={modelId} color="blue">
<Tooltip title={`ID: ${modelId}`}>
{tagDetails.model_info?.[modelId] || modelId}
</Tooltip>
</Badge>
))
)}
</div>
</div>
<div>
<Text className="font-medium">Created</Text>
<Text>{tagDetails.created_at ? new Date(tagDetails.created_at).toLocaleString() : "-"}</Text>
</div>
<div>
<Text className="font-medium">Last Updated</Text>
<Text>{tagDetails.updated_at ? new Date(tagDetails.updated_at).toLocaleString() : "-"}</Text>
</div>
</div>
</Card>
</div>
)}
</div>
);
};
export default TagInfoView;

View file

@ -0,0 +1,34 @@
export interface Tag {
name: string;
description?: string;
models: string[]; // model IDs
model_info?: { [key: string]: string }; // maps model_id to model_name
created_at: string;
updated_at: string;
created_by?: string;
updated_by?: string;
}
export interface TagInfoRequest {
names: string[];
}
export interface TagNewRequest {
name: string;
description?: string;
models: string[];
}
export interface TagUpdateRequest {
name: string;
description?: string;
models: string[];
}
export interface TagDeleteRequest {
name: string;
}
// The API returns a dictionary of tags where the key is the tag name
export type TagListResponse = Record<string, Tag>;
export type TagInfoResponse = Record<string, Tag>;