From 3dc6430fef6238e86b2fb53d0fa20dd208de6fe3 Mon Sep 17 00:00:00 2001 From: Ishaan Jaff Date: Sat, 29 Jun 2024 08:38:44 -0700 Subject: [PATCH] feat - setting up auth on pass through endpoint --- .../pass_through_endpoints.py | 24 +++++++++++++++++-- 1 file changed, 22 insertions(+), 2 deletions(-) diff --git a/litellm/proxy/pass_through_endpoints/pass_through_endpoints.py b/litellm/proxy/pass_through_endpoints/pass_through_endpoints.py index d4b4484965..218032e012 100644 --- a/litellm/proxy/pass_through_endpoints/pass_through_endpoints.py +++ b/litellm/proxy/pass_through_endpoints/pass_through_endpoints.py @@ -3,12 +3,21 @@ import traceback from base64 import b64encode import httpx -from fastapi import APIRouter, FastAPI, HTTPException, Request, Response, status +from fastapi import ( + APIRouter, + Depends, + FastAPI, + HTTPException, + Request, + Response, + status, +) from fastapi.responses import StreamingResponse import litellm from litellm._logging import verbose_proxy_logger from litellm.proxy._types import ProxyException +from litellm.proxy.auth.user_api_key_auth import user_api_key_auth async_client = httpx.AsyncClient() @@ -129,7 +138,8 @@ def create_pass_through_route(endpoint, target, custom_headers=None): async def initialize_pass_through_endpoints(pass_through_endpoints: list): verbose_proxy_logger.debug("initializing pass through endpoints") - from litellm.proxy.proxy_server import app + from litellm.proxy._types import CommonProxyErrors, LiteLLMRoutes + from litellm.proxy.proxy_server import app, premium_user for endpoint in pass_through_endpoints: _target = endpoint.get("target", None) @@ -138,6 +148,15 @@ async def initialize_pass_through_endpoints(pass_through_endpoints: list): _custom_headers = await set_env_variables_in_header( custom_headers=_custom_headers ) + _auth = endpoint.get("auth", None) + _dependencies = None + if _auth is not None and str(_auth).lower() == "true": + if premium_user is not True: + raise ValueError( + f"Error Setting Authentication on Pass Through Endpoint: {CommonProxyErrors.not_premium_user}" + ) + _dependencies = [Depends(user_api_key_auth)] + LiteLLMRoutes.openai_routes.value.append(_path) if _target is None: continue @@ -148,6 +167,7 @@ async def initialize_pass_through_endpoints(pass_through_endpoints: list): path=_path, endpoint=create_pass_through_route(_path, _target, _custom_headers), methods=["GET", "POST", "PUT", "DELETE", "PATCH"], + dependencies=_dependencies, ) verbose_proxy_logger.debug("Added new pass through endpoint: %s", _path)