feat(proxy_server.py): allow user to override api key auth

This commit is contained in:
Krrish Dholakia 2023-12-04 18:32:47 -08:00
parent 51cddf1e97
commit 030bd22078
9 changed files with 274 additions and 118 deletions

View file

@ -0,0 +1,14 @@
from litellm.proxy.types import UserAPIKeyAuth
from fastapi import Request
from dotenv import load_dotenv
import os
load_dotenv()
async def user_api_key_auth(request: Request, api_key: str) -> UserAPIKeyAuth:
try:
modified_master_key = f"{os.getenv('PROXY_MASTER_KEY')}-1234"
if api_key == modified_master_key:
return UserAPIKeyAuth(api_key=api_key)
raise Exception
except:
raise Exception

View file

@ -92,12 +92,16 @@ def generate_feedback_box():
import litellm import litellm
from litellm.proxy.utils import ( from litellm.proxy.utils import (
PrismaClient PrismaClient,
get_instance_fn
) )
import pydantic
from litellm.proxy.types import *
from litellm.caching import DualCache from litellm.caching import DualCache
litellm.suppress_debug_info = True litellm.suppress_debug_info = True
from fastapi import FastAPI, Request, HTTPException, status, Depends, BackgroundTasks from fastapi import FastAPI, Request, HTTPException, status, Depends, BackgroundTasks
from fastapi.routing import APIRouter from fastapi.routing import APIRouter
from fastapi.security import OAuth2PasswordBearer
from fastapi.encoders import jsonable_encoder from fastapi.encoders import jsonable_encoder
from fastapi.responses import StreamingResponse, FileResponse, ORJSONResponse from fastapi.responses import StreamingResponse, FileResponse, ORJSONResponse
from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.cors import CORSMiddleware
@ -163,70 +167,8 @@ def log_input_output(request, response, custom_logger=None):
return True return True
from typing import Dict from typing import Dict
from pydantic import BaseModel
######### Request Class Definition ######
class ProxyChatCompletionRequest(BaseModel):
model: str
messages: List[Dict[str, str]]
temperature: Optional[float] = None
top_p: Optional[float] = None
n: Optional[int] = None
stream: Optional[bool] = None
stop: Optional[List[str]] = None
max_tokens: Optional[int] = None
presence_penalty: Optional[float] = None
frequency_penalty: Optional[float] = None
logit_bias: Optional[Dict[str, float]] = None
user: Optional[str] = None
response_format: Optional[Dict[str, str]] = None
seed: Optional[int] = None
tools: Optional[List[str]] = None
tool_choice: Optional[str] = None
functions: Optional[List[str]] = None # soon to be deprecated
function_call: Optional[str] = None # soon to be deprecated
# Optional LiteLLM params
caching: Optional[bool] = None
api_base: Optional[str] = None
api_version: Optional[str] = None
api_key: Optional[str] = None
num_retries: Optional[int] = None
context_window_fallback_dict: Optional[Dict[str, str]] = None
fallbacks: Optional[List[str]] = None
metadata: Optional[Dict[str, str]] = {}
deployment_id: Optional[str] = None
request_timeout: Optional[int] = None
class Config:
extra='allow' # allow params not defined here, these fall in litellm.completion(**kwargs)
class ModelParams(BaseModel):
model_name: str
litellm_params: dict
model_info: Optional[dict]
class Config:
protected_namespaces = ()
class GenerateKeyRequest(BaseModel):
duration: str = "1h"
models: list = []
aliases: dict = {}
config: dict = {}
spend: int = 0
user_id: Optional[str] = None
class GenerateKeyResponse(BaseModel):
key: str
expires: datetime
user_id: str
class _DeleteKeyObject(BaseModel):
key: str
class DeleteKeyRequest(BaseModel):
keys: List[_DeleteKeyObject]
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
user_api_base = None user_api_base = None
user_model = None user_model = None
user_debug = False user_debug = False
@ -249,6 +191,7 @@ master_key = None
otel_logging = False otel_logging = False
prisma_client: Optional[PrismaClient] = None prisma_client: Optional[PrismaClient] = None
user_api_key_cache = DualCache() user_api_key_cache = DualCache()
user_custom_auth = None
### REDIS QUEUE ### ### REDIS QUEUE ###
async_result = None async_result = None
celery_app_conn = None celery_app_conn = None
@ -268,21 +211,21 @@ def usage_telemetry(
target=litellm.utils.litellm_telemetry, args=(data,), daemon=True target=litellm.utils.litellm_telemetry, args=(data,), daemon=True
).start() ).start()
api_key_header = APIKeyHeader(name="Authorization", auto_error=False)
async def user_api_key_auth(request: Request, api_key: str = fastapi.Security(api_key_header)):
global master_key, prisma_client, llm_model_list async def user_api_key_auth(request: Request, api_key: str = Depends(oauth2_scheme)) -> UserAPIKeyAuth:
print(f"master_key - {master_key}; api_key - {api_key}") global master_key, prisma_client, llm_model_list, user_custom_auth
if master_key is None:
if isinstance(api_key, str):
return {
"api_key": api_key.replace("Bearer ", "")
}
else:
return {
"api_key": api_key
}
try: try:
### USER-DEFINED AUTH FUNCTION ###
if user_custom_auth:
response = await user_custom_auth(request=request, api_key=api_key)
return UserAPIKeyAuth.model_validate(response)
if master_key is None:
if isinstance(api_key, str):
return UserAPIKeyAuth(api_key=api_key.replace("Bearer ", ""))
else:
return UserAPIKeyAuth()
if api_key is None: if api_key is None:
raise Exception("No api key passed in.") raise Exception("No api key passed in.")
route = request.url.path route = request.url.path
@ -290,9 +233,7 @@ async def user_api_key_auth(request: Request, api_key: str = fastapi.Security(ap
# note: never string compare api keys, this is vulenerable to a time attack. Use secrets.compare_digest instead # note: never string compare api keys, this is vulenerable to a time attack. Use secrets.compare_digest instead
is_master_key_valid = secrets.compare_digest(api_key, master_key) or secrets.compare_digest(api_key, "Bearer " + master_key) is_master_key_valid = secrets.compare_digest(api_key, master_key) or secrets.compare_digest(api_key, "Bearer " + master_key)
if is_master_key_valid: if is_master_key_valid:
return { return UserAPIKeyAuth(api_key=master_key)
"api_key": master_key
}
if (route == "/key/generate" or route == "/key/delete" or route == "/key/info") and not is_master_key_valid: if (route == "/key/generate" or route == "/key/delete" or route == "/key/info") and not is_master_key_valid:
raise Exception(f"If master key is set, only master key can be used to generate, delete or get info for new keys") raise Exception(f"If master key is set, only master key can be used to generate, delete or get info for new keys")
@ -318,7 +259,7 @@ async def user_api_key_auth(request: Request, api_key: str = fastapi.Security(ap
return_dict = {"api_key": valid_token.token} return_dict = {"api_key": valid_token.token}
if valid_token.user_id: if valid_token.user_id:
return_dict["user_id"] = valid_token.user_id return_dict["user_id"] = valid_token.user_id
return return_dict return UserAPIKeyAuth(**return_dict)
else: else:
data = await request.json() data = await request.json()
model = data.get("model", None) model = data.get("model", None)
@ -329,14 +270,14 @@ async def user_api_key_auth(request: Request, api_key: str = fastapi.Security(ap
return_dict = {"api_key": valid_token.token} return_dict = {"api_key": valid_token.token}
if valid_token.user_id: if valid_token.user_id:
return_dict["user_id"] = valid_token.user_id return_dict["user_id"] = valid_token.user_id
return return_dict return UserAPIKeyAuth(**return_dict)
else: else:
raise Exception(f"Invalid token") raise Exception(f"Invalid token")
except Exception as e: except Exception as e:
print(f"An exception occurred - {traceback.format_exc()}") print(f"An exception occurred - {traceback.format_exc()}")
raise HTTPException( raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, status_code=status.HTTP_401_UNAUTHORIZED,
detail={"error": "invalid user key"}, detail="invalid user key",
) )
def prisma_setup(database_url: Optional[str]): def prisma_setup(database_url: Optional[str]):
@ -464,7 +405,7 @@ def run_ollama_serve():
""") """)
def load_router_config(router: Optional[litellm.Router], config_file_path: str): def load_router_config(router: Optional[litellm.Router], config_file_path: str):
global master_key, user_config_file_path, otel_logging global master_key, user_config_file_path, otel_logging, user_custom_auth
config = {} config = {}
try: try:
if os.path.exists(config_file_path): if os.path.exists(config_file_path):
@ -499,7 +440,6 @@ def load_router_config(router: Optional[litellm.Router], config_file_path: str):
### LOAD FROM AZURE KEY VAULT ### ### LOAD FROM AZURE KEY VAULT ###
use_azure_key_vault = general_settings.get("use_azure_key_vault", False) use_azure_key_vault = general_settings.get("use_azure_key_vault", False)
load_from_azure_key_vault(use_azure_key_vault=use_azure_key_vault) load_from_azure_key_vault(use_azure_key_vault=use_azure_key_vault)
### CONNECT TO DATABASE ### ### CONNECT TO DATABASE ###
database_url = general_settings.get("database_url", None) database_url = general_settings.get("database_url", None)
if database_url and database_url.startswith("os.environ/"): if database_url and database_url.startswith("os.environ/"):
@ -514,12 +454,14 @@ def load_router_config(router: Optional[litellm.Router], config_file_path: str):
master_key = general_settings.get("master_key", None) master_key = general_settings.get("master_key", None)
if master_key and master_key.startswith("os.environ/"): if master_key and master_key.startswith("os.environ/"):
master_key = litellm.get_secret(master_key) master_key = litellm.get_secret(master_key)
#### OpenTelemetry Logging (OTEL) ######## #### OpenTelemetry Logging (OTEL) ########
otel_logging = general_settings.get("otel", False) otel_logging = general_settings.get("otel", False)
if otel_logging == True: if otel_logging == True:
print("\nOpenTelemetry Logging Activated") print("\nOpenTelemetry Logging Activated")
### CUSTOM API KEY AUTH ###
custom_auth = general_settings.get("custom_auth", None)
if custom_auth:
user_custom_auth = get_instance_fn(value=custom_auth, config_file_path=config_file_path)
## LITELLM MODULE SETTINGS (e.g. litellm.drop_params=True,..) ## LITELLM MODULE SETTINGS (e.g. litellm.drop_params=True,..)
litellm_settings = config.get('litellm_settings', None) litellm_settings = config.get('litellm_settings', None)
if litellm_settings: if litellm_settings:
@ -549,23 +491,7 @@ def load_router_config(router: Optional[litellm.Router], config_file_path: str):
password=cache_password password=cache_password
) )
elif key == "callbacks": elif key == "callbacks":
print(f"{blue_color_code}\nSetting custom callbacks on Proxy") litellm.callbacks = [get_instance_fn(value=value)]
passed_module, instance_name = value.split(".")
# Dynamically import the module
module = importlib.import_module(passed_module)
# Get the instance from the module
instance = getattr(module, instance_name)
methods = [method for method in dir(instance) if callable(getattr(instance, method))]
# Print the methods
print("Methods in the custom callbacks instance:")
for method in methods:
print(method)
litellm.callbacks = [instance]
print()
else: else:
setattr(litellm, key, value) setattr(litellm, key, value)
@ -844,7 +770,7 @@ def model_list():
@router.post("/v1/completions", dependencies=[Depends(user_api_key_auth)]) @router.post("/v1/completions", dependencies=[Depends(user_api_key_auth)])
@router.post("/completions", dependencies=[Depends(user_api_key_auth)]) @router.post("/completions", dependencies=[Depends(user_api_key_auth)])
@router.post("/engines/{model:path}/completions", dependencies=[Depends(user_api_key_auth)]) @router.post("/engines/{model:path}/completions", dependencies=[Depends(user_api_key_auth)])
async def completion(request: Request, model: Optional[str] = None, user_api_key_dict: dict = Depends(user_api_key_auth)): async def completion(request: Request, model: Optional[str] = None, user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth)):
try: try:
body = await request.body() body = await request.body()
body_str = body.decode() body_str = body.decode()
@ -853,7 +779,7 @@ async def completion(request: Request, model: Optional[str] = None, user_api_key
except: except:
data = json.loads(body_str) data = json.loads(body_str)
data["user"] = user_api_key_dict.get("user_id", None) data["user"] = user_api_key_dict.user_id
data["model"] = ( data["model"] = (
general_settings.get("completion_model", None) # server default general_settings.get("completion_model", None) # server default
or user_model # model name passed via cli args or user_model # model name passed via cli args
@ -864,9 +790,9 @@ async def completion(request: Request, model: Optional[str] = None, user_api_key
data["model"] = user_model data["model"] = user_model
data["call_type"] = "text_completion" data["call_type"] = "text_completion"
if "metadata" in data: if "metadata" in data:
data["metadata"]["user_api_key"] = user_api_key_dict["api_key"] data["metadata"]["user_api_key"] = user_api_key_dict.api_key
else: else:
data["metadata"] = {"user_api_key": user_api_key_dict["api_key"]} data["metadata"] = {"user_api_key": user_api_key_dict.api_key}
return litellm_completion( return litellm_completion(
**data **data
@ -888,7 +814,7 @@ async def completion(request: Request, model: Optional[str] = None, user_api_key
@router.post("/v1/chat/completions", dependencies=[Depends(user_api_key_auth)], tags=["chat/completions"]) @router.post("/v1/chat/completions", dependencies=[Depends(user_api_key_auth)], tags=["chat/completions"])
@router.post("/chat/completions", dependencies=[Depends(user_api_key_auth)], tags=["chat/completions"]) @router.post("/chat/completions", dependencies=[Depends(user_api_key_auth)], tags=["chat/completions"])
@router.post("/openai/deployments/{model:path}/chat/completions", dependencies=[Depends(user_api_key_auth)], tags=["chat/completions"]) # azure compatible endpoint @router.post("/openai/deployments/{model:path}/chat/completions", dependencies=[Depends(user_api_key_auth)], tags=["chat/completions"]) # azure compatible endpoint
async def chat_completion(request: Request, model: Optional[str] = None, user_api_key_dict: dict = Depends(user_api_key_auth), background_tasks: BackgroundTasks = BackgroundTasks()): async def chat_completion(request: Request, model: Optional[str] = None, user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), background_tasks: BackgroundTasks = BackgroundTasks()):
global general_settings, user_debug global general_settings, user_debug
try: try:
data = {} data = {}
@ -905,13 +831,13 @@ async def chat_completion(request: Request, model: Optional[str] = None, user_ap
# users can pass in 'user' param to /chat/completions. Don't override it # users can pass in 'user' param to /chat/completions. Don't override it
if data.get("user", None) is None: if data.get("user", None) is None:
# if users are using user_api_key_auth, set `user` in `data` # if users are using user_api_key_auth, set `user` in `data`
data["user"] = user_api_key_dict.get("user_id", None) data["user"] = user_api_key_dict.user_id
if "metadata" in data: if "metadata" in data:
data["metadata"]["user_api_key"] = user_api_key_dict["api_key"] data["metadata"]["user_api_key"] = user_api_key_dict.api_key
data["metadata"]["headers"] = request.headers data["metadata"]["headers"] = request.headers
else: else:
data["metadata"] = {"user_api_key": user_api_key_dict["api_key"]} data["metadata"] = {"user_api_key": user_api_key_dict.api_key}
data["metadata"]["headers"] = request.headers data["metadata"]["headers"] = request.headers
global user_temperature, user_request_timeout, user_max_tokens, user_api_base global user_temperature, user_request_timeout, user_max_tokens, user_api_base
# override with user settings, these are params passed via cli # override with user settings, these are params passed via cli
@ -962,14 +888,14 @@ async def chat_completion(request: Request, model: Optional[str] = None, user_ap
@router.post("/v1/embeddings", dependencies=[Depends(user_api_key_auth)], response_class=ORJSONResponse) @router.post("/v1/embeddings", dependencies=[Depends(user_api_key_auth)], response_class=ORJSONResponse)
@router.post("/embeddings", dependencies=[Depends(user_api_key_auth)], response_class=ORJSONResponse) @router.post("/embeddings", dependencies=[Depends(user_api_key_auth)], response_class=ORJSONResponse)
async def embeddings(request: Request, user_api_key_dict: dict = Depends(user_api_key_auth), background_tasks: BackgroundTasks = BackgroundTasks()): async def embeddings(request: Request, user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), background_tasks: BackgroundTasks = BackgroundTasks()):
try: try:
# Use orjson to parse JSON data, orjson speeds up requests significantly # Use orjson to parse JSON data, orjson speeds up requests significantly
body = await request.body() body = await request.body()
data = orjson.loads(body) data = orjson.loads(body)
data["user"] = user_api_key_dict.get("user_id", None) data["user"] = user_api_key_dict.user_id
data["model"] = ( data["model"] = (
general_settings.get("embedding_model", None) # server default general_settings.get("embedding_model", None) # server default
or user_model # model name passed via cli args or user_model # model name passed via cli args
@ -978,9 +904,9 @@ async def embeddings(request: Request, user_api_key_dict: dict = Depends(user_ap
if user_model: if user_model:
data["model"] = user_model data["model"] = user_model
if "metadata" in data: if "metadata" in data:
data["metadata"]["user_api_key"] = user_api_key_dict["api_key"] data["metadata"]["user_api_key"] = user_api_key_dict.api_key
else: else:
data["metadata"] = {"user_api_key": user_api_key_dict["api_key"]} data["metadata"] = {"user_api_key": user_api_key_dict.api_key}
## ROUTE TO CORRECT ENDPOINT ## ## ROUTE TO CORRECT ENDPOINT ##
router_model_names = [m["model_name"] for m in llm_model_list] if llm_model_list is not None else [] router_model_names = [m["model_name"] for m in llm_model_list] if llm_model_list is not None else []

70
litellm/proxy/types.py Normal file
View file

@ -0,0 +1,70 @@
from pydantic import BaseModel
from typing import Optional, List, Union, Dict
from datetime import datetime
######### Request Class Definition ######
class ProxyChatCompletionRequest(BaseModel):
model: str
messages: List[Dict[str, str]]
temperature: Optional[float] = None
top_p: Optional[float] = None
n: Optional[int] = None
stream: Optional[bool] = None
stop: Optional[List[str]] = None
max_tokens: Optional[int] = None
presence_penalty: Optional[float] = None
frequency_penalty: Optional[float] = None
logit_bias: Optional[Dict[str, float]] = None
user: Optional[str] = None
response_format: Optional[Dict[str, str]] = None
seed: Optional[int] = None
tools: Optional[List[str]] = None
tool_choice: Optional[str] = None
functions: Optional[List[str]] = None # soon to be deprecated
function_call: Optional[str] = None # soon to be deprecated
# Optional LiteLLM params
caching: Optional[bool] = None
api_base: Optional[str] = None
api_version: Optional[str] = None
api_key: Optional[str] = None
num_retries: Optional[int] = None
context_window_fallback_dict: Optional[Dict[str, str]] = None
fallbacks: Optional[List[str]] = None
metadata: Optional[Dict[str, str]] = {}
deployment_id: Optional[str] = None
request_timeout: Optional[int] = None
class Config:
extra='allow' # allow params not defined here, these fall in litellm.completion(**kwargs)
class ModelParams(BaseModel):
model_name: str
litellm_params: dict
model_info: Optional[dict]
class Config:
protected_namespaces = ()
class GenerateKeyRequest(BaseModel):
duration: str = "1h"
models: list = []
aliases: dict = {}
config: dict = {}
spend: int = 0
user_id: Optional[str] = None
class GenerateKeyResponse(BaseModel):
key: str
expires: datetime
user_id: str
class _DeleteKeyObject(BaseModel):
key: str
class DeleteKeyRequest(BaseModel):
keys: List[_DeleteKeyObject]
class UserAPIKeyAuth(BaseModel): # the expected response object for user api key auth
api_key: Optional[str] = None
user_id: Optional[str] = None

View file

@ -1,6 +1,7 @@
from typing import Optional, List, Any from typing import Optional, List, Any
import os, subprocess, hashlib import os, subprocess, hashlib, importlib
### DB CONNECTOR ###
class PrismaClient: class PrismaClient:
def __init__(self, database_url: str): def __init__(self, database_url: str):
print("LiteLLM: DATABASE_URL Set in config, trying to 'pip install prisma'") print("LiteLLM: DATABASE_URL Set in config, trying to 'pip install prisma'")
@ -95,3 +96,60 @@ class PrismaClient:
async def disconnect(self): async def disconnect(self):
await self.db.disconnect() await self.db.disconnect()
# ### CUSTOM FILE ###
# def get_instance_fn(value: str, config_file_path: Optional[str]=None):
# try:
# # Split the path by dots to separate module from instance
# parts = value.split(".")
# # The module path is all but the last part, and the instance is the last part
# module_path = ".".join(parts[:-1])
# instance_name = parts[-1]
# if config_file_path is not None:
# directory = os.path.dirname(config_file_path)
# module_path = os.path.join(directory, module_path)
# # Dynamically import the module
# module = importlib.import_module(module_path)
# # Get the instance from the module
# instance = getattr(module, instance_name)
# return instance
# except ImportError as e:
# print(e)
# raise ImportError(f"Could not import file at {value}")
def get_instance_fn(value: str, config_file_path: Optional[str] = None) -> Any:
try:
print(f"value: {value}")
# Split the path by dots to separate module from instance
parts = value.split(".")
# The module path is all but the last part, and the instance_name is the last part
module_name = ".".join(parts[:-1])
instance_name = parts[-1]
# If config_file_path is provided, use it to determine the module spec and load the module
if config_file_path is not None:
directory = os.path.dirname(config_file_path)
module_file_path = os.path.join(directory, *module_name.split('.'))
module_file_path += '.py'
spec = importlib.util.spec_from_file_location(module_name, module_file_path)
if spec is None:
raise ImportError(f"Could not find a module specification for {module_file_path}")
module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(module)
else:
# Dynamically import the module
module = importlib.import_module(module_name)
# Get the instance from the module
instance = getattr(module, instance_name)
return instance
except ImportError as e:
# Print the error message for easier debugging
print(e)
# Re-raise the exception with a user-friendly message
raise ImportError(f"Could not import {instance_name} from {module_name}") from e

View file

@ -0,0 +1,14 @@
from litellm.proxy.types import UserAPIKeyAuth
from fastapi import Request
from dotenv import load_dotenv
import os
load_dotenv()
async def user_api_key_auth(request: Request, api_key: str) -> UserAPIKeyAuth:
try:
print(f"api_key: {api_key}")
if api_key == f"{os.getenv('PROXY_MASTER_KEY')}-1234":
return UserAPIKeyAuth(api_key=api_key)
raise Exception
except:
raise Exception

View file

@ -0,0 +1,11 @@
model_list:
- model_name: "openai-model"
litellm_params:
model: "gpt-3.5-turbo"
litellm_settings:
drop_params: True
set_verbose: True
general_settings:
custom_auth: custom_auth.user_api_key_auth

View file

@ -0,0 +1,63 @@
import sys, os
import traceback
from dotenv import load_dotenv
load_dotenv()
import os, io
# this file is to test litellm/proxy
sys.path.insert(
0, os.path.abspath("../..")
) # Adds the parent directory to the system path
import pytest
import litellm
from litellm import embedding, completion, completion_cost, Timeout
from litellm import RateLimitError
# test /chat/completion request to the proxy
from fastapi.testclient import TestClient
from fastapi import FastAPI
from litellm.proxy.proxy_server import router, save_worker_config, startup_event # Replace with the actual module where your FastAPI router is defined
filepath = os.path.dirname(os.path.abspath(__file__))
config_fp = f"{filepath}/test_configs/test_config_custom_auth.yaml"
save_worker_config(config=config_fp, model=None, alias=None, api_base=None, api_version=None, debug=False, temperature=None, max_tokens=None, request_timeout=600, max_budget=None, telemetry=False, drop_params=True, add_function_to_prompt=False, headers=None, save=False, use_queue=False)
app = FastAPI()
app.include_router(router) # Include your router in the test app
@app.on_event("startup")
async def wrapper_startup_event():
await startup_event()
# Here you create a fixture that will be used by your tests
# Make sure the fixture returns TestClient(app)
@pytest.fixture(autouse=True)
def client():
with TestClient(app) as client:
yield client
def test_custom_auth(client):
try:
# Your test data
test_data = {
"model": "openai-model",
"messages": [
{
"role": "user",
"content": "hi"
},
],
"max_tokens": 10,
}
# Your bearer token
token = os.getenv("PROXY_MASTER_KEY")
headers = {
"Authorization": f"Bearer {token}"
}
response = client.post("/chat/completions", json=test_data, headers=headers)
print(f"response: {response.text}")
assert response.status_code == 401
result = response.json()
print(f"Received response: {result}")
except Exception as e:
pytest.fail("LiteLLM Proxy test failed. Exception", e)

View file

@ -25,7 +25,7 @@ from fastapi.testclient import TestClient
from fastapi import FastAPI from fastapi import FastAPI
from litellm.proxy.proxy_server import router, save_worker_config, startup_event # Replace with the actual module where your FastAPI router is defined from litellm.proxy.proxy_server import router, save_worker_config, startup_event # Replace with the actual module where your FastAPI router is defined
filepath = os.path.dirname(os.path.abspath(__file__)) filepath = os.path.dirname(os.path.abspath(__file__))
config_fp = f"{filepath}/test_config.yaml" config_fp = f"{filepath}/test_configs/test_config.yaml"
save_worker_config(config=config_fp, model=None, alias=None, api_base=None, api_version=None, debug=False, temperature=None, max_tokens=None, request_timeout=600, max_budget=None, telemetry=False, drop_params=True, add_function_to_prompt=False, headers=None, save=False, use_queue=False) save_worker_config(config=config_fp, model=None, alias=None, api_base=None, api_version=None, debug=False, temperature=None, max_tokens=None, request_timeout=600, max_budget=None, telemetry=False, drop_params=True, add_function_to_prompt=False, headers=None, save=False, use_queue=False)
app = FastAPI() app = FastAPI()
app.include_router(router) # Include your router in the test app app.include_router(router) # Include your router in the test app