test: fix proxy server testing

This commit is contained in:
Krrish Dholakia 2023-12-06 18:38:44 -08:00
parent 19b1deb200
commit c0eedf28fc
3 changed files with 27 additions and 11 deletions

View file

@ -234,7 +234,8 @@ async def user_api_key_auth(request: Request, api_key: str = fastapi.Security(ap
return UserAPIKeyAuth()
if api_key is None: # only require api key if master key is set
raise Exception("No api key passed in.")
raise Exception(f"No api key passed in.")
route = request.url.path
# note: never string compare api keys, this is vulenerable to a time attack. Use secrets.compare_digest instead
@ -816,11 +817,12 @@ async def startup_event():
@router.on_event("shutdown")
async def shutdown_event():
global prisma_client
global prisma_client, master_key, user_custom_auth
if prisma_client:
print("Disconnecting from Prisma")
await prisma_client.disconnect()
master_key = None
user_custom_auth = None
#### API ENDPOINTS ####
@router.get("/v1/models", dependencies=[Depends(user_api_key_auth)])
@router.get("/models", dependencies=[Depends(user_api_key_auth)]) # if project requires model list

View file

@ -23,7 +23,7 @@ logging.basicConfig(
# test /chat/completion request to the proxy
from fastapi.testclient import TestClient
from fastapi import FastAPI
from litellm.proxy.proxy_server import router, save_worker_config, startup_event # Replace with the actual module where your FastAPI router is defined
from litellm.proxy.proxy_server import router, save_worker_config, initialize # Replace with the actual module where your FastAPI router is defined
filepath = os.path.dirname(os.path.abspath(__file__))
config_fp = f"{filepath}/test_configs/test_config_no_auth.yaml"
save_worker_config(config=config_fp, model=None, alias=None, api_base=None, api_version=None, debug=False, temperature=None, max_tokens=None, request_timeout=600, max_budget=None, telemetry=False, drop_params=True, add_function_to_prompt=False, headers=None, save=False, use_queue=False)
@ -31,7 +31,14 @@ app = FastAPI()
app.include_router(router) # Include your router in the test app
@app.on_event("startup")
async def wrapper_startup_event():
await startup_event()
initialize(config=config_fp)
# Your bearer token
token = os.getenv("PROXY_MASTER_KEY")
headers = {
"Authorization": f"Bearer {token}"
}
# Here you create a fixture that will be used by your tests
# Make sure the fixture returns TestClient(app)
@ -41,6 +48,7 @@ def client():
yield client
def test_chat_completion(client):
global headers
try:
# Your test data
test_data = {
@ -53,8 +61,9 @@ def test_chat_completion(client):
],
"max_tokens": 10,
}
print("testing proxy server")
response = client.post("/v1/chat/completions", json=test_data)
response = client.post("/v1/chat/completions", json=test_data, headers=headers)
print(f"response - {response.text}")
assert response.status_code == 200
result = response.json()
@ -65,6 +74,7 @@ def test_chat_completion(client):
# Run the test
def test_chat_completion_azure(client):
global headers
try:
# Your test data
test_data = {
@ -77,8 +87,9 @@ def test_chat_completion_azure(client):
],
"max_tokens": 10,
}
print("testing proxy server with Azure Request")
response = client.post("/v1/chat/completions", json=test_data)
response = client.post("/v1/chat/completions", json=test_data, headers=headers)
assert response.status_code == 200
result = response.json()
@ -92,13 +103,14 @@ def test_chat_completion_azure(client):
def test_embedding(client):
global headers
try:
test_data = {
"model": "azure/azure-embedding-model",
"input": ["good morning from litellm"],
}
print("testing proxy server with OpenAI embedding")
response = client.post("/v1/embeddings", json=test_data)
response = client.post("/v1/embeddings", json=test_data, headers=headers)
assert response.status_code == 200
result = response.json()
@ -112,6 +124,7 @@ def test_embedding(client):
def test_add_new_model(client):
global headers
try:
test_data = {
"model_name": "test_openai_models",
@ -122,8 +135,8 @@ def test_add_new_model(client):
"description": "this is a test openai model"
}
}
client.post("/model/new", json=test_data)
response = client.get("/model/info")
client.post("/model/new", json=test_data, headers=headers)
response = client.get("/model/info", headers=headers)
assert response.status_code == 200
result = response.json()
print(f"response: {result}")
@ -172,7 +185,7 @@ def test_chat_completion_optional_params(client):
litellm.callbacks = [customHandler]
print("testing proxy server: optional params")
response = client.post("/v1/chat/completions", json=test_data)
response = client.post("/v1/chat/completions", json=test_data, headers=headers)
assert response.status_code == 200
result = response.json()
print(f"Received response: {result}")

View file

@ -554,6 +554,7 @@ class Logging:
"litellm_params": self.litellm_params,
"start_time": self.start_time,
"stream": self.stream,
"user": user,
**self.optional_params
}