forked from phoenix/litellm-mirror
Merge branch 'main' into litellm_budget_per_key
This commit is contained in:
commit
db68774d60
20 changed files with 731 additions and 183 deletions
|
@ -35,6 +35,7 @@ import pytest, logging, asyncio
|
|||
import litellm, asyncio
|
||||
from litellm.proxy.proxy_server import (
|
||||
new_user,
|
||||
generate_key_fn,
|
||||
user_api_key_auth,
|
||||
user_update,
|
||||
delete_key_fn,
|
||||
|
@ -53,6 +54,7 @@ from litellm.proxy._types import (
|
|||
DynamoDBArgs,
|
||||
DeleteKeyRequest,
|
||||
UpdateKeyRequest,
|
||||
GenerateKeyRequest,
|
||||
)
|
||||
from litellm.proxy.utils import DBClient
|
||||
from starlette.datastructures import URL
|
||||
|
@ -597,6 +599,85 @@ def test_generate_and_update_key(prisma_client):
|
|||
print("Got Exception", e)
|
||||
print(e.detail)
|
||||
pytest.fail(f"An exception occurred - {str(e)}")
|
||||
|
||||
def test_key_generate_with_custom_auth(prisma_client):
|
||||
# custom - generate key function
|
||||
async def custom_generate_key_fn(data: GenerateKeyRequest) -> dict:
|
||||
"""
|
||||
Asynchronous function for generating a key based on the input data.
|
||||
|
||||
Args:
|
||||
data (GenerateKeyRequest): The input data for key generation.
|
||||
|
||||
Returns:
|
||||
dict: A dictionary containing the decision and an optional message.
|
||||
{
|
||||
"decision": False,
|
||||
"message": "This violates LiteLLM Proxy Rules. No team id provided.",
|
||||
}
|
||||
"""
|
||||
|
||||
# decide if a key should be generated or not
|
||||
print("using custom auth function!")
|
||||
data_json = data.json() # type: ignore
|
||||
|
||||
# Unpacking variables
|
||||
team_id = data_json.get("team_id")
|
||||
duration = data_json.get("duration")
|
||||
models = data_json.get("models")
|
||||
aliases = data_json.get("aliases")
|
||||
config = data_json.get("config")
|
||||
spend = data_json.get("spend")
|
||||
user_id = data_json.get("user_id")
|
||||
max_parallel_requests = data_json.get("max_parallel_requests")
|
||||
metadata = data_json.get("metadata")
|
||||
tpm_limit = data_json.get("tpm_limit")
|
||||
rpm_limit = data_json.get("rpm_limit")
|
||||
|
||||
if team_id is not None and team_id == "litellm-core-infra@gmail.com":
|
||||
# only team_id="litellm-core-infra@gmail.com" can make keys
|
||||
return {
|
||||
"decision": True,
|
||||
}
|
||||
else:
|
||||
print("Failed custom auth")
|
||||
return {
|
||||
"decision": False,
|
||||
"message": "This violates LiteLLM Proxy Rules. No team id provided.",
|
||||
}
|
||||
|
||||
setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client)
|
||||
setattr(litellm.proxy.proxy_server, "master_key", "sk-1234")
|
||||
setattr(
|
||||
litellm.proxy.proxy_server, "user_custom_key_generate", custom_generate_key_fn
|
||||
)
|
||||
try:
|
||||
request = GenerateKeyRequest()
|
||||
key = await generate_key_fn(request)
|
||||
pytest.fail(f"Expected an exception. Got {key}")
|
||||
except Exception as e:
|
||||
# this should fail
|
||||
print("Got Exception", e)
|
||||
print(e.detail)
|
||||
print("First request failed!. This is expected")
|
||||
assert (
|
||||
"This violates LiteLLM Proxy Rules. No team id provided."
|
||||
in e.detail
|
||||
)
|
||||
|
||||
request_2 = GenerateKeyRequest(
|
||||
team_id="litellm-core-infra@gmail.com",
|
||||
)
|
||||
|
||||
key = await generate_key_fn(request_2)
|
||||
print(key)
|
||||
generated_key = key.key
|
||||
|
||||
asyncio.run(test())
|
||||
except Exception as e:
|
||||
print("Got Exception", e)
|
||||
print(e.detail)
|
||||
pytest.fail(f"An exception occurred - {str(e)}")
|
||||
|
||||
|
||||
def test_call_with_key_over_budget(prisma_client):
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue