diff --git a/litellm/proxy/_types.py b/litellm/proxy/_types.py index 9a5acc4406..0c38504b89 100644 --- a/litellm/proxy/_types.py +++ b/litellm/proxy/_types.py @@ -140,6 +140,7 @@ class GenerateRequestBase(LiteLLMBase): class GenerateKeyRequest(GenerateRequestBase): + key_alias: Optional[str] = None duration: Optional[str] = "1h" aliases: Optional[dict] = {} config: Optional[dict] = {} @@ -304,6 +305,8 @@ class ConfigYAML(LiteLLMBase): class LiteLLM_VerificationToken(LiteLLMBase): token: str + key_name: Optional[str] = None + key_alias: Optional[str] = None spend: float = 0.0 max_budget: Optional[float] = None expires: Union[str, None] diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index 21e2bb0592..4a854ec761 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -244,6 +244,8 @@ async def user_api_key_auth( response = await user_custom_auth(request=request, api_key=api_key) return UserAPIKeyAuth.model_validate(response) ### LITELLM-DEFINED AUTH FUNCTION ### + if isinstance(api_key, str): + assert api_key.startswith("sk-") # prevent token hashes from being used if master_key is None: if isinstance(api_key, str): return UserAPIKeyAuth(api_key=api_key) @@ -1247,6 +1249,7 @@ async def generate_key_helper_fn( rpm_limit: Optional[int] = None, query_type: Literal["insert_data", "update_data"] = "insert_data", update_key_values: Optional[dict] = None, + key_alias: Optional[str] = None, ): global prisma_client, custom_db_client @@ -1320,6 +1323,7 @@ async def generate_key_helper_fn( } key_data = { "token": token, + "key_alias": key_alias, "expires": expires, "models": models, "aliases": aliases_json, @@ -1335,6 +1339,8 @@ async def generate_key_helper_fn( "budget_duration": key_budget_duration, "budget_reset_at": key_reset_at, } + if general_settings.get("allow_user_auth", False) == True: + key_data["key_name"] = f"sk-...{token[-4:]}" if prisma_client is not None: ## CREATE USER (If necessary) verbose_proxy_logger.debug(f"prisma_client: Creating User={user_data}") diff --git a/litellm/tests/test_configs/test_config_no_auth.yaml b/litellm/tests/test_configs/test_config_no_auth.yaml index be85765a86..8441018e35 100644 --- a/litellm/tests/test_configs/test_config_no_auth.yaml +++ b/litellm/tests/test_configs/test_config_no_auth.yaml @@ -53,9 +53,9 @@ model_list: api_key: os.environ/AZURE_API_KEY api_version: 2023-07-01-preview model: azure/azure-embedding-model - model_name: azure-embedding-model model_info: - mode: "embedding" + mode: embedding + model_name: azure-embedding-model - litellm_params: model: gpt-3.5-turbo model_info: @@ -80,43 +80,49 @@ model_list: description: this is a test openai model id: 9b1ef341-322c-410a-8992-903987fef439 model_name: test_openai_models -- model_name: amazon-embeddings - litellm_params: - model: "bedrock/amazon.titan-embed-text-v1" +- litellm_params: + model: bedrock/amazon.titan-embed-text-v1 model_info: mode: embedding -- model_name: "GPT-J 6B - Sagemaker Text Embedding (Internal)" - litellm_params: - model: "sagemaker/berri-benchmarking-gpt-j-6b-fp16" + model_name: amazon-embeddings +- litellm_params: + model: sagemaker/berri-benchmarking-gpt-j-6b-fp16 model_info: mode: embedding -- model_name: dall-e-3 - litellm_params: + model_name: GPT-J 6B - Sagemaker Text Embedding (Internal) +- litellm_params: model: dall-e-3 model_info: mode: image_generation -- model_name: dall-e-3 - litellm_params: - model: "azure/dall-e-3-test" - api_version: "2023-12-01-preview" - api_base: "os.environ/AZURE_SWEDEN_API_BASE" - api_key: "os.environ/AZURE_SWEDEN_API_KEY" + model_name: dall-e-3 +- litellm_params: + api_base: os.environ/AZURE_SWEDEN_API_BASE + api_key: os.environ/AZURE_SWEDEN_API_KEY + api_version: 2023-12-01-preview + model: azure/dall-e-3-test model_info: mode: image_generation -- model_name: dall-e-2 - litellm_params: - model: "azure/" - api_version: "2023-06-01-preview" - api_base: "os.environ/AZURE_API_BASE" - api_key: "os.environ/AZURE_API_KEY" + model_name: dall-e-3 +- litellm_params: + api_base: os.environ/AZURE_API_BASE + api_key: os.environ/AZURE_API_KEY + api_version: 2023-06-01-preview + model: azure/ model_info: mode: image_generation -- model_name: text-embedding-ada-002 - litellm_params: + model_name: dall-e-2 +- litellm_params: + api_base: os.environ/AZURE_API_BASE + api_key: os.environ/AZURE_API_KEY + api_version: 2023-07-01-preview model: azure/azure-embedding-model - api_base: "os.environ/AZURE_API_BASE" - api_key: "os.environ/AZURE_API_KEY" - api_version: "2023-07-01-preview" model_info: + base_model: text-embedding-ada-002 mode: embedding - base_model: text-embedding-ada-002 \ No newline at end of file + model_name: text-embedding-ada-002 +- litellm_params: + model: gpt-3.5-turbo + model_info: + description: this is a test openai model + id: 34cb2419-7c63-44ae-a189-53f1d1ce5953 + model_name: test_openai_models diff --git a/litellm/tests/test_key_generate_prisma.py b/litellm/tests/test_key_generate_prisma.py index ab490063f5..9f6bfd1be9 100644 --- a/litellm/tests/test_key_generate_prisma.py +++ b/litellm/tests/test_key_generate_prisma.py @@ -12,6 +12,8 @@ # 11. Generate a Key, cal key/info, call key/update, call key/info # 12. Make a call with key over budget, expect to fail # 14. Make a streaming chat/completions call with key over budget, expect to fail +# 15. Generate key, when `allow_user_auth`=False - check if `/key/info` returns key_name=null +# 16. Generate key, when `allow_user_auth`=True - check if `/key/info` returns key_name=sk... # function to call to generate key - async def new_user(data: NewUserRequest): @@ -86,6 +88,7 @@ def prisma_client(): litellm.proxy.proxy_server.litellm_proxy_budget_name = ( f"litellm-proxy-budget-{time.time()}" ) + litellm.proxy.proxy_server.user_custom_key_generate = None return prisma_client @@ -1140,3 +1143,48 @@ async def test_view_spend_per_key(prisma_client): except Exception as e: print("Got Exception", e) pytest.fail(f"Got exception {e}") + + +@pytest.mark.asyncio() +async def test_key_name_null(prisma_client): + """ + - create key + - get key info + - assert key_name is null + """ + setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client) + setattr(litellm.proxy.proxy_server, "master_key", "sk-1234") + await litellm.proxy.proxy_server.prisma_client.connect() + try: + request = GenerateKeyRequest() + key = await generate_key_fn(request) + generated_key = key.key + result = await info_key_fn(key=generated_key) + print("result from info_key_fn", result) + assert result["info"]["key_name"] is None + except Exception as e: + print("Got Exception", e) + pytest.fail(f"Got exception {e}") + + +@pytest.mark.asyncio() +async def test_key_name_set(prisma_client): + """ + - create key + - get key info + - assert key_name is not null + """ + setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client) + setattr(litellm.proxy.proxy_server, "master_key", "sk-1234") + setattr(litellm.proxy.proxy_server, "general_settings", {"allow_user_auth": True}) + await litellm.proxy.proxy_server.prisma_client.connect() + try: + request = GenerateKeyRequest() + key = await generate_key_fn(request) + generated_key = key.key + result = await info_key_fn(key=generated_key) + print("result from info_key_fn", result) + assert isinstance(result["info"]["key_name"], str) + except Exception as e: + print("Got Exception", e) + pytest.fail(f"Got exception {e}") diff --git a/litellm/tests/test_proxy_pass_user_config.py b/litellm/tests/test_proxy_pass_user_config.py index 30fa1eeb11..12def1160f 100644 --- a/litellm/tests/test_proxy_pass_user_config.py +++ b/litellm/tests/test_proxy_pass_user_config.py @@ -32,7 +32,7 @@ from litellm.proxy.proxy_server import ( ) # Replace with the actual module where your FastAPI router is defined # Your bearer token -token = "" +token = "sk-1234" headers = {"Authorization": f"Bearer {token}"} diff --git a/litellm/tests/test_proxy_server.py b/litellm/tests/test_proxy_server.py index 972c4a583a..4e0f706eb0 100644 --- a/litellm/tests/test_proxy_server.py +++ b/litellm/tests/test_proxy_server.py @@ -31,7 +31,7 @@ from litellm.proxy.proxy_server import ( ) # Replace with the actual module where your FastAPI router is defined # Your bearer token -token = "" +token = "sk-1234" headers = {"Authorization": f"Bearer {token}"} diff --git a/litellm/tests/test_proxy_server_caching.py b/litellm/tests/test_proxy_server_caching.py index a1935bd05b..a9cf3504e4 100644 --- a/litellm/tests/test_proxy_server_caching.py +++ b/litellm/tests/test_proxy_server_caching.py @@ -33,7 +33,7 @@ from litellm.proxy.proxy_server import ( ) # Replace with the actual module where your FastAPI router is defined # Your bearer token -token = "" +token = "sk-1234" headers = {"Authorization": f"Bearer {token}"}