fix(proxy_server.py): persist models added via /model/new to db

allows models to be used across instances

https://github.com/BerriAI/litellm/issues/2319 , https://github.com/BerriAI/litellm/issues/2329
This commit is contained in:
Krrish Dholakia 2024-04-03 20:16:41 -07:00
parent 24d9fcb32c
commit f536fb13e6
7 changed files with 435 additions and 86 deletions

View file

@ -29,6 +29,103 @@ from litellm.utils import ModelResponse, CustomStreamWrapper
import copy
from litellm._logging import verbose_router_logger
import logging
from pydantic import BaseModel, validator
class ModelInfo(BaseModel):
id: Optional[
str
] # Allow id to be optional on input, but it will always be present as a str in the model instance
def __init__(self, id: Optional[str] = None, **params):
if id is None:
id = str(uuid.uuid4()) # Generate a UUID if id is None or not provided
super().__init__(id=id, **params)
class Config:
extra = "allow"
def __contains__(self, key):
# Define custom behavior for the 'in' operator
return hasattr(self, key)
def get(self, key, default=None):
# Custom .get() method to access attributes with a default value if the attribute doesn't exist
return getattr(self, key, default)
def __getitem__(self, key):
# Allow dictionary-style access to attributes
return getattr(self, key)
def __setitem__(self, key, value):
# Allow dictionary-style assignment of attributes
setattr(self, key, value)
class LiteLLM_Params(BaseModel):
model: str
tpm: Optional[int] = None
rpm: Optional[int] = None
api_key: Optional[str] = None
api_base: Optional[str] = None
api_version: Optional[str] = None
timeout: Optional[Union[float, str]] = None # if str, pass in as os.environ/
stream_timeout: Optional[Union[float, str]] = (
None # timeout when making stream=True calls, if str, pass in as os.environ/
)
max_retries: Optional[int] = 2 # follows openai default of 2
organization: Optional[str] = None # for openai orgs
class Config:
extra = "allow"
def __contains__(self, key):
# Define custom behavior for the 'in' operator
return hasattr(self, key)
def get(self, key, default=None):
# Custom .get() method to access attributes with a default value if the attribute doesn't exist
return getattr(self, key, default)
def __getitem__(self, key):
# Allow dictionary-style access to attributes
return getattr(self, key)
def __setitem__(self, key, value):
# Allow dictionary-style assignment of attributes
setattr(self, key, value)
class Deployment(BaseModel):
model_name: str
litellm_params: LiteLLM_Params
model_info: ModelInfo
def to_json(self, **kwargs):
try:
return self.model_dump(**kwargs) # noqa
except Exception as e:
# if using pydantic v1
return self.dict(**kwargs)
class Config:
extra = "allow"
def __contains__(self, key):
# Define custom behavior for the 'in' operator
return hasattr(self, key)
def get(self, key, default=None):
# Custom .get() method to access attributes with a default value if the attribute doesn't exist
return getattr(self, key, default)
def __getitem__(self, key):
# Allow dictionary-style access to attributes
return getattr(self, key)
def __setitem__(self, key, value):
# Allow dictionary-style assignment of attributes
setattr(self, key, value)
class Router:
@ -2040,76 +2137,114 @@ class Router:
) # cache for 1 hr
def set_model_list(self, model_list: list):
self.model_list = copy.deepcopy(model_list)
original_model_list = copy.deepcopy(model_list)
self.model_list = []
# we add api_base/api_key each model so load balancing between azure/gpt on api_base1 and api_base2 works
import os
for model in self.model_list:
#### MODEL ID INIT ########
model_info = model.get("model_info", {})
model_info["id"] = model_info.get("id", str(uuid.uuid4()))
model["model_info"] = model_info
#### DEPLOYMENT NAMES INIT ########
self.deployment_names.append(model["litellm_params"]["model"])
############ Users can either pass tpm/rpm as a litellm_param or a router param ###########
# for get_available_deployment, we use the litellm_param["rpm"]
# in this snippet we also set rpm to be a litellm_param
if (
model["litellm_params"].get("rpm") is None
and model.get("rpm") is not None
):
model["litellm_params"]["rpm"] = model.get("rpm")
if (
model["litellm_params"].get("tpm") is None
and model.get("tpm") is not None
):
model["litellm_params"]["tpm"] = model.get("tpm")
#### VALIDATE MODEL ########
# check if model provider in supported providers
(
_model,
custom_llm_provider,
dynamic_api_key,
api_base,
) = litellm.get_llm_provider(
model=model["litellm_params"]["model"],
custom_llm_provider=model["litellm_params"].get(
"custom_llm_provider", None
),
for model in original_model_list:
deployment = Deployment(
model_name=model["model_name"],
litellm_params=model["litellm_params"],
model_info=model.get("model_info", {}),
)
self._add_deployment(deployment=deployment)
# Check if user is trying to use model_name == "*"
# this is a catch all model for their specific api key
if model["model_name"] == "*":
self.default_deployment = model
model = deployment.to_json(exclude_none=True)
# Azure GPT-Vision Enhancements, users can pass os.environ/
data_sources = model.get("litellm_params", {}).get("dataSources", [])
for data_source in data_sources:
params = data_source.get("parameters", {})
for param_key in ["endpoint", "key"]:
# if endpoint or key set for Azure GPT Vision Enhancements, check if it's an env var
if param_key in params and params[param_key].startswith(
"os.environ/"
):
env_name = params[param_key].replace("os.environ/", "")
params[param_key] = os.environ.get(env_name, "")
# done reading model["litellm_params"]
if custom_llm_provider not in litellm.provider_list:
raise Exception(f"Unsupported provider - {custom_llm_provider}")
# init OpenAI, Azure clients
self.set_client(model=model)
self.model_list.append(model)
verbose_router_logger.debug(f"\nInitialized Model List {self.model_list}")
self.model_names = [m["model_name"] for m in model_list]
def _add_deployment(self, deployment: Deployment):
import os
#### DEPLOYMENT NAMES INIT ########
self.deployment_names.append(deployment.litellm_params.model)
############ Users can either pass tpm/rpm as a litellm_param or a router param ###########
# for get_available_deployment, we use the litellm_param["rpm"]
# in this snippet we also set rpm to be a litellm_param
if (
deployment.litellm_params.rpm is None
and getattr(deployment, "rpm", None) is not None
):
deployment.litellm_params.rpm = getattr(deployment, "rpm")
if (
deployment.litellm_params.tpm is None
and getattr(deployment, "tpm", None) is not None
):
deployment.litellm_params.tpm = getattr(deployment, "tpm")
#### VALIDATE MODEL ########
# check if model provider in supported providers
(
_model,
custom_llm_provider,
dynamic_api_key,
api_base,
) = litellm.get_llm_provider(
model=deployment.litellm_params.model,
custom_llm_provider=deployment.litellm_params.get(
"custom_llm_provider", None
),
)
# Check if user is trying to use model_name == "*"
# this is a catch all model for their specific api key
if deployment.model_name == "*":
self.default_deployment = deployment
# Azure GPT-Vision Enhancements, users can pass os.environ/
data_sources = deployment.litellm_params.get("dataSources", [])
for data_source in data_sources:
params = data_source.get("parameters", {})
for param_key in ["endpoint", "key"]:
# if endpoint or key set for Azure GPT Vision Enhancements, check if it's an env var
if param_key in params and params[param_key].startswith("os.environ/"):
env_name = params[param_key].replace("os.environ/", "")
params[param_key] = os.environ.get(env_name, "")
# done reading model["litellm_params"]
if custom_llm_provider not in litellm.provider_list:
raise Exception(f"Unsupported provider - {custom_llm_provider}")
# init OpenAI, Azure clients
self.set_client(model=deployment.to_json(exclude_none=True))
def add_deployment(self, deployment: Deployment):
# check if deployment already exists
if deployment.model_info.id in self.get_model_ids():
return
# add to model list
_deployment = deployment.to_json(exclude_none=True)
self.model_list.append(_deployment)
# initialize client
self._add_deployment(deployment=deployment)
# add to model names
self.model_names.append(deployment.model_name)
return
def get_model_ids(self):
ids = []
for model in self.model_list:
if "model_info" in model and "id" in model["model_info"]:
id = model["model_info"]["id"]
ids.append(id)
return ids
def get_model_names(self):
return self.model_names
def get_model_list(self):
return self.model_list
def _get_client(self, deployment, kwargs, client_type=None):
"""
Returns the appropriate client based on the given deployment, kwargs, and client_type.