From d1b8c4e08d7470748bb9b5ce52ce023b31b346f5 Mon Sep 17 00:00:00 2001 From: Ishaan Jaff Date: Mon, 5 Aug 2024 21:45:44 -0700 Subject: [PATCH 1/5] forward headers from request --- .../pass_through_endpoints.py | 16 ++++++++++++++-- 1 file changed, 14 insertions(+), 2 deletions(-) diff --git a/litellm/proxy/pass_through_endpoints/pass_through_endpoints.py b/litellm/proxy/pass_through_endpoints/pass_through_endpoints.py index 351b19c25..3c7ea3748 100644 --- a/litellm/proxy/pass_through_endpoints/pass_through_endpoints.py +++ b/litellm/proxy/pass_through_endpoints/pass_through_endpoints.py @@ -244,6 +244,7 @@ async def pass_through_request( target: str, custom_headers: dict, user_api_key_dict: UserAPIKeyAuth, + forward_headers: Optional[bool] = False, ): try: import time @@ -262,6 +263,10 @@ async def pass_through_request( except: _parsed_body = json.loads(body_str) + if forward_headers is True: + request_headers = dict(request.headers) + headers = {**headers, **request_headers} + verbose_proxy_logger.debug( "Pass through endpoint sending request to \nURL {}\nheaders: {}\nbody: {}\n".format( url, headers, _parsed_body @@ -360,7 +365,10 @@ async def pass_through_request( def create_pass_through_route( - endpoint, target: str, custom_headers: Optional[dict] = None + endpoint, + target: str, + custom_headers: Optional[dict] = None, + _forward_headers: Optional[bool] = False, ): # check if target is an adapter.py or a url import uuid @@ -400,6 +408,7 @@ def create_pass_through_route( target=target, custom_headers=custom_headers or {}, user_api_key_dict=user_api_key_dict, + forward_headers=_forward_headers, ) return endpoint_func @@ -418,6 +427,7 @@ async def initialize_pass_through_endpoints(pass_through_endpoints: list): _custom_headers = await set_env_variables_in_header( custom_headers=_custom_headers ) + _forward_headers = endpoint.get("forward_headers", None) _auth = endpoint.get("auth", None) _dependencies = None if _auth is not None and str(_auth).lower() == "true": @@ -437,7 +447,9 @@ async def initialize_pass_through_endpoints(pass_through_endpoints: list): app.add_api_route( path=_path, - endpoint=create_pass_through_route(_path, _target, _custom_headers), + endpoint=create_pass_through_route( + _path, _target, _custom_headers, _forward_headers + ), methods=["GET", "POST", "PUT", "DELETE", "PATCH"], dependencies=_dependencies, ) From bd1f3232979c4c111eea76ff6952de96d01b2a57 Mon Sep 17 00:00:00 2001 From: Ishaan Jaff Date: Tue, 6 Aug 2024 11:34:10 -0700 Subject: [PATCH 2/5] use helper to forward headers from request --- .../pass_through_endpoints.py | 32 +++++++++++++++---- litellm/proxy/proxy_config.yaml | 7 ++++ 2 files changed, 33 insertions(+), 6 deletions(-) diff --git a/litellm/proxy/pass_through_endpoints/pass_through_endpoints.py b/litellm/proxy/pass_through_endpoints/pass_through_endpoints.py index 3c7ea3748..0327b1297 100644 --- a/litellm/proxy/pass_through_endpoints/pass_through_endpoints.py +++ b/litellm/proxy/pass_through_endpoints/pass_through_endpoints.py @@ -239,6 +239,26 @@ async def chat_completion_pass_through_endpoint( ) +def forward_headers_from_request( + request: Request, + headers: dict, + forward_headers: Optional[bool] = False, +): + """ + Helper to forward headers from original request + """ + if forward_headers is True: + request_headers = dict(request.headers) + + # Header We Should NOT forward + request_headers.pop("content-length", None) + request_headers.pop("host", None) + + # Combine request headers with custom headers + headers = {**request_headers, **headers} + return headers + + async def pass_through_request( request: Request, target: str, @@ -255,6 +275,9 @@ async def pass_through_request( url = httpx.URL(target) headers = custom_headers + headers = forward_headers_from_request( + request=request, headers=headers, forward_headers=forward_headers + ) request_body = await request.body() body_str = request_body.decode() @@ -263,10 +286,6 @@ async def pass_through_request( except: _parsed_body = json.loads(body_str) - if forward_headers is True: - request_headers = dict(request.headers) - headers = {**headers, **request_headers} - verbose_proxy_logger.debug( "Pass through endpoint sending request to \nURL {}\nheaders: {}\nbody: {}\n".format( url, headers, _parsed_body @@ -443,8 +462,9 @@ async def initialize_pass_through_endpoints(pass_through_endpoints: list): if _target is None: continue - verbose_proxy_logger.debug("adding pass through endpoint: %s", _path) - + verbose_proxy_logger.debug( + "adding pass through endpoint: %s, dependencies: %s", _path, _dependencies + ) app.add_api_route( path=_path, endpoint=create_pass_through_route( diff --git a/litellm/proxy/proxy_config.yaml b/litellm/proxy/proxy_config.yaml index 50e8bcd62..97cd407d3 100644 --- a/litellm/proxy/proxy_config.yaml +++ b/litellm/proxy/proxy_config.yaml @@ -40,6 +40,13 @@ files_settings: general_settings: master_key: sk-1234 + pass_through_endpoints: + - path: "/v1/rerank" # route you want to add to LiteLLM Proxy Server + target: "https://api.cohere.com/v1/rerank" # URL this route should forward requests to + headers: # headers to forward to this URL + content-type: application/json # (Optional) Extra Headers to pass to this endpoint + accept: application/json + forward_headers: True litellm_settings: From c277a71c1ec99f87b3af10fcf68c89c1ede667ce Mon Sep 17 00:00:00 2001 From: Ishaan Jaff Date: Tue, 6 Aug 2024 12:04:04 -0700 Subject: [PATCH 3/5] init pass through endpoints --- .../pass_through_endpoints.py | 46 +++++++++++++------ 1 file changed, 32 insertions(+), 14 deletions(-) diff --git a/litellm/proxy/pass_through_endpoints/pass_through_endpoints.py b/litellm/proxy/pass_through_endpoints/pass_through_endpoints.py index 0327b1297..3ab0425a3 100644 --- a/litellm/proxy/pass_through_endpoints/pass_through_endpoints.py +++ b/litellm/proxy/pass_through_endpoints/pass_through_endpoints.py @@ -3,7 +3,7 @@ import asyncio import json import traceback from base64 import b64encode -from typing import Optional +from typing import List, Optional import httpx from fastapi import ( @@ -388,6 +388,7 @@ def create_pass_through_route( target: str, custom_headers: Optional[dict] = None, _forward_headers: Optional[bool] = False, + dependencies: Optional[List] = None, ): # check if target is an adapter.py or a url import uuid @@ -416,19 +417,36 @@ def create_pass_through_route( except Exception: verbose_proxy_logger.warning("Defaulting to target being a url.") + if dependencies is None: - async def endpoint_func( - request: Request, - fastapi_response: Response, - user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), - ): - return await pass_through_request( - request=request, - target=target, - custom_headers=custom_headers or {}, - user_api_key_dict=user_api_key_dict, - forward_headers=_forward_headers, - ) + async def endpoint_func_no_auth( + request: Request, + fastapi_response: Response, + ): + return await pass_through_request( + request=request, + target=target, + custom_headers=custom_headers or {}, + user_api_key_dict=UserAPIKeyAuth(), + forward_headers=_forward_headers, + ) + + return endpoint_func_no_auth + + else: + + async def endpoint_func( + request: Request, + fastapi_response: Response, + user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), + ): + return await pass_through_request( + request=request, + target=target, + custom_headers=custom_headers or {}, + user_api_key_dict=user_api_key_dict, + forward_headers=_forward_headers, + ) return endpoint_func @@ -468,7 +486,7 @@ async def initialize_pass_through_endpoints(pass_through_endpoints: list): app.add_api_route( path=_path, endpoint=create_pass_through_route( - _path, _target, _custom_headers, _forward_headers + _path, _target, _custom_headers, _forward_headers, _dependencies ), methods=["GET", "POST", "PUT", "DELETE", "PATCH"], dependencies=_dependencies, From 0cd2435aff285c3bb368d1d14e9d759a10968129 Mon Sep 17 00:00:00 2001 From: Ishaan Jaff Date: Tue, 6 Aug 2024 12:07:21 -0700 Subject: [PATCH 4/5] doc forward_headers --- docs/my-website/docs/proxy/pass_through.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/docs/my-website/docs/proxy/pass_through.md b/docs/my-website/docs/proxy/pass_through.md index d557b81bd..4554f8013 100644 --- a/docs/my-website/docs/proxy/pass_through.md +++ b/docs/my-website/docs/proxy/pass_through.md @@ -35,6 +35,7 @@ general_settings: Authorization: "bearer os.environ/COHERE_API_KEY" # (Optional) Auth Header to forward to your Endpoint content-type: application/json # (Optional) Extra Headers to pass to this endpoint accept: application/json + forward_headers: True # (Optional) Forward all headers from the incoming request to the target endpoint ``` **Step 2** Start Proxy Server in detailed_debug mode @@ -220,6 +221,7 @@ general_settings: * `LANGFUSE_PUBLIC_KEY` *string*: Your Langfuse account public key - only set this when forwarding to Langfuse. * `LANGFUSE_SECRET_KEY` *string*: Your Langfuse account secret key - only set this when forwarding to Langfuse. * `` *string*: Pass any custom header key/value pair + * `forward_headers` *Optional(boolean)*: If true, all headers from the incoming request will be forwarded to the target endpoint. Default is `False`. ## Custom Chat Endpoints (Anthropic/Bedrock/Vertex) From 404360b28d2a5a92d0fde71db1caf4de5f1709a4 Mon Sep 17 00:00:00 2001 From: Ishaan Jaff Date: Tue, 6 Aug 2024 12:16:00 -0700 Subject: [PATCH 5/5] test pass through endpoint --- .circleci/config.yml | 1 + proxy_server_config.yaml | 8 ++++ tests/test_passthrough_endpoints.py | 63 +++++++++++++++++++++++++++++ 3 files changed, 72 insertions(+) create mode 100644 tests/test_passthrough_endpoints.py diff --git a/.circleci/config.yml b/.circleci/config.yml index 5712c71ca..f697be521 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -209,6 +209,7 @@ jobs: -e MISTRAL_API_KEY=$MISTRAL_API_KEY \ -e AWS_ACCESS_KEY_ID=$AWS_ACCESS_KEY_ID \ -e GROQ_API_KEY=$GROQ_API_KEY \ + -e COHERE_API_KEY=$COHERE_API_KEY \ -e AWS_SECRET_ACCESS_KEY=$AWS_SECRET_ACCESS_KEY \ -e AWS_REGION_NAME=$AWS_REGION_NAME \ -e AUTO_INFER_REGION=True \ diff --git a/proxy_server_config.yaml b/proxy_server_config.yaml index dc943d59d..4912ebbbf 100644 --- a/proxy_server_config.yaml +++ b/proxy_server_config.yaml @@ -154,6 +154,14 @@ general_settings: database_connection_pool_limit: 10 # database_url: "postgresql://:@:/" # [OPTIONAL] use for token-based auth to proxy + pass_through_endpoints: + - path: "/v1/rerank" # route you want to add to LiteLLM Proxy Server + target: "https://api.cohere.com/v1/rerank" # URL this route should forward requests to + headers: # headers to forward to this URL + content-type: application/json # (Optional) Extra Headers to pass to this endpoint + accept: application/json + forward_headers: True + # environment_variables: # settings for using redis caching # REDIS_HOST: redis-16337.c322.us-east-1-2.ec2.cloud.redislabs.com diff --git a/tests/test_passthrough_endpoints.py b/tests/test_passthrough_endpoints.py new file mode 100644 index 000000000..69ce71371 --- /dev/null +++ b/tests/test_passthrough_endpoints.py @@ -0,0 +1,63 @@ +import pytest +import asyncio +import aiohttp, openai +from openai import OpenAI, AsyncOpenAI +from typing import Optional, List, Union + +import aiohttp +import asyncio +import json +import os +import dotenv + + +dotenv.load_dotenv() + + +async def cohere_rerank(session): + url = "http://localhost:4000/v1/rerank" + headers = { + "Authorization": f"bearer {os.getenv('COHERE_API_KEY')}", + "Content-Type": "application/json", + "Accept": "application/json", + } + data = { + "model": "rerank-english-v3.0", + "query": "What is the capital of the United States?", + "top_n": 3, + "documents": [ + "Carson City is the capital city of the American state of Nevada.", + "The Commonwealth of the Northern Mariana Islands is a group of islands in the Pacific Ocean. Its capital is Saipan.", + "Washington, D.C. (also known as simply Washington or D.C., and officially as the District of Columbia) is the capital of the United States. It is a federal district.", + "Capitalization or capitalisation in English grammar is the use of a capital letter at the start of a word. English usage varies from capitalization in other languages.", + "Capital punishment (the death penalty) has existed in the United States since beforethe United States was a country. As of 2017, capital punishment is legal in 30 of the 50 states.", + ], + } + + async with session.post(url, headers=headers, json=data) as response: + status = response.status + response_text = await response.text() + print(f"Status: {status}") + print(f"Response:\n{response_text}") + print() + + if status != 200: + raise Exception(f"Request did not return a 200 status code: {status}") + + return await response.json() + + +@pytest.mark.asyncio +async def test_basic_passthrough(): + """ + - Make request to pass through endpoint + + - This SHOULD not go through LiteLLM user_api_key_auth + - This should forward headers from request to pass through endpoint + """ + async with aiohttp.ClientSession() as session: + response = await cohere_rerank(session) + print("response from cohere rerank", response) + + assert response["id"] is not None + assert response["results"] is not None