feat(proxy_server.py): introduces new /user/auth endpoint for handling user email auth

decouples streamlit ui from proxy server. this then requires the proxy to handle user auth separately.
This commit is contained in:
Krrish Dholakia 2024-01-01 13:42:36 +05:30
parent 52db2a6040
commit 24e7dc359d
2 changed files with 73 additions and 120 deletions

View file

@ -101,7 +101,7 @@ from typing import Union
app = FastAPI( app = FastAPI(
docs_url="/", docs_url="/",
title="LiteLLM API", title="LiteLLM API",
description="Proxy Server to call 100+ LLMs in the OpenAI format\n\nAdmin Panel on `/admin` endpoint", description="Proxy Server to call 100+ LLMs in the OpenAI format\n\nAdmin Panel on `https://dashboard.litellm.ai/admin`",
) )
router = APIRouter() router = APIRouter()
origins = ["*"] origins = ["*"]
@ -199,18 +199,27 @@ async def user_api_key_auth(
if user_custom_auth: if user_custom_auth:
response = await user_custom_auth(request=request, api_key=api_key) response = await user_custom_auth(request=request, api_key=api_key)
return UserAPIKeyAuth.model_validate(response) return UserAPIKeyAuth.model_validate(response)
### LITELLM-DEFINED AUTH FUNCTION ###
if master_key is None: if master_key is None:
if isinstance(api_key, str): if isinstance(api_key, str):
return UserAPIKeyAuth(api_key=api_key) return UserAPIKeyAuth(api_key=api_key)
else: else:
return UserAPIKeyAuth() return UserAPIKeyAuth()
route: str = request.url.path
print(f"route: {route}")
if route == "/user/auth":
if general_settings.get("allow_user_auth", False) == True:
return UserAPIKeyAuth()
else:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="'allow_user_auth' not set or set to False",
)
if api_key is None: # only require api key if master key is set if api_key is None: # only require api key if master key is set
raise Exception(f"No api key passed in.") raise Exception(f"No api key passed in.")
route: str = request.url.path
# note: never string compare api keys, this is vulenerable to a time attack. Use secrets.compare_digest instead # note: never string compare api keys, this is vulenerable to a time attack. Use secrets.compare_digest instead
is_master_key_valid = secrets.compare_digest(api_key, master_key) is_master_key_valid = secrets.compare_digest(api_key, master_key)
if is_master_key_valid: if is_master_key_valid:
@ -347,36 +356,6 @@ def load_from_azure_key_vault(use_azure_key_vault: bool = False):
) )
async def run_streamlit_ui():
# Save the current working directory
original_dir = os.getcwd()
# set the working directory to where this script is
abspath = os.path.abspath(__file__)
dname = os.path.dirname(abspath)
os.chdir(dname)
try:
# Start Streamlit without opening the browser automatically
process = subprocess.Popen(
[
"streamlit",
"run",
"admin_ui.py",
"--server.headless=true",
"--browser.serverAddress=0.0.0.0",
"--server.enableCORS=false",
]
)
# Wait for the server to start before exiting the context manager
await asyncio.sleep(1)
print("Streamlit UI server has started successfully.")
os.chdir(original_dir)
# Keep the background task running
while True:
await asyncio.sleep(3600)
except Exception as e:
print_verbose(f"Admin UI - Streamlit. An error occurred: {e}")
def cost_tracking(): def cost_tracking():
global prisma_client global prisma_client
if prisma_client is not None: if prisma_client is not None:
@ -1678,8 +1657,6 @@ async def info_key_fn(
#### USER MANAGEMENT #### #### USER MANAGEMENT ####
@router.post( @router.post(
"/user/new", "/user/new",
tags=["user management"], tags=["user management"],
@ -1719,6 +1696,57 @@ async def new_user(data: NewUserRequest):
) )
@router.post(
"/user/auth", tags=["user management"], dependencies=[Depends(user_api_key_auth)]
)
async def user_auth(request: Request):
"""
Allows UI ("https://dashboard.litellm.ai/", or self-hosted - os.getenv("LITELLM_HOSTED_UI")) to request a magic link to be sent to user email, for auth to proxy.
Only allows emails from accepted email subdomains.
Rate limit: 1 request every 60s.
Only works, if you enable 'allow_user_auth' in general settings:
e.g.:
```yaml
general_settings:
allow_user_auth: true
```
Requirements:
This uses [Resend](https://resend.com/) for sending emails. Needs these 2 keys in your .env:
```env
RESEND_API_KEY = "my-resend-api-key"
RESEND_API_EMAIL = "my-sending-email"
```
"""
data = await request.json() # type: ignore
user_email = data["user_email"]
import os
import resend
## [TODO]: Check if user exists, if so - use an existing key, if not - create new user -> return new key
response = await generate_key_helper_fn(
**{"duration": "1hr", "models": [], "aliases": {}, "config": {}, "spend": 0} # type: ignore
)
base_url = os.getenv("LITELLM_HOSTED_UI", "https://dashboard.litellm.ai/")
resend.api_key = os.getenv("RESEND_API_KEY")
params = {
"from": f"LiteLLM Proxy <{os.getenv('RESEND_API_EMAIL')}>",
"to": [user_email],
"subject": "Your Magic Link",
"html": f"<strong> Follow this link, to login:\n\n{base_url}user/?token={response['token']}&user_id={response['user_id']}</strong>",
}
email = resend.Emails.send(params)
print(email)
return "Email sent!"
@router.post( @router.post(
"/user/info", tags=["user management"], dependencies=[Depends(user_api_key_auth)] "/user/info", tags=["user management"], dependencies=[Depends(user_api_key_auth)]
) )
@ -1798,6 +1826,12 @@ async def add_new_model(model_params: ModelParams):
tags=["model management"], tags=["model management"],
dependencies=[Depends(user_api_key_auth)], dependencies=[Depends(user_api_key_auth)],
) )
@router.get(
"/v1/model/info",
description="Provides more info about each model in /models, including config.yaml descriptions (except api key and api base)",
tags=["model management"],
dependencies=[Depends(user_api_key_auth)],
)
async def model_info_v1(request: Request): async def model_info_v1(request: Request):
global llm_model_list, general_settings, user_config_file_path global llm_model_list, general_settings, user_config_file_path
# Load existing config # Load existing config
@ -1822,55 +1856,6 @@ async def model_info_v1(request: Request):
return {"data": all_models} return {"data": all_models}
#### [BETA] - This is a beta endpoint, format might change based on user feedback. - https://github.com/BerriAI/litellm/issues/933
@router.get(
"/v1/model/info",
description="Provides more info about each model in /models, including config.yaml descriptions (except api key and api base)",
tags=["model management"],
dependencies=[Depends(user_api_key_auth)],
)
async def model_info(request: Request):
global llm_model_list, general_settings, user_config_file_path
# Load existing config
with open(f"{user_config_file_path}", "r") as config_file:
config = yaml.safe_load(config_file)
all_models = config["model_list"]
for model in all_models:
# get the model cost map info
## make an api call
data = copy.deepcopy(model["litellm_params"])
data["messages"] = [{"role": "user", "content": "Hey, how's it going?"}]
data["max_tokens"] = 10
print(f"data going to litellm acompletion: {data}")
response = await litellm.acompletion(**data)
response_model = response["model"]
print(f"response model: {response_model}; response - {response}")
litellm_model_info = litellm.get_model_info(response_model)
model_info = model.get("model_info", {})
for k, v in litellm_model_info.items():
if k not in model_info:
model_info[k] = v
model["model_info"] = model_info
# don't return the api key
model["litellm_params"].pop("api_key", None)
# all_models = list(set([m["model_name"] for m in llm_model_list]))
print_verbose(f"all_models: {all_models}")
return dict(
data=[
{
"id": model,
"object": "model",
"created": 1677610602,
"owned_by": "openai",
}
for model in all_models
],
object="list",
)
#### [BETA] - This is a beta endpoint, format might change based on user feedback. - https://github.com/BerriAI/litellm/issues/964 #### [BETA] - This is a beta endpoint, format might change based on user feedback. - https://github.com/BerriAI/litellm/issues/964
@router.post( @router.post(
"/model/delete", "/model/delete",
@ -2083,36 +2068,6 @@ async def retrieve_server_log(request: Request):
return FileResponse(filepath) return FileResponse(filepath)
#### ADMIN UI ENDPOINTS ####
@router.get("/admin")
async def admin_page(request: Request):
from fastapi.responses import HTMLResponse
# Assuming your Streamlit app is running on localhost port 8501
html_content = """
<html>
<head>
<title>Admin Page</title>
<style>
html, body, iframe {
margin: 0;
padding: 0;
width: 100%;
height: 100%;
border: none;
}
</style>
</head>
<body>
<iframe src="http://localhost:8501"></iframe>
</body>
</html>
"""
return HTMLResponse(content=html_content)
#### BASIC ENDPOINTS #### #### BASIC ENDPOINTS ####

View file

@ -36,15 +36,13 @@ proxy = [
"orjson", "orjson",
] ]
proxy-ui = [
"streamlit"
]
extra_proxy = [ extra_proxy = [
"prisma", "prisma",
"azure-identity", "azure-identity",
"azure-keyvault-secrets", "azure-keyvault-secrets",
"google-cloud-kms" "google-cloud-kms",
"streamlit",
"resend"
] ]