diff --git a/litellm/proxy/_new_secret_config.yaml b/litellm/proxy/_new_secret_config.yaml index 784413d77..9049d78e4 100644 --- a/litellm/proxy/_new_secret_config.yaml +++ b/litellm/proxy/_new_secret_config.yaml @@ -23,6 +23,7 @@ litellm_settings: general_settings: master_key: sk-1234 alerting: ["slack"] + store_model_in_db: True # proxy_batch_write_at: 60 # 👈 Frequency of batch writing logs to server (in seconds) enable_jwt_auth: True alerting: ["slack"] diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index 5ef3f454f..eac1eb31e 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -97,6 +97,8 @@ from litellm.proxy.utils import ( _is_projected_spend_over_limit, _get_projected_spend_over_limit, update_spend, + encrypt_value, + decrypt_value, ) from litellm.proxy.secret_managers.google_kms import load_google_kms from litellm.proxy.secret_managers.aws_secret_manager import load_aws_secret_manager @@ -104,6 +106,8 @@ import pydantic from litellm.proxy._types import * from litellm.caching import DualCache, RedisCache from litellm.proxy.health_check import perform_health_check +from litellm.router import LiteLLM_Params, Deployment +from litellm.router import ModelInfo as RouterModelInfo from litellm._logging import verbose_router_logger, verbose_proxy_logger from litellm.proxy.auth.handle_jwt import JWTHandler from litellm.proxy.hooks.prompt_injection_detection import ( @@ -2371,6 +2375,64 @@ class ProxyConfig: router = litellm.Router(**router_params) # type:ignore return router, model_list, general_settings + async def add_deployment( + self, + prisma_client: PrismaClient, + proxy_logging_obj: ProxyLogging, + ): + """ + - Check db for new models (last 10 most recently updated) + - Check if model id's in router already + - If not, add to router + """ + global llm_router, llm_model_list, master_key + + import base64 + + try: + if llm_router is None: + raise Exception("No router initialized") + + new_models = await prisma_client.db.litellm_proxymodeltable.find_many( + take=10, order={"updated_at": "desc"} + ) + + for m in new_models: + _litellm_params = m.litellm_params + if isinstance(_litellm_params, dict): + # decrypt values + for k, v in _litellm_params.items(): + # decode base64 + decoded_b64 = base64.b64decode(v) + # decrypt value + _litellm_params[k] = decrypt_value( + value=decoded_b64, master_key=master_key + ) + _litellm_params = LiteLLM_Params(**_litellm_params) + else: + raise Exception( + f"Invalid model added to proxy db. Invalid litellm params. litellm_params={_litellm_params}" + ) + + if m.model_info is not None and isinstance(m.model_info, dict): + if "id" not in m.model_info: + m.model_info["id"] = m.model_id + _model_info = RouterModelInfo(**m.model_info) + else: + _model_info = RouterModelInfo(id=m.model_id) + + llm_router.add_deployment( + deployment=Deployment( + model_name=m.model_name, + litellm_params=_litellm_params, + model_info=_model_info, + ) + ) + + llm_model_list = llm_router.get_model_list() + except Exception as e: + raise e + proxy_config = ProxyConfig() @@ -2943,7 +3005,7 @@ async def startup_event(): if prisma_client is not None: create_view_response = await prisma_client.check_view_exists() - ### START BATCH WRITING DB ### + ### START BATCH WRITING DB + CHECKING NEW MODELS### if prisma_client is not None: scheduler = AsyncIOScheduler() interval = random.randint( @@ -2966,6 +3028,15 @@ async def startup_event(): seconds=batch_writing_interval, args=[prisma_client, db_writer_client, proxy_logging_obj], ) + + ### ADD NEW MODELS ### + if general_settings.get("store_model_in_db", False) == True: + scheduler.add_job( + proxy_config.add_deployment, + "interval", + seconds=30, + args=[prisma_client, proxy_logging_obj], + ) scheduler.start() @@ -3314,8 +3385,6 @@ async def chat_completion( ) ) - start_time = time.time() - ### ROUTE THE REQUEST ### # Do not change this - it should be a constant time fetch - ALWAYS router_model_names = llm_router.model_names if llm_router is not None else [] @@ -3534,8 +3603,6 @@ async def embeddings( user_api_key_dict=user_api_key_dict, data=data, call_type="embeddings" ) - start_time = time.time() - ## ROUTE TO CORRECT ENDPOINT ## # skip router if user passed their key if "api_key" in data: @@ -6691,30 +6758,47 @@ async def info_budget(data: BudgetRequest): tags=["model management"], dependencies=[Depends(user_api_key_auth)], ) -async def add_new_model(model_params: ModelParams): - global llm_router, llm_model_list, general_settings, user_config_file_path, proxy_config +async def add_new_model( + model_params: ModelParams, + user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), +): + global llm_router, llm_model_list, general_settings, user_config_file_path, proxy_config, prisma_client, master_key try: - # Load existing config - config = await proxy_config.get_config() + import base64 - verbose_proxy_logger.debug("User config path: %s", user_config_file_path) + global prisma_client - verbose_proxy_logger.debug("Loaded config: %s", config) - # Add the new model to the config - model_info = model_params.model_info.json() - model_info = {k: v for k, v in model_info.items() if v is not None} - config["model_list"].append( - { - "model_name": model_params.model_name, - "litellm_params": model_params.litellm_params, - "model_info": model_info, - } - ) + if prisma_client is None: + raise HTTPException( + status_code=500, + detail={ + "error": "No DB Connected. Here's how to do it - https://docs.litellm.ai/docs/proxy/virtual_keys" + }, + ) - verbose_proxy_logger.debug("updated model list: %s", config["model_list"]) - - # Save new config - await proxy_config.save_config(new_config=config) + # update DB + if general_settings.get("store_model_in_db", False) == True: + """ + - store model_list in db + - store keys separately + """ + # encrypt litellm params # + for k, v in model_params.litellm_params.items(): + encrypted_value = encrypt_value(value=v, master_key=master_key) # type: ignore + model_params.litellm_params[k] = base64.b64encode( + encrypted_value + ).decode("utf-8") + await prisma_client.db.litellm_proxymodeltable.create( + data={ + "model_name": model_params.model_name, + "litellm_params": json.dumps(model_params.litellm_params), # type: ignore + "model_info": model_params.model_info.model_dump_json( # type: ignore + exclude_none=True + ), + "created_by": user_api_key_dict.user_id or litellm_proxy_admin_name, + "updated_by": user_api_key_dict.user_id or litellm_proxy_admin_name, + } + ) return {"message": "Model added successfully"} except Exception as e: @@ -6884,14 +6968,16 @@ async def model_info_v1( ): global llm_model_list, general_settings, user_config_file_path, proxy_config - # Load existing config - config = await proxy_config.get_config() + if llm_model_list is None: + raise HTTPException( + status_code=500, detail={"error": "LLM Model List not loaded in"} + ) if len(user_api_key_dict.models) > 0: model_names = user_api_key_dict.models - all_models = [m for m in config["model_list"] if m["model_name"] in model_names] + all_models = [m for m in llm_model_list if m["model_name"] in model_names] else: - all_models = config["model_list"] + all_models = llm_model_list for model in all_models: # provided model_info in config.yaml model_info = model.get("model_info", {}) @@ -6956,6 +7042,7 @@ async def delete_model(model_info: ModelInfoDelete): # Check if the model with the specified model_id exists model_to_delete = None + for model in config["model_list"]: if model.get("model_info", {}).get("id", None) == model_info.id: model_to_delete = model diff --git a/litellm/proxy/schema.prisma b/litellm/proxy/schema.prisma index 5ce0670b1..a18d4e581 100644 --- a/litellm/proxy/schema.prisma +++ b/litellm/proxy/schema.prisma @@ -27,6 +27,19 @@ model LiteLLM_BudgetTable { end_users LiteLLM_EndUserTable[] // multiple end-users can have the same budget } +// Models on proxy +model LiteLLM_ProxyModelTable { + model_id String @id @default(uuid()) + model_name String + litellm_params Json + model_info Json? + created_at DateTime @default(now()) @map("created_at") + created_by String + updated_at DateTime @default(now()) @updatedAt @map("updated_at") + updated_by String +} + + model LiteLLM_OrganizationTable { organization_id String @id @default(uuid()) organization_alias String diff --git a/litellm/proxy/utils.py b/litellm/proxy/utils.py index 17f70b322..ec6041bc8 100644 --- a/litellm/proxy/utils.py +++ b/litellm/proxy/utils.py @@ -13,6 +13,7 @@ from litellm.proxy._types import ( Member, ) from litellm.caching import DualCache, RedisCache +from litellm.router import Deployment, ModelInfo, LiteLLM_Params from litellm.llms.custom_httpx.httpx_handler import HTTPHandler from litellm.proxy.hooks.parallel_request_limiter import ( _PROXY_MaxParallelRequestsHandler, @@ -2181,6 +2182,32 @@ async def update_spend( raise e +# class Models: +# """ +# Need a class to maintain state of models / router across calls to check if new deployments need to be added +# """ + +# def __init__( +# self, +# router: litellm.Router, +# llm_model_list: list, +# prisma_client: PrismaClient, +# proxy_logging_obj: ProxyLogging, +# master_key: str, +# ) -> None: +# self.router = router +# self.llm_model_list = llm_model_list +# self.prisma_client = prisma_client +# self.proxy_logging_obj = proxy_logging_obj +# self.master_key = master_key + +# def get_router(self) -> litellm.Router: +# return self.router + +# def get_model_list(self) -> list: +# return self.llm_model_list + + async def _read_request_body(request): """ Asynchronous function to read the request body and parse it as JSON or literal data. @@ -2318,6 +2345,45 @@ def _is_user_proxy_admin(user_id_information=None): return False +def encrypt_value(value: str, master_key: str): + import hashlib + import nacl.secret + import nacl.utils + + # get 32 byte master key # + hash_object = hashlib.sha256(master_key.encode()) + hash_bytes = hash_object.digest() + + # initialize secret box # + box = nacl.secret.SecretBox(hash_bytes) + + # encode message # + value_bytes = value.encode("utf-8") + + encrypted = box.encrypt(value_bytes) + + return encrypted + + +def decrypt_value(value: bytes, master_key: str) -> str: + import hashlib + import nacl.secret + import nacl.utils + + # get 32 byte master key # + hash_object = hashlib.sha256(master_key.encode()) + hash_bytes = hash_object.digest() + + # initialize secret box # + box = nacl.secret.SecretBox(hash_bytes) + + # Convert the bytes object to a string + plaintext = box.decrypt(value) + + plaintext = plaintext.decode("utf-8") # type: ignore + return plaintext # type: ignore + + # LiteLLM Admin UI - Non SSO Login html_form = """ diff --git a/litellm/router.py b/litellm/router.py index a04197e19..0d161f1de 100644 --- a/litellm/router.py +++ b/litellm/router.py @@ -29,6 +29,103 @@ from litellm.utils import ModelResponse, CustomStreamWrapper import copy from litellm._logging import verbose_router_logger import logging +from pydantic import BaseModel, validator + + +class ModelInfo(BaseModel): + id: Optional[ + str + ] # Allow id to be optional on input, but it will always be present as a str in the model instance + + def __init__(self, id: Optional[str] = None, **params): + if id is None: + id = str(uuid.uuid4()) # Generate a UUID if id is None or not provided + super().__init__(id=id, **params) + + class Config: + extra = "allow" + + def __contains__(self, key): + # Define custom behavior for the 'in' operator + return hasattr(self, key) + + def get(self, key, default=None): + # Custom .get() method to access attributes with a default value if the attribute doesn't exist + return getattr(self, key, default) + + def __getitem__(self, key): + # Allow dictionary-style access to attributes + return getattr(self, key) + + def __setitem__(self, key, value): + # Allow dictionary-style assignment of attributes + setattr(self, key, value) + + +class LiteLLM_Params(BaseModel): + model: str + tpm: Optional[int] = None + rpm: Optional[int] = None + api_key: Optional[str] = None + api_base: Optional[str] = None + api_version: Optional[str] = None + timeout: Optional[Union[float, str]] = None # if str, pass in as os.environ/ + stream_timeout: Optional[Union[float, str]] = ( + None # timeout when making stream=True calls, if str, pass in as os.environ/ + ) + max_retries: Optional[int] = 2 # follows openai default of 2 + organization: Optional[str] = None # for openai orgs + + class Config: + extra = "allow" + + def __contains__(self, key): + # Define custom behavior for the 'in' operator + return hasattr(self, key) + + def get(self, key, default=None): + # Custom .get() method to access attributes with a default value if the attribute doesn't exist + return getattr(self, key, default) + + def __getitem__(self, key): + # Allow dictionary-style access to attributes + return getattr(self, key) + + def __setitem__(self, key, value): + # Allow dictionary-style assignment of attributes + setattr(self, key, value) + + +class Deployment(BaseModel): + model_name: str + litellm_params: LiteLLM_Params + model_info: ModelInfo + + def to_json(self, **kwargs): + try: + return self.model_dump(**kwargs) # noqa + except Exception as e: + # if using pydantic v1 + return self.dict(**kwargs) + + class Config: + extra = "allow" + + def __contains__(self, key): + # Define custom behavior for the 'in' operator + return hasattr(self, key) + + def get(self, key, default=None): + # Custom .get() method to access attributes with a default value if the attribute doesn't exist + return getattr(self, key, default) + + def __getitem__(self, key): + # Allow dictionary-style access to attributes + return getattr(self, key) + + def __setitem__(self, key, value): + # Allow dictionary-style assignment of attributes + setattr(self, key, value) class Router: @@ -2040,76 +2137,114 @@ class Router: ) # cache for 1 hr def set_model_list(self, model_list: list): - self.model_list = copy.deepcopy(model_list) + original_model_list = copy.deepcopy(model_list) + self.model_list = [] # we add api_base/api_key each model so load balancing between azure/gpt on api_base1 and api_base2 works import os - for model in self.model_list: - #### MODEL ID INIT ######## - model_info = model.get("model_info", {}) - model_info["id"] = model_info.get("id", str(uuid.uuid4())) - model["model_info"] = model_info - #### DEPLOYMENT NAMES INIT ######## - self.deployment_names.append(model["litellm_params"]["model"]) - ############ Users can either pass tpm/rpm as a litellm_param or a router param ########### - # for get_available_deployment, we use the litellm_param["rpm"] - # in this snippet we also set rpm to be a litellm_param - if ( - model["litellm_params"].get("rpm") is None - and model.get("rpm") is not None - ): - model["litellm_params"]["rpm"] = model.get("rpm") - if ( - model["litellm_params"].get("tpm") is None - and model.get("tpm") is not None - ): - model["litellm_params"]["tpm"] = model.get("tpm") - - #### VALIDATE MODEL ######## - # check if model provider in supported providers - ( - _model, - custom_llm_provider, - dynamic_api_key, - api_base, - ) = litellm.get_llm_provider( - model=model["litellm_params"]["model"], - custom_llm_provider=model["litellm_params"].get( - "custom_llm_provider", None - ), + for model in original_model_list: + deployment = Deployment( + model_name=model["model_name"], + litellm_params=model["litellm_params"], + model_info=model.get("model_info", {}), ) + self._add_deployment(deployment=deployment) - # Check if user is trying to use model_name == "*" - # this is a catch all model for their specific api key - if model["model_name"] == "*": - self.default_deployment = model + model = deployment.to_json(exclude_none=True) - # Azure GPT-Vision Enhancements, users can pass os.environ/ - data_sources = model.get("litellm_params", {}).get("dataSources", []) - - for data_source in data_sources: - params = data_source.get("parameters", {}) - for param_key in ["endpoint", "key"]: - # if endpoint or key set for Azure GPT Vision Enhancements, check if it's an env var - if param_key in params and params[param_key].startswith( - "os.environ/" - ): - env_name = params[param_key].replace("os.environ/", "") - params[param_key] = os.environ.get(env_name, "") - - # done reading model["litellm_params"] - if custom_llm_provider not in litellm.provider_list: - raise Exception(f"Unsupported provider - {custom_llm_provider}") - - # init OpenAI, Azure clients - self.set_client(model=model) + self.model_list.append(model) verbose_router_logger.debug(f"\nInitialized Model List {self.model_list}") self.model_names = [m["model_name"] for m in model_list] + def _add_deployment(self, deployment: Deployment): + import os + + #### DEPLOYMENT NAMES INIT ######## + self.deployment_names.append(deployment.litellm_params.model) + ############ Users can either pass tpm/rpm as a litellm_param or a router param ########### + # for get_available_deployment, we use the litellm_param["rpm"] + # in this snippet we also set rpm to be a litellm_param + if ( + deployment.litellm_params.rpm is None + and getattr(deployment, "rpm", None) is not None + ): + deployment.litellm_params.rpm = getattr(deployment, "rpm") + + if ( + deployment.litellm_params.tpm is None + and getattr(deployment, "tpm", None) is not None + ): + deployment.litellm_params.tpm = getattr(deployment, "tpm") + + #### VALIDATE MODEL ######## + # check if model provider in supported providers + ( + _model, + custom_llm_provider, + dynamic_api_key, + api_base, + ) = litellm.get_llm_provider( + model=deployment.litellm_params.model, + custom_llm_provider=deployment.litellm_params.get( + "custom_llm_provider", None + ), + ) + + # Check if user is trying to use model_name == "*" + # this is a catch all model for their specific api key + if deployment.model_name == "*": + self.default_deployment = deployment + + # Azure GPT-Vision Enhancements, users can pass os.environ/ + data_sources = deployment.litellm_params.get("dataSources", []) + + for data_source in data_sources: + params = data_source.get("parameters", {}) + for param_key in ["endpoint", "key"]: + # if endpoint or key set for Azure GPT Vision Enhancements, check if it's an env var + if param_key in params and params[param_key].startswith("os.environ/"): + env_name = params[param_key].replace("os.environ/", "") + params[param_key] = os.environ.get(env_name, "") + + # done reading model["litellm_params"] + if custom_llm_provider not in litellm.provider_list: + raise Exception(f"Unsupported provider - {custom_llm_provider}") + + # init OpenAI, Azure clients + self.set_client(model=deployment.to_json(exclude_none=True)) + + def add_deployment(self, deployment: Deployment): + # check if deployment already exists + + if deployment.model_info.id in self.get_model_ids(): + return + + # add to model list + _deployment = deployment.to_json(exclude_none=True) + self.model_list.append(_deployment) + + # initialize client + self._add_deployment(deployment=deployment) + + # add to model names + self.model_names.append(deployment.model_name) + return + + def get_model_ids(self): + ids = [] + for model in self.model_list: + if "model_info" in model and "id" in model["model_info"]: + id = model["model_info"]["id"] + ids.append(id) + return ids + def get_model_names(self): return self.model_names + def get_model_list(self): + return self.model_list + def _get_client(self, deployment, kwargs, client_type=None): """ Returns the appropriate client based on the given deployment, kwargs, and client_type. diff --git a/litellm/tests/test_router.py b/litellm/tests/test_router.py index dfcaf9f85..9ad8b8156 100644 --- a/litellm/tests/test_router.py +++ b/litellm/tests/test_router.py @@ -10,6 +10,7 @@ sys.path.insert( ) # Adds the parent directory to the system path import litellm from litellm import Router +from litellm.router import Deployment, LiteLLM_Params, ModelInfo from concurrent.futures import ThreadPoolExecutor from collections import defaultdict from dotenv import load_dotenv @@ -1193,3 +1194,37 @@ async def test_router_amoderation(): ) print("moderation result", result) + + +def test_router_add_deployment(): + initial_model_list = [ + { + "model_name": "fake-openai-endpoint", + "litellm_params": { + "model": "openai/my-fake-model", + "api_key": "my-fake-key", + "api_base": "https://openai-function-calling-workers.tasslexyz.workers.dev/", + }, + }, + ] + router = Router(model_list=initial_model_list) + + init_model_id_list = router.get_model_ids() + + print(f"init_model_id_list: {init_model_id_list}") + + router.add_deployment( + deployment=Deployment( + model_name="gpt-instruct", + litellm_params=LiteLLM_Params(model="gpt-3.5-turbo-instruct"), + model_info=ModelInfo(), + ) + ) + + new_model_id_list = router.get_model_ids() + + print(f"new_model_id_list: {new_model_id_list}") + + assert len(new_model_id_list) > len(init_model_id_list) + + assert new_model_id_list[1] != new_model_id_list[0] diff --git a/schema.prisma b/schema.prisma index 6e2400a12..529ae7f2b 100644 --- a/schema.prisma +++ b/schema.prisma @@ -27,6 +27,18 @@ model LiteLLM_BudgetTable { end_users LiteLLM_EndUserTable[] // multiple end-users can have the same budget } +// Models on proxy +model LiteLLM_ProxyModelTable { + model_id String @id @default(uuid()) + model_name String + litellm_params Json + model_info Json? + created_at DateTime @default(now()) @map("created_at") + created_by String + updated_at DateTime @default(now()) @updatedAt @map("updated_at") + updated_by String +} + model LiteLLM_OrganizationTable { organization_id String @id @default(uuid()) organization_alias String