diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 5e74b6596..9fc7bb3d3 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -3,7 +3,7 @@ repos: rev: 3.8.4 # The version of flake8 to use hooks: - id: flake8 - exclude: ^litellm/tests/|^litellm/proxy/|^litellm/integrations/ + exclude: ^litellm/tests/|^litellm/proxy/proxy_server.py|^litellm/integrations/ additional_dependencies: [flake8-print] files: litellm/.*\.py - repo: local diff --git a/litellm/proxy/_types.py b/litellm/proxy/_types.py index 626ec513c..67cbb6063 100644 --- a/litellm/proxy/_types.py +++ b/litellm/proxy/_types.py @@ -76,12 +76,13 @@ class ModelParams(BaseModel): protected_namespaces = () class GenerateKeyRequest(BaseModel): - duration: str = "1h" - models: list = [] - aliases: dict = {} - config: dict = {} - spend: int = 0 + duration: Optional[str] = "1h" + models: Optional[list] = [] + aliases: Optional[dict] = {} + config: Optional[dict] = {} + spend: Optional[float] = 0 user_id: Optional[str] = None + max_parallel_requests: Optional[int] = None class GenerateKeyResponse(BaseModel): key: str @@ -96,8 +97,17 @@ class DeleteKeyRequest(BaseModel): class UserAPIKeyAuth(BaseModel): # the expected response object for user api key auth + """ + Return the row in the db + """ api_key: Optional[str] = None + models: list = [] + aliases: dict = {} + config: dict = {} + spend: Optional[float] = 0 user_id: Optional[str] = None + max_parallel_requests: Optional[int] = None + duration: str = "1h" class ConfigGeneralSettings(BaseModel): """ diff --git a/litellm/proxy/hooks/__init__.py b/litellm/proxy/hooks/__init__.py new file mode 100644 index 000000000..b6e690fd5 --- /dev/null +++ b/litellm/proxy/hooks/__init__.py @@ -0,0 +1 @@ +from . import * diff --git a/litellm/proxy/hooks/parallel_request_limiter.py b/litellm/proxy/hooks/parallel_request_limiter.py new file mode 100644 index 000000000..14f4c330c --- /dev/null +++ b/litellm/proxy/hooks/parallel_request_limiter.py @@ -0,0 +1,33 @@ +from typing import Optional +from litellm.caching import DualCache +from fastapi import HTTPException + +async def max_parallel_request_allow_request(max_parallel_requests: Optional[int], api_key: Optional[str], user_api_key_cache: DualCache): + if api_key is None: + return + + if max_parallel_requests is None: + return + + # CHECK IF REQUEST ALLOWED + request_count_api_key = f"{api_key}_request_count" + current = user_api_key_cache.get_cache(key=request_count_api_key) + if current is None: + user_api_key_cache.set_cache(request_count_api_key, 1) + elif int(current) < max_parallel_requests: + # Increase count for this token + user_api_key_cache.set_cache(request_count_api_key, int(current) + 1) + else: + raise HTTPException(status_code=429, detail="Max parallel request limit reached.") + + +async def max_parallel_request_update_count(api_key: Optional[str], user_api_key_cache: DualCache): + if api_key is None: + return + + request_count_api_key = f"{api_key}_request_count" + # Decrease count for this token + current = user_api_key_cache.get_cache(key=request_count_api_key) or 1 + user_api_key_cache.set_cache(request_count_api_key, int(current) - 1) + + return \ No newline at end of file diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index e8b0d9e46..843930ad6 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -102,7 +102,7 @@ from litellm.proxy._types import * from litellm.caching import DualCache from litellm.proxy.health_check import perform_health_check litellm.suppress_debug_info = True -from fastapi import FastAPI, Request, HTTPException, status, Depends, BackgroundTasks +from fastapi import FastAPI, Request, HTTPException, status, Depends, BackgroundTasks, Header from fastapi.routing import APIRouter from fastapi.security import OAuth2PasswordBearer from fastapi.encoders import jsonable_encoder @@ -198,7 +198,7 @@ user_custom_auth = None use_background_health_checks = None health_check_interval = None health_check_results = {} -call_hooks = CallHooks() +call_hooks = CallHooks(user_api_key_cache=user_api_key_cache) proxy_logging_obj: Optional[ProxyLogging] = None ### REDIS QUEUE ### async_result = None @@ -259,10 +259,10 @@ async def user_api_key_auth(request: Request, api_key: str = fastapi.Security(ap if prisma_client: ## check for cache hit (In-Memory Cache) valid_token = user_api_key_cache.get_cache(key=api_key) + print(f"valid_token from cache: {valid_token}") if valid_token is None: ## check db - cleaned_api_key = api_key - valid_token = await prisma_client.get_data(token=cleaned_api_key, expires=datetime.utcnow()) + valid_token = await prisma_client.get_data(token=api_key, expires=datetime.utcnow()) user_api_key_cache.set_cache(key=api_key, value=valid_token, ttl=60) elif valid_token is not None: print(f"API Key Cache Hit!") @@ -274,10 +274,10 @@ async def user_api_key_auth(request: Request, api_key: str = fastapi.Security(ap llm_model_list = model_list print("\n new llm router model list", llm_model_list) if len(valid_token.models) == 0: # assume an empty model list means all models are allowed to be called - return_dict = {"api_key": valid_token.token} - if valid_token.user_id: - return_dict["user_id"] = valid_token.user_id - return UserAPIKeyAuth(**return_dict) + api_key = valid_token.token + valid_token_dict = valid_token.model_dump() + valid_token_dict.pop("token", None) + return UserAPIKeyAuth(api_key=api_key, **valid_token_dict) else: data = await request.json() model = data.get("model", None) @@ -285,10 +285,10 @@ async def user_api_key_auth(request: Request, api_key: str = fastapi.Security(ap model = litellm.model_alias_map[model] if model and model not in valid_token.models: raise Exception(f"Token not allowed to access model") - return_dict = {"api_key": valid_token.token} - if valid_token.user_id: - return_dict["user_id"] = valid_token.user_id - return UserAPIKeyAuth(**return_dict) + api_key = valid_token.token + valid_token_dict = valid_token.model_dump() + valid_token.pop("token", None) + return UserAPIKeyAuth(api_key=api_key, **valid_token) else: raise Exception(f"Invalid token") except Exception as e: @@ -588,7 +588,7 @@ def load_router_config(router: Optional[litellm.Router], config_file_path: str): call_hooks.update_router_config(litellm_settings=litellm_settings, model_list=model_list, general_settings=general_settings) return router, model_list, general_settings -async def generate_key_helper_fn(duration_str: Optional[str], models: list, aliases: dict, config: dict, spend: float, token: Optional[str]=None, user_id: Optional[str]=None): +async def generate_key_helper_fn(duration: Optional[str], models: list, aliases: dict, config: dict, spend: float, token: Optional[str]=None, user_id: Optional[str]=None, max_parallel_requests: Optional[int]=None): global prisma_client if prisma_client is None: @@ -617,11 +617,11 @@ async def generate_key_helper_fn(duration_str: Optional[str], models: list, alia else: raise ValueError("Unsupported duration unit") - if duration_str is None: # allow tokens that never expire + if duration is None: # allow tokens that never expire expires = None else: - duration = _duration_in_seconds(duration=duration_str) - expires = datetime.utcnow() + timedelta(seconds=duration) + duration_s = _duration_in_seconds(duration=duration) + expires = datetime.utcnow() + timedelta(seconds=duration_s) aliases_json = json.dumps(aliases) config_json = json.dumps(config) @@ -635,7 +635,8 @@ async def generate_key_helper_fn(duration_str: Optional[str], models: list, alia "aliases": aliases_json, "config": config_json, "spend": spend, - "user_id": user_id + "user_id": user_id, + "max_parallel_requests": max_parallel_requests } new_verification_token = await prisma_client.insert_data(data=verification_token_data) except Exception as e: @@ -755,14 +756,12 @@ def data_generator(response): except: yield f"data: {json.dumps(chunk)}\n\n" -async def async_data_generator(response): +async def async_data_generator(response, user_api_key_dict): global call_hooks print_verbose("inside generator") async for chunk in response: print_verbose(f"returned chunk: {chunk}") - ### CALL HOOKS ### - modify outgoing response - response = call_hooks.post_call_success(chunk=chunk, call_type="completion") try: yield f"data: {json.dumps(chunk.dict())}\n\n" except: @@ -812,36 +811,6 @@ def get_litellm_model_info(model: dict = {}): # if litellm does not have info on the model it should return {} return {} -@app.middleware("http") -async def rate_limit_per_token(request: Request, call_next): - global user_api_key_cache, general_settings - max_parallel_requests = general_settings.get("max_parallel_requests", None) - api_key = request.headers.get("Authorization") - if max_parallel_requests is not None and api_key is not None: # Rate limiting is enabled - api_key = _get_bearer_token(api_key=api_key) - # CHECK IF REQUEST ALLOWED - request_count_api_key = f"{api_key}_request_count" - current = user_api_key_cache.get_cache(key=request_count_api_key) - if current is None: - user_api_key_cache.set_cache(request_count_api_key, 1) - elif int(current) < max_parallel_requests: - # Increase count for this token - user_api_key_cache.set_cache(request_count_api_key, int(current) + 1) - else: - raise HTTPException(status_code=429, detail="Too many requests.") - - - response = await call_next(request) - - # Decrease count for this token - current = user_api_key_cache.get_cache(key=request_count_api_key) - user_api_key_cache.set_cache(request_count_api_key, int(current) - 1) - - return response - else: # Rate limiting is not enabled, just pass the request - response = await call_next(request) - return response - @router.on_event("startup") async def startup_event(): global prisma_client, master_key, use_background_health_checks @@ -868,7 +837,7 @@ async def startup_event(): if prisma_client is not None and master_key is not None: # add master key to db - await generate_key_helper_fn(duration_str=None, models=[], aliases={}, config={}, spend=0, token=master_key) + await generate_key_helper_fn(duration=None, models=[], aliases={}, config={}, spend=0, token=master_key) @router.on_event("shutdown") async def shutdown_event(): @@ -1008,7 +977,7 @@ async def chat_completion(request: Request, model: Optional[str] = None, user_ap data["api_base"] = user_api_base ### CALL HOOKS ### - modify incoming data before calling the model - data = call_hooks.pre_call(data=data, call_type="completion") + data = await call_hooks.pre_call(user_api_key_dict=user_api_key_dict, data=data, call_type="completion") ### ROUTE THE REQUEST ### router_model_names = [m["model_name"] for m in llm_model_list] if llm_model_list is not None else [] @@ -1021,15 +990,19 @@ async def chat_completion(request: Request, model: Optional[str] = None, user_ap else: # router is not set response = await litellm.acompletion(**data) + print(f"final response: {response}") if 'stream' in data and data['stream'] == True: # use generate_responses to stream responses - return StreamingResponse(async_data_generator(response), media_type='text/event-stream') + return StreamingResponse(async_data_generator(user_api_key_dict=user_api_key_dict, response=response), media_type='text/event-stream') ### CALL HOOKS ### - modify outgoing response - response = call_hooks.post_call_success(response=response, call_type="completion") + response = await call_hooks.post_call_success(user_api_key_dict=user_api_key_dict, response=response, call_type="completion") background_tasks.add_task(log_input_output, request, response) # background task for logging to OTEL return response - except Exception as e: + except Exception as e: + print(f"Exception received: {str(e)}") + raise e + await call_hooks.post_call_failure(original_exception=e, user_api_key_dict=user_api_key_dict) print(f"\033[1;31mAn error occurred: {e}\n\n Debug this by setting `--debug`, e.g. `litellm --model gpt-3.5-turbo --debug`") router_model_names = [m["model_name"] for m in llm_model_list] if llm_model_list is not None else [] if llm_router is not None and data.get("model", "") in router_model_names: @@ -1046,23 +1019,26 @@ async def chat_completion(request: Request, model: Optional[str] = None, user_ap print(f"{key}: {value}") if user_debug: traceback.print_exc() - error_traceback = traceback.format_exc() - error_msg = f"{str(e)}\n\n{error_traceback}" - try: - status = e.status_code # type: ignore - except: - status = 500 - raise HTTPException( - status_code=status, - detail=error_msg - ) + + if isinstance(e, HTTPException): + raise e + else: + error_traceback = traceback.format_exc() + error_msg = f"{str(e)}\n\n{error_traceback}" + try: + status = e.status_code # type: ignore + except: + status = 500 + raise HTTPException( + status_code=status, + detail=error_msg + ) @router.post("/v1/embeddings", dependencies=[Depends(user_api_key_auth)], response_class=ORJSONResponse) @router.post("/embeddings", dependencies=[Depends(user_api_key_auth)], response_class=ORJSONResponse) async def embeddings(request: Request, user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), background_tasks: BackgroundTasks = BackgroundTasks()): + global call_hooks try: - global call_hooks - # Use orjson to parse JSON data, orjson speeds up requests significantly body = await request.body() data = orjson.loads(body) @@ -1105,7 +1081,7 @@ async def embeddings(request: Request, user_api_key_dict: UserAPIKeyAuth = Depen break ### CALL HOOKS ### - modify incoming data before calling the model - data = call_hooks.pre_call(data=data, call_type="embeddings") + data = await call_hooks.pre_call(user_api_key_dict=user_api_key_dict, data=data, call_type="embeddings") ## ROUTE TO CORRECT ENDPOINT ## if llm_router is not None and data["model"] in router_model_names: # model in router model list @@ -1117,19 +1093,18 @@ async def embeddings(request: Request, user_api_key_dict: UserAPIKeyAuth = Depen background_tasks.add_task(log_input_output, request, response) # background task for logging to OTEL ### CALL HOOKS ### - modify outgoing response - data = call_hooks.post_call_success(response=response, call_type="embeddings") + data = call_hooks.post_call_success(user_api_key_dict=user_api_key_dict, response=response, call_type="embeddings") return response except Exception as e: + await call_hooks.post_call_failure(user_api_key_dict=user_api_key_dict, original_exception=e) traceback.print_exc() raise e - except Exception as e: - pass #### KEY MANAGEMENT #### @router.post("/key/generate", tags=["key management"], dependencies=[Depends(user_api_key_auth)], response_model=GenerateKeyResponse) -async def generate_key_fn(request: Request, data: GenerateKeyRequest): +async def generate_key_fn(request: Request, data: GenerateKeyRequest, Authorization: Optional[str] = Header(None)): """ Generate an API key based on the provided data. @@ -1141,26 +1116,17 @@ async def generate_key_fn(request: Request, data: GenerateKeyRequest): - aliases: Optional[dict] - Any alias mappings, on top of anything in the config.yaml model list. - https://docs.litellm.ai/docs/proxy/virtual_keys#managing-auth---upgradedowngrade-models - config: Optional[dict] - any key-specific configs, overrides config in config.yaml - spend: Optional[int] - Amount spent by key. Default is 0. Will be updated by proxy whenever key is used. https://docs.litellm.ai/docs/proxy/virtual_keys#managing-auth---tracking-spend + - max_parallel_requests: Optional[int] - Rate limit a user based on the number of parallel requests. Raises 429 error, if user's parallel requests > x. Returns: - - key: The generated api key - - expires: Datetime object for when key expires. + - key: (str) The generated api key + - expires: (datetime) Datetime object for when key expires. + - user_id: (str) Unique user id - used for tracking spend across multiple keys for same user id. """ # data = await request.json() - duration_str = data.duration # Default to 1 hour if duration is not provided - models = data.models # Default to an empty list (meaning allow token to call all models) - aliases = data.aliases # Default to an empty dict (no alias mappings, on top of anything in the config.yaml model_list) - config = data.config - spend = data.spend - user_id = data.user_id - if isinstance(models, list): - response = await generate_key_helper_fn(duration_str=duration_str, models=models, aliases=aliases, config=config, spend=spend, user_id=user_id) - return GenerateKeyResponse(key=response["token"], expires=response["expires"], user_id=response["user_id"]) - else: - raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, - detail={"error": "models param must be a list"}, - ) + data_json = data.model_dump() + response = await generate_key_helper_fn(**data_json) + return GenerateKeyResponse(key=response["token"], expires=response["expires"], user_id=response["user_id"]) @router.post("/key/delete", tags=["key management"], dependencies=[Depends(user_api_key_auth)]) async def delete_key_fn(request: Request, data: DeleteKeyRequest): diff --git a/litellm/proxy/schema.prisma b/litellm/proxy/schema.prisma index 1ce57e2fe..ab4fc5e00 100644 --- a/litellm/proxy/schema.prisma +++ b/litellm/proxy/schema.prisma @@ -16,4 +16,5 @@ model LiteLLM_VerificationToken { aliases Json @default("{}") config Json @default("{}") user_id String? + max_parallel_requests Int? } \ No newline at end of file diff --git a/litellm/proxy/utils.py b/litellm/proxy/utils.py index ab215f9ca..688820e83 100644 --- a/litellm/proxy/utils.py +++ b/litellm/proxy/utils.py @@ -1,7 +1,13 @@ from typing import Optional, List, Any, Literal import os, subprocess, hashlib, importlib, asyncio import litellm, backoff +from litellm.proxy._types import UserAPIKeyAuth +from litellm.caching import DualCache +from litellm.proxy.hooks.parallel_request_limiter import max_parallel_request_allow_request, max_parallel_request_update_count +def print_verbose(print_statement): + if litellm.set_verbose: + print(print_statement) # noqa ### LOGGING ### class ProxyLogging: """ @@ -17,7 +23,6 @@ class ProxyLogging: self._init_litellm_callbacks() pass - def _init_litellm_callbacks(self): if len(litellm.callbacks) > 0: for callback in litellm.callbacks: @@ -69,11 +74,11 @@ class ProxyLogging: # Function to be called whenever a retry is about to happen def on_backoff(details): # The 'tries' key in the details dictionary contains the number of completed tries - print(f"Backing off... this was attempt #{details['tries']}") + print_verbose(f"Backing off... this was attempt #{details['tries']}") class PrismaClient: def __init__(self, database_url: str, proxy_logging_obj: ProxyLogging): - print("LiteLLM: DATABASE_URL Set in config, trying to 'pip install prisma'") + print_verbose("LiteLLM: DATABASE_URL Set in config, trying to 'pip install prisma'") ## init logging object self.proxy_logging_obj = proxy_logging_obj @@ -109,20 +114,22 @@ class PrismaClient: max_time=10, # maximum total time to retry for on_backoff=on_backoff, # specifying the function to call on backoff ) - async def get_data(self, token: str, expires: Optional[Any]=None): + async def get_data(self, token: str, expires: Optional[Any]=None): try: - hashed_token = self.hash_token(token=token) + # check if plain text or hash + if token.startswith("sk-"): + token = self.hash_token(token=token) if expires: response = await self.db.litellm_verificationtoken.find_first( where={ - "token": hashed_token, + "token": token, "expires": {"gte": expires} # Check if the token is not expired } ) else: response = await self.db.litellm_verificationtoken.find_unique( where={ - "token": hashed_token + "token": token } ) return response @@ -175,25 +182,23 @@ class PrismaClient: Update existing data """ try: - hashed_token = self.hash_token(token=token) - data["token"] = hashed_token - await self.db.litellm_verificationtoken.update( + print_verbose(f"token: {token}") + # check if plain text or hash + if token.startswith("sk-"): + token = self.hash_token(token=token) + + data["token"] = token + response = await self.db.litellm_verificationtoken.update( where={ - "token": hashed_token + "token": token }, data={**data} # type: ignore ) - print("\033[91m" + f"DB write succeeded" + "\033[0m") + print_verbose("\033[91m" + f"DB write succeeded {response}" + "\033[0m") return {"token": token, "data": data} except Exception as e: asyncio.create_task(self.proxy_logging_obj.failure_handler(original_exception=e)) - print() - print() - print() - print("\033[91m" + f"DB write failed: {e}" + "\033[0m") - print() - print() - print() + print_verbose("\033[91m" + f"DB write failed: {e}" + "\033[0m") raise e @@ -252,7 +257,7 @@ class PrismaClient: ### CUSTOM FILE ### def get_instance_fn(value: str, config_file_path: Optional[str] = None) -> Any: try: - print(f"value: {value}") + print_verbose(f"value: {value}") # Split the path by dots to separate module from instance parts = value.split(".") @@ -285,8 +290,6 @@ def get_instance_fn(value: str, config_file_path: Optional[str] = None) -> Any: except Exception as e: raise e - - ### CALL HOOKS ### class CallHooks: """ @@ -297,20 +300,55 @@ class CallHooks: 2. /embeddings """ - def __init__(self, *args, **kwargs): - self.call_details = {} + def __init__(self, user_api_key_cache: DualCache): + self.call_details: dict = {} + self.call_details["user_api_key_cache"] = user_api_key_cache def update_router_config(self, litellm_settings: dict, general_settings: dict, model_list: list): self.call_details["litellm_settings"] = litellm_settings self.call_details["general_settings"] = general_settings self.call_details["model_list"] = model_list - def pre_call(self, data: dict, call_type: Literal["completion", "embeddings"]): - self.call_details["data"] = data - return data + async def pre_call(self, user_api_key_dict: UserAPIKeyAuth, data: dict, call_type: Literal["completion", "embeddings"]): + try: + self.call_details["data"] = data + self.call_details["call_type"] = call_type - def post_call_success(self, response: Optional[Any]=None, call_type: Optional[Literal["completion", "embeddings"]]=None, chunk: Optional[Any]=None): - return response + ## check if max parallel requests set + if user_api_key_dict.max_parallel_requests is not None: + ## if set, check if request allowed + await max_parallel_request_allow_request( + max_parallel_requests=user_api_key_dict.max_parallel_requests, + api_key=user_api_key_dict.api_key, + user_api_key_cache=self.call_details["user_api_key_cache"]) + + return data + except Exception as e: + raise e - def post_call_failure(self, *args, **kwargs): - pass \ No newline at end of file + async def post_call_success(self, user_api_key_dict: UserAPIKeyAuth, response: Optional[Any]=None, call_type: Optional[Literal["completion", "embeddings"]]=None, chunk: Optional[Any]=None): + try: + # check if max parallel requests set + if user_api_key_dict.max_parallel_requests is not None: + ## decrement call, once complete + await max_parallel_request_update_count( + api_key=user_api_key_dict.api_key, + user_api_key_cache=self.call_details["user_api_key_cache"]) + + return response + except Exception as e: + raise e + + async def post_call_failure(self, original_exception: Exception, user_api_key_dict: UserAPIKeyAuth): + # check if max parallel requests set + if user_api_key_dict.max_parallel_requests is not None: + ## decrement call count if call failed + if (hasattr(original_exception, "status_code") + and original_exception.status_code == 429 + and "Max parallel request limit reached" in str(original_exception)): + pass # ignore failed calls due to max limit being reached + else: + await max_parallel_request_update_count( + api_key=user_api_key_dict.api_key, + user_api_key_cache=self.call_details["user_api_key_cache"]) + return \ No newline at end of file diff --git a/litellm/tests/test_configs/test_config.yaml b/litellm/tests/test_configs/test_config.yaml index 253a39774..a5e7802a4 100644 --- a/litellm/tests/test_configs/test_config.yaml +++ b/litellm/tests/test_configs/test_config.yaml @@ -3,7 +3,6 @@ general_settings: master_key: os.environ/PROXY_MASTER_KEY litellm_settings: drop_params: true - set_verbose: true success_callback: ["langfuse"] model_list: diff --git a/litellm/tests/test_proxy_server_keys.py b/litellm/tests/test_proxy_server_keys.py index fb0ec2f3c..5ffbfe3b0 100644 --- a/litellm/tests/test_proxy_server_keys.py +++ b/litellm/tests/test_proxy_server_keys.py @@ -1,4 +1,4 @@ -import sys, os +import sys, os, time import traceback from dotenv import load_dotenv @@ -19,7 +19,7 @@ logging.basicConfig( level=logging.DEBUG, # Set the desired logging level format="%(asctime)s - %(levelname)s - %(message)s", ) - +from concurrent.futures import ThreadPoolExecutor # test /chat/completion request to the proxy from fastapi.testclient import TestClient from fastapi import FastAPI @@ -62,6 +62,41 @@ def test_add_new_key(client): assert result["key"].startswith("sk-") print(f"Received response: {result}") except Exception as e: - pytest.fail("LiteLLM Proxy test failed. Exception", e) + pytest.fail(f"LiteLLM Proxy test failed. Exception: {str(e)}") -# # Run the test - only runs via pytest \ No newline at end of file +# # Run the test - only runs via pytest + + +def test_add_new_key_max_parallel_limit(client): + try: + # Your test data + test_data = {"duration": "20m", "max_parallel_requests": 1} + # Your bearer token + token = os.getenv("PROXY_MASTER_KEY") + + headers = { + "Authorization": f"Bearer {token}" + } + response = client.post("/key/generate", json=test_data, headers=headers) + print(f"response: {response.text}") + assert response.status_code == 200 + result = response.json() + def _post_data(): + json_data = {'model': 'azure-model', "messages": [{"role": "user", "content": f"this is a test request, write a short poem {time.time()}"}]} + response = client.post("/chat/completions", json=json_data, headers={"Authorization": f"Bearer {result['key']}"}) + return response + def _run_in_parallel(): + with ThreadPoolExecutor(max_workers=2) as executor: + future1 = executor.submit(_post_data) + future2 = executor.submit(_post_data) + + # Obtain the results from the futures + response1 = future1.result() + response2 = future2.result() + if response1.status_code == 429 or response2.status_code == 429: + pass + else: + raise Exception() + _run_in_parallel() + except Exception as e: + pytest.fail(f"LiteLLM Proxy test failed. Exception: {str(e)}")