forked from phoenix/litellm-mirror
fix(proxy_server.py): persist models added via /model/new
to db
allows models to be used across instances https://github.com/BerriAI/litellm/issues/2319 , https://github.com/BerriAI/litellm/issues/2329
This commit is contained in:
parent
24d9fcb32c
commit
f536fb13e6
7 changed files with 435 additions and 86 deletions
|
@ -23,6 +23,7 @@ litellm_settings:
|
||||||
general_settings:
|
general_settings:
|
||||||
master_key: sk-1234
|
master_key: sk-1234
|
||||||
alerting: ["slack"]
|
alerting: ["slack"]
|
||||||
|
store_model_in_db: True
|
||||||
# proxy_batch_write_at: 60 # 👈 Frequency of batch writing logs to server (in seconds)
|
# proxy_batch_write_at: 60 # 👈 Frequency of batch writing logs to server (in seconds)
|
||||||
enable_jwt_auth: True
|
enable_jwt_auth: True
|
||||||
alerting: ["slack"]
|
alerting: ["slack"]
|
||||||
|
|
|
@ -97,6 +97,8 @@ from litellm.proxy.utils import (
|
||||||
_is_projected_spend_over_limit,
|
_is_projected_spend_over_limit,
|
||||||
_get_projected_spend_over_limit,
|
_get_projected_spend_over_limit,
|
||||||
update_spend,
|
update_spend,
|
||||||
|
encrypt_value,
|
||||||
|
decrypt_value,
|
||||||
)
|
)
|
||||||
from litellm.proxy.secret_managers.google_kms import load_google_kms
|
from litellm.proxy.secret_managers.google_kms import load_google_kms
|
||||||
from litellm.proxy.secret_managers.aws_secret_manager import load_aws_secret_manager
|
from litellm.proxy.secret_managers.aws_secret_manager import load_aws_secret_manager
|
||||||
|
@ -104,6 +106,8 @@ import pydantic
|
||||||
from litellm.proxy._types import *
|
from litellm.proxy._types import *
|
||||||
from litellm.caching import DualCache, RedisCache
|
from litellm.caching import DualCache, RedisCache
|
||||||
from litellm.proxy.health_check import perform_health_check
|
from litellm.proxy.health_check import perform_health_check
|
||||||
|
from litellm.router import LiteLLM_Params, Deployment
|
||||||
|
from litellm.router import ModelInfo as RouterModelInfo
|
||||||
from litellm._logging import verbose_router_logger, verbose_proxy_logger
|
from litellm._logging import verbose_router_logger, verbose_proxy_logger
|
||||||
from litellm.proxy.auth.handle_jwt import JWTHandler
|
from litellm.proxy.auth.handle_jwt import JWTHandler
|
||||||
from litellm.proxy.hooks.prompt_injection_detection import (
|
from litellm.proxy.hooks.prompt_injection_detection import (
|
||||||
|
@ -2371,6 +2375,64 @@ class ProxyConfig:
|
||||||
router = litellm.Router(**router_params) # type:ignore
|
router = litellm.Router(**router_params) # type:ignore
|
||||||
return router, model_list, general_settings
|
return router, model_list, general_settings
|
||||||
|
|
||||||
|
async def add_deployment(
|
||||||
|
self,
|
||||||
|
prisma_client: PrismaClient,
|
||||||
|
proxy_logging_obj: ProxyLogging,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
- Check db for new models (last 10 most recently updated)
|
||||||
|
- Check if model id's in router already
|
||||||
|
- If not, add to router
|
||||||
|
"""
|
||||||
|
global llm_router, llm_model_list, master_key
|
||||||
|
|
||||||
|
import base64
|
||||||
|
|
||||||
|
try:
|
||||||
|
if llm_router is None:
|
||||||
|
raise Exception("No router initialized")
|
||||||
|
|
||||||
|
new_models = await prisma_client.db.litellm_proxymodeltable.find_many(
|
||||||
|
take=10, order={"updated_at": "desc"}
|
||||||
|
)
|
||||||
|
|
||||||
|
for m in new_models:
|
||||||
|
_litellm_params = m.litellm_params
|
||||||
|
if isinstance(_litellm_params, dict):
|
||||||
|
# decrypt values
|
||||||
|
for k, v in _litellm_params.items():
|
||||||
|
# decode base64
|
||||||
|
decoded_b64 = base64.b64decode(v)
|
||||||
|
# decrypt value
|
||||||
|
_litellm_params[k] = decrypt_value(
|
||||||
|
value=decoded_b64, master_key=master_key
|
||||||
|
)
|
||||||
|
_litellm_params = LiteLLM_Params(**_litellm_params)
|
||||||
|
else:
|
||||||
|
raise Exception(
|
||||||
|
f"Invalid model added to proxy db. Invalid litellm params. litellm_params={_litellm_params}"
|
||||||
|
)
|
||||||
|
|
||||||
|
if m.model_info is not None and isinstance(m.model_info, dict):
|
||||||
|
if "id" not in m.model_info:
|
||||||
|
m.model_info["id"] = m.model_id
|
||||||
|
_model_info = RouterModelInfo(**m.model_info)
|
||||||
|
else:
|
||||||
|
_model_info = RouterModelInfo(id=m.model_id)
|
||||||
|
|
||||||
|
llm_router.add_deployment(
|
||||||
|
deployment=Deployment(
|
||||||
|
model_name=m.model_name,
|
||||||
|
litellm_params=_litellm_params,
|
||||||
|
model_info=_model_info,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
llm_model_list = llm_router.get_model_list()
|
||||||
|
except Exception as e:
|
||||||
|
raise e
|
||||||
|
|
||||||
|
|
||||||
proxy_config = ProxyConfig()
|
proxy_config = ProxyConfig()
|
||||||
|
|
||||||
|
@ -2943,7 +3005,7 @@ async def startup_event():
|
||||||
if prisma_client is not None:
|
if prisma_client is not None:
|
||||||
create_view_response = await prisma_client.check_view_exists()
|
create_view_response = await prisma_client.check_view_exists()
|
||||||
|
|
||||||
### START BATCH WRITING DB ###
|
### START BATCH WRITING DB + CHECKING NEW MODELS###
|
||||||
if prisma_client is not None:
|
if prisma_client is not None:
|
||||||
scheduler = AsyncIOScheduler()
|
scheduler = AsyncIOScheduler()
|
||||||
interval = random.randint(
|
interval = random.randint(
|
||||||
|
@ -2966,6 +3028,15 @@ async def startup_event():
|
||||||
seconds=batch_writing_interval,
|
seconds=batch_writing_interval,
|
||||||
args=[prisma_client, db_writer_client, proxy_logging_obj],
|
args=[prisma_client, db_writer_client, proxy_logging_obj],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
### ADD NEW MODELS ###
|
||||||
|
if general_settings.get("store_model_in_db", False) == True:
|
||||||
|
scheduler.add_job(
|
||||||
|
proxy_config.add_deployment,
|
||||||
|
"interval",
|
||||||
|
seconds=30,
|
||||||
|
args=[prisma_client, proxy_logging_obj],
|
||||||
|
)
|
||||||
scheduler.start()
|
scheduler.start()
|
||||||
|
|
||||||
|
|
||||||
|
@ -3314,8 +3385,6 @@ async def chat_completion(
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
start_time = time.time()
|
|
||||||
|
|
||||||
### ROUTE THE REQUEST ###
|
### ROUTE THE REQUEST ###
|
||||||
# Do not change this - it should be a constant time fetch - ALWAYS
|
# Do not change this - it should be a constant time fetch - ALWAYS
|
||||||
router_model_names = llm_router.model_names if llm_router is not None else []
|
router_model_names = llm_router.model_names if llm_router is not None else []
|
||||||
|
@ -3534,8 +3603,6 @@ async def embeddings(
|
||||||
user_api_key_dict=user_api_key_dict, data=data, call_type="embeddings"
|
user_api_key_dict=user_api_key_dict, data=data, call_type="embeddings"
|
||||||
)
|
)
|
||||||
|
|
||||||
start_time = time.time()
|
|
||||||
|
|
||||||
## ROUTE TO CORRECT ENDPOINT ##
|
## ROUTE TO CORRECT ENDPOINT ##
|
||||||
# skip router if user passed their key
|
# skip router if user passed their key
|
||||||
if "api_key" in data:
|
if "api_key" in data:
|
||||||
|
@ -6691,30 +6758,47 @@ async def info_budget(data: BudgetRequest):
|
||||||
tags=["model management"],
|
tags=["model management"],
|
||||||
dependencies=[Depends(user_api_key_auth)],
|
dependencies=[Depends(user_api_key_auth)],
|
||||||
)
|
)
|
||||||
async def add_new_model(model_params: ModelParams):
|
async def add_new_model(
|
||||||
global llm_router, llm_model_list, general_settings, user_config_file_path, proxy_config
|
model_params: ModelParams,
|
||||||
|
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
||||||
|
):
|
||||||
|
global llm_router, llm_model_list, general_settings, user_config_file_path, proxy_config, prisma_client, master_key
|
||||||
try:
|
try:
|
||||||
# Load existing config
|
import base64
|
||||||
config = await proxy_config.get_config()
|
|
||||||
|
|
||||||
verbose_proxy_logger.debug("User config path: %s", user_config_file_path)
|
global prisma_client
|
||||||
|
|
||||||
verbose_proxy_logger.debug("Loaded config: %s", config)
|
if prisma_client is None:
|
||||||
# Add the new model to the config
|
raise HTTPException(
|
||||||
model_info = model_params.model_info.json()
|
status_code=500,
|
||||||
model_info = {k: v for k, v in model_info.items() if v is not None}
|
detail={
|
||||||
config["model_list"].append(
|
"error": "No DB Connected. Here's how to do it - https://docs.litellm.ai/docs/proxy/virtual_keys"
|
||||||
{
|
},
|
||||||
"model_name": model_params.model_name,
|
|
||||||
"litellm_params": model_params.litellm_params,
|
|
||||||
"model_info": model_info,
|
|
||||||
}
|
|
||||||
)
|
)
|
||||||
|
|
||||||
verbose_proxy_logger.debug("updated model list: %s", config["model_list"])
|
# update DB
|
||||||
|
if general_settings.get("store_model_in_db", False) == True:
|
||||||
# Save new config
|
"""
|
||||||
await proxy_config.save_config(new_config=config)
|
- store model_list in db
|
||||||
|
- store keys separately
|
||||||
|
"""
|
||||||
|
# encrypt litellm params #
|
||||||
|
for k, v in model_params.litellm_params.items():
|
||||||
|
encrypted_value = encrypt_value(value=v, master_key=master_key) # type: ignore
|
||||||
|
model_params.litellm_params[k] = base64.b64encode(
|
||||||
|
encrypted_value
|
||||||
|
).decode("utf-8")
|
||||||
|
await prisma_client.db.litellm_proxymodeltable.create(
|
||||||
|
data={
|
||||||
|
"model_name": model_params.model_name,
|
||||||
|
"litellm_params": json.dumps(model_params.litellm_params), # type: ignore
|
||||||
|
"model_info": model_params.model_info.model_dump_json( # type: ignore
|
||||||
|
exclude_none=True
|
||||||
|
),
|
||||||
|
"created_by": user_api_key_dict.user_id or litellm_proxy_admin_name,
|
||||||
|
"updated_by": user_api_key_dict.user_id or litellm_proxy_admin_name,
|
||||||
|
}
|
||||||
|
)
|
||||||
return {"message": "Model added successfully"}
|
return {"message": "Model added successfully"}
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
@ -6884,14 +6968,16 @@ async def model_info_v1(
|
||||||
):
|
):
|
||||||
global llm_model_list, general_settings, user_config_file_path, proxy_config
|
global llm_model_list, general_settings, user_config_file_path, proxy_config
|
||||||
|
|
||||||
# Load existing config
|
if llm_model_list is None:
|
||||||
config = await proxy_config.get_config()
|
raise HTTPException(
|
||||||
|
status_code=500, detail={"error": "LLM Model List not loaded in"}
|
||||||
|
)
|
||||||
|
|
||||||
if len(user_api_key_dict.models) > 0:
|
if len(user_api_key_dict.models) > 0:
|
||||||
model_names = user_api_key_dict.models
|
model_names = user_api_key_dict.models
|
||||||
all_models = [m for m in config["model_list"] if m["model_name"] in model_names]
|
all_models = [m for m in llm_model_list if m["model_name"] in model_names]
|
||||||
else:
|
else:
|
||||||
all_models = config["model_list"]
|
all_models = llm_model_list
|
||||||
for model in all_models:
|
for model in all_models:
|
||||||
# provided model_info in config.yaml
|
# provided model_info in config.yaml
|
||||||
model_info = model.get("model_info", {})
|
model_info = model.get("model_info", {})
|
||||||
|
@ -6956,6 +7042,7 @@ async def delete_model(model_info: ModelInfoDelete):
|
||||||
|
|
||||||
# Check if the model with the specified model_id exists
|
# Check if the model with the specified model_id exists
|
||||||
model_to_delete = None
|
model_to_delete = None
|
||||||
|
|
||||||
for model in config["model_list"]:
|
for model in config["model_list"]:
|
||||||
if model.get("model_info", {}).get("id", None) == model_info.id:
|
if model.get("model_info", {}).get("id", None) == model_info.id:
|
||||||
model_to_delete = model
|
model_to_delete = model
|
||||||
|
|
|
@ -27,6 +27,19 @@ model LiteLLM_BudgetTable {
|
||||||
end_users LiteLLM_EndUserTable[] // multiple end-users can have the same budget
|
end_users LiteLLM_EndUserTable[] // multiple end-users can have the same budget
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Models on proxy
|
||||||
|
model LiteLLM_ProxyModelTable {
|
||||||
|
model_id String @id @default(uuid())
|
||||||
|
model_name String
|
||||||
|
litellm_params Json
|
||||||
|
model_info Json?
|
||||||
|
created_at DateTime @default(now()) @map("created_at")
|
||||||
|
created_by String
|
||||||
|
updated_at DateTime @default(now()) @updatedAt @map("updated_at")
|
||||||
|
updated_by String
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
model LiteLLM_OrganizationTable {
|
model LiteLLM_OrganizationTable {
|
||||||
organization_id String @id @default(uuid())
|
organization_id String @id @default(uuid())
|
||||||
organization_alias String
|
organization_alias String
|
||||||
|
|
|
@ -13,6 +13,7 @@ from litellm.proxy._types import (
|
||||||
Member,
|
Member,
|
||||||
)
|
)
|
||||||
from litellm.caching import DualCache, RedisCache
|
from litellm.caching import DualCache, RedisCache
|
||||||
|
from litellm.router import Deployment, ModelInfo, LiteLLM_Params
|
||||||
from litellm.llms.custom_httpx.httpx_handler import HTTPHandler
|
from litellm.llms.custom_httpx.httpx_handler import HTTPHandler
|
||||||
from litellm.proxy.hooks.parallel_request_limiter import (
|
from litellm.proxy.hooks.parallel_request_limiter import (
|
||||||
_PROXY_MaxParallelRequestsHandler,
|
_PROXY_MaxParallelRequestsHandler,
|
||||||
|
@ -2181,6 +2182,32 @@ async def update_spend(
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
|
|
||||||
|
# class Models:
|
||||||
|
# """
|
||||||
|
# Need a class to maintain state of models / router across calls to check if new deployments need to be added
|
||||||
|
# """
|
||||||
|
|
||||||
|
# def __init__(
|
||||||
|
# self,
|
||||||
|
# router: litellm.Router,
|
||||||
|
# llm_model_list: list,
|
||||||
|
# prisma_client: PrismaClient,
|
||||||
|
# proxy_logging_obj: ProxyLogging,
|
||||||
|
# master_key: str,
|
||||||
|
# ) -> None:
|
||||||
|
# self.router = router
|
||||||
|
# self.llm_model_list = llm_model_list
|
||||||
|
# self.prisma_client = prisma_client
|
||||||
|
# self.proxy_logging_obj = proxy_logging_obj
|
||||||
|
# self.master_key = master_key
|
||||||
|
|
||||||
|
# def get_router(self) -> litellm.Router:
|
||||||
|
# return self.router
|
||||||
|
|
||||||
|
# def get_model_list(self) -> list:
|
||||||
|
# return self.llm_model_list
|
||||||
|
|
||||||
|
|
||||||
async def _read_request_body(request):
|
async def _read_request_body(request):
|
||||||
"""
|
"""
|
||||||
Asynchronous function to read the request body and parse it as JSON or literal data.
|
Asynchronous function to read the request body and parse it as JSON or literal data.
|
||||||
|
@ -2318,6 +2345,45 @@ def _is_user_proxy_admin(user_id_information=None):
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def encrypt_value(value: str, master_key: str):
|
||||||
|
import hashlib
|
||||||
|
import nacl.secret
|
||||||
|
import nacl.utils
|
||||||
|
|
||||||
|
# get 32 byte master key #
|
||||||
|
hash_object = hashlib.sha256(master_key.encode())
|
||||||
|
hash_bytes = hash_object.digest()
|
||||||
|
|
||||||
|
# initialize secret box #
|
||||||
|
box = nacl.secret.SecretBox(hash_bytes)
|
||||||
|
|
||||||
|
# encode message #
|
||||||
|
value_bytes = value.encode("utf-8")
|
||||||
|
|
||||||
|
encrypted = box.encrypt(value_bytes)
|
||||||
|
|
||||||
|
return encrypted
|
||||||
|
|
||||||
|
|
||||||
|
def decrypt_value(value: bytes, master_key: str) -> str:
|
||||||
|
import hashlib
|
||||||
|
import nacl.secret
|
||||||
|
import nacl.utils
|
||||||
|
|
||||||
|
# get 32 byte master key #
|
||||||
|
hash_object = hashlib.sha256(master_key.encode())
|
||||||
|
hash_bytes = hash_object.digest()
|
||||||
|
|
||||||
|
# initialize secret box #
|
||||||
|
box = nacl.secret.SecretBox(hash_bytes)
|
||||||
|
|
||||||
|
# Convert the bytes object to a string
|
||||||
|
plaintext = box.decrypt(value)
|
||||||
|
|
||||||
|
plaintext = plaintext.decode("utf-8") # type: ignore
|
||||||
|
return plaintext # type: ignore
|
||||||
|
|
||||||
|
|
||||||
# LiteLLM Admin UI - Non SSO Login
|
# LiteLLM Admin UI - Non SSO Login
|
||||||
html_form = """
|
html_form = """
|
||||||
<!DOCTYPE html>
|
<!DOCTYPE html>
|
||||||
|
|
|
@ -29,6 +29,103 @@ from litellm.utils import ModelResponse, CustomStreamWrapper
|
||||||
import copy
|
import copy
|
||||||
from litellm._logging import verbose_router_logger
|
from litellm._logging import verbose_router_logger
|
||||||
import logging
|
import logging
|
||||||
|
from pydantic import BaseModel, validator
|
||||||
|
|
||||||
|
|
||||||
|
class ModelInfo(BaseModel):
|
||||||
|
id: Optional[
|
||||||
|
str
|
||||||
|
] # Allow id to be optional on input, but it will always be present as a str in the model instance
|
||||||
|
|
||||||
|
def __init__(self, id: Optional[str] = None, **params):
|
||||||
|
if id is None:
|
||||||
|
id = str(uuid.uuid4()) # Generate a UUID if id is None or not provided
|
||||||
|
super().__init__(id=id, **params)
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
extra = "allow"
|
||||||
|
|
||||||
|
def __contains__(self, key):
|
||||||
|
# Define custom behavior for the 'in' operator
|
||||||
|
return hasattr(self, key)
|
||||||
|
|
||||||
|
def get(self, key, default=None):
|
||||||
|
# Custom .get() method to access attributes with a default value if the attribute doesn't exist
|
||||||
|
return getattr(self, key, default)
|
||||||
|
|
||||||
|
def __getitem__(self, key):
|
||||||
|
# Allow dictionary-style access to attributes
|
||||||
|
return getattr(self, key)
|
||||||
|
|
||||||
|
def __setitem__(self, key, value):
|
||||||
|
# Allow dictionary-style assignment of attributes
|
||||||
|
setattr(self, key, value)
|
||||||
|
|
||||||
|
|
||||||
|
class LiteLLM_Params(BaseModel):
|
||||||
|
model: str
|
||||||
|
tpm: Optional[int] = None
|
||||||
|
rpm: Optional[int] = None
|
||||||
|
api_key: Optional[str] = None
|
||||||
|
api_base: Optional[str] = None
|
||||||
|
api_version: Optional[str] = None
|
||||||
|
timeout: Optional[Union[float, str]] = None # if str, pass in as os.environ/
|
||||||
|
stream_timeout: Optional[Union[float, str]] = (
|
||||||
|
None # timeout when making stream=True calls, if str, pass in as os.environ/
|
||||||
|
)
|
||||||
|
max_retries: Optional[int] = 2 # follows openai default of 2
|
||||||
|
organization: Optional[str] = None # for openai orgs
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
extra = "allow"
|
||||||
|
|
||||||
|
def __contains__(self, key):
|
||||||
|
# Define custom behavior for the 'in' operator
|
||||||
|
return hasattr(self, key)
|
||||||
|
|
||||||
|
def get(self, key, default=None):
|
||||||
|
# Custom .get() method to access attributes with a default value if the attribute doesn't exist
|
||||||
|
return getattr(self, key, default)
|
||||||
|
|
||||||
|
def __getitem__(self, key):
|
||||||
|
# Allow dictionary-style access to attributes
|
||||||
|
return getattr(self, key)
|
||||||
|
|
||||||
|
def __setitem__(self, key, value):
|
||||||
|
# Allow dictionary-style assignment of attributes
|
||||||
|
setattr(self, key, value)
|
||||||
|
|
||||||
|
|
||||||
|
class Deployment(BaseModel):
|
||||||
|
model_name: str
|
||||||
|
litellm_params: LiteLLM_Params
|
||||||
|
model_info: ModelInfo
|
||||||
|
|
||||||
|
def to_json(self, **kwargs):
|
||||||
|
try:
|
||||||
|
return self.model_dump(**kwargs) # noqa
|
||||||
|
except Exception as e:
|
||||||
|
# if using pydantic v1
|
||||||
|
return self.dict(**kwargs)
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
extra = "allow"
|
||||||
|
|
||||||
|
def __contains__(self, key):
|
||||||
|
# Define custom behavior for the 'in' operator
|
||||||
|
return hasattr(self, key)
|
||||||
|
|
||||||
|
def get(self, key, default=None):
|
||||||
|
# Custom .get() method to access attributes with a default value if the attribute doesn't exist
|
||||||
|
return getattr(self, key, default)
|
||||||
|
|
||||||
|
def __getitem__(self, key):
|
||||||
|
# Allow dictionary-style access to attributes
|
||||||
|
return getattr(self, key)
|
||||||
|
|
||||||
|
def __setitem__(self, key, value):
|
||||||
|
# Allow dictionary-style assignment of attributes
|
||||||
|
setattr(self, key, value)
|
||||||
|
|
||||||
|
|
||||||
class Router:
|
class Router:
|
||||||
|
@ -2040,30 +2137,45 @@ class Router:
|
||||||
) # cache for 1 hr
|
) # cache for 1 hr
|
||||||
|
|
||||||
def set_model_list(self, model_list: list):
|
def set_model_list(self, model_list: list):
|
||||||
self.model_list = copy.deepcopy(model_list)
|
original_model_list = copy.deepcopy(model_list)
|
||||||
|
self.model_list = []
|
||||||
# we add api_base/api_key each model so load balancing between azure/gpt on api_base1 and api_base2 works
|
# we add api_base/api_key each model so load balancing between azure/gpt on api_base1 and api_base2 works
|
||||||
import os
|
import os
|
||||||
|
|
||||||
for model in self.model_list:
|
for model in original_model_list:
|
||||||
#### MODEL ID INIT ########
|
deployment = Deployment(
|
||||||
model_info = model.get("model_info", {})
|
model_name=model["model_name"],
|
||||||
model_info["id"] = model_info.get("id", str(uuid.uuid4()))
|
litellm_params=model["litellm_params"],
|
||||||
model["model_info"] = model_info
|
model_info=model.get("model_info", {}),
|
||||||
|
)
|
||||||
|
self._add_deployment(deployment=deployment)
|
||||||
|
|
||||||
|
model = deployment.to_json(exclude_none=True)
|
||||||
|
|
||||||
|
self.model_list.append(model)
|
||||||
|
|
||||||
|
verbose_router_logger.debug(f"\nInitialized Model List {self.model_list}")
|
||||||
|
self.model_names = [m["model_name"] for m in model_list]
|
||||||
|
|
||||||
|
def _add_deployment(self, deployment: Deployment):
|
||||||
|
import os
|
||||||
|
|
||||||
#### DEPLOYMENT NAMES INIT ########
|
#### DEPLOYMENT NAMES INIT ########
|
||||||
self.deployment_names.append(model["litellm_params"]["model"])
|
self.deployment_names.append(deployment.litellm_params.model)
|
||||||
############ Users can either pass tpm/rpm as a litellm_param or a router param ###########
|
############ Users can either pass tpm/rpm as a litellm_param or a router param ###########
|
||||||
# for get_available_deployment, we use the litellm_param["rpm"]
|
# for get_available_deployment, we use the litellm_param["rpm"]
|
||||||
# in this snippet we also set rpm to be a litellm_param
|
# in this snippet we also set rpm to be a litellm_param
|
||||||
if (
|
if (
|
||||||
model["litellm_params"].get("rpm") is None
|
deployment.litellm_params.rpm is None
|
||||||
and model.get("rpm") is not None
|
and getattr(deployment, "rpm", None) is not None
|
||||||
):
|
):
|
||||||
model["litellm_params"]["rpm"] = model.get("rpm")
|
deployment.litellm_params.rpm = getattr(deployment, "rpm")
|
||||||
|
|
||||||
if (
|
if (
|
||||||
model["litellm_params"].get("tpm") is None
|
deployment.litellm_params.tpm is None
|
||||||
and model.get("tpm") is not None
|
and getattr(deployment, "tpm", None) is not None
|
||||||
):
|
):
|
||||||
model["litellm_params"]["tpm"] = model.get("tpm")
|
deployment.litellm_params.tpm = getattr(deployment, "tpm")
|
||||||
|
|
||||||
#### VALIDATE MODEL ########
|
#### VALIDATE MODEL ########
|
||||||
# check if model provider in supported providers
|
# check if model provider in supported providers
|
||||||
|
@ -2073,27 +2185,25 @@ class Router:
|
||||||
dynamic_api_key,
|
dynamic_api_key,
|
||||||
api_base,
|
api_base,
|
||||||
) = litellm.get_llm_provider(
|
) = litellm.get_llm_provider(
|
||||||
model=model["litellm_params"]["model"],
|
model=deployment.litellm_params.model,
|
||||||
custom_llm_provider=model["litellm_params"].get(
|
custom_llm_provider=deployment.litellm_params.get(
|
||||||
"custom_llm_provider", None
|
"custom_llm_provider", None
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
# Check if user is trying to use model_name == "*"
|
# Check if user is trying to use model_name == "*"
|
||||||
# this is a catch all model for their specific api key
|
# this is a catch all model for their specific api key
|
||||||
if model["model_name"] == "*":
|
if deployment.model_name == "*":
|
||||||
self.default_deployment = model
|
self.default_deployment = deployment
|
||||||
|
|
||||||
# Azure GPT-Vision Enhancements, users can pass os.environ/
|
# Azure GPT-Vision Enhancements, users can pass os.environ/
|
||||||
data_sources = model.get("litellm_params", {}).get("dataSources", [])
|
data_sources = deployment.litellm_params.get("dataSources", [])
|
||||||
|
|
||||||
for data_source in data_sources:
|
for data_source in data_sources:
|
||||||
params = data_source.get("parameters", {})
|
params = data_source.get("parameters", {})
|
||||||
for param_key in ["endpoint", "key"]:
|
for param_key in ["endpoint", "key"]:
|
||||||
# if endpoint or key set for Azure GPT Vision Enhancements, check if it's an env var
|
# if endpoint or key set for Azure GPT Vision Enhancements, check if it's an env var
|
||||||
if param_key in params and params[param_key].startswith(
|
if param_key in params and params[param_key].startswith("os.environ/"):
|
||||||
"os.environ/"
|
|
||||||
):
|
|
||||||
env_name = params[param_key].replace("os.environ/", "")
|
env_name = params[param_key].replace("os.environ/", "")
|
||||||
params[param_key] = os.environ.get(env_name, "")
|
params[param_key] = os.environ.get(env_name, "")
|
||||||
|
|
||||||
|
@ -2102,14 +2212,39 @@ class Router:
|
||||||
raise Exception(f"Unsupported provider - {custom_llm_provider}")
|
raise Exception(f"Unsupported provider - {custom_llm_provider}")
|
||||||
|
|
||||||
# init OpenAI, Azure clients
|
# init OpenAI, Azure clients
|
||||||
self.set_client(model=model)
|
self.set_client(model=deployment.to_json(exclude_none=True))
|
||||||
|
|
||||||
verbose_router_logger.debug(f"\nInitialized Model List {self.model_list}")
|
def add_deployment(self, deployment: Deployment):
|
||||||
self.model_names = [m["model_name"] for m in model_list]
|
# check if deployment already exists
|
||||||
|
|
||||||
|
if deployment.model_info.id in self.get_model_ids():
|
||||||
|
return
|
||||||
|
|
||||||
|
# add to model list
|
||||||
|
_deployment = deployment.to_json(exclude_none=True)
|
||||||
|
self.model_list.append(_deployment)
|
||||||
|
|
||||||
|
# initialize client
|
||||||
|
self._add_deployment(deployment=deployment)
|
||||||
|
|
||||||
|
# add to model names
|
||||||
|
self.model_names.append(deployment.model_name)
|
||||||
|
return
|
||||||
|
|
||||||
|
def get_model_ids(self):
|
||||||
|
ids = []
|
||||||
|
for model in self.model_list:
|
||||||
|
if "model_info" in model and "id" in model["model_info"]:
|
||||||
|
id = model["model_info"]["id"]
|
||||||
|
ids.append(id)
|
||||||
|
return ids
|
||||||
|
|
||||||
def get_model_names(self):
|
def get_model_names(self):
|
||||||
return self.model_names
|
return self.model_names
|
||||||
|
|
||||||
|
def get_model_list(self):
|
||||||
|
return self.model_list
|
||||||
|
|
||||||
def _get_client(self, deployment, kwargs, client_type=None):
|
def _get_client(self, deployment, kwargs, client_type=None):
|
||||||
"""
|
"""
|
||||||
Returns the appropriate client based on the given deployment, kwargs, and client_type.
|
Returns the appropriate client based on the given deployment, kwargs, and client_type.
|
||||||
|
|
|
@ -10,6 +10,7 @@ sys.path.insert(
|
||||||
) # Adds the parent directory to the system path
|
) # Adds the parent directory to the system path
|
||||||
import litellm
|
import litellm
|
||||||
from litellm import Router
|
from litellm import Router
|
||||||
|
from litellm.router import Deployment, LiteLLM_Params, ModelInfo
|
||||||
from concurrent.futures import ThreadPoolExecutor
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from dotenv import load_dotenv
|
from dotenv import load_dotenv
|
||||||
|
@ -1193,3 +1194,37 @@ async def test_router_amoderation():
|
||||||
)
|
)
|
||||||
|
|
||||||
print("moderation result", result)
|
print("moderation result", result)
|
||||||
|
|
||||||
|
|
||||||
|
def test_router_add_deployment():
|
||||||
|
initial_model_list = [
|
||||||
|
{
|
||||||
|
"model_name": "fake-openai-endpoint",
|
||||||
|
"litellm_params": {
|
||||||
|
"model": "openai/my-fake-model",
|
||||||
|
"api_key": "my-fake-key",
|
||||||
|
"api_base": "https://openai-function-calling-workers.tasslexyz.workers.dev/",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
]
|
||||||
|
router = Router(model_list=initial_model_list)
|
||||||
|
|
||||||
|
init_model_id_list = router.get_model_ids()
|
||||||
|
|
||||||
|
print(f"init_model_id_list: {init_model_id_list}")
|
||||||
|
|
||||||
|
router.add_deployment(
|
||||||
|
deployment=Deployment(
|
||||||
|
model_name="gpt-instruct",
|
||||||
|
litellm_params=LiteLLM_Params(model="gpt-3.5-turbo-instruct"),
|
||||||
|
model_info=ModelInfo(),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
new_model_id_list = router.get_model_ids()
|
||||||
|
|
||||||
|
print(f"new_model_id_list: {new_model_id_list}")
|
||||||
|
|
||||||
|
assert len(new_model_id_list) > len(init_model_id_list)
|
||||||
|
|
||||||
|
assert new_model_id_list[1] != new_model_id_list[0]
|
||||||
|
|
|
@ -27,6 +27,18 @@ model LiteLLM_BudgetTable {
|
||||||
end_users LiteLLM_EndUserTable[] // multiple end-users can have the same budget
|
end_users LiteLLM_EndUserTable[] // multiple end-users can have the same budget
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Models on proxy
|
||||||
|
model LiteLLM_ProxyModelTable {
|
||||||
|
model_id String @id @default(uuid())
|
||||||
|
model_name String
|
||||||
|
litellm_params Json
|
||||||
|
model_info Json?
|
||||||
|
created_at DateTime @default(now()) @map("created_at")
|
||||||
|
created_by String
|
||||||
|
updated_at DateTime @default(now()) @updatedAt @map("updated_at")
|
||||||
|
updated_by String
|
||||||
|
}
|
||||||
|
|
||||||
model LiteLLM_OrganizationTable {
|
model LiteLLM_OrganizationTable {
|
||||||
organization_id String @id @default(uuid())
|
organization_id String @id @default(uuid())
|
||||||
organization_alias String
|
organization_alias String
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue