mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-24 18:24:20 +00:00
[Feat] LiteLLM Tag/Policy Management (#9813)
* rendering tags on UI * use /models for building tags * CRUD endpoints for Tag management * fix tag management * working api for LIST tags * working tag management * refactor UI components * fixes ui tag management * clean up ui tag management * fix tag management ui * fix show allowed llms * e2e tag controls * stash change for rendering tags on UI * ui working tag selector on Test Key page * fixes for tag management * clean up tag info * fix code quality * test for tag management * ui clarify what tag routing is
This commit is contained in:
parent
ac9f03beae
commit
ff3a6830a4
16 changed files with 1595 additions and 9 deletions
356
litellm/proxy/management_endpoints/tag_management_endpoints.py
Normal file
356
litellm/proxy/management_endpoints/tag_management_endpoints.py
Normal file
|
@ -0,0 +1,356 @@
|
|||
"""
|
||||
TAG MANAGEMENT
|
||||
|
||||
All /tag management endpoints
|
||||
|
||||
/tag/new
|
||||
/tag/info
|
||||
/tag/update
|
||||
/tag/delete
|
||||
/tag/list
|
||||
"""
|
||||
|
||||
import datetime
|
||||
import json
|
||||
from typing import Dict
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
|
||||
from litellm._logging import verbose_proxy_logger
|
||||
from litellm.litellm_core_utils.safe_json_dumps import safe_dumps
|
||||
from litellm.proxy._types import UserAPIKeyAuth
|
||||
from litellm.proxy.auth.user_api_key_auth import user_api_key_auth
|
||||
from litellm.types.tag_management import (
|
||||
TagConfig,
|
||||
TagDeleteRequest,
|
||||
TagInfoRequest,
|
||||
TagNewRequest,
|
||||
TagUpdateRequest,
|
||||
)
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
async def _get_model_names(prisma_client, model_ids: list) -> Dict[str, str]:
|
||||
"""Helper function to get model names from model IDs"""
|
||||
try:
|
||||
models = await prisma_client.db.litellm_proxymodeltable.find_many(
|
||||
where={"model_id": {"in": model_ids}}
|
||||
)
|
||||
return {model.model_id: model.model_name for model in models}
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.error(f"Error getting model names: {str(e)}")
|
||||
return {}
|
||||
|
||||
|
||||
async def _get_tags_config(prisma_client) -> Dict[str, TagConfig]:
|
||||
"""Helper function to get tags config from db"""
|
||||
try:
|
||||
tags_config = await prisma_client.db.litellm_config.find_unique(
|
||||
where={"param_name": "tags_config"}
|
||||
)
|
||||
if tags_config is None:
|
||||
return {}
|
||||
# Convert from JSON if needed
|
||||
if isinstance(tags_config.param_value, str):
|
||||
config_dict = json.loads(tags_config.param_value)
|
||||
else:
|
||||
config_dict = tags_config.param_value or {}
|
||||
|
||||
# For each tag, get the model names
|
||||
for tag_name, tag_config in config_dict.items():
|
||||
if isinstance(tag_config, dict) and tag_config.get("models"):
|
||||
model_info = await _get_model_names(prisma_client, tag_config["models"])
|
||||
tag_config["model_info"] = model_info
|
||||
|
||||
return config_dict
|
||||
except Exception:
|
||||
return {}
|
||||
|
||||
|
||||
async def _save_tags_config(prisma_client, tags_config: Dict[str, TagConfig]):
|
||||
"""Helper function to save tags config to db"""
|
||||
try:
|
||||
verbose_proxy_logger.debug(f"Saving tags config: {tags_config}")
|
||||
# Convert TagConfig objects to dictionaries
|
||||
tags_config_dict = {}
|
||||
for name, tag in tags_config.items():
|
||||
if isinstance(tag, TagConfig):
|
||||
tag_dict = tag.model_dump()
|
||||
# Remove model_info before saving as it will be dynamically generated
|
||||
if "model_info" in tag_dict:
|
||||
del tag_dict["model_info"]
|
||||
tags_config_dict[name] = tag_dict
|
||||
else:
|
||||
# If it's already a dict, remove model_info
|
||||
tag_copy = tag.copy()
|
||||
if "model_info" in tag_copy:
|
||||
del tag_copy["model_info"]
|
||||
tags_config_dict[name] = tag_copy
|
||||
|
||||
json_tags_config = json.dumps(tags_config_dict, default=str)
|
||||
verbose_proxy_logger.debug(f"JSON tags config: {json_tags_config}")
|
||||
await prisma_client.db.litellm_config.upsert(
|
||||
where={"param_name": "tags_config"},
|
||||
data={
|
||||
"create": {
|
||||
"param_name": "tags_config",
|
||||
"param_value": json_tags_config,
|
||||
},
|
||||
"update": {"param_value": json_tags_config},
|
||||
},
|
||||
)
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=500, detail=f"Error saving tags config: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/tag/new",
|
||||
tags=["tag management"],
|
||||
dependencies=[Depends(user_api_key_auth)],
|
||||
)
|
||||
async def new_tag(
|
||||
tag: TagNewRequest,
|
||||
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
||||
):
|
||||
"""
|
||||
Create a new tag.
|
||||
|
||||
Parameters:
|
||||
- name: str - The name of the tag
|
||||
- description: Optional[str] - Description of what this tag represents
|
||||
- models: List[str] - List of LLM models allowed for this tag
|
||||
"""
|
||||
from litellm.proxy.proxy_server import prisma_client
|
||||
|
||||
if prisma_client is None:
|
||||
raise HTTPException(status_code=500, detail="Database not connected")
|
||||
try:
|
||||
# Get existing tags config
|
||||
tags_config = await _get_tags_config(prisma_client)
|
||||
|
||||
# Check if tag already exists
|
||||
if tag.name in tags_config:
|
||||
raise HTTPException(
|
||||
status_code=400, detail=f"Tag {tag.name} already exists"
|
||||
)
|
||||
|
||||
# Add new tag
|
||||
tags_config[tag.name] = TagConfig(
|
||||
name=tag.name,
|
||||
description=tag.description,
|
||||
models=tag.models,
|
||||
created_at=str(datetime.datetime.now()),
|
||||
updated_at=str(datetime.datetime.now()),
|
||||
created_by=user_api_key_dict.user_id,
|
||||
)
|
||||
|
||||
# Save updated config
|
||||
await _save_tags_config(
|
||||
prisma_client=prisma_client,
|
||||
tags_config=tags_config,
|
||||
)
|
||||
|
||||
# Update models with new tag
|
||||
if tag.models:
|
||||
for model_id in tag.models:
|
||||
await _add_tag_to_deployment(
|
||||
model_id=model_id,
|
||||
tag=tag.name,
|
||||
)
|
||||
|
||||
# Get model names for response
|
||||
model_info = await _get_model_names(prisma_client, tag.models or [])
|
||||
tags_config[tag.name].model_info = model_info
|
||||
|
||||
return {
|
||||
"message": f"Tag {tag.name} created successfully",
|
||||
"tag": tags_config[tag.name],
|
||||
}
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.exception(f"Error creating tag: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
async def _add_tag_to_deployment(model_id: str, tag: str):
|
||||
"""Helper function to add tag to deployment"""
|
||||
from litellm.proxy.proxy_server import prisma_client
|
||||
|
||||
if prisma_client is None:
|
||||
raise HTTPException(status_code=500, detail="Database not connected")
|
||||
|
||||
deployment = await prisma_client.db.litellm_proxymodeltable.find_unique(
|
||||
where={"model_id": model_id}
|
||||
)
|
||||
if deployment is None:
|
||||
raise HTTPException(status_code=404, detail=f"Deployment {model_id} not found")
|
||||
|
||||
litellm_params = deployment.litellm_params
|
||||
if "tags" not in litellm_params:
|
||||
litellm_params["tags"] = []
|
||||
litellm_params["tags"].append(tag)
|
||||
await prisma_client.db.litellm_proxymodeltable.update(
|
||||
where={"model_id": model_id},
|
||||
data={"litellm_params": safe_dumps(litellm_params)},
|
||||
)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/tag/update",
|
||||
tags=["tag management"],
|
||||
dependencies=[Depends(user_api_key_auth)],
|
||||
)
|
||||
async def update_tag(
|
||||
tag: TagUpdateRequest,
|
||||
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
||||
):
|
||||
"""
|
||||
Update an existing tag.
|
||||
|
||||
Parameters:
|
||||
- name: str - The name of the tag to update
|
||||
- description: Optional[str] - Updated description
|
||||
- models: List[str] - Updated list of allowed LLM models
|
||||
"""
|
||||
from litellm.proxy.proxy_server import prisma_client
|
||||
|
||||
if prisma_client is None:
|
||||
raise HTTPException(status_code=500, detail="Database not connected")
|
||||
|
||||
try:
|
||||
# Get existing tags config
|
||||
tags_config = await _get_tags_config(prisma_client)
|
||||
|
||||
# Check if tag exists
|
||||
if tag.name not in tags_config:
|
||||
raise HTTPException(status_code=404, detail=f"Tag {tag.name} not found")
|
||||
|
||||
# Update tag
|
||||
tag_config_dict = dict(tags_config[tag.name])
|
||||
tag_config_dict.update(
|
||||
{
|
||||
"description": tag.description,
|
||||
"models": tag.models,
|
||||
"updated_at": str(datetime.datetime.now()),
|
||||
"updated_by": user_api_key_dict.user_id,
|
||||
}
|
||||
)
|
||||
tags_config[tag.name] = TagConfig(**tag_config_dict)
|
||||
|
||||
# Save updated config
|
||||
await _save_tags_config(prisma_client, tags_config)
|
||||
|
||||
# Get model names for response
|
||||
model_info = await _get_model_names(prisma_client, tag.models or [])
|
||||
tags_config[tag.name].model_info = model_info
|
||||
|
||||
return {
|
||||
"message": f"Tag {tag.name} updated successfully",
|
||||
"tag": tags_config[tag.name],
|
||||
}
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.exception(f"Error updating tag: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.post(
|
||||
"/tag/info",
|
||||
tags=["tag management"],
|
||||
dependencies=[Depends(user_api_key_auth)],
|
||||
)
|
||||
async def info_tag(
|
||||
data: TagInfoRequest,
|
||||
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
||||
):
|
||||
"""
|
||||
Get information about specific tags.
|
||||
|
||||
Parameters:
|
||||
- names: List[str] - List of tag names to get information for
|
||||
"""
|
||||
from litellm.proxy.proxy_server import prisma_client
|
||||
|
||||
if prisma_client is None:
|
||||
raise HTTPException(status_code=500, detail="Database not connected")
|
||||
|
||||
try:
|
||||
tags_config = await _get_tags_config(prisma_client)
|
||||
|
||||
# Filter tags based on requested names
|
||||
requested_tags = {name: tags_config.get(name) for name in data.names}
|
||||
|
||||
# Check if any requested tags don't exist
|
||||
missing_tags = [name for name in data.names if name not in tags_config]
|
||||
if missing_tags:
|
||||
raise HTTPException(
|
||||
status_code=404, detail=f"Tags not found: {missing_tags}"
|
||||
)
|
||||
|
||||
return requested_tags
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.get(
|
||||
"/tag/list",
|
||||
tags=["tag management"],
|
||||
dependencies=[Depends(user_api_key_auth)],
|
||||
)
|
||||
async def list_tags(
|
||||
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
||||
):
|
||||
"""
|
||||
List all available tags.
|
||||
"""
|
||||
from litellm.proxy.proxy_server import prisma_client
|
||||
|
||||
if prisma_client is None:
|
||||
raise HTTPException(status_code=500, detail="Database not connected")
|
||||
|
||||
try:
|
||||
tags_config = await _get_tags_config(prisma_client)
|
||||
list_of_tags = list(tags_config.values())
|
||||
return list_of_tags
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.post(
|
||||
"/tag/delete",
|
||||
tags=["tag management"],
|
||||
dependencies=[Depends(user_api_key_auth)],
|
||||
)
|
||||
async def delete_tag(
|
||||
data: TagDeleteRequest,
|
||||
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
||||
):
|
||||
"""
|
||||
Delete a tag.
|
||||
|
||||
Parameters:
|
||||
- name: str - The name of the tag to delete
|
||||
"""
|
||||
from litellm.proxy.proxy_server import prisma_client
|
||||
|
||||
if prisma_client is None:
|
||||
raise HTTPException(status_code=500, detail="Database not connected")
|
||||
|
||||
try:
|
||||
# Get existing tags config
|
||||
tags_config = await _get_tags_config(prisma_client)
|
||||
|
||||
# Check if tag exists
|
||||
if data.name not in tags_config:
|
||||
raise HTTPException(status_code=404, detail=f"Tag {data.name} not found")
|
||||
|
||||
# Delete tag
|
||||
del tags_config[data.name]
|
||||
|
||||
# Save updated config
|
||||
await _save_tags_config(prisma_client, tags_config)
|
||||
|
||||
return {"message": f"Tag {data.name} deleted successfully"}
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
|
@ -11,6 +11,10 @@ model_list:
|
|||
|
||||
litellm_settings:
|
||||
require_auth_for_metrics_endpoint: true
|
||||
|
||||
|
||||
callbacks: ["prometheus"]
|
||||
service_callback: ["prometheus_system"]
|
||||
service_callback: ["prometheus_system"]
|
||||
|
||||
router_settings:
|
||||
enable_tag_filtering: True # 👈 Key Change
|
|
@ -237,6 +237,9 @@ from litellm.proxy.management_endpoints.model_management_endpoints import (
|
|||
from litellm.proxy.management_endpoints.organization_endpoints import (
|
||||
router as organization_router,
|
||||
)
|
||||
from litellm.proxy.management_endpoints.tag_management_endpoints import (
|
||||
router as tag_management_router,
|
||||
)
|
||||
from litellm.proxy.management_endpoints.team_callback_endpoints import (
|
||||
router as team_callback_router,
|
||||
)
|
||||
|
@ -347,13 +350,13 @@ from fastapi import (
|
|||
Request,
|
||||
Response,
|
||||
UploadFile,
|
||||
applications,
|
||||
status,
|
||||
applications
|
||||
)
|
||||
from fastapi.encoders import jsonable_encoder
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.openapi.utils import get_openapi
|
||||
from fastapi.openapi.docs import get_swagger_ui_html
|
||||
from fastapi.openapi.utils import get_openapi
|
||||
from fastapi.responses import (
|
||||
FileResponse,
|
||||
JSONResponse,
|
||||
|
@ -735,7 +738,7 @@ try:
|
|||
|
||||
except Exception:
|
||||
pass
|
||||
# current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
# ui_path = os.path.join(current_dir, "_experimental", "out")
|
||||
# # Mount this test directory instead
|
||||
# app.mount("/ui", StaticFiles(directory=ui_path, html=True), name="ui")
|
||||
|
@ -753,14 +756,18 @@ app.add_middleware(PrometheusAuthMiddleware)
|
|||
|
||||
swagger_path = os.path.join(current_dir, "swagger")
|
||||
app.mount("/swagger", StaticFiles(directory=swagger_path), name="swagger")
|
||||
|
||||
|
||||
def swagger_monkey_patch(*args, **kwargs):
|
||||
return get_swagger_ui_html(
|
||||
*args,
|
||||
**kwargs,
|
||||
swagger_js_url="/swagger/swagger-ui-bundle.js",
|
||||
swagger_css_url="/swagger/swagger-ui.css",
|
||||
swagger_favicon_url="/swagger/favicon.png"
|
||||
swagger_favicon_url="/swagger/favicon.png",
|
||||
)
|
||||
|
||||
|
||||
applications.get_swagger_ui_html = swagger_monkey_patch
|
||||
|
||||
from typing import Dict
|
||||
|
@ -8174,3 +8181,4 @@ app.include_router(openai_files_router)
|
|||
app.include_router(team_callback_router)
|
||||
app.include_router(budget_management_router)
|
||||
app.include_router(model_management_router)
|
||||
app.include_router(tag_management_router)
|
||||
|
|
32
litellm/types/tag_management.py
Normal file
32
litellm/types/tag_management.py
Normal file
|
@ -0,0 +1,32 @@
|
|||
from typing import Dict, List, Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class TagBase(BaseModel):
|
||||
name: str
|
||||
description: Optional[str] = None
|
||||
models: Optional[List[str]] = None
|
||||
model_info: Optional[Dict[str, str]] = None # maps model_id to model_name
|
||||
|
||||
|
||||
class TagConfig(TagBase):
|
||||
created_at: str
|
||||
updated_at: str
|
||||
created_by: Optional[str] = None
|
||||
|
||||
|
||||
class TagNewRequest(TagBase):
|
||||
pass
|
||||
|
||||
|
||||
class TagUpdateRequest(TagBase):
|
||||
pass
|
||||
|
||||
|
||||
class TagDeleteRequest(BaseModel):
|
||||
name: str
|
||||
|
||||
|
||||
class TagInfoRequest(BaseModel):
|
||||
names: List[str]
|
|
@ -0,0 +1,160 @@
|
|||
import json
|
||||
import os
|
||||
import sys
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
import pytest
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
sys.path.insert(
|
||||
0, os.path.abspath("../../../..")
|
||||
) # Adds the parent directory to the system path
|
||||
|
||||
from unittest.mock import patch
|
||||
|
||||
import litellm
|
||||
from litellm.proxy.proxy_server import app
|
||||
from litellm.types.tag_management import TagDeleteRequest, TagInfoRequest, TagNewRequest
|
||||
|
||||
client = TestClient(app)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_and_get_tag():
|
||||
"""
|
||||
Test creation of a new tag and retrieving its information
|
||||
"""
|
||||
# Mock the prisma client and _get_tags_config and _save_tags_config
|
||||
with patch("litellm.proxy.proxy_server.prisma_client") as mock_prisma, patch(
|
||||
"litellm.proxy.management_endpoints.tag_management_endpoints._get_tags_config"
|
||||
) as mock_get_tags, patch(
|
||||
"litellm.proxy.management_endpoints.tag_management_endpoints._save_tags_config"
|
||||
) as mock_save_tags, patch(
|
||||
"litellm.proxy.management_endpoints.tag_management_endpoints._add_tag_to_deployment"
|
||||
) as mock_add_tag, patch(
|
||||
"litellm.proxy.management_endpoints.tag_management_endpoints._get_model_names"
|
||||
) as mock_get_models:
|
||||
# Setup mocks
|
||||
mock_get_tags.return_value = {}
|
||||
mock_get_models.return_value = {"model-1": "gpt-3.5-turbo"}
|
||||
|
||||
# Create a new tag
|
||||
tag_data = {
|
||||
"name": "test-tag",
|
||||
"description": "Test tag for unit testing",
|
||||
"models": ["model-1"],
|
||||
}
|
||||
|
||||
# Set admin access for the test
|
||||
headers = {"Authorization": f"Bearer sk-1234"}
|
||||
|
||||
# Test tag creation
|
||||
response = client.post("/tag/new", json=tag_data, headers=headers)
|
||||
assert response.status_code == 200
|
||||
result = response.json()
|
||||
assert result["message"] == "Tag test-tag created successfully"
|
||||
assert result["tag"]["name"] == "test-tag"
|
||||
assert result["tag"]["description"] == "Test tag for unit testing"
|
||||
|
||||
# Mock updated tag config for the get request
|
||||
mock_get_tags.return_value = {
|
||||
"test-tag": {
|
||||
"name": "test-tag",
|
||||
"description": "Test tag for unit testing",
|
||||
"models": ["model-1"],
|
||||
"model_info": {"model-1": "gpt-3.5-turbo"},
|
||||
}
|
||||
}
|
||||
|
||||
# Test retrieving tag info
|
||||
info_data = {"names": ["test-tag"]}
|
||||
response = client.post("/tag/info", json=info_data, headers=headers)
|
||||
assert response.status_code == 200
|
||||
result = response.json()
|
||||
assert "test-tag" in result
|
||||
assert result["test-tag"]["description"] == "Test tag for unit testing"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_tag():
|
||||
"""
|
||||
Test updating an existing tag
|
||||
"""
|
||||
# Mock the prisma client and _get_tags_config and _save_tags_config
|
||||
with patch("litellm.proxy.proxy_server.prisma_client") as mock_prisma, patch(
|
||||
"litellm.proxy.management_endpoints.tag_management_endpoints._get_tags_config"
|
||||
) as mock_get_tags, patch(
|
||||
"litellm.proxy.management_endpoints.tag_management_endpoints._save_tags_config"
|
||||
) as mock_save_tags, patch(
|
||||
"litellm.proxy.management_endpoints.tag_management_endpoints._get_model_names"
|
||||
) as mock_get_models:
|
||||
# Setup mocks for existing tag
|
||||
mock_get_tags.return_value = {
|
||||
"test-tag": {
|
||||
"name": "test-tag",
|
||||
"description": "Original description",
|
||||
"models": ["model-1"],
|
||||
"created_at": "2023-01-01T00:00:00",
|
||||
"updated_at": "2023-01-01T00:00:00",
|
||||
"created_by": "user-123",
|
||||
}
|
||||
}
|
||||
mock_get_models.return_value = {"model-1": "gpt-3.5-turbo", "model-2": "gpt-4"}
|
||||
|
||||
# Update tag data
|
||||
update_data = {
|
||||
"name": "test-tag",
|
||||
"description": "Updated description",
|
||||
"models": ["model-1", "model-2"],
|
||||
}
|
||||
|
||||
# Set admin access for the test
|
||||
headers = {"Authorization": f"Bearer sk-1234"}
|
||||
|
||||
# Test tag update
|
||||
response = client.post("/tag/update", json=update_data, headers=headers)
|
||||
assert response.status_code == 200
|
||||
result = response.json()
|
||||
assert result["message"] == "Tag test-tag updated successfully"
|
||||
assert result["tag"]["description"] == "Updated description"
|
||||
assert len(result["tag"]["models"]) == 2
|
||||
assert "model-2" in result["tag"]["models"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_delete_tag():
|
||||
"""
|
||||
Test deleting a tag
|
||||
"""
|
||||
# Mock the prisma client and _get_tags_config and _save_tags_config
|
||||
with patch("litellm.proxy.proxy_server.prisma_client") as mock_prisma, patch(
|
||||
"litellm.proxy.management_endpoints.tag_management_endpoints._get_tags_config"
|
||||
) as mock_get_tags, patch(
|
||||
"litellm.proxy.management_endpoints.tag_management_endpoints._save_tags_config"
|
||||
) as mock_save_tags:
|
||||
# Setup mocks for existing tag
|
||||
mock_get_tags.return_value = {
|
||||
"test-tag": {
|
||||
"name": "test-tag",
|
||||
"description": "Test tag for deletion",
|
||||
"models": ["model-1"],
|
||||
"created_at": "2023-01-01T00:00:00",
|
||||
"updated_at": "2023-01-01T00:00:00",
|
||||
"created_by": "user-123",
|
||||
}
|
||||
}
|
||||
|
||||
# Delete tag data
|
||||
delete_data = {"name": "test-tag"}
|
||||
|
||||
# Set admin access for the test
|
||||
headers = {"Authorization": f"Bearer sk-1234"}
|
||||
|
||||
# Test tag deletion
|
||||
response = client.post("/tag/delete", json=delete_data, headers=headers)
|
||||
assert response.status_code == 200
|
||||
result = response.json()
|
||||
assert result["message"] == "Tag test-tag deleted successfully"
|
||||
|
||||
# Verify _save_tags_config was called without the deleted tag
|
||||
mock_save_tags.assert_called_once()
|
|
@ -33,6 +33,7 @@ import TransformRequestPanel from "@/components/transform_request";
|
|||
import { fetchUserModels } from "@/components/create_key_button";
|
||||
import { fetchTeams } from "@/components/common_components/fetch_teams";
|
||||
import MCPToolsViewer from "@/components/mcp_tools";
|
||||
import TagManagement from "@/components/tag_management";
|
||||
|
||||
function getCookie(name: string) {
|
||||
const cookieValue = document.cookie
|
||||
|
@ -355,6 +356,12 @@ export default function CreateKeyPage() {
|
|||
userRole={userRole}
|
||||
userID={userID}
|
||||
/>
|
||||
) : page == "tag-management" ? (
|
||||
<TagManagement
|
||||
accessToken={accessToken}
|
||||
userRole={userRole}
|
||||
userID={userID}
|
||||
/>
|
||||
) : page == "new_usage" ? (
|
||||
<NewUsagePage
|
||||
userID={userID}
|
||||
|
|
|
@ -31,6 +31,7 @@ import { litellmModeMapping, ModelMode, EndpointType, getEndpointType } from "./
|
|||
import { Prism as SyntaxHighlighter } from "react-syntax-highlighter";
|
||||
import { coy } from 'react-syntax-highlighter/dist/esm/styles/prism';
|
||||
import EndpointSelector from "./chat_ui/EndpointSelector";
|
||||
import TagSelector from "./tag_management/TagSelector";
|
||||
import { determineEndpointType } from "./chat_ui/EndpointUtils";
|
||||
import {
|
||||
SendOutlined,
|
||||
|
@ -40,7 +41,8 @@ import {
|
|||
RobotOutlined,
|
||||
UserOutlined,
|
||||
DeleteOutlined,
|
||||
LoadingOutlined
|
||||
LoadingOutlined,
|
||||
TagsOutlined
|
||||
} from "@ant-design/icons";
|
||||
|
||||
interface ChatUIProps {
|
||||
|
@ -73,6 +75,7 @@ const ChatUI: React.FC<ChatUIProps> = ({
|
|||
const [endpointType, setEndpointType] = useState<string>(EndpointType.CHAT);
|
||||
const [isLoading, setIsLoading] = useState<boolean>(false);
|
||||
const abortControllerRef = useRef<AbortController | null>(null);
|
||||
const [selectedTags, setSelectedTags] = useState<string[]>([]);
|
||||
|
||||
const chatEndRef = useRef<HTMLDivElement>(null);
|
||||
|
||||
|
@ -202,6 +205,7 @@ const ChatUI: React.FC<ChatUIProps> = ({
|
|||
(chunk, model) => updateTextUI("assistant", chunk, model),
|
||||
selectedModel,
|
||||
effectiveApiKey,
|
||||
selectedTags,
|
||||
signal
|
||||
);
|
||||
} else if (endpointType === EndpointType.IMAGE) {
|
||||
|
@ -211,6 +215,7 @@ const ChatUI: React.FC<ChatUIProps> = ({
|
|||
(imageUrl, model) => updateImageUI(imageUrl, model),
|
||||
selectedModel,
|
||||
effectiveApiKey,
|
||||
selectedTags,
|
||||
signal
|
||||
);
|
||||
}
|
||||
|
@ -343,6 +348,18 @@ const ChatUI: React.FC<ChatUIProps> = ({
|
|||
endpointType={endpointType}
|
||||
onEndpointChange={handleEndpointChange}
|
||||
className="mb-4"
|
||||
/>
|
||||
</div>
|
||||
|
||||
<div>
|
||||
<Text className="font-medium block mb-2 text-gray-700 flex items-center">
|
||||
<TagsOutlined className="mr-2" /> Tags
|
||||
</Text>
|
||||
<TagSelector
|
||||
value={selectedTags}
|
||||
onChange={setSelectedTags}
|
||||
className="mb-4"
|
||||
accessToken={accessToken || ""}
|
||||
/>
|
||||
</div>
|
||||
|
||||
|
|
|
@ -7,6 +7,7 @@ export async function makeOpenAIChatCompletionRequest(
|
|||
updateUI: (chunk: string, model: string) => void,
|
||||
selectedModel: string,
|
||||
accessToken: string,
|
||||
tags?: string[],
|
||||
signal?: AbortSignal
|
||||
) {
|
||||
// base url should be the current base_url
|
||||
|
@ -22,6 +23,7 @@ export async function makeOpenAIChatCompletionRequest(
|
|||
apiKey: accessToken, // Replace with your OpenAI API key
|
||||
baseURL: proxyBaseUrl, // Replace with your OpenAI API base URL
|
||||
dangerouslyAllowBrowser: true, // using a temporary litellm proxy key
|
||||
defaultHeaders: tags && tags.length > 0 ? { 'x-litellm-tags': tags.join(',') } : undefined,
|
||||
});
|
||||
|
||||
try {
|
||||
|
|
|
@ -6,6 +6,7 @@ export async function makeOpenAIImageGenerationRequest(
|
|||
updateUI: (imageUrl: string, model: string) => void,
|
||||
selectedModel: string,
|
||||
accessToken: string,
|
||||
tags?: string[],
|
||||
signal?: AbortSignal
|
||||
) {
|
||||
// base url should be the current base_url
|
||||
|
@ -21,6 +22,7 @@ export async function makeOpenAIImageGenerationRequest(
|
|||
apiKey: accessToken,
|
||||
baseURL: proxyBaseUrl,
|
||||
dangerouslyAllowBrowser: true,
|
||||
defaultHeaders: tags && tags.length > 0 ? { 'x-litellm-tags': tags.join(',') } : undefined,
|
||||
});
|
||||
|
||||
try {
|
||||
|
|
|
@ -22,6 +22,7 @@ import {
|
|||
ThunderboltOutlined,
|
||||
LockOutlined,
|
||||
ToolOutlined,
|
||||
TagsOutlined,
|
||||
} from '@ant-design/icons';
|
||||
import { old_admin_roles, v2_admin_role_names, all_admin_roles, rolesAllowedToSeeUsage, rolesWithWriteAccess, internalUserRoles } from '../utils/roles';
|
||||
|
||||
|
@ -63,9 +64,6 @@ const Sidebar: React.FC<SidebarProps> = ({
|
|||
{ key: "14", page: "api_ref", label: "API Reference", icon: <ApiOutlined /> },
|
||||
{ key: "16", page: "model-hub", label: "Model Hub", icon: <AppstoreOutlined /> },
|
||||
{ key: "15", page: "logs", label: "Logs", icon: <LineChartOutlined />},
|
||||
|
||||
|
||||
|
||||
{
|
||||
key: "experimental",
|
||||
page: "experimental",
|
||||
|
@ -77,6 +75,7 @@ const Sidebar: React.FC<SidebarProps> = ({
|
|||
{ key: "11", page: "guardrails", label: "Guardrails", icon: <SafetyOutlined />, roles: all_admin_roles },
|
||||
{ key: "12", page: "new_usage", label: "New Usage", icon: <BarChartOutlined />, roles: [...all_admin_roles, ...internalUserRoles] },
|
||||
{ key: "18", page: "mcp-tools", label: "MCP Tools", icon: <ToolOutlined />, roles: all_admin_roles },
|
||||
{ key: "19", page: "tag-management", label: "Tag Management", icon: <TagsOutlined />, roles: all_admin_roles },
|
||||
]
|
||||
},
|
||||
{
|
||||
|
|
|
@ -3,6 +3,7 @@
|
|||
*/
|
||||
import { all_admin_roles } from "@/utils/roles";
|
||||
import { message } from "antd";
|
||||
import { TagNewRequest, TagUpdateRequest, TagDeleteRequest, TagInfoRequest, TagListResponse, TagInfoResponse } from "./tag_management/types";
|
||||
|
||||
const isLocal = process.env.NODE_ENV === "development";
|
||||
export const proxyBaseUrl = isLocal ? "http://localhost:4000" : null;
|
||||
|
@ -4155,4 +4156,157 @@ export const callMCPTool = async (accessToken: string, toolName: string, toolArg
|
|||
console.error("Failed to call MCP tool:", error);
|
||||
throw error;
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
export const tagCreateCall = async (
|
||||
accessToken: string,
|
||||
formValues: TagNewRequest
|
||||
): Promise<void> => {
|
||||
try {
|
||||
let url = proxyBaseUrl
|
||||
? `${proxyBaseUrl}/tag/new`
|
||||
: `/tag/new`;
|
||||
|
||||
const response = await fetch(url, {
|
||||
method: "POST",
|
||||
headers: {
|
||||
"Content-Type": "application/json",
|
||||
Authorization: `Bearer ${accessToken}`,
|
||||
},
|
||||
body: JSON.stringify(formValues),
|
||||
});
|
||||
|
||||
if (!response.ok) {
|
||||
const errorData = await response.text();
|
||||
await handleError(errorData);
|
||||
return;
|
||||
}
|
||||
|
||||
return await response.json();
|
||||
} catch (error) {
|
||||
console.error("Error creating tag:", error);
|
||||
throw error;
|
||||
}
|
||||
};
|
||||
|
||||
export const tagUpdateCall = async (
|
||||
accessToken: string,
|
||||
formValues: TagUpdateRequest
|
||||
): Promise<void> => {
|
||||
try {
|
||||
let url = proxyBaseUrl
|
||||
? `${proxyBaseUrl}/tag/update`
|
||||
: `/tag/update`;
|
||||
|
||||
const response = await fetch(url, {
|
||||
method: "POST",
|
||||
headers: {
|
||||
"Content-Type": "application/json",
|
||||
Authorization: `Bearer ${accessToken}`,
|
||||
},
|
||||
body: JSON.stringify(formValues),
|
||||
});
|
||||
|
||||
if (!response.ok) {
|
||||
const errorData = await response.text();
|
||||
await handleError(errorData);
|
||||
return;
|
||||
}
|
||||
|
||||
return await response.json();
|
||||
} catch (error) {
|
||||
console.error("Error updating tag:", error);
|
||||
throw error;
|
||||
}
|
||||
};
|
||||
|
||||
export const tagInfoCall = async (
|
||||
accessToken: string,
|
||||
tagNames: string[]
|
||||
): Promise<TagInfoResponse> => {
|
||||
try {
|
||||
let url = proxyBaseUrl
|
||||
? `${proxyBaseUrl}/tag/info`
|
||||
: `/tag/info`;
|
||||
|
||||
const response = await fetch(url, {
|
||||
method: "POST",
|
||||
headers: {
|
||||
"Content-Type": "application/json",
|
||||
Authorization: `Bearer ${accessToken}`,
|
||||
},
|
||||
body: JSON.stringify({ names: tagNames }),
|
||||
});
|
||||
|
||||
if (!response.ok) {
|
||||
const errorData = await response.text();
|
||||
await handleError(errorData);
|
||||
return {};
|
||||
}
|
||||
|
||||
const data = await response.json();
|
||||
return data as TagInfoResponse;
|
||||
} catch (error) {
|
||||
console.error("Error getting tag info:", error);
|
||||
throw error;
|
||||
}
|
||||
};
|
||||
|
||||
export const tagListCall = async (accessToken: string): Promise<TagListResponse> => {
|
||||
try {
|
||||
let url = proxyBaseUrl
|
||||
? `${proxyBaseUrl}/tag/list`
|
||||
: `/tag/list`;
|
||||
|
||||
const response = await fetch(url, {
|
||||
method: "GET",
|
||||
headers: {
|
||||
Authorization: `Bearer ${accessToken}`,
|
||||
},
|
||||
});
|
||||
|
||||
if (!response.ok) {
|
||||
const errorData = await response.text();
|
||||
await handleError(errorData);
|
||||
return {};
|
||||
}
|
||||
|
||||
const data = await response.json();
|
||||
return data as TagListResponse;
|
||||
} catch (error) {
|
||||
console.error("Error listing tags:", error);
|
||||
throw error;
|
||||
}
|
||||
};
|
||||
|
||||
export const tagDeleteCall = async (
|
||||
accessToken: string,
|
||||
tagName: string
|
||||
): Promise<void> => {
|
||||
try {
|
||||
let url = proxyBaseUrl
|
||||
? `${proxyBaseUrl}/tag/delete`
|
||||
: `/tag/delete`;
|
||||
|
||||
const response = await fetch(url, {
|
||||
method: "POST",
|
||||
headers: {
|
||||
"Content-Type": "application/json",
|
||||
Authorization: `Bearer ${accessToken}`,
|
||||
},
|
||||
body: JSON.stringify({ name: tagName }),
|
||||
});
|
||||
|
||||
if (!response.ok) {
|
||||
const errorData = await response.text();
|
||||
await handleError(errorData);
|
||||
return;
|
||||
}
|
||||
|
||||
return await response.json();
|
||||
} catch (error) {
|
||||
console.error("Error deleting tag:", error);
|
||||
throw error;
|
||||
}
|
||||
};
|
|
@ -0,0 +1,52 @@
|
|||
import React, { useEffect, useState } from 'react';
|
||||
import { Select } from 'antd';
|
||||
import { Tag } from './types';
|
||||
import { tagListCall } from '../networking';
|
||||
|
||||
interface TagSelectorProps {
|
||||
onChange: (selectedTags: string[]) => void;
|
||||
value?: string[];
|
||||
className?: string;
|
||||
accessToken: string;
|
||||
}
|
||||
|
||||
const TagSelector: React.FC<TagSelectorProps> = ({ onChange, value, className, accessToken }) => {
|
||||
const [tags, setTags] = useState<Tag[]>([]);
|
||||
const [loading, setLoading] = useState(false);
|
||||
|
||||
useEffect(() => {
|
||||
const fetchTags = async () => {
|
||||
if (!accessToken) return;
|
||||
try {
|
||||
const response = await tagListCall(accessToken);
|
||||
console.log("List tags response:", response);
|
||||
setTags(Object.values(response));
|
||||
} catch (error) {
|
||||
console.error("Error fetching tags:", error);
|
||||
}
|
||||
};
|
||||
|
||||
fetchTags();
|
||||
}, []);
|
||||
|
||||
return (
|
||||
<Select
|
||||
mode="multiple"
|
||||
placeholder="Select tags"
|
||||
onChange={onChange}
|
||||
value={value}
|
||||
loading={loading}
|
||||
className={className}
|
||||
options={tags.map(tag => ({
|
||||
label: tag.name,
|
||||
value: tag.name,
|
||||
title: tag.description || tag.name,
|
||||
}))}
|
||||
optionFilterProp="label"
|
||||
showSearch
|
||||
style={{ width: '100%' }}
|
||||
/>
|
||||
);
|
||||
};
|
||||
|
||||
export default TagSelector;
|
244
ui/litellm-dashboard/src/components/tag_management/TagTable.tsx
Normal file
244
ui/litellm-dashboard/src/components/tag_management/TagTable.tsx
Normal file
|
@ -0,0 +1,244 @@
|
|||
import React from "react";
|
||||
import {
|
||||
Table,
|
||||
TableBody,
|
||||
TableCell,
|
||||
TableHead,
|
||||
TableHeaderCell,
|
||||
TableRow,
|
||||
Icon,
|
||||
Button,
|
||||
Badge,
|
||||
Text,
|
||||
} from "@tremor/react";
|
||||
import {
|
||||
PencilAltIcon,
|
||||
TrashIcon,
|
||||
SwitchVerticalIcon,
|
||||
ChevronUpIcon,
|
||||
ChevronDownIcon,
|
||||
} from "@heroicons/react/outline";
|
||||
import { Tooltip } from "antd";
|
||||
import {
|
||||
ColumnDef,
|
||||
flexRender,
|
||||
getCoreRowModel,
|
||||
getSortedRowModel,
|
||||
SortingState,
|
||||
useReactTable,
|
||||
} from "@tanstack/react-table";
|
||||
import { Tag } from "./types";
|
||||
|
||||
interface TagTableProps {
|
||||
data: Tag[];
|
||||
onEdit: (tag: Tag) => void;
|
||||
onDelete: (tagName: string) => void;
|
||||
onSelectTag: (tagName: string) => void;
|
||||
}
|
||||
|
||||
const TagTable: React.FC<TagTableProps> = ({
|
||||
data,
|
||||
onEdit,
|
||||
onDelete,
|
||||
onSelectTag,
|
||||
}) => {
|
||||
const [sorting, setSorting] = React.useState<SortingState>([
|
||||
{ id: "created_at", desc: true }
|
||||
]);
|
||||
|
||||
const columns: ColumnDef<Tag>[] = [
|
||||
{
|
||||
header: "Tag Name",
|
||||
accessorKey: "name",
|
||||
cell: ({ row }) => {
|
||||
const tag = row.original;
|
||||
return (
|
||||
<div className="overflow-hidden">
|
||||
<Tooltip title={tag.name}>
|
||||
<Button
|
||||
size="xs"
|
||||
variant="light"
|
||||
className="font-mono text-blue-500 bg-blue-50 hover:bg-blue-100 text-xs font-normal px-2 py-0.5"
|
||||
onClick={() => onSelectTag(tag.name)}
|
||||
>
|
||||
{tag.name}
|
||||
</Button>
|
||||
</Tooltip>
|
||||
</div>
|
||||
);
|
||||
},
|
||||
},
|
||||
{
|
||||
header: "Description",
|
||||
accessorKey: "description",
|
||||
cell: ({ row }) => {
|
||||
const tag = row.original;
|
||||
return (
|
||||
<Tooltip title={tag.description}>
|
||||
<span className="text-xs">
|
||||
{tag.description || "-"}
|
||||
</span>
|
||||
</Tooltip>
|
||||
);
|
||||
},
|
||||
},
|
||||
{
|
||||
header: "Allowed LLMs",
|
||||
accessorKey: "models",
|
||||
cell: ({ row }) => {
|
||||
const tag = row.original;
|
||||
return (
|
||||
<div style={{ display: "flex", flexDirection: "column" }}>
|
||||
{tag?.models?.length === 0 ? (
|
||||
<Badge size="xs" className="mb-1" color="red">
|
||||
All Models
|
||||
</Badge>
|
||||
) : (
|
||||
tag?.models?.map((modelId) => (
|
||||
<Badge
|
||||
key={modelId}
|
||||
size="xs"
|
||||
className="mb-1"
|
||||
color="blue"
|
||||
>
|
||||
<Tooltip title={`ID: ${modelId}`}>
|
||||
<Text>
|
||||
{tag.model_info?.[modelId] || modelId}
|
||||
</Text>
|
||||
</Tooltip>
|
||||
</Badge>
|
||||
))
|
||||
)}
|
||||
</div>
|
||||
);
|
||||
},
|
||||
},
|
||||
{
|
||||
header: "Created",
|
||||
accessorKey: "created_at",
|
||||
sortingFn: "datetime",
|
||||
cell: ({ row }) => {
|
||||
const tag = row.original;
|
||||
return (
|
||||
<span className="text-xs">
|
||||
{new Date(tag.created_at).toLocaleDateString()}
|
||||
</span>
|
||||
);
|
||||
},
|
||||
},
|
||||
{
|
||||
id: "actions",
|
||||
header: "",
|
||||
cell: ({ row }) => {
|
||||
const tag = row.original;
|
||||
return (
|
||||
<div className="flex space-x-2">
|
||||
<Icon
|
||||
icon={PencilAltIcon}
|
||||
size="sm"
|
||||
onClick={() => onEdit(tag)}
|
||||
className="cursor-pointer"
|
||||
/>
|
||||
<Icon
|
||||
icon={TrashIcon}
|
||||
size="sm"
|
||||
onClick={() => onDelete(tag.name)}
|
||||
className="cursor-pointer"
|
||||
/>
|
||||
</div>
|
||||
);
|
||||
},
|
||||
},
|
||||
];
|
||||
|
||||
const table = useReactTable({
|
||||
data,
|
||||
columns,
|
||||
state: {
|
||||
sorting,
|
||||
},
|
||||
onSortingChange: setSorting,
|
||||
getCoreRowModel: getCoreRowModel(),
|
||||
getSortedRowModel: getSortedRowModel(),
|
||||
enableSorting: true,
|
||||
});
|
||||
|
||||
return (
|
||||
<div className="rounded-lg custom-border relative">
|
||||
<div className="overflow-x-auto">
|
||||
<Table className="[&_td]:py-0.5 [&_th]:py-1">
|
||||
<TableHead>
|
||||
{table.getHeaderGroups().map((headerGroup) => (
|
||||
<TableRow key={headerGroup.id}>
|
||||
{headerGroup.headers.map((header) => (
|
||||
<TableHeaderCell
|
||||
key={header.id}
|
||||
className={`py-1 h-8 ${
|
||||
header.id === 'actions'
|
||||
? 'sticky right-0 bg-white shadow-[-4px_0_8px_-6px_rgba(0,0,0,0.1)]'
|
||||
: ''
|
||||
}`}
|
||||
onClick={header.column.getToggleSortingHandler()}
|
||||
>
|
||||
<div className="flex items-center justify-between gap-2">
|
||||
<div className="flex items-center">
|
||||
{header.isPlaceholder ? null : (
|
||||
flexRender(
|
||||
header.column.columnDef.header,
|
||||
header.getContext()
|
||||
)
|
||||
)}
|
||||
</div>
|
||||
{header.id !== 'actions' && (
|
||||
<div className="w-4">
|
||||
{header.column.getIsSorted() ? (
|
||||
{
|
||||
asc: <ChevronUpIcon className="h-4 w-4 text-blue-500" />,
|
||||
desc: <ChevronDownIcon className="h-4 w-4 text-blue-500" />
|
||||
}[header.column.getIsSorted() as string]
|
||||
) : (
|
||||
<SwitchVerticalIcon className="h-4 w-4 text-gray-400" />
|
||||
)}
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
</TableHeaderCell>
|
||||
))}
|
||||
</TableRow>
|
||||
))}
|
||||
</TableHead>
|
||||
<TableBody>
|
||||
{table.getRowModel().rows.length > 0 ? (
|
||||
table.getRowModel().rows.map((row) => (
|
||||
<TableRow key={row.id} className="h-8">
|
||||
{row.getVisibleCells().map((cell) => (
|
||||
<TableCell
|
||||
key={cell.id}
|
||||
className={`py-0.5 max-h-8 overflow-hidden text-ellipsis whitespace-nowrap ${
|
||||
cell.column.id === 'actions'
|
||||
? 'sticky right-0 bg-white shadow-[-4px_0_8px_-6px_rgba(0,0,0,0.1)]'
|
||||
: ''
|
||||
}`}
|
||||
>
|
||||
{flexRender(cell.column.columnDef.cell, cell.getContext())}
|
||||
</TableCell>
|
||||
))}
|
||||
</TableRow>
|
||||
))
|
||||
) : (
|
||||
<TableRow>
|
||||
<TableCell colSpan={columns.length} className="h-8 text-center">
|
||||
<div className="text-center text-gray-500">
|
||||
<p>No tags found</p>
|
||||
</div>
|
||||
</TableCell>
|
||||
</TableRow>
|
||||
)}
|
||||
</TableBody>
|
||||
</Table>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
};
|
||||
|
||||
export default TagTable;
|
309
ui/litellm-dashboard/src/components/tag_management/index.tsx
Normal file
309
ui/litellm-dashboard/src/components/tag_management/index.tsx
Normal file
|
@ -0,0 +1,309 @@
|
|||
import React, { useState, useEffect } from "react";
|
||||
import {
|
||||
Card,
|
||||
Icon,
|
||||
Button,
|
||||
Col,
|
||||
Text,
|
||||
Grid,
|
||||
TextInput,
|
||||
} from "@tremor/react";
|
||||
import {
|
||||
InformationCircleIcon,
|
||||
RefreshIcon,
|
||||
} from "@heroicons/react/outline";
|
||||
import {
|
||||
Modal,
|
||||
Form,
|
||||
Select as Select2,
|
||||
message,
|
||||
Tooltip,
|
||||
Input
|
||||
} from "antd";
|
||||
import { InfoCircleOutlined } from '@ant-design/icons';
|
||||
import NumericalInput from "../shared/numerical_input";
|
||||
import TagInfoView from "./tag_info";
|
||||
import { modelInfoCall } from "../networking";
|
||||
import { tagCreateCall, tagListCall, tagDeleteCall } from "../networking";
|
||||
import { Tag } from "./types";
|
||||
import TagTable from "./TagTable";
|
||||
|
||||
interface ModelInfo {
|
||||
model_name: string;
|
||||
litellm_params: {
|
||||
model: string;
|
||||
};
|
||||
model_info: {
|
||||
id: string;
|
||||
};
|
||||
}
|
||||
|
||||
interface TagProps {
|
||||
accessToken: string | null;
|
||||
userID: string | null;
|
||||
userRole: string | null;
|
||||
}
|
||||
|
||||
const TagManagement: React.FC<TagProps> = ({
|
||||
accessToken,
|
||||
userID,
|
||||
userRole,
|
||||
}) => {
|
||||
const [tags, setTags] = useState<Tag[]>([]);
|
||||
const [isCreateModalVisible, setIsCreateModalVisible] = useState(false);
|
||||
const [selectedTagId, setSelectedTagId] = useState<string | null>(null);
|
||||
const [editTag, setEditTag] = useState<boolean>(false);
|
||||
const [isDeleteModalOpen, setIsDeleteModalOpen] = useState(false);
|
||||
const [tagToDelete, setTagToDelete] = useState<string | null>(null);
|
||||
const [lastRefreshed, setLastRefreshed] = useState("");
|
||||
const [form] = Form.useForm();
|
||||
const [availableModels, setAvailableModels] = useState<ModelInfo[]>([]);
|
||||
|
||||
const fetchTags = async () => {
|
||||
if (!accessToken) return;
|
||||
try {
|
||||
const response = await tagListCall(accessToken);
|
||||
console.log("List tags response:", response);
|
||||
setTags(Object.values(response));
|
||||
} catch (error) {
|
||||
console.error("Error fetching tags:", error);
|
||||
message.error("Error fetching tags: " + error);
|
||||
}
|
||||
};
|
||||
|
||||
const handleRefreshClick = () => {
|
||||
fetchTags();
|
||||
const currentDate = new Date();
|
||||
setLastRefreshed(currentDate.toLocaleString());
|
||||
};
|
||||
|
||||
const handleCreate = async (formValues: any) => {
|
||||
if (!accessToken) return;
|
||||
try {
|
||||
await tagCreateCall(accessToken, {
|
||||
name: formValues.tag_name,
|
||||
description: formValues.description,
|
||||
models: formValues.allowed_llms,
|
||||
});
|
||||
message.success("Tag created successfully");
|
||||
setIsCreateModalVisible(false);
|
||||
form.resetFields();
|
||||
fetchTags();
|
||||
} catch (error) {
|
||||
console.error("Error creating tag:", error);
|
||||
message.error("Error creating tag: " + error);
|
||||
}
|
||||
};
|
||||
|
||||
const handleDelete = async (tagName: string) => {
|
||||
setTagToDelete(tagName);
|
||||
setIsDeleteModalOpen(true);
|
||||
};
|
||||
|
||||
const confirmDelete = async () => {
|
||||
if (!accessToken || !tagToDelete) return;
|
||||
try {
|
||||
await tagDeleteCall(accessToken, tagToDelete);
|
||||
message.success("Tag deleted successfully");
|
||||
fetchTags();
|
||||
} catch (error) {
|
||||
console.error("Error deleting tag:", error);
|
||||
message.error("Error deleting tag: " + error);
|
||||
}
|
||||
setIsDeleteModalOpen(false);
|
||||
setTagToDelete(null);
|
||||
};
|
||||
|
||||
useEffect(() => {
|
||||
if (userID && userRole && accessToken) {
|
||||
const fetchModels = async () => {
|
||||
try {
|
||||
const response = await modelInfoCall(accessToken, userID, userRole);
|
||||
if (response && response.data) {
|
||||
setAvailableModels(response.data);
|
||||
}
|
||||
} catch (error) {
|
||||
console.error("Error fetching models:", error);
|
||||
message.error("Error fetching models: " + error);
|
||||
}
|
||||
};
|
||||
fetchModels();
|
||||
}
|
||||
}, [accessToken, userID, userRole]);
|
||||
|
||||
useEffect(() => {
|
||||
fetchTags();
|
||||
}, [accessToken]);
|
||||
|
||||
return (
|
||||
<div className="w-full mx-4 h-[75vh]">
|
||||
{selectedTagId ? (
|
||||
<TagInfoView
|
||||
tagId={selectedTagId}
|
||||
onClose={() => {
|
||||
setSelectedTagId(null);
|
||||
setEditTag(false);
|
||||
}}
|
||||
accessToken={accessToken}
|
||||
is_admin={userRole === "Admin"}
|
||||
editTag={editTag}
|
||||
/>
|
||||
) : (
|
||||
<div className="gap-2 p-8 h-[75vh] w-full mt-2">
|
||||
<div className="flex justify-between mt-2 w-full items-center mb-4">
|
||||
<h1>Tag Management</h1>
|
||||
<div className="flex items-center space-x-2">
|
||||
{lastRefreshed && <Text>Last Refreshed: {lastRefreshed}</Text>}
|
||||
<Icon
|
||||
icon={RefreshIcon}
|
||||
variant="shadow"
|
||||
size="xs"
|
||||
className="self-center cursor-pointer"
|
||||
onClick={handleRefreshClick}
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
|
||||
<Text className="mb-4">
|
||||
Click on a tag name to view and edit its details.
|
||||
|
||||
<p>You can use tags to restrict the usage of certain LLMs based on tags passed in the request. Read more about tag routing <a href="https://docs.litellm.ai/docs/proxy/tag_routing" target="_blank" rel="noopener noreferrer">here</a>.</p>
|
||||
</Text>
|
||||
|
||||
<Button
|
||||
className="mb-4"
|
||||
onClick={() => setIsCreateModalVisible(true)}
|
||||
>
|
||||
+ Create New Tag
|
||||
</Button>
|
||||
|
||||
<Grid numItems={1} className="gap-2 pt-2 pb-2 h-[75vh] w-full mt-2">
|
||||
<Col numColSpan={1}>
|
||||
<TagTable
|
||||
data={tags}
|
||||
onEdit={(tag) => {
|
||||
setSelectedTagId(tag.name);
|
||||
setEditTag(true);
|
||||
}}
|
||||
onDelete={handleDelete}
|
||||
onSelectTag={setSelectedTagId}
|
||||
/>
|
||||
</Col>
|
||||
</Grid>
|
||||
|
||||
{/* Create Tag Modal */}
|
||||
<Modal
|
||||
title="Create New Tag"
|
||||
visible={isCreateModalVisible}
|
||||
width={800}
|
||||
footer={null}
|
||||
onCancel={() => {
|
||||
setIsCreateModalVisible(false);
|
||||
form.resetFields();
|
||||
}}
|
||||
>
|
||||
<Form
|
||||
form={form}
|
||||
onFinish={handleCreate}
|
||||
labelCol={{ span: 8 }}
|
||||
wrapperCol={{ span: 16 }}
|
||||
labelAlign="left"
|
||||
>
|
||||
<Form.Item
|
||||
label="Tag Name"
|
||||
name="tag_name"
|
||||
rules={[{ required: true, message: "Please input a tag name" }]}
|
||||
>
|
||||
<TextInput />
|
||||
</Form.Item>
|
||||
|
||||
<Form.Item
|
||||
label="Description"
|
||||
name="description"
|
||||
>
|
||||
<Input.TextArea rows={4} />
|
||||
</Form.Item>
|
||||
|
||||
<Form.Item
|
||||
label={
|
||||
<span>
|
||||
Allowed Models{' '}
|
||||
<Tooltip title="Select which LLMs are allowed to process requests from this tag">
|
||||
<InfoCircleOutlined style={{ marginLeft: '4px' }} />
|
||||
</Tooltip>
|
||||
</span>
|
||||
}
|
||||
name="allowed_llms"
|
||||
>
|
||||
<Select2
|
||||
mode="multiple"
|
||||
placeholder="Select LLMs"
|
||||
>
|
||||
{availableModels.map((model) => (
|
||||
<Select2.Option key={model.model_info.id} value={model.model_info.id}>
|
||||
<div>
|
||||
<span>{model.model_name}</span>
|
||||
<span className="text-gray-400 ml-2">({model.model_info.id})</span>
|
||||
</div>
|
||||
</Select2.Option>
|
||||
))}
|
||||
</Select2>
|
||||
</Form.Item>
|
||||
|
||||
<div style={{ textAlign: "right", marginTop: "10px" }}>
|
||||
<Button type="submit">
|
||||
Create Tag
|
||||
</Button>
|
||||
</div>
|
||||
</Form>
|
||||
</Modal>
|
||||
|
||||
{/* Delete Confirmation Modal */}
|
||||
{isDeleteModalOpen && (
|
||||
<div className="fixed z-10 inset-0 overflow-y-auto">
|
||||
<div className="flex items-end justify-center min-h-screen pt-4 px-4 pb-20 text-center sm:block sm:p-0">
|
||||
<div className="fixed inset-0 transition-opacity" aria-hidden="true">
|
||||
<div className="absolute inset-0 bg-gray-500 opacity-75"></div>
|
||||
</div>
|
||||
<div className="inline-block align-bottom bg-white rounded-lg text-left overflow-hidden shadow-xl transform transition-all sm:my-8 sm:align-middle sm:max-w-lg sm:w-full">
|
||||
<div className="bg-white px-4 pt-5 pb-4 sm:p-6 sm:pb-4">
|
||||
<div className="sm:flex sm:items-start">
|
||||
<div className="mt-3 text-center sm:mt-0 sm:ml-4 sm:text-left">
|
||||
<h3 className="text-lg leading-6 font-medium text-gray-900">
|
||||
Delete Tag
|
||||
</h3>
|
||||
<div className="mt-2">
|
||||
<p className="text-sm text-gray-500">
|
||||
Are you sure you want to delete this tag?
|
||||
</p>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
<div className="bg-gray-50 px-4 py-3 sm:px-6 sm:flex sm:flex-row-reverse">
|
||||
<Button
|
||||
onClick={confirmDelete}
|
||||
color="red"
|
||||
className="ml-2"
|
||||
>
|
||||
Delete
|
||||
</Button>
|
||||
<Button onClick={() => {
|
||||
setIsDeleteModalOpen(false);
|
||||
setTagToDelete(null);
|
||||
}}>
|
||||
Cancel
|
||||
</Button>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
);
|
||||
};
|
||||
|
||||
export default TagManagement;
|
206
ui/litellm-dashboard/src/components/tag_management/tag_info.tsx
Normal file
206
ui/litellm-dashboard/src/components/tag_management/tag_info.tsx
Normal file
|
@ -0,0 +1,206 @@
|
|||
import React, { useState, useEffect } from "react";
|
||||
import {
|
||||
Card,
|
||||
Text,
|
||||
Title,
|
||||
Button,
|
||||
Badge,
|
||||
} from "@tremor/react";
|
||||
import {
|
||||
Form,
|
||||
Input,
|
||||
Select as Select2,
|
||||
message,
|
||||
Tooltip,
|
||||
} from "antd";
|
||||
import { InfoCircleOutlined } from '@ant-design/icons';
|
||||
import { fetchUserModels } from "../create_key_button";
|
||||
import { getModelDisplayName } from "../key_team_helpers/fetch_available_models_team_key";
|
||||
import { tagInfoCall, tagUpdateCall } from "../networking";
|
||||
import { Tag, TagInfoResponse } from "./types";
|
||||
|
||||
interface TagInfoViewProps {
|
||||
tagId: string;
|
||||
onClose: () => void;
|
||||
accessToken: string | null;
|
||||
is_admin: boolean;
|
||||
editTag: boolean;
|
||||
}
|
||||
|
||||
const TagInfoView: React.FC<TagInfoViewProps> = ({
|
||||
tagId,
|
||||
onClose,
|
||||
accessToken,
|
||||
is_admin,
|
||||
editTag,
|
||||
}) => {
|
||||
const [form] = Form.useForm();
|
||||
const [tagDetails, setTagDetails] = useState<Tag | null>(null);
|
||||
const [isEditing, setIsEditing] = useState<boolean>(editTag);
|
||||
const [userModels, setUserModels] = useState<string[]>([]);
|
||||
|
||||
const fetchTagDetails = async () => {
|
||||
if (!accessToken) return;
|
||||
try {
|
||||
const response = await tagInfoCall(accessToken, [tagId]);
|
||||
const tagData = response[tagId];
|
||||
if (tagData) {
|
||||
setTagDetails(tagData);
|
||||
if (editTag) {
|
||||
form.setFieldsValue({
|
||||
name: tagData.name,
|
||||
description: tagData.description,
|
||||
models: tagData.models,
|
||||
});
|
||||
}
|
||||
}
|
||||
} catch (error) {
|
||||
console.error("Error fetching tag details:", error);
|
||||
message.error("Error fetching tag details: " + error);
|
||||
}
|
||||
};
|
||||
|
||||
useEffect(() => {
|
||||
fetchTagDetails();
|
||||
}, [tagId, accessToken]);
|
||||
|
||||
useEffect(() => {
|
||||
if (accessToken) {
|
||||
// Using dummy values for userID and userRole since they're required by the function
|
||||
// TODO: Pass these as props if needed for the actual API implementation
|
||||
fetchUserModels("dummy-user", "Admin", accessToken, setUserModels);
|
||||
}
|
||||
}, [accessToken]);
|
||||
|
||||
const handleSave = async (values: any) => {
|
||||
if (!accessToken) return;
|
||||
try {
|
||||
await tagUpdateCall(accessToken, {
|
||||
name: values.name,
|
||||
description: values.description,
|
||||
models: values.models,
|
||||
});
|
||||
message.success("Tag updated successfully");
|
||||
setIsEditing(false);
|
||||
fetchTagDetails();
|
||||
} catch (error) {
|
||||
console.error("Error updating tag:", error);
|
||||
message.error("Error updating tag: " + error);
|
||||
}
|
||||
};
|
||||
|
||||
if (!tagDetails) {
|
||||
return <div>Loading...</div>;
|
||||
}
|
||||
|
||||
return (
|
||||
<div className="p-4">
|
||||
<div className="flex justify-between items-center mb-6">
|
||||
<div>
|
||||
<Button onClick={onClose} className="mb-4">← Back to Tags</Button>
|
||||
<Title>Tag Name: {tagDetails.name}</Title>
|
||||
<Text className="text-gray-500">{tagDetails.description || "No description"}</Text>
|
||||
</div>
|
||||
{is_admin && !isEditing && (
|
||||
<Button onClick={() => setIsEditing(true)}>Edit Tag</Button>
|
||||
)}
|
||||
</div>
|
||||
|
||||
{isEditing ? (
|
||||
<Card>
|
||||
<Form
|
||||
form={form}
|
||||
onFinish={handleSave}
|
||||
layout="vertical"
|
||||
initialValues={tagDetails}
|
||||
>
|
||||
<Form.Item
|
||||
label="Tag Name"
|
||||
name="name"
|
||||
rules={[{ required: true, message: "Please input a tag name" }]}
|
||||
>
|
||||
<Input />
|
||||
</Form.Item>
|
||||
|
||||
<Form.Item
|
||||
label="Description"
|
||||
name="description"
|
||||
>
|
||||
<Input.TextArea rows={4} />
|
||||
</Form.Item>
|
||||
|
||||
<Form.Item
|
||||
label={
|
||||
<span>
|
||||
Allowed LLMs{' '}
|
||||
<Tooltip title="Select which LLMs are allowed to process this type of data">
|
||||
<InfoCircleOutlined style={{ marginLeft: '4px' }} />
|
||||
</Tooltip>
|
||||
</span>
|
||||
}
|
||||
name="models"
|
||||
>
|
||||
<Select2
|
||||
mode="multiple"
|
||||
placeholder="Select LLMs"
|
||||
>
|
||||
{userModels.map((modelId) => (
|
||||
<Select2.Option key={modelId} value={modelId}>
|
||||
{getModelDisplayName(modelId)}
|
||||
</Select2.Option>
|
||||
))}
|
||||
</Select2>
|
||||
</Form.Item>
|
||||
|
||||
<div className="flex justify-end space-x-2">
|
||||
<Button onClick={() => setIsEditing(false)}>Cancel</Button>
|
||||
<Button type="submit">Save Changes</Button>
|
||||
</div>
|
||||
</Form>
|
||||
</Card>
|
||||
) : (
|
||||
<div className="space-y-6">
|
||||
<Card>
|
||||
<Title>Tag Details</Title>
|
||||
<div className="space-y-4 mt-4">
|
||||
<div>
|
||||
<Text className="font-medium">Name</Text>
|
||||
<Text>{tagDetails.name}</Text>
|
||||
</div>
|
||||
<div>
|
||||
<Text className="font-medium">Description</Text>
|
||||
<Text>{tagDetails.description || "-"}</Text>
|
||||
</div>
|
||||
<div>
|
||||
<Text className="font-medium">Allowed LLMs</Text>
|
||||
<div className="flex flex-wrap gap-2 mt-2">
|
||||
{tagDetails.models.length === 0 ? (
|
||||
<Badge color="red">All Models</Badge>
|
||||
) : (
|
||||
tagDetails.models.map((modelId) => (
|
||||
<Badge key={modelId} color="blue">
|
||||
<Tooltip title={`ID: ${modelId}`}>
|
||||
{tagDetails.model_info?.[modelId] || modelId}
|
||||
</Tooltip>
|
||||
</Badge>
|
||||
))
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
<div>
|
||||
<Text className="font-medium">Created</Text>
|
||||
<Text>{tagDetails.created_at ? new Date(tagDetails.created_at).toLocaleString() : "-"}</Text>
|
||||
</div>
|
||||
<div>
|
||||
<Text className="font-medium">Last Updated</Text>
|
||||
<Text>{tagDetails.updated_at ? new Date(tagDetails.updated_at).toLocaleString() : "-"}</Text>
|
||||
</div>
|
||||
</div>
|
||||
</Card>
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
);
|
||||
};
|
||||
|
||||
export default TagInfoView;
|
34
ui/litellm-dashboard/src/components/tag_management/types.tsx
Normal file
34
ui/litellm-dashboard/src/components/tag_management/types.tsx
Normal file
|
@ -0,0 +1,34 @@
|
|||
export interface Tag {
|
||||
name: string;
|
||||
description?: string;
|
||||
models: string[]; // model IDs
|
||||
model_info?: { [key: string]: string }; // maps model_id to model_name
|
||||
created_at: string;
|
||||
updated_at: string;
|
||||
created_by?: string;
|
||||
updated_by?: string;
|
||||
}
|
||||
|
||||
export interface TagInfoRequest {
|
||||
names: string[];
|
||||
}
|
||||
|
||||
export interface TagNewRequest {
|
||||
name: string;
|
||||
description?: string;
|
||||
models: string[];
|
||||
}
|
||||
|
||||
export interface TagUpdateRequest {
|
||||
name: string;
|
||||
description?: string;
|
||||
models: string[];
|
||||
}
|
||||
|
||||
export interface TagDeleteRequest {
|
||||
name: string;
|
||||
}
|
||||
|
||||
// The API returns a dictionary of tags where the key is the tag name
|
||||
export type TagListResponse = Record<string, Tag>;
|
||||
export type TagInfoResponse = Record<string, Tag>;
|
Loading…
Add table
Add a link
Reference in a new issue