diff --git a/litellm/proxy/pass_through_endpoints/pass_through_endpoints.py b/litellm/proxy/pass_through_endpoints/pass_through_endpoints.py index d4b4484965..218032e012 100644 --- a/litellm/proxy/pass_through_endpoints/pass_through_endpoints.py +++ b/litellm/proxy/pass_through_endpoints/pass_through_endpoints.py @@ -3,12 +3,21 @@ import traceback from base64 import b64encode import httpx -from fastapi import APIRouter, FastAPI, HTTPException, Request, Response, status +from fastapi import ( + APIRouter, + Depends, + FastAPI, + HTTPException, + Request, + Response, + status, +) from fastapi.responses import StreamingResponse import litellm from litellm._logging import verbose_proxy_logger from litellm.proxy._types import ProxyException +from litellm.proxy.auth.user_api_key_auth import user_api_key_auth async_client = httpx.AsyncClient() @@ -129,7 +138,8 @@ def create_pass_through_route(endpoint, target, custom_headers=None): async def initialize_pass_through_endpoints(pass_through_endpoints: list): verbose_proxy_logger.debug("initializing pass through endpoints") - from litellm.proxy.proxy_server import app + from litellm.proxy._types import CommonProxyErrors, LiteLLMRoutes + from litellm.proxy.proxy_server import app, premium_user for endpoint in pass_through_endpoints: _target = endpoint.get("target", None) @@ -138,6 +148,15 @@ async def initialize_pass_through_endpoints(pass_through_endpoints: list): _custom_headers = await set_env_variables_in_header( custom_headers=_custom_headers ) + _auth = endpoint.get("auth", None) + _dependencies = None + if _auth is not None and str(_auth).lower() == "true": + if premium_user is not True: + raise ValueError( + f"Error Setting Authentication on Pass Through Endpoint: {CommonProxyErrors.not_premium_user}" + ) + _dependencies = [Depends(user_api_key_auth)] + LiteLLMRoutes.openai_routes.value.append(_path) if _target is None: continue @@ -148,6 +167,7 @@ async def initialize_pass_through_endpoints(pass_through_endpoints: list): path=_path, endpoint=create_pass_through_route(_path, _target, _custom_headers), methods=["GET", "POST", "PUT", "DELETE", "PATCH"], + dependencies=_dependencies, ) verbose_proxy_logger.debug("Added new pass through endpoint: %s", _path)