diff --git a/litellm/proxy/hooks/parallel_request_limiter.py b/litellm/proxy/hooks/parallel_request_limiter.py index 14f4c330c..4a321d009 100644 --- a/litellm/proxy/hooks/parallel_request_limiter.py +++ b/litellm/proxy/hooks/parallel_request_limiter.py @@ -1,33 +1,70 @@ from typing import Optional +import litellm from litellm.caching import DualCache +from litellm.integrations.custom_logger import CustomLogger from fastapi import HTTPException -async def max_parallel_request_allow_request(max_parallel_requests: Optional[int], api_key: Optional[str], user_api_key_cache: DualCache): - if api_key is None: - return +class MaxParallelRequestsHandler(CustomLogger): + # Class variables or attributes + def __init__(self): + pass + + def print_verbose(self, print_statement): + if litellm.set_verbose is True: + print(print_statement) # noqa - if max_parallel_requests is None: - return - # CHECK IF REQUEST ALLOWED - request_count_api_key = f"{api_key}_request_count" - current = user_api_key_cache.get_cache(key=request_count_api_key) - if current is None: - user_api_key_cache.set_cache(request_count_api_key, 1) - elif int(current) < max_parallel_requests: - # Increase count for this token - user_api_key_cache.set_cache(request_count_api_key, int(current) + 1) - else: - raise HTTPException(status_code=429, detail="Max parallel request limit reached.") + async def max_parallel_request_allow_request(self, max_parallel_requests: Optional[int], api_key: Optional[str], user_api_key_cache: DualCache): + if api_key is None: + return + + if max_parallel_requests is None: + return + + self.user_api_key_cache = user_api_key_cache # save the api key cache for updating the value + + # CHECK IF REQUEST ALLOWED + request_count_api_key = f"{api_key}_request_count" + current = user_api_key_cache.get_cache(key=request_count_api_key) + self.print_verbose(f"current: {current}") + if current is None: + user_api_key_cache.set_cache(request_count_api_key, 1) + elif int(current) < max_parallel_requests: + # Increase count for this token + user_api_key_cache.set_cache(request_count_api_key, int(current) + 1) + else: + raise HTTPException(status_code=429, detail="Max parallel request limit reached.") -async def max_parallel_request_update_count(api_key: Optional[str], user_api_key_cache: DualCache): - if api_key is None: - return - - request_count_api_key = f"{api_key}_request_count" - # Decrease count for this token - current = user_api_key_cache.get_cache(key=request_count_api_key) or 1 - user_api_key_cache.set_cache(request_count_api_key, int(current) - 1) + async def async_log_success_event(self, kwargs, response_obj, start_time, end_time): + try: + self.print_verbose(f"INSIDE ASYNC SUCCESS LOGGING") + user_api_key = kwargs["litellm_params"]["metadata"]["user_api_key"] + if user_api_key is None: + return + + request_count_api_key = f"{user_api_key}_request_count" + # check if it has collected an entire stream response + self.print_verbose(f"'complete_streaming_response' is in kwargs: {'complete_streaming_response' in kwargs}") + if "complete_streaming_response" in kwargs or kwargs["stream"] != True: + # Decrease count for this token + current = self.user_api_key_cache.get_cache(key=request_count_api_key) or 1 + new_val = current - 1 + self.print_verbose(f"updated_value in success call: {new_val}") + self.user_api_key_cache.set_cache(request_count_api_key, new_val) + except Exception as e: + self.print_verbose(e) # noqa - return \ No newline at end of file + async def async_log_failure_call(self, api_key, user_api_key_cache): + try: + if api_key is None: + return + + request_count_api_key = f"{api_key}_request_count" + # Decrease count for this token + current = self.user_api_key_cache.get_cache(key=request_count_api_key) or 1 + new_val = current - 1 + self.print_verbose(f"updated_value in failure call: {new_val}") + self.user_api_key_cache.set_cache(request_count_api_key, new_val) + except Exception as e: + self.print_verbose(f"An exception occurred - {str(e)}") # noqa \ No newline at end of file diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index 14d487f9e..47630163a 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -1,4 +1,4 @@ -import sys, os, platform, time, copy, re, asyncio +import sys, os, platform, time, copy, re, asyncio, inspect import threading, ast import shutil, random, traceback, requests from datetime import datetime, timedelta @@ -94,7 +94,6 @@ import litellm from litellm.proxy.utils import ( PrismaClient, get_instance_fn, - CallHooks, ProxyLogging ) import pydantic @@ -198,8 +197,8 @@ user_custom_auth = None use_background_health_checks = None health_check_interval = None health_check_results = {} -call_hooks = CallHooks(user_api_key_cache=user_api_key_cache) -proxy_logging_obj: Optional[ProxyLogging] = None +### INITIALIZE GLOBAL LOGGING OBJECT ### +proxy_logging_obj = ProxyLogging(user_api_key_cache=user_api_key_cache) ### REDIS QUEUE ### async_result = None celery_app_conn = None @@ -309,10 +308,9 @@ async def user_api_key_auth(request: Request, api_key: str = fastapi.Security(ap ) def prisma_setup(database_url: Optional[str]): - global prisma_client, proxy_logging_obj - ### INITIALIZE GLOBAL LOGGING OBJECT ### - proxy_logging_obj = ProxyLogging() + global prisma_client, proxy_logging_obj, user_api_key_cache + proxy_logging_obj._init_litellm_callbacks() if database_url is not None: try: prisma_client = PrismaClient(database_url=database_url, proxy_logging_obj=proxy_logging_obj) @@ -390,6 +388,10 @@ async def track_cost_callback( completion=output_text ) print("streaming response_cost", response_cost) + user_api_key = kwargs["litellm_params"]["metadata"].get("user_api_key", None) + print(f"user_api_key - {user_api_key}; prisma_client - {prisma_client}") + if user_api_key and prisma_client: + await update_prisma_database(token=user_api_key, response_cost=response_cost) elif kwargs["stream"] == False: # for non streaming responses input_text = kwargs.get("messages", "") print(f"type of input_text: {type(input_text)}") @@ -400,10 +402,10 @@ async def track_cost_callback( print(f"received completion response: {completion_response}") print(f"regular response_cost: {response_cost}") - user_api_key = kwargs["litellm_params"]["metadata"].get("user_api_key", None) - print(f"user_api_key - {user_api_key}; prisma_client - {prisma_client}") - if user_api_key and prisma_client: - await update_prisma_database(token=user_api_key, response_cost=response_cost) + user_api_key = kwargs["litellm_params"]["metadata"].get("user_api_key", None) + print(f"user_api_key - {user_api_key}; prisma_client - {prisma_client}") + if user_api_key and prisma_client: + await update_prisma_database(token=user_api_key, response_cost=response_cost) except Exception as e: print(f"error in tracking cost callback - {str(e)}") @@ -475,9 +477,6 @@ def load_router_config(router: Optional[litellm.Router], config_file_path: str): print_verbose(f"Loaded config YAML (api_key and environment_variables are not shown):\n{json.dumps(printed_yaml, indent=2)}") - ## ROUTER CONFIG - cache_responses = False - ## ENVIRONMENT VARIABLES environment_variables = config.get('environment_variables', None) if environment_variables: @@ -554,6 +553,8 @@ def load_router_config(router: Optional[litellm.Router], config_file_path: str): else: setattr(litellm, key, value) + + ## GENERAL SERVER SETTINGS (e.g. master key,..) # do this after initializing litellm, to ensure sentry logging works for proxylogging general_settings = config.get("general_settings", {}) if general_settings is None: @@ -589,18 +590,41 @@ def load_router_config(router: Optional[litellm.Router], config_file_path: str): use_background_health_checks = general_settings.get("background_health_checks", False) health_check_interval = general_settings.get("health_check_interval", 300) + router_params: dict = { + "num_retries": 3, + "cache_responses": litellm.cache != None # cache if user passed in cache values + } ## MODEL LIST model_list = config.get('model_list', None) if model_list: - router = litellm.Router(model_list=model_list, num_retries=3, cache_responses=cache_responses) + router_params["model_list"] = model_list print(f"\033[32mLiteLLM: Proxy initialized with Config, Set models:\033[0m") for model in model_list: print(f"\033[32m {model.get('model_name', '')}\033[0m") litellm_model_name = model["litellm_params"]["model"] - if "ollama" in litellm_model_name: + litellm_model_api_base = model["litellm_params"].get("api_base", None) + if "ollama" in litellm_model_name and litellm_model_api_base is None: run_ollama_serve() - call_hooks.update_router_config(litellm_settings=litellm_settings, model_list=model_list, general_settings=general_settings) + ## ROUTER SETTINGS (e.g. routing_strategy, ...) + router_settings = config.get("router_settings", None) + if router_settings and isinstance(router_settings, dict): + arg_spec = inspect.getfullargspec(litellm.Router) + # model list already set + exclude_args = { + "self", + "model_list", + } + + available_args = [ + x for x in arg_spec.args if x not in exclude_args + ] + + for k, v in router_settings.items(): + if k in available_args: + router_params[k] = v + + router = litellm.Router(**router_params) # type:ignore return router, model_list, general_settings async def generate_key_helper_fn(duration: Optional[str], models: list, aliases: dict, config: dict, spend: float, token: Optional[str]=None, user_id: Optional[str]=None, max_parallel_requests: Optional[int]=None): @@ -772,10 +796,13 @@ def data_generator(response): yield f"data: {json.dumps(chunk)}\n\n" async def async_data_generator(response, user_api_key_dict): - global call_hooks - print_verbose("inside generator") async for chunk in response: + # try: + # await proxy_logging_obj.pre_call_hook(user_api_key_dict=user_api_key_dict, data=None, call_type="completion") + # except Exception as e: + # print(f"An exception occurred - {str(e)}") + print_verbose(f"returned chunk: {chunk}") try: yield f"data: {json.dumps(chunk.dict())}\n\n" @@ -946,7 +973,7 @@ async def completion(request: Request, model: Optional[str] = None, user_api_key @router.post("/chat/completions", dependencies=[Depends(user_api_key_auth)], tags=["chat/completions"]) @router.post("/openai/deployments/{model:path}/chat/completions", dependencies=[Depends(user_api_key_auth)], tags=["chat/completions"]) # azure compatible endpoint async def chat_completion(request: Request, model: Optional[str] = None, user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), background_tasks: BackgroundTasks = BackgroundTasks()): - global general_settings, user_debug, call_hooks + global general_settings, user_debug, proxy_logging_obj try: data = {} data = await request.json() # type: ignore @@ -992,7 +1019,7 @@ async def chat_completion(request: Request, model: Optional[str] = None, user_ap data["api_base"] = user_api_base ### CALL HOOKS ### - modify incoming data before calling the model - data = await call_hooks.pre_call(user_api_key_dict=user_api_key_dict, data=data, call_type="completion") + data = await proxy_logging_obj.pre_call_hook(user_api_key_dict=user_api_key_dict, data=data, call_type="completion") ### ROUTE THE REQUEST ### router_model_names = [m["model_name"] for m in llm_model_list] if llm_model_list is not None else [] @@ -1009,15 +1036,10 @@ async def chat_completion(request: Request, model: Optional[str] = None, user_ap if 'stream' in data and data['stream'] == True: # use generate_responses to stream responses return StreamingResponse(async_data_generator(user_api_key_dict=user_api_key_dict, response=response), media_type='text/event-stream') - ### CALL HOOKS ### - modify outgoing response - response = await call_hooks.post_call_success(user_api_key_dict=user_api_key_dict, response=response, call_type="completion") - background_tasks.add_task(log_input_output, request, response) # background task for logging to OTEL return response except Exception as e: - print(f"Exception received: {str(e)}") - raise e - await call_hooks.post_call_failure(original_exception=e, user_api_key_dict=user_api_key_dict) + await proxy_logging_obj.post_call_failure_hook(user_api_key_dict=user_api_key_dict, original_exception=e) print(f"\033[1;31mAn error occurred: {e}\n\n Debug this by setting `--debug`, e.g. `litellm --model gpt-3.5-turbo --debug`") router_model_names = [m["model_name"] for m in llm_model_list] if llm_model_list is not None else [] if llm_router is not None and data.get("model", "") in router_model_names: @@ -1052,7 +1074,7 @@ async def chat_completion(request: Request, model: Optional[str] = None, user_ap @router.post("/v1/embeddings", dependencies=[Depends(user_api_key_auth)], response_class=ORJSONResponse) @router.post("/embeddings", dependencies=[Depends(user_api_key_auth)], response_class=ORJSONResponse) async def embeddings(request: Request, user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), background_tasks: BackgroundTasks = BackgroundTasks()): - global call_hooks + global proxy_logging_obj try: # Use orjson to parse JSON data, orjson speeds up requests significantly body = await request.body() @@ -1095,8 +1117,8 @@ async def embeddings(request: Request, user_api_key_dict: UserAPIKeyAuth = Depen data["input"] = input_list break - ### CALL HOOKS ### - modify incoming data before calling the model - data = await call_hooks.pre_call(user_api_key_dict=user_api_key_dict, data=data, call_type="embeddings") + ### CALL HOOKS ### - modify incoming data / reject request before calling the model + data = await proxy_logging_obj.pre_call_hook(user_api_key_dict=user_api_key_dict, data=data, call_type="embeddings") ## ROUTE TO CORRECT ENDPOINT ## if llm_router is not None and data["model"] in router_model_names: # model in router model list @@ -1107,12 +1129,9 @@ async def embeddings(request: Request, user_api_key_dict: UserAPIKeyAuth = Depen response = await litellm.aembedding(**data) background_tasks.add_task(log_input_output, request, response) # background task for logging to OTEL - ### CALL HOOKS ### - modify outgoing response - data = call_hooks.post_call_success(user_api_key_dict=user_api_key_dict, response=response, call_type="embeddings") - return response except Exception as e: - await call_hooks.post_call_failure(user_api_key_dict=user_api_key_dict, original_exception=e) + await proxy_logging_obj.post_call_failure_hook(user_api_key_dict=user_api_key_dict, original_exception=e) traceback.print_exc() raise e @@ -1139,7 +1158,7 @@ async def generate_key_fn(request: Request, data: GenerateKeyRequest, Authorizat - user_id: (str) Unique user id - used for tracking spend across multiple keys for same user id. """ # data = await request.json() - data_json = json.loads(data.json()) # type: ignore + data_json = data.json() # type: ignore response = await generate_key_helper_fn(**data_json) return GenerateKeyResponse(key=response["token"], expires=response["expires"], user_id=response["user_id"]) diff --git a/litellm/proxy/utils.py b/litellm/proxy/utils.py index 688820e83..e972eff4d 100644 --- a/litellm/proxy/utils.py +++ b/litellm/proxy/utils.py @@ -3,7 +3,7 @@ import os, subprocess, hashlib, importlib, asyncio import litellm, backoff from litellm.proxy._types import UserAPIKeyAuth from litellm.caching import DualCache -from litellm.proxy.hooks.parallel_request_limiter import max_parallel_request_allow_request, max_parallel_request_update_count +from litellm.proxy.hooks.parallel_request_limiter import MaxParallelRequestsHandler def print_verbose(print_statement): if litellm.set_verbose: @@ -11,32 +11,35 @@ def print_verbose(print_statement): ### LOGGING ### class ProxyLogging: """ - Logging for proxy. + Logging/Custom Handlers for proxy. - Implemented mainly to log successful/failed db read/writes. - - Currently just logs this to a provided sentry integration. + Implemented mainly to: + - log successful/failed db read/writes + - support the max parallel request integration """ - def __init__(self,): + def __init__(self, user_api_key_cache: DualCache): ## INITIALIZE LITELLM CALLBACKS ## - self._init_litellm_callbacks() + self.call_details: dict = {} + self.call_details["user_api_key_cache"] = user_api_key_cache + self.max_parallel_request_limiter = MaxParallelRequestsHandler() pass - def _init_litellm_callbacks(self): - if len(litellm.callbacks) > 0: - for callback in litellm.callbacks: - if callback not in litellm.input_callback: - litellm.input_callback.append(callback) - if callback not in litellm.success_callback: - litellm.success_callback.append(callback) - if callback not in litellm.failure_callback: - litellm.failure_callback.append(callback) - if callback not in litellm._async_success_callback: - litellm._async_success_callback.append(callback) - if callback not in litellm._async_failure_callback: - litellm._async_failure_callback.append(callback) + def _init_litellm_callbacks(self): + litellm.callbacks.append(self.max_parallel_request_limiter) + for callback in litellm.callbacks: + if callback not in litellm.input_callback: + litellm.input_callback.append(callback) + if callback not in litellm.success_callback: + litellm.success_callback.append(callback) + if callback not in litellm.failure_callback: + litellm.failure_callback.append(callback) + if callback not in litellm._async_success_callback: + litellm._async_success_callback.append(callback) + if callback not in litellm._async_failure_callback: + litellm._async_failure_callback.append(callback) + if ( len(litellm.input_callback) > 0 or len(litellm.success_callback) > 0 @@ -53,6 +56,30 @@ class ProxyLogging: callback_list=callback_list ) + async def pre_call_hook(self, user_api_key_dict: UserAPIKeyAuth, data: dict, call_type: Literal["completion", "embeddings"]): + """ + Allows users to modify/reject the incoming request to the proxy, without having to deal with parsing Request body. + + Covers: + 1. /chat/completions + 2. /embeddings + """ + try: + self.call_details["data"] = data + self.call_details["call_type"] = call_type + + ## check if max parallel requests set + if user_api_key_dict.max_parallel_requests is not None: + ## if set, check if request allowed + await self.max_parallel_request_limiter.max_parallel_request_allow_request( + max_parallel_requests=user_api_key_dict.max_parallel_requests, + api_key=user_api_key_dict.api_key, + user_api_key_cache=self.call_details["user_api_key_cache"]) + + return data + except Exception as e: + raise e + async def success_handler(self, *args, **kwargs): """ Log successful db read/writes @@ -67,6 +94,27 @@ class ProxyLogging: """ if litellm.utils.capture_exception: litellm.utils.capture_exception(error=original_exception) + + async def post_call_failure_hook(self, original_exception: Exception, user_api_key_dict: UserAPIKeyAuth): + """ + Allows users to raise custom exceptions/log when a call fails, without having to deal with parsing Request body. + + Covers: + 1. /chat/completions + 2. /embeddings + """ + # check if max parallel requests set + if user_api_key_dict.max_parallel_requests is not None: + ## decrement call count if call failed + if (hasattr(original_exception, "status_code") + and original_exception.status_code == 429 + and "Max parallel request limit reached" in str(original_exception)): + pass # ignore failed calls due to max limit being reached + else: + await self.max_parallel_request_limiter.async_log_failure_call( + api_key=user_api_key_dict.api_key, + user_api_key_cache=self.call_details["user_api_key_cache"]) + return ### DB CONNECTOR ### @@ -290,65 +338,4 @@ def get_instance_fn(value: str, config_file_path: Optional[str] = None) -> Any: except Exception as e: raise e -### CALL HOOKS ### -class CallHooks: - """ - Allows users to modify the incoming request / output to the proxy, without having to deal with parsing Request body. - - Covers: - 1. /chat/completions - 2. /embeddings - """ - - def __init__(self, user_api_key_cache: DualCache): - self.call_details: dict = {} - self.call_details["user_api_key_cache"] = user_api_key_cache - - def update_router_config(self, litellm_settings: dict, general_settings: dict, model_list: list): - self.call_details["litellm_settings"] = litellm_settings - self.call_details["general_settings"] = general_settings - self.call_details["model_list"] = model_list - - async def pre_call(self, user_api_key_dict: UserAPIKeyAuth, data: dict, call_type: Literal["completion", "embeddings"]): - try: - self.call_details["data"] = data - self.call_details["call_type"] = call_type - - ## check if max parallel requests set - if user_api_key_dict.max_parallel_requests is not None: - ## if set, check if request allowed - await max_parallel_request_allow_request( - max_parallel_requests=user_api_key_dict.max_parallel_requests, - api_key=user_api_key_dict.api_key, - user_api_key_cache=self.call_details["user_api_key_cache"]) - - return data - except Exception as e: - raise e - - async def post_call_success(self, user_api_key_dict: UserAPIKeyAuth, response: Optional[Any]=None, call_type: Optional[Literal["completion", "embeddings"]]=None, chunk: Optional[Any]=None): - try: - # check if max parallel requests set - if user_api_key_dict.max_parallel_requests is not None: - ## decrement call, once complete - await max_parallel_request_update_count( - api_key=user_api_key_dict.api_key, - user_api_key_cache=self.call_details["user_api_key_cache"]) - - return response - except Exception as e: - raise e - - async def post_call_failure(self, original_exception: Exception, user_api_key_dict: UserAPIKeyAuth): - # check if max parallel requests set - if user_api_key_dict.max_parallel_requests is not None: - ## decrement call count if call failed - if (hasattr(original_exception, "status_code") - and original_exception.status_code == 429 - and "Max parallel request limit reached" in str(original_exception)): - pass # ignore failed calls due to max limit being reached - else: - await max_parallel_request_update_count( - api_key=user_api_key_dict.api_key, - user_api_key_cache=self.call_details["user_api_key_cache"]) - return \ No newline at end of file + \ No newline at end of file diff --git a/litellm/tests/test_proxy_server_keys.py b/litellm/tests/test_proxy_server_keys.py index c5984ba13..239442b2c 100644 --- a/litellm/tests/test_proxy_server_keys.py +++ b/litellm/tests/test_proxy_server_keys.py @@ -25,7 +25,7 @@ from fastapi.testclient import TestClient from fastapi import FastAPI from litellm.proxy.proxy_server import router, save_worker_config, startup_event # Replace with the actual module where your FastAPI router is defined filepath = os.path.dirname(os.path.abspath(__file__)) -config_fp = f"{filepath}/test_configs/test_config_custom_auth.yaml" +config_fp = f"{filepath}/test_configs/test_config.yaml" save_worker_config(config=config_fp, model=None, alias=None, api_base=None, api_version=None, debug=False, temperature=None, max_tokens=None, request_timeout=600, max_budget=None, telemetry=False, drop_params=True, add_function_to_prompt=False, headers=None, save=False, use_queue=False) app = FastAPI() app.include_router(router) # Include your router in the test app @@ -100,3 +100,37 @@ def test_add_new_key_max_parallel_limit(client): _run_in_parallel() except Exception as e: pytest.fail(f"LiteLLM Proxy test failed. Exception: {str(e)}") + +def test_add_new_key_max_parallel_limit_streaming(client): + try: + # Your test data + test_data = {"duration": "20m", "max_parallel_requests": 1} + # Your bearer token + token = os.getenv('PROXY_MASTER_KEY') + + headers = { + "Authorization": f"Bearer {token}" + } + response = client.post("/key/generate", json=test_data, headers=headers) + print(f"response: {response.text}") + assert response.status_code == 200 + result = response.json() + def _post_data(): + json_data = {'model': 'azure-model', "messages": [{"role": "user", "content": f"this is a test request, write a short poem {time.time()}"}], "stream": True} + response = client.post("/chat/completions", json=json_data, headers={"Authorization": f"Bearer {result['key']}"}) + return response + def _run_in_parallel(): + with ThreadPoolExecutor(max_workers=2) as executor: + future1 = executor.submit(_post_data) + future2 = executor.submit(_post_data) + + # Obtain the results from the futures + response1 = future1.result() + response2 = future2.result() + if response1.status_code == 429 or response2.status_code == 429: + pass + else: + raise Exception() + _run_in_parallel() + except Exception as e: + pytest.fail(f"LiteLLM Proxy test failed. Exception: {str(e)}") \ No newline at end of file