From d589061a9c2aea4592b153d0cb8288eb92c477fa Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Wed, 14 Feb 2024 14:54:29 -0800 Subject: [PATCH 1/5] refactor(proxy_server.py): initial stubbed endpoints for team management --- litellm/proxy/proxy_server.py | 59 ++++++++++++++++++++++++++++++++++- 1 file changed, 58 insertions(+), 1 deletion(-) diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index 3dedc3a71..86d1d0fb9 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -2798,7 +2798,7 @@ async def image_generation( ) -#### KEY MANAGEMENT #### +#### KEY MANAGEMENT ##### @router.post( @@ -3159,6 +3159,9 @@ async def info_key_fn( ) +#### SPEND MANAGEMENT ##### + + @router.get( "/spend/keys", tags=["budget & spend Tracking"], @@ -3946,6 +3949,60 @@ async def user_update(data: UpdateUserRequest): ) +#### TEAM MANAGEMENT #### + + +@router.post( + "/team/new", tags=["team management"], dependencies=[Depends(user_api_key_auth)] +) +async def new_team(): + """ + Create a new team + + Parameters: + - team_alias: Optional[str] - User defined team alias + - team_id: Optional[str] - The team id of the user. If none passed, we'll generate it. + - team_admins: list - A list of user IDs that will be owning the team + - metadata: Optional[dict] - Metadata for team, store information for team. Example metadata = {"team": "core-infra", "app": "app2", "email": "ishaan@berri.ai" } + + Returns: + - key: (str) The generated api key + - expires: (datetime) Datetime object for when key expires. + - team_id: (str) Unique team id - used for tracking spend across multiple keys for same team id. + """ + pass + + +@router.post( + "/team/update", tags=["team management"], dependencies=[Depends(user_api_key_auth)] +) +async def update_team(): + """ + update team and members + """ + pass + + +@router.post( + "/team/delete", tags=["team management"], dependencies=[Depends(user_api_key_auth)] +) +async def delete_team(): + """ + delete team and team keys + """ + pass + + +@router.post( + "/team/info", tags=["team management"], dependencies=[Depends(user_api_key_auth)] +) +async def info_team(): + """ + get info on team + related keys + """ + pass + + #### MODEL MANAGEMENT #### From c094db7160df6720a586686e9666a64f57c990fc Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Wed, 14 Feb 2024 17:20:41 -0800 Subject: [PATCH 2/5] feat(proxy_server.py): working /team/new endpoint for creating a new team --- litellm/proxy/_types.py | 32 +- litellm/proxy/proxy_server.py | 574 ++++++++++++++++++---------------- litellm/proxy/schema.prisma | 19 ++ litellm/proxy/utils.py | 13 +- schema.prisma | 21 +- 5 files changed, 383 insertions(+), 276 deletions(-) diff --git a/litellm/proxy/_types.py b/litellm/proxy/_types.py index c85564231..6d2d6e36c 100644 --- a/litellm/proxy/_types.py +++ b/litellm/proxy/_types.py @@ -21,10 +21,10 @@ class LiteLLMBase(BaseModel): def json(self, **kwargs): try: - return self.model_dump() # noqa + return self.model_dump(**kwargs) # noqa except Exception as e: # if using pydantic v1 - return self.dict() + return self.dict(**kwargs) def fields_set(self): try: @@ -211,6 +211,34 @@ class UpdateUserRequest(GenerateRequestBase): max_budget: Optional[float] = None +class NewTeamRequest(LiteLLMBase): + team_alias: Optional[str] = None + team_id: Optional[str] = None + admins: list = [] + members: list = [] + metadata: Optional[dict] = None + + +class LiteLLM_TeamTable(NewTeamRequest): + max_budget: Optional[float] = None + spend: Optional[float] = None + models: list = [] + max_parallel_requests: Optional[int] = None + tpm_limit: Optional[int] = None + rpm_limit: Optional[int] = None + budget_duration: Optional[str] = None + budget_reset_at: Optional[datetime] = None + + +class NewTeamResponse(LiteLLMBase): + team_id: str + admins: list + members: list + metadata: dict + created_at: datetime + updated_at: datetime + + class KeyManagementSystem(enum.Enum): GOOGLE_KMS = "google_kms" AZURE_KEY_VAULT = "azure_key_vault" diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index 86d1d0fb9..e24fc1a82 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -3573,271 +3573,6 @@ async def user_auth(request: Request): return "Email sent!" -@app.get("/sso/key/generate", tags=["experimental"]) -async def google_login(request: Request): - """ - Create Proxy API Keys using Google Workspace SSO. Requires setting PROXY_BASE_URL in .env - - PROXY_BASE_URL should be the your deployed proxy endpoint, e.g. PROXY_BASE_URL="https://litellm-production-7002.up.railway.app/" - Example: - - """ - microsoft_client_id = os.getenv("MICROSOFT_CLIENT_ID", None) - google_client_id = os.getenv("GOOGLE_CLIENT_ID", None) - - # get url from request - redirect_url = os.getenv("PROXY_BASE_URL", str(request.base_url)) - - ui_username = os.getenv("UI_USERNAME") - if redirect_url.endswith("/"): - redirect_url += "sso/callback" - else: - redirect_url += "/sso/callback" - # Google SSO Auth - if google_client_id is not None: - from fastapi_sso.sso.google import GoogleSSO - - google_client_secret = os.getenv("GOOGLE_CLIENT_SECRET", None) - if google_client_secret is None: - raise ProxyException( - message="GOOGLE_CLIENT_SECRET not set. Set it in .env file", - type="auth_error", - param="GOOGLE_CLIENT_SECRET", - code=status.HTTP_500_INTERNAL_SERVER_ERROR, - ) - - google_sso = GoogleSSO( - client_id=google_client_id, - client_secret=google_client_secret, - redirect_uri=redirect_url, - ) - - verbose_proxy_logger.info( - f"In /google-login/key/generate, \nGOOGLE_REDIRECT_URI: {redirect_url}\nGOOGLE_CLIENT_ID: {google_client_id}" - ) - - with google_sso: - return await google_sso.get_login_redirect() - - # Microsoft SSO Auth - elif microsoft_client_id is not None: - from fastapi_sso.sso.microsoft import MicrosoftSSO - - microsoft_client_secret = os.getenv("MICROSOFT_CLIENT_SECRET", None) - microsoft_tenant = os.getenv("MICROSOFT_TENANT", None) - if microsoft_client_secret is None: - raise ProxyException( - message="MICROSOFT_CLIENT_SECRET not set. Set it in .env file", - type="auth_error", - param="MICROSOFT_CLIENT_SECRET", - code=status.HTTP_500_INTERNAL_SERVER_ERROR, - ) - - microsoft_sso = MicrosoftSSO( - client_id=microsoft_client_id, - client_secret=microsoft_client_secret, - tenant=microsoft_tenant, - redirect_uri=redirect_url, - allow_insecure_http=True, - ) - with microsoft_sso: - return await microsoft_sso.get_login_redirect() - elif ui_username is not None: - # No Google, Microsoft SSO - # Use UI Credentials set in .env - from fastapi.responses import HTMLResponse - - return HTMLResponse(content=html_form, status_code=200) - else: - from fastapi.responses import HTMLResponse - - return HTMLResponse(content=html_form, status_code=200) - - -@router.post( - "/login", include_in_schema=False -) # hidden since this is a helper for UI sso login -async def login(request: Request): - try: - import multipart - except ImportError: - subprocess.run(["pip", "install", "python-multipart"]) - global master_key - form = await request.form() - username = str(form.get("username")) - password = str(form.get("password")) - ui_username = os.getenv("UI_USERNAME", "admin") - ui_password = os.getenv("UI_PASSWORD", None) - if ui_password is None: - ui_password = str(master_key) if master_key is not None else None - - if ui_password is None: - raise ProxyException( - message="set Proxy master key to use UI. https://docs.litellm.ai/docs/proxy/virtual_keys", - type="auth_error", - param="UI_PASSWORD", - code=status.HTTP_500_INTERNAL_SERVER_ERROR, - ) - - if secrets.compare_digest(username, ui_username) and secrets.compare_digest( - password, ui_password - ): - user_role = "app_owner" - user_id = username - key_user_id = user_id - if ( - os.getenv("PROXY_ADMIN_ID", None) is not None - and os.environ["PROXY_ADMIN_ID"] == user_id - ) or user_id == "admin": - # checks if user is admin - user_role = "app_admin" - key_user_id = os.getenv("PROXY_ADMIN_ID", "default_user_id") - - # Admin is Authe'd in - generate key for the UI to access Proxy - - if os.getenv("DATABASE_URL") is not None: - response = await generate_key_helper_fn( - **{"duration": "1hr", "key_max_budget": 0, "models": [], "aliases": {}, "config": {}, "spend": 0, "user_id": key_user_id, "team_id": "litellm-dashboard"} # type: ignore - ) - else: - response = { - "token": "sk-gm", - "user_id": "litellm-dashboard", - } - - key = response["token"] # type: ignore - - litellm_dashboard_ui = os.getenv("PROXY_BASE_URL", "/") + "ui/" - - import jwt - - jwt_token = jwt.encode( - { - "user_id": user_id, - "key": key, - "user_email": user_id, - "user_role": user_role, - }, - "secret", - algorithm="HS256", - ) - litellm_dashboard_ui += "?userID=" + user_id + "&token=" + jwt_token - - # if a user has logged in they should be allowed to create keys - this ensures that it's set to True - general_settings["allow_user_auth"] = True - return RedirectResponse(url=litellm_dashboard_ui, status_code=303) - else: - raise ProxyException( - message=f"Invalid credentials used to access UI. Passed in username: {username}, passed in password: {password}.\nCheck 'UI_USERNAME', 'UI_PASSWORD' in .env file", - type="auth_error", - param="invalid_credentials", - code=status.HTTP_401_UNAUTHORIZED, - ) - - -@app.get("/sso/callback", tags=["experimental"]) -async def auth_callback(request: Request): - """Verify login""" - global general_settings - microsoft_client_id = os.getenv("MICROSOFT_CLIENT_ID", None) - google_client_id = os.getenv("GOOGLE_CLIENT_ID", None) - - # get url from request - redirect_url = os.getenv("PROXY_BASE_URL", str(request.base_url)) - - if redirect_url.endswith("/"): - redirect_url += "sso/callback" - else: - redirect_url += "/sso/callback" - - if google_client_id is not None: - from fastapi_sso.sso.google import GoogleSSO - - google_client_secret = os.getenv("GOOGLE_CLIENT_SECRET", None) - if google_client_secret is None: - raise ProxyException( - message="GOOGLE_CLIENT_SECRET not set. Set it in .env file", - type="auth_error", - param="GOOGLE_CLIENT_SECRET", - code=status.HTTP_500_INTERNAL_SERVER_ERROR, - ) - google_sso = GoogleSSO( - client_id=google_client_id, - redirect_uri=redirect_url, - client_secret=google_client_secret, - ) - result = await google_sso.verify_and_process(request) - - elif microsoft_client_id is not None: - from fastapi_sso.sso.microsoft import MicrosoftSSO - - microsoft_client_secret = os.getenv("MICROSOFT_CLIENT_SECRET", None) - microsoft_tenant = os.getenv("MICROSOFT_TENANT", None) - if microsoft_client_secret is None: - raise ProxyException( - message="MICROSOFT_CLIENT_SECRET not set. Set it in .env file", - type="auth_error", - param="MICROSOFT_CLIENT_SECRET", - code=status.HTTP_500_INTERNAL_SERVER_ERROR, - ) - if microsoft_tenant is None: - raise ProxyException( - message="MICROSOFT_TENANT not set. Set it in .env file", - type="auth_error", - param="MICROSOFT_TENANT", - code=status.HTTP_500_INTERNAL_SERVER_ERROR, - ) - - microsoft_sso = MicrosoftSSO( - client_id=microsoft_client_id, - client_secret=microsoft_client_secret, - tenant=microsoft_tenant, - redirect_uri=redirect_url, - allow_insecure_http=True, - ) - result = await microsoft_sso.verify_and_process(request) - - # User is Authe'd in - generate key for the UI to access Proxy - user_email = getattr(result, "email", None) - user_id = getattr(result, "id", None) - if user_id is None: - user_id = getattr(result, "first_name", "") + getattr(result, "last_name", "") - - response = await generate_key_helper_fn( - **{"duration": "1hr", "key_max_budget": 0, "models": [], "aliases": {}, "config": {}, "spend": 0, "user_id": user_id, "team_id": "litellm-dashboard", "user_email": user_email} # type: ignore - ) - key = response["token"] # type: ignore - user_id = response["user_id"] # type: ignore - - litellm_dashboard_ui = "/ui/" - - user_role = "app_owner" - if ( - os.getenv("PROXY_ADMIN_ID", None) is not None - and os.environ["PROXY_ADMIN_ID"] == user_id - ): - # checks if user is admin - user_role = "app_admin" - - import jwt - - jwt_token = jwt.encode( - { - "user_id": user_id, - "key": key, - "user_email": user_email, - "user_role": user_role, - }, - "secret", - algorithm="HS256", - ) - litellm_dashboard_ui += "?userID=" + user_id + "&token=" + jwt_token - - # if a user has logged in they should be allowed to create keys - this ensures that it's set to True - general_settings["allow_user_auth"] = True - return RedirectResponse(url=litellm_dashboard_ui) - - @router.get( "/user/info", tags=["user management"], dependencies=[Depends(user_api_key_auth)] ) @@ -3953,24 +3688,51 @@ async def user_update(data: UpdateUserRequest): @router.post( - "/team/new", tags=["team management"], dependencies=[Depends(user_api_key_auth)] + "/team/new", + tags=["team management"], + dependencies=[Depends(user_api_key_auth)], + response_model=NewTeamResponse, ) -async def new_team(): +async def new_team( + data: NewTeamRequest, + user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), +): """ - Create a new team + Allow users to create a new team. Apply user permissions to their team. Parameters: - team_alias: Optional[str] - User defined team alias - team_id: Optional[str] - The team id of the user. If none passed, we'll generate it. - - team_admins: list - A list of user IDs that will be owning the team + - admins: list - A list of user IDs that will be owning the team + - members: list - A list of user IDs that will be members of the team - metadata: Optional[dict] - Metadata for team, store information for team. Example metadata = {"team": "core-infra", "app": "app2", "email": "ishaan@berri.ai" } Returns: - - key: (str) The generated api key - - expires: (datetime) Datetime object for when key expires. - team_id: (str) Unique team id - used for tracking spend across multiple keys for same team id. """ - pass + global prisma_client + + if prisma_client is None: + raise HTTPException(status_code=500, detail={"error": "No db connected"}) + + if data.team_id is None: + data.team_id = str(uuid.uuid4()) + + complete_team_data = LiteLLM_TeamTable( + **data.json(), + max_budget=user_api_key_dict.max_budget, + models=user_api_key_dict.models, + max_parallel_requests=user_api_key_dict.max_parallel_requests, + tpm_limit=user_api_key_dict.tpm_limit, + rpm_limit=user_api_key_dict.rpm_limit, + budget_duration=user_api_key_dict.budget_duration, + budget_reset_at=user_api_key_dict.budget_reset_at, + ) + + team_row = await prisma_client.insert_data( + data=complete_team_data.json(exclude_none=True), table_name="team" + ) + return team_row @router.post( @@ -4354,6 +4116,274 @@ async def retrieve_server_log(request: Request): return FileResponse(filepath) +#### LOGIN ENDPOINTS #### + + +@app.get("/sso/key/generate", tags=["experimental"]) +async def google_login(request: Request): + """ + Create Proxy API Keys using Google Workspace SSO. Requires setting PROXY_BASE_URL in .env + + PROXY_BASE_URL should be the your deployed proxy endpoint, e.g. PROXY_BASE_URL="https://litellm-production-7002.up.railway.app/" + Example: + + """ + microsoft_client_id = os.getenv("MICROSOFT_CLIENT_ID", None) + google_client_id = os.getenv("GOOGLE_CLIENT_ID", None) + + # get url from request + redirect_url = os.getenv("PROXY_BASE_URL", str(request.base_url)) + + ui_username = os.getenv("UI_USERNAME") + if redirect_url.endswith("/"): + redirect_url += "sso/callback" + else: + redirect_url += "/sso/callback" + # Google SSO Auth + if google_client_id is not None: + from fastapi_sso.sso.google import GoogleSSO + + google_client_secret = os.getenv("GOOGLE_CLIENT_SECRET", None) + if google_client_secret is None: + raise ProxyException( + message="GOOGLE_CLIENT_SECRET not set. Set it in .env file", + type="auth_error", + param="GOOGLE_CLIENT_SECRET", + code=status.HTTP_500_INTERNAL_SERVER_ERROR, + ) + + google_sso = GoogleSSO( + client_id=google_client_id, + client_secret=google_client_secret, + redirect_uri=redirect_url, + ) + + verbose_proxy_logger.info( + f"In /google-login/key/generate, \nGOOGLE_REDIRECT_URI: {redirect_url}\nGOOGLE_CLIENT_ID: {google_client_id}" + ) + + with google_sso: + return await google_sso.get_login_redirect() + + # Microsoft SSO Auth + elif microsoft_client_id is not None: + from fastapi_sso.sso.microsoft import MicrosoftSSO + + microsoft_client_secret = os.getenv("MICROSOFT_CLIENT_SECRET", None) + microsoft_tenant = os.getenv("MICROSOFT_TENANT", None) + if microsoft_client_secret is None: + raise ProxyException( + message="MICROSOFT_CLIENT_SECRET not set. Set it in .env file", + type="auth_error", + param="MICROSOFT_CLIENT_SECRET", + code=status.HTTP_500_INTERNAL_SERVER_ERROR, + ) + + microsoft_sso = MicrosoftSSO( + client_id=microsoft_client_id, + client_secret=microsoft_client_secret, + tenant=microsoft_tenant, + redirect_uri=redirect_url, + allow_insecure_http=True, + ) + with microsoft_sso: + return await microsoft_sso.get_login_redirect() + elif ui_username is not None: + # No Google, Microsoft SSO + # Use UI Credentials set in .env + from fastapi.responses import HTMLResponse + + return HTMLResponse(content=html_form, status_code=200) + else: + from fastapi.responses import HTMLResponse + + return HTMLResponse(content=html_form, status_code=200) + + +@router.post( + "/login", include_in_schema=False +) # hidden since this is a helper for UI sso login +async def login(request: Request): + try: + import multipart + except ImportError: + subprocess.run(["pip", "install", "python-multipart"]) + global master_key + form = await request.form() + username = str(form.get("username")) + password = str(form.get("password")) + ui_username = os.getenv("UI_USERNAME", "admin") + ui_password = os.getenv("UI_PASSWORD", None) + if ui_password is None: + ui_password = str(master_key) if master_key is not None else None + + if ui_password is None: + raise ProxyException( + message="set Proxy master key to use UI. https://docs.litellm.ai/docs/proxy/virtual_keys", + type="auth_error", + param="UI_PASSWORD", + code=status.HTTP_500_INTERNAL_SERVER_ERROR, + ) + + if secrets.compare_digest(username, ui_username) and secrets.compare_digest( + password, ui_password + ): + user_role = "app_owner" + user_id = username + key_user_id = user_id + if ( + os.getenv("PROXY_ADMIN_ID", None) is not None + and os.environ["PROXY_ADMIN_ID"] == user_id + ) or user_id == "admin": + # checks if user is admin + user_role = "app_admin" + key_user_id = os.getenv("PROXY_ADMIN_ID", "default_user_id") + + # Admin is Authe'd in - generate key for the UI to access Proxy + + if os.getenv("DATABASE_URL") is not None: + response = await generate_key_helper_fn( + **{"duration": "1hr", "key_max_budget": 0, "models": [], "aliases": {}, "config": {}, "spend": 0, "user_id": key_user_id, "team_id": "litellm-dashboard"} # type: ignore + ) + else: + response = { + "token": "sk-gm", + "user_id": "litellm-dashboard", + } + + key = response["token"] # type: ignore + + litellm_dashboard_ui = os.getenv("PROXY_BASE_URL", "/") + "ui/" + + import jwt + + jwt_token = jwt.encode( + { + "user_id": user_id, + "key": key, + "user_email": user_id, + "user_role": user_role, + }, + "secret", + algorithm="HS256", + ) + litellm_dashboard_ui += "?userID=" + user_id + "&token=" + jwt_token + + # if a user has logged in they should be allowed to create keys - this ensures that it's set to True + general_settings["allow_user_auth"] = True + return RedirectResponse(url=litellm_dashboard_ui, status_code=303) + else: + raise ProxyException( + message=f"Invalid credentials used to access UI. Passed in username: {username}, passed in password: {password}.\nCheck 'UI_USERNAME', 'UI_PASSWORD' in .env file", + type="auth_error", + param="invalid_credentials", + code=status.HTTP_401_UNAUTHORIZED, + ) + + +@app.get("/sso/callback", tags=["experimental"]) +async def auth_callback(request: Request): + """Verify login""" + global general_settings + microsoft_client_id = os.getenv("MICROSOFT_CLIENT_ID", None) + google_client_id = os.getenv("GOOGLE_CLIENT_ID", None) + + # get url from request + redirect_url = os.getenv("PROXY_BASE_URL", str(request.base_url)) + + if redirect_url.endswith("/"): + redirect_url += "sso/callback" + else: + redirect_url += "/sso/callback" + + if google_client_id is not None: + from fastapi_sso.sso.google import GoogleSSO + + google_client_secret = os.getenv("GOOGLE_CLIENT_SECRET", None) + if google_client_secret is None: + raise ProxyException( + message="GOOGLE_CLIENT_SECRET not set. Set it in .env file", + type="auth_error", + param="GOOGLE_CLIENT_SECRET", + code=status.HTTP_500_INTERNAL_SERVER_ERROR, + ) + google_sso = GoogleSSO( + client_id=google_client_id, + redirect_uri=redirect_url, + client_secret=google_client_secret, + ) + result = await google_sso.verify_and_process(request) + + elif microsoft_client_id is not None: + from fastapi_sso.sso.microsoft import MicrosoftSSO + + microsoft_client_secret = os.getenv("MICROSOFT_CLIENT_SECRET", None) + microsoft_tenant = os.getenv("MICROSOFT_TENANT", None) + if microsoft_client_secret is None: + raise ProxyException( + message="MICROSOFT_CLIENT_SECRET not set. Set it in .env file", + type="auth_error", + param="MICROSOFT_CLIENT_SECRET", + code=status.HTTP_500_INTERNAL_SERVER_ERROR, + ) + if microsoft_tenant is None: + raise ProxyException( + message="MICROSOFT_TENANT not set. Set it in .env file", + type="auth_error", + param="MICROSOFT_TENANT", + code=status.HTTP_500_INTERNAL_SERVER_ERROR, + ) + + microsoft_sso = MicrosoftSSO( + client_id=microsoft_client_id, + client_secret=microsoft_client_secret, + tenant=microsoft_tenant, + redirect_uri=redirect_url, + allow_insecure_http=True, + ) + result = await microsoft_sso.verify_and_process(request) + + # User is Authe'd in - generate key for the UI to access Proxy + user_email = getattr(result, "email", None) + user_id = getattr(result, "id", None) + if user_id is None: + user_id = getattr(result, "first_name", "") + getattr(result, "last_name", "") + + response = await generate_key_helper_fn( + **{"duration": "1hr", "key_max_budget": 0, "models": [], "aliases": {}, "config": {}, "spend": 0, "user_id": user_id, "team_id": "litellm-dashboard", "user_email": user_email} # type: ignore + ) + key = response["token"] # type: ignore + user_id = response["user_id"] # type: ignore + + litellm_dashboard_ui = "/ui/" + + user_role = "app_owner" + if ( + os.getenv("PROXY_ADMIN_ID", None) is not None + and os.environ["PROXY_ADMIN_ID"] == user_id + ): + # checks if user is admin + user_role = "app_admin" + + import jwt + + jwt_token = jwt.encode( + { + "user_id": user_id, + "key": key, + "user_email": user_email, + "user_role": user_role, + }, + "secret", + algorithm="HS256", + ) + litellm_dashboard_ui += "?userID=" + user_id + "&token=" + jwt_token + + # if a user has logged in they should be allowed to create keys - this ensures that it's set to True + general_settings["allow_user_auth"] = True + return RedirectResponse(url=litellm_dashboard_ui) + + #### BASIC ENDPOINTS #### @router.post( "/config/update", diff --git a/litellm/proxy/schema.prisma b/litellm/proxy/schema.prisma index 3a5c3ea23..dbdcd3f5b 100644 --- a/litellm/proxy/schema.prisma +++ b/litellm/proxy/schema.prisma @@ -8,6 +8,25 @@ generator client { provider = "prisma-client-py" } +// Assign prod keys to groups, not individuals +model LiteLLM_TeamTable { + team_id String @unique + team_alias String? + admins String[] + members String[] + metadata Json @default("{}") + max_budget Float? + spend Float @default(0.0) + models String[] + max_parallel_requests Int? + tpm_limit BigInt? + rpm_limit BigInt? + budget_duration String? + budget_reset_at DateTime? + created_at DateTime @default(now()) @map("created_at") + updated_at DateTime @default(now()) @updatedAt @map("updated_at") +} + // Track spend, rate limit, budget Users model LiteLLM_UserTable { user_id String @unique diff --git a/litellm/proxy/utils.py b/litellm/proxy/utils.py index 0350d54bd..0108836fc 100644 --- a/litellm/proxy/utils.py +++ b/litellm/proxy/utils.py @@ -677,7 +677,7 @@ class PrismaClient: on_backoff=on_backoff, # specifying the function to call on backoff ) async def insert_data( - self, data: dict, table_name: Literal["user", "key", "config", "spend"] + self, data: dict, table_name: Literal["user", "key", "config", "spend", "team"] ): """ Add a key to the database. If it already exists, do nothing. @@ -713,6 +713,17 @@ class PrismaClient: ) verbose_proxy_logger.info(f"Data Inserted into User Table") return new_user_row + elif table_name == "team": + db_data = self.jsonify_object(data=data) + new_team_row = await self.db.litellm_teamtable.upsert( + where={"team_id": data["team_id"]}, + data={ + "create": {**db_data}, # type: ignore + "update": {}, # don't do anything if it already exists + }, + ) + verbose_proxy_logger.info(f"Data Inserted into Team Table") + return new_team_row elif table_name == "config": """ For each param, diff --git a/schema.prisma b/schema.prisma index 3a5c3ea23..f26363059 100644 --- a/schema.prisma +++ b/schema.prisma @@ -1,13 +1,32 @@ datasource client { provider = "postgresql" url = env("DATABASE_URL") - directUrl = env("DIRECT_URL") + directUrl = env("DIRECT_URL")? } generator client { provider = "prisma-client-py" } +// Assign prod keys to groups, not individuals +model LiteLLM_TeamTable { + team_id String @unique + team_alias String? + admins String[] + members String[] + metadata Json @default("{}") + max_budget Float? + spend Float @default(0.0) + models String[] + max_parallel_requests Int? + tpm_limit BigInt? + rpm_limit BigInt? + budget_duration String? + budget_reset_at DateTime? + created_at DateTime @default(now()) @map("created_at") + updated_at DateTime @default(now()) @updatedAt @map("updated_at") +} + // Track spend, rate limit, budget Users model LiteLLM_UserTable { user_id String @unique From 8717ee6d9ab7d540459b14d90bc3a874f55e4316 Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Wed, 14 Feb 2024 18:12:18 -0800 Subject: [PATCH 3/5] feat(proxy_server.py): enable /team/info endpoint --- litellm/proxy/_types.py | 4 ++ litellm/proxy/proxy_server.py | 84 ++++++++++++++++++++++++++++++++--- litellm/proxy/utils.py | 27 ++++++++++- 3 files changed, 108 insertions(+), 7 deletions(-) diff --git a/litellm/proxy/_types.py b/litellm/proxy/_types.py index 6d2d6e36c..7a6e400c8 100644 --- a/litellm/proxy/_types.py +++ b/litellm/proxy/_types.py @@ -239,6 +239,10 @@ class NewTeamResponse(LiteLLMBase): updated_at: datetime +class TeamRequest(LiteLLMBase): + teams: List[str] + + class KeyManagementSystem(enum.Enum): GOOGLE_KMS = "google_kms" AZURE_KEY_VAULT = "azure_key_vault" diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index e24fc1a82..ff10b5a17 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -2819,7 +2819,8 @@ async def generate_key_fn( Parameters: - duration: Optional[str] - Specify the length of time the token is valid for. You can set duration as seconds ("30s"), minutes ("30m"), hours ("30h"), days ("30d"). - key_alias: Optional[str] - User defined key alias - - team_id: Optional[str] - The team id of the user + - team_id: Optional[str] - The team id of the key + - user_id: Optional[str] - The user id of the key - models: Optional[list] - Model_name's a user is allowed to call. (if empty, key is allowed to call all models) - aliases: Optional[dict] - Any alias mappings, on top of anything in the config.yaml model list. - https://docs.litellm.ai/docs/proxy/virtual_keys#managing-auth---upgradedowngrade-models - config: Optional[dict] - any key-specific configs, overrides config in config.yaml @@ -3048,8 +3049,8 @@ async def info_key_fn_v2( Example Curl: ``` curl -X GET "http://0.0.0.0:8000/key/info" \ --H "Authorization: Bearer sk-1234" \ --d {"keys": ["sk-1", "sk-2", "sk-3"]} + -H "Authorization: Bearer sk-1234" \ + -d {"keys": ["sk-1", "sk-2", "sk-3"]} ``` """ global prisma_client @@ -3755,14 +3756,85 @@ async def delete_team(): pass -@router.post( +@router.get( "/team/info", tags=["team management"], dependencies=[Depends(user_api_key_auth)] ) -async def info_team(): +async def team_info( + team_id: str = fastapi.Query( + default=None, description="Team ID in the request parameters" + ) +): """ get info on team + related keys + + ``` + curl --location 'http://localhost:4000/team/info' \ + --header 'Authorization: Bearer sk-1234' \ + --header 'Content-Type: application/json' \ + --data '{ + "teams": ["",..] + }' + ``` """ - pass + global prisma_client + try: + if prisma_client is None: + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail={ + "error": f"Database not connected. Connect a database to your proxy - https://docs.litellm.ai/docs/simple_proxy#managing-auth---virtual-keys" + }, + ) + if team_id is None: + raise HTTPException( + status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, + detail={"message": "Malformed request. No team id passed in."}, + ) + + team_info = await prisma_client.get_data( + team_id=team_id, table_name="team", query_type="find_unique" + ) + ## GET ALL KEYS ## + keys = await prisma_client.get_data( + team_id=team_id, + table_name="key", + query_type="find_all", + expires=datetime.now(), + ) + + if team_info is None: + ## make sure we still return a total spend ## + spend = 0 + for k in keys: + spend += getattr(k, "spend", 0) + team_info = {"spend": spend} + + ## REMOVE HASHED TOKEN INFO before returning ## + for key in keys: + try: + key = key.model_dump() # noqa + except: + # if using pydantic v1 + key = key.dict() + key.pop("token", None) + return {"team_id": team_id, "team_info": team_info, "keys": keys} + + except Exception as e: + if isinstance(e, HTTPException): + raise ProxyException( + message=getattr(e, "detail", f"Authentication Error({str(e)})"), + type="auth_error", + param=getattr(e, "param", "None"), + code=getattr(e, "status_code", status.HTTP_400_BAD_REQUEST), + ) + elif isinstance(e, ProxyException): + raise e + raise ProxyException( + message="Authentication Error, " + str(e), + type="auth_error", + param=getattr(e, "param", "None"), + code=status.HTTP_400_BAD_REQUEST, + ) #### MODEL MANAGEMENT #### diff --git a/litellm/proxy/utils.py b/litellm/proxy/utils.py index 0108836fc..d82a7231f 100644 --- a/litellm/proxy/utils.py +++ b/litellm/proxy/utils.py @@ -509,8 +509,9 @@ class PrismaClient: token: Optional[Union[str, list]] = None, user_id: Optional[str] = None, user_id_list: Optional[list] = None, + team_id: Optional[str] = None, key_val: Optional[dict] = None, - table_name: Optional[Literal["user", "key", "config", "spend"]] = None, + table_name: Optional[Literal["user", "key", "config", "spend", "team"]] = None, query_type: Literal["find_unique", "find_all"] = "find_unique", expires: Optional[datetime] = None, reset_at: Optional[datetime] = None, @@ -545,6 +546,14 @@ class PrismaClient: for r in response: if isinstance(r.expires, datetime): r.expires = r.expires.isoformat() + elif query_type == "find_all" and team_id is not None: + response = await self.db.litellm_verificationtoken.find_many( + where={"team_id": team_id} + ) + if response is not None and len(response) > 0: + for r in response: + if isinstance(r.expires, datetime): + r.expires = r.expires.isoformat() elif ( query_type == "find_all" and expires is not None @@ -657,7 +666,23 @@ class PrismaClient: order={"startTime": "desc"}, ) return response + elif table_name == "team": + if query_type == "find_unique": + response = await self.db.litellm_teamtable.find_unique( + where={"team_id": team_id} # type: ignore + ) + if query_type == "find_all" and team_id is not None: + user_id_values = str(tuple(team_id)) + sql_query = f""" + SELECT * + FROM "LiteLLM_TeamTable" + WHERE "team_id" IN {team_id} + """ + # Execute the raw query + # The asterisk before `team_id` unpacks the list into separate arguments + response = await self.db.query_raw(sql_query) + return response except Exception as e: print_verbose(f"LiteLLM Prisma Client Exception: {e}") import traceback From 4e758ecd64e8b430bc48889a52ee8da02fc8a98b Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Wed, 14 Feb 2024 21:01:00 -0800 Subject: [PATCH 4/5] fix: prisma schema --- schema.prisma | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/schema.prisma b/schema.prisma index f26363059..dbdcd3f5b 100644 --- a/schema.prisma +++ b/schema.prisma @@ -1,7 +1,7 @@ datasource client { provider = "postgresql" url = env("DATABASE_URL") - directUrl = env("DIRECT_URL")? + directUrl = env("DIRECT_URL") } generator client { From 97ccc00a76c95c51daf6ed605e8b09904abca0ef Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Wed, 14 Feb 2024 21:04:39 -0800 Subject: [PATCH 5/5] refactor(main.py): trigger rebuild --- litellm/main.py | 1 + 1 file changed, 1 insertion(+) diff --git a/litellm/main.py b/litellm/main.py index bf1017a48..a7990ecfb 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -10,6 +10,7 @@ import os, openai, sys, json, inspect, uuid, datetime, threading from typing import Any, Literal, Union from functools import partial + import dotenv, traceback, random, asyncio, time, contextvars from copy import deepcopy import httpx