fix(proxy/_types.py): allow internal user to access sso routes

This commit is contained in:
Krrish Dholakia 2024-06-17 18:34:20 -07:00
parent 27722c816f
commit ec56ae7c9a
2 changed files with 63 additions and 58 deletions

View file

@ -1,13 +1,17 @@
from pydantic import BaseModel, Extra, Field, model_validator, Json, ConfigDict
from dataclasses import fields
import enum import enum
from typing import Optional, List, Union, Dict, Literal, Any, TypedDict, TYPE_CHECKING import json
import os
import sys
import uuid
from dataclasses import fields
from datetime import datetime from datetime import datetime
import uuid, json, sys, os from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional, TypedDict, Union
from litellm.types.router import UpdateRouterConfig
from litellm.types.utils import ProviderField from pydantic import BaseModel, ConfigDict, Extra, Field, Json, model_validator
from typing_extensions import Annotated from typing_extensions import Annotated
from litellm.types.router import UpdateRouterConfig
from litellm.types.utils import ProviderField
if TYPE_CHECKING: if TYPE_CHECKING:
from opentelemetry.trace import Span as _Span from opentelemetry.trace import Span as _Span
@ -283,12 +287,16 @@ class LiteLLMRoutes(enum.Enum):
"/metrics", "/metrics",
] ]
internal_user_routes: List = [ internal_user_routes: List = (
[
"/key/generate", "/key/generate",
"/key/update", "/key/update",
"/key/delete", "/key/delete",
"/key/info", "/key/info",
] + spend_tracking_routes ]
+ spend_tracking_routes
+ sso_only_routes
)
# class LiteLLMAllowedRoutes(LiteLLMBase): # class LiteLLMAllowedRoutes(LiteLLMBase):

View file

@ -7,59 +7,56 @@ Returns a UserAPIKeyAuth object if the API key is valid
""" """
import asyncio
import json import json
import secrets
import traceback
from datetime import datetime, timedelta, timezone from datetime import datetime, timedelta, timezone
from typing import Optional from typing import Optional
import secrets
from uuid import uuid4 from uuid import uuid4
import fastapi import fastapi
from fastapi import Request
from pydantic import BaseModel
import litellm
import traceback
import asyncio
from fastapi import ( from fastapi import (
FastAPI,
Request,
HTTPException,
status,
Path,
Depends, Depends,
Header, FastAPI,
Response,
Form,
UploadFile,
File, File,
Form,
Header,
HTTPException,
Path,
Request,
Response,
UploadFile,
status,
) )
from fastapi.responses import (
StreamingResponse,
FileResponse,
ORJSONResponse,
JSONResponse,
)
from fastapi.openapi.utils import get_openapi
from fastapi.responses import RedirectResponse
from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.cors import CORSMiddleware
from fastapi.staticfiles import StaticFiles from fastapi.openapi.utils import get_openapi
from fastapi.responses import (
FileResponse,
JSONResponse,
ORJSONResponse,
RedirectResponse,
StreamingResponse,
)
from fastapi.security.api_key import APIKeyHeader from fastapi.security.api_key import APIKeyHeader
from litellm.proxy._types import * from fastapi.staticfiles import StaticFiles
from litellm._logging import verbose_logger, verbose_proxy_logger from pydantic import BaseModel
import litellm
from litellm._logging import verbose_logger, verbose_proxy_logger
from litellm.proxy._types import *
from litellm.proxy.auth.auth_checks import ( from litellm.proxy.auth.auth_checks import (
allowed_routes_check,
common_checks, common_checks,
get_actual_routes,
get_end_user_object, get_end_user_object,
get_org_object, get_org_object,
get_team_object, get_team_object,
get_user_object, get_user_object,
allowed_routes_check,
get_actual_routes,
log_to_opentelemetry, log_to_opentelemetry,
) )
from litellm.proxy.utils import _to_ns
from litellm.proxy.common_utils.http_parsing_utils import _read_request_body from litellm.proxy.common_utils.http_parsing_utils import _read_request_body
from litellm.proxy.utils import _to_ns
api_key_header = APIKeyHeader( api_key_header = APIKeyHeader(
name="Authorization", auto_error=False, description="Bearer token" name="Authorization", auto_error=False, description="Bearer token"
@ -88,20 +85,20 @@ async def user_api_key_auth(
) -> UserAPIKeyAuth: ) -> UserAPIKeyAuth:
from litellm.proxy.proxy_server import ( from litellm.proxy.proxy_server import (
litellm_proxy_admin_name, allowed_routes_check,
common_checks, common_checks,
master_key,
prisma_client,
llm_model_list,
user_custom_auth,
custom_db_client, custom_db_client,
general_settings, general_settings,
proxy_logging_obj,
open_telemetry_logger,
user_api_key_cache,
jwt_handler,
allowed_routes_check,
get_actual_routes, get_actual_routes,
jwt_handler,
litellm_proxy_admin_name,
llm_model_list,
master_key,
open_telemetry_logger,
prisma_client,
proxy_logging_obj,
user_api_key_cache,
user_custom_auth,
) )
try: try:
@ -1004,7 +1001,7 @@ async def user_api_key_auth(
): ):
pass pass
elif _user_role == LitellmUserRoles.PROXY_ADMIN_VIEW_ONLY: elif _user_role == LitellmUserRoles.PROXY_ADMIN_VIEW_ONLY.value:
if route in LiteLLMRoutes.openai_routes.value: if route in LiteLLMRoutes.openai_routes.value:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN, status_code=status.HTTP_403_FORBIDDEN,
@ -1031,7 +1028,7 @@ async def user_api_key_auth(
detail=f"user not allowed to access this route, role= {_user_role}. Trying to access: {route}", detail=f"user not allowed to access this route, role= {_user_role}. Trying to access: {route}",
) )
elif ( elif (
_user_role == LitellmUserRoles.INTERNAL_USER _user_role == LitellmUserRoles.INTERNAL_USER.value
and route in LiteLLMRoutes.internal_user_routes.value and route in LiteLLMRoutes.internal_user_routes.value
): ):
pass pass
@ -1059,6 +1056,7 @@ async def user_api_key_auth(
# this token is only used for managing the ui # this token is only used for managing the ui
allowed_routes = [ allowed_routes = [
"/sso", "/sso",
"/sso/get/logout_url",
"/login", "/login",
"/key/generate", "/key/generate",
"/key/update", "/key/update",
@ -1144,8 +1142,8 @@ async def user_api_key_auth(
raise Exception() raise Exception()
except Exception as e: except Exception as e:
verbose_proxy_logger.error( verbose_proxy_logger.error(
"litellm.proxy.proxy_server.user_api_key_auth(): Exception occured - {}".format( "litellm.proxy.proxy_server.user_api_key_auth(): Exception occured - {}\n{}".format(
str(e) str(e), traceback.format_exc()
) )
) )
@ -1156,7 +1154,6 @@ async def user_api_key_auth(
user_api_key_dict=UserAPIKeyAuth(parent_otel_span=parent_otel_span), user_api_key_dict=UserAPIKeyAuth(parent_otel_span=parent_otel_span),
) )
verbose_proxy_logger.debug(traceback.format_exc())
if isinstance(e, litellm.BudgetExceededError): if isinstance(e, litellm.BudgetExceededError):
raise ProxyException( raise ProxyException(
message=e.message, type="auth_error", param=None, code=400 message=e.message, type="auth_error", param=None, code=400