mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 18:54:30 +00:00
test: fix proxy server testing
This commit is contained in:
parent
19b1deb200
commit
c0eedf28fc
3 changed files with 27 additions and 11 deletions
|
@ -234,7 +234,8 @@ async def user_api_key_auth(request: Request, api_key: str = fastapi.Security(ap
|
|||
return UserAPIKeyAuth()
|
||||
|
||||
if api_key is None: # only require api key if master key is set
|
||||
raise Exception("No api key passed in.")
|
||||
raise Exception(f"No api key passed in.")
|
||||
|
||||
route = request.url.path
|
||||
|
||||
# note: never string compare api keys, this is vulenerable to a time attack. Use secrets.compare_digest instead
|
||||
|
@ -816,11 +817,12 @@ async def startup_event():
|
|||
|
||||
@router.on_event("shutdown")
|
||||
async def shutdown_event():
|
||||
global prisma_client
|
||||
global prisma_client, master_key, user_custom_auth
|
||||
if prisma_client:
|
||||
print("Disconnecting from Prisma")
|
||||
await prisma_client.disconnect()
|
||||
|
||||
master_key = None
|
||||
user_custom_auth = None
|
||||
#### API ENDPOINTS ####
|
||||
@router.get("/v1/models", dependencies=[Depends(user_api_key_auth)])
|
||||
@router.get("/models", dependencies=[Depends(user_api_key_auth)]) # if project requires model list
|
||||
|
|
|
@ -23,7 +23,7 @@ logging.basicConfig(
|
|||
# test /chat/completion request to the proxy
|
||||
from fastapi.testclient import TestClient
|
||||
from fastapi import FastAPI
|
||||
from litellm.proxy.proxy_server import router, save_worker_config, startup_event # Replace with the actual module where your FastAPI router is defined
|
||||
from litellm.proxy.proxy_server import router, save_worker_config, initialize # Replace with the actual module where your FastAPI router is defined
|
||||
filepath = os.path.dirname(os.path.abspath(__file__))
|
||||
config_fp = f"{filepath}/test_configs/test_config_no_auth.yaml"
|
||||
save_worker_config(config=config_fp, model=None, alias=None, api_base=None, api_version=None, debug=False, temperature=None, max_tokens=None, request_timeout=600, max_budget=None, telemetry=False, drop_params=True, add_function_to_prompt=False, headers=None, save=False, use_queue=False)
|
||||
|
@ -31,8 +31,15 @@ app = FastAPI()
|
|||
app.include_router(router) # Include your router in the test app
|
||||
@app.on_event("startup")
|
||||
async def wrapper_startup_event():
|
||||
await startup_event()
|
||||
initialize(config=config_fp)
|
||||
|
||||
# Your bearer token
|
||||
token = os.getenv("PROXY_MASTER_KEY")
|
||||
|
||||
headers = {
|
||||
"Authorization": f"Bearer {token}"
|
||||
}
|
||||
|
||||
# Here you create a fixture that will be used by your tests
|
||||
# Make sure the fixture returns TestClient(app)
|
||||
@pytest.fixture(autouse=True)
|
||||
|
@ -41,6 +48,7 @@ def client():
|
|||
yield client
|
||||
|
||||
def test_chat_completion(client):
|
||||
global headers
|
||||
try:
|
||||
# Your test data
|
||||
test_data = {
|
||||
|
@ -53,8 +61,9 @@ def test_chat_completion(client):
|
|||
],
|
||||
"max_tokens": 10,
|
||||
}
|
||||
|
||||
print("testing proxy server")
|
||||
response = client.post("/v1/chat/completions", json=test_data)
|
||||
response = client.post("/v1/chat/completions", json=test_data, headers=headers)
|
||||
print(f"response - {response.text}")
|
||||
assert response.status_code == 200
|
||||
result = response.json()
|
||||
|
@ -65,6 +74,7 @@ def test_chat_completion(client):
|
|||
# Run the test
|
||||
|
||||
def test_chat_completion_azure(client):
|
||||
global headers
|
||||
try:
|
||||
# Your test data
|
||||
test_data = {
|
||||
|
@ -77,8 +87,9 @@ def test_chat_completion_azure(client):
|
|||
],
|
||||
"max_tokens": 10,
|
||||
}
|
||||
|
||||
print("testing proxy server with Azure Request")
|
||||
response = client.post("/v1/chat/completions", json=test_data)
|
||||
response = client.post("/v1/chat/completions", json=test_data, headers=headers)
|
||||
|
||||
assert response.status_code == 200
|
||||
result = response.json()
|
||||
|
@ -92,13 +103,14 @@ def test_chat_completion_azure(client):
|
|||
|
||||
|
||||
def test_embedding(client):
|
||||
global headers
|
||||
try:
|
||||
test_data = {
|
||||
"model": "azure/azure-embedding-model",
|
||||
"input": ["good morning from litellm"],
|
||||
}
|
||||
print("testing proxy server with OpenAI embedding")
|
||||
response = client.post("/v1/embeddings", json=test_data)
|
||||
response = client.post("/v1/embeddings", json=test_data, headers=headers)
|
||||
|
||||
assert response.status_code == 200
|
||||
result = response.json()
|
||||
|
@ -112,6 +124,7 @@ def test_embedding(client):
|
|||
|
||||
|
||||
def test_add_new_model(client):
|
||||
global headers
|
||||
try:
|
||||
test_data = {
|
||||
"model_name": "test_openai_models",
|
||||
|
@ -122,8 +135,8 @@ def test_add_new_model(client):
|
|||
"description": "this is a test openai model"
|
||||
}
|
||||
}
|
||||
client.post("/model/new", json=test_data)
|
||||
response = client.get("/model/info")
|
||||
client.post("/model/new", json=test_data, headers=headers)
|
||||
response = client.get("/model/info", headers=headers)
|
||||
assert response.status_code == 200
|
||||
result = response.json()
|
||||
print(f"response: {result}")
|
||||
|
@ -172,7 +185,7 @@ def test_chat_completion_optional_params(client):
|
|||
|
||||
litellm.callbacks = [customHandler]
|
||||
print("testing proxy server: optional params")
|
||||
response = client.post("/v1/chat/completions", json=test_data)
|
||||
response = client.post("/v1/chat/completions", json=test_data, headers=headers)
|
||||
assert response.status_code == 200
|
||||
result = response.json()
|
||||
print(f"Received response: {result}")
|
||||
|
|
|
@ -554,6 +554,7 @@ class Logging:
|
|||
"litellm_params": self.litellm_params,
|
||||
"start_time": self.start_time,
|
||||
"stream": self.stream,
|
||||
"user": user,
|
||||
**self.optional_params
|
||||
}
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue