forked from phoenix/litellm-mirror
feat(proxy_server.py): dynamic reloading config.yaml with new models
This commit is contained in:
parent
eae5b3ce50
commit
72381c3cc2
2 changed files with 42 additions and 63 deletions
|
@ -1,35 +0,0 @@
|
||||||
model_list:
|
|
||||||
- model_name: gpt-4
|
|
||||||
litellm_params:
|
|
||||||
model: azure/chatgpt-v-2
|
|
||||||
api_base: https://openai-gpt-4-test-v-1.openai.azure.com/
|
|
||||||
api_version: "2023-05-15"
|
|
||||||
api_key:
|
|
||||||
- model_name: gpt-4
|
|
||||||
litellm_params:
|
|
||||||
model: azure/gpt-4
|
|
||||||
api_key:
|
|
||||||
api_base: https://openai-gpt-4-test-v-2.openai.azure.com/
|
|
||||||
- model_name: gpt-4
|
|
||||||
litellm_params:
|
|
||||||
model: azure/gpt-4
|
|
||||||
api_key:
|
|
||||||
api_base: https://openai-gpt-4-test-v-2.openai.azure.com/
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
litellm_settings:
|
|
||||||
drop_params: True
|
|
||||||
set_verbose: True
|
|
||||||
# cache: # optional if you want to use caching
|
|
||||||
# type: redis # tell litellm to use redis caching
|
|
||||||
|
|
||||||
general_settings:
|
|
||||||
# master_key: sk-1234 # [OPTIONAL] Only use this if you to require all calls to contain this key (Authorization: Bearer sk-1234)
|
|
||||||
# database_url: "postgresql://<user>:<password>@<host>:<port>/<dbname>" # [OPTIONAL] use for token-based auth to proxy
|
|
||||||
|
|
||||||
environment_variables:
|
|
||||||
# settings for using redis caching
|
|
||||||
# REDIS_HOST: redis-16337.c322.us-east-1-2.ec2.cloud.redislabs.com
|
|
||||||
# REDIS_PORT: "16337"
|
|
||||||
# REDIS_PASSWORD:
|
|
|
@ -153,6 +153,10 @@ class ProxyChatCompletionRequest(BaseModel):
|
||||||
class Config:
|
class Config:
|
||||||
extra='allow' # allow params not defined here, these fall in litellm.completion(**kwargs)
|
extra='allow' # allow params not defined here, these fall in litellm.completion(**kwargs)
|
||||||
|
|
||||||
|
class ModelParams(BaseModel):
|
||||||
|
model_name: str
|
||||||
|
litellm_params: dict
|
||||||
|
|
||||||
user_api_base = None
|
user_api_base = None
|
||||||
user_model = None
|
user_model = None
|
||||||
user_debug = False
|
user_debug = False
|
||||||
|
@ -162,6 +166,7 @@ user_temperature = None
|
||||||
user_telemetry = True
|
user_telemetry = True
|
||||||
user_config = None
|
user_config = None
|
||||||
user_headers = None
|
user_headers = None
|
||||||
|
user_config_file_path = None
|
||||||
local_logging = True # writes logs to a local api_log.json file for debugging
|
local_logging = True # writes logs to a local api_log.json file for debugging
|
||||||
experimental = False
|
experimental = False
|
||||||
#### GLOBAL VARIABLES ####
|
#### GLOBAL VARIABLES ####
|
||||||
|
@ -315,7 +320,6 @@ def load_from_azure_key_vault(use_azure_key_vault: bool = False):
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print("Error when loading keys from Azure Key Vault. Ensure you run `pip install azure-identity azure-keyvault-secrets`")
|
print("Error when loading keys from Azure Key Vault. Ensure you run `pip install azure-identity azure-keyvault-secrets`")
|
||||||
|
|
||||||
|
|
||||||
def cost_tracking():
|
def cost_tracking():
|
||||||
global prisma_client, master_key
|
global prisma_client, master_key
|
||||||
if prisma_client is not None and master_key is not None:
|
if prisma_client is not None and master_key is not None:
|
||||||
|
@ -405,10 +409,11 @@ def run_ollama_serve():
|
||||||
""")
|
""")
|
||||||
|
|
||||||
def load_router_config(router: Optional[litellm.Router], config_file_path: str):
|
def load_router_config(router: Optional[litellm.Router], config_file_path: str):
|
||||||
global master_key
|
global master_key, user_config_file_path
|
||||||
config = {}
|
config = {}
|
||||||
try:
|
try:
|
||||||
if os.path.exists(config_file_path):
|
if os.path.exists(config_file_path):
|
||||||
|
user_config_file_path = config_file_path
|
||||||
with open(config_file_path, 'r') as file:
|
with open(config_file_path, 'r') as file:
|
||||||
config = yaml.safe_load(file)
|
config = yaml.safe_load(file)
|
||||||
else:
|
else:
|
||||||
|
@ -947,36 +952,46 @@ async def info_key_fn(key: str = fastapi.Query(..., description="Key in the requ
|
||||||
|
|
||||||
#### MODEL MANAGEMENT ####
|
#### MODEL MANAGEMENT ####
|
||||||
|
|
||||||
|
#### [BETA] - This is a beta endpoint, format might change based on user feedback. - https://github.com/BerriAI/litellm/issues/964
|
||||||
|
@router.post("/model/new", description="Allows adding new models to the model list in the config.yaml", tags=["model management"], dependencies=[Depends(user_api_key_auth)])
|
||||||
|
async def add_new_model(model_params: ModelParams):
|
||||||
|
global llm_router, llm_model_list, general_settings
|
||||||
|
try:
|
||||||
|
# Load existing config
|
||||||
|
with open(user_config_file_path, "r") as config_file:
|
||||||
|
config = yaml.safe_load(config_file)
|
||||||
|
|
||||||
|
# Add the new model to the config
|
||||||
|
config['model_list'].append({
|
||||||
|
'model_name': model_params.model_name,
|
||||||
|
'litellm_params': model_params.litellm_params
|
||||||
|
})
|
||||||
|
|
||||||
|
# Save the updated config
|
||||||
|
with open(user_config_file_path, "w") as config_file:
|
||||||
|
yaml.dump(config, config_file, default_flow_style=False)
|
||||||
|
|
||||||
|
# update Router
|
||||||
|
llm_router, llm_model_list, general_settings = load_router_config(router=llm_router, config_file_path=config)
|
||||||
|
|
||||||
|
|
||||||
|
return {"message": "Model added successfully"}
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
raise HTTPException(status_code=500, detail=f"Internal Server Error: {str(e)}")
|
||||||
|
|
||||||
#### [BETA] - This is a beta endpoint, format might change based on user feedback. - https://github.com/BerriAI/litellm/issues/933
|
#### [BETA] - This is a beta endpoint, format might change based on user feedback. - https://github.com/BerriAI/litellm/issues/933
|
||||||
@router.get("/model/info", description="Provides more info about each model in /models, including config.yaml descriptions (except api key and api base)", tags=["model management"], dependencies=[Depends(user_api_key_auth)])
|
@router.get("/model/info", description="Provides more info about each model in /models, including config.yaml descriptions (except api key and api base)", tags=["model management"], dependencies=[Depends(user_api_key_auth)])
|
||||||
async def model_info(request: Request):
|
async def model_info(request: Request):
|
||||||
global llm_model_list, general_settings
|
global llm_model_list, general_settings
|
||||||
all_models = []
|
# Load existing config
|
||||||
if llm_model_list is not None:
|
with open(user_config_file_path, "r") as config_file:
|
||||||
for m in llm_model_list:
|
config = yaml.safe_load(config_file)
|
||||||
model_dict = {}
|
all_models = config['model_list']
|
||||||
model_name = m["model_name"]
|
|
||||||
model_params = {}
|
|
||||||
for k,v in m["litellm_params"].items():
|
|
||||||
if k == "api_key" or k == "api_base": # don't send the api key or api base
|
|
||||||
continue
|
|
||||||
|
|
||||||
if k == "model":
|
for model in all_models:
|
||||||
########## remove -ModelID-XXXX from model ##############
|
# don't return the api key
|
||||||
original_model_string = v
|
model["litellm_params"].pop("api_key", None)
|
||||||
# Find the index of "ModelID" in the string
|
|
||||||
index_of_model_id = original_model_string.find("-ModelID")
|
|
||||||
# Remove everything after "-ModelID" if it exists
|
|
||||||
if index_of_model_id != -1:
|
|
||||||
v = original_model_string[:index_of_model_id]
|
|
||||||
else:
|
|
||||||
v = original_model_string
|
|
||||||
|
|
||||||
model_params[k] = v
|
|
||||||
|
|
||||||
model_dict["model_name"] = model_name
|
|
||||||
model_dict["model_params"] = model_params
|
|
||||||
all_models.append(model_dict)
|
|
||||||
# all_models = list(set([m["model_name"] for m in llm_model_list]))
|
# all_models = list(set([m["model_name"] for m in llm_model_list]))
|
||||||
print_verbose(f"all_models: {all_models}")
|
print_verbose(f"all_models: {all_models}")
|
||||||
return dict(
|
return dict(
|
||||||
|
@ -991,7 +1006,6 @@ async def model_info(request: Request):
|
||||||
],
|
],
|
||||||
object="list",
|
object="list",
|
||||||
)
|
)
|
||||||
pass
|
|
||||||
#### EXPERIMENTAL QUEUING ####
|
#### EXPERIMENTAL QUEUING ####
|
||||||
@router.post("/queue/request", dependencies=[Depends(user_api_key_auth)])
|
@router.post("/queue/request", dependencies=[Depends(user_api_key_auth)])
|
||||||
async def async_queue_request(request: Request):
|
async def async_queue_request(request: Request):
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue