diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index 7c572f13c1..1419a963b9 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -106,7 +106,7 @@ import pydantic from litellm.proxy._types import * from litellm.caching import DualCache, RedisCache from litellm.proxy.health_check import perform_health_check -from litellm.router import LiteLLM_Params, Deployment +from litellm.router import LiteLLM_Params, Deployment, updateDeployment from litellm.router import ModelInfo as RouterModelInfo from litellm._logging import verbose_router_logger, verbose_proxy_logger from litellm.proxy.auth.handle_jwt import JWTHandler @@ -7243,6 +7243,89 @@ async def add_new_model( ) +#### MODEL MANAGEMENT #### +@router.post( + "/model/update", + description="Edit existing model params", + tags=["model management"], + dependencies=[Depends(user_api_key_auth)], +) +async def update_model( + model_params: updateDeployment, + user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), +): + global llm_router, llm_model_list, general_settings, user_config_file_path, proxy_config, prisma_client, master_key, store_model_in_db, proxy_logging_obj + try: + import base64 + + global prisma_client + + if prisma_client is None: + raise HTTPException( + status_code=500, + detail={ + "error": "No DB Connected. Here's how to do it - https://docs.litellm.ai/docs/proxy/virtual_keys" + }, + ) + # update DB + if store_model_in_db == True: + _model_id = None + _model_info = getattr(model_params, "model_info", None) + if _model_info is None: + raise Exception("model_info not provided") + + _model_id = _model_info.id + if _model_id is None: + raise Exception("model_info.id not provided") + _existing_litellm_params = ( + await prisma_client.db.litellm_proxymodeltable.find_unique( + where={"model_id": _model_id} + ) + ) + if _existing_litellm_params is None: + raise Exception("model not found") + _existing_litellm_params_dict = dict( + _existing_litellm_params.litellm_params + ) + + if model_params.litellm_params is None: + raise Exception("litellm_params not provided") + + _new_litellm_params_dict = model_params.litellm_params.dict( + exclude_none=True + ) + + for key, value in _existing_litellm_params_dict.items(): + if key in _new_litellm_params_dict: + _existing_litellm_params_dict[key] = _new_litellm_params_dict[key] + + _data: dict = { + "litellm_params": json.dumps(_existing_litellm_params_dict), # type: ignore + "updated_by": user_api_key_dict.user_id or litellm_proxy_admin_name, + } + model_response = await prisma_client.db.litellm_proxymodeltable.update( + where={"model_id": _model_id}, + data=_data, # type: ignore + ) + except Exception as e: + traceback.print_exc() + if isinstance(e, HTTPException): + raise ProxyException( + message=getattr(e, "detail", f"Authentication Error({str(e)})"), + type="auth_error", + param=getattr(e, "param", "None"), + code=getattr(e, "status_code", status.HTTP_400_BAD_REQUEST), + ) + elif isinstance(e, ProxyException): + raise e + raise ProxyException( + message="Authentication Error, " + str(e), + type="auth_error", + param=getattr(e, "param", "None"), + code=status.HTTP_400_BAD_REQUEST, + ) + + @router.get( "/v2/model/info", description="v2 - returns all the models set on the config.yaml, shows 'user_access' = True if the user has access to the model. Provides more info about each model in /models, including config.yaml descriptions (except api key and api base)", diff --git a/litellm/router.py b/litellm/router.py index 07195aa3a4..3416a84954 100644 --- a/litellm/router.py +++ b/litellm/router.py @@ -35,7 +35,14 @@ from litellm.utils import ( import copy from litellm._logging import verbose_router_logger import logging -from litellm.types.router import Deployment, ModelInfo, LiteLLM_Params, RouterErrors +from litellm.types.router import ( + Deployment, + ModelInfo, + LiteLLM_Params, + RouterErrors, + updateDeployment, + updateLiteLLMParams, +) from litellm.integrations.custom_logger import CustomLogger diff --git a/litellm/tests/test_add_update_models.py b/litellm/tests/test_add_update_models.py new file mode 100644 index 0000000000..ec9ab33b6f --- /dev/null +++ b/litellm/tests/test_add_update_models.py @@ -0,0 +1,191 @@ +import sys, os +import traceback +from dotenv import load_dotenv +from fastapi import Request +from datetime import datetime + +load_dotenv() +import os, io, time + +# this file is to test litellm/proxy + +sys.path.insert( + 0, os.path.abspath("../..") +) # Adds the parent directory to the system path +import pytest, logging, asyncio +import litellm, asyncio +from litellm.proxy.proxy_server import add_new_model, update_model +from litellm._logging import verbose_proxy_logger +from litellm.proxy.utils import PrismaClient, ProxyLogging + +verbose_proxy_logger.setLevel(level=logging.DEBUG) +from litellm.proxy.utils import DBClient +from litellm.caching import DualCache +from litellm.router import ( + Deployment, + updateDeployment, + LiteLLM_Params, + ModelInfo, + updateLiteLLMParams, +) + +from litellm.proxy._types import ( + UserAPIKeyAuth, +) + +proxy_logging_obj = ProxyLogging(user_api_key_cache=DualCache()) + + +@pytest.fixture +def prisma_client(): + from litellm.proxy.proxy_cli import append_query_params + + ### add connection pool + pool timeout args + params = {"connection_limit": 100, "pool_timeout": 60} + database_url = os.getenv("DATABASE_URL") + modified_url = append_query_params(database_url, params) + os.environ["DATABASE_URL"] = modified_url + os.environ["STORE_MODEL_IN_DB"] = "true" + + # Assuming DBClient is a class that needs to be instantiated + prisma_client = PrismaClient( + database_url=os.environ["DATABASE_URL"], proxy_logging_obj=proxy_logging_obj + ) + + # Reset litellm.proxy.proxy_server.prisma_client to None + litellm.proxy.proxy_server.custom_db_client = None + litellm.proxy.proxy_server.litellm_proxy_budget_name = ( + f"litellm-proxy-budget-{time.time()}" + ) + litellm.proxy.proxy_server.user_custom_key_generate = None + + return prisma_client + + +@pytest.mark.asyncio +@pytest.mark.skip(reason="new feature, tests passing locally") +async def test_add_new_model(prisma_client): + setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client) + setattr(litellm.proxy.proxy_server, "master_key", "sk-1234") + setattr(litellm.proxy.proxy_server, "store_model_in_db", True) + + await litellm.proxy.proxy_server.prisma_client.connect() + from litellm.proxy.proxy_server import user_api_key_cache + import uuid + + _new_model_id = f"local-test-{uuid.uuid4().hex}" + + await add_new_model( + model_params=Deployment( + model_name="test_model", + litellm_params=LiteLLM_Params( + model="azure/gpt-3.5-turbo", + api_key="test_api_key", + api_base="test_api_base", + rpm=1000, + tpm=1000, + ), + model_info=ModelInfo( + id=_new_model_id, + ), + ), + user_api_key_dict=UserAPIKeyAuth( + user_role="proxy_admin", api_key="sk-1234", user_id="1234" + ), + ) + + _new_models = await prisma_client.db.litellm_proxymodeltable.find_many() + print("_new_models: ", _new_models) + + _new_model_in_db = None + for model in _new_models: + print("current model: ", model) + if model.model_info["id"] == _new_model_id: + print("FOUND MODEL: ", model) + _new_model_in_db = model + + assert _new_model_in_db is not None + + +@pytest.mark.asyncio +@pytest.mark.skip(reason="new feature, tests passing locally") +async def test_add_update_model(prisma_client): + # test that existing litellm_params are not updated + # only new / updated params get updated + setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client) + setattr(litellm.proxy.proxy_server, "master_key", "sk-1234") + setattr(litellm.proxy.proxy_server, "store_model_in_db", True) + + await litellm.proxy.proxy_server.prisma_client.connect() + from litellm.proxy.proxy_server import user_api_key_cache + import uuid + + _new_model_id = f"local-test-{uuid.uuid4().hex}" + + await add_new_model( + model_params=Deployment( + model_name="test_model", + litellm_params=LiteLLM_Params( + model="azure/gpt-3.5-turbo", + api_key="test_api_key", + api_base="test_api_base", + rpm=1000, + tpm=1000, + ), + model_info=ModelInfo( + id=_new_model_id, + ), + ), + user_api_key_dict=UserAPIKeyAuth( + user_role="proxy_admin", api_key="sk-1234", user_id="1234" + ), + ) + + _new_models = await prisma_client.db.litellm_proxymodeltable.find_many() + print("_new_models: ", _new_models) + + _new_model_in_db = None + for model in _new_models: + print("current model: ", model) + if model.model_info["id"] == _new_model_id: + print("FOUND MODEL: ", model) + _new_model_in_db = model + + assert _new_model_in_db is not None + + _original_model = _new_model_in_db + _original_litellm_params = _new_model_in_db.litellm_params + print("_original_litellm_params: ", _original_litellm_params) + print("now updating the tpm for model") + # run update to update "tpm" + await update_model( + model_params=updateDeployment( + litellm_params=updateLiteLLMParams(tpm=123456), + model_info=ModelInfo( + id=_new_model_id, + ), + ), + user_api_key_dict=UserAPIKeyAuth( + user_role="proxy_admin", api_key="sk-1234", user_id="1234" + ), + ) + + _new_models = await prisma_client.db.litellm_proxymodeltable.find_many() + + _new_model_in_db = None + for model in _new_models: + if model.model_info["id"] == _new_model_id: + print("\nFOUND MODEL: ", model) + _new_model_in_db = model + + # assert all other litellm params are identical to _original_litellm_params + for key, value in _original_litellm_params.items(): + if key == "tpm": + # assert that tpm actually got updated + assert _new_model_in_db.litellm_params[key] == 123456 + else: + assert _new_model_in_db.litellm_params[key] == value + + assert _original_model.model_id == _new_model_in_db.model_id + assert _original_model.model_name == _new_model_in_db.model_name + assert _original_model.model_info == _new_model_in_db.model_info diff --git a/litellm/types/router.py b/litellm/types/router.py index 8f0c0d7312..961f20a91d 100644 --- a/litellm/types/router.py +++ b/litellm/types/router.py @@ -155,6 +155,36 @@ class LiteLLM_Params(BaseModel): setattr(self, key, value) +class updateLiteLLMParams(BaseModel): + # This class is used to update the LiteLLM_Params + # only differece is model is optional + model: Optional[str] = None + tpm: Optional[int] = None + rpm: Optional[int] = None + api_key: Optional[str] = None + api_base: Optional[str] = None + api_version: Optional[str] = None + timeout: Optional[Union[float, str]] = None # if str, pass in as os.environ/ + stream_timeout: Optional[Union[float, str]] = ( + None # timeout when making stream=True calls, if str, pass in as os.environ/ + ) + max_retries: int = 2 # follows openai default of 2 + organization: Optional[str] = None # for openai orgs + ## VERTEX AI ## + vertex_project: Optional[str] = None + vertex_location: Optional[str] = None + ## AWS BEDROCK / SAGEMAKER ## + aws_access_key_id: Optional[str] = None + aws_secret_access_key: Optional[str] = None + aws_region_name: Optional[str] = None + + +class updateDeployment(BaseModel): + model_name: Optional[str] = None + litellm_params: Optional[updateLiteLLMParams] = None + model_info: Optional[ModelInfo] = None + + class Deployment(BaseModel): model_name: str litellm_params: LiteLLM_Params diff --git a/ui/litellm-dashboard/src/components/model_dashboard.tsx b/ui/litellm-dashboard/src/components/model_dashboard.tsx index d4dc3b8f03..08b9211c12 100644 --- a/ui/litellm-dashboard/src/components/model_dashboard.tsx +++ b/ui/litellm-dashboard/src/components/model_dashboard.tsx @@ -18,7 +18,7 @@ import { } from "@tremor/react"; import { TabPanel, TabPanels, TabGroup, TabList, Tab, TextInput, Icon } from "@tremor/react"; import { Select, SelectItem, MultiSelect, MultiSelectItem } from "@tremor/react"; -import { modelInfoCall, userGetRequesedtModelsCall, modelCreateCall, Model, modelCostMap, modelDeleteCall, healthCheckCall } from "./networking"; +import { modelInfoCall, userGetRequesedtModelsCall, modelCreateCall, Model, modelCostMap, modelDeleteCall, healthCheckCall, modelUpdateCall } from "./networking"; import { BarChart } from "@tremor/react"; import { Button as Button2, @@ -51,6 +51,13 @@ interface ModelDashboardProps { userID: string | null; } +interface EditModelModalProps { + visible: boolean; + onCancel: () => void; + model: any; // Assuming TeamType is a type representing your team object + onSubmit: (data: FormData) => void; // Assuming FormData is the type of data to be submitted +} + //["OpenAI", "Azure OpenAI", "Anthropic", "Gemini (Google AI Studio)", "Amazon Bedrock", "OpenAI-Compatible Endpoints (Groq, Together AI, Mistral AI, etc.)"] enum Providers { @@ -183,6 +190,169 @@ const ModelDashboard: React.FC = ({ const [selectedProvider, setSelectedProvider] = useState("OpenAI"); const [healthCheckResponse, setHealthCheckResponse] = useState(''); + const [editModalVisible, setEditModalVisible] = useState(false); + const [selectedModel, setSelectedModel] = useState(null); + + const EditModelModal: React.FC = ({ visible, onCancel, model, onSubmit }) => { + const [form] = Form.useForm(); + let litellm_params_to_edit: Record = {} + let model_name = ""; + let model_id = ""; + if (model) { + litellm_params_to_edit = model.litellm_params + model_name = model.model_name; + let model_info = model.model_info; + if (model_info ) { + model_id = model_info.id; + console.log(`model_id: ${model_id}`) + litellm_params_to_edit.model_id = model_id; + } + } + + + const handleOk = () => { + form + .validateFields() + .then((values) => { + onSubmit(values); + form.resetFields(); + }) + .catch((error) => { + console.error("Validation failed:", error); + }); + }; + + return ( + +
+ <> + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
+ Save +
+
+
+ ); + }; + + + const handleEditClick = (model: any) => { + setSelectedModel(model); + setEditModalVisible(true); + }; + + const handleEditCancel = () => { + setEditModalVisible(false); + setSelectedModel(null); + }; + + +const handleEditSubmit = async (formValues: Record) => { + // Call API to update team with teamId and values + + console.log("handleEditSubmit:", formValues); + if (accessToken == null) { + return; + } + + let newLiteLLMParams: Record = {} + let model_info_model_id = null; + + for (const [key, value] of Object.entries(formValues)) { + if (key !== "model_id") { + newLiteLLMParams[key] = value; + } else { + model_info_model_id = value; + } + } + + let payload = { + litellm_params: newLiteLLMParams, + model_info: { + "id": model_info_model_id + } + } + + console.log("handleEditSubmit payload:", payload); + + let newModelValue = await modelUpdateCall(accessToken, payload); + + // Update the teams state with the updated team data + // if (teams) { + // const updatedTeams = teams.map((team) => + // team.team_id === teamId ? newTeamValues.data : team + // ); + // setTeams(updatedTeams); + // } + message.success("Model updated successfully, restart server to see updates"); + + setEditModalVisible(false); + setSelectedModel(null); +}; + + + const props: UploadProps = { @@ -510,15 +680,27 @@ const ModelDashboard: React.FC = ({ {model.output_cost} {model.max_tokens} + handleEditClick(model)} + /> ))} + + Add new model diff --git a/ui/litellm-dashboard/src/components/networking.tsx b/ui/litellm-dashboard/src/components/networking.tsx index 0f44d51b57..9de32ac50d 100644 --- a/ui/litellm-dashboard/src/components/networking.tsx +++ b/ui/litellm-dashboard/src/components/networking.tsx @@ -1015,6 +1015,41 @@ export const teamUpdateCall = async ( } }; +export const modelUpdateCall = async ( + accessToken: string, + formValues: Record // Assuming formValues is an object +) => { + try { + console.log("Form Values in modelUpateCall:", formValues); // Log the form values before making the API call + + const url = proxyBaseUrl ? `${proxyBaseUrl}/model/update` : `/model/update`; + const response = await fetch(url, { + method: "POST", + headers: { + Authorization: `Bearer ${accessToken}`, + "Content-Type": "application/json", + }, + body: JSON.stringify({ + ...formValues, // Include formValues in the request body + }), + }); + + if (!response.ok) { + const errorData = await response.text(); + message.error("Failed to update model: " + errorData, 20); + console.error("Error update from the server:", errorData); + throw new Error("Network response was not ok"); + } + const data = await response.json(); + console.log("Update model Response:", data); + return data; + // Handle success - you might want to update some state or UI based on the created key + } catch (error) { + console.error("Failed to update model:", error); + throw error; + } +}; + export interface Member { role: string; user_id: string | null;