Merge pull request #5062 from BerriAI/litellm_forward_headers

[Fix-Proxy] allow forwarding headers from request
This commit is contained in:
Ishaan Jaff 2024-08-06 12:34:25 -07:00 committed by GitHub
commit 645d3ae09d
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
6 changed files with 147 additions and 16 deletions

View file

@ -209,6 +209,7 @@ jobs:
-e MISTRAL_API_KEY=$MISTRAL_API_KEY \
-e AWS_ACCESS_KEY_ID=$AWS_ACCESS_KEY_ID \
-e GROQ_API_KEY=$GROQ_API_KEY \
-e COHERE_API_KEY=$COHERE_API_KEY \
-e AWS_SECRET_ACCESS_KEY=$AWS_SECRET_ACCESS_KEY \
-e AWS_REGION_NAME=$AWS_REGION_NAME \
-e AUTO_INFER_REGION=True \

View file

@ -35,6 +35,7 @@ general_settings:
Authorization: "bearer os.environ/COHERE_API_KEY" # (Optional) Auth Header to forward to your Endpoint
content-type: application/json # (Optional) Extra Headers to pass to this endpoint
accept: application/json
forward_headers: True # (Optional) Forward all headers from the incoming request to the target endpoint
```
**Step 2** Start Proxy Server in detailed_debug mode
@ -220,6 +221,7 @@ general_settings:
* `LANGFUSE_PUBLIC_KEY` *string*: Your Langfuse account public key - only set this when forwarding to Langfuse.
* `LANGFUSE_SECRET_KEY` *string*: Your Langfuse account secret key - only set this when forwarding to Langfuse.
* `<your-custom-header>` *string*: Pass any custom header key/value pair
* `forward_headers` *Optional(boolean)*: If true, all headers from the incoming request will be forwarded to the target endpoint. Default is `False`.
## Custom Chat Endpoints (Anthropic/Bedrock/Vertex)

View file

@ -3,7 +3,7 @@ import asyncio
import json
import traceback
from base64 import b64encode
from typing import Optional
from typing import List, Optional
import httpx
from fastapi import (
@ -239,11 +239,32 @@ async def chat_completion_pass_through_endpoint(
)
def forward_headers_from_request(
request: Request,
headers: dict,
forward_headers: Optional[bool] = False,
):
"""
Helper to forward headers from original request
"""
if forward_headers is True:
request_headers = dict(request.headers)
# Header We Should NOT forward
request_headers.pop("content-length", None)
request_headers.pop("host", None)
# Combine request headers with custom headers
headers = {**request_headers, **headers}
return headers
async def pass_through_request(
request: Request,
target: str,
custom_headers: dict,
user_api_key_dict: UserAPIKeyAuth,
forward_headers: Optional[bool] = False,
):
try:
import time
@ -254,6 +275,9 @@ async def pass_through_request(
url = httpx.URL(target)
headers = custom_headers
headers = forward_headers_from_request(
request=request, headers=headers, forward_headers=forward_headers
)
request_body = await request.body()
body_str = request_body.decode()
@ -360,7 +384,11 @@ async def pass_through_request(
def create_pass_through_route(
endpoint, target: str, custom_headers: Optional[dict] = None
endpoint,
target: str,
custom_headers: Optional[dict] = None,
_forward_headers: Optional[bool] = False,
dependencies: Optional[List] = None,
):
# check if target is an adapter.py or a url
import uuid
@ -389,18 +417,36 @@ def create_pass_through_route(
except Exception:
verbose_proxy_logger.warning("Defaulting to target being a url.")
if dependencies is None:
async def endpoint_func(
request: Request,
fastapi_response: Response,
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
):
return await pass_through_request(
request=request,
target=target,
custom_headers=custom_headers or {},
user_api_key_dict=user_api_key_dict,
)
async def endpoint_func_no_auth(
request: Request,
fastapi_response: Response,
):
return await pass_through_request(
request=request,
target=target,
custom_headers=custom_headers or {},
user_api_key_dict=UserAPIKeyAuth(),
forward_headers=_forward_headers,
)
return endpoint_func_no_auth
else:
async def endpoint_func(
request: Request,
fastapi_response: Response,
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
):
return await pass_through_request(
request=request,
target=target,
custom_headers=custom_headers or {},
user_api_key_dict=user_api_key_dict,
forward_headers=_forward_headers,
)
return endpoint_func
@ -418,6 +464,7 @@ async def initialize_pass_through_endpoints(pass_through_endpoints: list):
_custom_headers = await set_env_variables_in_header(
custom_headers=_custom_headers
)
_forward_headers = endpoint.get("forward_headers", None)
_auth = endpoint.get("auth", None)
_dependencies = None
if _auth is not None and str(_auth).lower() == "true":
@ -433,11 +480,14 @@ async def initialize_pass_through_endpoints(pass_through_endpoints: list):
if _target is None:
continue
verbose_proxy_logger.debug("adding pass through endpoint: %s", _path)
verbose_proxy_logger.debug(
"adding pass through endpoint: %s, dependencies: %s", _path, _dependencies
)
app.add_api_route(
path=_path,
endpoint=create_pass_through_route(_path, _target, _custom_headers),
endpoint=create_pass_through_route(
_path, _target, _custom_headers, _forward_headers, _dependencies
),
methods=["GET", "POST", "PUT", "DELETE", "PATCH"],
dependencies=_dependencies,
)

View file

@ -40,6 +40,13 @@ files_settings:
general_settings:
master_key: sk-1234
pass_through_endpoints:
- path: "/v1/rerank" # route you want to add to LiteLLM Proxy Server
target: "https://api.cohere.com/v1/rerank" # URL this route should forward requests to
headers: # headers to forward to this URL
content-type: application/json # (Optional) Extra Headers to pass to this endpoint
accept: application/json
forward_headers: True
litellm_settings:

View file

@ -154,6 +154,14 @@ general_settings:
database_connection_pool_limit: 10
# database_url: "postgresql://<user>:<password>@<host>:<port>/<dbname>" # [OPTIONAL] use for token-based auth to proxy
pass_through_endpoints:
- path: "/v1/rerank" # route you want to add to LiteLLM Proxy Server
target: "https://api.cohere.com/v1/rerank" # URL this route should forward requests to
headers: # headers to forward to this URL
content-type: application/json # (Optional) Extra Headers to pass to this endpoint
accept: application/json
forward_headers: True
# environment_variables:
# settings for using redis caching
# REDIS_HOST: redis-16337.c322.us-east-1-2.ec2.cloud.redislabs.com

View file

@ -0,0 +1,63 @@
import pytest
import asyncio
import aiohttp, openai
from openai import OpenAI, AsyncOpenAI
from typing import Optional, List, Union
import aiohttp
import asyncio
import json
import os
import dotenv
dotenv.load_dotenv()
async def cohere_rerank(session):
url = "http://localhost:4000/v1/rerank"
headers = {
"Authorization": f"bearer {os.getenv('COHERE_API_KEY')}",
"Content-Type": "application/json",
"Accept": "application/json",
}
data = {
"model": "rerank-english-v3.0",
"query": "What is the capital of the United States?",
"top_n": 3,
"documents": [
"Carson City is the capital city of the American state of Nevada.",
"The Commonwealth of the Northern Mariana Islands is a group of islands in the Pacific Ocean. Its capital is Saipan.",
"Washington, D.C. (also known as simply Washington or D.C., and officially as the District of Columbia) is the capital of the United States. It is a federal district.",
"Capitalization or capitalisation in English grammar is the use of a capital letter at the start of a word. English usage varies from capitalization in other languages.",
"Capital punishment (the death penalty) has existed in the United States since beforethe United States was a country. As of 2017, capital punishment is legal in 30 of the 50 states.",
],
}
async with session.post(url, headers=headers, json=data) as response:
status = response.status
response_text = await response.text()
print(f"Status: {status}")
print(f"Response:\n{response_text}")
print()
if status != 200:
raise Exception(f"Request did not return a 200 status code: {status}")
return await response.json()
@pytest.mark.asyncio
async def test_basic_passthrough():
"""
- Make request to pass through endpoint
- This SHOULD not go through LiteLLM user_api_key_auth
- This should forward headers from request to pass through endpoint
"""
async with aiohttp.ClientSession() as session:
response = await cohere_rerank(session)
print("response from cohere rerank", response)
assert response["id"] is not None
assert response["results"] is not None