From f5ced089d6f0af05600062e25a981fdabebba815 Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Fri, 19 Jan 2024 14:54:15 -0800 Subject: [PATCH] test(tests/): add unit testing for proxy server endpoints --- litellm/proxy/_types.py | 4 +- litellm/proxy/proxy_server.py | 69 ++++-- .../tests/test_amazing_vertex_completion.py | 5 +- proxy_server_config.yaml | 21 ++ tests/test_chat_completion.py | 58 ----- tests/test_health.py | 115 ++++++++++ tests/test_keys.py | 183 ++++++++++++++++ tests/test_models.py | 190 +++++++++++++++++ tests/test_openai_endpoints.py | 201 ++++++++++++++++++ tests/test_parallel_key_gen.py | 33 --- tests/test_users.py | 102 +++++++++ 11 files changed, 870 insertions(+), 111 deletions(-) delete mode 100644 tests/test_chat_completion.py create mode 100644 tests/test_health.py create mode 100644 tests/test_keys.py create mode 100644 tests/test_models.py create mode 100644 tests/test_openai_endpoints.py delete mode 100644 tests/test_parallel_key_gen.py create mode 100644 tests/test_users.py diff --git a/litellm/proxy/_types.py b/litellm/proxy/_types.py index 3315fe607..72b7273e5 100644 --- a/litellm/proxy/_types.py +++ b/litellm/proxy/_types.py @@ -13,7 +13,7 @@ class LiteLLMBase(BaseModel): def json(self, **kwargs): try: return self.model_dump() # noqa - except: + except Exception as e: # if using pydantic v1 return self.dict() @@ -177,7 +177,7 @@ class GenerateKeyResponse(LiteLLMBase): class DeleteKeyRequest(LiteLLMBase): - keys: List[str] + keys: List class NewUserRequest(GenerateKeyRequest): diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index acc48af60..15fda1501 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -264,16 +264,6 @@ async def user_api_key_auth( if route.startswith("/config/") and not is_master_key_valid: raise Exception(f"Only admin can modify config") - if ( - (route.startswith("/key/") or route.startswith("/user/")) - or route.startswith("/model/") - and not is_master_key_valid - and general_settings.get("allow_user_auth", False) != True - ): - raise Exception( - f"If master key is set, only master key can be used to generate, delete, update or get info for new keys/users" - ) - if ( prisma_client is None and custom_db_client is None ): # if both master key + user key submitted, and user key != master key, and no db connected, raise an error @@ -432,6 +422,39 @@ async def user_api_key_auth( db=custom_db_client, ) ) + + if ( + (route.startswith("/key/") or route.startswith("/user/")) + or route.startswith("/model/") + and not is_master_key_valid + and general_settings.get("allow_user_auth", False) != True + ): + if route == "/key/info": + # check if user can access this route + query_params = request.query_params + key = query_params.get("key") + if prisma_client.hash_token(token=key) != api_key: + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail="user not allowed to access this key's info", + ) + elif route == "/user/info": + # check if user can access this route + query_params = request.query_params + user_id = query_params.get("user_id") + if user_id != valid_token.user_id: + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail="user not allowed to access this key's info", + ) + elif route == "/model/info": + # /model/info just shows models user has access to + pass + else: + raise Exception( + f"If master key is set, only master key can be used to generate, delete, update or get info for new keys/users" + ) + return UserAPIKeyAuth(api_key=api_key, **valid_token_dict) else: raise Exception(f"Invalid Key Passed to LiteLLM Proxy") @@ -2160,7 +2183,7 @@ async def update_key_fn(request: Request, data: UpdateKeyRequest): @router.post( "/key/delete", tags=["key management"], dependencies=[Depends(user_api_key_auth)] ) -async def delete_key_fn(request: Request, data: DeleteKeyRequest): +async def delete_key_fn(data: DeleteKeyRequest): """ Delete a key from the key management system. @@ -2203,6 +2226,9 @@ async def info_key_fn( f"Database not connected. Connect a database to your proxy - https://docs.litellm.ai/docs/simple_proxy#managing-auth---virtual-keys" ) key_info = await prisma_client.get_data(token=key) + ## REMOVE HASHED TOKEN INFO BEFORE RETURNING ## + key_info = key_info.model_dump() + key_info.pop("token") return {"key": key, "info": key_info} except Exception as e: raise HTTPException( @@ -2338,6 +2364,10 @@ async def user_info( keys = await prisma_client.get_data( user_id=user_id, table_name="key", query_type="find_all" ) + ## REMOVE HASHED TOKEN INFO before returning ## + for key in keys: + key = key.model_dump() + key.pop("token", None) return {"user_id": user_id, "user_info": user_info, "keys": keys} except Exception as e: raise HTTPException( @@ -2415,13 +2445,19 @@ async def add_new_model(model_params: ModelParams): tags=["model management"], dependencies=[Depends(user_api_key_auth)], ) -async def model_info_v1(request: Request): +async def model_info_v1( + user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), +): global llm_model_list, general_settings, user_config_file_path, proxy_config # Load existing config config = await proxy_config.get_config() - all_models = config["model_list"] + if len(user_api_key_dict.models) > 0: + model_names = user_api_key_dict.models + all_models = [m for m in config["model_list"] if m in model_names] + else: + all_models = config["model_list"] for model in all_models: # provided model_info in config.yaml model_info = model.get("model_info", {}) @@ -2750,7 +2786,7 @@ async def test_endpoint(request: Request): @router.get("/health", tags=["health"], dependencies=[Depends(user_api_key_auth)]) async def health_endpoint( - request: Request, + user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), model: Optional[str] = fastapi.Query( None, description="Specify the model name (optional)" ), @@ -2785,6 +2821,11 @@ async def health_endpoint( detail={"error": "Model list not initialized"}, ) + ### FILTER MODELS FOR ONLY THOSE USER HAS ACCESS TO ### + if len(user_api_key_dict.models) > 0: + allowed_model_names = user_api_key_dict.models + else: + allowed_model_names = [] # if use_background_health_checks: return health_check_results else: diff --git a/litellm/tests/test_amazing_vertex_completion.py b/litellm/tests/test_amazing_vertex_completion.py index a56e0343c..8467e4434 100644 --- a/litellm/tests/test_amazing_vertex_completion.py +++ b/litellm/tests/test_amazing_vertex_completion.py @@ -302,10 +302,7 @@ def test_gemini_pro_vision(): assert prompt_tokens == 263 # the gemini api returns 263 to us except Exception as e: - import traceback - - traceback.print_exc() - raise e + pytest.fail(f"An exception occurred - {str(e)}") # test_gemini_pro_vision() diff --git a/proxy_server_config.yaml b/proxy_server_config.yaml index abe999858..5a089c764 100644 --- a/proxy_server_config.yaml +++ b/proxy_server_config.yaml @@ -1,4 +1,10 @@ model_list: + - model_name: gpt-3.5-turbo + litellm_params: + model: azure/chatgpt-v-2 + api_base: https://openai-gpt-4-test-v-1.openai.azure.com/ + api_version: "2023-05-15" + api_key: os.environ/AZURE_API_KEY # The `os.environ/` prefix tells litellm to read this from the env. See https://docs.litellm.ai/docs/simple_proxy#load-api-keys-from-vault - model_name: gpt-4 litellm_params: model: azure/chatgpt-v-2 @@ -17,6 +23,21 @@ model_list: api_key: os.environ/AZURE_EUROPE_API_KEY api_base: https://my-endpoint-europe-berri-992.openai.azure.com rpm: 10 + - model_name: text-embedding-ada-002 + litellm_params: + model: azure/azure-embedding-model + api_key: os.environ/AZURE_API_KEY + api_base: https://openai-gpt-4-test-v-1.openai.azure.com/ + api_version: "2023-05-15" + model_info: + mode: embedding + base_model: text-embedding-ada-002 + - model_name: dall-e-2 + litellm_params: + model: azure/ + api_version: 2023-06-01-preview + api_base: https://openai-gpt-4-test-v-1.openai.azure.com/ + api_key: os.environ/AZURE_API_KEY litellm_settings: drop_params: True diff --git a/tests/test_chat_completion.py b/tests/test_chat_completion.py deleted file mode 100644 index b8da94155..000000000 --- a/tests/test_chat_completion.py +++ /dev/null @@ -1,58 +0,0 @@ -# What this tests ? -## Tests /chat/completions by generating a key and then making a chat completions request -import pytest -import asyncio -import aiohttp - - -async def generate_key(session): - url = "http://0.0.0.0:4000/key/generate" - headers = {"Authorization": "Bearer sk-1234", "Content-Type": "application/json"} - data = { - "models": ["gpt-4"], - "duration": None, - } - - async with session.post(url, headers=headers, json=data) as response: - status = response.status - response_text = await response.text() - - print(response_text) - print() - - if status != 200: - raise Exception(f"Request did not return a 200 status code: {status}") - return await response.json() - - -async def chat_completion(session, key): - url = "http://0.0.0.0:4000/chat/completions" - headers = { - "Authorization": f"Bearer {key}", - "Content-Type": "application/json", - } - data = { - "model": "gpt-4", - "messages": [ - {"role": "system", "content": "You are a helpful assistant."}, - {"role": "user", "content": "Hello!"}, - ], - } - - async with session.post(url, headers=headers, json=data) as response: - status = response.status - response_text = await response.text() - - print(response_text) - print() - - if status != 200: - raise Exception(f"Request did not return a 200 status code: {status}") - - -@pytest.mark.asyncio -async def test_key_gen(): - async with aiohttp.ClientSession() as session: - key_gen = await generate_key(session=session) - key = key_gen["key"] - await chat_completion(session=session, key=key) diff --git a/tests/test_health.py b/tests/test_health.py new file mode 100644 index 000000000..f0a89f529 --- /dev/null +++ b/tests/test_health.py @@ -0,0 +1,115 @@ +# What this tests? +## Tests /health + /routes endpoints. + +import pytest +import asyncio +import aiohttp + + +async def health(session, call_key): + url = "http://0.0.0.0:4000/health" + headers = { + "Authorization": f"Bearer {call_key}", + "Content-Type": "application/json", + } + + async with session.get(url, headers=headers) as response: + status = response.status + response_text = await response.text() + + print(f"Response (Status code: {status}):") + print(response_text) + print() + + if status != 200: + raise Exception(f"Request did not return a 200 status code: {status}") + + return await response.json() + + +async def generate_key(session): + url = "http://0.0.0.0:4000/key/generate" + headers = {"Authorization": "Bearer sk-1234", "Content-Type": "application/json"} + data = { + "models": ["gpt-4", "text-embedding-ada-002", "dall-e-2"], + "duration": None, + } + + async with session.post(url, headers=headers, json=data) as response: + status = response.status + response_text = await response.text() + + print(response_text) + print() + + if status != 200: + raise Exception(f"Request did not return a 200 status code: {status}") + return await response.json() + + +@pytest.mark.asyncio +async def test_health(): + """ + - Call /health + """ + async with aiohttp.ClientSession() as session: + # as admin # + all_healthy_models = await health(session=session, call_key="sk-1234") + total_model_count = ( + all_healthy_models["healthy_count"] + all_healthy_models["unhealthy_count"] + ) + assert total_model_count > 0 + + +@pytest.mark.asyncio +async def test_health_readiness(): + """ + Check if 200 + """ + async with aiohttp.ClientSession() as session: + url = "http://0.0.0.0:4000/health/readiness" + async with session.get(url) as response: + status = response.status + response_text = await response.text() + + print(response_text) + print() + + if status != 200: + raise Exception(f"Request did not return a 200 status code: {status}") + + +@pytest.mark.asyncio +async def test_health_liveliness(): + """ + Check if 200 + """ + async with aiohttp.ClientSession() as session: + url = "http://0.0.0.0:4000/health/liveliness" + async with session.get(url) as response: + status = response.status + response_text = await response.text() + + print(response_text) + print() + + if status != 200: + raise Exception(f"Request did not return a 200 status code: {status}") + + +@pytest.mark.asyncio +async def test_routes(): + """ + Check if 200 + """ + async with aiohttp.ClientSession() as session: + url = "http://0.0.0.0:4000/routes" + async with session.get(url) as response: + status = response.status + response_text = await response.text() + + print(response_text) + print() + + if status != 200: + raise Exception(f"Request did not return a 200 status code: {status}") diff --git a/tests/test_keys.py b/tests/test_keys.py new file mode 100644 index 000000000..f209f4c5a --- /dev/null +++ b/tests/test_keys.py @@ -0,0 +1,183 @@ +# What this tests ? +## Tests /key endpoints. + +import pytest +import asyncio +import aiohttp + + +async def generate_key(session, i): + url = "http://0.0.0.0:4000/key/generate" + headers = {"Authorization": "Bearer sk-1234", "Content-Type": "application/json"} + data = { + "models": ["azure-models"], + "aliases": {"mistral-7b": "gpt-3.5-turbo"}, + "duration": None, + } + + async with session.post(url, headers=headers, json=data) as response: + status = response.status + response_text = await response.text() + + print(f"Response {i} (Status code: {status}):") + print(response_text) + print() + + if status != 200: + raise Exception(f"Request {i} did not return a 200 status code: {status}") + + return await response.json() + + +@pytest.mark.asyncio +async def test_key_gen(): + async with aiohttp.ClientSession() as session: + tasks = [generate_key(session, i) for i in range(1, 11)] + await asyncio.gather(*tasks) + + +async def update_key(session, get_key): + """ + Make sure only models user has access to are returned + """ + url = "http://0.0.0.0:4000/key/update" + headers = { + "Authorization": f"Bearer sk-1234", + "Content-Type": "application/json", + } + data = {"key": get_key, "models": ["gpt-4"]} + + async with session.post(url, headers=headers, json=data) as response: + status = response.status + response_text = await response.text() + print(response_text) + print() + + if status != 200: + raise Exception(f"Request did not return a 200 status code: {status}") + return await response.json() + + +async def chat_completion(session, key, model="gpt-4"): + url = "http://0.0.0.0:4000/chat/completions" + headers = { + "Authorization": f"Bearer {key}", + "Content-Type": "application/json", + } + data = { + "model": model, + "messages": [ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": "Hello!"}, + ], + } + + async with session.post(url, headers=headers, json=data) as response: + status = response.status + response_text = await response.text() + + print(response_text) + print() + + if status != 200: + raise Exception(f"Request did not return a 200 status code: {status}") + + +@pytest.mark.asyncio +async def test_key_update(): + """ + Create key + Update key with new model + Test key w/ model + """ + async with aiohttp.ClientSession() as session: + key_gen = await generate_key(session=session, i=0) + key = key_gen["key"] + await update_key( + session=session, + get_key=key, + ) + await chat_completion(session=session, key=key) + + +async def delete_key(session, get_key): + """ + Delete key + """ + url = "http://0.0.0.0:4000/key/delete" + headers = { + "Authorization": f"Bearer sk-1234", + "Content-Type": "application/json", + } + data = {"keys": [get_key]} + + async with session.post(url, headers=headers, json=data) as response: + status = response.status + response_text = await response.text() + print(response_text) + print() + + if status != 200: + raise Exception(f"Request did not return a 200 status code: {status}") + return await response.json() + + +@pytest.mark.asyncio +async def test_key_delete(): + """ + Delete key + """ + async with aiohttp.ClientSession() as session: + key_gen = await generate_key(session=session, i=0) + key = key_gen["key"] + await delete_key( + session=session, + get_key=key, + ) + + +async def get_key_info(session, get_key, call_key): + """ + Make sure only models user has access to are returned + """ + url = f"http://0.0.0.0:4000/key/info?key={get_key}" + headers = { + "Authorization": f"Bearer {call_key}", + "Content-Type": "application/json", + } + + async with session.get(url, headers=headers) as response: + status = response.status + response_text = await response.text() + print(response_text) + print() + + if status != 200: + if call_key != get_key: + return status + else: + print(f"call_key: {call_key}; get_key: {get_key}") + raise Exception(f"Request did not return a 200 status code: {status}") + return await response.json() + + +@pytest.mark.asyncio +async def test_key_info(): + """ + Get key info + - as admin -> 200 + - as key itself -> 200 + - as random key -> 403 + """ + async with aiohttp.ClientSession() as session: + key_gen = await generate_key(session=session, i=0) + key = key_gen["key"] + # as admin # + await get_key_info(session=session, get_key=key, call_key="sk-1234") + # as key itself # + await get_key_info(session=session, get_key=key, call_key=key) + # as random key # + key_gen = await generate_key(session=session, i=0) + random_key = key_gen["key"] + status = await get_key_info(session=session, get_key=key, call_key=random_key) + assert status == 403 diff --git a/tests/test_models.py b/tests/test_models.py new file mode 100644 index 000000000..b76dfb116 --- /dev/null +++ b/tests/test_models.py @@ -0,0 +1,190 @@ +# What this tests ? +## Tests /models and /model/* endpoints + +import pytest +import asyncio +import aiohttp + + +async def generate_key(session, models=[]): + url = "http://0.0.0.0:4000/key/generate" + headers = {"Authorization": "Bearer sk-1234", "Content-Type": "application/json"} + data = { + "models": models, + "duration": None, + } + + async with session.post(url, headers=headers, json=data) as response: + status = response.status + response_text = await response.text() + + print(response_text) + print() + + if status != 200: + raise Exception(f"Request did not return a 200 status code: {status}") + return await response.json() + + +async def get_models(session, key): + url = "http://0.0.0.0:4000/models" + headers = { + "Authorization": f"Bearer {key}", + "Content-Type": "application/json", + } + + async with session.get(url, headers=headers) as response: + status = response.status + response_text = await response.text() + + print(response_text) + print() + + if status != 200: + raise Exception(f"Request did not return a 200 status code: {status}") + + +@pytest.mark.asyncio +async def test_get_models(): + async with aiohttp.ClientSession() as session: + key_gen = await generate_key(session=session) + key = key_gen["key"] + await get_models(session=session, key=key) + + +async def add_models(session, model_id="123"): + url = "http://0.0.0.0:4000/model/new" + headers = { + "Authorization": f"Bearer sk-1234", + "Content-Type": "application/json", + } + + data = { + "model_name": "azure-gpt-3.5", + "litellm_params": { + "model": "azure/chatgpt-v-2", + "api_key": "os.environ/AZURE_API_KEY", + "api_base": "https://openai-gpt-4-test-v-1.openai.azure.com/", + "api_version": "2023-05-15", + }, + "model_info": {"id": model_id}, + } + + async with session.post(url, headers=headers, json=data) as response: + status = response.status + response_text = await response.text() + + print(f"Add models {response_text}") + print() + + if status != 200: + raise Exception(f"Request did not return a 200 status code: {status}") + + +async def get_model_info(session, key): + """ + Make sure only models user has access to are returned + """ + url = "http://0.0.0.0:4000/model/info" + headers = { + "Authorization": f"Bearer {key}", + "Content-Type": "application/json", + } + + async with session.get(url, headers=headers) as response: + status = response.status + response_text = await response.text() + print(response_text) + print() + + if status != 200: + raise Exception(f"Request did not return a 200 status code: {status}") + return await response.json() + + +async def chat_completion(session, key): + url = "http://0.0.0.0:4000/chat/completions" + headers = { + "Authorization": f"Bearer {key}", + "Content-Type": "application/json", + } + data = { + "model": "azure-gpt-3.5", + "messages": [ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": "Hello!"}, + ], + } + + async with session.post(url, headers=headers, json=data) as response: + status = response.status + response_text = await response.text() + + print(response_text) + print() + + if status != 200: + raise Exception(f"Request did not return a 200 status code: {status}") + + +@pytest.mark.asyncio +async def test_add_models(): + """ + Add model + Call new model + """ + async with aiohttp.ClientSession() as session: + key_gen = await generate_key(session=session) + key = key_gen["key"] + await add_models(session=session) + await chat_completion(session=session, key=key) + + +@pytest.mark.asyncio +async def test_get_models(): + """ + Get models user has access to + """ + async with aiohttp.ClientSession() as session: + key_gen = await generate_key(session=session, models=["gpt-4"]) + key = key_gen["key"] + response = await get_model_info(session=session, key=key) + models = [m["model_name"] for m in response["data"]] + for m in models: + assert m == "gpt-4" + + +async def delete_model(session, model_id="123"): + """ + Make sure only models user has access to are returned + """ + url = "http://0.0.0.0:4000/model/delete" + headers = { + "Authorization": f"Bearer sk-1234", + "Content-Type": "application/json", + } + data = {"id": model_id} + + async with session.post(url, headers=headers, json=data) as response: + status = response.status + response_text = await response.text() + print(response_text) + print() + + if status != 200: + raise Exception(f"Request did not return a 200 status code: {status}") + return await response.json() + + +@pytest.mark.asyncio +async def test_delete_models(): + """ + Get models user has access to + """ + model_id = "12345" + async with aiohttp.ClientSession() as session: + key_gen = await generate_key(session=session) + key = key_gen["key"] + await add_models(session=session, model_id=model_id) + await chat_completion(session=session, key=key) + await delete_model(session=session, model_id=model_id) diff --git a/tests/test_openai_endpoints.py b/tests/test_openai_endpoints.py new file mode 100644 index 000000000..5a91bffa7 --- /dev/null +++ b/tests/test_openai_endpoints.py @@ -0,0 +1,201 @@ +# What this tests ? +## Tests /chat/completions by generating a key and then making a chat completions request +import pytest +import asyncio +import aiohttp + + +async def generate_key(session): + url = "http://0.0.0.0:4000/key/generate" + headers = {"Authorization": "Bearer sk-1234", "Content-Type": "application/json"} + data = { + "models": ["gpt-4", "text-embedding-ada-002", "dall-e-2"], + "duration": None, + } + + async with session.post(url, headers=headers, json=data) as response: + status = response.status + response_text = await response.text() + + print(response_text) + print() + + if status != 200: + raise Exception(f"Request did not return a 200 status code: {status}") + return await response.json() + + +async def new_user(session): + url = "http://0.0.0.0:4000/user/new" + headers = {"Authorization": "Bearer sk-1234", "Content-Type": "application/json"} + data = { + "models": ["gpt-4", "text-embedding-ada-002", "dall-e-2"], + "duration": None, + } + + async with session.post(url, headers=headers, json=data) as response: + status = response.status + response_text = await response.text() + + print(response_text) + print() + + if status != 200: + raise Exception(f"Request did not return a 200 status code: {status}") + return await response.json() + + +async def chat_completion(session, key): + url = "http://0.0.0.0:4000/chat/completions" + headers = { + "Authorization": f"Bearer {key}", + "Content-Type": "application/json", + } + data = { + "model": "gpt-4", + "messages": [ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": "Hello!"}, + ], + } + + async with session.post(url, headers=headers, json=data) as response: + status = response.status + response_text = await response.text() + + print(response_text) + print() + + if status != 200: + raise Exception(f"Request did not return a 200 status code: {status}") + + +@pytest.mark.asyncio +async def test_chat_completion(): + """ + - Create key + Make chat completion call + - Create user + make chat completion call + """ + async with aiohttp.ClientSession() as session: + key_gen = await generate_key(session=session) + key = key_gen["key"] + await chat_completion(session=session, key=key) + key_gen = await new_user(session=session) + key_2 = key_gen["key"] + await chat_completion(session=session, key=key_2) + + +async def completion(session, key): + url = "http://0.0.0.0:4000/completions" + headers = { + "Authorization": f"Bearer {key}", + "Content-Type": "application/json", + } + data = {"model": "gpt-4", "prompt": "Hello!"} + + async with session.post(url, headers=headers, json=data) as response: + status = response.status + response_text = await response.text() + + print(response_text) + print() + + if status != 200: + raise Exception(f"Request did not return a 200 status code: {status}") + + +@pytest.mark.asyncio +async def test_completion(): + """ + - Create key + Make chat completion call + - Create user + make chat completion call + """ + async with aiohttp.ClientSession() as session: + key_gen = await generate_key(session=session) + key = key_gen["key"] + await completion(session=session, key=key) + key_gen = await new_user(session=session) + key_2 = key_gen["key"] + await completion(session=session, key=key_2) + + +async def embeddings(session, key): + url = "http://0.0.0.0:4000/embeddings" + headers = { + "Authorization": f"Bearer {key}", + "Content-Type": "application/json", + } + data = { + "model": "text-embedding-ada-002", + "input": ["hello world"], + } + + async with session.post(url, headers=headers, json=data) as response: + status = response.status + response_text = await response.text() + + print(response_text) + print() + + if status != 200: + raise Exception(f"Request did not return a 200 status code: {status}") + + +@pytest.mark.asyncio +async def test_embeddings(): + """ + - Create key + Make embeddings call + - Create user + make embeddings call + """ + async with aiohttp.ClientSession() as session: + key_gen = await generate_key(session=session) + key = key_gen["key"] + await embeddings(session=session, key=key) + key_gen = await new_user(session=session) + key_2 = key_gen["key"] + await embeddings(session=session, key=key_2) + + +async def image_generation(session, key): + url = "http://0.0.0.0:4000/images/generations" + headers = { + "Authorization": f"Bearer {key}", + "Content-Type": "application/json", + } + data = { + "model": "dall-e-2", + "prompt": "A cute baby sea otter", + } + + async with session.post(url, headers=headers, json=data) as response: + status = response.status + response_text = await response.text() + + print(response_text) + print() + + if status != 200: + raise Exception(f"Request did not return a 200 status code: {status}") + + +@pytest.mark.asyncio +async def test_image_generation(): + """ + - Create key + Make embeddings call + - Create user + make embeddings call + """ + async with aiohttp.ClientSession() as session: + key_gen = await generate_key(session=session) + key = key_gen["key"] + await image_generation(session=session, key=key) + key_gen = await new_user(session=session) + key_2 = key_gen["key"] + await image_generation(session=session, key=key_2) diff --git a/tests/test_parallel_key_gen.py b/tests/test_parallel_key_gen.py deleted file mode 100644 index 36595b4c3..000000000 --- a/tests/test_parallel_key_gen.py +++ /dev/null @@ -1,33 +0,0 @@ -# What this tests ? -## Tests /key/generate by making 10 parallel requests, and asserting all are successful -import pytest -import asyncio -import aiohttp - - -async def generate_key(session, i): - url = "http://0.0.0.0:4000/key/generate" - headers = {"Authorization": "Bearer sk-1234", "Content-Type": "application/json"} - data = { - "models": ["azure-models"], - "aliases": {"mistral-7b": "gpt-3.5-turbo"}, - "duration": None, - } - - async with session.post(url, headers=headers, json=data) as response: - status = response.status - response_text = await response.text() - - print(f"Response {i} (Status code: {status}):") - print(response_text) - print() - - if status != 200: - raise Exception(f"Request {i} did not return a 200 status code: {status}") - - -@pytest.mark.asyncio -async def test_key_gen(): - async with aiohttp.ClientSession() as session: - tasks = [generate_key(session, i) for i in range(1, 11)] - await asyncio.gather(*tasks) diff --git a/tests/test_users.py b/tests/test_users.py new file mode 100644 index 000000000..81b6166dc --- /dev/null +++ b/tests/test_users.py @@ -0,0 +1,102 @@ +# What this tests ? +## Tests /user endpoints. +import pytest +import asyncio +import aiohttp +import time + + +async def new_user(session, i, user_id=None): + url = "http://0.0.0.0:4000/user/new" + headers = {"Authorization": "Bearer sk-1234", "Content-Type": "application/json"} + data = { + "models": ["azure-models"], + "aliases": {"mistral-7b": "gpt-3.5-turbo"}, + "duration": None, + } + + if user_id is not None: + data["user_id"] = user_id + + async with session.post(url, headers=headers, json=data) as response: + status = response.status + response_text = await response.text() + + print(f"Response {i} (Status code: {status}):") + print(response_text) + print() + + if status != 200: + raise Exception(f"Request {i} did not return a 200 status code: {status}") + + return await response.json() + + +@pytest.mark.asyncio +async def test_user_new(): + """ + Make 20 parallel calls to /user/new. Assert all worked. + """ + async with aiohttp.ClientSession() as session: + tasks = [new_user(session, i) for i in range(1, 11)] + await asyncio.gather(*tasks) + + +async def get_user_info(session, get_user, call_user): + """ + Make sure only models user has access to are returned + """ + url = f"http://0.0.0.0:4000/user/info?key={get_user}" + headers = { + "Authorization": f"Bearer {call_user}", + "Content-Type": "application/json", + } + + async with session.get(url, headers=headers) as response: + status = response.status + response_text = await response.text() + print(response_text) + print() + + if status != 200: + if call_user != get_user: + return status + else: + print(f"call_user: {call_user}; get_user: {get_user}") + raise Exception(f"Request did not return a 200 status code: {status}") + return await response.json() + + +@pytest.mark.asyncio +async def test_user_info(): + """ + Get user info + - as admin + - as user themself + - as random + """ + get_user = f"krrish_{time.time()}@berri.ai" + async with aiohttp.ClientSession() as session: + key_gen = await new_user(session, 0, user_id=get_user) + key = key_gen["key"] + ## as admin ## + await get_user_info(session=session, get_user=get_user, call_user="sk-1234") + ## as user themself ## + await get_user_info(session=session, get_user=get_user, call_user=key) + # as random user # + key_gen = await new_user(session=session, i=0) + random_key = key_gen["key"] + status = await get_user_info( + session=session, get_user=get_user, call_user=random_key + ) + assert status == 403 + + +@pytest.mark.asyncio +async def test_user_update(): + """ + Create user + Update user access to new model + Make chat completion call + """ + pass