mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 11:14:04 +00:00
fix: fix proxy testing
This commit is contained in:
parent
cf033be697
commit
9318a29fb1
4 changed files with 73 additions and 53 deletions
|
@ -168,7 +168,7 @@ def log_input_output(request, response, custom_logger=None):
|
|||
|
||||
from typing import Dict
|
||||
|
||||
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
|
||||
api_key_header = APIKeyHeader(name="Authorization", auto_error=False)
|
||||
user_api_base = None
|
||||
user_model = None
|
||||
user_debug = False
|
||||
|
@ -213,9 +213,13 @@ def usage_telemetry(
|
|||
|
||||
|
||||
|
||||
async def user_api_key_auth(request: Request, api_key: str = Depends(oauth2_scheme)) -> UserAPIKeyAuth:
|
||||
async def user_api_key_auth(request: Request, api_key: str = fastapi.Security(api_key_header)) -> UserAPIKeyAuth:
|
||||
global master_key, prisma_client, llm_model_list, user_custom_auth
|
||||
try:
|
||||
if isinstance(api_key, str):
|
||||
assert api_key.startswith("Bearer ") # ensure Bearer token passed in
|
||||
api_key = api_key.replace("Bearer ", "") # extract the token
|
||||
print(f"api_key: {api_key}; master_key: {master_key}; user_custom_auth: {user_custom_auth}")
|
||||
### USER-DEFINED AUTH FUNCTION ###
|
||||
if user_custom_auth:
|
||||
response = await user_custom_auth(request=request, api_key=api_key)
|
||||
|
@ -223,15 +227,16 @@ async def user_api_key_auth(request: Request, api_key: str = Depends(oauth2_sche
|
|||
|
||||
if master_key is None:
|
||||
if isinstance(api_key, str):
|
||||
return UserAPIKeyAuth(api_key=api_key.replace("Bearer ", ""))
|
||||
else:
|
||||
return UserAPIKeyAuth()
|
||||
if api_key is None:
|
||||
return UserAPIKeyAuth(api_key=api_key)
|
||||
else:
|
||||
return UserAPIKeyAuth()
|
||||
|
||||
if api_key is None: # only require api key if master key is set
|
||||
raise Exception("No api key passed in.")
|
||||
route = request.url.path
|
||||
|
||||
# note: never string compare api keys, this is vulenerable to a time attack. Use secrets.compare_digest instead
|
||||
is_master_key_valid = secrets.compare_digest(api_key, master_key) or secrets.compare_digest(api_key, "Bearer " + master_key)
|
||||
is_master_key_valid = secrets.compare_digest(api_key, master_key)
|
||||
if is_master_key_valid:
|
||||
return UserAPIKeyAuth(api_key=master_key)
|
||||
|
||||
|
@ -241,9 +246,9 @@ async def user_api_key_auth(request: Request, api_key: str = Depends(oauth2_sche
|
|||
if prisma_client:
|
||||
## check for cache hit (In-Memory Cache)
|
||||
valid_token = user_api_key_cache.get_cache(key=api_key)
|
||||
if valid_token is None and "Bearer " in api_key:
|
||||
if valid_token is None:
|
||||
## check db
|
||||
cleaned_api_key = api_key[len("Bearer "):]
|
||||
cleaned_api_key = api_key
|
||||
valid_token = await prisma_client.get_data(token=cleaned_api_key, expires=datetime.utcnow())
|
||||
user_api_key_cache.set_cache(key=api_key, value=valid_token, ttl=60)
|
||||
elif valid_token is not None:
|
||||
|
@ -275,10 +280,10 @@ async def user_api_key_auth(request: Request, api_key: str = Depends(oauth2_sche
|
|||
raise Exception(f"Invalid token")
|
||||
except Exception as e:
|
||||
print(f"An exception occurred - {traceback.format_exc()}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="invalid user key",
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="invalid user key",
|
||||
)
|
||||
|
||||
def prisma_setup(database_url: Optional[str]):
|
||||
global prisma_client
|
||||
|
@ -597,13 +602,17 @@ def initialize(
|
|||
config,
|
||||
use_queue
|
||||
):
|
||||
global user_model, user_api_base, user_debug, user_max_tokens, user_request_timeout, user_temperature, user_telemetry, user_headers, experimental, llm_model_list, llm_router, general_settings
|
||||
global user_model, user_api_base, user_debug, user_max_tokens, user_request_timeout, user_temperature, user_telemetry, user_headers, experimental, llm_model_list, llm_router, general_settings, master_key, user_custom_auth
|
||||
generate_feedback_box()
|
||||
user_model = model
|
||||
user_debug = debug
|
||||
dynamic_config = {"general": {}, user_model: {}}
|
||||
if config:
|
||||
llm_router, llm_model_list, general_settings = load_router_config(router=llm_router, config_file_path=config)
|
||||
else:
|
||||
# reset auth if config not passed, needed for consecutive tests on proxy
|
||||
master_key = None
|
||||
user_custom_auth = None
|
||||
if headers: # model-specific param
|
||||
user_headers = headers
|
||||
dynamic_config[user_model]["headers"] = headers
|
||||
|
@ -810,7 +819,6 @@ async def completion(request: Request, model: Optional[str] = None, user_api_key
|
|||
detail=error_msg
|
||||
)
|
||||
|
||||
|
||||
@router.post("/v1/chat/completions", dependencies=[Depends(user_api_key_auth)], tags=["chat/completions"])
|
||||
@router.post("/chat/completions", dependencies=[Depends(user_api_key_auth)], tags=["chat/completions"])
|
||||
@router.post("/openai/deployments/{model:path}/chat/completions", dependencies=[Depends(user_api_key_auth)], tags=["chat/completions"]) # azure compatible endpoint
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue