From e3a1c58dd96ebf71fcea082943577544ac723dd1 Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Thu, 2 Nov 2023 20:56:33 -0700 Subject: [PATCH] build(litellm_server/utils.py): add support for general settings + num retries as a module variable --- litellm/__init__.py | 1 + litellm/llms/prompt_templates/factory.py | 1 - litellm/main.py | 1 + litellm/utils.py | 7 ++- litellm_server/Dockerfile | 2 +- litellm_server/main.py | 55 +++++++++++++----------- litellm_server/utils.py | 18 +++++--- 7 files changed, 52 insertions(+), 33 deletions(-) diff --git a/litellm/__init__.py b/litellm/__init__.py index b2627ecbb..271ee7a84 100644 --- a/litellm/__init__.py +++ b/litellm/__init__.py @@ -46,6 +46,7 @@ add_function_to_prompt: bool = False # if function calling not supported by api, client_session: Optional[requests.Session] = None model_fallbacks: Optional[List] = None model_cost_map_url: str = "https://raw.githubusercontent.com/BerriAI/litellm/main/model_prices_and_context_window.json" +num_retries: Optional[int] = None ############################################# def get_model_cost_map(url: str): diff --git a/litellm/llms/prompt_templates/factory.py b/litellm/llms/prompt_templates/factory.py index f23691612..959b8759f 100644 --- a/litellm/llms/prompt_templates/factory.py +++ b/litellm/llms/prompt_templates/factory.py @@ -267,7 +267,6 @@ def custom_prompt(role_dict: dict, messages: list, initial_prompt_value: str="", bos_open = False prompt += final_prompt_value - print(f"COMPLETE PROMPT: {prompt}") return prompt def prompt_factory(model: str, messages: list, custom_llm_provider: Optional[str]=None): diff --git a/litellm/main.py b/litellm/main.py index 00255c18c..55918f3a1 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -270,6 +270,7 @@ def completion( non_default_params = {k: v for k,v in kwargs.items() if k not in default_params} # model-specific params - pass them straight to the model/provider if mock_response: return mock_completion(model, messages, stream=stream, mock_response=mock_response) + try: logging = litellm_logging_obj fallbacks = ( diff --git a/litellm/utils.py b/litellm/utils.py index 207315120..e97a1bd6c 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -826,7 +826,12 @@ def client(original_function): except Exception as e: call_type = original_function.__name__ if call_type == CallTypes.completion.value: - num_retries = kwargs.get("num_retries", None) + num_retries = ( + kwargs.get("num_retries", None) + or litellm.num_retries + or None + ) + litellm.num_retries = None # set retries to None to prevent infinite loops context_window_fallback_dict = kwargs.get("context_window_fallback_dict", {}) if num_retries: diff --git a/litellm_server/Dockerfile b/litellm_server/Dockerfile index 70d12a253..7be7ba4c9 100644 --- a/litellm_server/Dockerfile +++ b/litellm_server/Dockerfile @@ -7,4 +7,4 @@ RUN pip install -r requirements.txt EXPOSE $PORT -CMD exec uvicorn main:app --host 0.0.0.0 --port $PORT \ No newline at end of file +CMD exec uvicorn main:app --host 0.0.0.0 --port $PORT --workers 10 \ No newline at end of file diff --git a/litellm_server/main.py b/litellm_server/main.py index c69bda620..c7b26b685 100644 --- a/litellm_server/main.py +++ b/litellm_server/main.py @@ -5,10 +5,11 @@ from fastapi.responses import StreamingResponse, FileResponse from fastapi.middleware.cors import CORSMiddleware import json, sys from typing import Optional -# sys.path.insert( -# 0, os.path.abspath("../") -# ) # Adds the parent directory to the system path - for litellm local dev +sys.path.insert( + 0, os.path.abspath("../") +) # Adds the parent directory to the system path - for litellm local dev import litellm +print(f"litellm: {litellm}") try: from utils import set_callbacks, load_router_config, print_verbose except ImportError: @@ -30,14 +31,15 @@ app.add_middleware( #### GLOBAL VARIABLES #### llm_router: Optional[litellm.Router] = None llm_model_list: Optional[list] = None +server_settings: Optional[dict] = None set_callbacks() # sets litellm callbacks for logging if they exist in the environment if "CONFIG_FILE_PATH" in os.environ: print(f"CONFIG FILE DETECTED") - llm_router, llm_model_list = load_router_config(router=llm_router, config_file_path=os.getenv("CONFIG_FILE_PATH")) + llm_router, llm_model_list, server_settings = load_router_config(router=llm_router, config_file_path=os.getenv("CONFIG_FILE_PATH")) else: - llm_router, llm_model_list = load_router_config(router=llm_router) + llm_router, llm_model_list, server_settings = load_router_config(router=llm_router) #### API ENDPOINTS #### @router.get("/v1/models") @router.get("/models") # if project requires model list @@ -100,27 +102,31 @@ async def embedding(request: Request): @router.post("/chat/completions") @router.post("/openai/deployments/{model:path}/chat/completions") # azure compatible endpoint async def chat_completion(request: Request, model: Optional[str] = None): - global llm_model_list + global llm_model_list, server_settings try: data = await request.json() - if model: - data["model"] = model + print(f"data: {data}") + data["model"] = ( + server_settings.get("completion_model", None) # server default + or model # model passed in url + or data["model"] # default passed in + ) ## CHECK KEYS ## # default to always using the "ENV" variables, only if AUTH_STRATEGY==DYNAMIC then reads headers - env_validation = litellm.validate_environment(model=data["model"]) - if (env_validation['keys_in_environment'] is False or os.getenv("AUTH_STRATEGY", None) == "DYNAMIC") and ("authorization" in request.headers or "api-key" in request.headers): # if users pass LLM api keys as part of header - if "authorization" in request.headers: - api_key = request.headers.get("authorization") - elif "api-key" in request.headers: - api_key = request.headers.get("api-key") - print(f"api_key in headers: {api_key}") - if " " in api_key: - api_key = api_key.split(" ")[1] - print(f"api_key split: {api_key}") - if len(api_key) > 0: - api_key = api_key - data["api_key"] = api_key - print(f"api_key in data: {api_key}") + # env_validation = litellm.validate_environment(model=data["model"]) + # if (env_validation['keys_in_environment'] is False or os.getenv("AUTH_STRATEGY", None) == "DYNAMIC") and ("authorization" in request.headers or "api-key" in request.headers): # if users pass LLM api keys as part of header + # if "authorization" in request.headers: + # api_key = request.headers.get("authorization") + # elif "api-key" in request.headers: + # api_key = request.headers.get("api-key") + # print(f"api_key in headers: {api_key}") + # if " " in api_key: + # api_key = api_key.split(" ")[1] + # print(f"api_key split: {api_key}") + # if len(api_key) > 0: + # api_key = api_key + # data["api_key"] = api_key + # print(f"api_key in data: {api_key}") ## CHECK CONFIG ## if llm_model_list and data["model"] in [m["model_name"] for m in llm_model_list]: for m in llm_model_list: @@ -133,13 +139,14 @@ async def chat_completion(request: Request, model: Optional[str] = None): ) if 'stream' in data and data['stream'] == True: # use generate_responses to stream responses return StreamingResponse(data_generator(response), media_type='text/event-stream') + print(f"response: {response}") return response except Exception as e: error_traceback = traceback.format_exc() print(f"{error_traceback}") error_msg = f"{str(e)}\n\n{error_traceback}" - return {"error": error_msg} - # raise HTTPException(status_code=500, detail=error_msg) + # return {"error": error_msg} + raise HTTPException(status_code=500, detail=error_msg) @router.post("/router/completions") async def router_completion(request: Request): diff --git a/litellm_server/utils.py b/litellm_server/utils.py index 5f328d328..ffaa64c91 100644 --- a/litellm_server/utils.py +++ b/litellm_server/utils.py @@ -43,16 +43,17 @@ def set_callbacks(): ## CACHING ### REDIS - if len(os.getenv("REDIS_HOST", "")) > 0 and len(os.getenv("REDIS_PORT", "")) > 0 and len(os.getenv("REDIS_PASSWORD", "")) > 0: - from litellm.caching import Cache - litellm.cache = Cache(type="redis", host=os.getenv("REDIS_HOST"), port=os.getenv("REDIS_PORT"), password=os.getenv("REDIS_PASSWORD")) - print("\033[92mLiteLLM: Switched on Redis caching\033[0m") + # if len(os.getenv("REDIS_HOST", "")) > 0 and len(os.getenv("REDIS_PORT", "")) > 0 and len(os.getenv("REDIS_PASSWORD", "")) > 0: + # print(f"redis host: {os.getenv('REDIS_HOST')}; redis port: {os.getenv('REDIS_PORT')}; password: {os.getenv('REDIS_PASSWORD')}") + # from litellm.caching import Cache + # litellm.cache = Cache(type="redis", host=os.getenv("REDIS_HOST"), port=os.getenv("REDIS_PORT"), password=os.getenv("REDIS_PASSWORD")) + # print("\033[92mLiteLLM: Switched on Redis caching\033[0m") def load_router_config(router: Optional[litellm.Router], config_file_path: Optional[str]='/app/config.yaml'): config = {} - + server_settings = {} try: if os.path.exists(config_file_path): with open(config_file_path, 'r') as file: @@ -62,6 +63,11 @@ def load_router_config(router: Optional[litellm.Router], config_file_path: Optio except: pass + ## SERVER SETTINGS (e.g. default completion model = 'ollama/mistral') + server_settings = config.get("server_settings", None) + if server_settings: + server_settings = server_settings + ## LITELLM MODULE SETTINGS (e.g. litellm.drop_params=True,..) litellm_settings = config.get('litellm_settings', None) if litellm_settings: @@ -79,4 +85,4 @@ def load_router_config(router: Optional[litellm.Router], config_file_path: Optio for key, value in environment_variables.items(): os.environ[key] = value - return router, model_list + return router, model_list, server_settings