forked from phoenix/litellm-mirror
Merge pull request #2827 from BerriAI/litellm_model_add_api
fix(proxy_server.py): persist models added via `/model/new` to db
This commit is contained in:
commit
0c5b8a7667
11 changed files with 432 additions and 90 deletions
|
@ -23,6 +23,7 @@ litellm_settings:
|
|||
general_settings:
|
||||
master_key: sk-1234
|
||||
alerting: ["slack"]
|
||||
store_model_in_db: True
|
||||
# proxy_batch_write_at: 60 # 👈 Frequency of batch writing logs to server (in seconds)
|
||||
enable_jwt_auth: True
|
||||
alerting: ["slack"]
|
||||
|
|
|
@ -97,6 +97,8 @@ from litellm.proxy.utils import (
|
|||
_is_projected_spend_over_limit,
|
||||
_get_projected_spend_over_limit,
|
||||
update_spend,
|
||||
encrypt_value,
|
||||
decrypt_value,
|
||||
)
|
||||
from litellm.proxy.secret_managers.google_kms import load_google_kms
|
||||
from litellm.proxy.secret_managers.aws_secret_manager import load_aws_secret_manager
|
||||
|
@ -104,6 +106,8 @@ import pydantic
|
|||
from litellm.proxy._types import *
|
||||
from litellm.caching import DualCache, RedisCache
|
||||
from litellm.proxy.health_check import perform_health_check
|
||||
from litellm.router import LiteLLM_Params, Deployment
|
||||
from litellm.router import ModelInfo as RouterModelInfo
|
||||
from litellm._logging import verbose_router_logger, verbose_proxy_logger
|
||||
from litellm.proxy.auth.handle_jwt import JWTHandler
|
||||
from litellm.proxy.hooks.prompt_injection_detection import (
|
||||
|
@ -2371,6 +2375,64 @@ class ProxyConfig:
|
|||
router = litellm.Router(**router_params) # type:ignore
|
||||
return router, model_list, general_settings
|
||||
|
||||
async def add_deployment(
|
||||
self,
|
||||
prisma_client: PrismaClient,
|
||||
proxy_logging_obj: ProxyLogging,
|
||||
):
|
||||
"""
|
||||
- Check db for new models (last 10 most recently updated)
|
||||
- Check if model id's in router already
|
||||
- If not, add to router
|
||||
"""
|
||||
global llm_router, llm_model_list, master_key
|
||||
|
||||
import base64
|
||||
|
||||
try:
|
||||
if llm_router is None:
|
||||
raise Exception("No router initialized")
|
||||
|
||||
new_models = await prisma_client.db.litellm_proxymodeltable.find_many(
|
||||
take=10, order={"updated_at": "desc"}
|
||||
)
|
||||
|
||||
for m in new_models:
|
||||
_litellm_params = m.litellm_params
|
||||
if isinstance(_litellm_params, dict):
|
||||
# decrypt values
|
||||
for k, v in _litellm_params.items():
|
||||
# decode base64
|
||||
decoded_b64 = base64.b64decode(v)
|
||||
# decrypt value
|
||||
_litellm_params[k] = decrypt_value(
|
||||
value=decoded_b64, master_key=master_key
|
||||
)
|
||||
_litellm_params = LiteLLM_Params(**_litellm_params)
|
||||
else:
|
||||
raise Exception(
|
||||
f"Invalid model added to proxy db. Invalid litellm params. litellm_params={_litellm_params}"
|
||||
)
|
||||
|
||||
if m.model_info is not None and isinstance(m.model_info, dict):
|
||||
if "id" not in m.model_info:
|
||||
m.model_info["id"] = m.model_id
|
||||
_model_info = RouterModelInfo(**m.model_info)
|
||||
else:
|
||||
_model_info = RouterModelInfo(id=m.model_id)
|
||||
|
||||
llm_router.add_deployment(
|
||||
deployment=Deployment(
|
||||
model_name=m.model_name,
|
||||
litellm_params=_litellm_params,
|
||||
model_info=_model_info,
|
||||
)
|
||||
)
|
||||
|
||||
llm_model_list = llm_router.get_model_list()
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
|
||||
proxy_config = ProxyConfig()
|
||||
|
||||
|
@ -2943,7 +3005,7 @@ async def startup_event():
|
|||
if prisma_client is not None:
|
||||
create_view_response = await prisma_client.check_view_exists()
|
||||
|
||||
### START BATCH WRITING DB ###
|
||||
### START BATCH WRITING DB + CHECKING NEW MODELS###
|
||||
if prisma_client is not None:
|
||||
scheduler = AsyncIOScheduler()
|
||||
interval = random.randint(
|
||||
|
@ -2966,6 +3028,15 @@ async def startup_event():
|
|||
seconds=batch_writing_interval,
|
||||
args=[prisma_client, db_writer_client, proxy_logging_obj],
|
||||
)
|
||||
|
||||
### ADD NEW MODELS ###
|
||||
if general_settings.get("store_model_in_db", False) == True:
|
||||
scheduler.add_job(
|
||||
proxy_config.add_deployment,
|
||||
"interval",
|
||||
seconds=30,
|
||||
args=[prisma_client, proxy_logging_obj],
|
||||
)
|
||||
scheduler.start()
|
||||
|
||||
|
||||
|
@ -3314,8 +3385,6 @@ async def chat_completion(
|
|||
)
|
||||
)
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
### ROUTE THE REQUEST ###
|
||||
# Do not change this - it should be a constant time fetch - ALWAYS
|
||||
router_model_names = llm_router.model_names if llm_router is not None else []
|
||||
|
@ -3534,8 +3603,6 @@ async def embeddings(
|
|||
user_api_key_dict=user_api_key_dict, data=data, call_type="embeddings"
|
||||
)
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
## ROUTE TO CORRECT ENDPOINT ##
|
||||
# skip router if user passed their key
|
||||
if "api_key" in data:
|
||||
|
@ -6692,30 +6759,54 @@ async def info_budget(data: BudgetRequest):
|
|||
tags=["model management"],
|
||||
dependencies=[Depends(user_api_key_auth)],
|
||||
)
|
||||
async def add_new_model(model_params: ModelParams):
|
||||
global llm_router, llm_model_list, general_settings, user_config_file_path, proxy_config
|
||||
async def add_new_model(
|
||||
model_params: ModelParams,
|
||||
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
||||
):
|
||||
global llm_router, llm_model_list, general_settings, user_config_file_path, proxy_config, prisma_client, master_key
|
||||
try:
|
||||
# Load existing config
|
||||
config = await proxy_config.get_config()
|
||||
import base64
|
||||
|
||||
verbose_proxy_logger.debug("User config path: %s", user_config_file_path)
|
||||
global prisma_client
|
||||
|
||||
verbose_proxy_logger.debug("Loaded config: %s", config)
|
||||
# Add the new model to the config
|
||||
model_info = model_params.model_info.json()
|
||||
model_info = {k: v for k, v in model_info.items() if v is not None}
|
||||
config["model_list"].append(
|
||||
{
|
||||
"model_name": model_params.model_name,
|
||||
"litellm_params": model_params.litellm_params,
|
||||
"model_info": model_info,
|
||||
}
|
||||
)
|
||||
if prisma_client is None:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail={
|
||||
"error": "No DB Connected. Here's how to do it - https://docs.litellm.ai/docs/proxy/virtual_keys"
|
||||
},
|
||||
)
|
||||
|
||||
verbose_proxy_logger.debug("updated model list: %s", config["model_list"])
|
||||
|
||||
# Save new config
|
||||
await proxy_config.save_config(new_config=config)
|
||||
# update DB
|
||||
if general_settings.get("store_model_in_db", False) == True:
|
||||
"""
|
||||
- store model_list in db
|
||||
- store keys separately
|
||||
"""
|
||||
# encrypt litellm params #
|
||||
for k, v in model_params.litellm_params.items():
|
||||
encrypted_value = encrypt_value(value=v, master_key=master_key) # type: ignore
|
||||
model_params.litellm_params[k] = base64.b64encode(
|
||||
encrypted_value
|
||||
).decode("utf-8")
|
||||
await prisma_client.db.litellm_proxymodeltable.create(
|
||||
data={
|
||||
"model_name": model_params.model_name,
|
||||
"litellm_params": json.dumps(model_params.litellm_params), # type: ignore
|
||||
"model_info": model_params.model_info.model_dump_json( # type: ignore
|
||||
exclude_none=True
|
||||
),
|
||||
"created_by": user_api_key_dict.user_id or litellm_proxy_admin_name,
|
||||
"updated_by": user_api_key_dict.user_id or litellm_proxy_admin_name,
|
||||
}
|
||||
)
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail={
|
||||
"error": "Set `store_model_in_db: true` in general_settings on your config.yaml"
|
||||
},
|
||||
)
|
||||
return {"message": "Model added successfully"}
|
||||
|
||||
except Exception as e:
|
||||
|
@ -6885,14 +6976,16 @@ async def model_info_v1(
|
|||
):
|
||||
global llm_model_list, general_settings, user_config_file_path, proxy_config
|
||||
|
||||
# Load existing config
|
||||
config = await proxy_config.get_config()
|
||||
if llm_model_list is None:
|
||||
raise HTTPException(
|
||||
status_code=500, detail={"error": "LLM Model List not loaded in"}
|
||||
)
|
||||
|
||||
if len(user_api_key_dict.models) > 0:
|
||||
model_names = user_api_key_dict.models
|
||||
all_models = [m for m in config["model_list"] if m["model_name"] in model_names]
|
||||
all_models = [m for m in llm_model_list if m["model_name"] in model_names]
|
||||
else:
|
||||
all_models = config["model_list"]
|
||||
all_models = llm_model_list
|
||||
for model in all_models:
|
||||
# provided model_info in config.yaml
|
||||
model_info = model.get("model_info", {})
|
||||
|
@ -6957,6 +7050,7 @@ async def delete_model(model_info: ModelInfoDelete):
|
|||
|
||||
# Check if the model with the specified model_id exists
|
||||
model_to_delete = None
|
||||
|
||||
for model in config["model_list"]:
|
||||
if model.get("model_info", {}).get("id", None) == model_info.id:
|
||||
model_to_delete = model
|
||||
|
|
|
@ -39,6 +39,7 @@ model LiteLLM_ProxyModelTable {
|
|||
updated_by String
|
||||
}
|
||||
|
||||
|
||||
model LiteLLM_OrganizationTable {
|
||||
organization_id String @id @default(uuid())
|
||||
organization_alias String
|
||||
|
|
|
@ -13,6 +13,7 @@ from litellm.proxy._types import (
|
|||
Member,
|
||||
)
|
||||
from litellm.caching import DualCache, RedisCache
|
||||
from litellm.router import Deployment, ModelInfo, LiteLLM_Params
|
||||
from litellm.llms.custom_httpx.httpx_handler import HTTPHandler
|
||||
from litellm.proxy.hooks.parallel_request_limiter import (
|
||||
_PROXY_MaxParallelRequestsHandler,
|
||||
|
@ -2181,6 +2182,32 @@ async def update_spend(
|
|||
raise e
|
||||
|
||||
|
||||
# class Models:
|
||||
# """
|
||||
# Need a class to maintain state of models / router across calls to check if new deployments need to be added
|
||||
# """
|
||||
|
||||
# def __init__(
|
||||
# self,
|
||||
# router: litellm.Router,
|
||||
# llm_model_list: list,
|
||||
# prisma_client: PrismaClient,
|
||||
# proxy_logging_obj: ProxyLogging,
|
||||
# master_key: str,
|
||||
# ) -> None:
|
||||
# self.router = router
|
||||
# self.llm_model_list = llm_model_list
|
||||
# self.prisma_client = prisma_client
|
||||
# self.proxy_logging_obj = proxy_logging_obj
|
||||
# self.master_key = master_key
|
||||
|
||||
# def get_router(self) -> litellm.Router:
|
||||
# return self.router
|
||||
|
||||
# def get_model_list(self) -> list:
|
||||
# return self.llm_model_list
|
||||
|
||||
|
||||
async def _read_request_body(request):
|
||||
"""
|
||||
Asynchronous function to read the request body and parse it as JSON or literal data.
|
||||
|
@ -2318,6 +2345,45 @@ def _is_user_proxy_admin(user_id_information=None):
|
|||
return False
|
||||
|
||||
|
||||
def encrypt_value(value: str, master_key: str):
|
||||
import hashlib
|
||||
import nacl.secret
|
||||
import nacl.utils
|
||||
|
||||
# get 32 byte master key #
|
||||
hash_object = hashlib.sha256(master_key.encode())
|
||||
hash_bytes = hash_object.digest()
|
||||
|
||||
# initialize secret box #
|
||||
box = nacl.secret.SecretBox(hash_bytes)
|
||||
|
||||
# encode message #
|
||||
value_bytes = value.encode("utf-8")
|
||||
|
||||
encrypted = box.encrypt(value_bytes)
|
||||
|
||||
return encrypted
|
||||
|
||||
|
||||
def decrypt_value(value: bytes, master_key: str) -> str:
|
||||
import hashlib
|
||||
import nacl.secret
|
||||
import nacl.utils
|
||||
|
||||
# get 32 byte master key #
|
||||
hash_object = hashlib.sha256(master_key.encode())
|
||||
hash_bytes = hash_object.digest()
|
||||
|
||||
# initialize secret box #
|
||||
box = nacl.secret.SecretBox(hash_bytes)
|
||||
|
||||
# Convert the bytes object to a string
|
||||
plaintext = box.decrypt(value)
|
||||
|
||||
plaintext = plaintext.decode("utf-8") # type: ignore
|
||||
return plaintext # type: ignore
|
||||
|
||||
|
||||
# LiteLLM Admin UI - Non SSO Login
|
||||
html_form = """
|
||||
<!DOCTYPE html>
|
||||
|
|
|
@ -29,6 +29,103 @@ from litellm.utils import ModelResponse, CustomStreamWrapper
|
|||
import copy
|
||||
from litellm._logging import verbose_router_logger
|
||||
import logging
|
||||
from pydantic import BaseModel, validator
|
||||
|
||||
|
||||
class ModelInfo(BaseModel):
|
||||
id: Optional[
|
||||
str
|
||||
] # Allow id to be optional on input, but it will always be present as a str in the model instance
|
||||
|
||||
def __init__(self, id: Optional[str] = None, **params):
|
||||
if id is None:
|
||||
id = str(uuid.uuid4()) # Generate a UUID if id is None or not provided
|
||||
super().__init__(id=id, **params)
|
||||
|
||||
class Config:
|
||||
extra = "allow"
|
||||
|
||||
def __contains__(self, key):
|
||||
# Define custom behavior for the 'in' operator
|
||||
return hasattr(self, key)
|
||||
|
||||
def get(self, key, default=None):
|
||||
# Custom .get() method to access attributes with a default value if the attribute doesn't exist
|
||||
return getattr(self, key, default)
|
||||
|
||||
def __getitem__(self, key):
|
||||
# Allow dictionary-style access to attributes
|
||||
return getattr(self, key)
|
||||
|
||||
def __setitem__(self, key, value):
|
||||
# Allow dictionary-style assignment of attributes
|
||||
setattr(self, key, value)
|
||||
|
||||
|
||||
class LiteLLM_Params(BaseModel):
|
||||
model: str
|
||||
tpm: Optional[int] = None
|
||||
rpm: Optional[int] = None
|
||||
api_key: Optional[str] = None
|
||||
api_base: Optional[str] = None
|
||||
api_version: Optional[str] = None
|
||||
timeout: Optional[Union[float, str]] = None # if str, pass in as os.environ/
|
||||
stream_timeout: Optional[Union[float, str]] = (
|
||||
None # timeout when making stream=True calls, if str, pass in as os.environ/
|
||||
)
|
||||
max_retries: Optional[int] = 2 # follows openai default of 2
|
||||
organization: Optional[str] = None # for openai orgs
|
||||
|
||||
class Config:
|
||||
extra = "allow"
|
||||
|
||||
def __contains__(self, key):
|
||||
# Define custom behavior for the 'in' operator
|
||||
return hasattr(self, key)
|
||||
|
||||
def get(self, key, default=None):
|
||||
# Custom .get() method to access attributes with a default value if the attribute doesn't exist
|
||||
return getattr(self, key, default)
|
||||
|
||||
def __getitem__(self, key):
|
||||
# Allow dictionary-style access to attributes
|
||||
return getattr(self, key)
|
||||
|
||||
def __setitem__(self, key, value):
|
||||
# Allow dictionary-style assignment of attributes
|
||||
setattr(self, key, value)
|
||||
|
||||
|
||||
class Deployment(BaseModel):
|
||||
model_name: str
|
||||
litellm_params: LiteLLM_Params
|
||||
model_info: ModelInfo
|
||||
|
||||
def to_json(self, **kwargs):
|
||||
try:
|
||||
return self.model_dump(**kwargs) # noqa
|
||||
except Exception as e:
|
||||
# if using pydantic v1
|
||||
return self.dict(**kwargs)
|
||||
|
||||
class Config:
|
||||
extra = "allow"
|
||||
|
||||
def __contains__(self, key):
|
||||
# Define custom behavior for the 'in' operator
|
||||
return hasattr(self, key)
|
||||
|
||||
def get(self, key, default=None):
|
||||
# Custom .get() method to access attributes with a default value if the attribute doesn't exist
|
||||
return getattr(self, key, default)
|
||||
|
||||
def __getitem__(self, key):
|
||||
# Allow dictionary-style access to attributes
|
||||
return getattr(self, key)
|
||||
|
||||
def __setitem__(self, key, value):
|
||||
# Allow dictionary-style assignment of attributes
|
||||
setattr(self, key, value)
|
||||
|
||||
|
||||
class Router:
|
||||
|
@ -2040,76 +2137,121 @@ class Router:
|
|||
) # cache for 1 hr
|
||||
|
||||
def set_model_list(self, model_list: list):
|
||||
self.model_list = copy.deepcopy(model_list)
|
||||
original_model_list = copy.deepcopy(model_list)
|
||||
self.model_list = []
|
||||
# we add api_base/api_key each model so load balancing between azure/gpt on api_base1 and api_base2 works
|
||||
import os
|
||||
|
||||
for model in self.model_list:
|
||||
#### MODEL ID INIT ########
|
||||
model_info = model.get("model_info", {})
|
||||
model_info["id"] = model_info.get("id", str(uuid.uuid4()))
|
||||
model["model_info"] = model_info
|
||||
#### DEPLOYMENT NAMES INIT ########
|
||||
self.deployment_names.append(model["litellm_params"]["model"])
|
||||
############ Users can either pass tpm/rpm as a litellm_param or a router param ###########
|
||||
# for get_available_deployment, we use the litellm_param["rpm"]
|
||||
# in this snippet we also set rpm to be a litellm_param
|
||||
if (
|
||||
model["litellm_params"].get("rpm") is None
|
||||
and model.get("rpm") is not None
|
||||
):
|
||||
model["litellm_params"]["rpm"] = model.get("rpm")
|
||||
if (
|
||||
model["litellm_params"].get("tpm") is None
|
||||
and model.get("tpm") is not None
|
||||
):
|
||||
model["litellm_params"]["tpm"] = model.get("tpm")
|
||||
|
||||
#### VALIDATE MODEL ########
|
||||
# check if model provider in supported providers
|
||||
(
|
||||
_model,
|
||||
custom_llm_provider,
|
||||
dynamic_api_key,
|
||||
api_base,
|
||||
) = litellm.get_llm_provider(
|
||||
model=model["litellm_params"]["model"],
|
||||
custom_llm_provider=model["litellm_params"].get(
|
||||
"custom_llm_provider", None
|
||||
),
|
||||
for model in original_model_list:
|
||||
_model_name = model.pop("model_name")
|
||||
_litellm_params = model.pop("litellm_params")
|
||||
_model_info = model.pop("model_info", {})
|
||||
deployment = Deployment(
|
||||
**model,
|
||||
model_name=_model_name,
|
||||
litellm_params=_litellm_params,
|
||||
model_info=_model_info,
|
||||
)
|
||||
|
||||
# Check if user is trying to use model_name == "*"
|
||||
# this is a catch all model for their specific api key
|
||||
if model["model_name"] == "*":
|
||||
self.default_deployment = model
|
||||
deployment = self._add_deployment(deployment=deployment)
|
||||
|
||||
# Azure GPT-Vision Enhancements, users can pass os.environ/
|
||||
data_sources = model.get("litellm_params", {}).get("dataSources", [])
|
||||
model = deployment.to_json(exclude_none=True)
|
||||
|
||||
for data_source in data_sources:
|
||||
params = data_source.get("parameters", {})
|
||||
for param_key in ["endpoint", "key"]:
|
||||
# if endpoint or key set for Azure GPT Vision Enhancements, check if it's an env var
|
||||
if param_key in params and params[param_key].startswith(
|
||||
"os.environ/"
|
||||
):
|
||||
env_name = params[param_key].replace("os.environ/", "")
|
||||
params[param_key] = os.environ.get(env_name, "")
|
||||
|
||||
# done reading model["litellm_params"]
|
||||
if custom_llm_provider not in litellm.provider_list:
|
||||
raise Exception(f"Unsupported provider - {custom_llm_provider}")
|
||||
|
||||
# init OpenAI, Azure clients
|
||||
self.set_client(model=model)
|
||||
self.model_list.append(model)
|
||||
|
||||
verbose_router_logger.debug(f"\nInitialized Model List {self.model_list}")
|
||||
self.model_names = [m["model_name"] for m in model_list]
|
||||
|
||||
def _add_deployment(self, deployment: Deployment) -> Deployment:
|
||||
import os
|
||||
|
||||
#### DEPLOYMENT NAMES INIT ########
|
||||
self.deployment_names.append(deployment.litellm_params.model)
|
||||
############ Users can either pass tpm/rpm as a litellm_param or a router param ###########
|
||||
# for get_available_deployment, we use the litellm_param["rpm"]
|
||||
# in this snippet we also set rpm to be a litellm_param
|
||||
if (
|
||||
deployment.litellm_params.rpm is None
|
||||
and getattr(deployment, "rpm", None) is not None
|
||||
):
|
||||
deployment.litellm_params.rpm = getattr(deployment, "rpm")
|
||||
|
||||
if (
|
||||
deployment.litellm_params.tpm is None
|
||||
and getattr(deployment, "tpm", None) is not None
|
||||
):
|
||||
deployment.litellm_params.tpm = getattr(deployment, "tpm")
|
||||
|
||||
#### VALIDATE MODEL ########
|
||||
# check if model provider in supported providers
|
||||
(
|
||||
_model,
|
||||
custom_llm_provider,
|
||||
dynamic_api_key,
|
||||
api_base,
|
||||
) = litellm.get_llm_provider(
|
||||
model=deployment.litellm_params.model,
|
||||
custom_llm_provider=deployment.litellm_params.get(
|
||||
"custom_llm_provider", None
|
||||
),
|
||||
)
|
||||
|
||||
# Check if user is trying to use model_name == "*"
|
||||
# this is a catch all model for their specific api key
|
||||
if deployment.model_name == "*":
|
||||
self.default_deployment = deployment.to_json(exclude_none=True)
|
||||
|
||||
# Azure GPT-Vision Enhancements, users can pass os.environ/
|
||||
data_sources = deployment.litellm_params.get("dataSources", [])
|
||||
|
||||
for data_source in data_sources:
|
||||
params = data_source.get("parameters", {})
|
||||
for param_key in ["endpoint", "key"]:
|
||||
# if endpoint or key set for Azure GPT Vision Enhancements, check if it's an env var
|
||||
if param_key in params and params[param_key].startswith("os.environ/"):
|
||||
env_name = params[param_key].replace("os.environ/", "")
|
||||
params[param_key] = os.environ.get(env_name, "")
|
||||
|
||||
# done reading model["litellm_params"]
|
||||
if custom_llm_provider not in litellm.provider_list:
|
||||
raise Exception(f"Unsupported provider - {custom_llm_provider}")
|
||||
|
||||
# init OpenAI, Azure clients
|
||||
self.set_client(model=deployment.to_json(exclude_none=True))
|
||||
|
||||
return deployment
|
||||
|
||||
def add_deployment(self, deployment: Deployment):
|
||||
# check if deployment already exists
|
||||
|
||||
if deployment.model_info.id in self.get_model_ids():
|
||||
return
|
||||
|
||||
# add to model list
|
||||
_deployment = deployment.to_json(exclude_none=True)
|
||||
self.model_list.append(_deployment)
|
||||
|
||||
# initialize client
|
||||
self._add_deployment(deployment=deployment)
|
||||
|
||||
# add to model names
|
||||
self.model_names.append(deployment.model_name)
|
||||
return
|
||||
|
||||
def get_model_ids(self):
|
||||
ids = []
|
||||
for model in self.model_list:
|
||||
if "model_info" in model and "id" in model["model_info"]:
|
||||
id = model["model_info"]["id"]
|
||||
ids.append(id)
|
||||
return ids
|
||||
|
||||
def get_model_names(self):
|
||||
return self.model_names
|
||||
|
||||
def get_model_list(self):
|
||||
return self.model_list
|
||||
|
||||
def _get_client(self, deployment, kwargs, client_type=None):
|
||||
"""
|
||||
Returns the appropriate client based on the given deployment, kwargs, and client_type.
|
||||
|
@ -2315,6 +2457,7 @@ class Router:
|
|||
model = litellm.model_alias_map[
|
||||
model
|
||||
] # update the model to the actual value if an alias has been passed in
|
||||
|
||||
if self.routing_strategy == "least-busy" and self.leastbusy_logger is not None:
|
||||
deployment = self.leastbusy_logger.get_available_deployments(
|
||||
model_group=model, healthy_deployments=healthy_deployments
|
||||
|
|
|
@ -146,7 +146,7 @@ def test_cooldown_same_model_name():
|
|||
"api_key": os.getenv("AZURE_API_KEY"),
|
||||
"api_version": os.getenv("AZURE_API_VERSION"),
|
||||
"api_base": os.getenv("AZURE_API_BASE"),
|
||||
"tpm": 0.000001,
|
||||
"tpm": 1,
|
||||
},
|
||||
},
|
||||
]
|
||||
|
@ -181,4 +181,4 @@ def test_cooldown_same_model_name():
|
|||
pytest.fail(f"Got unexpected exception on router! - {e}")
|
||||
|
||||
|
||||
test_cooldown_same_model_name()
|
||||
# test_cooldown_same_model_name()
|
||||
|
|
|
@ -194,7 +194,7 @@ def test_img_gen(client_no_auth):
|
|||
|
||||
|
||||
#### ADDITIONAL
|
||||
# @pytest.mark.skip(reason="hitting yaml load issues on circle-ci")
|
||||
@pytest.mark.skip(reason="test via docker tests. Requires prisma client.")
|
||||
def test_add_new_model(client_no_auth):
|
||||
global headers
|
||||
try:
|
||||
|
|
|
@ -10,6 +10,7 @@ sys.path.insert(
|
|||
) # Adds the parent directory to the system path
|
||||
import litellm
|
||||
from litellm import Router
|
||||
from litellm.router import Deployment, LiteLLM_Params, ModelInfo
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from collections import defaultdict
|
||||
from dotenv import load_dotenv
|
||||
|
@ -1193,3 +1194,37 @@ async def test_router_amoderation():
|
|||
)
|
||||
|
||||
print("moderation result", result)
|
||||
|
||||
|
||||
def test_router_add_deployment():
|
||||
initial_model_list = [
|
||||
{
|
||||
"model_name": "fake-openai-endpoint",
|
||||
"litellm_params": {
|
||||
"model": "openai/my-fake-model",
|
||||
"api_key": "my-fake-key",
|
||||
"api_base": "https://openai-function-calling-workers.tasslexyz.workers.dev/",
|
||||
},
|
||||
},
|
||||
]
|
||||
router = Router(model_list=initial_model_list)
|
||||
|
||||
init_model_id_list = router.get_model_ids()
|
||||
|
||||
print(f"init_model_id_list: {init_model_id_list}")
|
||||
|
||||
router.add_deployment(
|
||||
deployment=Deployment(
|
||||
model_name="gpt-instruct",
|
||||
litellm_params=LiteLLM_Params(model="gpt-3.5-turbo-instruct"),
|
||||
model_info=ModelInfo(),
|
||||
)
|
||||
)
|
||||
|
||||
new_model_id_list = router.get_model_ids()
|
||||
|
||||
print(f"new_model_id_list: {new_model_id_list}")
|
||||
|
||||
assert len(new_model_id_list) > len(init_model_id_list)
|
||||
|
||||
assert new_model_id_list[1] != new_model_id_list[0]
|
||||
|
|
|
@ -181,7 +181,7 @@ def test_weighted_selection_router_tpm_as_router_param():
|
|||
pytest.fail(f"Error occurred: {e}")
|
||||
|
||||
|
||||
test_weighted_selection_router_tpm_as_router_param()
|
||||
# test_weighted_selection_router_tpm_as_router_param()
|
||||
|
||||
|
||||
def test_weighted_selection_router_rpm_as_router_param():
|
||||
|
@ -433,7 +433,7 @@ def test_usage_based_routing():
|
|||
|
||||
selection_counts[response["model"]] += 1
|
||||
|
||||
# print("selection counts", selection_counts)
|
||||
print("selection counts", selection_counts)
|
||||
|
||||
total_requests = sum(selection_counts.values())
|
||||
|
||||
|
|
|
@ -54,6 +54,7 @@ litellm_settings:
|
|||
|
||||
general_settings:
|
||||
master_key: sk-1234 # [OPTIONAL] Use to enforce auth on proxy. See - https://docs.litellm.ai/docs/proxy/virtual_keys
|
||||
store_model_in_db: True
|
||||
proxy_budget_rescheduler_min_time: 60
|
||||
proxy_budget_rescheduler_max_time: 64
|
||||
proxy_batch_write_at: 1
|
||||
|
|
|
@ -13,6 +13,7 @@ numpy==1.24.3 # semantic caching
|
|||
pandas==2.1.1 # for viewing clickhouse spend analytics
|
||||
prisma==0.11.0 # for db
|
||||
mangum==0.17.0 # for aws lambda functions
|
||||
pynacl==1.5.0 # for encrypting keys
|
||||
google-cloud-aiplatform==1.43.0 # for vertex ai calls
|
||||
google-generativeai==0.3.2 # for vertex ai calls
|
||||
async_generator==1.10.0 # for async ollama calls
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue