mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 11:14:04 +00:00
fix: fix proxy testing
This commit is contained in:
parent
a602d59645
commit
b46c73a46e
4 changed files with 73 additions and 53 deletions
|
@ -168,7 +168,7 @@ def log_input_output(request, response, custom_logger=None):
|
||||||
|
|
||||||
from typing import Dict
|
from typing import Dict
|
||||||
|
|
||||||
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
|
api_key_header = APIKeyHeader(name="Authorization", auto_error=False)
|
||||||
user_api_base = None
|
user_api_base = None
|
||||||
user_model = None
|
user_model = None
|
||||||
user_debug = False
|
user_debug = False
|
||||||
|
@ -213,9 +213,13 @@ def usage_telemetry(
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
async def user_api_key_auth(request: Request, api_key: str = Depends(oauth2_scheme)) -> UserAPIKeyAuth:
|
async def user_api_key_auth(request: Request, api_key: str = fastapi.Security(api_key_header)) -> UserAPIKeyAuth:
|
||||||
global master_key, prisma_client, llm_model_list, user_custom_auth
|
global master_key, prisma_client, llm_model_list, user_custom_auth
|
||||||
try:
|
try:
|
||||||
|
if isinstance(api_key, str):
|
||||||
|
assert api_key.startswith("Bearer ") # ensure Bearer token passed in
|
||||||
|
api_key = api_key.replace("Bearer ", "") # extract the token
|
||||||
|
print(f"api_key: {api_key}; master_key: {master_key}; user_custom_auth: {user_custom_auth}")
|
||||||
### USER-DEFINED AUTH FUNCTION ###
|
### USER-DEFINED AUTH FUNCTION ###
|
||||||
if user_custom_auth:
|
if user_custom_auth:
|
||||||
response = await user_custom_auth(request=request, api_key=api_key)
|
response = await user_custom_auth(request=request, api_key=api_key)
|
||||||
|
@ -223,15 +227,16 @@ async def user_api_key_auth(request: Request, api_key: str = Depends(oauth2_sche
|
||||||
|
|
||||||
if master_key is None:
|
if master_key is None:
|
||||||
if isinstance(api_key, str):
|
if isinstance(api_key, str):
|
||||||
return UserAPIKeyAuth(api_key=api_key.replace("Bearer ", ""))
|
return UserAPIKeyAuth(api_key=api_key)
|
||||||
else:
|
else:
|
||||||
return UserAPIKeyAuth()
|
return UserAPIKeyAuth()
|
||||||
if api_key is None:
|
|
||||||
|
if api_key is None: # only require api key if master key is set
|
||||||
raise Exception("No api key passed in.")
|
raise Exception("No api key passed in.")
|
||||||
route = request.url.path
|
route = request.url.path
|
||||||
|
|
||||||
# note: never string compare api keys, this is vulenerable to a time attack. Use secrets.compare_digest instead
|
# note: never string compare api keys, this is vulenerable to a time attack. Use secrets.compare_digest instead
|
||||||
is_master_key_valid = secrets.compare_digest(api_key, master_key) or secrets.compare_digest(api_key, "Bearer " + master_key)
|
is_master_key_valid = secrets.compare_digest(api_key, master_key)
|
||||||
if is_master_key_valid:
|
if is_master_key_valid:
|
||||||
return UserAPIKeyAuth(api_key=master_key)
|
return UserAPIKeyAuth(api_key=master_key)
|
||||||
|
|
||||||
|
@ -241,9 +246,9 @@ async def user_api_key_auth(request: Request, api_key: str = Depends(oauth2_sche
|
||||||
if prisma_client:
|
if prisma_client:
|
||||||
## check for cache hit (In-Memory Cache)
|
## check for cache hit (In-Memory Cache)
|
||||||
valid_token = user_api_key_cache.get_cache(key=api_key)
|
valid_token = user_api_key_cache.get_cache(key=api_key)
|
||||||
if valid_token is None and "Bearer " in api_key:
|
if valid_token is None:
|
||||||
## check db
|
## check db
|
||||||
cleaned_api_key = api_key[len("Bearer "):]
|
cleaned_api_key = api_key
|
||||||
valid_token = await prisma_client.get_data(token=cleaned_api_key, expires=datetime.utcnow())
|
valid_token = await prisma_client.get_data(token=cleaned_api_key, expires=datetime.utcnow())
|
||||||
user_api_key_cache.set_cache(key=api_key, value=valid_token, ttl=60)
|
user_api_key_cache.set_cache(key=api_key, value=valid_token, ttl=60)
|
||||||
elif valid_token is not None:
|
elif valid_token is not None:
|
||||||
|
@ -597,13 +602,17 @@ def initialize(
|
||||||
config,
|
config,
|
||||||
use_queue
|
use_queue
|
||||||
):
|
):
|
||||||
global user_model, user_api_base, user_debug, user_max_tokens, user_request_timeout, user_temperature, user_telemetry, user_headers, experimental, llm_model_list, llm_router, general_settings
|
global user_model, user_api_base, user_debug, user_max_tokens, user_request_timeout, user_temperature, user_telemetry, user_headers, experimental, llm_model_list, llm_router, general_settings, master_key, user_custom_auth
|
||||||
generate_feedback_box()
|
generate_feedback_box()
|
||||||
user_model = model
|
user_model = model
|
||||||
user_debug = debug
|
user_debug = debug
|
||||||
dynamic_config = {"general": {}, user_model: {}}
|
dynamic_config = {"general": {}, user_model: {}}
|
||||||
if config:
|
if config:
|
||||||
llm_router, llm_model_list, general_settings = load_router_config(router=llm_router, config_file_path=config)
|
llm_router, llm_model_list, general_settings = load_router_config(router=llm_router, config_file_path=config)
|
||||||
|
else:
|
||||||
|
# reset auth if config not passed, needed for consecutive tests on proxy
|
||||||
|
master_key = None
|
||||||
|
user_custom_auth = None
|
||||||
if headers: # model-specific param
|
if headers: # model-specific param
|
||||||
user_headers = headers
|
user_headers = headers
|
||||||
dynamic_config[user_model]["headers"] = headers
|
dynamic_config[user_model]["headers"] = headers
|
||||||
|
@ -810,7 +819,6 @@ async def completion(request: Request, model: Optional[str] = None, user_api_key
|
||||||
detail=error_msg
|
detail=error_msg
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@router.post("/v1/chat/completions", dependencies=[Depends(user_api_key_auth)], tags=["chat/completions"])
|
@router.post("/v1/chat/completions", dependencies=[Depends(user_api_key_auth)], tags=["chat/completions"])
|
||||||
@router.post("/chat/completions", dependencies=[Depends(user_api_key_auth)], tags=["chat/completions"])
|
@router.post("/chat/completions", dependencies=[Depends(user_api_key_auth)], tags=["chat/completions"])
|
||||||
@router.post("/openai/deployments/{model:path}/chat/completions", dependencies=[Depends(user_api_key_auth)], tags=["chat/completions"]) # azure compatible endpoint
|
@router.post("/openai/deployments/{model:path}/chat/completions", dependencies=[Depends(user_api_key_auth)], tags=["chat/completions"]) # azure compatible endpoint
|
||||||
|
|
|
@ -1,24 +1,27 @@
|
||||||
model_list:
|
|
||||||
- model_name: "azure-model"
|
|
||||||
litellm_params:
|
|
||||||
model: "azure/gpt-35-turbo"
|
|
||||||
api_key: "os.environ/AZURE_EUROPE_API_KEY"
|
|
||||||
api_base: "https://my-endpoint-europe-berri-992.openai.azure.com/"
|
|
||||||
- model_name: "azure-model"
|
|
||||||
litellm_params:
|
|
||||||
model: "azure/gpt-35-turbo"
|
|
||||||
api_key: "os.environ/AZURE_CANADA_API_KEY"
|
|
||||||
api_base: "https://my-endpoint-canada-berri992.openai.azure.com"
|
|
||||||
- model_name: "azure-model"
|
|
||||||
litellm_params:
|
|
||||||
model: "azure/gpt-turbo"
|
|
||||||
api_key: "os.environ/AZURE_FRANCE_API_KEY"
|
|
||||||
api_base: "https://openai-france-1234.openai.azure.com"
|
|
||||||
|
|
||||||
litellm_settings:
|
|
||||||
drop_params: True
|
|
||||||
set_verbose: True
|
|
||||||
|
|
||||||
general_settings:
|
general_settings:
|
||||||
master_key: "os.environ/PROXY_MASTER_KEY"
|
database_url: os.environ/PROXY_DATABASE_URL
|
||||||
database_url: "os.environ/PROXY_DATABASE_URL" # [OPTIONAL] use for token-based auth to proxy
|
master_key: os.environ/PROXY_MASTER_KEY
|
||||||
|
litellm_settings:
|
||||||
|
drop_params: true
|
||||||
|
set_verbose: true
|
||||||
|
model_list:
|
||||||
|
- litellm_params:
|
||||||
|
api_base: https://my-endpoint-europe-berri-992.openai.azure.com/
|
||||||
|
api_key: os.environ/AZURE_EUROPE_API_KEY
|
||||||
|
model: azure/gpt-35-turbo
|
||||||
|
model_name: azure-model
|
||||||
|
- litellm_params:
|
||||||
|
api_base: https://my-endpoint-canada-berri992.openai.azure.com
|
||||||
|
api_key: os.environ/AZURE_CANADA_API_KEY
|
||||||
|
model: azure/gpt-35-turbo
|
||||||
|
model_name: azure-model
|
||||||
|
- litellm_params:
|
||||||
|
api_base: https://openai-france-1234.openai.azure.com
|
||||||
|
api_key: os.environ/AZURE_FRANCE_API_KEY
|
||||||
|
model: azure/gpt-turbo
|
||||||
|
model_name: azure-model
|
||||||
|
- litellm_params:
|
||||||
|
model: gpt-3.5-turbo
|
||||||
|
model_info:
|
||||||
|
description: this is a test openai model
|
||||||
|
model_name: test_openai_models
|
||||||
|
|
|
@ -18,7 +18,7 @@ from litellm import RateLimitError
|
||||||
# test /chat/completion request to the proxy
|
# test /chat/completion request to the proxy
|
||||||
from fastapi.testclient import TestClient
|
from fastapi.testclient import TestClient
|
||||||
from fastapi import FastAPI
|
from fastapi import FastAPI
|
||||||
from litellm.proxy.proxy_server import router, save_worker_config, startup_event # Replace with the actual module where your FastAPI router is defined
|
from litellm.proxy.proxy_server import router, save_worker_config, initialize # Replace with the actual module where your FastAPI router is defined
|
||||||
filepath = os.path.dirname(os.path.abspath(__file__))
|
filepath = os.path.dirname(os.path.abspath(__file__))
|
||||||
config_fp = f"{filepath}/test_configs/test_config_custom_auth.yaml"
|
config_fp = f"{filepath}/test_configs/test_config_custom_auth.yaml"
|
||||||
save_worker_config(config=config_fp, model=None, alias=None, api_base=None, api_version=None, debug=False, temperature=None, max_tokens=None, request_timeout=600, max_budget=None, telemetry=False, drop_params=True, add_function_to_prompt=False, headers=None, save=False, use_queue=False)
|
save_worker_config(config=config_fp, model=None, alias=None, api_base=None, api_version=None, debug=False, temperature=None, max_tokens=None, request_timeout=600, max_budget=None, telemetry=False, drop_params=True, add_function_to_prompt=False, headers=None, save=False, use_queue=False)
|
||||||
|
@ -26,7 +26,7 @@ app = FastAPI()
|
||||||
app.include_router(router) # Include your router in the test app
|
app.include_router(router) # Include your router in the test app
|
||||||
@app.on_event("startup")
|
@app.on_event("startup")
|
||||||
async def wrapper_startup_event():
|
async def wrapper_startup_event():
|
||||||
await startup_event()
|
initialize(config=config_fp, model=None, alias=None, api_base=None, api_version=None, debug=False, temperature=None, max_tokens=None, request_timeout=600, max_budget=None, telemetry=False, drop_params=True, add_function_to_prompt=False, headers=None, save=False, use_queue=False)
|
||||||
|
|
||||||
# Here you create a fixture that will be used by your tests
|
# Here you create a fixture that will be used by your tests
|
||||||
# Make sure the fixture returns TestClient(app)
|
# Make sure the fixture returns TestClient(app)
|
||||||
|
|
|
@ -18,11 +18,22 @@ from litellm import RateLimitError
|
||||||
# test /chat/completion request to the proxy
|
# test /chat/completion request to the proxy
|
||||||
from fastapi.testclient import TestClient
|
from fastapi.testclient import TestClient
|
||||||
from fastapi import FastAPI
|
from fastapi import FastAPI
|
||||||
from litellm.proxy.proxy_server import router # Replace with the actual module where your FastAPI router is defined
|
from litellm.proxy.proxy_server import router, save_worker_config, initialize # Replace with the actual module where your FastAPI router is defined
|
||||||
|
save_worker_config(config=None, model=None, alias=None, api_base=None, api_version=None, debug=False, temperature=None, max_tokens=None, request_timeout=600, max_budget=None, telemetry=False, drop_params=True, add_function_to_prompt=False, headers=None, save=False, use_queue=False)
|
||||||
app = FastAPI()
|
app = FastAPI()
|
||||||
app.include_router(router) # Include your router in the test app
|
app.include_router(router) # Include your router in the test app
|
||||||
client = TestClient(app)
|
@app.on_event("startup")
|
||||||
def test_chat_completion():
|
async def wrapper_startup_event(): # required to reset config on app init - b/c pytest collects across multiple files - which sets the fastapi client + WORKER CONFIG to whatever was collected last
|
||||||
|
initialize(config=None, model=None, alias=None, api_base=None, api_version=None, debug=False, temperature=None, max_tokens=None, request_timeout=600, max_budget=None, telemetry=False, drop_params=True, add_function_to_prompt=False, headers=None, save=False, use_queue=False)
|
||||||
|
|
||||||
|
# Here you create a fixture that will be used by your tests
|
||||||
|
# Make sure the fixture returns TestClient(app)
|
||||||
|
@pytest.fixture(autouse=True)
|
||||||
|
def client():
|
||||||
|
with TestClient(app) as client:
|
||||||
|
yield client
|
||||||
|
|
||||||
|
def test_chat_completion(client):
|
||||||
try:
|
try:
|
||||||
# Your test data
|
# Your test data
|
||||||
test_data = {
|
test_data = {
|
||||||
|
@ -37,18 +48,16 @@ def test_chat_completion():
|
||||||
}
|
}
|
||||||
print("testing proxy server")
|
print("testing proxy server")
|
||||||
response = client.post("/v1/chat/completions", json=test_data)
|
response = client.post("/v1/chat/completions", json=test_data)
|
||||||
|
print(f"response - {response.text}")
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
result = response.json()
|
result = response.json()
|
||||||
print(f"Received response: {result}")
|
print(f"Received response: {result}")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
pytest.fail("LiteLLM Proxy test failed. Exception", e)
|
pytest.fail(f"LiteLLM Proxy test failed. Exception - {str(e)}")
|
||||||
|
|
||||||
# Run the test
|
# Run the test
|
||||||
# test_chat_completion()
|
|
||||||
|
|
||||||
|
def test_chat_completion_azure(client):
|
||||||
def test_chat_completion_azure():
|
|
||||||
try:
|
try:
|
||||||
# Your test data
|
# Your test data
|
||||||
test_data = {
|
test_data = {
|
||||||
|
@ -69,13 +78,13 @@ def test_chat_completion_azure():
|
||||||
print(f"Received response: {result}")
|
print(f"Received response: {result}")
|
||||||
assert len(result["choices"][0]["message"]["content"]) > 0
|
assert len(result["choices"][0]["message"]["content"]) > 0
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
pytest.fail("LiteLLM Proxy test failed. Exception", e)
|
pytest.fail(f"LiteLLM Proxy test failed. Exception - {str(e)}")
|
||||||
|
|
||||||
# Run the test
|
# Run the test
|
||||||
# test_chat_completion_azure()
|
# test_chat_completion_azure()
|
||||||
|
|
||||||
|
|
||||||
def test_embedding():
|
def test_embedding(client):
|
||||||
try:
|
try:
|
||||||
test_data = {
|
test_data = {
|
||||||
"model": "azure/azure-embedding-model",
|
"model": "azure/azure-embedding-model",
|
||||||
|
@ -89,13 +98,13 @@ def test_embedding():
|
||||||
print(len(result["data"][0]["embedding"]))
|
print(len(result["data"][0]["embedding"]))
|
||||||
assert len(result["data"][0]["embedding"]) > 10 # this usually has len==1536 so
|
assert len(result["data"][0]["embedding"]) > 10 # this usually has len==1536 so
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
pytest.fail("LiteLLM Proxy test failed. Exception", e)
|
pytest.fail(f"LiteLLM Proxy test failed. Exception - {str(e)}")
|
||||||
|
|
||||||
# Run the test
|
# Run the test
|
||||||
# test_embedding()
|
# test_embedding()
|
||||||
|
|
||||||
|
|
||||||
def test_add_new_model():
|
def test_add_new_model(client):
|
||||||
try:
|
try:
|
||||||
test_data = {
|
test_data = {
|
||||||
"model_name": "test_openai_models",
|
"model_name": "test_openai_models",
|
||||||
|
@ -135,7 +144,7 @@ class MyCustomHandler(CustomLogger):
|
||||||
customHandler = MyCustomHandler()
|
customHandler = MyCustomHandler()
|
||||||
|
|
||||||
|
|
||||||
def test_chat_completion_optional_params():
|
def test_chat_completion_optional_params(client):
|
||||||
# [PROXY: PROD TEST] - DO NOT DELETE
|
# [PROXY: PROD TEST] - DO NOT DELETE
|
||||||
# This tests if all the /chat/completion params are passed to litellm
|
# This tests if all the /chat/completion params are passed to litellm
|
||||||
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue