From 4eb244c3caeeaac1c8c60daf37a07a2ddc86aef7 Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Mon, 11 Mar 2024 12:13:30 -0700 Subject: [PATCH 01/11] fix(proxy_server.py): prevent user from deleting non-user owned keys when they use ui --- litellm/proxy/proxy_server.py | 38 ++++++++++++++++++++++++++--------- litellm/proxy/utils.py | 18 ++++++++++++++--- 2 files changed, 44 insertions(+), 12 deletions(-) diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index 8510b35016..7917f14c3f 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -2103,12 +2103,14 @@ async def generate_key_helper_fn( return key_data -async def delete_verification_token(tokens: List): +async def delete_verification_token(tokens: List, user_id: Optional[str] = None): global prisma_client try: if prisma_client: # Assuming 'db' is your Prisma Client instance - deleted_tokens = await prisma_client.delete_data(tokens=tokens) + deleted_tokens = await prisma_client.delete_data( + tokens=tokens, user_id=user_id + ) else: raise Exception except Exception as e: @@ -3744,7 +3746,10 @@ async def update_key_fn(request: Request, data: UpdateKeyRequest): @router.post( "/key/delete", tags=["key management"], dependencies=[Depends(user_api_key_auth)] ) -async def delete_key_fn(data: KeyRequest): +async def delete_key_fn( + data: KeyRequest, + user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), +): """ Delete a key from the key management system. @@ -3769,11 +3774,28 @@ async def delete_key_fn(data: KeyRequest): code=status.HTTP_400_BAD_REQUEST, ) - result = await delete_verification_token(tokens=keys) - verbose_proxy_logger.debug("/key/delete - deleted_keys=", result) + ## only allow user to delete keys they own + user_id = user_api_key_dict.user_id + if ( + user_api_key_dict.user_role is not None + and user_api_key_dict.user_role == "proxy_admin" + ): + user_id = None # unless they're admin - number_deleted_keys = len(result["deleted_keys"]) - assert len(keys) == number_deleted_keys + number_deleted_keys = await delete_verification_token( + tokens=keys, user_id=user_id + ) + verbose_proxy_logger.debug("/key/delete - deleted_keys=", number_deleted_keys) + + try: + assert len(keys) == number_deleted_keys + except Exception as e: + raise HTTPException( + status_code=400, + detail={ + "error": "Not all keys passed in were deleted. This probably means you don't have access to delete all the keys passed in." + }, + ) for key in keys: user_api_key_cache.delete_cache(key) @@ -6529,8 +6551,6 @@ async def login(request: Request): algorithm="HS256", ) litellm_dashboard_ui += "?userID=" + user_id + "&token=" + jwt_token - # if a user has logged in they should be allowed to create keys - this ensures that it's set to True - general_settings["allow_user_auth"] = True return RedirectResponse(url=litellm_dashboard_ui, status_code=303) else: raise ProxyException( diff --git a/litellm/proxy/utils.py b/litellm/proxy/utils.py index 270b53647b..d95f5a5500 100644 --- a/litellm/proxy/utils.py +++ b/litellm/proxy/utils.py @@ -1356,9 +1356,12 @@ class PrismaClient: tokens: Optional[List] = None, team_id_list: Optional[List] = None, table_name: Optional[Literal["user", "key", "config", "spend", "team"]] = None, + user_id: Optional[str] = None, ): """ Allow user to delete a key(s) + + Ensure user owns that key, unless admin. """ try: if tokens is not None and isinstance(tokens, List): @@ -1369,15 +1372,23 @@ class PrismaClient: else: hashed_token = token hashed_tokens.append(hashed_token) - await self.db.litellm_verificationtoken.delete_many( - where={"token": {"in": hashed_tokens}} + filter_query: dict = {} + if user_id is not None: + filter_query = { + "AND": [{"token": {"in": hashed_tokens}}, {"user_id": user_id}] + } + else: + filter_query = {"token": {"in": hashed_tokens}} + deleted_tokens = await self.db.litellm_verificationtoken.delete_many( + where=filter_query # type: ignore ) - return {"deleted_keys": tokens} + return {"deleted_keys": deleted_tokens} elif ( table_name == "team" and team_id_list is not None and isinstance(team_id_list, List) ): + # admin only endpoint -> `/team/delete` await self.db.litellm_teamtable.delete_many( where={"team_id": {"in": team_id_list}} ) @@ -1387,6 +1398,7 @@ class PrismaClient: and team_id_list is not None and isinstance(team_id_list, List) ): + # admin only endpoint -> `/team/delete` await self.db.litellm_verificationtoken.delete_many( where={"team_id": {"in": team_id_list}} ) From 1369e18e8578eda5eb135ba8c1f9576a1a8d897c Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Mon, 11 Mar 2024 13:43:50 -0700 Subject: [PATCH 02/11] build: fix default config.yaml --- proxy_server_config.yaml | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/proxy_server_config.yaml b/proxy_server_config.yaml index 5f4875a786..0976103ef4 100644 --- a/proxy_server_config.yaml +++ b/proxy_server_config.yaml @@ -6,13 +6,13 @@ model_list: model: azure/chatgpt-v-2 api_base: https://openai-gpt-4-test-v-1.openai.azure.com/ api_version: "2023-05-15" - api_key: sk-defaultKey # use `os.environ/AZURE_API_KEY` for production. The `os.environ/` prefix tells litellm to read this from the env. See https://docs.litellm.ai/docs/simple_proxy#load-api-keys-from-vault + api_key: os.environ/AZURE_API_KEY # use `os.environ/AZURE_API_KEY` for production. The `os.environ/` prefix tells litellm to read this from the env. See https://docs.litellm.ai/docs/simple_proxy#load-api-keys-from-vault - model_name: gpt-4 litellm_params: model: azure/chatgpt-v-2 api_base: https://openai-gpt-4-test-v-1.openai.azure.com/ api_version: "2023-05-15" - api_key: sk-defaultKey # use `os.environ/AZURE_API_KEY` for production. The `os.environ/` prefix tells litellm to read this from the env. See https://docs.litellm.ai/docs/simple_proxy#load-api-keys-from-vault + api_key: os.environ/AZURE_API_KEY # use `os.environ/AZURE_API_KEY` for production. The `os.environ/` prefix tells litellm to read this from the env. See https://docs.litellm.ai/docs/simple_proxy#load-api-keys-from-vault - model_name: sagemaker-completion-model litellm_params: model: sagemaker/berri-benchmarking-Llama-2-70b-chat-hf-4 @@ -20,7 +20,7 @@ model_list: - model_name: text-embedding-ada-002 litellm_params: model: azure/azure-embedding-model - api_key: sk-defaultKey # use `os.environ/AZURE_API_KEY` for production. The `os.environ/` prefix tells litellm to read this from the env. See https://docs.litellm.ai/docs/simple_proxy#load-api-keys-from-vault + api_key: os.environ/AZURE_API_KEY # use `os.environ/AZURE_API_KEY` for production. The `os.environ/` prefix tells litellm to read this from the env. See https://docs.litellm.ai/docs/simple_proxy#load-api-keys-from-vault api_base: https://openai-gpt-4-test-v-1.openai.azure.com/ api_version: "2023-05-15" model_info: @@ -31,7 +31,7 @@ model_list: model: azure/dall-e-2 api_version: 2023-06-01-preview api_base: https://openai-gpt-4-test-v-1.openai.azure.com/ - api_key: sk-defaultKey # use `os.environ/AZURE_API_KEY` for production. The `os.environ/` prefix tells litellm to read this from the env. See https://docs.litellm.ai/docs/simple_proxy#load-api-keys-from-vault + api_key: os.environ/AZURE_API_KEY # use `os.environ/AZURE_API_KEY` for production. The `os.environ/` prefix tells litellm to read this from the env. See https://docs.litellm.ai/docs/simple_proxy#load-api-keys-from-vault litellm_settings: drop_params: True From 2addd663939a88b02eab5ccb217cfc70f269fcb6 Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Mon, 11 Mar 2024 13:54:58 -0700 Subject: [PATCH 03/11] fix(proxy_server.py): bug fix --- litellm/proxy/proxy_server.py | 6 ++++-- litellm/proxy/utils.py | 2 ++ 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index 7917f14c3f..86fd69fa6f 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -3785,10 +3785,12 @@ async def delete_key_fn( number_deleted_keys = await delete_verification_token( tokens=keys, user_id=user_id ) - verbose_proxy_logger.debug("/key/delete - deleted_keys=", number_deleted_keys) + verbose_proxy_logger.debug( + f"/key/delete - deleted_keys={number_deleted_keys['deleted_keys']}" + ) try: - assert len(keys) == number_deleted_keys + assert len(keys) == number_deleted_keys["deleted_keys"] except Exception as e: raise HTTPException( status_code=400, diff --git a/litellm/proxy/utils.py b/litellm/proxy/utils.py index d95f5a5500..4bfb87058b 100644 --- a/litellm/proxy/utils.py +++ b/litellm/proxy/utils.py @@ -1379,9 +1379,11 @@ class PrismaClient: } else: filter_query = {"token": {"in": hashed_tokens}} + deleted_tokens = await self.db.litellm_verificationtoken.delete_many( where=filter_query # type: ignore ) + verbose_proxy_logger.debug(f"deleted_tokens: {deleted_tokens}") return {"deleted_keys": deleted_tokens} elif ( table_name == "team" From f683acda61883c8b3f08e8bff69bd0fe63b18d19 Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Mon, 11 Mar 2024 13:56:10 -0700 Subject: [PATCH 04/11] build: fix default config --- proxy_server_config.yaml | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/proxy_server_config.yaml b/proxy_server_config.yaml index 0976103ef4..83bcc0626f 100644 --- a/proxy_server_config.yaml +++ b/proxy_server_config.yaml @@ -1,18 +1,16 @@ model_list: - # NOTE: This is the default config users use with Dockerfile. - # DO not expect users to pass os.environ/<> vars here, this will lead to proxy startup failing for them if they don't have the expected env vars - model_name: gpt-3.5-turbo litellm_params: model: azure/chatgpt-v-2 api_base: https://openai-gpt-4-test-v-1.openai.azure.com/ api_version: "2023-05-15" - api_key: os.environ/AZURE_API_KEY # use `os.environ/AZURE_API_KEY` for production. The `os.environ/` prefix tells litellm to read this from the env. See https://docs.litellm.ai/docs/simple_proxy#load-api-keys-from-vault + api_key: os.environ/AZURE_API_KEY # The `os.environ/` prefix tells litellm to read this from the env. See https://docs.litellm.ai/docs/simple_proxy#load-api-keys-from-vault - model_name: gpt-4 litellm_params: model: azure/chatgpt-v-2 api_base: https://openai-gpt-4-test-v-1.openai.azure.com/ api_version: "2023-05-15" - api_key: os.environ/AZURE_API_KEY # use `os.environ/AZURE_API_KEY` for production. The `os.environ/` prefix tells litellm to read this from the env. See https://docs.litellm.ai/docs/simple_proxy#load-api-keys-from-vault + api_key: os.environ/AZURE_API_KEY # The `os.environ/` prefix tells litellm to read this from the env. See https://docs.litellm.ai/docs/simple_proxy#load-api-keys-from-vault - model_name: sagemaker-completion-model litellm_params: model: sagemaker/berri-benchmarking-Llama-2-70b-chat-hf-4 @@ -20,7 +18,7 @@ model_list: - model_name: text-embedding-ada-002 litellm_params: model: azure/azure-embedding-model - api_key: os.environ/AZURE_API_KEY # use `os.environ/AZURE_API_KEY` for production. The `os.environ/` prefix tells litellm to read this from the env. See https://docs.litellm.ai/docs/simple_proxy#load-api-keys-from-vault + api_key: os.environ/AZURE_API_KEY api_base: https://openai-gpt-4-test-v-1.openai.azure.com/ api_version: "2023-05-15" model_info: @@ -28,10 +26,13 @@ model_list: base_model: text-embedding-ada-002 - model_name: dall-e-2 litellm_params: - model: azure/dall-e-2 + model: azure/ api_version: 2023-06-01-preview api_base: https://openai-gpt-4-test-v-1.openai.azure.com/ - api_key: os.environ/AZURE_API_KEY # use `os.environ/AZURE_API_KEY` for production. The `os.environ/` prefix tells litellm to read this from the env. See https://docs.litellm.ai/docs/simple_proxy#load-api-keys-from-vault + api_key: os.environ/AZURE_API_KEY + - model_name: openai-dall-e-3 + litellm_params: + model: dall-e-3 litellm_settings: drop_params: True @@ -39,7 +40,7 @@ litellm_settings: budget_duration: 30d num_retries: 5 request_timeout: 600 -general_settings: +general_settings: master_key: sk-1234 # [OPTIONAL] Only use this if you to require all calls to contain this key (Authorization: Bearer sk-1234) proxy_budget_rescheduler_min_time: 60 proxy_budget_rescheduler_max_time: 64 From e07174736feae74006fb8be4e43a3c97b8001f2d Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Mon, 11 Mar 2024 13:57:40 -0700 Subject: [PATCH 05/11] refactor(main.py): trigger new build --- litellm/main.py | 1 - 1 file changed, 1 deletion(-) diff --git a/litellm/main.py b/litellm/main.py index 41848028ee..114b469488 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -12,7 +12,6 @@ from typing import Any, Literal, Union, BinaryIO from functools import partial import dotenv, traceback, random, asyncio, time, contextvars from copy import deepcopy - import httpx import litellm from ._logging import verbose_logger From d1644db8ce0e94541729ec35619f85c492a3f400 Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Mon, 11 Mar 2024 14:18:01 -0700 Subject: [PATCH 06/11] test(test_key_generate_prisma.py): fix test to only let admin delete a key --- litellm/proxy/proxy_server.py | 3 +++ litellm/tests/test_key_generate_prisma.py | 13 ++++++++++++- 2 files changed, 15 insertions(+), 1 deletion(-) diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index 86fd69fa6f..45f432f9df 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -3776,6 +3776,9 @@ async def delete_key_fn( ## only allow user to delete keys they own user_id = user_api_key_dict.user_id + verbose_proxy_logger.debug( + f"user_api_key_dict.user_role: {user_api_key_dict.user_role}" + ) if ( user_api_key_dict.user_role is not None and user_api_key_dict.user_role == "proxy_admin" diff --git a/litellm/tests/test_key_generate_prisma.py b/litellm/tests/test_key_generate_prisma.py index 524eee6f29..74ff61abd1 100644 --- a/litellm/tests/test_key_generate_prisma.py +++ b/litellm/tests/test_key_generate_prisma.py @@ -737,8 +737,19 @@ def test_delete_key(prisma_client): delete_key_request = KeyRequest(keys=[generated_key]) + bearer_token = "Bearer sk-1234" + + request = Request(scope={"type": "http"}) + request._url = URL(url="/key/delete") + + # use generated key to auth in + result = await user_api_key_auth(request=request, api_key=bearer_token) + print(f"result: {result}") + result.user_role = "proxy_admin" # delete the key - result_delete_key = await delete_key_fn(data=delete_key_request) + result_delete_key = await delete_key_fn( + data=delete_key_request, user_api_key_dict=result + ) print("result from delete key", result_delete_key) assert result_delete_key == {"deleted_keys": [generated_key]} From c4db5d4c33eb4b7b0c7dfb87895c14bc3ca944e9 Mon Sep 17 00:00:00 2001 From: Elad Segal Date: Mon, 11 Mar 2024 23:35:03 +0200 Subject: [PATCH 07/11] Make `argon2-cffi` optional, used only for proxy --- pyproject.toml | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 06cbe89a7d..0d44b366bc 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -22,7 +22,6 @@ click = "*" jinja2 = "^3.1.2" aiohttp = "*" requests = "^2.31.0" -argon2-cffi = "^23.1.0" uvicorn = {version = "^0.22.0", optional = true} gunicorn = {version = "^21.2.0", optional = true} @@ -36,6 +35,7 @@ streamlit = {version = "^1.29.0", optional = true} fastapi-sso = { version = "^0.10.0", optional = true } PyJWT = { version = "^2.8.0", optional = true } python-multipart = { version = "^0.0.6", optional = true } +argon2-cffi = { version = "^23.1.0", optional = true } [tool.poetry.extras] proxy = [ @@ -50,6 +50,7 @@ proxy = [ "fastapi-sso", "PyJWT", "python-multipart", + "argon2-cffi", ] extra_proxy = [ From 9735250db78ff819c38e174745a49950aa8993c7 Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Mon, 11 Mar 2024 14:51:22 -0700 Subject: [PATCH 08/11] fix(router.py): support fallbacks / retries with sync embedding calls --- litellm/router.py | 111 +++++++++++++------- litellm/tests/test_router_fallbacks.py | 51 +++++++++ litellm/tests/test_router_with_fallbacks.py | 56 ++++++++++ 3 files changed, 181 insertions(+), 37 deletions(-) create mode 100644 litellm/tests/test_router_with_fallbacks.py diff --git a/litellm/router.py b/litellm/router.py index e4b14dd097..7f23e19d74 100644 --- a/litellm/router.py +++ b/litellm/router.py @@ -970,44 +970,81 @@ class Router: is_async: Optional[bool] = False, **kwargs, ) -> Union[List[float], None]: - # pick the one that is available (lowest TPM/RPM) - deployment = self.get_available_deployment( - model=model, - input=input, - specific_deployment=kwargs.pop("specific_deployment", None), - ) - kwargs.setdefault("model_info", {}) - kwargs.setdefault("metadata", {}).update( - {"model_group": model, "deployment": deployment["litellm_params"]["model"]} - ) # [TODO]: move to using async_function_with_fallbacks - data = deployment["litellm_params"].copy() - for k, v in self.default_litellm_params.items(): + try: + kwargs["model"] = model + kwargs["input"] = input + kwargs["original_function"] = self._embedding + kwargs["num_retries"] = kwargs.get("num_retries", self.num_retries) + timeout = kwargs.get("request_timeout", self.timeout) + kwargs.setdefault("metadata", {}).update({"model_group": model}) + response = self.function_with_fallbacks(**kwargs) + return response + except Exception as e: + raise e + + def _embedding(self, input: Union[str, List], model: str, **kwargs): + try: + verbose_router_logger.debug( + f"Inside embedding()- model: {model}; kwargs: {kwargs}" + ) + deployment = self.get_available_deployment( + model=model, + input=input, + specific_deployment=kwargs.pop("specific_deployment", None), + ) + kwargs.setdefault("metadata", {}).update( + { + "deployment": deployment["litellm_params"]["model"], + "model_info": deployment.get("model_info", {}), + } + ) + kwargs["model_info"] = deployment.get("model_info", {}) + data = deployment["litellm_params"].copy() + model_name = data["model"] + for k, v in self.default_litellm_params.items(): + if ( + k not in kwargs + ): # prioritize model-specific params > default router params + kwargs[k] = v + elif k == "metadata": + kwargs[k].update(v) + + potential_model_client = self._get_client( + deployment=deployment, kwargs=kwargs, client_type="sync" + ) + # check if provided keys == client keys # + dynamic_api_key = kwargs.get("api_key", None) if ( - k not in kwargs - ): # prioritize model-specific params > default router params - kwargs[k] = v - elif k == "metadata": - kwargs[k].update(v) - potential_model_client = self._get_client(deployment=deployment, kwargs=kwargs) - # check if provided keys == client keys # - dynamic_api_key = kwargs.get("api_key", None) - if ( - dynamic_api_key is not None - and potential_model_client is not None - and dynamic_api_key != potential_model_client.api_key - ): - model_client = None - else: - model_client = potential_model_client - return litellm.embedding( - **{ - **data, - "input": input, - "caching": self.cache_responses, - "client": model_client, - **kwargs, - } - ) + dynamic_api_key is not None + and potential_model_client is not None + and dynamic_api_key != potential_model_client.api_key + ): + model_client = None + else: + model_client = potential_model_client + + self.total_calls[model_name] += 1 + response = litellm.embedding( + **{ + **data, + "input": input, + "caching": self.cache_responses, + "client": model_client, + **kwargs, + } + ) + self.success_calls[model_name] += 1 + verbose_router_logger.info( + f"litellm.embedding(model={model_name})\033[32m 200 OK\033[0m" + ) + return response + except Exception as e: + verbose_router_logger.info( + f"litellm.embedding(model={model_name})\033[31m Exception {str(e)}\033[0m" + ) + if model_name is not None: + self.fail_calls[model_name] += 1 + raise e async def aembedding( self, diff --git a/litellm/tests/test_router_fallbacks.py b/litellm/tests/test_router_fallbacks.py index 5d17d36c9f..98a2449f06 100644 --- a/litellm/tests/test_router_fallbacks.py +++ b/litellm/tests/test_router_fallbacks.py @@ -227,6 +227,57 @@ async def test_async_fallbacks(): # test_async_fallbacks() +def test_sync_fallbacks_embeddings(): + litellm.set_verbose = False + model_list = [ + { # list of model deployments + "model_name": "bad-azure-embedding-model", # openai model name + "litellm_params": { # params for litellm completion/embedding call + "model": "azure/azure-embedding-model", + "api_key": "bad-key", + "api_version": os.getenv("AZURE_API_VERSION"), + "api_base": os.getenv("AZURE_API_BASE"), + }, + "tpm": 240000, + "rpm": 1800, + }, + { # list of model deployments + "model_name": "good-azure-embedding-model", # openai model name + "litellm_params": { # params for litellm completion/embedding call + "model": "azure/azure-embedding-model", + "api_key": os.getenv("AZURE_API_KEY"), + "api_version": os.getenv("AZURE_API_VERSION"), + "api_base": os.getenv("AZURE_API_BASE"), + }, + "tpm": 240000, + "rpm": 1800, + }, + ] + + router = Router( + model_list=model_list, + fallbacks=[{"bad-azure-embedding-model": ["good-azure-embedding-model"]}], + set_verbose=False, + ) + customHandler = MyCustomHandler() + litellm.callbacks = [customHandler] + user_message = "Hello, how are you?" + input = [user_message] + try: + kwargs = {"model": "bad-azure-embedding-model", "input": input} + response = router.embedding(**kwargs) + print(f"customHandler.previous_models: {customHandler.previous_models}") + time.sleep(0.05) # allow a delay as success_callbacks are on a separate thread + assert customHandler.previous_models == 1 # 0 retries, 1 fallback + router.reset() + except litellm.Timeout as e: + pass + except Exception as e: + pytest.fail(f"An exception occurred: {e}") + finally: + router.reset() + + @pytest.mark.asyncio async def test_async_fallbacks_embeddings(): litellm.set_verbose = False diff --git a/litellm/tests/test_router_with_fallbacks.py b/litellm/tests/test_router_with_fallbacks.py new file mode 100644 index 0000000000..deabf73750 --- /dev/null +++ b/litellm/tests/test_router_with_fallbacks.py @@ -0,0 +1,56 @@ +# [LOCAL TEST] - runs against mock openai proxy +# # What this tests? +# ## This tests if fallbacks works for 429 errors + +# import sys, os, time +# import traceback, asyncio +# import pytest + +# sys.path.insert( +# 0, os.path.abspath("../..") +# ) # Adds the parent directory to the system path +# import litellm +# from litellm import Router + +# model_list = [ +# { # list of model deployments +# "model_name": "text-embedding-ada-002", # model alias +# "litellm_params": { # params for litellm completion/embedding call +# "model": "text-embedding-ada-002", # actual model name +# "api_key": "sk-fakekey", +# "api_base": "http://0.0.0.0:8080", +# }, +# "tpm": 1000, +# "rpm": 6, +# }, +# { +# "model_name": "text-embedding-ada-002-fallback", +# "litellm_params": { # params for litellm completion/embedding call +# "model": "openai/text-embedding-ada-002-anything-else", # actual model name +# "api_key": "sk-fakekey2", +# "api_base": "http://0.0.0.0:8080", +# }, +# "tpm": 1000, +# "rpm": 6, +# }, +# ] + +# router = Router( +# model_list=model_list, +# fallbacks=[ +# {"text-embedding-ada-002": ["text-embedding-ada-002-fallback"]}, +# {"text-embedding-ada-002-fallback": ["text-embedding-ada-002"]}, +# ], +# set_verbose=True, +# num_retries=0, +# debug_level="INFO", +# routing_strategy="usage-based-routing", +# ) + + +# def test_embedding_with_fallbacks(): +# response = router.embedding(model="text-embedding-ada-002", input=["Hello world"]) +# print(f"response: {response}") + + +# test_embedding_with_fallbacks() From 64aeb088d9f9d4d0472735392c6b41a5920348e5 Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Mon, 11 Mar 2024 14:59:11 -0700 Subject: [PATCH 09/11] test(test_key_generate_prisma.py): fix test --- litellm/tests/test_key_generate_prisma.py | 1 + 1 file changed, 1 insertion(+) diff --git a/litellm/tests/test_key_generate_prisma.py b/litellm/tests/test_key_generate_prisma.py index 74ff61abd1..91c6d24dbd 100644 --- a/litellm/tests/test_key_generate_prisma.py +++ b/litellm/tests/test_key_generate_prisma.py @@ -722,6 +722,7 @@ def test_delete_key(prisma_client): setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client) setattr(litellm.proxy.proxy_server, "master_key", "sk-1234") + setattr(litellm.proxy.proxy_server, "user_custom_auth", None) try: async def test(): From 917f92800de03322682a00ce68dc8e91aee7ad86 Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Mon, 11 Mar 2024 15:24:42 -0700 Subject: [PATCH 10/11] test(test_key_generate_prisma.py): fix tests --- litellm/tests/test_key_generate_prisma.py | 43 ++++++++++++++++++++--- 1 file changed, 39 insertions(+), 4 deletions(-) diff --git a/litellm/tests/test_key_generate_prisma.py b/litellm/tests/test_key_generate_prisma.py index 91c6d24dbd..62f6c38a95 100644 --- a/litellm/tests/test_key_generate_prisma.py +++ b/litellm/tests/test_key_generate_prisma.py @@ -788,7 +788,19 @@ def test_delete_key_auth(prisma_client): delete_key_request = KeyRequest(keys=[generated_key]) # delete the key - result_delete_key = await delete_key_fn(data=delete_key_request) + bearer_token = "Bearer sk-1234" + + request = Request(scope={"type": "http"}) + request._url = URL(url="/key/delete") + + # use generated key to auth in + result = await user_api_key_auth(request=request, api_key=bearer_token) + print(f"result: {result}") + result.user_role = "proxy_admin" + + result_delete_key = await delete_key_fn( + data=delete_key_request, user_api_key_dict=result + ) print("result from delete key", result_delete_key) assert result_delete_key == {"deleted_keys": [generated_key]} @@ -803,6 +815,7 @@ def test_delete_key_auth(prisma_client): ) # use generated key to auth in + bearer_token = "Bearer " + generated_key result = await user_api_key_auth(request=request, api_key=bearer_token) print("got result", result) pytest.fail(f"This should have failed!. IT's an invalid key") @@ -847,9 +860,19 @@ def test_generate_and_call_key_info(prisma_client): # cleanup - delete key delete_key_request = KeyRequest(keys=[generated_key]) + bearer_token = "Bearer sk-1234" - # delete the key - await delete_key_fn(data=delete_key_request) + request = Request(scope={"type": "http"}) + request._url = URL(url="/key/delete") + + # use generated key to auth in + result = await user_api_key_auth(request=request, api_key=bearer_token) + print(f"result: {result}") + result.user_role = "proxy_admin" + + result_delete_key = await delete_key_fn( + data=delete_key_request, user_api_key_dict=result + ) asyncio.run(test()) except Exception as e: @@ -928,7 +951,19 @@ def test_generate_and_update_key(prisma_client): delete_key_request = KeyRequest(keys=[generated_key]) # delete the key - await delete_key_fn(data=delete_key_request) + bearer_token = "Bearer sk-1234" + + request = Request(scope={"type": "http"}) + request._url = URL(url="/key/delete") + + # use generated key to auth in + result = await user_api_key_auth(request=request, api_key=bearer_token) + print(f"result: {result}") + result.user_role = "proxy_admin" + + result_delete_key = await delete_key_fn( + data=delete_key_request, user_api_key_dict=result + ) asyncio.run(test()) except Exception as e: From 3dda6f0cf30ef2de85b0b9bcf638284ce6dc6161 Mon Sep 17 00:00:00 2001 From: ishaan-jaff Date: Mon, 11 Mar 2024 16:38:31 -0700 Subject: [PATCH 11/11] (fix) test_mem_usage --- litellm/tests/test_mem_usage.py | 242 ++++++++++++++++---------------- 1 file changed, 121 insertions(+), 121 deletions(-) diff --git a/litellm/tests/test_mem_usage.py b/litellm/tests/test_mem_usage.py index 31e15c6d6b..95bf3993f7 100644 --- a/litellm/tests/test_mem_usage.py +++ b/litellm/tests/test_mem_usage.py @@ -1,149 +1,149 @@ -#### What this tests #### +# #### What this tests #### -from memory_profiler import profile, memory_usage -import sys, os, time -import traceback, asyncio -import pytest +# from memory_profiler import profile, memory_usage +# import sys, os, time +# import traceback, asyncio +# import pytest -sys.path.insert( - 0, os.path.abspath("../..") -) # Adds the parent directory to the system path -import litellm -from litellm import Router -from concurrent.futures import ThreadPoolExecutor -from collections import defaultdict -from dotenv import load_dotenv -import uuid -import tracemalloc -import objgraph +# sys.path.insert( +# 0, os.path.abspath("../..") +# ) # Adds the parent directory to the system path +# import litellm +# from litellm import Router +# from concurrent.futures import ThreadPoolExecutor +# from collections import defaultdict +# from dotenv import load_dotenv +# import uuid +# import tracemalloc +# import objgraph -objgraph.growth(shortnames=True) -objgraph.show_most_common_types(limit=10) +# objgraph.growth(shortnames=True) +# objgraph.show_most_common_types(limit=10) -from mem_top import mem_top +# from mem_top import mem_top -load_dotenv() +# load_dotenv() -model_list = [ - { - "model_name": "gpt-3.5-turbo", # openai model name - "litellm_params": { # params for litellm completion/embedding call - "model": "azure/chatgpt-v-2", - "api_key": os.getenv("AZURE_API_KEY"), - "api_version": os.getenv("AZURE_API_VERSION"), - "api_base": os.getenv("AZURE_API_BASE"), - }, - "tpm": 240000, - "rpm": 1800, - }, - { - "model_name": "bad-model", # openai model name - "litellm_params": { # params for litellm completion/embedding call - "model": "azure/chatgpt-v-2", - "api_key": "bad-key", - "api_version": os.getenv("AZURE_API_VERSION"), - "api_base": os.getenv("AZURE_API_BASE"), - }, - "tpm": 240000, - "rpm": 1800, - }, - { - "model_name": "text-embedding-ada-002", - "litellm_params": { - "model": "azure/azure-embedding-model", - "api_key": os.environ["AZURE_API_KEY"], - "api_base": os.environ["AZURE_API_BASE"], - }, - "tpm": 100000, - "rpm": 10000, - }, -] -litellm.set_verbose = True -litellm.cache = litellm.Cache( - type="s3", s3_bucket_name="litellm-my-test-bucket-2", s3_region_name="us-east-1" -) -router = Router( - model_list=model_list, - fallbacks=[ - {"bad-model": ["gpt-3.5-turbo"]}, - ], -) # type: ignore +# model_list = [ +# { +# "model_name": "gpt-3.5-turbo", # openai model name +# "litellm_params": { # params for litellm completion/embedding call +# "model": "azure/chatgpt-v-2", +# "api_key": os.getenv("AZURE_API_KEY"), +# "api_version": os.getenv("AZURE_API_VERSION"), +# "api_base": os.getenv("AZURE_API_BASE"), +# }, +# "tpm": 240000, +# "rpm": 1800, +# }, +# { +# "model_name": "bad-model", # openai model name +# "litellm_params": { # params for litellm completion/embedding call +# "model": "azure/chatgpt-v-2", +# "api_key": "bad-key", +# "api_version": os.getenv("AZURE_API_VERSION"), +# "api_base": os.getenv("AZURE_API_BASE"), +# }, +# "tpm": 240000, +# "rpm": 1800, +# }, +# { +# "model_name": "text-embedding-ada-002", +# "litellm_params": { +# "model": "azure/azure-embedding-model", +# "api_key": os.environ["AZURE_API_KEY"], +# "api_base": os.environ["AZURE_API_BASE"], +# }, +# "tpm": 100000, +# "rpm": 10000, +# }, +# ] +# litellm.set_verbose = True +# litellm.cache = litellm.Cache( +# type="s3", s3_bucket_name="litellm-my-test-bucket-2", s3_region_name="us-east-1" +# ) +# router = Router( +# model_list=model_list, +# fallbacks=[ +# {"bad-model": ["gpt-3.5-turbo"]}, +# ], +# ) # type: ignore -async def router_acompletion(): - # embedding call - question = f"This is a test: {uuid.uuid4()}" * 1 +# async def router_acompletion(): +# # embedding call +# question = f"This is a test: {uuid.uuid4()}" * 1 - response = await router.acompletion( - model="bad-model", messages=[{"role": "user", "content": question}] - ) - print("completion-resp", response) - return response +# response = await router.acompletion( +# model="bad-model", messages=[{"role": "user", "content": question}] +# ) +# print("completion-resp", response) +# return response -async def main(): - for i in range(1): - start = time.time() - n = 15 # Number of concurrent tasks - tasks = [router_acompletion() for _ in range(n)] +# async def main(): +# for i in range(1): +# start = time.time() +# n = 15 # Number of concurrent tasks +# tasks = [router_acompletion() for _ in range(n)] - chat_completions = await asyncio.gather(*tasks) +# chat_completions = await asyncio.gather(*tasks) - successful_completions = [c for c in chat_completions if c is not None] +# successful_completions = [c for c in chat_completions if c is not None] - # Write errors to error_log.txt - with open("error_log.txt", "a") as error_log: - for completion in chat_completions: - if isinstance(completion, str): - error_log.write(completion + "\n") +# # Write errors to error_log.txt +# with open("error_log.txt", "a") as error_log: +# for completion in chat_completions: +# if isinstance(completion, str): +# error_log.write(completion + "\n") - print(n, time.time() - start, len(successful_completions)) - print() - print(vars(router)) +# print(n, time.time() - start, len(successful_completions)) +# print() +# print(vars(router)) -if __name__ == "__main__": - # Blank out contents of error_log.txt - open("error_log.txt", "w").close() +# if __name__ == "__main__": +# # Blank out contents of error_log.txt +# open("error_log.txt", "w").close() - import tracemalloc +# import tracemalloc - tracemalloc.start(25) +# tracemalloc.start(25) - # ... run your application ... +# # ... run your application ... - asyncio.run(main()) - print(mem_top()) +# asyncio.run(main()) +# print(mem_top()) - snapshot = tracemalloc.take_snapshot() - # top_stats = snapshot.statistics('lineno') +# snapshot = tracemalloc.take_snapshot() +# # top_stats = snapshot.statistics('lineno') - # print("[ Top 10 ]") - # for stat in top_stats[:50]: - # print(stat) +# # print("[ Top 10 ]") +# # for stat in top_stats[:50]: +# # print(stat) - top_stats = snapshot.statistics("traceback") +# top_stats = snapshot.statistics("traceback") - # pick the biggest memory block - stat = top_stats[0] - print("%s memory blocks: %.1f KiB" % (stat.count, stat.size / 1024)) - for line in stat.traceback.format(): - print(line) - print() - stat = top_stats[1] - print("%s memory blocks: %.1f KiB" % (stat.count, stat.size / 1024)) - for line in stat.traceback.format(): - print(line) +# # pick the biggest memory block +# stat = top_stats[0] +# print("%s memory blocks: %.1f KiB" % (stat.count, stat.size / 1024)) +# for line in stat.traceback.format(): +# print(line) +# print() +# stat = top_stats[1] +# print("%s memory blocks: %.1f KiB" % (stat.count, stat.size / 1024)) +# for line in stat.traceback.format(): +# print(line) - print() - stat = top_stats[2] - print("%s memory blocks: %.1f KiB" % (stat.count, stat.size / 1024)) - for line in stat.traceback.format(): - print(line) - print() +# print() +# stat = top_stats[2] +# print("%s memory blocks: %.1f KiB" % (stat.count, stat.size / 1024)) +# for line in stat.traceback.format(): +# print(line) +# print() - stat = top_stats[3] - print("%s memory blocks: %.1f KiB" % (stat.count, stat.size / 1024)) - for line in stat.traceback.format(): - print(line) +# stat = top_stats[3] +# print("%s memory blocks: %.1f KiB" % (stat.count, stat.size / 1024)) +# for line in stat.traceback.format(): +# print(line)