diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index 65b5dd3dc..58722a99f 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -31,6 +31,7 @@ except ImportError: "appdirs", "tomli-w", "backoff", + "pyyaml" ] ) import uvicorn @@ -125,7 +126,7 @@ user_config_path = os.getenv( #### GLOBAL VARIABLES #### llm_router: Optional[litellm.Router] = None llm_model_list: Optional[list] = None -server_settings: Optional[dict] = None +server_settings: dict = {} log_file = "api_log.json" @@ -197,7 +198,7 @@ def save_params_to_config(data: dict): tomli_w.dump(config, f) -def load_router_config(router: Optional[litellm.Router], config_file_path: Optional[str]): +def load_router_config(router: Optional[litellm.Router], config_file_path: str): config = {} server_settings = {} try: @@ -210,9 +211,9 @@ def load_router_config(router: Optional[litellm.Router], config_file_path: Optio pass ## SERVER SETTINGS (e.g. default completion model = 'ollama/mistral') - server_settings = config.get("server_settings", None) - if server_settings: - server_settings = server_settings + _server_settings = config.get("server_settings", None) + if _server_settings: + server_settings = _server_settings ## LITELLM MODULE SETTINGS (e.g. litellm.drop_params=True,..) litellm_settings = config.get('litellm_settings', None) @@ -543,8 +544,10 @@ def litellm_completion(*args, **kwargs): @router.get("/v1/models") @router.get("/models") # if project requires model list def model_list(): - global llm_model_list - all_models = litellm.utils.get_valid_models() + global llm_model_list, server_settings + all_models = [] + if server_settings.get("infer_model_from_keys", False): + all_models = litellm.utils.get_valid_models() if llm_model_list: all_models += llm_model_list if user_model is not None: @@ -573,13 +576,19 @@ def model_list(): @router.post("/v1/completions") @router.post("/completions") @router.post("/engines/{model:path}/completions") -async def completion(request: Request): +async def completion(request: Request, model: Optional[str] = None): body = await request.body() body_str = body.decode() try: data = ast.literal_eval(body_str) except: data = json.loads(body_str) + data["model"] = ( + server_settings.get("completion_model", None) # server default + or user_model # model name passed via cli args + or model # for azure deployments + or data["model"] # default passed in http request + ) if user_model: data["model"] = user_model data["call_type"] = "text_completion" @@ -590,15 +599,21 @@ async def completion(request: Request): @router.post("/v1/chat/completions") @router.post("/chat/completions") -async def chat_completion(request: Request): +@router.post("/openai/deployments/{model:path}/chat/completions") # azure compatible endpoint +async def chat_completion(request: Request, model: Optional[str] = None): + global server_settings body = await request.body() body_str = body.decode() try: data = ast.literal_eval(body_str) except: data = json.loads(body_str) - if user_model: - data["model"] = user_model + data["model"] = ( + server_settings.get("completion_model", None) # server default + or user_model # model name passed via cli args + or model # for azure deployments + or data["model"] # default passed in http request + ) data["call_type"] = "chat_completion" return litellm_completion( **data