Merge pull request #1478 from BerriAI/litellm_unit_test_proxy_key_gen

[Test] Proxy -  Unit Test proxy key gen
This commit is contained in:
Ishaan Jaff 2024-01-17 13:38:38 -08:00 committed by GitHub
commit cbb4484bce
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 427 additions and 19 deletions

View file

@ -370,7 +370,7 @@ async def user_api_key_auth(
# Token exists but is expired. # Token exists but is expired.
raise HTTPException( raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN, status_code=status.HTTP_403_FORBIDDEN,
detail="expired user key", detail=f"Authentication Error - Expired Key. Key Expiry time {expiry_time} and current time {current_time}",
) )
# Token passed all checks # Token passed all checks

View file

@ -7,10 +7,12 @@ from litellm.proxy.hooks.parallel_request_limiter import MaxParallelRequestsHand
from litellm.proxy.hooks.max_budget_limiter import MaxBudgetLimiter from litellm.proxy.hooks.max_budget_limiter import MaxBudgetLimiter
from litellm.integrations.custom_logger import CustomLogger from litellm.integrations.custom_logger import CustomLogger
from litellm.proxy.db.base_client import CustomDB from litellm.proxy.db.base_client import CustomDB
from litellm._logging import verbose_proxy_logger
from fastapi import HTTPException, status from fastapi import HTTPException, status
import smtplib import smtplib
from email.mime.text import MIMEText from email.mime.text import MIMEText
from email.mime.multipart import MIMEMultipart from email.mime.multipart import MIMEMultipart
from datetime import datetime
def print_verbose(print_statement): def print_verbose(print_statement):
@ -375,13 +377,14 @@ class PrismaClient:
print_verbose(f"PrismaClient: response={response}") print_verbose(f"PrismaClient: response={response}")
if response is not None: if response is not None:
# for prisma we need to cast the expires time to str # for prisma we need to cast the expires time to str
if isinstance(response.expires, datetime):
response.expires = response.expires.isoformat() response.expires = response.expires.isoformat()
return response return response
else: else:
# Token does not exist. # Token does not exist.
raise HTTPException( raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, status_code=status.HTTP_401_UNAUTHORIZED,
detail="invalid user key", detail="Authentication Error: invalid user key - token does not exist",
) )
elif user_id is not None: elif user_id is not None:
response = await self.db.litellm_usertable.find_unique( # type: ignore response = await self.db.litellm_usertable.find_unique( # type: ignore
@ -559,7 +562,13 @@ class PrismaClient:
) )
async def connect(self): async def connect(self):
try: try:
verbose_proxy_logger.debug(
"PrismaClient: connect() called Attempting to Connect to DB"
)
if self.db.is_connected() == False: if self.db.is_connected() == False:
verbose_proxy_logger.debug(
"PrismaClient: DB not connected, Attempting to Connect to DB"
)
await self.db.connect() await self.db.connect()
except Exception as e: except Exception as e:
asyncio.create_task( asyncio.create_task(

View file

@ -32,6 +32,17 @@ from litellm.proxy.utils import DBClient
from starlette.datastructures import URL from starlette.datastructures import URL
request_data = {
"model": "azure-gpt-3.5",
"messages": [
{"role": "user", "content": "this is my new test. respond in 50 lines"}
],
}
@pytest.fixture
def custom_db_client():
# Assuming DBClient is a class that needs to be instantiated
db_args = { db_args = {
"ssl_verify": False, "ssl_verify": False,
"billing_mode": "PAY_PER_REQUEST", "billing_mode": "PAY_PER_REQUEST",
@ -41,16 +52,13 @@ custom_db_client = DBClient(
custom_db_type="dynamo_db", custom_db_type="dynamo_db",
custom_db_args=db_args, custom_db_args=db_args,
) )
# Reset litellm.proxy.proxy_server.prisma_client to None
litellm.proxy.proxy_server.prisma_client = None
request_data = { return custom_db_client
"model": "azure-gpt-3.5",
"messages": [
{"role": "user", "content": "this is my new test. respond in 50 lines"}
],
}
def test_generate_and_call_with_valid_key(): def test_generate_and_call_with_valid_key(custom_db_client):
# 1. Generate a Key, and use it to make a call # 1. Generate a Key, and use it to make a call
setattr(litellm.proxy.proxy_server, "custom_db_client", custom_db_client) setattr(litellm.proxy.proxy_server, "custom_db_client", custom_db_client)
setattr(litellm.proxy.proxy_server, "master_key", "sk-1234") setattr(litellm.proxy.proxy_server, "master_key", "sk-1234")
@ -76,7 +84,7 @@ def test_generate_and_call_with_valid_key():
pytest.fail(f"An exception occurred - {str(e)}") pytest.fail(f"An exception occurred - {str(e)}")
def test_call_with_invalid_key(): def test_call_with_invalid_key(custom_db_client):
# 2. Make a call with invalid key, expect it to fail # 2. Make a call with invalid key, expect it to fail
setattr(litellm.proxy.proxy_server, "custom_db_client", custom_db_client) setattr(litellm.proxy.proxy_server, "custom_db_client", custom_db_client)
setattr(litellm.proxy.proxy_server, "master_key", "sk-1234") setattr(litellm.proxy.proxy_server, "master_key", "sk-1234")
@ -101,7 +109,7 @@ def test_call_with_invalid_key():
pass pass
def test_call_with_invalid_model(): def test_call_with_invalid_model(custom_db_client):
# 3. Make a call to a key with an invalid model - expect to fail # 3. Make a call to a key with an invalid model - expect to fail
setattr(litellm.proxy.proxy_server, "custom_db_client", custom_db_client) setattr(litellm.proxy.proxy_server, "custom_db_client", custom_db_client)
setattr(litellm.proxy.proxy_server, "master_key", "sk-1234") setattr(litellm.proxy.proxy_server, "master_key", "sk-1234")
@ -136,7 +144,7 @@ def test_call_with_invalid_model():
pass pass
def test_call_with_valid_model(): def test_call_with_valid_model(custom_db_client):
# 4. Make a call to a key with a valid model - expect to pass # 4. Make a call to a key with a valid model - expect to pass
setattr(litellm.proxy.proxy_server, "custom_db_client", custom_db_client) setattr(litellm.proxy.proxy_server, "custom_db_client", custom_db_client)
setattr(litellm.proxy.proxy_server, "master_key", "sk-1234") setattr(litellm.proxy.proxy_server, "master_key", "sk-1234")
@ -167,7 +175,7 @@ def test_call_with_valid_model():
pytest.fail(f"An exception occurred - {str(e)}") pytest.fail(f"An exception occurred - {str(e)}")
def test_call_with_key_over_budget(): def test_call_with_key_over_budget(custom_db_client):
# 5. Make a call with a key over budget, expect to fail # 5. Make a call with a key over budget, expect to fail
setattr(litellm.proxy.proxy_server, "custom_db_client", custom_db_client) setattr(litellm.proxy.proxy_server, "custom_db_client", custom_db_client)
setattr(litellm.proxy.proxy_server, "master_key", "sk-1234") setattr(litellm.proxy.proxy_server, "master_key", "sk-1234")
@ -233,7 +241,7 @@ def test_call_with_key_over_budget():
print(vars(e)) print(vars(e))
def test_call_with_key_over_budget_stream(): def test_call_with_key_over_budget_stream(custom_db_client):
# 6. Make a call with a key over budget, expect to fail # 6. Make a call with a key over budget, expect to fail
setattr(litellm.proxy.proxy_server, "custom_db_client", custom_db_client) setattr(litellm.proxy.proxy_server, "custom_db_client", custom_db_client)
setattr(litellm.proxy.proxy_server, "master_key", "sk-1234") setattr(litellm.proxy.proxy_server, "master_key", "sk-1234")

View file

@ -0,0 +1,391 @@
# Test the following scenarios:
# 1. Generate a Key, and use it to make a call
# 2. Make a call with invalid key, expect it to fail
# 3. Make a call to a key with invalid model - expect to fail
# 4. Make a call to a key with valid model - expect to pass
# 5. Make a call with key over budget, expect to fail
# 6. Make a streaming chat/completions call with key over budget, expect to fail
# 7. Make a call with an key that never expires, expect to pass
# 8. Make a call with an expired key, expect to fail
# function to call to generate key - async def new_user(data: NewUserRequest):
# function to validate a request - async def user_auth(request: Request):
import sys, os
import traceback
from dotenv import load_dotenv
from fastapi import Request
load_dotenv()
import os, io
# this file is to test litellm/proxy
sys.path.insert(
0, os.path.abspath("../..")
) # Adds the parent directory to the system path
import pytest, logging, asyncio
import litellm, asyncio
from litellm.proxy.proxy_server import new_user, user_api_key_auth, user_update
from litellm.proxy.utils import PrismaClient, ProxyLogging
from litellm._logging import verbose_proxy_logger
verbose_proxy_logger.setLevel(level=logging.DEBUG)
from litellm.proxy._types import NewUserRequest, DynamoDBArgs
from litellm.proxy.utils import DBClient
from starlette.datastructures import URL
from litellm.caching import DualCache
proxy_logging_obj = ProxyLogging(user_api_key_cache=DualCache())
request_data = {
"model": "azure-gpt-3.5",
"messages": [
{"role": "user", "content": "this is my new test. respond in 50 lines"}
],
}
@pytest.fixture
def prisma_client():
# Assuming DBClient is a class that needs to be instantiated
prisma_client = PrismaClient(
database_url=os.environ["DATABASE_URL"], proxy_logging_obj=proxy_logging_obj
)
# Reset litellm.proxy.proxy_server.prisma_client to None
litellm.proxy.proxy_server.custom_db_client = None
return prisma_client
def test_generate_and_call_with_valid_key(prisma_client):
# 1. Generate a Key, and use it to make a call
print("prisma client=", prisma_client)
setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client)
setattr(litellm.proxy.proxy_server, "master_key", "sk-1234")
try:
async def test():
await litellm.proxy.proxy_server.prisma_client.connect()
request = NewUserRequest()
key = await new_user(request)
print(key)
generated_key = key.key
bearer_token = "Bearer " + generated_key
request = Request(scope={"type": "http"})
request._url = URL(url="/chat/completions")
# use generated key to auth in
result = await user_api_key_auth(request=request, api_key=bearer_token)
print("result from user auth with new key", result)
asyncio.run(test())
except Exception as e:
pytest.fail(f"An exception occurred - {str(e)}")
def test_call_with_invalid_key(prisma_client):
# 2. Make a call with invalid key, expect it to fail
setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client)
setattr(litellm.proxy.proxy_server, "master_key", "sk-1234")
try:
async def test():
await litellm.proxy.proxy_server.prisma_client.connect()
generated_key = "bad-key"
bearer_token = "Bearer " + generated_key
request = Request(scope={"type": "http"}, receive=None)
request._url = URL(url="/chat/completions")
# use generated key to auth in
result = await user_api_key_auth(request=request, api_key=bearer_token)
print("got result", result)
pytest.fail(f"This should have failed!. IT's an invalid key")
asyncio.run(test())
except Exception as e:
print("Got Exception", e)
print(e.detail)
assert "Authentication Error" in e.detail
pass
def test_call_with_invalid_model(prisma_client):
# 3. Make a call to a key with an invalid model - expect to fail
setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client)
setattr(litellm.proxy.proxy_server, "master_key", "sk-1234")
try:
async def test():
await litellm.proxy.proxy_server.prisma_client.connect()
request = NewUserRequest(models=["mistral"])
key = await new_user(request)
print(key)
generated_key = key.key
bearer_token = "Bearer " + generated_key
request = Request(scope={"type": "http"})
request._url = URL(url="/chat/completions")
async def return_body():
return b'{"model": "gemini-pro-vision"}'
request.body = return_body
# use generated key to auth in
result = await user_api_key_auth(request=request, api_key=bearer_token)
pytest.fail(f"This should have failed!. IT's an invalid model")
asyncio.run(test())
except Exception as e:
assert (
e.detail
== "Authentication Error, API Key not allowed to access model. This token can only access models=['mistral']. Tried to access gemini-pro-vision"
)
pass
def test_call_with_valid_model(prisma_client):
# 4. Make a call to a key with a valid model - expect to pass
setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client)
setattr(litellm.proxy.proxy_server, "master_key", "sk-1234")
try:
async def test():
await litellm.proxy.proxy_server.prisma_client.connect()
request = NewUserRequest(models=["mistral"])
key = await new_user(request)
print(key)
generated_key = key.key
bearer_token = "Bearer " + generated_key
request = Request(scope={"type": "http"})
request._url = URL(url="/chat/completions")
async def return_body():
return b'{"model": "mistral"}'
request.body = return_body
# use generated key to auth in
result = await user_api_key_auth(request=request, api_key=bearer_token)
print("result from user auth with new key", result)
asyncio.run(test())
except Exception as e:
pytest.fail(f"An exception occurred - {str(e)}")
def test_call_with_key_over_budget(prisma_client):
# 5. Make a call with a key over budget, expect to fail
setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client)
setattr(litellm.proxy.proxy_server, "master_key", "sk-1234")
try:
async def test():
await litellm.proxy.proxy_server.prisma_client.connect()
request = NewUserRequest(max_budget=0.00001)
key = await new_user(request)
print(key)
generated_key = key.key
user_id = key.user_id
bearer_token = "Bearer " + generated_key
request = Request(scope={"type": "http"})
request._url = URL(url="/chat/completions")
# use generated key to auth in
result = await user_api_key_auth(request=request, api_key=bearer_token)
print("result from user auth with new key", result)
# update spend using track_cost callback, make 2nd request, it should fail
from litellm.proxy.proxy_server import track_cost_callback
from litellm import ModelResponse, Choices, Message, Usage
resp = ModelResponse(
id="chatcmpl-e41836bb-bb8b-4df2-8e70-8f3e160155ac",
choices=[
Choices(
finish_reason=None,
index=0,
message=Message(
content=" Sure! Here is a short poem about the sky:\n\nA canvas of blue, a",
role="assistant",
),
)
],
model="gpt-35-turbo", # azure always has model written like this
usage=Usage(prompt_tokens=210, completion_tokens=200, total_tokens=410),
)
await track_cost_callback(
kwargs={
"stream": False,
"litellm_params": {
"metadata": {
"user_api_key": generated_key,
"user_api_key_user_id": user_id,
}
},
},
completion_response=resp,
)
# use generated key to auth in
result = await user_api_key_auth(request=request, api_key=bearer_token)
print("result from user auth with new key", result)
pytest.fail(f"This should have failed!. They key crossed it's budget")
asyncio.run(test())
except Exception as e:
error_detail = e.detail
assert "Authentication Error, ExceededBudget:" in error_detail
print(vars(e))
def test_call_with_key_over_budget_stream(prisma_client):
# 6. Make a call with a key over budget, expect to fail
setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client)
setattr(litellm.proxy.proxy_server, "master_key", "sk-1234")
from litellm._logging import verbose_proxy_logger
import logging
litellm.set_verbose = True
verbose_proxy_logger.setLevel(logging.DEBUG)
try:
async def test():
await litellm.proxy.proxy_server.prisma_client.connect()
request = NewUserRequest(max_budget=0.00001)
key = await new_user(request)
print(key)
generated_key = key.key
user_id = key.user_id
bearer_token = "Bearer " + generated_key
request = Request(scope={"type": "http"})
request._url = URL(url="/chat/completions")
# use generated key to auth in
result = await user_api_key_auth(request=request, api_key=bearer_token)
print("result from user auth with new key", result)
# update spend using track_cost callback, make 2nd request, it should fail
from litellm.proxy.proxy_server import track_cost_callback
from litellm import ModelResponse, Choices, Message, Usage
resp = ModelResponse(
id="chatcmpl-e41836bb-bb8b-4df2-8e70-8f3e160155ac",
choices=[
Choices(
finish_reason=None,
index=0,
message=Message(
content=" Sure! Here is a short poem about the sky:\n\nA canvas of blue, a",
role="assistant",
),
)
],
model="gpt-35-turbo", # azure always has model written like this
usage=Usage(prompt_tokens=210, completion_tokens=200, total_tokens=410),
)
await track_cost_callback(
kwargs={
"stream": True,
"complete_streaming_response": resp,
"litellm_params": {
"metadata": {
"user_api_key": generated_key,
"user_api_key_user_id": user_id,
}
},
},
completion_response=ModelResponse(),
)
# use generated key to auth in
result = await user_api_key_auth(request=request, api_key=bearer_token)
print("result from user auth with new key", result)
pytest.fail(f"This should have failed!. They key crossed it's budget")
asyncio.run(test())
except Exception as e:
error_detail = e.detail
assert "Authentication Error, ExceededBudget:" in error_detail
print(vars(e))
def test_generate_and_call_with_valid_key_never_expires(prisma_client):
# 7. Make a call with an key that never expires, expect to pass
print("prisma client=", prisma_client)
setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client)
setattr(litellm.proxy.proxy_server, "master_key", "sk-1234")
try:
async def test():
await litellm.proxy.proxy_server.prisma_client.connect()
request = NewUserRequest(duration=None)
key = await new_user(request)
print(key)
generated_key = key.key
bearer_token = "Bearer " + generated_key
request = Request(scope={"type": "http"})
request._url = URL(url="/chat/completions")
# use generated key to auth in
result = await user_api_key_auth(request=request, api_key=bearer_token)
print("result from user auth with new key", result)
asyncio.run(test())
except Exception as e:
pytest.fail(f"An exception occurred - {str(e)}")
def test_generate_and_call_with_expired_key(prisma_client):
# 8. Make a call with an expired key, expect to fail
print("prisma client=", prisma_client)
setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client)
setattr(litellm.proxy.proxy_server, "master_key", "sk-1234")
try:
async def test():
await litellm.proxy.proxy_server.prisma_client.connect()
request = NewUserRequest(duration="0s")
key = await new_user(request)
print(key)
generated_key = key.key
bearer_token = "Bearer " + generated_key
request = Request(scope={"type": "http"})
request._url = URL(url="/chat/completions")
# use generated key to auth in
result = await user_api_key_auth(request=request, api_key=bearer_token)
print("result from user auth with new key", result)
pytest.fail(f"This should have failed!. IT's an expired key")
asyncio.run(test())
except Exception as e:
print("Got Exception", e)
print(e.detail)
assert "Authentication Error" in e.detail
pass