Merge pull request #2827 from BerriAI/litellm_model_add_api

fix(proxy_server.py): persist models added via `/model/new` to db
This commit is contained in:
Krish Dholakia 2024-04-03 23:30:39 -07:00 committed by GitHub
commit 0c5b8a7667
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
11 changed files with 432 additions and 90 deletions

View file

@ -23,6 +23,7 @@ litellm_settings:
general_settings:
master_key: sk-1234
alerting: ["slack"]
store_model_in_db: True
# proxy_batch_write_at: 60 # 👈 Frequency of batch writing logs to server (in seconds)
enable_jwt_auth: True
alerting: ["slack"]

View file

@ -97,6 +97,8 @@ from litellm.proxy.utils import (
_is_projected_spend_over_limit,
_get_projected_spend_over_limit,
update_spend,
encrypt_value,
decrypt_value,
)
from litellm.proxy.secret_managers.google_kms import load_google_kms
from litellm.proxy.secret_managers.aws_secret_manager import load_aws_secret_manager
@ -104,6 +106,8 @@ import pydantic
from litellm.proxy._types import *
from litellm.caching import DualCache, RedisCache
from litellm.proxy.health_check import perform_health_check
from litellm.router import LiteLLM_Params, Deployment
from litellm.router import ModelInfo as RouterModelInfo
from litellm._logging import verbose_router_logger, verbose_proxy_logger
from litellm.proxy.auth.handle_jwt import JWTHandler
from litellm.proxy.hooks.prompt_injection_detection import (
@ -2371,6 +2375,64 @@ class ProxyConfig:
router = litellm.Router(**router_params) # type:ignore
return router, model_list, general_settings
async def add_deployment(
self,
prisma_client: PrismaClient,
proxy_logging_obj: ProxyLogging,
):
"""
- Check db for new models (last 10 most recently updated)
- Check if model id's in router already
- If not, add to router
"""
global llm_router, llm_model_list, master_key
import base64
try:
if llm_router is None:
raise Exception("No router initialized")
new_models = await prisma_client.db.litellm_proxymodeltable.find_many(
take=10, order={"updated_at": "desc"}
)
for m in new_models:
_litellm_params = m.litellm_params
if isinstance(_litellm_params, dict):
# decrypt values
for k, v in _litellm_params.items():
# decode base64
decoded_b64 = base64.b64decode(v)
# decrypt value
_litellm_params[k] = decrypt_value(
value=decoded_b64, master_key=master_key
)
_litellm_params = LiteLLM_Params(**_litellm_params)
else:
raise Exception(
f"Invalid model added to proxy db. Invalid litellm params. litellm_params={_litellm_params}"
)
if m.model_info is not None and isinstance(m.model_info, dict):
if "id" not in m.model_info:
m.model_info["id"] = m.model_id
_model_info = RouterModelInfo(**m.model_info)
else:
_model_info = RouterModelInfo(id=m.model_id)
llm_router.add_deployment(
deployment=Deployment(
model_name=m.model_name,
litellm_params=_litellm_params,
model_info=_model_info,
)
)
llm_model_list = llm_router.get_model_list()
except Exception as e:
raise e
proxy_config = ProxyConfig()
@ -2943,7 +3005,7 @@ async def startup_event():
if prisma_client is not None:
create_view_response = await prisma_client.check_view_exists()
### START BATCH WRITING DB ###
### START BATCH WRITING DB + CHECKING NEW MODELS###
if prisma_client is not None:
scheduler = AsyncIOScheduler()
interval = random.randint(
@ -2966,6 +3028,15 @@ async def startup_event():
seconds=batch_writing_interval,
args=[prisma_client, db_writer_client, proxy_logging_obj],
)
### ADD NEW MODELS ###
if general_settings.get("store_model_in_db", False) == True:
scheduler.add_job(
proxy_config.add_deployment,
"interval",
seconds=30,
args=[prisma_client, proxy_logging_obj],
)
scheduler.start()
@ -3314,8 +3385,6 @@ async def chat_completion(
)
)
start_time = time.time()
### ROUTE THE REQUEST ###
# Do not change this - it should be a constant time fetch - ALWAYS
router_model_names = llm_router.model_names if llm_router is not None else []
@ -3534,8 +3603,6 @@ async def embeddings(
user_api_key_dict=user_api_key_dict, data=data, call_type="embeddings"
)
start_time = time.time()
## ROUTE TO CORRECT ENDPOINT ##
# skip router if user passed their key
if "api_key" in data:
@ -6692,30 +6759,54 @@ async def info_budget(data: BudgetRequest):
tags=["model management"],
dependencies=[Depends(user_api_key_auth)],
)
async def add_new_model(model_params: ModelParams):
global llm_router, llm_model_list, general_settings, user_config_file_path, proxy_config
async def add_new_model(
model_params: ModelParams,
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
):
global llm_router, llm_model_list, general_settings, user_config_file_path, proxy_config, prisma_client, master_key
try:
# Load existing config
config = await proxy_config.get_config()
import base64
verbose_proxy_logger.debug("User config path: %s", user_config_file_path)
global prisma_client
verbose_proxy_logger.debug("Loaded config: %s", config)
# Add the new model to the config
model_info = model_params.model_info.json()
model_info = {k: v for k, v in model_info.items() if v is not None}
config["model_list"].append(
{
"model_name": model_params.model_name,
"litellm_params": model_params.litellm_params,
"model_info": model_info,
}
)
if prisma_client is None:
raise HTTPException(
status_code=500,
detail={
"error": "No DB Connected. Here's how to do it - https://docs.litellm.ai/docs/proxy/virtual_keys"
},
)
verbose_proxy_logger.debug("updated model list: %s", config["model_list"])
# Save new config
await proxy_config.save_config(new_config=config)
# update DB
if general_settings.get("store_model_in_db", False) == True:
"""
- store model_list in db
- store keys separately
"""
# encrypt litellm params #
for k, v in model_params.litellm_params.items():
encrypted_value = encrypt_value(value=v, master_key=master_key) # type: ignore
model_params.litellm_params[k] = base64.b64encode(
encrypted_value
).decode("utf-8")
await prisma_client.db.litellm_proxymodeltable.create(
data={
"model_name": model_params.model_name,
"litellm_params": json.dumps(model_params.litellm_params), # type: ignore
"model_info": model_params.model_info.model_dump_json( # type: ignore
exclude_none=True
),
"created_by": user_api_key_dict.user_id or litellm_proxy_admin_name,
"updated_by": user_api_key_dict.user_id or litellm_proxy_admin_name,
}
)
else:
raise HTTPException(
status_code=500,
detail={
"error": "Set `store_model_in_db: true` in general_settings on your config.yaml"
},
)
return {"message": "Model added successfully"}
except Exception as e:
@ -6885,14 +6976,16 @@ async def model_info_v1(
):
global llm_model_list, general_settings, user_config_file_path, proxy_config
# Load existing config
config = await proxy_config.get_config()
if llm_model_list is None:
raise HTTPException(
status_code=500, detail={"error": "LLM Model List not loaded in"}
)
if len(user_api_key_dict.models) > 0:
model_names = user_api_key_dict.models
all_models = [m for m in config["model_list"] if m["model_name"] in model_names]
all_models = [m for m in llm_model_list if m["model_name"] in model_names]
else:
all_models = config["model_list"]
all_models = llm_model_list
for model in all_models:
# provided model_info in config.yaml
model_info = model.get("model_info", {})
@ -6957,6 +7050,7 @@ async def delete_model(model_info: ModelInfoDelete):
# Check if the model with the specified model_id exists
model_to_delete = None
for model in config["model_list"]:
if model.get("model_info", {}).get("id", None) == model_info.id:
model_to_delete = model

View file

@ -39,6 +39,7 @@ model LiteLLM_ProxyModelTable {
updated_by String
}
model LiteLLM_OrganizationTable {
organization_id String @id @default(uuid())
organization_alias String

View file

@ -13,6 +13,7 @@ from litellm.proxy._types import (
Member,
)
from litellm.caching import DualCache, RedisCache
from litellm.router import Deployment, ModelInfo, LiteLLM_Params
from litellm.llms.custom_httpx.httpx_handler import HTTPHandler
from litellm.proxy.hooks.parallel_request_limiter import (
_PROXY_MaxParallelRequestsHandler,
@ -2181,6 +2182,32 @@ async def update_spend(
raise e
# class Models:
# """
# Need a class to maintain state of models / router across calls to check if new deployments need to be added
# """
# def __init__(
# self,
# router: litellm.Router,
# llm_model_list: list,
# prisma_client: PrismaClient,
# proxy_logging_obj: ProxyLogging,
# master_key: str,
# ) -> None:
# self.router = router
# self.llm_model_list = llm_model_list
# self.prisma_client = prisma_client
# self.proxy_logging_obj = proxy_logging_obj
# self.master_key = master_key
# def get_router(self) -> litellm.Router:
# return self.router
# def get_model_list(self) -> list:
# return self.llm_model_list
async def _read_request_body(request):
"""
Asynchronous function to read the request body and parse it as JSON or literal data.
@ -2318,6 +2345,45 @@ def _is_user_proxy_admin(user_id_information=None):
return False
def encrypt_value(value: str, master_key: str):
import hashlib
import nacl.secret
import nacl.utils
# get 32 byte master key #
hash_object = hashlib.sha256(master_key.encode())
hash_bytes = hash_object.digest()
# initialize secret box #
box = nacl.secret.SecretBox(hash_bytes)
# encode message #
value_bytes = value.encode("utf-8")
encrypted = box.encrypt(value_bytes)
return encrypted
def decrypt_value(value: bytes, master_key: str) -> str:
import hashlib
import nacl.secret
import nacl.utils
# get 32 byte master key #
hash_object = hashlib.sha256(master_key.encode())
hash_bytes = hash_object.digest()
# initialize secret box #
box = nacl.secret.SecretBox(hash_bytes)
# Convert the bytes object to a string
plaintext = box.decrypt(value)
plaintext = plaintext.decode("utf-8") # type: ignore
return plaintext # type: ignore
# LiteLLM Admin UI - Non SSO Login
html_form = """
<!DOCTYPE html>

View file

@ -29,6 +29,103 @@ from litellm.utils import ModelResponse, CustomStreamWrapper
import copy
from litellm._logging import verbose_router_logger
import logging
from pydantic import BaseModel, validator
class ModelInfo(BaseModel):
id: Optional[
str
] # Allow id to be optional on input, but it will always be present as a str in the model instance
def __init__(self, id: Optional[str] = None, **params):
if id is None:
id = str(uuid.uuid4()) # Generate a UUID if id is None or not provided
super().__init__(id=id, **params)
class Config:
extra = "allow"
def __contains__(self, key):
# Define custom behavior for the 'in' operator
return hasattr(self, key)
def get(self, key, default=None):
# Custom .get() method to access attributes with a default value if the attribute doesn't exist
return getattr(self, key, default)
def __getitem__(self, key):
# Allow dictionary-style access to attributes
return getattr(self, key)
def __setitem__(self, key, value):
# Allow dictionary-style assignment of attributes
setattr(self, key, value)
class LiteLLM_Params(BaseModel):
model: str
tpm: Optional[int] = None
rpm: Optional[int] = None
api_key: Optional[str] = None
api_base: Optional[str] = None
api_version: Optional[str] = None
timeout: Optional[Union[float, str]] = None # if str, pass in as os.environ/
stream_timeout: Optional[Union[float, str]] = (
None # timeout when making stream=True calls, if str, pass in as os.environ/
)
max_retries: Optional[int] = 2 # follows openai default of 2
organization: Optional[str] = None # for openai orgs
class Config:
extra = "allow"
def __contains__(self, key):
# Define custom behavior for the 'in' operator
return hasattr(self, key)
def get(self, key, default=None):
# Custom .get() method to access attributes with a default value if the attribute doesn't exist
return getattr(self, key, default)
def __getitem__(self, key):
# Allow dictionary-style access to attributes
return getattr(self, key)
def __setitem__(self, key, value):
# Allow dictionary-style assignment of attributes
setattr(self, key, value)
class Deployment(BaseModel):
model_name: str
litellm_params: LiteLLM_Params
model_info: ModelInfo
def to_json(self, **kwargs):
try:
return self.model_dump(**kwargs) # noqa
except Exception as e:
# if using pydantic v1
return self.dict(**kwargs)
class Config:
extra = "allow"
def __contains__(self, key):
# Define custom behavior for the 'in' operator
return hasattr(self, key)
def get(self, key, default=None):
# Custom .get() method to access attributes with a default value if the attribute doesn't exist
return getattr(self, key, default)
def __getitem__(self, key):
# Allow dictionary-style access to attributes
return getattr(self, key)
def __setitem__(self, key, value):
# Allow dictionary-style assignment of attributes
setattr(self, key, value)
class Router:
@ -2040,76 +2137,121 @@ class Router:
) # cache for 1 hr
def set_model_list(self, model_list: list):
self.model_list = copy.deepcopy(model_list)
original_model_list = copy.deepcopy(model_list)
self.model_list = []
# we add api_base/api_key each model so load balancing between azure/gpt on api_base1 and api_base2 works
import os
for model in self.model_list:
#### MODEL ID INIT ########
model_info = model.get("model_info", {})
model_info["id"] = model_info.get("id", str(uuid.uuid4()))
model["model_info"] = model_info
#### DEPLOYMENT NAMES INIT ########
self.deployment_names.append(model["litellm_params"]["model"])
############ Users can either pass tpm/rpm as a litellm_param or a router param ###########
# for get_available_deployment, we use the litellm_param["rpm"]
# in this snippet we also set rpm to be a litellm_param
if (
model["litellm_params"].get("rpm") is None
and model.get("rpm") is not None
):
model["litellm_params"]["rpm"] = model.get("rpm")
if (
model["litellm_params"].get("tpm") is None
and model.get("tpm") is not None
):
model["litellm_params"]["tpm"] = model.get("tpm")
#### VALIDATE MODEL ########
# check if model provider in supported providers
(
_model,
custom_llm_provider,
dynamic_api_key,
api_base,
) = litellm.get_llm_provider(
model=model["litellm_params"]["model"],
custom_llm_provider=model["litellm_params"].get(
"custom_llm_provider", None
),
for model in original_model_list:
_model_name = model.pop("model_name")
_litellm_params = model.pop("litellm_params")
_model_info = model.pop("model_info", {})
deployment = Deployment(
**model,
model_name=_model_name,
litellm_params=_litellm_params,
model_info=_model_info,
)
# Check if user is trying to use model_name == "*"
# this is a catch all model for their specific api key
if model["model_name"] == "*":
self.default_deployment = model
deployment = self._add_deployment(deployment=deployment)
# Azure GPT-Vision Enhancements, users can pass os.environ/
data_sources = model.get("litellm_params", {}).get("dataSources", [])
model = deployment.to_json(exclude_none=True)
for data_source in data_sources:
params = data_source.get("parameters", {})
for param_key in ["endpoint", "key"]:
# if endpoint or key set for Azure GPT Vision Enhancements, check if it's an env var
if param_key in params and params[param_key].startswith(
"os.environ/"
):
env_name = params[param_key].replace("os.environ/", "")
params[param_key] = os.environ.get(env_name, "")
# done reading model["litellm_params"]
if custom_llm_provider not in litellm.provider_list:
raise Exception(f"Unsupported provider - {custom_llm_provider}")
# init OpenAI, Azure clients
self.set_client(model=model)
self.model_list.append(model)
verbose_router_logger.debug(f"\nInitialized Model List {self.model_list}")
self.model_names = [m["model_name"] for m in model_list]
def _add_deployment(self, deployment: Deployment) -> Deployment:
import os
#### DEPLOYMENT NAMES INIT ########
self.deployment_names.append(deployment.litellm_params.model)
############ Users can either pass tpm/rpm as a litellm_param or a router param ###########
# for get_available_deployment, we use the litellm_param["rpm"]
# in this snippet we also set rpm to be a litellm_param
if (
deployment.litellm_params.rpm is None
and getattr(deployment, "rpm", None) is not None
):
deployment.litellm_params.rpm = getattr(deployment, "rpm")
if (
deployment.litellm_params.tpm is None
and getattr(deployment, "tpm", None) is not None
):
deployment.litellm_params.tpm = getattr(deployment, "tpm")
#### VALIDATE MODEL ########
# check if model provider in supported providers
(
_model,
custom_llm_provider,
dynamic_api_key,
api_base,
) = litellm.get_llm_provider(
model=deployment.litellm_params.model,
custom_llm_provider=deployment.litellm_params.get(
"custom_llm_provider", None
),
)
# Check if user is trying to use model_name == "*"
# this is a catch all model for their specific api key
if deployment.model_name == "*":
self.default_deployment = deployment.to_json(exclude_none=True)
# Azure GPT-Vision Enhancements, users can pass os.environ/
data_sources = deployment.litellm_params.get("dataSources", [])
for data_source in data_sources:
params = data_source.get("parameters", {})
for param_key in ["endpoint", "key"]:
# if endpoint or key set for Azure GPT Vision Enhancements, check if it's an env var
if param_key in params and params[param_key].startswith("os.environ/"):
env_name = params[param_key].replace("os.environ/", "")
params[param_key] = os.environ.get(env_name, "")
# done reading model["litellm_params"]
if custom_llm_provider not in litellm.provider_list:
raise Exception(f"Unsupported provider - {custom_llm_provider}")
# init OpenAI, Azure clients
self.set_client(model=deployment.to_json(exclude_none=True))
return deployment
def add_deployment(self, deployment: Deployment):
# check if deployment already exists
if deployment.model_info.id in self.get_model_ids():
return
# add to model list
_deployment = deployment.to_json(exclude_none=True)
self.model_list.append(_deployment)
# initialize client
self._add_deployment(deployment=deployment)
# add to model names
self.model_names.append(deployment.model_name)
return
def get_model_ids(self):
ids = []
for model in self.model_list:
if "model_info" in model and "id" in model["model_info"]:
id = model["model_info"]["id"]
ids.append(id)
return ids
def get_model_names(self):
return self.model_names
def get_model_list(self):
return self.model_list
def _get_client(self, deployment, kwargs, client_type=None):
"""
Returns the appropriate client based on the given deployment, kwargs, and client_type.
@ -2315,6 +2457,7 @@ class Router:
model = litellm.model_alias_map[
model
] # update the model to the actual value if an alias has been passed in
if self.routing_strategy == "least-busy" and self.leastbusy_logger is not None:
deployment = self.leastbusy_logger.get_available_deployments(
model_group=model, healthy_deployments=healthy_deployments

View file

@ -146,7 +146,7 @@ def test_cooldown_same_model_name():
"api_key": os.getenv("AZURE_API_KEY"),
"api_version": os.getenv("AZURE_API_VERSION"),
"api_base": os.getenv("AZURE_API_BASE"),
"tpm": 0.000001,
"tpm": 1,
},
},
]
@ -181,4 +181,4 @@ def test_cooldown_same_model_name():
pytest.fail(f"Got unexpected exception on router! - {e}")
test_cooldown_same_model_name()
# test_cooldown_same_model_name()

View file

@ -194,7 +194,7 @@ def test_img_gen(client_no_auth):
#### ADDITIONAL
# @pytest.mark.skip(reason="hitting yaml load issues on circle-ci")
@pytest.mark.skip(reason="test via docker tests. Requires prisma client.")
def test_add_new_model(client_no_auth):
global headers
try:

View file

@ -10,6 +10,7 @@ sys.path.insert(
) # Adds the parent directory to the system path
import litellm
from litellm import Router
from litellm.router import Deployment, LiteLLM_Params, ModelInfo
from concurrent.futures import ThreadPoolExecutor
from collections import defaultdict
from dotenv import load_dotenv
@ -1193,3 +1194,37 @@ async def test_router_amoderation():
)
print("moderation result", result)
def test_router_add_deployment():
initial_model_list = [
{
"model_name": "fake-openai-endpoint",
"litellm_params": {
"model": "openai/my-fake-model",
"api_key": "my-fake-key",
"api_base": "https://openai-function-calling-workers.tasslexyz.workers.dev/",
},
},
]
router = Router(model_list=initial_model_list)
init_model_id_list = router.get_model_ids()
print(f"init_model_id_list: {init_model_id_list}")
router.add_deployment(
deployment=Deployment(
model_name="gpt-instruct",
litellm_params=LiteLLM_Params(model="gpt-3.5-turbo-instruct"),
model_info=ModelInfo(),
)
)
new_model_id_list = router.get_model_ids()
print(f"new_model_id_list: {new_model_id_list}")
assert len(new_model_id_list) > len(init_model_id_list)
assert new_model_id_list[1] != new_model_id_list[0]

View file

@ -181,7 +181,7 @@ def test_weighted_selection_router_tpm_as_router_param():
pytest.fail(f"Error occurred: {e}")
test_weighted_selection_router_tpm_as_router_param()
# test_weighted_selection_router_tpm_as_router_param()
def test_weighted_selection_router_rpm_as_router_param():
@ -433,7 +433,7 @@ def test_usage_based_routing():
selection_counts[response["model"]] += 1
# print("selection counts", selection_counts)
print("selection counts", selection_counts)
total_requests = sum(selection_counts.values())

View file

@ -54,6 +54,7 @@ litellm_settings:
general_settings:
master_key: sk-1234 # [OPTIONAL] Use to enforce auth on proxy. See - https://docs.litellm.ai/docs/proxy/virtual_keys
store_model_in_db: True
proxy_budget_rescheduler_min_time: 60
proxy_budget_rescheduler_max_time: 64
proxy_batch_write_at: 1

View file

@ -13,6 +13,7 @@ numpy==1.24.3 # semantic caching
pandas==2.1.1 # for viewing clickhouse spend analytics
prisma==0.11.0 # for db
mangum==0.17.0 # for aws lambda functions
pynacl==1.5.0 # for encrypting keys
google-cloud-aiplatform==1.43.0 # for vertex ai calls
google-generativeai==0.3.2 # for vertex ai calls
async_generator==1.10.0 # for async ollama calls