Merge branch 'main' into litellm_no_store_cache_control

This commit is contained in:
Krish Dholakia 2024-01-30 21:44:57 -08:00 committed by GitHub
commit ce415a243d
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
7 changed files with 266 additions and 31 deletions

View file

@ -76,6 +76,7 @@ from litellm.proxy.utils import (
get_logging_payload,
reset_budget,
hash_token,
html_form,
)
from litellm.proxy.secret_managers.google_kms import load_google_kms
import pydantic
@ -94,6 +95,7 @@ from fastapi import (
BackgroundTasks,
Header,
Response,
Form,
)
from fastapi.routing import APIRouter
from fastapi.security import OAuth2PasswordBearer
@ -1268,7 +1270,7 @@ async def generate_key_helper_fn(
key_alias: Optional[str] = None,
allowed_cache_controls: Optional[list] = [],
):
global prisma_client, custom_db_client
global prisma_client, custom_db_client, user_api_key_cache
if prisma_client is None and custom_db_client is None:
raise Exception(
@ -1361,6 +1363,18 @@ async def generate_key_helper_fn(
}
if general_settings.get("allow_user_auth", False) == True:
key_data["key_name"] = f"sk-...{token[-4:]}"
saved_token = copy.deepcopy(key_data)
if isinstance(saved_token["aliases"], str):
saved_token["aliases"] = json.loads(saved_token["aliases"])
if isinstance(saved_token["config"], str):
saved_token["config"] = json.loads(saved_token["config"])
if isinstance(saved_token["metadata"], str):
saved_token["metadata"] = json.loads(saved_token["metadata"])
user_api_key_cache.set_cache(
key=key_data["token"],
value=LiteLLM_VerificationToken(**saved_token), # type: ignore
ttl=60,
)
if prisma_client is not None:
## CREATE USER (If necessary)
verbose_proxy_logger.debug(f"prisma_client: Creating User={user_data}")
@ -1675,14 +1689,16 @@ async def startup_event():
if prisma_client is not None and master_key is not None:
# add master key to db
await generate_key_helper_fn(
duration=None,
models=[],
aliases={},
config={},
spend=0,
token=master_key,
user_id="default_user_id",
asyncio.create_task(
generate_key_helper_fn(
duration=None,
models=[],
aliases={},
config={},
spend=0,
token=master_key,
user_id="default_user_id",
)
)
if prisma_client is not None and litellm.max_budget > 0:
@ -1692,20 +1708,22 @@ async def startup_event():
)
# add proxy budget to db in the user table
await generate_key_helper_fn(
user_id=litellm_proxy_budget_name,
duration=None,
models=[],
aliases={},
config={},
spend=0,
max_budget=litellm.max_budget,
budget_duration=litellm.budget_duration,
query_type="update_data",
update_key_values={
"max_budget": litellm.max_budget,
"budget_duration": litellm.budget_duration,
},
asyncio.create_task(
generate_key_helper_fn(
user_id=litellm_proxy_budget_name,
duration=None,
models=[],
aliases={},
config={},
spend=0,
max_budget=litellm.max_budget,
budget_duration=litellm.budget_duration,
query_type="update_data",
update_key_values={
"max_budget": litellm.max_budget,
"budget_duration": litellm.budget_duration,
},
)
)
verbose_proxy_logger.debug(
@ -2962,6 +2980,60 @@ async def google_login(request: Request):
)
with microsoft_sso:
return await microsoft_sso.get_login_redirect()
else:
# No Google, Microsoft SSO
# Use UI Credentials set in .env
from fastapi.responses import HTMLResponse
return HTMLResponse(content=html_form, status_code=200)
@router.post(
"/login", include_in_schema=False
) # hidden since this is a helper for UI sso login
async def login(request: Request):
try:
import multipart
except ImportError:
subprocess.run(["pip", "install", "python-multipart"])
form = await request.form()
username = str(form.get("username"))
password = form.get("password")
ui_username = os.getenv("UI_USERNAME")
ui_password = os.getenv("UI_PASSWORD")
if username == ui_username and password == ui_password:
user_id = username
response = await generate_key_helper_fn(
**{"duration": "24hr", "models": [], "aliases": {}, "config": {}, "spend": 0, "user_id": user_id, "team_id": "litellm-dashboard"} # type: ignore
)
key = response["token"] # type: ignore
user_id = response["user_id"] # type: ignore
litellm_dashboard_ui = "https://litellm-dashboard.vercel.app/"
# if user set LITELLM_UI_LINK in .env, use that
litellm_ui_link_in_env = os.getenv("LITELLM_UI_LINK", None)
if litellm_ui_link_in_env is not None:
litellm_dashboard_ui = litellm_ui_link_in_env
litellm_dashboard_ui += (
"?userID="
+ user_id
+ "&accessToken="
+ key
+ "&proxyBaseUrl="
+ os.getenv("PROXY_BASE_URL")
)
return RedirectResponse(url=litellm_dashboard_ui)
else:
raise ProxyException(
message=f"Invalid credentials used to access UI. Passed in username: {username}, passed in password: {password}.\nCheck 'UI_USERNAME', 'UI_PASSWORD' in .env file",
type="auth_error",
param="invalid_credentials",
code=status.HTTP_401_UNAUTHORIZED,
)
@app.get("/sso/callback", tags=["experimental"])