fix(proxy/_types.py): allow internal user to access sso routes

This commit is contained in:
Krrish Dholakia 2024-06-17 18:34:20 -07:00
parent 27722c816f
commit ec56ae7c9a
2 changed files with 63 additions and 58 deletions

View file

@ -1,13 +1,17 @@
from pydantic import BaseModel, Extra, Field, model_validator, Json, ConfigDict
from dataclasses import fields
import enum
from typing import Optional, List, Union, Dict, Literal, Any, TypedDict, TYPE_CHECKING
import json
import os
import sys
import uuid
from dataclasses import fields
from datetime import datetime
import uuid, json, sys, os
from litellm.types.router import UpdateRouterConfig
from litellm.types.utils import ProviderField
from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional, TypedDict, Union
from pydantic import BaseModel, ConfigDict, Extra, Field, Json, model_validator
from typing_extensions import Annotated
from litellm.types.router import UpdateRouterConfig
from litellm.types.utils import ProviderField
if TYPE_CHECKING:
from opentelemetry.trace import Span as _Span
@ -283,12 +287,16 @@ class LiteLLMRoutes(enum.Enum):
"/metrics",
]
internal_user_routes: List = [
internal_user_routes: List = (
[
"/key/generate",
"/key/update",
"/key/delete",
"/key/info",
] + spend_tracking_routes
]
+ spend_tracking_routes
+ sso_only_routes
)
# class LiteLLMAllowedRoutes(LiteLLMBase):

View file

@ -7,59 +7,56 @@ Returns a UserAPIKeyAuth object if the API key is valid
"""
import asyncio
import json
import secrets
import traceback
from datetime import datetime, timedelta, timezone
from typing import Optional
import secrets
from uuid import uuid4
import fastapi
from fastapi import Request
from pydantic import BaseModel
import litellm
import traceback
import asyncio
from fastapi import (
FastAPI,
Request,
HTTPException,
status,
Path,
Depends,
Header,
Response,
Form,
UploadFile,
FastAPI,
File,
Form,
Header,
HTTPException,
Path,
Request,
Response,
UploadFile,
status,
)
from fastapi.responses import (
StreamingResponse,
FileResponse,
ORJSONResponse,
JSONResponse,
)
from fastapi.openapi.utils import get_openapi
from fastapi.responses import RedirectResponse
from fastapi.middleware.cors import CORSMiddleware
from fastapi.staticfiles import StaticFiles
from fastapi.openapi.utils import get_openapi
from fastapi.responses import (
FileResponse,
JSONResponse,
ORJSONResponse,
RedirectResponse,
StreamingResponse,
)
from fastapi.security.api_key import APIKeyHeader
from litellm.proxy._types import *
from litellm._logging import verbose_logger, verbose_proxy_logger
from fastapi.staticfiles import StaticFiles
from pydantic import BaseModel
import litellm
from litellm._logging import verbose_logger, verbose_proxy_logger
from litellm.proxy._types import *
from litellm.proxy.auth.auth_checks import (
allowed_routes_check,
common_checks,
get_actual_routes,
get_end_user_object,
get_org_object,
get_team_object,
get_user_object,
allowed_routes_check,
get_actual_routes,
log_to_opentelemetry,
)
from litellm.proxy.utils import _to_ns
from litellm.proxy.common_utils.http_parsing_utils import _read_request_body
from litellm.proxy.utils import _to_ns
api_key_header = APIKeyHeader(
name="Authorization", auto_error=False, description="Bearer token"
@ -88,20 +85,20 @@ async def user_api_key_auth(
) -> UserAPIKeyAuth:
from litellm.proxy.proxy_server import (
litellm_proxy_admin_name,
allowed_routes_check,
common_checks,
master_key,
prisma_client,
llm_model_list,
user_custom_auth,
custom_db_client,
general_settings,
proxy_logging_obj,
open_telemetry_logger,
user_api_key_cache,
jwt_handler,
allowed_routes_check,
get_actual_routes,
jwt_handler,
litellm_proxy_admin_name,
llm_model_list,
master_key,
open_telemetry_logger,
prisma_client,
proxy_logging_obj,
user_api_key_cache,
user_custom_auth,
)
try:
@ -1004,7 +1001,7 @@ async def user_api_key_auth(
):
pass
elif _user_role == LitellmUserRoles.PROXY_ADMIN_VIEW_ONLY:
elif _user_role == LitellmUserRoles.PROXY_ADMIN_VIEW_ONLY.value:
if route in LiteLLMRoutes.openai_routes.value:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
@ -1031,7 +1028,7 @@ async def user_api_key_auth(
detail=f"user not allowed to access this route, role= {_user_role}. Trying to access: {route}",
)
elif (
_user_role == LitellmUserRoles.INTERNAL_USER
_user_role == LitellmUserRoles.INTERNAL_USER.value
and route in LiteLLMRoutes.internal_user_routes.value
):
pass
@ -1059,6 +1056,7 @@ async def user_api_key_auth(
# this token is only used for managing the ui
allowed_routes = [
"/sso",
"/sso/get/logout_url",
"/login",
"/key/generate",
"/key/update",
@ -1144,8 +1142,8 @@ async def user_api_key_auth(
raise Exception()
except Exception as e:
verbose_proxy_logger.error(
"litellm.proxy.proxy_server.user_api_key_auth(): Exception occured - {}".format(
str(e)
"litellm.proxy.proxy_server.user_api_key_auth(): Exception occured - {}\n{}".format(
str(e), traceback.format_exc()
)
)
@ -1156,7 +1154,6 @@ async def user_api_key_auth(
user_api_key_dict=UserAPIKeyAuth(parent_otel_span=parent_otel_span),
)
verbose_proxy_logger.debug(traceback.format_exc())
if isinstance(e, litellm.BudgetExceededError):
raise ProxyException(
message=e.message, type="auth_error", param=None, code=400