diff --git a/litellm/proxy/management_endpoints/tag_management_endpoints.py b/litellm/proxy/management_endpoints/tag_management_endpoints.py new file mode 100644 index 0000000000..014a1f3c57 --- /dev/null +++ b/litellm/proxy/management_endpoints/tag_management_endpoints.py @@ -0,0 +1,356 @@ +""" +TAG MANAGEMENT + +All /tag management endpoints + +/tag/new +/tag/info +/tag/update +/tag/delete +/tag/list +""" + +import datetime +import json +from typing import Dict + +from fastapi import APIRouter, Depends, HTTPException + +from litellm._logging import verbose_proxy_logger +from litellm.litellm_core_utils.safe_json_dumps import safe_dumps +from litellm.proxy._types import UserAPIKeyAuth +from litellm.proxy.auth.user_api_key_auth import user_api_key_auth +from litellm.types.tag_management import ( + TagConfig, + TagDeleteRequest, + TagInfoRequest, + TagNewRequest, + TagUpdateRequest, +) + +router = APIRouter() + + +async def _get_model_names(prisma_client, model_ids: list) -> Dict[str, str]: + """Helper function to get model names from model IDs""" + try: + models = await prisma_client.db.litellm_proxymodeltable.find_many( + where={"model_id": {"in": model_ids}} + ) + return {model.model_id: model.model_name for model in models} + except Exception as e: + verbose_proxy_logger.error(f"Error getting model names: {str(e)}") + return {} + + +async def _get_tags_config(prisma_client) -> Dict[str, TagConfig]: + """Helper function to get tags config from db""" + try: + tags_config = await prisma_client.db.litellm_config.find_unique( + where={"param_name": "tags_config"} + ) + if tags_config is None: + return {} + # Convert from JSON if needed + if isinstance(tags_config.param_value, str): + config_dict = json.loads(tags_config.param_value) + else: + config_dict = tags_config.param_value or {} + + # For each tag, get the model names + for tag_name, tag_config in config_dict.items(): + if isinstance(tag_config, dict) and tag_config.get("models"): + model_info = await _get_model_names(prisma_client, tag_config["models"]) + tag_config["model_info"] = model_info + + return config_dict + except Exception: + return {} + + +async def _save_tags_config(prisma_client, tags_config: Dict[str, TagConfig]): + """Helper function to save tags config to db""" + try: + verbose_proxy_logger.debug(f"Saving tags config: {tags_config}") + # Convert TagConfig objects to dictionaries + tags_config_dict = {} + for name, tag in tags_config.items(): + if isinstance(tag, TagConfig): + tag_dict = tag.model_dump() + # Remove model_info before saving as it will be dynamically generated + if "model_info" in tag_dict: + del tag_dict["model_info"] + tags_config_dict[name] = tag_dict + else: + # If it's already a dict, remove model_info + tag_copy = tag.copy() + if "model_info" in tag_copy: + del tag_copy["model_info"] + tags_config_dict[name] = tag_copy + + json_tags_config = json.dumps(tags_config_dict, default=str) + verbose_proxy_logger.debug(f"JSON tags config: {json_tags_config}") + await prisma_client.db.litellm_config.upsert( + where={"param_name": "tags_config"}, + data={ + "create": { + "param_name": "tags_config", + "param_value": json_tags_config, + }, + "update": {"param_value": json_tags_config}, + }, + ) + except Exception as e: + raise HTTPException( + status_code=500, detail=f"Error saving tags config: {str(e)}" + ) + + +@router.post( + "/tag/new", + tags=["tag management"], + dependencies=[Depends(user_api_key_auth)], +) +async def new_tag( + tag: TagNewRequest, + user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), +): + """ + Create a new tag. + + Parameters: + - name: str - The name of the tag + - description: Optional[str] - Description of what this tag represents + - models: List[str] - List of LLM models allowed for this tag + """ + from litellm.proxy.proxy_server import prisma_client + + if prisma_client is None: + raise HTTPException(status_code=500, detail="Database not connected") + try: + # Get existing tags config + tags_config = await _get_tags_config(prisma_client) + + # Check if tag already exists + if tag.name in tags_config: + raise HTTPException( + status_code=400, detail=f"Tag {tag.name} already exists" + ) + + # Add new tag + tags_config[tag.name] = TagConfig( + name=tag.name, + description=tag.description, + models=tag.models, + created_at=str(datetime.datetime.now()), + updated_at=str(datetime.datetime.now()), + created_by=user_api_key_dict.user_id, + ) + + # Save updated config + await _save_tags_config( + prisma_client=prisma_client, + tags_config=tags_config, + ) + + # Update models with new tag + if tag.models: + for model_id in tag.models: + await _add_tag_to_deployment( + model_id=model_id, + tag=tag.name, + ) + + # Get model names for response + model_info = await _get_model_names(prisma_client, tag.models or []) + tags_config[tag.name].model_info = model_info + + return { + "message": f"Tag {tag.name} created successfully", + "tag": tags_config[tag.name], + } + except Exception as e: + verbose_proxy_logger.exception(f"Error creating tag: {str(e)}") + raise HTTPException(status_code=500, detail=str(e)) + + +async def _add_tag_to_deployment(model_id: str, tag: str): + """Helper function to add tag to deployment""" + from litellm.proxy.proxy_server import prisma_client + + if prisma_client is None: + raise HTTPException(status_code=500, detail="Database not connected") + + deployment = await prisma_client.db.litellm_proxymodeltable.find_unique( + where={"model_id": model_id} + ) + if deployment is None: + raise HTTPException(status_code=404, detail=f"Deployment {model_id} not found") + + litellm_params = deployment.litellm_params + if "tags" not in litellm_params: + litellm_params["tags"] = [] + litellm_params["tags"].append(tag) + await prisma_client.db.litellm_proxymodeltable.update( + where={"model_id": model_id}, + data={"litellm_params": safe_dumps(litellm_params)}, + ) + + +@router.post( + "/tag/update", + tags=["tag management"], + dependencies=[Depends(user_api_key_auth)], +) +async def update_tag( + tag: TagUpdateRequest, + user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), +): + """ + Update an existing tag. + + Parameters: + - name: str - The name of the tag to update + - description: Optional[str] - Updated description + - models: List[str] - Updated list of allowed LLM models + """ + from litellm.proxy.proxy_server import prisma_client + + if prisma_client is None: + raise HTTPException(status_code=500, detail="Database not connected") + + try: + # Get existing tags config + tags_config = await _get_tags_config(prisma_client) + + # Check if tag exists + if tag.name not in tags_config: + raise HTTPException(status_code=404, detail=f"Tag {tag.name} not found") + + # Update tag + tag_config_dict = dict(tags_config[tag.name]) + tag_config_dict.update( + { + "description": tag.description, + "models": tag.models, + "updated_at": str(datetime.datetime.now()), + "updated_by": user_api_key_dict.user_id, + } + ) + tags_config[tag.name] = TagConfig(**tag_config_dict) + + # Save updated config + await _save_tags_config(prisma_client, tags_config) + + # Get model names for response + model_info = await _get_model_names(prisma_client, tag.models or []) + tags_config[tag.name].model_info = model_info + + return { + "message": f"Tag {tag.name} updated successfully", + "tag": tags_config[tag.name], + } + except Exception as e: + verbose_proxy_logger.exception(f"Error updating tag: {str(e)}") + raise HTTPException(status_code=500, detail=str(e)) + + +@router.post( + "/tag/info", + tags=["tag management"], + dependencies=[Depends(user_api_key_auth)], +) +async def info_tag( + data: TagInfoRequest, + user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), +): + """ + Get information about specific tags. + + Parameters: + - names: List[str] - List of tag names to get information for + """ + from litellm.proxy.proxy_server import prisma_client + + if prisma_client is None: + raise HTTPException(status_code=500, detail="Database not connected") + + try: + tags_config = await _get_tags_config(prisma_client) + + # Filter tags based on requested names + requested_tags = {name: tags_config.get(name) for name in data.names} + + # Check if any requested tags don't exist + missing_tags = [name for name in data.names if name not in tags_config] + if missing_tags: + raise HTTPException( + status_code=404, detail=f"Tags not found: {missing_tags}" + ) + + return requested_tags + except Exception as e: + raise HTTPException(status_code=500, detail=str(e)) + + +@router.get( + "/tag/list", + tags=["tag management"], + dependencies=[Depends(user_api_key_auth)], +) +async def list_tags( + user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), +): + """ + List all available tags. + """ + from litellm.proxy.proxy_server import prisma_client + + if prisma_client is None: + raise HTTPException(status_code=500, detail="Database not connected") + + try: + tags_config = await _get_tags_config(prisma_client) + list_of_tags = list(tags_config.values()) + return list_of_tags + except Exception as e: + raise HTTPException(status_code=500, detail=str(e)) + + +@router.post( + "/tag/delete", + tags=["tag management"], + dependencies=[Depends(user_api_key_auth)], +) +async def delete_tag( + data: TagDeleteRequest, + user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), +): + """ + Delete a tag. + + Parameters: + - name: str - The name of the tag to delete + """ + from litellm.proxy.proxy_server import prisma_client + + if prisma_client is None: + raise HTTPException(status_code=500, detail="Database not connected") + + try: + # Get existing tags config + tags_config = await _get_tags_config(prisma_client) + + # Check if tag exists + if data.name not in tags_config: + raise HTTPException(status_code=404, detail=f"Tag {data.name} not found") + + # Delete tag + del tags_config[data.name] + + # Save updated config + await _save_tags_config(prisma_client, tags_config) + + return {"message": f"Tag {data.name} deleted successfully"} + except Exception as e: + raise HTTPException(status_code=500, detail=str(e)) diff --git a/litellm/proxy/proxy_config.yaml b/litellm/proxy/proxy_config.yaml index da8bc383b1..709cf08729 100644 --- a/litellm/proxy/proxy_config.yaml +++ b/litellm/proxy/proxy_config.yaml @@ -11,6 +11,10 @@ model_list: litellm_settings: require_auth_for_metrics_endpoint: true + callbacks: ["prometheus"] - service_callback: ["prometheus_system"] \ No newline at end of file + service_callback: ["prometheus_system"] + +router_settings: + enable_tag_filtering: True # 👈 Key Change \ No newline at end of file diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index 051d8f89ca..c270d41cf0 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -237,6 +237,9 @@ from litellm.proxy.management_endpoints.model_management_endpoints import ( from litellm.proxy.management_endpoints.organization_endpoints import ( router as organization_router, ) +from litellm.proxy.management_endpoints.tag_management_endpoints import ( + router as tag_management_router, +) from litellm.proxy.management_endpoints.team_callback_endpoints import ( router as team_callback_router, ) @@ -347,13 +350,13 @@ from fastapi import ( Request, Response, UploadFile, + applications, status, - applications ) from fastapi.encoders import jsonable_encoder from fastapi.middleware.cors import CORSMiddleware -from fastapi.openapi.utils import get_openapi from fastapi.openapi.docs import get_swagger_ui_html +from fastapi.openapi.utils import get_openapi from fastapi.responses import ( FileResponse, JSONResponse, @@ -735,7 +738,7 @@ try: except Exception: pass -# current_dir = os.path.dirname(os.path.abspath(__file__)) +current_dir = os.path.dirname(os.path.abspath(__file__)) # ui_path = os.path.join(current_dir, "_experimental", "out") # # Mount this test directory instead # app.mount("/ui", StaticFiles(directory=ui_path, html=True), name="ui") @@ -753,14 +756,18 @@ app.add_middleware(PrometheusAuthMiddleware) swagger_path = os.path.join(current_dir, "swagger") app.mount("/swagger", StaticFiles(directory=swagger_path), name="swagger") + + def swagger_monkey_patch(*args, **kwargs): return get_swagger_ui_html( *args, **kwargs, swagger_js_url="/swagger/swagger-ui-bundle.js", swagger_css_url="/swagger/swagger-ui.css", - swagger_favicon_url="/swagger/favicon.png" + swagger_favicon_url="/swagger/favicon.png", ) + + applications.get_swagger_ui_html = swagger_monkey_patch from typing import Dict @@ -8174,3 +8181,4 @@ app.include_router(openai_files_router) app.include_router(team_callback_router) app.include_router(budget_management_router) app.include_router(model_management_router) +app.include_router(tag_management_router) diff --git a/litellm/types/tag_management.py b/litellm/types/tag_management.py new file mode 100644 index 0000000000..e530b37cab --- /dev/null +++ b/litellm/types/tag_management.py @@ -0,0 +1,32 @@ +from typing import Dict, List, Optional + +from pydantic import BaseModel + + +class TagBase(BaseModel): + name: str + description: Optional[str] = None + models: Optional[List[str]] = None + model_info: Optional[Dict[str, str]] = None # maps model_id to model_name + + +class TagConfig(TagBase): + created_at: str + updated_at: str + created_by: Optional[str] = None + + +class TagNewRequest(TagBase): + pass + + +class TagUpdateRequest(TagBase): + pass + + +class TagDeleteRequest(BaseModel): + name: str + + +class TagInfoRequest(BaseModel): + names: List[str] diff --git a/tests/litellm/proxy/management_endpoints/test_tag_management_endpoints.py b/tests/litellm/proxy/management_endpoints/test_tag_management_endpoints.py new file mode 100644 index 0000000000..8c2da0cc8a --- /dev/null +++ b/tests/litellm/proxy/management_endpoints/test_tag_management_endpoints.py @@ -0,0 +1,160 @@ +import json +import os +import sys +from typing import Any, Dict, Optional + +import pytest +from fastapi.testclient import TestClient + +sys.path.insert( + 0, os.path.abspath("../../../..") +) # Adds the parent directory to the system path + +from unittest.mock import patch + +import litellm +from litellm.proxy.proxy_server import app +from litellm.types.tag_management import TagDeleteRequest, TagInfoRequest, TagNewRequest + +client = TestClient(app) + + +@pytest.mark.asyncio +async def test_create_and_get_tag(): + """ + Test creation of a new tag and retrieving its information + """ + # Mock the prisma client and _get_tags_config and _save_tags_config + with patch("litellm.proxy.proxy_server.prisma_client") as mock_prisma, patch( + "litellm.proxy.management_endpoints.tag_management_endpoints._get_tags_config" + ) as mock_get_tags, patch( + "litellm.proxy.management_endpoints.tag_management_endpoints._save_tags_config" + ) as mock_save_tags, patch( + "litellm.proxy.management_endpoints.tag_management_endpoints._add_tag_to_deployment" + ) as mock_add_tag, patch( + "litellm.proxy.management_endpoints.tag_management_endpoints._get_model_names" + ) as mock_get_models: + # Setup mocks + mock_get_tags.return_value = {} + mock_get_models.return_value = {"model-1": "gpt-3.5-turbo"} + + # Create a new tag + tag_data = { + "name": "test-tag", + "description": "Test tag for unit testing", + "models": ["model-1"], + } + + # Set admin access for the test + headers = {"Authorization": f"Bearer sk-1234"} + + # Test tag creation + response = client.post("/tag/new", json=tag_data, headers=headers) + assert response.status_code == 200 + result = response.json() + assert result["message"] == "Tag test-tag created successfully" + assert result["tag"]["name"] == "test-tag" + assert result["tag"]["description"] == "Test tag for unit testing" + + # Mock updated tag config for the get request + mock_get_tags.return_value = { + "test-tag": { + "name": "test-tag", + "description": "Test tag for unit testing", + "models": ["model-1"], + "model_info": {"model-1": "gpt-3.5-turbo"}, + } + } + + # Test retrieving tag info + info_data = {"names": ["test-tag"]} + response = client.post("/tag/info", json=info_data, headers=headers) + assert response.status_code == 200 + result = response.json() + assert "test-tag" in result + assert result["test-tag"]["description"] == "Test tag for unit testing" + + +@pytest.mark.asyncio +async def test_update_tag(): + """ + Test updating an existing tag + """ + # Mock the prisma client and _get_tags_config and _save_tags_config + with patch("litellm.proxy.proxy_server.prisma_client") as mock_prisma, patch( + "litellm.proxy.management_endpoints.tag_management_endpoints._get_tags_config" + ) as mock_get_tags, patch( + "litellm.proxy.management_endpoints.tag_management_endpoints._save_tags_config" + ) as mock_save_tags, patch( + "litellm.proxy.management_endpoints.tag_management_endpoints._get_model_names" + ) as mock_get_models: + # Setup mocks for existing tag + mock_get_tags.return_value = { + "test-tag": { + "name": "test-tag", + "description": "Original description", + "models": ["model-1"], + "created_at": "2023-01-01T00:00:00", + "updated_at": "2023-01-01T00:00:00", + "created_by": "user-123", + } + } + mock_get_models.return_value = {"model-1": "gpt-3.5-turbo", "model-2": "gpt-4"} + + # Update tag data + update_data = { + "name": "test-tag", + "description": "Updated description", + "models": ["model-1", "model-2"], + } + + # Set admin access for the test + headers = {"Authorization": f"Bearer sk-1234"} + + # Test tag update + response = client.post("/tag/update", json=update_data, headers=headers) + assert response.status_code == 200 + result = response.json() + assert result["message"] == "Tag test-tag updated successfully" + assert result["tag"]["description"] == "Updated description" + assert len(result["tag"]["models"]) == 2 + assert "model-2" in result["tag"]["models"] + + +@pytest.mark.asyncio +async def test_delete_tag(): + """ + Test deleting a tag + """ + # Mock the prisma client and _get_tags_config and _save_tags_config + with patch("litellm.proxy.proxy_server.prisma_client") as mock_prisma, patch( + "litellm.proxy.management_endpoints.tag_management_endpoints._get_tags_config" + ) as mock_get_tags, patch( + "litellm.proxy.management_endpoints.tag_management_endpoints._save_tags_config" + ) as mock_save_tags: + # Setup mocks for existing tag + mock_get_tags.return_value = { + "test-tag": { + "name": "test-tag", + "description": "Test tag for deletion", + "models": ["model-1"], + "created_at": "2023-01-01T00:00:00", + "updated_at": "2023-01-01T00:00:00", + "created_by": "user-123", + } + } + + # Delete tag data + delete_data = {"name": "test-tag"} + + # Set admin access for the test + headers = {"Authorization": f"Bearer sk-1234"} + + # Test tag deletion + response = client.post("/tag/delete", json=delete_data, headers=headers) + assert response.status_code == 200 + result = response.json() + assert result["message"] == "Tag test-tag deleted successfully" + + # Verify _save_tags_config was called without the deleted tag + mock_save_tags.assert_called_once() diff --git a/ui/litellm-dashboard/src/app/page.tsx b/ui/litellm-dashboard/src/app/page.tsx index a4256b3f4b..592c7bf0f2 100644 --- a/ui/litellm-dashboard/src/app/page.tsx +++ b/ui/litellm-dashboard/src/app/page.tsx @@ -33,6 +33,7 @@ import TransformRequestPanel from "@/components/transform_request"; import { fetchUserModels } from "@/components/create_key_button"; import { fetchTeams } from "@/components/common_components/fetch_teams"; import MCPToolsViewer from "@/components/mcp_tools"; +import TagManagement from "@/components/tag_management"; function getCookie(name: string) { const cookieValue = document.cookie @@ -355,6 +356,12 @@ export default function CreateKeyPage() { userRole={userRole} userID={userID} /> + ) : page == "tag-management" ? ( + ) : page == "new_usage" ? ( = ({ const [endpointType, setEndpointType] = useState(EndpointType.CHAT); const [isLoading, setIsLoading] = useState(false); const abortControllerRef = useRef(null); + const [selectedTags, setSelectedTags] = useState([]); const chatEndRef = useRef(null); @@ -202,6 +205,7 @@ const ChatUI: React.FC = ({ (chunk, model) => updateTextUI("assistant", chunk, model), selectedModel, effectiveApiKey, + selectedTags, signal ); } else if (endpointType === EndpointType.IMAGE) { @@ -211,6 +215,7 @@ const ChatUI: React.FC = ({ (imageUrl, model) => updateImageUI(imageUrl, model), selectedModel, effectiveApiKey, + selectedTags, signal ); } @@ -343,6 +348,18 @@ const ChatUI: React.FC = ({ endpointType={endpointType} onEndpointChange={handleEndpointChange} className="mb-4" + /> + + +
+ + Tags + +
diff --git a/ui/litellm-dashboard/src/components/chat_ui/llm_calls/chat_completion.tsx b/ui/litellm-dashboard/src/components/chat_ui/llm_calls/chat_completion.tsx index e40eb3e696..a5a44e94f6 100644 --- a/ui/litellm-dashboard/src/components/chat_ui/llm_calls/chat_completion.tsx +++ b/ui/litellm-dashboard/src/components/chat_ui/llm_calls/chat_completion.tsx @@ -7,6 +7,7 @@ export async function makeOpenAIChatCompletionRequest( updateUI: (chunk: string, model: string) => void, selectedModel: string, accessToken: string, + tags?: string[], signal?: AbortSignal ) { // base url should be the current base_url @@ -22,6 +23,7 @@ export async function makeOpenAIChatCompletionRequest( apiKey: accessToken, // Replace with your OpenAI API key baseURL: proxyBaseUrl, // Replace with your OpenAI API base URL dangerouslyAllowBrowser: true, // using a temporary litellm proxy key + defaultHeaders: tags && tags.length > 0 ? { 'x-litellm-tags': tags.join(',') } : undefined, }); try { diff --git a/ui/litellm-dashboard/src/components/chat_ui/llm_calls/image_generation.tsx b/ui/litellm-dashboard/src/components/chat_ui/llm_calls/image_generation.tsx index d972870a42..254d9621ad 100644 --- a/ui/litellm-dashboard/src/components/chat_ui/llm_calls/image_generation.tsx +++ b/ui/litellm-dashboard/src/components/chat_ui/llm_calls/image_generation.tsx @@ -6,6 +6,7 @@ export async function makeOpenAIImageGenerationRequest( updateUI: (imageUrl: string, model: string) => void, selectedModel: string, accessToken: string, + tags?: string[], signal?: AbortSignal ) { // base url should be the current base_url @@ -21,6 +22,7 @@ export async function makeOpenAIImageGenerationRequest( apiKey: accessToken, baseURL: proxyBaseUrl, dangerouslyAllowBrowser: true, + defaultHeaders: tags && tags.length > 0 ? { 'x-litellm-tags': tags.join(',') } : undefined, }); try { diff --git a/ui/litellm-dashboard/src/components/leftnav.tsx b/ui/litellm-dashboard/src/components/leftnav.tsx index f58f60db33..8f8c5469e8 100644 --- a/ui/litellm-dashboard/src/components/leftnav.tsx +++ b/ui/litellm-dashboard/src/components/leftnav.tsx @@ -22,6 +22,7 @@ import { ThunderboltOutlined, LockOutlined, ToolOutlined, + TagsOutlined, } from '@ant-design/icons'; import { old_admin_roles, v2_admin_role_names, all_admin_roles, rolesAllowedToSeeUsage, rolesWithWriteAccess, internalUserRoles } from '../utils/roles'; @@ -63,9 +64,6 @@ const Sidebar: React.FC = ({ { key: "14", page: "api_ref", label: "API Reference", icon: }, { key: "16", page: "model-hub", label: "Model Hub", icon: }, { key: "15", page: "logs", label: "Logs", icon: }, - - - { key: "experimental", page: "experimental", @@ -77,6 +75,7 @@ const Sidebar: React.FC = ({ { key: "11", page: "guardrails", label: "Guardrails", icon: , roles: all_admin_roles }, { key: "12", page: "new_usage", label: "New Usage", icon: , roles: [...all_admin_roles, ...internalUserRoles] }, { key: "18", page: "mcp-tools", label: "MCP Tools", icon: , roles: all_admin_roles }, + { key: "19", page: "tag-management", label: "Tag Management", icon: , roles: all_admin_roles }, ] }, { diff --git a/ui/litellm-dashboard/src/components/networking.tsx b/ui/litellm-dashboard/src/components/networking.tsx index fb33ee4815..ac79237fb8 100644 --- a/ui/litellm-dashboard/src/components/networking.tsx +++ b/ui/litellm-dashboard/src/components/networking.tsx @@ -3,6 +3,7 @@ */ import { all_admin_roles } from "@/utils/roles"; import { message } from "antd"; +import { TagNewRequest, TagUpdateRequest, TagDeleteRequest, TagInfoRequest, TagListResponse, TagInfoResponse } from "./tag_management/types"; const isLocal = process.env.NODE_ENV === "development"; export const proxyBaseUrl = isLocal ? "http://localhost:4000" : null; @@ -4155,4 +4156,157 @@ export const callMCPTool = async (accessToken: string, toolName: string, toolArg console.error("Failed to call MCP tool:", error); throw error; } +}; + + +export const tagCreateCall = async ( + accessToken: string, + formValues: TagNewRequest +): Promise => { + try { + let url = proxyBaseUrl + ? `${proxyBaseUrl}/tag/new` + : `/tag/new`; + + const response = await fetch(url, { + method: "POST", + headers: { + "Content-Type": "application/json", + Authorization: `Bearer ${accessToken}`, + }, + body: JSON.stringify(formValues), + }); + + if (!response.ok) { + const errorData = await response.text(); + await handleError(errorData); + return; + } + + return await response.json(); + } catch (error) { + console.error("Error creating tag:", error); + throw error; + } +}; + +export const tagUpdateCall = async ( + accessToken: string, + formValues: TagUpdateRequest +): Promise => { + try { + let url = proxyBaseUrl + ? `${proxyBaseUrl}/tag/update` + : `/tag/update`; + + const response = await fetch(url, { + method: "POST", + headers: { + "Content-Type": "application/json", + Authorization: `Bearer ${accessToken}`, + }, + body: JSON.stringify(formValues), + }); + + if (!response.ok) { + const errorData = await response.text(); + await handleError(errorData); + return; + } + + return await response.json(); + } catch (error) { + console.error("Error updating tag:", error); + throw error; + } +}; + +export const tagInfoCall = async ( + accessToken: string, + tagNames: string[] +): Promise => { + try { + let url = proxyBaseUrl + ? `${proxyBaseUrl}/tag/info` + : `/tag/info`; + + const response = await fetch(url, { + method: "POST", + headers: { + "Content-Type": "application/json", + Authorization: `Bearer ${accessToken}`, + }, + body: JSON.stringify({ names: tagNames }), + }); + + if (!response.ok) { + const errorData = await response.text(); + await handleError(errorData); + return {}; + } + + const data = await response.json(); + return data as TagInfoResponse; + } catch (error) { + console.error("Error getting tag info:", error); + throw error; + } +}; + +export const tagListCall = async (accessToken: string): Promise => { + try { + let url = proxyBaseUrl + ? `${proxyBaseUrl}/tag/list` + : `/tag/list`; + + const response = await fetch(url, { + method: "GET", + headers: { + Authorization: `Bearer ${accessToken}`, + }, + }); + + if (!response.ok) { + const errorData = await response.text(); + await handleError(errorData); + return {}; + } + + const data = await response.json(); + return data as TagListResponse; + } catch (error) { + console.error("Error listing tags:", error); + throw error; + } +}; + +export const tagDeleteCall = async ( + accessToken: string, + tagName: string +): Promise => { + try { + let url = proxyBaseUrl + ? `${proxyBaseUrl}/tag/delete` + : `/tag/delete`; + + const response = await fetch(url, { + method: "POST", + headers: { + "Content-Type": "application/json", + Authorization: `Bearer ${accessToken}`, + }, + body: JSON.stringify({ name: tagName }), + }); + + if (!response.ok) { + const errorData = await response.text(); + await handleError(errorData); + return; + } + + return await response.json(); + } catch (error) { + console.error("Error deleting tag:", error); + throw error; + } }; \ No newline at end of file diff --git a/ui/litellm-dashboard/src/components/tag_management/TagSelector.tsx b/ui/litellm-dashboard/src/components/tag_management/TagSelector.tsx new file mode 100644 index 0000000000..beb41b874f --- /dev/null +++ b/ui/litellm-dashboard/src/components/tag_management/TagSelector.tsx @@ -0,0 +1,52 @@ +import React, { useEffect, useState } from 'react'; +import { Select } from 'antd'; +import { Tag } from './types'; +import { tagListCall } from '../networking'; + +interface TagSelectorProps { + onChange: (selectedTags: string[]) => void; + value?: string[]; + className?: string; + accessToken: string; +} + +const TagSelector: React.FC = ({ onChange, value, className, accessToken }) => { + const [tags, setTags] = useState([]); + const [loading, setLoading] = useState(false); + + useEffect(() => { + const fetchTags = async () => { + if (!accessToken) return; + try { + const response = await tagListCall(accessToken); + console.log("List tags response:", response); + setTags(Object.values(response)); + } catch (error) { + console.error("Error fetching tags:", error); + } + }; + + fetchTags(); + }, []); + + return ( + + + + + + + + + Allowed LLMs{' '} + + + + + } + name="models" + > + + {userModels.map((modelId) => ( + + {getModelDisplayName(modelId)} + + ))} + + + +
+ + +
+ + + ) : ( +
+ + Tag Details +
+
+ Name + {tagDetails.name} +
+
+ Description + {tagDetails.description || "-"} +
+
+ Allowed LLMs +
+ {tagDetails.models.length === 0 ? ( + All Models + ) : ( + tagDetails.models.map((modelId) => ( + + + {tagDetails.model_info?.[modelId] || modelId} + + + )) + )} +
+
+
+ Created + {tagDetails.created_at ? new Date(tagDetails.created_at).toLocaleString() : "-"} +
+
+ Last Updated + {tagDetails.updated_at ? new Date(tagDetails.updated_at).toLocaleString() : "-"} +
+
+
+
+ )} + + ); +}; + +export default TagInfoView; \ No newline at end of file diff --git a/ui/litellm-dashboard/src/components/tag_management/types.tsx b/ui/litellm-dashboard/src/components/tag_management/types.tsx new file mode 100644 index 0000000000..78ba973bf6 --- /dev/null +++ b/ui/litellm-dashboard/src/components/tag_management/types.tsx @@ -0,0 +1,34 @@ +export interface Tag { + name: string; + description?: string; + models: string[]; // model IDs + model_info?: { [key: string]: string }; // maps model_id to model_name + created_at: string; + updated_at: string; + created_by?: string; + updated_by?: string; +} + +export interface TagInfoRequest { + names: string[]; +} + +export interface TagNewRequest { + name: string; + description?: string; + models: string[]; +} + +export interface TagUpdateRequest { + name: string; + description?: string; + models: string[]; +} + +export interface TagDeleteRequest { + name: string; +} + +// The API returns a dictionary of tags where the key is the tag name +export type TagListResponse = Record; +export type TagInfoResponse = Record; \ No newline at end of file