From 4414594e7d9d0ea980d38d3409c72014664db363 Mon Sep 17 00:00:00 2001 From: coconut49 Date: Tue, 17 Oct 2023 23:48:55 +0800 Subject: [PATCH] Refactor proxy_server.py for readability and code consistency --- litellm/proxy/proxy_server.py | 334 ++++++++++++++++++---------------- 1 file changed, 181 insertions(+), 153 deletions(-) diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index 90a967921..fc16747c3 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -1,11 +1,11 @@ import sys, os, platform, time, copy import threading import shutil, random, traceback + sys.path.insert( 0, os.path.abspath("../..") ) # Adds the parent directory to the system path - for litellm local dev - try: import uvicorn import fastapi @@ -22,13 +22,14 @@ except ImportError: import tomli as tomllib import appdirs import tomli_w - + try: from .llm import litellm_completion -except ImportError as e: +except ImportError as e: from llm import litellm_completion import random + list_of_messages = [ "'The thing I wish you improved is...'", "'A feature I really want is...'", @@ -37,35 +38,36 @@ list_of_messages = [ "'I don't like how this works...'", "'It would help me if you could add...'", "'This feature doesn't meet my needs because...'", - "'I get frustrated when the product...'", + "'I get frustrated when the product...'", ] + def generate_feedback_box(): - box_width = 60 + box_width = 60 - # Select a random message - message = random.choice(list_of_messages) + # Select a random message + message = random.choice(list_of_messages) + + print() + print('\033[1;37m' + '#' + '-' * box_width + '#\033[0m') + print('\033[1;37m' + '#' + ' ' * box_width + '#\033[0m') + print('\033[1;37m' + '# {:^59} #\033[0m'.format(message)) + print('\033[1;37m' + '# {:^59} #\033[0m'.format('https://github.com/BerriAI/litellm/issues/new')) + print('\033[1;37m' + '#' + ' ' * box_width + '#\033[0m') + print('\033[1;37m' + '#' + '-' * box_width + '#\033[0m') + print() + print(' Thank you for using LiteLLM! - Krrish & Ishaan') + print() + print() - print() - print('\033[1;37m' + '#' + '-'*box_width + '#\033[0m') - print('\033[1;37m' + '#' + ' '*box_width + '#\033[0m') - print('\033[1;37m' + '# {:^59} #\033[0m'.format(message)) - print('\033[1;37m' + '# {:^59} #\033[0m'.format('https://github.com/BerriAI/litellm/issues/new')) - print('\033[1;37m' + '#' + ' '*box_width + '#\033[0m') - print('\033[1;37m' + '#' + '-'*box_width + '#\033[0m') - print() - print(' Thank you for using LiteLLM! - Krrish & Ishaan') - print() - print() generate_feedback_box() - print() print("\033[1;31mGive Feedback / Get Help: https://github.com/BerriAI/litellm/issues/new\033[0m") print() print("\033[1;34mDocs: https://docs.litellm.ai/docs/proxy_server\033[0m") -print() +print() import litellm from fastapi import FastAPI, Request @@ -100,24 +102,29 @@ config_dir = os.getcwd() config_dir = appdirs.user_config_dir("litellm") user_config_path = os.path.join(config_dir, config_filename) log_file = 'api_log.json' + + #### HELPER FUNCTIONS #### def print_verbose(print_statement): - global user_debug - if user_debug: - print(print_statement) + global user_debug + if user_debug: + print(print_statement) -def usage_telemetry(feature: str): # helps us know if people are using this feature. Set `litellm --telemetry False` to your cli call to turn this off - if user_telemetry: + +def usage_telemetry( + feature: str): # helps us know if people are using this feature. Set `litellm --telemetry False` to your cli call to turn this off + if user_telemetry: data = { - "feature": feature # "local_proxy_server" + "feature": feature # "local_proxy_server" } threading.Thread(target=litellm.utils.litellm_telemetry, args=(data,), daemon=True).start() + def add_keys_to_config(key, value): # Check if file exists if os.path.exists(user_config_path): # Load existing file - with open(user_config_path, "rb") as f: + with open(user_config_path, "rb") as f: config = tomllib.load(f) else: # File doesn't exist, create empty config @@ -130,21 +137,22 @@ def add_keys_to_config(key, value): with open(user_config_path, 'wb') as f: tomli_w.dump(config, f) -def save_params_to_config(data: dict): + +def save_params_to_config(data: dict): # Check if file exists if os.path.exists(user_config_path): # Load existing file - with open(user_config_path, "rb") as f: + with open(user_config_path, "rb") as f: config = tomllib.load(f) else: # File doesn't exist, create empty config config = {} - + config.setdefault('general', {}) ## general config general_settings = data["general"] - + for key, value in general_settings.items(): config["general"][key] = value @@ -161,101 +169,104 @@ def save_params_to_config(data: dict): # Write config to file with open(user_config_path, 'wb') as f: tomli_w.dump(config, f) - + def load_config(): - try: - global user_config, user_api_base, user_max_tokens, user_temperature, user_model - # As the .env file is typically much simpler in structure, we use load_dotenv here directly - with open(user_config_path, "rb") as f: - user_config = tomllib.load(f) + global user_config, user_api_base, user_max_tokens, user_temperature, user_model + # As the .env file is typically much simpler in structure, we use load_dotenv here directly + with open(user_config_path, "rb") as f: + user_config = tomllib.load(f) - ## load keys - if "keys" in user_config: - for key in user_config["keys"]: - os.environ[key] = user_config["keys"][key] # litellm can read keys from the environment - ## settings - if "general" in user_config: - litellm.add_function_to_prompt = user_config["general"].get("add_function_to_prompt", True) # by default add function to prompt if unsupported by provider - litellm.drop_params = user_config["general"].get("drop_params", True) # by default drop params if unsupported by provider - litellm.model_fallbacks = user_config["general"].get("fallbacks", None) # fallback models in case initial completion call fails - default_model = user_config["general"].get("default_model", None) # route all requests to this model. + ## load keys + if "keys" in user_config: + for key in user_config["keys"]: + os.environ[key] = user_config["keys"][key] # litellm can read keys from the environment + ## settings + if "general" in user_config: + litellm.add_function_to_prompt = user_config["general"].get("add_function_to_prompt", + True) # by default add function to prompt if unsupported by provider + litellm.drop_params = user_config["general"].get("drop_params", + True) # by default drop params if unsupported by provider + litellm.model_fallbacks = user_config["general"].get("fallbacks", + None) # fallback models in case initial completion call fails + default_model = user_config["general"].get("default_model", None) # route all requests to this model. - if user_model is None: # `litellm --model `` > default_model. - user_model = default_model + if user_model is None: # `litellm --model `` > default_model. + user_model = default_model - ## load model config - to set this run `litellm --config` - model_config = None - if "model" in user_config: - if user_model in user_config["model"]: - model_config = user_config["model"][user_model] - - print_verbose(f"user_config: {user_config}") - print_verbose(f"model_config: {model_config}") - print_verbose(f"user_model: {user_model}") - if model_config is None: - return + ## load model config - to set this run `litellm --config` + model_config = None + if "model" in user_config: + if user_model in user_config["model"]: + model_config = user_config["model"][user_model] - user_max_tokens = model_config.get("max_tokens", None) - user_temperature = model_config.get("temperature", None) - user_api_base = model_config.get("api_base", None) - - ## custom prompt template - if "prompt_template" in model_config: - model_prompt_template = model_config["prompt_template"] - if len(model_prompt_template.keys()) > 0: # if user has initialized this at all - litellm.register_prompt_template( - model=user_model, - initial_prompt_value=model_prompt_template.get("MODEL_PRE_PROMPT", ""), - roles={ - "system": { - "pre_message": model_prompt_template.get("MODEL_SYSTEM_MESSAGE_START_TOKEN", ""), - "post_message": model_prompt_template.get("MODEL_SYSTEM_MESSAGE_END_TOKEN", ""), - }, - "user": { - "pre_message": model_prompt_template.get("MODEL_USER_MESSAGE_START_TOKEN", ""), - "post_message": model_prompt_template.get("MODEL_USER_MESSAGE_END_TOKEN", ""), - }, - "assistant": { - "pre_message": model_prompt_template.get("MODEL_ASSISTANT_MESSAGE_START_TOKEN", ""), - "post_message": model_prompt_template.get("MODEL_ASSISTANT_MESSAGE_END_TOKEN", ""), - } - }, - final_prompt_value=model_prompt_template.get("MODEL_POST_PROMPT", ""), - ) - except Exception as e: - pass + print_verbose(f"user_config: {user_config}") + print_verbose(f"model_config: {model_config}") + print_verbose(f"user_model: {user_model}") + if model_config is None: + return -def initialize(model, alias, api_base, debug, temperature, max_tokens, max_budget, telemetry, drop_params, add_function_to_prompt, headers, save): + user_max_tokens = model_config.get("max_tokens", None) + user_temperature = model_config.get("temperature", None) + user_api_base = model_config.get("api_base", None) + + ## custom prompt template + if "prompt_template" in model_config: + model_prompt_template = model_config["prompt_template"] + if len(model_prompt_template.keys()) > 0: # if user has initialized this at all + litellm.register_prompt_template( + model=user_model, + initial_prompt_value=model_prompt_template.get("MODEL_PRE_PROMPT", ""), + roles={ + "system": { + "pre_message": model_prompt_template.get("MODEL_SYSTEM_MESSAGE_START_TOKEN", ""), + "post_message": model_prompt_template.get("MODEL_SYSTEM_MESSAGE_END_TOKEN", ""), + }, + "user": { + "pre_message": model_prompt_template.get("MODEL_USER_MESSAGE_START_TOKEN", ""), + "post_message": model_prompt_template.get("MODEL_USER_MESSAGE_END_TOKEN", ""), + }, + "assistant": { + "pre_message": model_prompt_template.get("MODEL_ASSISTANT_MESSAGE_START_TOKEN", ""), + "post_message": model_prompt_template.get("MODEL_ASSISTANT_MESSAGE_END_TOKEN", ""), + } + }, + final_prompt_value=model_prompt_template.get("MODEL_POST_PROMPT", ""), + ) + + + +def initialize(model, alias, api_base, debug, temperature, max_tokens, max_budget, telemetry, drop_params, + add_function_to_prompt, headers, save): global user_model, user_api_base, user_debug, user_max_tokens, user_temperature, user_telemetry, user_headers user_model = model user_debug = debug load_config() - dynamic_config = {"general": {}, user_model: {}} - if headers: # model-specific param + dynamic_config = {"general": {}, user_model: {}} + if headers: # model-specific param user_headers = headers dynamic_config[user_model]["headers"] = headers - if api_base: # model-specific param + if api_base: # model-specific param user_api_base = api_base dynamic_config[user_model]["api_base"] = api_base - if max_tokens: # model-specific param + if max_tokens: # model-specific param user_max_tokens = max_tokens dynamic_config[user_model]["max_tokens"] = max_tokens - if temperature: # model-specific param + if temperature: # model-specific param user_temperature = temperature dynamic_config[user_model]["temperature"] = temperature - if alias: # model-specific param + if alias: # model-specific param dynamic_config[user_model]["alias"] = alias - if drop_params == True: # litellm-specific param + if drop_params == True: # litellm-specific param litellm.drop_params = True dynamic_config["general"]["drop_params"] = True - if add_function_to_prompt == True: # litellm-specific param + if add_function_to_prompt == True: # litellm-specific param litellm.add_function_to_prompt = True dynamic_config["general"]["add_function_to_prompt"] = True - if max_budget: # litellm-specific param + if max_budget: # litellm-specific param litellm.max_budget = max_budget dynamic_config["general"]["max_budget"] = max_budget - if save: + if save: save_params_to_config(dynamic_config) with open(user_config_path) as f: print(f.read()) @@ -263,6 +274,7 @@ def initialize(model, alias, api_base, debug, temperature, max_tokens, max_budge user_telemetry = telemetry usage_telemetry(feature="local_proxy_server") + def deploy_proxy(model, api_base, debug, temperature, max_tokens, telemetry, deploy): import requests # Load .env file @@ -293,8 +305,6 @@ def deploy_proxy(model, api_base, debug, temperature, max_tokens, telemetry, dep files = {"file": open(".env", "rb")} # print(files) - - response = requests.post(url, data=data, files=files) # print(response) # Check the status of the request @@ -309,10 +319,11 @@ def deploy_proxy(model, api_base, debug, temperature, max_tokens, telemetry, dep return url + def track_cost_callback( - kwargs, # kwargs to completion - completion_response, # response from completion - start_time, end_time # start/end time + kwargs, # kwargs to completion + completion_response, # response from completion + start_time, end_time # start/end time ): # track cost like this # { @@ -330,12 +341,12 @@ def track_cost_callback( # for streaming responses if "complete_streaming_response" in kwargs: # for tracking streaming cost we pass the "messages" and the output_text to litellm.completion_cost - completion_response=kwargs["complete_streaming_response"] + completion_response = kwargs["complete_streaming_response"] input_text = kwargs["messages"] output_text = completion_response["choices"][0]["message"]["content"] response_cost = litellm.completion_cost( - model = kwargs["model"], - messages = input_text, + model=kwargs["model"], + messages=input_text, completion=output_text ) model = kwargs['model'] @@ -353,7 +364,7 @@ def track_cost_callback( with open("costs.json") as f: cost_data = json.load(f) except FileNotFoundError: - cost_data = {} + cost_data = {} import datetime date = datetime.datetime.now().strftime("%b-%d-%Y") if date not in cost_data: @@ -374,47 +385,32 @@ def track_cost_callback( except: pass -def logger( - kwargs, # kwargs to completion - completion_response=None, # response from completion - start_time=None, - end_time=None # start/end time -): - log_event_type = kwargs['log_event_type'] - try: - if log_event_type == 'pre_api_call': - inference_params = copy.deepcopy(kwargs) - timestamp = inference_params.pop('start_time') - dt_key = timestamp.strftime("%Y%m%d%H%M%S%f")[:23] - log_data = { - dt_key: { - 'pre_api_call': inference_params - } - } - - try: - with open(log_file, 'r') as f: - existing_data = json.load(f) - except FileNotFoundError: - existing_data = {} - - existing_data.update(log_data) - def write_to_log(): - with open(log_file, 'w') as f: - json.dump(existing_data, f, indent=2) - thread = threading.Thread(target=write_to_log, daemon=True) - thread.start() - elif log_event_type == 'post_api_call': - if "stream" not in kwargs["optional_params"] or kwargs["optional_params"]["stream"] is False or kwargs.get("complete_streaming_response", False): +def logger( + kwargs, # kwargs to completion + completion_response=None, # response from completion + start_time=None, + end_time=None # start/end time +): + log_event_type = kwargs['log_event_type'] + try: + if log_event_type == 'pre_api_call': inference_params = copy.deepcopy(kwargs) timestamp = inference_params.pop('start_time') dt_key = timestamp.strftime("%Y%m%d%H%M%S%f")[:23] - - with open(log_file, 'r') as f: - existing_data = json.load(f) - - existing_data[dt_key]['post_api_call'] = inference_params + log_data = { + dt_key: { + 'pre_api_call': inference_params + } + } + + try: + with open(log_file, 'r') as f: + existing_data = json.load(f) + except FileNotFoundError: + existing_data = {} + + existing_data.update(log_data) def write_to_log(): with open(log_file, 'w') as f: @@ -422,15 +418,35 @@ def logger( thread = threading.Thread(target=write_to_log, daemon=True) thread.start() - except: - pass + elif log_event_type == 'post_api_call': + if "stream" not in kwargs["optional_params"] or kwargs["optional_params"]["stream"] is False or kwargs.get( + "complete_streaming_response", False): + inference_params = copy.deepcopy(kwargs) + timestamp = inference_params.pop('start_time') + dt_key = timestamp.strftime("%Y%m%d%H%M%S%f")[:23] + + with open(log_file, 'r') as f: + existing_data = json.load(f) + + existing_data[dt_key]['post_api_call'] = inference_params + + def write_to_log(): + with open(log_file, 'w') as f: + json.dump(existing_data, f, indent=2) + + thread = threading.Thread(target=write_to_log, daemon=True) + thread.start() + except: + pass + litellm.input_callback = [logger] litellm.success_callback = [logger] litellm.failure_callback = [logger] + #### API ENDPOINTS #### -@router.get("/models") # if project requires model list +@router.get("/models") # if project requires model list def model_list(): if user_model != None: return dict( @@ -440,19 +456,26 @@ def model_list(): else: all_models = litellm.utils.get_valid_models() return dict( - data = [{"id": model, "object": "model", "created": 1677610602, "owned_by": "openai"} for model in all_models], + data=[{"id": model, "object": "model", "created": 1677610602, "owned_by": "openai"} for model in + all_models], object="list", ) + @router.post("/completions") async def completion(request: Request): data = await request.json() - return litellm_completion(data=data, type="completion", user_model=user_model, user_temperature=user_temperature, user_max_tokens=user_max_tokens, user_api_base=user_api_base, user_headers=user_headers, user_debug=user_debug) + return litellm_completion(data=data, type="completion", user_model=user_model, user_temperature=user_temperature, + user_max_tokens=user_max_tokens, user_api_base=user_api_base, user_headers=user_headers, + user_debug=user_debug) + @router.post("/chat/completions") async def chat_completion(request: Request): data = await request.json() - response = litellm_completion(data, type="chat_completion", user_model=user_model, user_temperature=user_temperature, user_max_tokens=user_max_tokens, user_api_base=user_api_base, user_headers=user_headers, user_debug=user_debug) + response = litellm_completion(data, type="chat_completion", user_model=user_model, + user_temperature=user_temperature, user_max_tokens=user_max_tokens, + user_api_base=user_api_base, user_headers=user_headers, user_debug=user_debug) return response @@ -462,6 +485,7 @@ async def v1_completion(request: Request): data = await request.json() return litellm_completion(data=data, type="completion") + @router.post("/v1/chat/completions") async def v1_chat_completion(request: Request): data = await request.json() @@ -469,6 +493,7 @@ async def v1_chat_completion(request: Request): response = litellm_completion(data, type="chat_completion") return response + def print_cost_logs(): with open('costs.json', 'r') as f: # print this in green @@ -477,13 +502,16 @@ def print_cost_logs(): print("\033[0m") return + @router.get("/ollama_logs") async def retrieve_server_log(request: Request): filepath = os.path.expanduser('~/.ollama/logs/server.log') return FileResponse(filepath) + @router.get("/") async def home(request: Request): return "LiteLLM: RUNNING" -app.include_router(router) \ No newline at end of file + +app.include_router(router)