/model/update endpoint

This commit is contained in:
Ishaan Jaff 2024-04-24 10:39:20 -07:00
parent 3d1a158b63
commit efbf85a5ad
2 changed files with 273 additions and 1 deletions

View file

@ -106,7 +106,7 @@ import pydantic
from litellm.proxy._types import *
from litellm.caching import DualCache, RedisCache
from litellm.proxy.health_check import perform_health_check
from litellm.router import LiteLLM_Params, Deployment
from litellm.router import LiteLLM_Params, Deployment, updateDeployment
from litellm.router import ModelInfo as RouterModelInfo
from litellm._logging import verbose_router_logger, verbose_proxy_logger
from litellm.proxy.auth.handle_jwt import JWTHandler
@ -7236,6 +7236,89 @@ async def add_new_model(
)
#### MODEL MANAGEMENT ####
@router.post(
"/model/update",
description="Edit existing model params",
tags=["model management"],
dependencies=[Depends(user_api_key_auth)],
)
async def update_model(
model_params: updateDeployment,
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
):
global llm_router, llm_model_list, general_settings, user_config_file_path, proxy_config, prisma_client, master_key, store_model_in_db, proxy_logging_obj
try:
import base64
global prisma_client
if prisma_client is None:
raise HTTPException(
status_code=500,
detail={
"error": "No DB Connected. Here's how to do it - https://docs.litellm.ai/docs/proxy/virtual_keys"
},
)
# update DB
if store_model_in_db == True:
_model_id = None
_model_info = getattr(model_params, "model_info", None)
if _model_info is None:
raise Exception("model_info not provided")
_model_id = _model_info.id
if _model_id is None:
raise Exception("model_info.id not provided")
_existing_litellm_params = (
await prisma_client.db.litellm_proxymodeltable.find_unique(
where={"model_id": _model_id}
)
)
if _existing_litellm_params is None:
raise Exception("model not found")
_existing_litellm_params_dict = dict(
_existing_litellm_params.litellm_params
)
if model_params.litellm_params is None:
raise Exception("litellm_params not provided")
_new_litellm_params_dict = model_params.litellm_params.dict(
exclude_none=True
)
for key, value in _existing_litellm_params_dict.items():
if key in _new_litellm_params_dict:
_existing_litellm_params_dict[key] = _new_litellm_params_dict[key]
_data: dict = {
"litellm_params": json.dumps(_existing_litellm_params_dict), # type: ignore
"updated_by": user_api_key_dict.user_id or litellm_proxy_admin_name,
}
model_response = await prisma_client.db.litellm_proxymodeltable.update(
where={"model_id": _model_id},
data=_data, # type: ignore
)
except Exception as e:
traceback.print_exc()
if isinstance(e, HTTPException):
raise ProxyException(
message=getattr(e, "detail", f"Authentication Error({str(e)})"),
type="auth_error",
param=getattr(e, "param", "None"),
code=getattr(e, "status_code", status.HTTP_400_BAD_REQUEST),
)
elif isinstance(e, ProxyException):
raise e
raise ProxyException(
message="Authentication Error, " + str(e),
type="auth_error",
param=getattr(e, "param", "None"),
code=status.HTTP_400_BAD_REQUEST,
)
@router.get(
"/v2/model/info",
description="v2 - returns all the models set on the config.yaml, shows 'user_access' = True if the user has access to the model. Provides more info about each model in /models, including config.yaml descriptions (except api key and api base)",

View file

@ -0,0 +1,189 @@
import sys, os
import traceback
from dotenv import load_dotenv
from fastapi import Request
from datetime import datetime
load_dotenv()
import os, io, time
# this file is to test litellm/proxy
sys.path.insert(
0, os.path.abspath("../..")
) # Adds the parent directory to the system path
import pytest, logging, asyncio
import litellm, asyncio
from litellm.proxy.proxy_server import add_new_model, update_model
from litellm._logging import verbose_proxy_logger
from litellm.proxy.utils import PrismaClient, ProxyLogging
verbose_proxy_logger.setLevel(level=logging.DEBUG)
from litellm.proxy.utils import DBClient
from litellm.caching import DualCache
from litellm.router import (
Deployment,
updateDeployment,
LiteLLM_Params,
ModelInfo,
updateLiteLLMParams,
)
from litellm.proxy._types import (
UserAPIKeyAuth,
)
proxy_logging_obj = ProxyLogging(user_api_key_cache=DualCache())
@pytest.fixture
def prisma_client():
from litellm.proxy.proxy_cli import append_query_params
### add connection pool + pool timeout args
params = {"connection_limit": 100, "pool_timeout": 60}
database_url = os.getenv("DATABASE_URL")
modified_url = append_query_params(database_url, params)
os.environ["DATABASE_URL"] = modified_url
os.environ["STORE_MODEL_IN_DB"] = "true"
# Assuming DBClient is a class that needs to be instantiated
prisma_client = PrismaClient(
database_url=os.environ["DATABASE_URL"], proxy_logging_obj=proxy_logging_obj
)
# Reset litellm.proxy.proxy_server.prisma_client to None
litellm.proxy.proxy_server.custom_db_client = None
litellm.proxy.proxy_server.litellm_proxy_budget_name = (
f"litellm-proxy-budget-{time.time()}"
)
litellm.proxy.proxy_server.user_custom_key_generate = None
return prisma_client
@pytest.mark.asyncio
async def test_add_new_model(prisma_client):
setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client)
setattr(litellm.proxy.proxy_server, "master_key", "sk-1234")
setattr(litellm.proxy.proxy_server, "store_model_in_db", True)
await litellm.proxy.proxy_server.prisma_client.connect()
from litellm.proxy.proxy_server import user_api_key_cache
import uuid
_new_model_id = f"local-test-{uuid.uuid4().hex}"
await add_new_model(
model_params=Deployment(
model_name="test_model",
litellm_params=LiteLLM_Params(
model="azure/gpt-3.5-turbo",
api_key="test_api_key",
api_base="test_api_base",
rpm=1000,
tpm=1000,
),
model_info=ModelInfo(
id=_new_model_id,
),
),
user_api_key_dict=UserAPIKeyAuth(
user_role="proxy_admin", api_key="sk-1234", user_id="1234"
),
)
_new_models = await prisma_client.db.litellm_proxymodeltable.find_many()
print("_new_models: ", _new_models)
_new_model_in_db = None
for model in _new_models:
print("current model: ", model)
if model.model_info["id"] == _new_model_id:
print("FOUND MODEL: ", model)
_new_model_in_db = model
assert _new_model_in_db is not None
@pytest.mark.asyncio
async def test_add_update_model(prisma_client):
# test that existing litellm_params are not updated
# only new / updated params get updated
setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client)
setattr(litellm.proxy.proxy_server, "master_key", "sk-1234")
setattr(litellm.proxy.proxy_server, "store_model_in_db", True)
await litellm.proxy.proxy_server.prisma_client.connect()
from litellm.proxy.proxy_server import user_api_key_cache
import uuid
_new_model_id = f"local-test-{uuid.uuid4().hex}"
await add_new_model(
model_params=Deployment(
model_name="test_model",
litellm_params=LiteLLM_Params(
model="azure/gpt-3.5-turbo",
api_key="test_api_key",
api_base="test_api_base",
rpm=1000,
tpm=1000,
),
model_info=ModelInfo(
id=_new_model_id,
),
),
user_api_key_dict=UserAPIKeyAuth(
user_role="proxy_admin", api_key="sk-1234", user_id="1234"
),
)
_new_models = await prisma_client.db.litellm_proxymodeltable.find_many()
print("_new_models: ", _new_models)
_new_model_in_db = None
for model in _new_models:
print("current model: ", model)
if model.model_info["id"] == _new_model_id:
print("FOUND MODEL: ", model)
_new_model_in_db = model
assert _new_model_in_db is not None
_original_model = _new_model_in_db
_original_litellm_params = _new_model_in_db.litellm_params
print("_original_litellm_params: ", _original_litellm_params)
print("now updating the tpm for model")
# run update to update "tpm"
await update_model(
model_params=updateDeployment(
litellm_params=updateLiteLLMParams(tpm=123456),
model_info=ModelInfo(
id=_new_model_id,
),
),
user_api_key_dict=UserAPIKeyAuth(
user_role="proxy_admin", api_key="sk-1234", user_id="1234"
),
)
_new_models = await prisma_client.db.litellm_proxymodeltable.find_many()
_new_model_in_db = None
for model in _new_models:
if model.model_info["id"] == _new_model_id:
print("\nFOUND MODEL: ", model)
_new_model_in_db = model
# assert all other litellm params are identical to _original_litellm_params
for key, value in _original_litellm_params.items():
if key == "tpm":
# assert that tpm actually got updated
assert _new_model_in_db.litellm_params[key] == 123456
else:
assert _new_model_in_db.litellm_params[key] == value
assert _original_model.model_id == _new_model_in_db.model_id
assert _original_model.model_name == _new_model_in_db.model_name
assert _original_model.model_info == _new_model_in_db.model_info