diff --git a/litellm/proxy/proxy_cli.py b/litellm/proxy/proxy_cli.py index 1fa714bc67..3a885c87c9 100644 --- a/litellm/proxy/proxy_cli.py +++ b/litellm/proxy/proxy_cli.py @@ -5,10 +5,17 @@ import random import subprocess import sys import urllib.parse as urlparse +from typing import TYPE_CHECKING, Any, Optional, Union import click +import httpx from dotenv import load_dotenv +if TYPE_CHECKING: + from fastapi import FastAPI +else: + FastAPI = Any + sys.path.append(os.getcwd()) config_filename = "litellm.secrets" @@ -39,25 +46,252 @@ def append_query_params(url, params) -> str: return modified_url # type: ignore -def run_ollama_serve(): - try: - command = ["ollama", "serve"] +class ProxyInitializationHelpers: + @staticmethod + def _echo_litellm_version(): + pkg_version = importlib.metadata.version("litellm") # type: ignore + click.echo(f"\nLiteLLM: Current Version = {pkg_version}\n") + + @staticmethod + def _run_health_check(host, port): + print("\nLiteLLM: Health Testing models in config") # noqa + response = httpx.get(url=f"http://{host}:{port}/health") + print(json.dumps(response.json(), indent=4)) # noqa + + @staticmethod + def _run_test_chat_completion( + host: str, + port: int, + model: str, + test: Union[bool, str], + ): + request_model = model or "gpt-3.5-turbo" + click.echo( + f"\nLiteLLM: Making a test ChatCompletions request to your proxy. Model={request_model}" + ) + import openai + + api_base = f"http://{host}:{port}" + if isinstance(test, str): + api_base = test + else: + raise ValueError("Invalid test value") + client = openai.OpenAI(api_key="My API Key", base_url=api_base) + + response = client.chat.completions.create( + model=request_model, + messages=[ + { + "role": "user", + "content": "this is a test request, write a short poem", + } + ], + max_tokens=256, + ) + click.echo(f"\nLiteLLM: response from proxy {response}") - with open(os.devnull, "w") as devnull: - subprocess.Popen(command, stdout=devnull, stderr=devnull) - except Exception as e: print( # noqa - f""" - LiteLLM Warning: proxy started with `ollama` model\n`ollama serve` failed with Exception{e}. \nEnsure you run `ollama serve` + f"\n LiteLLM: Making a test ChatCompletions + streaming r equest to proxy. Model={request_model}" + ) + + stream_response = client.chat.completions.create( + model=request_model, + messages=[ + { + "role": "user", + "content": "this is a test request, write a short poem", + } + ], + stream=True, + ) + for chunk in stream_response: + click.echo(f"LiteLLM: streaming response from proxy {chunk}") + print("\n making completion request to proxy") # noqa + completion_response = client.completions.create( + model=request_model, prompt="this is a test request, write a short poem" + ) + print(completion_response) # noqa + + @staticmethod + def _get_default_unvicorn_init_args( + host: str, + port: int, + log_config: Optional[str] = None, + ) -> dict: """ + Get the arguments for `uvicorn` worker + """ + import litellm + + uvicorn_args = { + "app": "litellm.proxy.proxy_server:app", + "host": host, + "port": port, + } + if log_config is not None: + print(f"Using log_config: {log_config}") # noqa + uvicorn_args["log_config"] = log_config + elif litellm.json_logs: + print("Using json logs. Setting log_config to None.") # noqa + uvicorn_args["log_config"] = None + return uvicorn_args + + @staticmethod + def _init_hypercorn_server( + app: FastAPI, + host: str, + port: int, + ssl_certfile_path: str, + ssl_keyfile_path: str, + ): + """ + Initialize litellm with `hypercorn` + """ + import asyncio + + from hypercorn.asyncio import serve + from hypercorn.config import Config + + print( # noqa + f"\033[1;32mLiteLLM Proxy: Starting server on {host}:{port} using Hypercorn\033[0m\n" # noqa ) # noqa + config = Config() + config.bind = [f"{host}:{port}"] + if ssl_certfile_path is not None and ssl_keyfile_path is not None: + print( # noqa + f"\033[1;32mLiteLLM Proxy: Using SSL with certfile: {ssl_certfile_path} and keyfile: {ssl_keyfile_path}\033[0m\n" # noqa + ) + config.certfile = ssl_certfile_path + config.keyfile = ssl_keyfile_path -def is_port_in_use(port): - import socket + # hypercorn serve raises a type warning when passing a fast api app - even though fast API is a valid type + asyncio.run(serve(app, config)) # type: ignore - with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: - return s.connect_ex(("localhost", port)) == 0 + @staticmethod + def _run_gunicorn_server( + host: str, + port: int, + app: FastAPI, + num_workers: int, + ssl_certfile_path: str, + ssl_keyfile_path: str, + ): + """ + Run litellm with `gunicorn` + """ + if os.name == "nt": + pass + else: + import gunicorn.app.base + + # Gunicorn Application Class + class StandaloneApplication(gunicorn.app.base.BaseApplication): + def __init__(self, app, options=None): + self.options = options or {} # gunicorn options + self.application = app # FastAPI app + super().__init__() + + _endpoint_str = ( + f"curl --location 'http://0.0.0.0:{port}/chat/completions' \\" + ) + curl_command = ( + _endpoint_str + + """ + --header 'Content-Type: application/json' \\ + --data ' { + "model": "gpt-3.5-turbo", + "messages": [ + { + "role": "user", + "content": "what llm are you" + } + ] + }' + \n + """ + ) + print() # noqa + print( # noqa + '\033[1;34mLiteLLM: Test your local proxy with: "litellm --test" This runs an openai.ChatCompletion request to your proxy [In a new terminal tab]\033[0m\n' + ) + print( # noqa + f"\033[1;34mLiteLLM: Curl Command Test for your local proxy\n {curl_command} \033[0m\n" + ) + print( # noqa + "\033[1;34mDocs: https://docs.litellm.ai/docs/simple_proxy\033[0m\n" + ) # noqa + print( # noqa + f"\033[1;34mSee all Router/Swagger docs on http://0.0.0.0:{port} \033[0m\n" + ) # noqa + + def load_config(self): + # note: This Loads the gunicorn config - has nothing to do with LiteLLM Proxy config + if self.cfg is not None: + config = { + key: value + for key, value in self.options.items() + if key in self.cfg.settings and value is not None + } + else: + config = {} + for key, value in config.items(): + if self.cfg is not None: + self.cfg.set(key.lower(), value) + + def load(self): + # gunicorn app function + return self.application + + print( # noqa + f"\033[1;32mLiteLLM Proxy: Starting server on {host}:{port} with {num_workers} workers\033[0m\n" # noqa + ) + gunicorn_options = { + "bind": f"{host}:{port}", + "workers": num_workers, # default is 1 + "worker_class": "uvicorn.workers.UvicornWorker", + "preload": True, # Add the preload flag, + "accesslog": "-", # Log to stdout + "timeout": 600, # default to very high number, bedrock/anthropic.claude-v2:1 can take 30+ seconds for the 1st chunk to come in + "access_log_format": '%(h)s %(l)s %(u)s %(t)s "%(r)s" %(s)s %(b)s', + } + + if ssl_certfile_path is not None and ssl_keyfile_path is not None: + print( # noqa + f"\033[1;32mLiteLLM Proxy: Using SSL with certfile: {ssl_certfile_path} and keyfile: {ssl_keyfile_path}\033[0m\n" # noqa + ) + gunicorn_options["certfile"] = ssl_certfile_path + gunicorn_options["keyfile"] = ssl_keyfile_path + + StandaloneApplication(app=app, options=gunicorn_options).run() # Run gunicorn + + @staticmethod + def _run_ollama_serve(): + try: + command = ["ollama", "serve"] + + with open(os.devnull, "w") as devnull: + subprocess.Popen(command, stdout=devnull, stderr=devnull) + except Exception as e: + print( # noqa + f""" + LiteLLM Warning: proxy started with `ollama` model\n`ollama serve` failed with Exception{e}. \nEnsure you run `ollama serve` + """ + ) # noqa + + @staticmethod + def _is_port_in_use(port): + import socket + + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + return s.connect_ex(("localhost", port)) == 0 + + @staticmethod + def _get_loop_type(): + """Helper function to determine the event loop type based on platform""" + if sys.platform in ("win32", "cygwin", "cli"): + return None # Let uvicorn choose the default loop on Windows + return "uvloop" @click.command() @@ -282,136 +516,15 @@ def run_server( # noqa: PLR0915 save_worker_config, ) if version is True: - pkg_version = importlib.metadata.version("litellm") # type: ignore - click.echo(f"\nLiteLLM: Current Version = {pkg_version}\n") + ProxyInitializationHelpers._echo_litellm_version() return if model and "ollama" in model and api_base is None: - run_ollama_serve() - import httpx - - if test_async is True: - import concurrent - import time - - api_base = f"http://{host}:{port}" - - def _make_openai_completion(): - data = { - "model": "gpt-3.5-turbo", - "messages": [ - {"role": "user", "content": "Write a short poem about the moon"} - ], - } - - response = httpx.post("http://0.0.0.0:4000/queue/request", json=data) - - response = response.json() - - while True: - try: - url = response["url"] - polling_url = f"{api_base}{url}" - polling_response = httpx.get(polling_url) - polling_response = polling_response.json() - print("\n RESPONSE FROM POLLING JOB", polling_response) # noqa - status = polling_response["status"] - if status == "finished": - polling_response["result"] - break - print( # noqa - f"POLLING JOB{polling_url}\nSTATUS: {status}, \n Response {polling_response}" # noqa - ) # noqa - time.sleep(0.5) - except Exception as e: - print("got exception in polling", e) # noqa - break - - # Number of concurrent calls (you can adjust this) - concurrent_calls = num_requests - - # List to store the futures of concurrent calls - futures = [] - start_time = time.time() - # Make concurrent calls - with concurrent.futures.ThreadPoolExecutor( # type: ignore - max_workers=concurrent_calls - ) as executor: - for _ in range(concurrent_calls): - futures.append(executor.submit(_make_openai_completion)) - - # Wait for all futures to complete - concurrent.futures.wait(futures) # type: ignore - - # Summarize the results - successful_calls = 0 - failed_calls = 0 - - for future in futures: - if future.done(): - if future.result() is not None: - successful_calls += 1 - else: - failed_calls += 1 - end_time = time.time() - print(f"Elapsed Time: {end_time-start_time}") # noqa - print(f"Load test Summary:") # noqa - print(f"Total Requests: {concurrent_calls}") # noqa - print(f"Successful Calls: {successful_calls}") # noqa - print(f"Failed Calls: {failed_calls}") # noqa + ProxyInitializationHelpers._run_ollama_serve() + if health is True: + ProxyInitializationHelpers._run_health_check(host, port) return - if health is not False: - - print("\nLiteLLM: Health Testing models in config") # noqa - response = httpx.get(url=f"http://{host}:{port}/health") - print(json.dumps(response.json(), indent=4)) # noqa - return - if test is not False: - request_model = model or "gpt-3.5-turbo" - click.echo( - f"\nLiteLLM: Making a test ChatCompletions request to your proxy. Model={request_model}" - ) - import openai - - if test is True: # flag value set - api_base = f"http://{host}:{port}" - else: - api_base = test - client = openai.OpenAI(api_key="My API Key", base_url=api_base) - - response = client.chat.completions.create( - model=request_model, - messages=[ - { - "role": "user", - "content": "this is a test request, write a short poem", - } - ], - max_tokens=256, - ) - click.echo(f"\nLiteLLM: response from proxy {response}") - - print( # noqa - f"\n LiteLLM: Making a test ChatCompletions + streaming r equest to proxy. Model={request_model}" - ) - - response = client.chat.completions.create( - model=request_model, - messages=[ - { - "role": "user", - "content": "this is a test request, write a short poem", - } - ], - stream=True, - ) - for chunk in response: - click.echo(f"LiteLLM: streaming response from proxy {chunk}") - print("\n making completion request to proxy") # noqa - response = client.completions.create( - model=request_model, prompt="this is a test request, write a short poem" - ) - print(response) # noqa - + if test is True: + ProxyInitializationHelpers._run_test_chat_completion(host, port, model, test) return else: if headers: @@ -437,11 +550,6 @@ def run_server( # noqa: PLR0915 ) try: import uvicorn - - if os.name == "nt": - pass - else: - import gunicorn.app.base except Exception: raise ImportError( "uvicorn, gunicorn needs to be imported. Run - `pip install 'litellm[proxy]'`" @@ -641,7 +749,7 @@ def run_server( # noqa: PLR0915 print( # noqa f"Unable to connect to DB. DATABASE_URL found in environment, but prisma package not found." # noqa ) - if port == 4000 and is_port_in_use(port): + if port == 4000 and ProxyInitializationHelpers._is_port_in_use(port): port = random.randint(1024, 49152) import litellm @@ -652,18 +760,11 @@ def run_server( # noqa: PLR0915 # DO NOT DELETE - enables global variables to work across files from litellm.proxy.proxy_server import app # noqa - uvicorn_args = { - "app": "litellm.proxy.proxy_server:app", - "host": host, - "port": port, - } - if log_config is not None: - print(f"Using log_config: {log_config}") # noqa - uvicorn_args["log_config"] = log_config - elif litellm.json_logs: - print("Using json logs. Setting log_config to None.") # noqa - uvicorn_args["log_config"] = None - + uvicorn_args = ProxyInitializationHelpers._get_default_unvicorn_init_args( + host=host, + port=port, + log_config=log_config, + ) if run_gunicorn is False and run_hypercorn is False: if ssl_certfile_path is not None and ssl_keyfile_path is not None: print( # noqa @@ -671,114 +772,32 @@ def run_server( # noqa: PLR0915 ) uvicorn_args["ssl_keyfile"] = ssl_keyfile_path uvicorn_args["ssl_certfile"] = ssl_certfile_path + + loop_type = ProxyInitializationHelpers._get_loop_type() + if loop_type: + uvicorn_args["loop"] = loop_type + uvicorn.run( **uvicorn_args, - loop="uvloop", workers=num_workers, ) elif run_gunicorn is True: - # Gunicorn Application Class - class StandaloneApplication(gunicorn.app.base.BaseApplication): - def __init__(self, app, options=None): - self.options = options or {} # gunicorn options - self.application = app # FastAPI app - super().__init__() - - _endpoint_str = ( - f"curl --location 'http://0.0.0.0:{port}/chat/completions' \\" - ) - curl_command = ( - _endpoint_str - + """ - --header 'Content-Type: application/json' \\ - --data ' { - "model": "gpt-3.5-turbo", - "messages": [ - { - "role": "user", - "content": "what llm are you" - } - ] - }' - \n - """ - ) - print() # noqa - print( # noqa - '\033[1;34mLiteLLM: Test your local proxy with: "litellm --test" This runs an openai.ChatCompletion request to your proxy [In a new terminal tab]\033[0m\n' - ) - print( # noqa - f"\033[1;34mLiteLLM: Curl Command Test for your local proxy\n {curl_command} \033[0m\n" - ) - print( # noqa - "\033[1;34mDocs: https://docs.litellm.ai/docs/simple_proxy\033[0m\n" - ) # noqa - print( # noqa - f"\033[1;34mSee all Router/Swagger docs on http://0.0.0.0:{port} \033[0m\n" - ) # noqa - - def load_config(self): - # note: This Loads the gunicorn config - has nothing to do with LiteLLM Proxy config - if self.cfg is not None: - config = { - key: value - for key, value in self.options.items() - if key in self.cfg.settings and value is not None - } - else: - config = {} - for key, value in config.items(): - if self.cfg is not None: - self.cfg.set(key.lower(), value) - - def load(self): - # gunicorn app function - return self.application - - print( # noqa - f"\033[1;32mLiteLLM Proxy: Starting server on {host}:{port} with {num_workers} workers\033[0m\n" # noqa + ProxyInitializationHelpers._run_gunicorn_server( + host=host, + port=port, + app=app, + num_workers=num_workers, + ssl_certfile_path=ssl_certfile_path, + ssl_keyfile_path=ssl_keyfile_path, ) - gunicorn_options = { - "bind": f"{host}:{port}", - "workers": num_workers, # default is 1 - "worker_class": "uvicorn.workers.UvicornWorker", - "preload": True, # Add the preload flag, - "accesslog": "-", # Log to stdout - "timeout": 600, # default to very high number, bedrock/anthropic.claude-v2:1 can take 30+ seconds for the 1st chunk to come in - "access_log_format": '%(h)s %(l)s %(u)s %(t)s "%(r)s" %(s)s %(b)s', - } - - if ssl_certfile_path is not None and ssl_keyfile_path is not None: - print( # noqa - f"\033[1;32mLiteLLM Proxy: Using SSL with certfile: {ssl_certfile_path} and keyfile: {ssl_keyfile_path}\033[0m\n" # noqa - ) - gunicorn_options["certfile"] = ssl_certfile_path - gunicorn_options["keyfile"] = ssl_keyfile_path - - StandaloneApplication( - app=app, options=gunicorn_options - ).run() # Run gunicorn elif run_hypercorn is True: - import asyncio - - from hypercorn.asyncio import serve - from hypercorn.config import Config - - print( # noqa - f"\033[1;32mLiteLLM Proxy: Starting server on {host}:{port} using Hypercorn\033[0m\n" # noqa - ) # noqa - config = Config() - config.bind = [f"{host}:{port}"] - - if ssl_certfile_path is not None and ssl_keyfile_path is not None: - print( # noqa - f"\033[1;32mLiteLLM Proxy: Using SSL with certfile: {ssl_certfile_path} and keyfile: {ssl_keyfile_path}\033[0m\n" # noqa - ) - config.certfile = ssl_certfile_path - config.keyfile = ssl_keyfile_path - - # hypercorn serve raises a type warning when passing a fast api app - even though fast API is a valid type - asyncio.run(serve(app, config)) # type: ignore + ProxyInitializationHelpers._init_hypercorn_server( + app=app, + host=host, + port=port, + ssl_certfile_path=ssl_certfile_path, + ssl_keyfile_path=ssl_keyfile_path, + ) if __name__ == "__main__": diff --git a/tests/litellm/proxy/test_proxy_cli.py b/tests/litellm/proxy/test_proxy_cli.py new file mode 100644 index 0000000000..6e1d70553f --- /dev/null +++ b/tests/litellm/proxy/test_proxy_cli.py @@ -0,0 +1,167 @@ +import importlib +import json +import os +import socket +import subprocess +import sys +from unittest.mock import MagicMock, mock_open, patch + +import click +import httpx +import pytest +from fastapi import FastAPI +from fastapi.testclient import TestClient + +sys.path.insert( + 0, os.path.abspath("../../..") +) # Adds the parent directory to the system-path + +import litellm +from litellm.proxy.proxy_cli import ProxyInitializationHelpers + + +class TestProxyInitializationHelpers: + + @patch("importlib.metadata.version") + @patch("click.echo") + def test_echo_litellm_version(self, mock_echo, mock_version): + # Setup + mock_version.return_value = "1.0.0" + + # Execute + ProxyInitializationHelpers._echo_litellm_version() + + # Assert + mock_version.assert_called_once_with("litellm") + mock_echo.assert_called_once_with("\nLiteLLM: Current Version = 1.0.0\n") + + @patch("httpx.get") + @patch("builtins.print") + @patch("json.dumps") + def test_run_health_check(self, mock_dumps, mock_print, mock_get): + # Setup + mock_response = MagicMock() + mock_response.json.return_value = {"status": "healthy"} + mock_get.return_value = mock_response + mock_dumps.return_value = '{"status": "healthy"}' + + # Execute + ProxyInitializationHelpers._run_health_check("localhost", 8000) + + # Assert + mock_get.assert_called_once_with(url="http://localhost:8000/health") + mock_response.json.assert_called_once() + mock_dumps.assert_called_once_with({"status": "healthy"}, indent=4) + + @patch("openai.OpenAI") + @patch("click.echo") + @patch("builtins.print") + def test_run_test_chat_completion(self, mock_print, mock_echo, mock_openai): + # Setup + mock_client = MagicMock() + mock_openai.return_value = mock_client + + mock_response = MagicMock() + mock_client.chat.completions.create.return_value = mock_response + + mock_stream_response = MagicMock() + mock_stream_response.__iter__.return_value = [MagicMock(), MagicMock()] + mock_client.chat.completions.create.side_effect = [ + mock_response, + mock_stream_response, + ] + + # Execute + with pytest.raises(ValueError, match="Invalid test value"): + ProxyInitializationHelpers._run_test_chat_completion( + "localhost", 8000, "gpt-3.5-turbo", True + ) + + # Test with valid string test value + ProxyInitializationHelpers._run_test_chat_completion( + "localhost", 8000, "gpt-3.5-turbo", "http://test-url" + ) + + # Assert + mock_openai.assert_called_once_with( + api_key="My API Key", base_url="http://test-url" + ) + mock_client.chat.completions.create.assert_called() + + def test_get_default_unvicorn_init_args(self): + # Test without log_config + args = ProxyInitializationHelpers._get_default_unvicorn_init_args( + "localhost", 8000 + ) + assert args["app"] == "litellm.proxy.proxy_server:app" + assert args["host"] == "localhost" + assert args["port"] == 8000 + + # Test with log_config + args = ProxyInitializationHelpers._get_default_unvicorn_init_args( + "localhost", 8000, "log_config.json" + ) + assert args["log_config"] == "log_config.json" + + # Test with json_logs=True + with patch("litellm.json_logs", True): + args = ProxyInitializationHelpers._get_default_unvicorn_init_args( + "localhost", 8000 + ) + assert args["log_config"] is None + + @patch("asyncio.run") + @patch("builtins.print") + def test_init_hypercorn_server(self, mock_print, mock_asyncio_run): + # Setup + mock_app = MagicMock() + + # Execute + ProxyInitializationHelpers._init_hypercorn_server( + mock_app, "localhost", 8000, None, None + ) + + # Assert + mock_asyncio_run.assert_called_once() + + # Test with SSL + ProxyInitializationHelpers._init_hypercorn_server( + mock_app, "localhost", 8000, "cert.pem", "key.pem" + ) + + @patch("subprocess.Popen") + def test_run_ollama_serve(self, mock_popen): + # Execute + ProxyInitializationHelpers._run_ollama_serve() + + # Assert + mock_popen.assert_called_once() + + # Test exception handling + mock_popen.side_effect = Exception("Test exception") + ProxyInitializationHelpers._run_ollama_serve() # Should not raise + + @patch("socket.socket") + def test_is_port_in_use(self, mock_socket): + # Setup for port in use + mock_socket_instance = MagicMock() + mock_socket_instance.connect_ex.return_value = 0 + mock_socket.return_value.__enter__.return_value = mock_socket_instance + + # Execute and Assert + assert ProxyInitializationHelpers._is_port_in_use(8000) is True + + # Setup for port not in use + mock_socket_instance.connect_ex.return_value = 1 + + # Execute and Assert + assert ProxyInitializationHelpers._is_port_in_use(8000) is False + + def test_get_loop_type(self): + # Test on Windows + with patch("sys.platform", "win32"): + assert ProxyInitializationHelpers._get_loop_type() is None + + # Test on Linux + with patch("sys.platform", "linux"): + assert ProxyInitializationHelpers._get_loop_type() == "uvloop" diff --git a/tests/local_testing/test_completion.py b/tests/local_testing/test_completion.py index d84825ed89..f5c75b80bb 100644 --- a/tests/local_testing/test_completion.py +++ b/tests/local_testing/test_completion.py @@ -11,7 +11,7 @@ import os sys.path.insert( 0, os.path.abspath("../..") -) # Adds the parent directory to the system path +) # Adds the parent directory to the system-path import os