diff --git a/litellm/main.py b/litellm/main.py index 41848028ee..114b469488 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -12,7 +12,6 @@ from typing import Any, Literal, Union, BinaryIO from functools import partial import dotenv, traceback, random, asyncio, time, contextvars from copy import deepcopy - import httpx import litellm from ._logging import verbose_logger diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index 8510b35016..45f432f9df 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -2103,12 +2103,14 @@ async def generate_key_helper_fn( return key_data -async def delete_verification_token(tokens: List): +async def delete_verification_token(tokens: List, user_id: Optional[str] = None): global prisma_client try: if prisma_client: # Assuming 'db' is your Prisma Client instance - deleted_tokens = await prisma_client.delete_data(tokens=tokens) + deleted_tokens = await prisma_client.delete_data( + tokens=tokens, user_id=user_id + ) else: raise Exception except Exception as e: @@ -3744,7 +3746,10 @@ async def update_key_fn(request: Request, data: UpdateKeyRequest): @router.post( "/key/delete", tags=["key management"], dependencies=[Depends(user_api_key_auth)] ) -async def delete_key_fn(data: KeyRequest): +async def delete_key_fn( + data: KeyRequest, + user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), +): """ Delete a key from the key management system. @@ -3769,11 +3774,33 @@ async def delete_key_fn(data: KeyRequest): code=status.HTTP_400_BAD_REQUEST, ) - result = await delete_verification_token(tokens=keys) - verbose_proxy_logger.debug("/key/delete - deleted_keys=", result) + ## only allow user to delete keys they own + user_id = user_api_key_dict.user_id + verbose_proxy_logger.debug( + f"user_api_key_dict.user_role: {user_api_key_dict.user_role}" + ) + if ( + user_api_key_dict.user_role is not None + and user_api_key_dict.user_role == "proxy_admin" + ): + user_id = None # unless they're admin - number_deleted_keys = len(result["deleted_keys"]) - assert len(keys) == number_deleted_keys + number_deleted_keys = await delete_verification_token( + tokens=keys, user_id=user_id + ) + verbose_proxy_logger.debug( + f"/key/delete - deleted_keys={number_deleted_keys['deleted_keys']}" + ) + + try: + assert len(keys) == number_deleted_keys["deleted_keys"] + except Exception as e: + raise HTTPException( + status_code=400, + detail={ + "error": "Not all keys passed in were deleted. This probably means you don't have access to delete all the keys passed in." + }, + ) for key in keys: user_api_key_cache.delete_cache(key) @@ -6529,8 +6556,6 @@ async def login(request: Request): algorithm="HS256", ) litellm_dashboard_ui += "?userID=" + user_id + "&token=" + jwt_token - # if a user has logged in they should be allowed to create keys - this ensures that it's set to True - general_settings["allow_user_auth"] = True return RedirectResponse(url=litellm_dashboard_ui, status_code=303) else: raise ProxyException( diff --git a/litellm/proxy/utils.py b/litellm/proxy/utils.py index 270b53647b..4bfb87058b 100644 --- a/litellm/proxy/utils.py +++ b/litellm/proxy/utils.py @@ -1356,9 +1356,12 @@ class PrismaClient: tokens: Optional[List] = None, team_id_list: Optional[List] = None, table_name: Optional[Literal["user", "key", "config", "spend", "team"]] = None, + user_id: Optional[str] = None, ): """ Allow user to delete a key(s) + + Ensure user owns that key, unless admin. """ try: if tokens is not None and isinstance(tokens, List): @@ -1369,15 +1372,25 @@ class PrismaClient: else: hashed_token = token hashed_tokens.append(hashed_token) - await self.db.litellm_verificationtoken.delete_many( - where={"token": {"in": hashed_tokens}} + filter_query: dict = {} + if user_id is not None: + filter_query = { + "AND": [{"token": {"in": hashed_tokens}}, {"user_id": user_id}] + } + else: + filter_query = {"token": {"in": hashed_tokens}} + + deleted_tokens = await self.db.litellm_verificationtoken.delete_many( + where=filter_query # type: ignore ) - return {"deleted_keys": tokens} + verbose_proxy_logger.debug(f"deleted_tokens: {deleted_tokens}") + return {"deleted_keys": deleted_tokens} elif ( table_name == "team" and team_id_list is not None and isinstance(team_id_list, List) ): + # admin only endpoint -> `/team/delete` await self.db.litellm_teamtable.delete_many( where={"team_id": {"in": team_id_list}} ) @@ -1387,6 +1400,7 @@ class PrismaClient: and team_id_list is not None and isinstance(team_id_list, List) ): + # admin only endpoint -> `/team/delete` await self.db.litellm_verificationtoken.delete_many( where={"team_id": {"in": team_id_list}} ) diff --git a/litellm/router.py b/litellm/router.py index 2869e1fb45..c6a2bc8fed 100644 --- a/litellm/router.py +++ b/litellm/router.py @@ -967,44 +967,81 @@ class Router: is_async: Optional[bool] = False, **kwargs, ) -> Union[List[float], None]: - # pick the one that is available (lowest TPM/RPM) - deployment = self.get_available_deployment( - model=model, - input=input, - specific_deployment=kwargs.pop("specific_deployment", None), - ) - kwargs.setdefault("model_info", {}) - kwargs.setdefault("metadata", {}).update( - {"model_group": model, "deployment": deployment["litellm_params"]["model"]} - ) # [TODO]: move to using async_function_with_fallbacks - data = deployment["litellm_params"].copy() - for k, v in self.default_litellm_params.items(): + try: + kwargs["model"] = model + kwargs["input"] = input + kwargs["original_function"] = self._embedding + kwargs["num_retries"] = kwargs.get("num_retries", self.num_retries) + timeout = kwargs.get("request_timeout", self.timeout) + kwargs.setdefault("metadata", {}).update({"model_group": model}) + response = self.function_with_fallbacks(**kwargs) + return response + except Exception as e: + raise e + + def _embedding(self, input: Union[str, List], model: str, **kwargs): + try: + verbose_router_logger.debug( + f"Inside embedding()- model: {model}; kwargs: {kwargs}" + ) + deployment = self.get_available_deployment( + model=model, + input=input, + specific_deployment=kwargs.pop("specific_deployment", None), + ) + kwargs.setdefault("metadata", {}).update( + { + "deployment": deployment["litellm_params"]["model"], + "model_info": deployment.get("model_info", {}), + } + ) + kwargs["model_info"] = deployment.get("model_info", {}) + data = deployment["litellm_params"].copy() + model_name = data["model"] + for k, v in self.default_litellm_params.items(): + if ( + k not in kwargs + ): # prioritize model-specific params > default router params + kwargs[k] = v + elif k == "metadata": + kwargs[k].update(v) + + potential_model_client = self._get_client( + deployment=deployment, kwargs=kwargs, client_type="sync" + ) + # check if provided keys == client keys # + dynamic_api_key = kwargs.get("api_key", None) if ( - k not in kwargs - ): # prioritize model-specific params > default router params - kwargs[k] = v - elif k == "metadata": - kwargs[k].update(v) - potential_model_client = self._get_client(deployment=deployment, kwargs=kwargs) - # check if provided keys == client keys # - dynamic_api_key = kwargs.get("api_key", None) - if ( - dynamic_api_key is not None - and potential_model_client is not None - and dynamic_api_key != potential_model_client.api_key - ): - model_client = None - else: - model_client = potential_model_client - return litellm.embedding( - **{ - **data, - "input": input, - "caching": self.cache_responses, - "client": model_client, - **kwargs, - } - ) + dynamic_api_key is not None + and potential_model_client is not None + and dynamic_api_key != potential_model_client.api_key + ): + model_client = None + else: + model_client = potential_model_client + + self.total_calls[model_name] += 1 + response = litellm.embedding( + **{ + **data, + "input": input, + "caching": self.cache_responses, + "client": model_client, + **kwargs, + } + ) + self.success_calls[model_name] += 1 + verbose_router_logger.info( + f"litellm.embedding(model={model_name})\033[32m 200 OK\033[0m" + ) + return response + except Exception as e: + verbose_router_logger.info( + f"litellm.embedding(model={model_name})\033[31m Exception {str(e)}\033[0m" + ) + if model_name is not None: + self.fail_calls[model_name] += 1 + raise e async def aembedding( self, diff --git a/litellm/tests/test_key_generate_prisma.py b/litellm/tests/test_key_generate_prisma.py index 524eee6f29..62f6c38a95 100644 --- a/litellm/tests/test_key_generate_prisma.py +++ b/litellm/tests/test_key_generate_prisma.py @@ -722,6 +722,7 @@ def test_delete_key(prisma_client): setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client) setattr(litellm.proxy.proxy_server, "master_key", "sk-1234") + setattr(litellm.proxy.proxy_server, "user_custom_auth", None) try: async def test(): @@ -737,8 +738,19 @@ def test_delete_key(prisma_client): delete_key_request = KeyRequest(keys=[generated_key]) + bearer_token = "Bearer sk-1234" + + request = Request(scope={"type": "http"}) + request._url = URL(url="/key/delete") + + # use generated key to auth in + result = await user_api_key_auth(request=request, api_key=bearer_token) + print(f"result: {result}") + result.user_role = "proxy_admin" # delete the key - result_delete_key = await delete_key_fn(data=delete_key_request) + result_delete_key = await delete_key_fn( + data=delete_key_request, user_api_key_dict=result + ) print("result from delete key", result_delete_key) assert result_delete_key == {"deleted_keys": [generated_key]} @@ -776,7 +788,19 @@ def test_delete_key_auth(prisma_client): delete_key_request = KeyRequest(keys=[generated_key]) # delete the key - result_delete_key = await delete_key_fn(data=delete_key_request) + bearer_token = "Bearer sk-1234" + + request = Request(scope={"type": "http"}) + request._url = URL(url="/key/delete") + + # use generated key to auth in + result = await user_api_key_auth(request=request, api_key=bearer_token) + print(f"result: {result}") + result.user_role = "proxy_admin" + + result_delete_key = await delete_key_fn( + data=delete_key_request, user_api_key_dict=result + ) print("result from delete key", result_delete_key) assert result_delete_key == {"deleted_keys": [generated_key]} @@ -791,6 +815,7 @@ def test_delete_key_auth(prisma_client): ) # use generated key to auth in + bearer_token = "Bearer " + generated_key result = await user_api_key_auth(request=request, api_key=bearer_token) print("got result", result) pytest.fail(f"This should have failed!. IT's an invalid key") @@ -835,9 +860,19 @@ def test_generate_and_call_key_info(prisma_client): # cleanup - delete key delete_key_request = KeyRequest(keys=[generated_key]) + bearer_token = "Bearer sk-1234" - # delete the key - await delete_key_fn(data=delete_key_request) + request = Request(scope={"type": "http"}) + request._url = URL(url="/key/delete") + + # use generated key to auth in + result = await user_api_key_auth(request=request, api_key=bearer_token) + print(f"result: {result}") + result.user_role = "proxy_admin" + + result_delete_key = await delete_key_fn( + data=delete_key_request, user_api_key_dict=result + ) asyncio.run(test()) except Exception as e: @@ -916,7 +951,19 @@ def test_generate_and_update_key(prisma_client): delete_key_request = KeyRequest(keys=[generated_key]) # delete the key - await delete_key_fn(data=delete_key_request) + bearer_token = "Bearer sk-1234" + + request = Request(scope={"type": "http"}) + request._url = URL(url="/key/delete") + + # use generated key to auth in + result = await user_api_key_auth(request=request, api_key=bearer_token) + print(f"result: {result}") + result.user_role = "proxy_admin" + + result_delete_key = await delete_key_fn( + data=delete_key_request, user_api_key_dict=result + ) asyncio.run(test()) except Exception as e: diff --git a/litellm/tests/test_mem_usage.py b/litellm/tests/test_mem_usage.py index 90540ddd03..4a804b4033 100644 --- a/litellm/tests/test_mem_usage.py +++ b/litellm/tests/test_mem_usage.py @@ -85,7 +85,7 @@ # async def main(): # for i in range(1): # start = time.time() -# n = 20 # Number of concurrent tasks +# n = 15 # Number of concurrent tasks # tasks = [router_acompletion() for _ in range(n)] # chat_completions = await asyncio.gather(*tasks) diff --git a/litellm/tests/test_router_fallbacks.py b/litellm/tests/test_router_fallbacks.py index 5d17d36c9f..98a2449f06 100644 --- a/litellm/tests/test_router_fallbacks.py +++ b/litellm/tests/test_router_fallbacks.py @@ -227,6 +227,57 @@ async def test_async_fallbacks(): # test_async_fallbacks() +def test_sync_fallbacks_embeddings(): + litellm.set_verbose = False + model_list = [ + { # list of model deployments + "model_name": "bad-azure-embedding-model", # openai model name + "litellm_params": { # params for litellm completion/embedding call + "model": "azure/azure-embedding-model", + "api_key": "bad-key", + "api_version": os.getenv("AZURE_API_VERSION"), + "api_base": os.getenv("AZURE_API_BASE"), + }, + "tpm": 240000, + "rpm": 1800, + }, + { # list of model deployments + "model_name": "good-azure-embedding-model", # openai model name + "litellm_params": { # params for litellm completion/embedding call + "model": "azure/azure-embedding-model", + "api_key": os.getenv("AZURE_API_KEY"), + "api_version": os.getenv("AZURE_API_VERSION"), + "api_base": os.getenv("AZURE_API_BASE"), + }, + "tpm": 240000, + "rpm": 1800, + }, + ] + + router = Router( + model_list=model_list, + fallbacks=[{"bad-azure-embedding-model": ["good-azure-embedding-model"]}], + set_verbose=False, + ) + customHandler = MyCustomHandler() + litellm.callbacks = [customHandler] + user_message = "Hello, how are you?" + input = [user_message] + try: + kwargs = {"model": "bad-azure-embedding-model", "input": input} + response = router.embedding(**kwargs) + print(f"customHandler.previous_models: {customHandler.previous_models}") + time.sleep(0.05) # allow a delay as success_callbacks are on a separate thread + assert customHandler.previous_models == 1 # 0 retries, 1 fallback + router.reset() + except litellm.Timeout as e: + pass + except Exception as e: + pytest.fail(f"An exception occurred: {e}") + finally: + router.reset() + + @pytest.mark.asyncio async def test_async_fallbacks_embeddings(): litellm.set_verbose = False diff --git a/litellm/tests/test_router_with_fallbacks.py b/litellm/tests/test_router_with_fallbacks.py new file mode 100644 index 0000000000..deabf73750 --- /dev/null +++ b/litellm/tests/test_router_with_fallbacks.py @@ -0,0 +1,56 @@ +# [LOCAL TEST] - runs against mock openai proxy +# # What this tests? +# ## This tests if fallbacks works for 429 errors + +# import sys, os, time +# import traceback, asyncio +# import pytest + +# sys.path.insert( +# 0, os.path.abspath("../..") +# ) # Adds the parent directory to the system path +# import litellm +# from litellm import Router + +# model_list = [ +# { # list of model deployments +# "model_name": "text-embedding-ada-002", # model alias +# "litellm_params": { # params for litellm completion/embedding call +# "model": "text-embedding-ada-002", # actual model name +# "api_key": "sk-fakekey", +# "api_base": "http://0.0.0.0:8080", +# }, +# "tpm": 1000, +# "rpm": 6, +# }, +# { +# "model_name": "text-embedding-ada-002-fallback", +# "litellm_params": { # params for litellm completion/embedding call +# "model": "openai/text-embedding-ada-002-anything-else", # actual model name +# "api_key": "sk-fakekey2", +# "api_base": "http://0.0.0.0:8080", +# }, +# "tpm": 1000, +# "rpm": 6, +# }, +# ] + +# router = Router( +# model_list=model_list, +# fallbacks=[ +# {"text-embedding-ada-002": ["text-embedding-ada-002-fallback"]}, +# {"text-embedding-ada-002-fallback": ["text-embedding-ada-002"]}, +# ], +# set_verbose=True, +# num_retries=0, +# debug_level="INFO", +# routing_strategy="usage-based-routing", +# ) + + +# def test_embedding_with_fallbacks(): +# response = router.embedding(model="text-embedding-ada-002", input=["Hello world"]) +# print(f"response: {response}") + + +# test_embedding_with_fallbacks() diff --git a/proxy_server_config.yaml b/proxy_server_config.yaml index 5f4875a786..83bcc0626f 100644 --- a/proxy_server_config.yaml +++ b/proxy_server_config.yaml @@ -1,18 +1,16 @@ model_list: - # NOTE: This is the default config users use with Dockerfile. - # DO not expect users to pass os.environ/<> vars here, this will lead to proxy startup failing for them if they don't have the expected env vars - model_name: gpt-3.5-turbo litellm_params: model: azure/chatgpt-v-2 api_base: https://openai-gpt-4-test-v-1.openai.azure.com/ api_version: "2023-05-15" - api_key: sk-defaultKey # use `os.environ/AZURE_API_KEY` for production. The `os.environ/` prefix tells litellm to read this from the env. See https://docs.litellm.ai/docs/simple_proxy#load-api-keys-from-vault + api_key: os.environ/AZURE_API_KEY # The `os.environ/` prefix tells litellm to read this from the env. See https://docs.litellm.ai/docs/simple_proxy#load-api-keys-from-vault - model_name: gpt-4 litellm_params: model: azure/chatgpt-v-2 api_base: https://openai-gpt-4-test-v-1.openai.azure.com/ api_version: "2023-05-15" - api_key: sk-defaultKey # use `os.environ/AZURE_API_KEY` for production. The `os.environ/` prefix tells litellm to read this from the env. See https://docs.litellm.ai/docs/simple_proxy#load-api-keys-from-vault + api_key: os.environ/AZURE_API_KEY # The `os.environ/` prefix tells litellm to read this from the env. See https://docs.litellm.ai/docs/simple_proxy#load-api-keys-from-vault - model_name: sagemaker-completion-model litellm_params: model: sagemaker/berri-benchmarking-Llama-2-70b-chat-hf-4 @@ -20,7 +18,7 @@ model_list: - model_name: text-embedding-ada-002 litellm_params: model: azure/azure-embedding-model - api_key: sk-defaultKey # use `os.environ/AZURE_API_KEY` for production. The `os.environ/` prefix tells litellm to read this from the env. See https://docs.litellm.ai/docs/simple_proxy#load-api-keys-from-vault + api_key: os.environ/AZURE_API_KEY api_base: https://openai-gpt-4-test-v-1.openai.azure.com/ api_version: "2023-05-15" model_info: @@ -28,10 +26,13 @@ model_list: base_model: text-embedding-ada-002 - model_name: dall-e-2 litellm_params: - model: azure/dall-e-2 + model: azure/ api_version: 2023-06-01-preview api_base: https://openai-gpt-4-test-v-1.openai.azure.com/ - api_key: sk-defaultKey # use `os.environ/AZURE_API_KEY` for production. The `os.environ/` prefix tells litellm to read this from the env. See https://docs.litellm.ai/docs/simple_proxy#load-api-keys-from-vault + api_key: os.environ/AZURE_API_KEY + - model_name: openai-dall-e-3 + litellm_params: + model: dall-e-3 litellm_settings: drop_params: True @@ -39,7 +40,7 @@ litellm_settings: budget_duration: 30d num_retries: 5 request_timeout: 600 -general_settings: +general_settings: master_key: sk-1234 # [OPTIONAL] Only use this if you to require all calls to contain this key (Authorization: Bearer sk-1234) proxy_budget_rescheduler_min_time: 60 proxy_budget_rescheduler_max_time: 64 diff --git a/pyproject.toml b/pyproject.toml index 06cbe89a7d..0d44b366bc 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -22,7 +22,6 @@ click = "*" jinja2 = "^3.1.2" aiohttp = "*" requests = "^2.31.0" -argon2-cffi = "^23.1.0" uvicorn = {version = "^0.22.0", optional = true} gunicorn = {version = "^21.2.0", optional = true} @@ -36,6 +35,7 @@ streamlit = {version = "^1.29.0", optional = true} fastapi-sso = { version = "^0.10.0", optional = true } PyJWT = { version = "^2.8.0", optional = true } python-multipart = { version = "^0.0.6", optional = true } +argon2-cffi = { version = "^23.1.0", optional = true } [tool.poetry.extras] proxy = [ @@ -50,6 +50,7 @@ proxy = [ "fastapi-sso", "PyJWT", "python-multipart", + "argon2-cffi", ] extra_proxy = [