From 1a29272b47a135154fcf330fb5346bb98a73bc91 Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Fri, 19 Jan 2024 10:22:27 -0800 Subject: [PATCH] fix(parallel_request_limiter.py): handle tpm/rpm limits being null --- litellm/proxy/_types.py | 12 ++-- .../proxy/hooks/parallel_request_limiter.py | 4 +- tests/README.MD | 1 + tests/test_chat_completion.py | 58 +++++++++++++++++++ tests/test_parallel_key_gen.py | 33 +++++++++++ 5 files changed, 100 insertions(+), 8 deletions(-) create mode 100644 tests/README.MD create mode 100644 tests/test_chat_completion.py create mode 100644 tests/test_parallel_key_gen.py diff --git a/litellm/proxy/_types.py b/litellm/proxy/_types.py index e033994d9..3315fe607 100644 --- a/litellm/proxy/_types.py +++ b/litellm/proxy/_types.py @@ -132,8 +132,8 @@ class GenerateKeyRequest(LiteLLMBase): team_id: Optional[str] = None max_parallel_requests: Optional[int] = None metadata: Optional[dict] = {} - tpm_limit: int = sys.maxsize - rpm_limit: int = sys.maxsize + tpm_limit: Optional[int] = None + rpm_limit: Optional[int] = None class UpdateKeyRequest(LiteLLMBase): @@ -148,8 +148,8 @@ class UpdateKeyRequest(LiteLLMBase): user_id: Optional[str] = None max_parallel_requests: Optional[int] = None metadata: Optional[dict] = None - tpm_limit: int = sys.maxsize - rpm_limit: int = sys.maxsize + tpm_limit: Optional[int] = None + rpm_limit: Optional[int] = None class UserAPIKeyAuth(LiteLLMBase): # the expected response object for user api key auth @@ -166,8 +166,8 @@ class UserAPIKeyAuth(LiteLLMBase): # the expected response object for user api max_parallel_requests: Optional[int] = None duration: str = "1h" metadata: dict = {} - tpm_limit: int = sys.maxsize - rpm_limit: int = sys.maxsize + tpm_limit: Optional[int] = None + rpm_limit: Optional[int] = None class GenerateKeyResponse(LiteLLMBase): diff --git a/litellm/proxy/hooks/parallel_request_limiter.py b/litellm/proxy/hooks/parallel_request_limiter.py index 2ef19a149..0a38e5ede 100644 --- a/litellm/proxy/hooks/parallel_request_limiter.py +++ b/litellm/proxy/hooks/parallel_request_limiter.py @@ -29,8 +29,8 @@ class MaxParallelRequestsHandler(CustomLogger): self.print_verbose(f"Inside Max Parallel Request Pre-Call Hook") api_key = user_api_key_dict.api_key max_parallel_requests = user_api_key_dict.max_parallel_requests or sys.maxsize - tpm_limit = user_api_key_dict.tpm_limit - rpm_limit = user_api_key_dict.rpm_limit + tpm_limit = user_api_key_dict.tpm_limit or sys.maxsize + rpm_limit = user_api_key_dict.rpm_limit or sys.maxsize if api_key is None: return diff --git a/tests/README.MD b/tests/README.MD new file mode 100644 index 000000000..6555b3728 --- /dev/null +++ b/tests/README.MD @@ -0,0 +1 @@ +Most tests are in `/litellm/tests`. These are just the tests for the proxy docker image, used for circle ci. diff --git a/tests/test_chat_completion.py b/tests/test_chat_completion.py new file mode 100644 index 000000000..b8da94155 --- /dev/null +++ b/tests/test_chat_completion.py @@ -0,0 +1,58 @@ +# What this tests ? +## Tests /chat/completions by generating a key and then making a chat completions request +import pytest +import asyncio +import aiohttp + + +async def generate_key(session): + url = "http://0.0.0.0:4000/key/generate" + headers = {"Authorization": "Bearer sk-1234", "Content-Type": "application/json"} + data = { + "models": ["gpt-4"], + "duration": None, + } + + async with session.post(url, headers=headers, json=data) as response: + status = response.status + response_text = await response.text() + + print(response_text) + print() + + if status != 200: + raise Exception(f"Request did not return a 200 status code: {status}") + return await response.json() + + +async def chat_completion(session, key): + url = "http://0.0.0.0:4000/chat/completions" + headers = { + "Authorization": f"Bearer {key}", + "Content-Type": "application/json", + } + data = { + "model": "gpt-4", + "messages": [ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": "Hello!"}, + ], + } + + async with session.post(url, headers=headers, json=data) as response: + status = response.status + response_text = await response.text() + + print(response_text) + print() + + if status != 200: + raise Exception(f"Request did not return a 200 status code: {status}") + + +@pytest.mark.asyncio +async def test_key_gen(): + async with aiohttp.ClientSession() as session: + key_gen = await generate_key(session=session) + key = key_gen["key"] + await chat_completion(session=session, key=key) diff --git a/tests/test_parallel_key_gen.py b/tests/test_parallel_key_gen.py new file mode 100644 index 000000000..36595b4c3 --- /dev/null +++ b/tests/test_parallel_key_gen.py @@ -0,0 +1,33 @@ +# What this tests ? +## Tests /key/generate by making 10 parallel requests, and asserting all are successful +import pytest +import asyncio +import aiohttp + + +async def generate_key(session, i): + url = "http://0.0.0.0:4000/key/generate" + headers = {"Authorization": "Bearer sk-1234", "Content-Type": "application/json"} + data = { + "models": ["azure-models"], + "aliases": {"mistral-7b": "gpt-3.5-turbo"}, + "duration": None, + } + + async with session.post(url, headers=headers, json=data) as response: + status = response.status + response_text = await response.text() + + print(f"Response {i} (Status code: {status}):") + print(response_text) + print() + + if status != 200: + raise Exception(f"Request {i} did not return a 200 status code: {status}") + + +@pytest.mark.asyncio +async def test_key_gen(): + async with aiohttp.ClientSession() as session: + tasks = [generate_key(session, i) for i in range(1, 11)] + await asyncio.gather(*tasks)