test: fix proxy server testing

This commit is contained in:
Krrish Dholakia 2023-12-06 18:38:44 -08:00
parent 19b1deb200
commit c0eedf28fc
3 changed files with 27 additions and 11 deletions

View file

@ -234,7 +234,8 @@ async def user_api_key_auth(request: Request, api_key: str = fastapi.Security(ap
return UserAPIKeyAuth() return UserAPIKeyAuth()
if api_key is None: # only require api key if master key is set if api_key is None: # only require api key if master key is set
raise Exception("No api key passed in.") raise Exception(f"No api key passed in.")
route = request.url.path route = request.url.path
# note: never string compare api keys, this is vulenerable to a time attack. Use secrets.compare_digest instead # note: never string compare api keys, this is vulenerable to a time attack. Use secrets.compare_digest instead
@ -816,11 +817,12 @@ async def startup_event():
@router.on_event("shutdown") @router.on_event("shutdown")
async def shutdown_event(): async def shutdown_event():
global prisma_client global prisma_client, master_key, user_custom_auth
if prisma_client: if prisma_client:
print("Disconnecting from Prisma") print("Disconnecting from Prisma")
await prisma_client.disconnect() await prisma_client.disconnect()
master_key = None
user_custom_auth = None
#### API ENDPOINTS #### #### API ENDPOINTS ####
@router.get("/v1/models", dependencies=[Depends(user_api_key_auth)]) @router.get("/v1/models", dependencies=[Depends(user_api_key_auth)])
@router.get("/models", dependencies=[Depends(user_api_key_auth)]) # if project requires model list @router.get("/models", dependencies=[Depends(user_api_key_auth)]) # if project requires model list

View file

@ -23,7 +23,7 @@ logging.basicConfig(
# test /chat/completion request to the proxy # test /chat/completion request to the proxy
from fastapi.testclient import TestClient from fastapi.testclient import TestClient
from fastapi import FastAPI from fastapi import FastAPI
from litellm.proxy.proxy_server import router, save_worker_config, startup_event # Replace with the actual module where your FastAPI router is defined from litellm.proxy.proxy_server import router, save_worker_config, initialize # Replace with the actual module where your FastAPI router is defined
filepath = os.path.dirname(os.path.abspath(__file__)) filepath = os.path.dirname(os.path.abspath(__file__))
config_fp = f"{filepath}/test_configs/test_config_no_auth.yaml" config_fp = f"{filepath}/test_configs/test_config_no_auth.yaml"
save_worker_config(config=config_fp, model=None, alias=None, api_base=None, api_version=None, debug=False, temperature=None, max_tokens=None, request_timeout=600, max_budget=None, telemetry=False, drop_params=True, add_function_to_prompt=False, headers=None, save=False, use_queue=False) save_worker_config(config=config_fp, model=None, alias=None, api_base=None, api_version=None, debug=False, temperature=None, max_tokens=None, request_timeout=600, max_budget=None, telemetry=False, drop_params=True, add_function_to_prompt=False, headers=None, save=False, use_queue=False)
@ -31,8 +31,15 @@ app = FastAPI()
app.include_router(router) # Include your router in the test app app.include_router(router) # Include your router in the test app
@app.on_event("startup") @app.on_event("startup")
async def wrapper_startup_event(): async def wrapper_startup_event():
await startup_event() initialize(config=config_fp)
# Your bearer token
token = os.getenv("PROXY_MASTER_KEY")
headers = {
"Authorization": f"Bearer {token}"
}
# Here you create a fixture that will be used by your tests # Here you create a fixture that will be used by your tests
# Make sure the fixture returns TestClient(app) # Make sure the fixture returns TestClient(app)
@pytest.fixture(autouse=True) @pytest.fixture(autouse=True)
@ -41,6 +48,7 @@ def client():
yield client yield client
def test_chat_completion(client): def test_chat_completion(client):
global headers
try: try:
# Your test data # Your test data
test_data = { test_data = {
@ -53,8 +61,9 @@ def test_chat_completion(client):
], ],
"max_tokens": 10, "max_tokens": 10,
} }
print("testing proxy server") print("testing proxy server")
response = client.post("/v1/chat/completions", json=test_data) response = client.post("/v1/chat/completions", json=test_data, headers=headers)
print(f"response - {response.text}") print(f"response - {response.text}")
assert response.status_code == 200 assert response.status_code == 200
result = response.json() result = response.json()
@ -65,6 +74,7 @@ def test_chat_completion(client):
# Run the test # Run the test
def test_chat_completion_azure(client): def test_chat_completion_azure(client):
global headers
try: try:
# Your test data # Your test data
test_data = { test_data = {
@ -77,8 +87,9 @@ def test_chat_completion_azure(client):
], ],
"max_tokens": 10, "max_tokens": 10,
} }
print("testing proxy server with Azure Request") print("testing proxy server with Azure Request")
response = client.post("/v1/chat/completions", json=test_data) response = client.post("/v1/chat/completions", json=test_data, headers=headers)
assert response.status_code == 200 assert response.status_code == 200
result = response.json() result = response.json()
@ -92,13 +103,14 @@ def test_chat_completion_azure(client):
def test_embedding(client): def test_embedding(client):
global headers
try: try:
test_data = { test_data = {
"model": "azure/azure-embedding-model", "model": "azure/azure-embedding-model",
"input": ["good morning from litellm"], "input": ["good morning from litellm"],
} }
print("testing proxy server with OpenAI embedding") print("testing proxy server with OpenAI embedding")
response = client.post("/v1/embeddings", json=test_data) response = client.post("/v1/embeddings", json=test_data, headers=headers)
assert response.status_code == 200 assert response.status_code == 200
result = response.json() result = response.json()
@ -112,6 +124,7 @@ def test_embedding(client):
def test_add_new_model(client): def test_add_new_model(client):
global headers
try: try:
test_data = { test_data = {
"model_name": "test_openai_models", "model_name": "test_openai_models",
@ -122,8 +135,8 @@ def test_add_new_model(client):
"description": "this is a test openai model" "description": "this is a test openai model"
} }
} }
client.post("/model/new", json=test_data) client.post("/model/new", json=test_data, headers=headers)
response = client.get("/model/info") response = client.get("/model/info", headers=headers)
assert response.status_code == 200 assert response.status_code == 200
result = response.json() result = response.json()
print(f"response: {result}") print(f"response: {result}")
@ -172,7 +185,7 @@ def test_chat_completion_optional_params(client):
litellm.callbacks = [customHandler] litellm.callbacks = [customHandler]
print("testing proxy server: optional params") print("testing proxy server: optional params")
response = client.post("/v1/chat/completions", json=test_data) response = client.post("/v1/chat/completions", json=test_data, headers=headers)
assert response.status_code == 200 assert response.status_code == 200
result = response.json() result = response.json()
print(f"Received response: {result}") print(f"Received response: {result}")

View file

@ -554,6 +554,7 @@ class Logging:
"litellm_params": self.litellm_params, "litellm_params": self.litellm_params,
"start_time": self.start_time, "start_time": self.start_time,
"stream": self.stream, "stream": self.stream,
"user": user,
**self.optional_params **self.optional_params
} }