diff --git a/litellm/main.py b/litellm/main.py index 322127de4..352ce1882 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -10,6 +10,7 @@ import os, openai, sys, json, inspect, uuid, datetime, threading from typing import Any, Literal, Union from functools import partial + import dotenv, traceback, random, asyncio, time, contextvars from copy import deepcopy import httpx diff --git a/litellm/proxy/_types.py b/litellm/proxy/_types.py index c85564231..7a6e400c8 100644 --- a/litellm/proxy/_types.py +++ b/litellm/proxy/_types.py @@ -21,10 +21,10 @@ class LiteLLMBase(BaseModel): def json(self, **kwargs): try: - return self.model_dump() # noqa + return self.model_dump(**kwargs) # noqa except Exception as e: # if using pydantic v1 - return self.dict() + return self.dict(**kwargs) def fields_set(self): try: @@ -211,6 +211,38 @@ class UpdateUserRequest(GenerateRequestBase): max_budget: Optional[float] = None +class NewTeamRequest(LiteLLMBase): + team_alias: Optional[str] = None + team_id: Optional[str] = None + admins: list = [] + members: list = [] + metadata: Optional[dict] = None + + +class LiteLLM_TeamTable(NewTeamRequest): + max_budget: Optional[float] = None + spend: Optional[float] = None + models: list = [] + max_parallel_requests: Optional[int] = None + tpm_limit: Optional[int] = None + rpm_limit: Optional[int] = None + budget_duration: Optional[str] = None + budget_reset_at: Optional[datetime] = None + + +class NewTeamResponse(LiteLLMBase): + team_id: str + admins: list + members: list + metadata: dict + created_at: datetime + updated_at: datetime + + +class TeamRequest(LiteLLMBase): + teams: List[str] + + class KeyManagementSystem(enum.Enum): GOOGLE_KMS = "google_kms" AZURE_KEY_VAULT = "azure_key_vault" diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index 24bd59b8a..07e08ac61 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -2799,6 +2799,7 @@ async def image_generation( ) + @router.post( "/v1/moderations", dependencies=[Depends(user_api_key_auth)], @@ -2973,7 +2974,8 @@ async def generate_key_fn( Parameters: - duration: Optional[str] - Specify the length of time the token is valid for. You can set duration as seconds ("30s"), minutes ("30m"), hours ("30h"), days ("30d"). - key_alias: Optional[str] - User defined key alias - - team_id: Optional[str] - The team id of the user + - team_id: Optional[str] - The team id of the key + - user_id: Optional[str] - The user id of the key - models: Optional[list] - Model_name's a user is allowed to call. (if empty, key is allowed to call all models) - aliases: Optional[dict] - Any alias mappings, on top of anything in the config.yaml model list. - https://docs.litellm.ai/docs/proxy/virtual_keys#managing-auth---upgradedowngrade-models - config: Optional[dict] - any key-specific configs, overrides config in config.yaml @@ -3202,8 +3204,8 @@ async def info_key_fn_v2( Example Curl: ``` curl -X GET "http://0.0.0.0:8000/key/info" \ --H "Authorization: Bearer sk-1234" \ --d {"keys": ["sk-1", "sk-2", "sk-3"]} + -H "Authorization: Bearer sk-1234" \ + -d {"keys": ["sk-1", "sk-2", "sk-3"]} ``` """ global prisma_client @@ -3313,6 +3315,9 @@ async def info_key_fn( ) +#### SPEND MANAGEMENT ##### + + @router.get( "/spend/keys", tags=["budget & spend Tracking"], @@ -3724,232 +3729,200 @@ async def user_auth(request: Request): return "Email sent!" -@app.get("/sso/key/generate", tags=["experimental"]) -async def google_login(request: Request): +@router.get( + "/user/info", tags=["user management"], dependencies=[Depends(user_api_key_auth)] +) +async def user_info( + user_id: Optional[str] = fastapi.Query( + default=None, description="User ID in the request parameters" + ) +): """ - Create Proxy API Keys using Google Workspace SSO. Requires setting PROXY_BASE_URL in .env - - PROXY_BASE_URL should be the your deployed proxy endpoint, e.g. PROXY_BASE_URL="https://litellm-production-7002.up.railway.app/" - Example: + Use this to get user information. (user row + all user key info) + Example request + ``` + curl -X GET 'http://localhost:8000/user/info?user_id=krrish7%40berri.ai' \ + --header 'Authorization: Bearer sk-1234' + ``` """ - microsoft_client_id = os.getenv("MICROSOFT_CLIENT_ID", None) - google_client_id = os.getenv("GOOGLE_CLIENT_ID", None) - generic_client_id = os.getenv("GENERIC_CLIENT_ID", None) - - # get url from request - redirect_url = os.getenv("PROXY_BASE_URL", str(request.base_url)) - - ui_username = os.getenv("UI_USERNAME") - if redirect_url.endswith("/"): - redirect_url += "sso/callback" - else: - redirect_url += "/sso/callback" - # Google SSO Auth - if google_client_id is not None: - from fastapi_sso.sso.google import GoogleSSO - - google_client_secret = os.getenv("GOOGLE_CLIENT_SECRET", None) - if google_client_secret is None: - raise ProxyException( - message="GOOGLE_CLIENT_SECRET not set. Set it in .env file", - type="auth_error", - param="GOOGLE_CLIENT_SECRET", - code=status.HTTP_500_INTERNAL_SERVER_ERROR, + global prisma_client + try: + if prisma_client is None: + raise Exception( + f"Database not connected. Connect a database to your proxy - https://docs.litellm.ai/docs/simple_proxy#managing-auth---virtual-keys" ) - - google_sso = GoogleSSO( - client_id=google_client_id, - client_secret=google_client_secret, - redirect_uri=redirect_url, + ## GET USER ROW ## + if user_id is not None: + user_info = await prisma_client.get_data(user_id=user_id) + else: + user_info = None + ## GET ALL KEYS ## + keys = await prisma_client.get_data( + user_id=user_id, + table_name="key", + query_type="find_all", + expires=datetime.now(), ) - verbose_proxy_logger.info( - f"In /google-login/key/generate, \nGOOGLE_REDIRECT_URI: {redirect_url}\nGOOGLE_CLIENT_ID: {google_client_id}" - ) + if user_info is None: + ## make sure we still return a total spend ## + spend = 0 + for k in keys: + spend += getattr(k, "spend", 0) + user_info = {"spend": spend} - with google_sso: - return await google_sso.get_login_redirect() - - # Microsoft SSO Auth - elif microsoft_client_id is not None: - from fastapi_sso.sso.microsoft import MicrosoftSSO - - microsoft_client_secret = os.getenv("MICROSOFT_CLIENT_SECRET", None) - microsoft_tenant = os.getenv("MICROSOFT_TENANT", None) - if microsoft_client_secret is None: + ## REMOVE HASHED TOKEN INFO before returning ## + for key in keys: + try: + key = key.model_dump() # noqa + except: + # if using pydantic v1 + key = key.dict() + key.pop("token", None) + return {"user_id": user_id, "user_info": user_info, "keys": keys} + except Exception as e: + if isinstance(e, HTTPException): raise ProxyException( - message="MICROSOFT_CLIENT_SECRET not set. Set it in .env file", + message=getattr(e, "detail", f"Authentication Error({str(e)})"), type="auth_error", - param="MICROSOFT_CLIENT_SECRET", - code=status.HTTP_500_INTERNAL_SERVER_ERROR, + param=getattr(e, "param", "None"), + code=getattr(e, "status_code", status.HTTP_400_BAD_REQUEST), ) - - microsoft_sso = MicrosoftSSO( - client_id=microsoft_client_id, - client_secret=microsoft_client_secret, - tenant=microsoft_tenant, - redirect_uri=redirect_url, - allow_insecure_http=True, + elif isinstance(e, ProxyException): + raise e + raise ProxyException( + message="Authentication Error, " + str(e), + type="auth_error", + param=getattr(e, "param", "None"), + code=status.HTTP_400_BAD_REQUEST, ) - with microsoft_sso: - return await microsoft_sso.get_login_redirect() - elif generic_client_id is not None: - from fastapi_sso.sso.generic import create_provider, DiscoveryDocument - - generic_client_secret = os.getenv("GENERIC_CLIENT_SECRET", None) - generic_authorization_endpoint = os.getenv( - "GENERIC_AUTHORIZATION_ENDPOINT", None - ) - generic_token_endpoint = os.getenv("GENERIC_TOKEN_ENDPOINT", None) - generic_userinfo_endpoint = os.getenv("GENERIC_USERINFO_ENDPOINT", None) - if generic_client_secret is None: - raise ProxyException( - message="GENERIC_CLIENT_SECRET not set. Set it in .env file", - type="auth_error", - param="GENERIC_CLIENT_SECRET", - code=status.HTTP_500_INTERNAL_SERVER_ERROR, - ) - if generic_authorization_endpoint is None: - raise ProxyException( - message="GENERIC_AUTHORIZATION_ENDPOINT not set. Set it in .env file", - type="auth_error", - param="GENERIC_AUTHORIZATION_ENDPOINT", - code=status.HTTP_500_INTERNAL_SERVER_ERROR, - ) - if generic_token_endpoint is None: - raise ProxyException( - message="GENERIC_TOKEN_ENDPOINT not set. Set it in .env file", - type="auth_error", - param="GENERIC_TOKEN_ENDPOINT", - code=status.HTTP_500_INTERNAL_SERVER_ERROR, - ) - if generic_userinfo_endpoint is None: - raise ProxyException( - message="GENERIC_USERINFO_ENDPOINT not set. Set it in .env file", - type="auth_error", - param="GENERIC_USERINFO_ENDPOINT", - code=status.HTTP_500_INTERNAL_SERVER_ERROR, - ) - - verbose_proxy_logger.debug( - f"authorization_endpoint: {generic_authorization_endpoint}\ntoken_endpoint: {generic_token_endpoint}\nuserinfo_endpoint: {generic_userinfo_endpoint}" - ) - - verbose_proxy_logger.debug( - f"GENERIC_REDIRECT_URI: {redirect_url}\nGENERIC_CLIENT_ID: {generic_client_id}\n" - ) - - discovery = DiscoveryDocument( - authorization_endpoint=generic_authorization_endpoint, - token_endpoint=generic_token_endpoint, - userinfo_endpoint=generic_userinfo_endpoint, - ) - - SSOProvider = create_provider(name="oidc", discovery_document=discovery) - generic_sso = SSOProvider( - client_id=generic_client_id, - client_secret=generic_client_secret, - redirect_uri=redirect_url, - allow_insecure_http=True, - ) - - with generic_sso: - return await generic_sso.get_login_redirect() - - elif ui_username is not None: - # No Google, Microsoft SSO - # Use UI Credentials set in .env - from fastapi.responses import HTMLResponse - - return HTMLResponse(content=html_form, status_code=200) - else: - from fastapi.responses import HTMLResponse - - return HTMLResponse(content=html_form, status_code=200) @router.post( - "/login", include_in_schema=False -) # hidden since this is a helper for UI sso login -async def login(request: Request): + "/user/update", tags=["user management"], dependencies=[Depends(user_api_key_auth)] +) +async def user_update(data: UpdateUserRequest): + """ + [TODO]: Use this to update user budget + """ + global prisma_client try: - import multipart - except ImportError: - subprocess.run(["pip", "install", "python-multipart"]) - global master_key - form = await request.form() - username = str(form.get("username")) - password = str(form.get("password")) - ui_username = os.getenv("UI_USERNAME", "admin") - ui_password = os.getenv("UI_PASSWORD", None) - if ui_password is None: - ui_password = str(master_key) if master_key is not None else None + data_json: dict = data.json() + # get the row from db + if prisma_client is None: + raise Exception("Not connected to DB!") - if ui_password is None: - raise ProxyException( - message="set Proxy master key to use UI. https://docs.litellm.ai/docs/proxy/virtual_keys", - type="auth_error", - param="UI_PASSWORD", - code=status.HTTP_500_INTERNAL_SERVER_ERROR, + non_default_values = {k: v for k, v in data_json.items() if v is not None} + response = await prisma_client.update_data( + user_id=data_json["user_id"], + data=non_default_values, + update_key_values=non_default_values, ) - - if secrets.compare_digest(username, ui_username) and secrets.compare_digest( - password, ui_password - ): - user_role = "app_owner" - user_id = username - key_user_id = user_id - if ( - os.getenv("PROXY_ADMIN_ID", None) is not None - and os.environ["PROXY_ADMIN_ID"] == user_id - ) or user_id == "admin": - # checks if user is admin - user_role = "app_admin" - key_user_id = os.getenv("PROXY_ADMIN_ID", "default_user_id") - - # Admin is Authe'd in - generate key for the UI to access Proxy - - if os.getenv("DATABASE_URL") is not None: - response = await generate_key_helper_fn( - **{"duration": "1hr", "key_max_budget": 0, "models": [], "aliases": {}, "config": {}, "spend": 0, "user_id": key_user_id, "team_id": "litellm-dashboard"} # type: ignore + return {"user_id": data_json["user_id"], **non_default_values} + # update based on remaining passed in values + except Exception as e: + traceback.print_exc() + if isinstance(e, HTTPException): + raise ProxyException( + message=getattr(e, "detail", f"Authentication Error({str(e)})"), + type="auth_error", + param=getattr(e, "param", "None"), + code=getattr(e, "status_code", status.HTTP_400_BAD_REQUEST), ) - else: - response = { - "token": "sk-gm", - "user_id": "litellm-dashboard", - } - - key = response["token"] # type: ignore - - litellm_dashboard_ui = os.getenv("PROXY_BASE_URL", "/") + "ui/" - - import jwt - - jwt_token = jwt.encode( - { - "user_id": user_id, - "key": key, - "user_email": user_id, - "user_role": user_role, - }, - "secret", - algorithm="HS256", - ) - litellm_dashboard_ui += "?userID=" + user_id + "&token=" + jwt_token - - # if a user has logged in they should be allowed to create keys - this ensures that it's set to True - general_settings["allow_user_auth"] = True - return RedirectResponse(url=litellm_dashboard_ui, status_code=303) - else: + elif isinstance(e, ProxyException): + raise e raise ProxyException( - message=f"Invalid credentials used to access UI. Passed in username: {username}, passed in password: {password}.\nCheck 'UI_USERNAME', 'UI_PASSWORD' in .env file", + message="Authentication Error, " + str(e), type="auth_error", - param="invalid_credentials", - code=status.HTTP_401_UNAUTHORIZED, + param=getattr(e, "param", "None"), + code=status.HTTP_400_BAD_REQUEST, ) +#### TEAM MANAGEMENT #### + +@router.post( + "/team/new", + tags=["team management"], + dependencies=[Depends(user_api_key_auth)], + response_model=NewTeamResponse, +) +async def new_team( + data: NewTeamRequest, + user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), +): + """ + Allow users to create a new team. Apply user permissions to their team. + + Parameters: + - team_alias: Optional[str] - User defined team alias + - team_id: Optional[str] - The team id of the user. If none passed, we'll generate it. + - admins: list - A list of user IDs that will be owning the team + - members: list - A list of user IDs that will be members of the team + - metadata: Optional[dict] - Metadata for team, store information for team. Example metadata = {"team": "core-infra", "app": "app2", "email": "ishaan@berri.ai" } + + Returns: + - team_id: (str) Unique team id - used for tracking spend across multiple keys for same team id. + """ + global prisma_client + + if prisma_client is None: + raise HTTPException(status_code=500, detail={"error": "No db connected"}) + + if data.team_id is None: + data.team_id = str(uuid.uuid4()) + + complete_team_data = LiteLLM_TeamTable( + **data.json(), + max_budget=user_api_key_dict.max_budget, + models=user_api_key_dict.models, + max_parallel_requests=user_api_key_dict.max_parallel_requests, + tpm_limit=user_api_key_dict.tpm_limit, + rpm_limit=user_api_key_dict.rpm_limit, + budget_duration=user_api_key_dict.budget_duration, + budget_reset_at=user_api_key_dict.budget_reset_at, + ) + + team_row = await prisma_client.insert_data( + data=complete_team_data.json(exclude_none=True), table_name="team" + ) + return team_row + + +@router.post( + "/team/update", tags=["team management"], dependencies=[Depends(user_api_key_auth)] +) +async def update_team(): + """ + update team and members + """ + pass + + +@router.post( + "/team/delete", tags=["team management"], dependencies=[Depends(user_api_key_auth)] +) +async def delete_team(): + """ + delete team and team keys + """ + pass + + +@router.get( + "/team/info", tags=["team management"], dependencies=[Depends(user_api_key_auth)] +) +async def team_info( + team_id: str = fastapi.Query( + default=None, description="Team ID in the request parameters" + ) +): + """ + get info on team + related keys + """ + pass + @app.get("/sso/callback", tags=["experimental"]) async def auth_callback(request: Request): """Verify login""" @@ -4125,117 +4098,6 @@ async def auth_callback(request: Request): return RedirectResponse(url=litellm_dashboard_ui) -@router.get( - "/user/info", tags=["user management"], dependencies=[Depends(user_api_key_auth)] -) -async def user_info( - user_id: Optional[str] = fastapi.Query( - default=None, description="User ID in the request parameters" - ) -): - """ - Use this to get user information. (user row + all user key info) - - Example request - ``` - curl -X GET 'http://localhost:8000/user/info?user_id=krrish7%40berri.ai' \ - --header 'Authorization: Bearer sk-1234' - ``` - """ - global prisma_client - try: - if prisma_client is None: - raise Exception( - f"Database not connected. Connect a database to your proxy - https://docs.litellm.ai/docs/simple_proxy#managing-auth---virtual-keys" - ) - ## GET USER ROW ## - if user_id is not None: - user_info = await prisma_client.get_data(user_id=user_id) - else: - user_info = None - ## GET ALL KEYS ## - keys = await prisma_client.get_data( - user_id=user_id, - table_name="key", - query_type="find_all", - expires=datetime.now(), - ) - - if user_info is None: - ## make sure we still return a total spend ## - spend = 0 - for k in keys: - spend += getattr(k, "spend", 0) - user_info = {"spend": spend} - - ## REMOVE HASHED TOKEN INFO before returning ## - for key in keys: - try: - key = key.model_dump() # noqa - except: - # if using pydantic v1 - key = key.dict() - key.pop("token", None) - return {"user_id": user_id, "user_info": user_info, "keys": keys} - except Exception as e: - if isinstance(e, HTTPException): - raise ProxyException( - message=getattr(e, "detail", f"Authentication Error({str(e)})"), - type="auth_error", - param=getattr(e, "param", "None"), - code=getattr(e, "status_code", status.HTTP_400_BAD_REQUEST), - ) - elif isinstance(e, ProxyException): - raise e - raise ProxyException( - message="Authentication Error, " + str(e), - type="auth_error", - param=getattr(e, "param", "None"), - code=status.HTTP_400_BAD_REQUEST, - ) - - -@router.post( - "/user/update", tags=["user management"], dependencies=[Depends(user_api_key_auth)] -) -async def user_update(data: UpdateUserRequest): - """ - [TODO]: Use this to update user budget - """ - global prisma_client - try: - data_json: dict = data.json() - # get the row from db - if prisma_client is None: - raise Exception("Not connected to DB!") - - non_default_values = {k: v for k, v in data_json.items() if v is not None} - response = await prisma_client.update_data( - user_id=data_json["user_id"], - data=non_default_values, - update_key_values=non_default_values, - ) - return {"user_id": data_json["user_id"], **non_default_values} - # update based on remaining passed in values - except Exception as e: - traceback.print_exc() - if isinstance(e, HTTPException): - raise ProxyException( - message=getattr(e, "detail", f"Authentication Error({str(e)})"), - type="auth_error", - param=getattr(e, "param", "None"), - code=getattr(e, "status_code", status.HTTP_400_BAD_REQUEST), - ) - elif isinstance(e, ProxyException): - raise e - raise ProxyException( - message="Authentication Error, " + str(e), - type="auth_error", - param=getattr(e, "param", "None"), - code=status.HTTP_400_BAD_REQUEST, - ) - - #### MODEL MANAGEMENT #### @@ -4587,6 +4449,341 @@ async def retrieve_server_log(request: Request): return FileResponse(filepath) +#### LOGIN ENDPOINTS #### + + +@app.get("/sso/key/generate", tags=["experimental"]) +async def google_login(request: Request): + """ + Create Proxy API Keys using Google Workspace SSO. Requires setting PROXY_BASE_URL in .env + + PROXY_BASE_URL should be the your deployed proxy endpoint, e.g. PROXY_BASE_URL="https://litellm-production-7002.up.railway.app/" + Example: + + """ + microsoft_client_id = os.getenv("MICROSOFT_CLIENT_ID", None) + google_client_id = os.getenv("GOOGLE_CLIENT_ID", None) + + # get url from request + redirect_url = os.getenv("PROXY_BASE_URL", str(request.base_url)) + + ui_username = os.getenv("UI_USERNAME") + if redirect_url.endswith("/"): + redirect_url += "sso/callback" + else: + redirect_url += "/sso/callback" + # Google SSO Auth + if google_client_id is not None: + from fastapi_sso.sso.google import GoogleSSO + + google_client_secret = os.getenv("GOOGLE_CLIENT_SECRET", None) + if google_client_secret is None: + raise ProxyException( + message="GOOGLE_CLIENT_SECRET not set. Set it in .env file", + type="auth_error", + param="GOOGLE_CLIENT_SECRET", + code=status.HTTP_500_INTERNAL_SERVER_ERROR, + ) + + google_sso = GoogleSSO( + client_id=google_client_id, + client_secret=google_client_secret, + redirect_uri=redirect_url, + ) + + verbose_proxy_logger.info( + f"In /google-login/key/generate, \nGOOGLE_REDIRECT_URI: {redirect_url}\nGOOGLE_CLIENT_ID: {google_client_id}" + ) + + with google_sso: + return await google_sso.get_login_redirect() + + # Microsoft SSO Auth + elif microsoft_client_id is not None: + from fastapi_sso.sso.microsoft import MicrosoftSSO + + microsoft_client_secret = os.getenv("MICROSOFT_CLIENT_SECRET", None) + microsoft_tenant = os.getenv("MICROSOFT_TENANT", None) + if microsoft_client_secret is None: + raise ProxyException( + message="MICROSOFT_CLIENT_SECRET not set. Set it in .env file", + type="auth_error", + param="MICROSOFT_CLIENT_SECRET", + code=status.HTTP_500_INTERNAL_SERVER_ERROR, + ) + + microsoft_sso = MicrosoftSSO( + client_id=microsoft_client_id, + client_secret=microsoft_client_secret, + tenant=microsoft_tenant, + redirect_uri=redirect_url, + allow_insecure_http=True, + ) + with microsoft_sso: + return await microsoft_sso.get_login_redirect() + elif generic_client_id is not None: + from fastapi_sso.sso.generic import create_provider, DiscoveryDocument + + generic_client_secret = os.getenv("GENERIC_CLIENT_SECRET", None) + generic_authorization_endpoint = os.getenv( + "GENERIC_AUTHORIZATION_ENDPOINT", None + ) + generic_token_endpoint = os.getenv("GENERIC_TOKEN_ENDPOINT", None) + generic_userinfo_endpoint = os.getenv("GENERIC_USERINFO_ENDPOINT", None) + if generic_client_secret is None: + raise ProxyException( + message="GENERIC_CLIENT_SECRET not set. Set it in .env file", + type="auth_error", + param="GENERIC_CLIENT_SECRET", + code=status.HTTP_500_INTERNAL_SERVER_ERROR, + ) + if generic_authorization_endpoint is None: + raise ProxyException( + message="GENERIC_AUTHORIZATION_ENDPOINT not set. Set it in .env file", + type="auth_error", + param="GENERIC_AUTHORIZATION_ENDPOINT", + code=status.HTTP_500_INTERNAL_SERVER_ERROR, + ) + if generic_token_endpoint is None: + raise ProxyException( + message="GENERIC_TOKEN_ENDPOINT not set. Set it in .env file", + type="auth_error", + param="GENERIC_TOKEN_ENDPOINT", + code=status.HTTP_500_INTERNAL_SERVER_ERROR, + ) + if generic_userinfo_endpoint is None: + raise ProxyException( + message="GENERIC_USERINFO_ENDPOINT not set. Set it in .env file", + type="auth_error", + param="GENERIC_USERINFO_ENDPOINT", + code=status.HTTP_500_INTERNAL_SERVER_ERROR, + ) + + verbose_proxy_logger.debug( + f"authorization_endpoint: {generic_authorization_endpoint}\ntoken_endpoint: {generic_token_endpoint}\nuserinfo_endpoint: {generic_userinfo_endpoint}" + ) + + verbose_proxy_logger.debug( + f"GENERIC_REDIRECT_URI: {redirect_url}\nGENERIC_CLIENT_ID: {generic_client_id}\n" + ) + + discovery = DiscoveryDocument( + authorization_endpoint=generic_authorization_endpoint, + token_endpoint=generic_token_endpoint, + userinfo_endpoint=generic_userinfo_endpoint, + ) + + SSOProvider = create_provider(name="oidc", discovery_document=discovery) + generic_sso = SSOProvider( + client_id=generic_client_id, + client_secret=generic_client_secret, + redirect_uri=redirect_url, + allow_insecure_http=True, + ) + + with generic_sso: + return await generic_sso.get_login_redirect() + + elif ui_username is not None: + # No Google, Microsoft SSO + # Use UI Credentials set in .env + from fastapi.responses import HTMLResponse + elif ui_username is not None: + # No Google, Microsoft SSO + # Use UI Credentials set in .env + from fastapi.responses import HTMLResponse + + return HTMLResponse(content=html_form, status_code=200) + else: + from fastapi.responses import HTMLResponse + + return HTMLResponse(content=html_form, status_code=200) + + +@router.post( + "/login", include_in_schema=False +) # hidden since this is a helper for UI sso login +async def login(request: Request): + try: + import multipart + except ImportError: + subprocess.run(["pip", "install", "python-multipart"]) + global master_key + form = await request.form() + username = str(form.get("username")) + password = str(form.get("password")) + ui_username = os.getenv("UI_USERNAME", "admin") + ui_password = os.getenv("UI_PASSWORD", None) + if ui_password is None: + ui_password = str(master_key) if master_key is not None else None + + if ui_password is None: + raise ProxyException( + message="set Proxy master key to use UI. https://docs.litellm.ai/docs/proxy/virtual_keys", + type="auth_error", + param="UI_PASSWORD", + code=status.HTTP_500_INTERNAL_SERVER_ERROR, + ) + + if secrets.compare_digest(username, ui_username) and secrets.compare_digest( + password, ui_password + ): + user_role = "app_owner" + user_id = username + key_user_id = user_id + if ( + os.getenv("PROXY_ADMIN_ID", None) is not None + and os.environ["PROXY_ADMIN_ID"] == user_id + ) or user_id == "admin": + # checks if user is admin + user_role = "app_admin" + key_user_id = os.getenv("PROXY_ADMIN_ID", "default_user_id") + + # Admin is Authe'd in - generate key for the UI to access Proxy + + if os.getenv("DATABASE_URL") is not None: + response = await generate_key_helper_fn( + **{"duration": "1hr", "key_max_budget": 0, "models": [], "aliases": {}, "config": {}, "spend": 0, "user_id": key_user_id, "team_id": "litellm-dashboard"} # type: ignore + ) + else: + response = { + "token": "sk-gm", + "user_id": "litellm-dashboard", + } + + key = response["token"] # type: ignore + + litellm_dashboard_ui = os.getenv("PROXY_BASE_URL", "/") + "ui/" + + import jwt + + jwt_token = jwt.encode( + { + "user_id": user_id, + "key": key, + "user_email": user_id, + "user_role": user_role, + }, + "secret", + algorithm="HS256", + ) + litellm_dashboard_ui += "?userID=" + user_id + "&token=" + jwt_token + + # if a user has logged in they should be allowed to create keys - this ensures that it's set to True + general_settings["allow_user_auth"] = True + return RedirectResponse(url=litellm_dashboard_ui, status_code=303) + else: + raise ProxyException( + message=f"Invalid credentials used to access UI. Passed in username: {username}, passed in password: {password}.\nCheck 'UI_USERNAME', 'UI_PASSWORD' in .env file", + type="auth_error", + param="invalid_credentials", + code=status.HTTP_401_UNAUTHORIZED, + ) + + +@app.get("/sso/callback", tags=["experimental"]) +async def auth_callback(request: Request): + """Verify login""" + global general_settings + microsoft_client_id = os.getenv("MICROSOFT_CLIENT_ID", None) + google_client_id = os.getenv("GOOGLE_CLIENT_ID", None) + + # get url from request + redirect_url = os.getenv("PROXY_BASE_URL", str(request.base_url)) + + if redirect_url.endswith("/"): + redirect_url += "sso/callback" + else: + redirect_url += "/sso/callback" + + if google_client_id is not None: + from fastapi_sso.sso.google import GoogleSSO + + google_client_secret = os.getenv("GOOGLE_CLIENT_SECRET", None) + if google_client_secret is None: + raise ProxyException( + message="GOOGLE_CLIENT_SECRET not set. Set it in .env file", + type="auth_error", + param="GOOGLE_CLIENT_SECRET", + code=status.HTTP_500_INTERNAL_SERVER_ERROR, + ) + google_sso = GoogleSSO( + client_id=google_client_id, + redirect_uri=redirect_url, + client_secret=google_client_secret, + ) + result = await google_sso.verify_and_process(request) + + elif microsoft_client_id is not None: + from fastapi_sso.sso.microsoft import MicrosoftSSO + + microsoft_client_secret = os.getenv("MICROSOFT_CLIENT_SECRET", None) + microsoft_tenant = os.getenv("MICROSOFT_TENANT", None) + if microsoft_client_secret is None: + raise ProxyException( + message="MICROSOFT_CLIENT_SECRET not set. Set it in .env file", + type="auth_error", + param="MICROSOFT_CLIENT_SECRET", + code=status.HTTP_500_INTERNAL_SERVER_ERROR, + ) + if microsoft_tenant is None: + raise ProxyException( + message="MICROSOFT_TENANT not set. Set it in .env file", + type="auth_error", + param="MICROSOFT_TENANT", + code=status.HTTP_500_INTERNAL_SERVER_ERROR, + ) + + microsoft_sso = MicrosoftSSO( + client_id=microsoft_client_id, + client_secret=microsoft_client_secret, + tenant=microsoft_tenant, + redirect_uri=redirect_url, + allow_insecure_http=True, + ) + result = await microsoft_sso.verify_and_process(request) + + # User is Authe'd in - generate key for the UI to access Proxy + user_email = getattr(result, "email", None) + user_id = getattr(result, "id", None) + if user_id is None: + user_id = getattr(result, "first_name", "") + getattr(result, "last_name", "") + + response = await generate_key_helper_fn( + **{"duration": "1hr", "key_max_budget": 0, "models": [], "aliases": {}, "config": {}, "spend": 0, "user_id": user_id, "team_id": "litellm-dashboard", "user_email": user_email} # type: ignore + ) + key = response["token"] # type: ignore + user_id = response["user_id"] # type: ignore + + litellm_dashboard_ui = "/ui/" + + user_role = "app_owner" + if ( + os.getenv("PROXY_ADMIN_ID", None) is not None + and os.environ["PROXY_ADMIN_ID"] == user_id + ): + # checks if user is admin + user_role = "app_admin" + + import jwt + + jwt_token = jwt.encode( + { + "user_id": user_id, + "key": key, + "user_email": user_email, + "user_role": user_role, + }, + "secret", + algorithm="HS256", + ) + litellm_dashboard_ui += "?userID=" + user_id + "&token=" + jwt_token + + # if a user has logged in they should be allowed to create keys - this ensures that it's set to True + general_settings["allow_user_auth"] = True + return RedirectResponse(url=litellm_dashboard_ui) + + #### BASIC ENDPOINTS #### @router.post( "/config/update", diff --git a/litellm/proxy/schema.prisma b/litellm/proxy/schema.prisma index 3a5c3ea23..dbdcd3f5b 100644 --- a/litellm/proxy/schema.prisma +++ b/litellm/proxy/schema.prisma @@ -8,6 +8,25 @@ generator client { provider = "prisma-client-py" } +// Assign prod keys to groups, not individuals +model LiteLLM_TeamTable { + team_id String @unique + team_alias String? + admins String[] + members String[] + metadata Json @default("{}") + max_budget Float? + spend Float @default(0.0) + models String[] + max_parallel_requests Int? + tpm_limit BigInt? + rpm_limit BigInt? + budget_duration String? + budget_reset_at DateTime? + created_at DateTime @default(now()) @map("created_at") + updated_at DateTime @default(now()) @updatedAt @map("updated_at") +} + // Track spend, rate limit, budget Users model LiteLLM_UserTable { user_id String @unique diff --git a/litellm/proxy/utils.py b/litellm/proxy/utils.py index 4f28e2adb..d3c95f350 100644 --- a/litellm/proxy/utils.py +++ b/litellm/proxy/utils.py @@ -511,8 +511,9 @@ class PrismaClient: token: Optional[Union[str, list]] = None, user_id: Optional[str] = None, user_id_list: Optional[list] = None, + team_id: Optional[str] = None, key_val: Optional[dict] = None, - table_name: Optional[Literal["user", "key", "config", "spend"]] = None, + table_name: Optional[Literal["user", "key", "config", "spend", "team"]] = None, query_type: Literal["find_unique", "find_all"] = "find_unique", expires: Optional[datetime] = None, reset_at: Optional[datetime] = None, @@ -547,6 +548,14 @@ class PrismaClient: for r in response: if isinstance(r.expires, datetime): r.expires = r.expires.isoformat() + elif query_type == "find_all" and team_id is not None: + response = await self.db.litellm_verificationtoken.find_many( + where={"team_id": team_id} + ) + if response is not None and len(response) > 0: + for r in response: + if isinstance(r.expires, datetime): + r.expires = r.expires.isoformat() elif ( query_type == "find_all" and expires is not None @@ -659,7 +668,23 @@ class PrismaClient: order={"startTime": "desc"}, ) return response + elif table_name == "team": + if query_type == "find_unique": + response = await self.db.litellm_teamtable.find_unique( + where={"team_id": team_id} # type: ignore + ) + if query_type == "find_all" and team_id is not None: + user_id_values = str(tuple(team_id)) + sql_query = f""" + SELECT * + FROM "LiteLLM_TeamTable" + WHERE "team_id" IN {team_id} + """ + # Execute the raw query + # The asterisk before `team_id` unpacks the list into separate arguments + response = await self.db.query_raw(sql_query) + return response except Exception as e: print_verbose(f"LiteLLM Prisma Client Exception: {e}") import traceback @@ -679,7 +704,7 @@ class PrismaClient: on_backoff=on_backoff, # specifying the function to call on backoff ) async def insert_data( - self, data: dict, table_name: Literal["user", "key", "config", "spend"] + self, data: dict, table_name: Literal["user", "key", "config", "spend", "team"] ): """ Add a key to the database. If it already exists, do nothing. @@ -715,6 +740,17 @@ class PrismaClient: ) verbose_proxy_logger.info(f"Data Inserted into User Table") return new_user_row + elif table_name == "team": + db_data = self.jsonify_object(data=data) + new_team_row = await self.db.litellm_teamtable.upsert( + where={"team_id": data["team_id"]}, + data={ + "create": {**db_data}, # type: ignore + "update": {}, # don't do anything if it already exists + }, + ) + verbose_proxy_logger.info(f"Data Inserted into Team Table") + return new_team_row elif table_name == "config": """ For each param, diff --git a/schema.prisma b/schema.prisma index 3a5c3ea23..dbdcd3f5b 100644 --- a/schema.prisma +++ b/schema.prisma @@ -8,6 +8,25 @@ generator client { provider = "prisma-client-py" } +// Assign prod keys to groups, not individuals +model LiteLLM_TeamTable { + team_id String @unique + team_alias String? + admins String[] + members String[] + metadata Json @default("{}") + max_budget Float? + spend Float @default(0.0) + models String[] + max_parallel_requests Int? + tpm_limit BigInt? + rpm_limit BigInt? + budget_duration String? + budget_reset_at DateTime? + created_at DateTime @default(now()) @map("created_at") + updated_at DateTime @default(now()) @updatedAt @map("updated_at") +} + // Track spend, rate limit, budget Users model LiteLLM_UserTable { user_id String @unique