mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 18:54:30 +00:00
test: fix proxy server testing
This commit is contained in:
parent
19b1deb200
commit
c0eedf28fc
3 changed files with 27 additions and 11 deletions
|
@ -234,7 +234,8 @@ async def user_api_key_auth(request: Request, api_key: str = fastapi.Security(ap
|
||||||
return UserAPIKeyAuth()
|
return UserAPIKeyAuth()
|
||||||
|
|
||||||
if api_key is None: # only require api key if master key is set
|
if api_key is None: # only require api key if master key is set
|
||||||
raise Exception("No api key passed in.")
|
raise Exception(f"No api key passed in.")
|
||||||
|
|
||||||
route = request.url.path
|
route = request.url.path
|
||||||
|
|
||||||
# note: never string compare api keys, this is vulenerable to a time attack. Use secrets.compare_digest instead
|
# note: never string compare api keys, this is vulenerable to a time attack. Use secrets.compare_digest instead
|
||||||
|
@ -816,11 +817,12 @@ async def startup_event():
|
||||||
|
|
||||||
@router.on_event("shutdown")
|
@router.on_event("shutdown")
|
||||||
async def shutdown_event():
|
async def shutdown_event():
|
||||||
global prisma_client
|
global prisma_client, master_key, user_custom_auth
|
||||||
if prisma_client:
|
if prisma_client:
|
||||||
print("Disconnecting from Prisma")
|
print("Disconnecting from Prisma")
|
||||||
await prisma_client.disconnect()
|
await prisma_client.disconnect()
|
||||||
|
master_key = None
|
||||||
|
user_custom_auth = None
|
||||||
#### API ENDPOINTS ####
|
#### API ENDPOINTS ####
|
||||||
@router.get("/v1/models", dependencies=[Depends(user_api_key_auth)])
|
@router.get("/v1/models", dependencies=[Depends(user_api_key_auth)])
|
||||||
@router.get("/models", dependencies=[Depends(user_api_key_auth)]) # if project requires model list
|
@router.get("/models", dependencies=[Depends(user_api_key_auth)]) # if project requires model list
|
||||||
|
|
|
@ -23,7 +23,7 @@ logging.basicConfig(
|
||||||
# test /chat/completion request to the proxy
|
# test /chat/completion request to the proxy
|
||||||
from fastapi.testclient import TestClient
|
from fastapi.testclient import TestClient
|
||||||
from fastapi import FastAPI
|
from fastapi import FastAPI
|
||||||
from litellm.proxy.proxy_server import router, save_worker_config, startup_event # Replace with the actual module where your FastAPI router is defined
|
from litellm.proxy.proxy_server import router, save_worker_config, initialize # Replace with the actual module where your FastAPI router is defined
|
||||||
filepath = os.path.dirname(os.path.abspath(__file__))
|
filepath = os.path.dirname(os.path.abspath(__file__))
|
||||||
config_fp = f"{filepath}/test_configs/test_config_no_auth.yaml"
|
config_fp = f"{filepath}/test_configs/test_config_no_auth.yaml"
|
||||||
save_worker_config(config=config_fp, model=None, alias=None, api_base=None, api_version=None, debug=False, temperature=None, max_tokens=None, request_timeout=600, max_budget=None, telemetry=False, drop_params=True, add_function_to_prompt=False, headers=None, save=False, use_queue=False)
|
save_worker_config(config=config_fp, model=None, alias=None, api_base=None, api_version=None, debug=False, temperature=None, max_tokens=None, request_timeout=600, max_budget=None, telemetry=False, drop_params=True, add_function_to_prompt=False, headers=None, save=False, use_queue=False)
|
||||||
|
@ -31,8 +31,15 @@ app = FastAPI()
|
||||||
app.include_router(router) # Include your router in the test app
|
app.include_router(router) # Include your router in the test app
|
||||||
@app.on_event("startup")
|
@app.on_event("startup")
|
||||||
async def wrapper_startup_event():
|
async def wrapper_startup_event():
|
||||||
await startup_event()
|
initialize(config=config_fp)
|
||||||
|
|
||||||
|
# Your bearer token
|
||||||
|
token = os.getenv("PROXY_MASTER_KEY")
|
||||||
|
|
||||||
|
headers = {
|
||||||
|
"Authorization": f"Bearer {token}"
|
||||||
|
}
|
||||||
|
|
||||||
# Here you create a fixture that will be used by your tests
|
# Here you create a fixture that will be used by your tests
|
||||||
# Make sure the fixture returns TestClient(app)
|
# Make sure the fixture returns TestClient(app)
|
||||||
@pytest.fixture(autouse=True)
|
@pytest.fixture(autouse=True)
|
||||||
|
@ -41,6 +48,7 @@ def client():
|
||||||
yield client
|
yield client
|
||||||
|
|
||||||
def test_chat_completion(client):
|
def test_chat_completion(client):
|
||||||
|
global headers
|
||||||
try:
|
try:
|
||||||
# Your test data
|
# Your test data
|
||||||
test_data = {
|
test_data = {
|
||||||
|
@ -53,8 +61,9 @@ def test_chat_completion(client):
|
||||||
],
|
],
|
||||||
"max_tokens": 10,
|
"max_tokens": 10,
|
||||||
}
|
}
|
||||||
|
|
||||||
print("testing proxy server")
|
print("testing proxy server")
|
||||||
response = client.post("/v1/chat/completions", json=test_data)
|
response = client.post("/v1/chat/completions", json=test_data, headers=headers)
|
||||||
print(f"response - {response.text}")
|
print(f"response - {response.text}")
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
result = response.json()
|
result = response.json()
|
||||||
|
@ -65,6 +74,7 @@ def test_chat_completion(client):
|
||||||
# Run the test
|
# Run the test
|
||||||
|
|
||||||
def test_chat_completion_azure(client):
|
def test_chat_completion_azure(client):
|
||||||
|
global headers
|
||||||
try:
|
try:
|
||||||
# Your test data
|
# Your test data
|
||||||
test_data = {
|
test_data = {
|
||||||
|
@ -77,8 +87,9 @@ def test_chat_completion_azure(client):
|
||||||
],
|
],
|
||||||
"max_tokens": 10,
|
"max_tokens": 10,
|
||||||
}
|
}
|
||||||
|
|
||||||
print("testing proxy server with Azure Request")
|
print("testing proxy server with Azure Request")
|
||||||
response = client.post("/v1/chat/completions", json=test_data)
|
response = client.post("/v1/chat/completions", json=test_data, headers=headers)
|
||||||
|
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
result = response.json()
|
result = response.json()
|
||||||
|
@ -92,13 +103,14 @@ def test_chat_completion_azure(client):
|
||||||
|
|
||||||
|
|
||||||
def test_embedding(client):
|
def test_embedding(client):
|
||||||
|
global headers
|
||||||
try:
|
try:
|
||||||
test_data = {
|
test_data = {
|
||||||
"model": "azure/azure-embedding-model",
|
"model": "azure/azure-embedding-model",
|
||||||
"input": ["good morning from litellm"],
|
"input": ["good morning from litellm"],
|
||||||
}
|
}
|
||||||
print("testing proxy server with OpenAI embedding")
|
print("testing proxy server with OpenAI embedding")
|
||||||
response = client.post("/v1/embeddings", json=test_data)
|
response = client.post("/v1/embeddings", json=test_data, headers=headers)
|
||||||
|
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
result = response.json()
|
result = response.json()
|
||||||
|
@ -112,6 +124,7 @@ def test_embedding(client):
|
||||||
|
|
||||||
|
|
||||||
def test_add_new_model(client):
|
def test_add_new_model(client):
|
||||||
|
global headers
|
||||||
try:
|
try:
|
||||||
test_data = {
|
test_data = {
|
||||||
"model_name": "test_openai_models",
|
"model_name": "test_openai_models",
|
||||||
|
@ -122,8 +135,8 @@ def test_add_new_model(client):
|
||||||
"description": "this is a test openai model"
|
"description": "this is a test openai model"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
client.post("/model/new", json=test_data)
|
client.post("/model/new", json=test_data, headers=headers)
|
||||||
response = client.get("/model/info")
|
response = client.get("/model/info", headers=headers)
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
result = response.json()
|
result = response.json()
|
||||||
print(f"response: {result}")
|
print(f"response: {result}")
|
||||||
|
@ -172,7 +185,7 @@ def test_chat_completion_optional_params(client):
|
||||||
|
|
||||||
litellm.callbacks = [customHandler]
|
litellm.callbacks = [customHandler]
|
||||||
print("testing proxy server: optional params")
|
print("testing proxy server: optional params")
|
||||||
response = client.post("/v1/chat/completions", json=test_data)
|
response = client.post("/v1/chat/completions", json=test_data, headers=headers)
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
result = response.json()
|
result = response.json()
|
||||||
print(f"Received response: {result}")
|
print(f"Received response: {result}")
|
||||||
|
|
|
@ -554,6 +554,7 @@ class Logging:
|
||||||
"litellm_params": self.litellm_params,
|
"litellm_params": self.litellm_params,
|
||||||
"start_time": self.start_time,
|
"start_time": self.start_time,
|
||||||
"stream": self.stream,
|
"stream": self.stream,
|
||||||
|
"user": user,
|
||||||
**self.optional_params
|
**self.optional_params
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue