diff --git a/.circleci/config.yml b/.circleci/config.yml index 5712c71ca..f697be521 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -209,6 +209,7 @@ jobs: -e MISTRAL_API_KEY=$MISTRAL_API_KEY \ -e AWS_ACCESS_KEY_ID=$AWS_ACCESS_KEY_ID \ -e GROQ_API_KEY=$GROQ_API_KEY \ + -e COHERE_API_KEY=$COHERE_API_KEY \ -e AWS_SECRET_ACCESS_KEY=$AWS_SECRET_ACCESS_KEY \ -e AWS_REGION_NAME=$AWS_REGION_NAME \ -e AUTO_INFER_REGION=True \ diff --git a/docs/my-website/docs/proxy/pass_through.md b/docs/my-website/docs/proxy/pass_through.md index d557b81bd..4554f8013 100644 --- a/docs/my-website/docs/proxy/pass_through.md +++ b/docs/my-website/docs/proxy/pass_through.md @@ -35,6 +35,7 @@ general_settings: Authorization: "bearer os.environ/COHERE_API_KEY" # (Optional) Auth Header to forward to your Endpoint content-type: application/json # (Optional) Extra Headers to pass to this endpoint accept: application/json + forward_headers: True # (Optional) Forward all headers from the incoming request to the target endpoint ``` **Step 2** Start Proxy Server in detailed_debug mode @@ -220,6 +221,7 @@ general_settings: * `LANGFUSE_PUBLIC_KEY` *string*: Your Langfuse account public key - only set this when forwarding to Langfuse. * `LANGFUSE_SECRET_KEY` *string*: Your Langfuse account secret key - only set this when forwarding to Langfuse. * `` *string*: Pass any custom header key/value pair + * `forward_headers` *Optional(boolean)*: If true, all headers from the incoming request will be forwarded to the target endpoint. Default is `False`. ## Custom Chat Endpoints (Anthropic/Bedrock/Vertex) diff --git a/litellm/proxy/pass_through_endpoints/pass_through_endpoints.py b/litellm/proxy/pass_through_endpoints/pass_through_endpoints.py index 351b19c25..3ab0425a3 100644 --- a/litellm/proxy/pass_through_endpoints/pass_through_endpoints.py +++ b/litellm/proxy/pass_through_endpoints/pass_through_endpoints.py @@ -3,7 +3,7 @@ import asyncio import json import traceback from base64 import b64encode -from typing import Optional +from typing import List, Optional import httpx from fastapi import ( @@ -239,11 +239,32 @@ async def chat_completion_pass_through_endpoint( ) +def forward_headers_from_request( + request: Request, + headers: dict, + forward_headers: Optional[bool] = False, +): + """ + Helper to forward headers from original request + """ + if forward_headers is True: + request_headers = dict(request.headers) + + # Header We Should NOT forward + request_headers.pop("content-length", None) + request_headers.pop("host", None) + + # Combine request headers with custom headers + headers = {**request_headers, **headers} + return headers + + async def pass_through_request( request: Request, target: str, custom_headers: dict, user_api_key_dict: UserAPIKeyAuth, + forward_headers: Optional[bool] = False, ): try: import time @@ -254,6 +275,9 @@ async def pass_through_request( url = httpx.URL(target) headers = custom_headers + headers = forward_headers_from_request( + request=request, headers=headers, forward_headers=forward_headers + ) request_body = await request.body() body_str = request_body.decode() @@ -360,7 +384,11 @@ async def pass_through_request( def create_pass_through_route( - endpoint, target: str, custom_headers: Optional[dict] = None + endpoint, + target: str, + custom_headers: Optional[dict] = None, + _forward_headers: Optional[bool] = False, + dependencies: Optional[List] = None, ): # check if target is an adapter.py or a url import uuid @@ -389,18 +417,36 @@ def create_pass_through_route( except Exception: verbose_proxy_logger.warning("Defaulting to target being a url.") + if dependencies is None: - async def endpoint_func( - request: Request, - fastapi_response: Response, - user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), - ): - return await pass_through_request( - request=request, - target=target, - custom_headers=custom_headers or {}, - user_api_key_dict=user_api_key_dict, - ) + async def endpoint_func_no_auth( + request: Request, + fastapi_response: Response, + ): + return await pass_through_request( + request=request, + target=target, + custom_headers=custom_headers or {}, + user_api_key_dict=UserAPIKeyAuth(), + forward_headers=_forward_headers, + ) + + return endpoint_func_no_auth + + else: + + async def endpoint_func( + request: Request, + fastapi_response: Response, + user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), + ): + return await pass_through_request( + request=request, + target=target, + custom_headers=custom_headers or {}, + user_api_key_dict=user_api_key_dict, + forward_headers=_forward_headers, + ) return endpoint_func @@ -418,6 +464,7 @@ async def initialize_pass_through_endpoints(pass_through_endpoints: list): _custom_headers = await set_env_variables_in_header( custom_headers=_custom_headers ) + _forward_headers = endpoint.get("forward_headers", None) _auth = endpoint.get("auth", None) _dependencies = None if _auth is not None and str(_auth).lower() == "true": @@ -433,11 +480,14 @@ async def initialize_pass_through_endpoints(pass_through_endpoints: list): if _target is None: continue - verbose_proxy_logger.debug("adding pass through endpoint: %s", _path) - + verbose_proxy_logger.debug( + "adding pass through endpoint: %s, dependencies: %s", _path, _dependencies + ) app.add_api_route( path=_path, - endpoint=create_pass_through_route(_path, _target, _custom_headers), + endpoint=create_pass_through_route( + _path, _target, _custom_headers, _forward_headers, _dependencies + ), methods=["GET", "POST", "PUT", "DELETE", "PATCH"], dependencies=_dependencies, ) diff --git a/litellm/proxy/proxy_config.yaml b/litellm/proxy/proxy_config.yaml index 50e8bcd62..97cd407d3 100644 --- a/litellm/proxy/proxy_config.yaml +++ b/litellm/proxy/proxy_config.yaml @@ -40,6 +40,13 @@ files_settings: general_settings: master_key: sk-1234 + pass_through_endpoints: + - path: "/v1/rerank" # route you want to add to LiteLLM Proxy Server + target: "https://api.cohere.com/v1/rerank" # URL this route should forward requests to + headers: # headers to forward to this URL + content-type: application/json # (Optional) Extra Headers to pass to this endpoint + accept: application/json + forward_headers: True litellm_settings: diff --git a/proxy_server_config.yaml b/proxy_server_config.yaml index dc943d59d..4912ebbbf 100644 --- a/proxy_server_config.yaml +++ b/proxy_server_config.yaml @@ -154,6 +154,14 @@ general_settings: database_connection_pool_limit: 10 # database_url: "postgresql://:@:/" # [OPTIONAL] use for token-based auth to proxy + pass_through_endpoints: + - path: "/v1/rerank" # route you want to add to LiteLLM Proxy Server + target: "https://api.cohere.com/v1/rerank" # URL this route should forward requests to + headers: # headers to forward to this URL + content-type: application/json # (Optional) Extra Headers to pass to this endpoint + accept: application/json + forward_headers: True + # environment_variables: # settings for using redis caching # REDIS_HOST: redis-16337.c322.us-east-1-2.ec2.cloud.redislabs.com diff --git a/tests/test_passthrough_endpoints.py b/tests/test_passthrough_endpoints.py new file mode 100644 index 000000000..69ce71371 --- /dev/null +++ b/tests/test_passthrough_endpoints.py @@ -0,0 +1,63 @@ +import pytest +import asyncio +import aiohttp, openai +from openai import OpenAI, AsyncOpenAI +from typing import Optional, List, Union + +import aiohttp +import asyncio +import json +import os +import dotenv + + +dotenv.load_dotenv() + + +async def cohere_rerank(session): + url = "http://localhost:4000/v1/rerank" + headers = { + "Authorization": f"bearer {os.getenv('COHERE_API_KEY')}", + "Content-Type": "application/json", + "Accept": "application/json", + } + data = { + "model": "rerank-english-v3.0", + "query": "What is the capital of the United States?", + "top_n": 3, + "documents": [ + "Carson City is the capital city of the American state of Nevada.", + "The Commonwealth of the Northern Mariana Islands is a group of islands in the Pacific Ocean. Its capital is Saipan.", + "Washington, D.C. (also known as simply Washington or D.C., and officially as the District of Columbia) is the capital of the United States. It is a federal district.", + "Capitalization or capitalisation in English grammar is the use of a capital letter at the start of a word. English usage varies from capitalization in other languages.", + "Capital punishment (the death penalty) has existed in the United States since beforethe United States was a country. As of 2017, capital punishment is legal in 30 of the 50 states.", + ], + } + + async with session.post(url, headers=headers, json=data) as response: + status = response.status + response_text = await response.text() + print(f"Status: {status}") + print(f"Response:\n{response_text}") + print() + + if status != 200: + raise Exception(f"Request did not return a 200 status code: {status}") + + return await response.json() + + +@pytest.mark.asyncio +async def test_basic_passthrough(): + """ + - Make request to pass through endpoint + + - This SHOULD not go through LiteLLM user_api_key_auth + - This should forward headers from request to pass through endpoint + """ + async with aiohttp.ClientSession() as session: + response = await cohere_rerank(session) + print("response from cohere rerank", response) + + assert response["id"] is not None + assert response["results"] is not None