fix use os.environ/ vars for pass through endpoints

This commit is contained in:
Ishaan Jaff 2024-06-28 15:30:31 -07:00
parent cf9636cc59
commit 69deb65c04
2 changed files with 42 additions and 4 deletions

View file

@ -5,20 +5,49 @@ import httpx
from fastapi import APIRouter, FastAPI, HTTPException, Request, Response, status from fastapi import APIRouter, FastAPI, HTTPException, Request, Response, status
from fastapi.responses import StreamingResponse from fastapi.responses import StreamingResponse
import litellm
from litellm._logging import verbose_proxy_logger from litellm._logging import verbose_proxy_logger
from litellm.proxy._types import ProxyException from litellm.proxy._types import ProxyException
async_client = httpx.AsyncClient() async_client = httpx.AsyncClient()
async def set_env_variables_in_header(custom_headers: dict):
"""
checks if nay headers on config.yaml are defined as os.environ/COHERE_API_KEY etc
only runs for headers defined on config.yaml
example header can be
{"Authorization": "bearer os.environ/COHERE_API_KEY"}
"""
headers = {}
for key, value in custom_headers.items():
headers[key] = value
if isinstance(value, str) and "os.environ/" in value:
verbose_proxy_logger.debug(
"pass through endpoint - looking up 'os.environ/' variable"
)
# get string section that is os.environ/
start_index = value.find("os.environ/")
_variable_name = value[start_index:]
verbose_proxy_logger.debug(
"pass through endpoint - getting secret for variable name: %s",
_variable_name,
)
_secret_value = litellm.get_secret(_variable_name)
new_value = value.replace(_variable_name, _secret_value)
headers[key] = new_value
return headers
async def pass_through_request(request: Request, target: str, custom_headers: dict): async def pass_through_request(request: Request, target: str, custom_headers: dict):
try: try:
url = httpx.URL(target) url = httpx.URL(target)
# Start with the original request headers
headers = custom_headers headers = custom_headers
# headers = dict(request.headers)
request_body = await request.body() request_body = await request.body()
_parsed_body = ast.literal_eval(request_body.decode("utf-8")) _parsed_body = ast.literal_eval(request_body.decode("utf-8"))
@ -86,6 +115,9 @@ async def initialize_pass_through_endpoints(pass_through_endpoints: list):
_target = endpoint.get("target", None) _target = endpoint.get("target", None)
_path = endpoint.get("path", None) _path = endpoint.get("path", None)
_custom_headers = endpoint.get("headers", None) _custom_headers = endpoint.get("headers", None)
_custom_headers = await set_env_variables_in_header(
custom_headers=_custom_headers
)
if _target is None: if _target is None:
continue continue

View file

@ -22,6 +22,13 @@ general_settings:
master_key: sk-1234 master_key: sk-1234
alerting: ["slack", "email"] alerting: ["slack", "email"]
public_routes: ["LiteLLMRoutes.public_routes", "/spend/calculate"] public_routes: ["LiteLLMRoutes.public_routes", "/spend/calculate"]
pass_through_endpoints:
- path: "/v1/rerank"
target: "https://api.cohere.com/v1/rerank"
headers:
Authorization: "bearer os.environ/COHERE_API_KEY"
content-type: application/json
accept: application/json
litellm_settings: litellm_settings:
@ -34,6 +41,5 @@ litellm_settings:
- user - user
- metadata - metadata
- metadata.generation_name - metadata.generation_name
cache: True