diff --git a/.github/workflows/docker.yml b/.github/workflows/docker.yml new file mode 100644 index 0000000000..616ba5c9f9 --- /dev/null +++ b/.github/workflows/docker.yml @@ -0,0 +1,46 @@ +name: Build Docker Images +on: + workflow_dispatch: + inputs: + tag: + description: "The tag version you want to build" +jobs: + build: + runs-on: ubuntu-latest + steps: + - name: Checkout + uses: actions/checkout@v4 + - name: Set up QEMU + uses: docker/setup-qemu-action@v3 + - name: Set up Docker Buildx + uses: docker/setup-buildx-action@v3 + - name: Login to GitHub Container Registry + uses: docker/login-action@v3 + with: + registry: ghcr.io + username: ${{ github.actor }} + password: ${{ secrets.GITHUB_TOKEN }} + - name: Extract metadata (tags, labels) for Docker + id: meta + uses: docker/metadata-action@v5 + with: + images: ghcr.io/${{ github.repository }} + - name: Get tag to build + id: tag + run: | + echo "latest=ghcr.io/${{ github.repository }}:latest" >> $GITHUB_OUTPUT + if [[ -z "${{ github.event.inputs.tag }}" ]]; then + echo "versioned=ghcr.io/${{ github.repository }}:${{ github.ref_name }}" >> $GITHUB_OUTPUT + else + echo "versioned=ghcr.io/${{ github.repository }}:${{ github.event.inputs.tag }}" >> $GITHUB_OUTPUT + fi + - name: Build and release Docker images + uses: docker/build-push-action@v5 + with: + context: . + platforms: linux/amd64,linux/arm64 + tags: | + ${{ steps.tag.outputs.latest }} + ${{ steps.tag.outputs.versioned }} + labels: ${{ steps.meta.outputs.labels }} + push: true \ No newline at end of file diff --git a/Dockerfile b/Dockerfile index be162d4511..42b223b1fc 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,13 +1,10 @@ FROM python:3.10 +ENV LITELLM_CONFIG_PATH="/litellm.secrets.toml" COPY . /app WORKDIR /app -RUN mkdir -p /root/.config/litellm/ && cp /app/secrets_template.toml /root/.config/litellm/litellm.secrets.toml RUN pip install -r requirements.txt WORKDIR /app/litellm/proxy EXPOSE 8000 -ENTRYPOINT [ "python3", "proxy_cli.py" ] -# TODO - Set up a GitHub Action to automatically create the Docker image, -# and then we can quickly deploy the litellm proxy in the following way -# `docker run -p 8000:8000 -v ./secrets_template.toml:/root/.config/litellm/litellm.secrets.toml ghcr.io/BerriAI/litellm:v0.8.4` \ No newline at end of file +ENTRYPOINT [ "python3", "proxy_cli.py" ] \ No newline at end of file diff --git a/litellm/proxy/proxy_cli.py b/litellm/proxy/proxy_cli.py index f2b29bf5f3..96e089caaa 100644 --- a/litellm/proxy/proxy_cli.py +++ b/litellm/proxy/proxy_cli.py @@ -9,7 +9,7 @@ import operator config_filename = "litellm.secrets.toml" # Using appdirs to determine user-specific config path config_dir = appdirs.user_config_dir("litellm") -user_config_path = os.path.join(config_dir, config_filename) +user_config_path = os.getenv("LITELLM_CONFIG_PATH", os.path.join(config_dir, config_filename)) load_dotenv() from importlib import resources diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index 5736465840..f82177418c 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -18,7 +18,20 @@ except ImportError: import subprocess import sys - subprocess.check_call([sys.executable, "-m", "pip", "install", "uvicorn", "fastapi", "tomli", "appdirs", "tomli-w", "backoff"]) + subprocess.check_call( + [ + sys.executable, + "-m", + "pip", + "install", + "uvicorn", + "fastapi", + "tomli", + "appdirs", + "tomli-w", + "backoff", + ] + ) import uvicorn import fastapi import tomli as tomllib @@ -26,9 +39,9 @@ except ImportError: import tomli_w try: - from .llm import litellm_completion + from .llm import litellm_completion except ImportError as e: - from llm import litellm_completion # type: ignore + from llm import litellm_completion # type: ignore import random @@ -51,14 +64,17 @@ def generate_feedback_box(): message = random.choice(list_of_messages) print() - print('\033[1;37m' + '#' + '-' * box_width + '#\033[0m') - print('\033[1;37m' + '#' + ' ' * box_width + '#\033[0m') - print('\033[1;37m' + '# {:^59} #\033[0m'.format(message)) - print('\033[1;37m' + '# {:^59} #\033[0m'.format('https://github.com/BerriAI/litellm/issues/new')) - print('\033[1;37m' + '#' + ' ' * box_width + '#\033[0m') - print('\033[1;37m' + '#' + '-' * box_width + '#\033[0m') + print("\033[1;37m" + "#" + "-" * box_width + "#\033[0m") + print("\033[1;37m" + "#" + " " * box_width + "#\033[0m") + print("\033[1;37m" + "# {:^59} #\033[0m".format(message)) + print( + "\033[1;37m" + + "# {:^59} #\033[0m".format("https://github.com/BerriAI/litellm/issues/new") + ) + print("\033[1;37m" + "#" + " " * box_width + "#\033[0m") + print("\033[1;37m" + "#" + "-" * box_width + "#\033[0m") print() - print(' Thank you for using LiteLLM! - Krrish & Ishaan') + print(" Thank you for using LiteLLM! - Krrish & Ishaan") print() print() @@ -66,7 +82,9 @@ def generate_feedback_box(): generate_feedback_box() print() -print("\033[1;31mGive Feedback / Get Help: https://github.com/BerriAI/litellm/issues/new\033[0m") +print( + "\033[1;31mGive Feedback / Get Help: https://github.com/BerriAI/litellm/issues/new\033[0m" +) print() print("\033[1;34mDocs: https://docs.litellm.ai/docs/proxy_server\033[0m") print() @@ -105,8 +123,10 @@ model_router = litellm.Router() config_filename = "litellm.secrets.toml" config_dir = os.getcwd() config_dir = appdirs.user_config_dir("litellm") -user_config_path = os.path.join(config_dir, config_filename) -log_file = 'api_log.json' +user_config_path = os.getenv( + "LITELLM_CONFIG_PATH", os.path.join(config_dir, config_filename) +) +log_file = "api_log.json" #### HELPER FUNCTIONS #### @@ -124,12 +144,13 @@ def find_avatar_url(role): def usage_telemetry( - feature: str): # helps us know if people are using this feature. Set `litellm --telemetry False` to your cli call to turn this off + feature: str, +): # helps us know if people are using this feature. Set `litellm --telemetry False` to your cli call to turn this off if user_telemetry: - data = { - "feature": feature # "local_proxy_server" - } - threading.Thread(target=litellm.utils.litellm_telemetry, args=(data,), daemon=True).start() + data = {"feature": feature} # "local_proxy_server" + threading.Thread( + target=litellm.utils.litellm_telemetry, args=(data,), daemon=True + ).start() def add_keys_to_config(key, value): @@ -142,11 +163,11 @@ def add_keys_to_config(key, value): # File doesn't exist, create empty config config = {} - # Add new key - config.setdefault('keys', {})[key] = value + # Add new key + config.setdefault("keys", {})[key] = value - # Write config to file - with open(user_config_path, 'wb') as f: + # Write config to file + with open(user_config_path, "wb") as f: tomli_w.dump(config, f) @@ -160,15 +181,15 @@ def save_params_to_config(data: dict): # File doesn't exist, create empty config config = {} - config.setdefault('general', {}) + config.setdefault("general", {}) - ## general config + ## general config general_settings = data["general"] for key, value in general_settings.items(): config["general"][key] = value - ## model-specific config + ## model-specific config config.setdefault("model", {}) config["model"].setdefault(user_model, {}) @@ -178,13 +199,13 @@ def save_params_to_config(data: dict): for key, value in user_model_config.items(): config["model"][model_key][key] = value - # Write config to file - with open(user_config_path, 'wb') as f: + # Write config to file + with open(user_config_path, "wb") as f: tomli_w.dump(config, f) def load_config(): - try: + try: global user_config, user_api_base, user_max_tokens, user_temperature, user_model, local_logging # As the .env file is typically much simpler in structure, we use load_dotenv here directly with open(user_config_path, "rb") as f: @@ -193,16 +214,23 @@ def load_config(): ## load keys if "keys" in user_config: for key in user_config["keys"]: - os.environ[key] = user_config["keys"][key] # litellm can read keys from the environment + os.environ[key] = user_config["keys"][ + key + ] # litellm can read keys from the environment ## settings if "general" in user_config: - litellm.add_function_to_prompt = user_config["general"].get("add_function_to_prompt", - True) # by default add function to prompt if unsupported by provider - litellm.drop_params = user_config["general"].get("drop_params", - True) # by default drop params if unsupported by provider - litellm.model_fallbacks = user_config["general"].get("fallbacks", - None) # fallback models in case initial completion call fails - default_model = user_config["general"].get("default_model", None) # route all requests to this model. + litellm.add_function_to_prompt = user_config["general"].get( + "add_function_to_prompt", True + ) # by default add function to prompt if unsupported by provider + litellm.drop_params = user_config["general"].get( + "drop_params", True + ) # by default drop params if unsupported by provider + litellm.model_fallbacks = user_config["general"].get( + "fallbacks", None + ) # fallback models in case initial completion call fails + default_model = user_config["general"].get( + "default_model", None + ) # route all requests to this model. local_logging = user_config["general"].get("local_logging", True) @@ -215,10 +243,10 @@ def load_config(): if user_model in user_config["model"]: model_config = user_config["model"][user_model] model_list = [] - for model in user_config["model"]: + for model in user_config["model"]: if "model_list" in user_config["model"][model]: model_list.extend(user_config["model"][model]["model_list"]) - if len(model_list) > 0: + if len(model_list) > 0: model_router.set_model_list(model_list=model_list) print_verbose(f"user_config: {user_config}") @@ -234,32 +262,63 @@ def load_config(): ## custom prompt template if "prompt_template" in model_config: model_prompt_template = model_config["prompt_template"] - if len(model_prompt_template.keys()) > 0: # if user has initialized this at all + if ( + len(model_prompt_template.keys()) > 0 + ): # if user has initialized this at all litellm.register_prompt_template( model=user_model, - initial_prompt_value=model_prompt_template.get("MODEL_PRE_PROMPT", ""), + initial_prompt_value=model_prompt_template.get( + "MODEL_PRE_PROMPT", "" + ), roles={ "system": { - "pre_message": model_prompt_template.get("MODEL_SYSTEM_MESSAGE_START_TOKEN", ""), - "post_message": model_prompt_template.get("MODEL_SYSTEM_MESSAGE_END_TOKEN", ""), + "pre_message": model_prompt_template.get( + "MODEL_SYSTEM_MESSAGE_START_TOKEN", "" + ), + "post_message": model_prompt_template.get( + "MODEL_SYSTEM_MESSAGE_END_TOKEN", "" + ), }, "user": { - "pre_message": model_prompt_template.get("MODEL_USER_MESSAGE_START_TOKEN", ""), - "post_message": model_prompt_template.get("MODEL_USER_MESSAGE_END_TOKEN", ""), + "pre_message": model_prompt_template.get( + "MODEL_USER_MESSAGE_START_TOKEN", "" + ), + "post_message": model_prompt_template.get( + "MODEL_USER_MESSAGE_END_TOKEN", "" + ), }, "assistant": { - "pre_message": model_prompt_template.get("MODEL_ASSISTANT_MESSAGE_START_TOKEN", ""), - "post_message": model_prompt_template.get("MODEL_ASSISTANT_MESSAGE_END_TOKEN", ""), - } + "pre_message": model_prompt_template.get( + "MODEL_ASSISTANT_MESSAGE_START_TOKEN", "" + ), + "post_message": model_prompt_template.get( + "MODEL_ASSISTANT_MESSAGE_END_TOKEN", "" + ), + }, }, - final_prompt_value=model_prompt_template.get("MODEL_POST_PROMPT", ""), + final_prompt_value=model_prompt_template.get( + "MODEL_POST_PROMPT", "" + ), ) - except: + except: pass -def initialize(model, alias, api_base, api_version, debug, temperature, max_tokens, max_budget, telemetry, drop_params, - add_function_to_prompt, headers, save): +def initialize( + model, + alias, + api_base, + api_version, + debug, + temperature, + max_tokens, + max_budget, + telemetry, + drop_params, + add_function_to_prompt, + headers, + save, +): global user_model, user_api_base, user_debug, user_max_tokens, user_temperature, user_telemetry, user_headers user_model = model user_debug = debug @@ -271,8 +330,10 @@ def initialize(model, alias, api_base, api_version, debug, temperature, max_toke if api_base: # model-specific param user_api_base = api_base dynamic_config[user_model]["api_base"] = api_base - if api_version: - os.environ["AZURE_API_VERSION"] = api_version # set this for azure - litellm can read this from the env + if api_version: + os.environ[ + "AZURE_API_VERSION" + ] = api_version # set this for azure - litellm can read this from the env if max_tokens: # model-specific param user_max_tokens = max_tokens dynamic_config[user_model]["max_tokens"] = max_tokens @@ -290,7 +351,7 @@ def initialize(model, alias, api_base, api_version, debug, temperature, max_toke if max_budget: # litellm-specific param litellm.max_budget = max_budget dynamic_config["general"]["max_budget"] = max_budget - if debug: # litellm-specific param + if debug: # litellm-specific param litellm.set_verbose = True if save: save_params_to_config(dynamic_config) @@ -300,16 +361,18 @@ def initialize(model, alias, api_base, api_version, debug, temperature, max_toke user_telemetry = telemetry usage_telemetry(feature="local_proxy_server") + def track_cost_callback( - kwargs, # kwargs to completion - completion_response, # response from completion - start_time, end_time # start/end time + kwargs, # kwargs to completion + completion_response, # response from completion + start_time, + end_time, # start/end time ): - # track cost like this + # track cost like this # { # "Oct12": { # "gpt-4": 10, - # "claude-2": 12.01, + # "claude-2": 12.01, # }, # "Oct 15": { # "ollama/llama2": 0.0, @@ -317,28 +380,27 @@ def track_cost_callback( # } # } try: - # for streaming responses if "complete_streaming_response" in kwargs: - # for tracking streaming cost we pass the "messages" and the output_text to litellm.completion_cost + # for tracking streaming cost we pass the "messages" and the output_text to litellm.completion_cost completion_response = kwargs["complete_streaming_response"] input_text = kwargs["messages"] output_text = completion_response["choices"][0]["message"]["content"] response_cost = litellm.completion_cost( - model=kwargs["model"], - messages=input_text, - completion=output_text + model=kwargs["model"], messages=input_text, completion=output_text ) - model = kwargs['model'] + model = kwargs["model"] # for non streaming responses else: # we pass the completion_response obj if kwargs["stream"] != True: - response_cost = litellm.completion_cost(completion_response=completion_response) + response_cost = litellm.completion_cost( + completion_response=completion_response + ) model = completion_response["model"] - # read/write from json for storing daily model costs + # read/write from json for storing daily model costs cost_data = {} try: with open("costs.json") as f: @@ -346,6 +408,7 @@ def track_cost_callback( except FileNotFoundError: cost_data = {} import datetime + date = datetime.datetime.now().strftime("%b-%d-%Y") if date not in cost_data: cost_data[date] = {} @@ -356,7 +419,7 @@ def track_cost_callback( else: cost_data[date][kwargs["model"]] = { "cost": response_cost, - "num_requests": 1 + "num_requests": 1, } with open("costs.json", "w") as f: @@ -367,25 +430,21 @@ def track_cost_callback( def logger( - kwargs, # kwargs to completion - completion_response=None, # response from completion - start_time=None, - end_time=None # start/end time + kwargs, # kwargs to completion + completion_response=None, # response from completion + start_time=None, + end_time=None, # start/end time ): - log_event_type = kwargs['log_event_type'] + log_event_type = kwargs["log_event_type"] try: - if log_event_type == 'pre_api_call': + if log_event_type == "pre_api_call": inference_params = copy.deepcopy(kwargs) - timestamp = inference_params.pop('start_time') + timestamp = inference_params.pop("start_time") dt_key = timestamp.strftime("%Y%m%d%H%M%S%f")[:23] - log_data = { - dt_key: { - 'pre_api_call': inference_params - } - } + log_data = {dt_key: {"pre_api_call": inference_params}} try: - with open(log_file, 'r') as f: + with open(log_file, "r") as f: existing_data = json.load(f) except FileNotFoundError: existing_data = {} @@ -393,7 +452,7 @@ def logger( existing_data.update(log_data) def write_to_log(): - with open(log_file, 'w') as f: + with open(log_file, "w") as f: json.dump(existing_data, f, indent=2) thread = threading.Thread(target=write_to_log, daemon=True) @@ -413,14 +472,28 @@ litellm.failure_callback = [logger] def model_list(): if user_model != None: return dict( - data=[{"id": user_model, "object": "model", "created": 1677610602, "owned_by": "openai"}], + data=[ + { + "id": user_model, + "object": "model", + "created": 1677610602, + "owned_by": "openai", + } + ], object="list", ) else: all_models = litellm.utils.get_valid_models() return dict( - data=[{"id": model, "object": "model", "created": 1677610602, "owned_by": "openai"} for model in - all_models], + data=[ + { + "id": model, + "object": "model", + "created": 1677610602, + "owned_by": "openai", + } + for model in all_models + ], object="list", ) @@ -445,7 +518,7 @@ async def chat_completion(request: Request): def print_cost_logs(): - with open('costs.json', 'r') as f: + with open("costs.json", "r") as f: # print this in green print("\033[1;32m") print(f.read()) @@ -455,7 +528,7 @@ def print_cost_logs(): @router.get("/ollama_logs") async def retrieve_server_log(request: Request): - filepath = os.path.expanduser('~/.ollama/logs/server.log') + filepath = os.path.expanduser("~/.ollama/logs/server.log") return FileResponse(filepath)