forked from phoenix/litellm-mirror
fix(proxy_server.py): persist models added via /model/new
to db
allows models to be used across instances https://github.com/BerriAI/litellm/issues/2319 , https://github.com/BerriAI/litellm/issues/2329
This commit is contained in:
parent
24d9fcb32c
commit
f536fb13e6
7 changed files with 435 additions and 86 deletions
|
@ -29,6 +29,103 @@ from litellm.utils import ModelResponse, CustomStreamWrapper
|
|||
import copy
|
||||
from litellm._logging import verbose_router_logger
|
||||
import logging
|
||||
from pydantic import BaseModel, validator
|
||||
|
||||
|
||||
class ModelInfo(BaseModel):
|
||||
id: Optional[
|
||||
str
|
||||
] # Allow id to be optional on input, but it will always be present as a str in the model instance
|
||||
|
||||
def __init__(self, id: Optional[str] = None, **params):
|
||||
if id is None:
|
||||
id = str(uuid.uuid4()) # Generate a UUID if id is None or not provided
|
||||
super().__init__(id=id, **params)
|
||||
|
||||
class Config:
|
||||
extra = "allow"
|
||||
|
||||
def __contains__(self, key):
|
||||
# Define custom behavior for the 'in' operator
|
||||
return hasattr(self, key)
|
||||
|
||||
def get(self, key, default=None):
|
||||
# Custom .get() method to access attributes with a default value if the attribute doesn't exist
|
||||
return getattr(self, key, default)
|
||||
|
||||
def __getitem__(self, key):
|
||||
# Allow dictionary-style access to attributes
|
||||
return getattr(self, key)
|
||||
|
||||
def __setitem__(self, key, value):
|
||||
# Allow dictionary-style assignment of attributes
|
||||
setattr(self, key, value)
|
||||
|
||||
|
||||
class LiteLLM_Params(BaseModel):
|
||||
model: str
|
||||
tpm: Optional[int] = None
|
||||
rpm: Optional[int] = None
|
||||
api_key: Optional[str] = None
|
||||
api_base: Optional[str] = None
|
||||
api_version: Optional[str] = None
|
||||
timeout: Optional[Union[float, str]] = None # if str, pass in as os.environ/
|
||||
stream_timeout: Optional[Union[float, str]] = (
|
||||
None # timeout when making stream=True calls, if str, pass in as os.environ/
|
||||
)
|
||||
max_retries: Optional[int] = 2 # follows openai default of 2
|
||||
organization: Optional[str] = None # for openai orgs
|
||||
|
||||
class Config:
|
||||
extra = "allow"
|
||||
|
||||
def __contains__(self, key):
|
||||
# Define custom behavior for the 'in' operator
|
||||
return hasattr(self, key)
|
||||
|
||||
def get(self, key, default=None):
|
||||
# Custom .get() method to access attributes with a default value if the attribute doesn't exist
|
||||
return getattr(self, key, default)
|
||||
|
||||
def __getitem__(self, key):
|
||||
# Allow dictionary-style access to attributes
|
||||
return getattr(self, key)
|
||||
|
||||
def __setitem__(self, key, value):
|
||||
# Allow dictionary-style assignment of attributes
|
||||
setattr(self, key, value)
|
||||
|
||||
|
||||
class Deployment(BaseModel):
|
||||
model_name: str
|
||||
litellm_params: LiteLLM_Params
|
||||
model_info: ModelInfo
|
||||
|
||||
def to_json(self, **kwargs):
|
||||
try:
|
||||
return self.model_dump(**kwargs) # noqa
|
||||
except Exception as e:
|
||||
# if using pydantic v1
|
||||
return self.dict(**kwargs)
|
||||
|
||||
class Config:
|
||||
extra = "allow"
|
||||
|
||||
def __contains__(self, key):
|
||||
# Define custom behavior for the 'in' operator
|
||||
return hasattr(self, key)
|
||||
|
||||
def get(self, key, default=None):
|
||||
# Custom .get() method to access attributes with a default value if the attribute doesn't exist
|
||||
return getattr(self, key, default)
|
||||
|
||||
def __getitem__(self, key):
|
||||
# Allow dictionary-style access to attributes
|
||||
return getattr(self, key)
|
||||
|
||||
def __setitem__(self, key, value):
|
||||
# Allow dictionary-style assignment of attributes
|
||||
setattr(self, key, value)
|
||||
|
||||
|
||||
class Router:
|
||||
|
@ -2040,76 +2137,114 @@ class Router:
|
|||
) # cache for 1 hr
|
||||
|
||||
def set_model_list(self, model_list: list):
|
||||
self.model_list = copy.deepcopy(model_list)
|
||||
original_model_list = copy.deepcopy(model_list)
|
||||
self.model_list = []
|
||||
# we add api_base/api_key each model so load balancing between azure/gpt on api_base1 and api_base2 works
|
||||
import os
|
||||
|
||||
for model in self.model_list:
|
||||
#### MODEL ID INIT ########
|
||||
model_info = model.get("model_info", {})
|
||||
model_info["id"] = model_info.get("id", str(uuid.uuid4()))
|
||||
model["model_info"] = model_info
|
||||
#### DEPLOYMENT NAMES INIT ########
|
||||
self.deployment_names.append(model["litellm_params"]["model"])
|
||||
############ Users can either pass tpm/rpm as a litellm_param or a router param ###########
|
||||
# for get_available_deployment, we use the litellm_param["rpm"]
|
||||
# in this snippet we also set rpm to be a litellm_param
|
||||
if (
|
||||
model["litellm_params"].get("rpm") is None
|
||||
and model.get("rpm") is not None
|
||||
):
|
||||
model["litellm_params"]["rpm"] = model.get("rpm")
|
||||
if (
|
||||
model["litellm_params"].get("tpm") is None
|
||||
and model.get("tpm") is not None
|
||||
):
|
||||
model["litellm_params"]["tpm"] = model.get("tpm")
|
||||
|
||||
#### VALIDATE MODEL ########
|
||||
# check if model provider in supported providers
|
||||
(
|
||||
_model,
|
||||
custom_llm_provider,
|
||||
dynamic_api_key,
|
||||
api_base,
|
||||
) = litellm.get_llm_provider(
|
||||
model=model["litellm_params"]["model"],
|
||||
custom_llm_provider=model["litellm_params"].get(
|
||||
"custom_llm_provider", None
|
||||
),
|
||||
for model in original_model_list:
|
||||
deployment = Deployment(
|
||||
model_name=model["model_name"],
|
||||
litellm_params=model["litellm_params"],
|
||||
model_info=model.get("model_info", {}),
|
||||
)
|
||||
self._add_deployment(deployment=deployment)
|
||||
|
||||
# Check if user is trying to use model_name == "*"
|
||||
# this is a catch all model for their specific api key
|
||||
if model["model_name"] == "*":
|
||||
self.default_deployment = model
|
||||
model = deployment.to_json(exclude_none=True)
|
||||
|
||||
# Azure GPT-Vision Enhancements, users can pass os.environ/
|
||||
data_sources = model.get("litellm_params", {}).get("dataSources", [])
|
||||
|
||||
for data_source in data_sources:
|
||||
params = data_source.get("parameters", {})
|
||||
for param_key in ["endpoint", "key"]:
|
||||
# if endpoint or key set for Azure GPT Vision Enhancements, check if it's an env var
|
||||
if param_key in params and params[param_key].startswith(
|
||||
"os.environ/"
|
||||
):
|
||||
env_name = params[param_key].replace("os.environ/", "")
|
||||
params[param_key] = os.environ.get(env_name, "")
|
||||
|
||||
# done reading model["litellm_params"]
|
||||
if custom_llm_provider not in litellm.provider_list:
|
||||
raise Exception(f"Unsupported provider - {custom_llm_provider}")
|
||||
|
||||
# init OpenAI, Azure clients
|
||||
self.set_client(model=model)
|
||||
self.model_list.append(model)
|
||||
|
||||
verbose_router_logger.debug(f"\nInitialized Model List {self.model_list}")
|
||||
self.model_names = [m["model_name"] for m in model_list]
|
||||
|
||||
def _add_deployment(self, deployment: Deployment):
|
||||
import os
|
||||
|
||||
#### DEPLOYMENT NAMES INIT ########
|
||||
self.deployment_names.append(deployment.litellm_params.model)
|
||||
############ Users can either pass tpm/rpm as a litellm_param or a router param ###########
|
||||
# for get_available_deployment, we use the litellm_param["rpm"]
|
||||
# in this snippet we also set rpm to be a litellm_param
|
||||
if (
|
||||
deployment.litellm_params.rpm is None
|
||||
and getattr(deployment, "rpm", None) is not None
|
||||
):
|
||||
deployment.litellm_params.rpm = getattr(deployment, "rpm")
|
||||
|
||||
if (
|
||||
deployment.litellm_params.tpm is None
|
||||
and getattr(deployment, "tpm", None) is not None
|
||||
):
|
||||
deployment.litellm_params.tpm = getattr(deployment, "tpm")
|
||||
|
||||
#### VALIDATE MODEL ########
|
||||
# check if model provider in supported providers
|
||||
(
|
||||
_model,
|
||||
custom_llm_provider,
|
||||
dynamic_api_key,
|
||||
api_base,
|
||||
) = litellm.get_llm_provider(
|
||||
model=deployment.litellm_params.model,
|
||||
custom_llm_provider=deployment.litellm_params.get(
|
||||
"custom_llm_provider", None
|
||||
),
|
||||
)
|
||||
|
||||
# Check if user is trying to use model_name == "*"
|
||||
# this is a catch all model for their specific api key
|
||||
if deployment.model_name == "*":
|
||||
self.default_deployment = deployment
|
||||
|
||||
# Azure GPT-Vision Enhancements, users can pass os.environ/
|
||||
data_sources = deployment.litellm_params.get("dataSources", [])
|
||||
|
||||
for data_source in data_sources:
|
||||
params = data_source.get("parameters", {})
|
||||
for param_key in ["endpoint", "key"]:
|
||||
# if endpoint or key set for Azure GPT Vision Enhancements, check if it's an env var
|
||||
if param_key in params and params[param_key].startswith("os.environ/"):
|
||||
env_name = params[param_key].replace("os.environ/", "")
|
||||
params[param_key] = os.environ.get(env_name, "")
|
||||
|
||||
# done reading model["litellm_params"]
|
||||
if custom_llm_provider not in litellm.provider_list:
|
||||
raise Exception(f"Unsupported provider - {custom_llm_provider}")
|
||||
|
||||
# init OpenAI, Azure clients
|
||||
self.set_client(model=deployment.to_json(exclude_none=True))
|
||||
|
||||
def add_deployment(self, deployment: Deployment):
|
||||
# check if deployment already exists
|
||||
|
||||
if deployment.model_info.id in self.get_model_ids():
|
||||
return
|
||||
|
||||
# add to model list
|
||||
_deployment = deployment.to_json(exclude_none=True)
|
||||
self.model_list.append(_deployment)
|
||||
|
||||
# initialize client
|
||||
self._add_deployment(deployment=deployment)
|
||||
|
||||
# add to model names
|
||||
self.model_names.append(deployment.model_name)
|
||||
return
|
||||
|
||||
def get_model_ids(self):
|
||||
ids = []
|
||||
for model in self.model_list:
|
||||
if "model_info" in model and "id" in model["model_info"]:
|
||||
id = model["model_info"]["id"]
|
||||
ids.append(id)
|
||||
return ids
|
||||
|
||||
def get_model_names(self):
|
||||
return self.model_names
|
||||
|
||||
def get_model_list(self):
|
||||
return self.model_list
|
||||
|
||||
def _get_client(self, deployment, kwargs, client_type=None):
|
||||
"""
|
||||
Returns the appropriate client based on the given deployment, kwargs, and client_type.
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue