forked from phoenix/litellm-mirror
Merge pull request #5062 from BerriAI/litellm_forward_headers
[Fix-Proxy] allow forwarding headers from request
This commit is contained in:
commit
645d3ae09d
6 changed files with 147 additions and 16 deletions
|
@ -209,6 +209,7 @@ jobs:
|
||||||
-e MISTRAL_API_KEY=$MISTRAL_API_KEY \
|
-e MISTRAL_API_KEY=$MISTRAL_API_KEY \
|
||||||
-e AWS_ACCESS_KEY_ID=$AWS_ACCESS_KEY_ID \
|
-e AWS_ACCESS_KEY_ID=$AWS_ACCESS_KEY_ID \
|
||||||
-e GROQ_API_KEY=$GROQ_API_KEY \
|
-e GROQ_API_KEY=$GROQ_API_KEY \
|
||||||
|
-e COHERE_API_KEY=$COHERE_API_KEY \
|
||||||
-e AWS_SECRET_ACCESS_KEY=$AWS_SECRET_ACCESS_KEY \
|
-e AWS_SECRET_ACCESS_KEY=$AWS_SECRET_ACCESS_KEY \
|
||||||
-e AWS_REGION_NAME=$AWS_REGION_NAME \
|
-e AWS_REGION_NAME=$AWS_REGION_NAME \
|
||||||
-e AUTO_INFER_REGION=True \
|
-e AUTO_INFER_REGION=True \
|
||||||
|
|
|
@ -35,6 +35,7 @@ general_settings:
|
||||||
Authorization: "bearer os.environ/COHERE_API_KEY" # (Optional) Auth Header to forward to your Endpoint
|
Authorization: "bearer os.environ/COHERE_API_KEY" # (Optional) Auth Header to forward to your Endpoint
|
||||||
content-type: application/json # (Optional) Extra Headers to pass to this endpoint
|
content-type: application/json # (Optional) Extra Headers to pass to this endpoint
|
||||||
accept: application/json
|
accept: application/json
|
||||||
|
forward_headers: True # (Optional) Forward all headers from the incoming request to the target endpoint
|
||||||
```
|
```
|
||||||
|
|
||||||
**Step 2** Start Proxy Server in detailed_debug mode
|
**Step 2** Start Proxy Server in detailed_debug mode
|
||||||
|
@ -220,6 +221,7 @@ general_settings:
|
||||||
* `LANGFUSE_PUBLIC_KEY` *string*: Your Langfuse account public key - only set this when forwarding to Langfuse.
|
* `LANGFUSE_PUBLIC_KEY` *string*: Your Langfuse account public key - only set this when forwarding to Langfuse.
|
||||||
* `LANGFUSE_SECRET_KEY` *string*: Your Langfuse account secret key - only set this when forwarding to Langfuse.
|
* `LANGFUSE_SECRET_KEY` *string*: Your Langfuse account secret key - only set this when forwarding to Langfuse.
|
||||||
* `<your-custom-header>` *string*: Pass any custom header key/value pair
|
* `<your-custom-header>` *string*: Pass any custom header key/value pair
|
||||||
|
* `forward_headers` *Optional(boolean)*: If true, all headers from the incoming request will be forwarded to the target endpoint. Default is `False`.
|
||||||
|
|
||||||
|
|
||||||
## Custom Chat Endpoints (Anthropic/Bedrock/Vertex)
|
## Custom Chat Endpoints (Anthropic/Bedrock/Vertex)
|
||||||
|
|
|
@ -3,7 +3,7 @@ import asyncio
|
||||||
import json
|
import json
|
||||||
import traceback
|
import traceback
|
||||||
from base64 import b64encode
|
from base64 import b64encode
|
||||||
from typing import Optional
|
from typing import List, Optional
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
from fastapi import (
|
from fastapi import (
|
||||||
|
@ -239,11 +239,32 @@ async def chat_completion_pass_through_endpoint(
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def forward_headers_from_request(
|
||||||
|
request: Request,
|
||||||
|
headers: dict,
|
||||||
|
forward_headers: Optional[bool] = False,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Helper to forward headers from original request
|
||||||
|
"""
|
||||||
|
if forward_headers is True:
|
||||||
|
request_headers = dict(request.headers)
|
||||||
|
|
||||||
|
# Header We Should NOT forward
|
||||||
|
request_headers.pop("content-length", None)
|
||||||
|
request_headers.pop("host", None)
|
||||||
|
|
||||||
|
# Combine request headers with custom headers
|
||||||
|
headers = {**request_headers, **headers}
|
||||||
|
return headers
|
||||||
|
|
||||||
|
|
||||||
async def pass_through_request(
|
async def pass_through_request(
|
||||||
request: Request,
|
request: Request,
|
||||||
target: str,
|
target: str,
|
||||||
custom_headers: dict,
|
custom_headers: dict,
|
||||||
user_api_key_dict: UserAPIKeyAuth,
|
user_api_key_dict: UserAPIKeyAuth,
|
||||||
|
forward_headers: Optional[bool] = False,
|
||||||
):
|
):
|
||||||
try:
|
try:
|
||||||
import time
|
import time
|
||||||
|
@ -254,6 +275,9 @@ async def pass_through_request(
|
||||||
|
|
||||||
url = httpx.URL(target)
|
url = httpx.URL(target)
|
||||||
headers = custom_headers
|
headers = custom_headers
|
||||||
|
headers = forward_headers_from_request(
|
||||||
|
request=request, headers=headers, forward_headers=forward_headers
|
||||||
|
)
|
||||||
|
|
||||||
request_body = await request.body()
|
request_body = await request.body()
|
||||||
body_str = request_body.decode()
|
body_str = request_body.decode()
|
||||||
|
@ -360,7 +384,11 @@ async def pass_through_request(
|
||||||
|
|
||||||
|
|
||||||
def create_pass_through_route(
|
def create_pass_through_route(
|
||||||
endpoint, target: str, custom_headers: Optional[dict] = None
|
endpoint,
|
||||||
|
target: str,
|
||||||
|
custom_headers: Optional[dict] = None,
|
||||||
|
_forward_headers: Optional[bool] = False,
|
||||||
|
dependencies: Optional[List] = None,
|
||||||
):
|
):
|
||||||
# check if target is an adapter.py or a url
|
# check if target is an adapter.py or a url
|
||||||
import uuid
|
import uuid
|
||||||
|
@ -389,6 +417,23 @@ def create_pass_through_route(
|
||||||
|
|
||||||
except Exception:
|
except Exception:
|
||||||
verbose_proxy_logger.warning("Defaulting to target being a url.")
|
verbose_proxy_logger.warning("Defaulting to target being a url.")
|
||||||
|
if dependencies is None:
|
||||||
|
|
||||||
|
async def endpoint_func_no_auth(
|
||||||
|
request: Request,
|
||||||
|
fastapi_response: Response,
|
||||||
|
):
|
||||||
|
return await pass_through_request(
|
||||||
|
request=request,
|
||||||
|
target=target,
|
||||||
|
custom_headers=custom_headers or {},
|
||||||
|
user_api_key_dict=UserAPIKeyAuth(),
|
||||||
|
forward_headers=_forward_headers,
|
||||||
|
)
|
||||||
|
|
||||||
|
return endpoint_func_no_auth
|
||||||
|
|
||||||
|
else:
|
||||||
|
|
||||||
async def endpoint_func(
|
async def endpoint_func(
|
||||||
request: Request,
|
request: Request,
|
||||||
|
@ -400,6 +445,7 @@ def create_pass_through_route(
|
||||||
target=target,
|
target=target,
|
||||||
custom_headers=custom_headers or {},
|
custom_headers=custom_headers or {},
|
||||||
user_api_key_dict=user_api_key_dict,
|
user_api_key_dict=user_api_key_dict,
|
||||||
|
forward_headers=_forward_headers,
|
||||||
)
|
)
|
||||||
|
|
||||||
return endpoint_func
|
return endpoint_func
|
||||||
|
@ -418,6 +464,7 @@ async def initialize_pass_through_endpoints(pass_through_endpoints: list):
|
||||||
_custom_headers = await set_env_variables_in_header(
|
_custom_headers = await set_env_variables_in_header(
|
||||||
custom_headers=_custom_headers
|
custom_headers=_custom_headers
|
||||||
)
|
)
|
||||||
|
_forward_headers = endpoint.get("forward_headers", None)
|
||||||
_auth = endpoint.get("auth", None)
|
_auth = endpoint.get("auth", None)
|
||||||
_dependencies = None
|
_dependencies = None
|
||||||
if _auth is not None and str(_auth).lower() == "true":
|
if _auth is not None and str(_auth).lower() == "true":
|
||||||
|
@ -433,11 +480,14 @@ async def initialize_pass_through_endpoints(pass_through_endpoints: list):
|
||||||
if _target is None:
|
if _target is None:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
verbose_proxy_logger.debug("adding pass through endpoint: %s", _path)
|
verbose_proxy_logger.debug(
|
||||||
|
"adding pass through endpoint: %s, dependencies: %s", _path, _dependencies
|
||||||
|
)
|
||||||
app.add_api_route(
|
app.add_api_route(
|
||||||
path=_path,
|
path=_path,
|
||||||
endpoint=create_pass_through_route(_path, _target, _custom_headers),
|
endpoint=create_pass_through_route(
|
||||||
|
_path, _target, _custom_headers, _forward_headers, _dependencies
|
||||||
|
),
|
||||||
methods=["GET", "POST", "PUT", "DELETE", "PATCH"],
|
methods=["GET", "POST", "PUT", "DELETE", "PATCH"],
|
||||||
dependencies=_dependencies,
|
dependencies=_dependencies,
|
||||||
)
|
)
|
||||||
|
|
|
@ -40,6 +40,13 @@ files_settings:
|
||||||
|
|
||||||
general_settings:
|
general_settings:
|
||||||
master_key: sk-1234
|
master_key: sk-1234
|
||||||
|
pass_through_endpoints:
|
||||||
|
- path: "/v1/rerank" # route you want to add to LiteLLM Proxy Server
|
||||||
|
target: "https://api.cohere.com/v1/rerank" # URL this route should forward requests to
|
||||||
|
headers: # headers to forward to this URL
|
||||||
|
content-type: application/json # (Optional) Extra Headers to pass to this endpoint
|
||||||
|
accept: application/json
|
||||||
|
forward_headers: True
|
||||||
|
|
||||||
|
|
||||||
litellm_settings:
|
litellm_settings:
|
||||||
|
|
|
@ -154,6 +154,14 @@ general_settings:
|
||||||
database_connection_pool_limit: 10
|
database_connection_pool_limit: 10
|
||||||
# database_url: "postgresql://<user>:<password>@<host>:<port>/<dbname>" # [OPTIONAL] use for token-based auth to proxy
|
# database_url: "postgresql://<user>:<password>@<host>:<port>/<dbname>" # [OPTIONAL] use for token-based auth to proxy
|
||||||
|
|
||||||
|
pass_through_endpoints:
|
||||||
|
- path: "/v1/rerank" # route you want to add to LiteLLM Proxy Server
|
||||||
|
target: "https://api.cohere.com/v1/rerank" # URL this route should forward requests to
|
||||||
|
headers: # headers to forward to this URL
|
||||||
|
content-type: application/json # (Optional) Extra Headers to pass to this endpoint
|
||||||
|
accept: application/json
|
||||||
|
forward_headers: True
|
||||||
|
|
||||||
# environment_variables:
|
# environment_variables:
|
||||||
# settings for using redis caching
|
# settings for using redis caching
|
||||||
# REDIS_HOST: redis-16337.c322.us-east-1-2.ec2.cloud.redislabs.com
|
# REDIS_HOST: redis-16337.c322.us-east-1-2.ec2.cloud.redislabs.com
|
||||||
|
|
63
tests/test_passthrough_endpoints.py
Normal file
63
tests/test_passthrough_endpoints.py
Normal file
|
@ -0,0 +1,63 @@
|
||||||
|
import pytest
|
||||||
|
import asyncio
|
||||||
|
import aiohttp, openai
|
||||||
|
from openai import OpenAI, AsyncOpenAI
|
||||||
|
from typing import Optional, List, Union
|
||||||
|
|
||||||
|
import aiohttp
|
||||||
|
import asyncio
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import dotenv
|
||||||
|
|
||||||
|
|
||||||
|
dotenv.load_dotenv()
|
||||||
|
|
||||||
|
|
||||||
|
async def cohere_rerank(session):
|
||||||
|
url = "http://localhost:4000/v1/rerank"
|
||||||
|
headers = {
|
||||||
|
"Authorization": f"bearer {os.getenv('COHERE_API_KEY')}",
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
"Accept": "application/json",
|
||||||
|
}
|
||||||
|
data = {
|
||||||
|
"model": "rerank-english-v3.0",
|
||||||
|
"query": "What is the capital of the United States?",
|
||||||
|
"top_n": 3,
|
||||||
|
"documents": [
|
||||||
|
"Carson City is the capital city of the American state of Nevada.",
|
||||||
|
"The Commonwealth of the Northern Mariana Islands is a group of islands in the Pacific Ocean. Its capital is Saipan.",
|
||||||
|
"Washington, D.C. (also known as simply Washington or D.C., and officially as the District of Columbia) is the capital of the United States. It is a federal district.",
|
||||||
|
"Capitalization or capitalisation in English grammar is the use of a capital letter at the start of a word. English usage varies from capitalization in other languages.",
|
||||||
|
"Capital punishment (the death penalty) has existed in the United States since beforethe United States was a country. As of 2017, capital punishment is legal in 30 of the 50 states.",
|
||||||
|
],
|
||||||
|
}
|
||||||
|
|
||||||
|
async with session.post(url, headers=headers, json=data) as response:
|
||||||
|
status = response.status
|
||||||
|
response_text = await response.text()
|
||||||
|
print(f"Status: {status}")
|
||||||
|
print(f"Response:\n{response_text}")
|
||||||
|
print()
|
||||||
|
|
||||||
|
if status != 200:
|
||||||
|
raise Exception(f"Request did not return a 200 status code: {status}")
|
||||||
|
|
||||||
|
return await response.json()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_basic_passthrough():
|
||||||
|
"""
|
||||||
|
- Make request to pass through endpoint
|
||||||
|
|
||||||
|
- This SHOULD not go through LiteLLM user_api_key_auth
|
||||||
|
- This should forward headers from request to pass through endpoint
|
||||||
|
"""
|
||||||
|
async with aiohttp.ClientSession() as session:
|
||||||
|
response = await cohere_rerank(session)
|
||||||
|
print("response from cohere rerank", response)
|
||||||
|
|
||||||
|
assert response["id"] is not None
|
||||||
|
assert response["results"] is not None
|
Loading…
Add table
Add a link
Reference in a new issue