litellm-mirror/litellm/proxy/pass_through_endpoints/pass_through_endpoints.py

133 lines
4.4 KiB
Python

import ast
import traceback
import httpx
from fastapi import APIRouter, FastAPI, HTTPException, Request, Response, status
from fastapi.responses import StreamingResponse
import litellm
from litellm._logging import verbose_proxy_logger
from litellm.proxy._types import ProxyException
async_client = httpx.AsyncClient()
async def set_env_variables_in_header(custom_headers: dict):
"""
checks if nay headers on config.yaml are defined as os.environ/COHERE_API_KEY etc
only runs for headers defined on config.yaml
example header can be
{"Authorization": "bearer os.environ/COHERE_API_KEY"}
"""
headers = {}
for key, value in custom_headers.items():
headers[key] = value
if isinstance(value, str) and "os.environ/" in value:
verbose_proxy_logger.debug(
"pass through endpoint - looking up 'os.environ/' variable"
)
# get string section that is os.environ/
start_index = value.find("os.environ/")
_variable_name = value[start_index:]
verbose_proxy_logger.debug(
"pass through endpoint - getting secret for variable name: %s",
_variable_name,
)
_secret_value = litellm.get_secret(_variable_name)
new_value = value.replace(_variable_name, _secret_value)
headers[key] = new_value
return headers
async def pass_through_request(request: Request, target: str, custom_headers: dict):
try:
url = httpx.URL(target)
headers = custom_headers
request_body = await request.body()
_parsed_body = ast.literal_eval(request_body.decode("utf-8"))
verbose_proxy_logger.debug(
"Pass through endpoint sending request to \nURL {}\nheaders: {}\nbody: {}\n".format(
url, headers, _parsed_body
)
)
response = await async_client.request(
method=request.method,
url=url,
headers=headers,
params=request.query_params,
json=_parsed_body,
)
if response.status_code != 200:
raise HTTPException(status_code=response.status_code, detail=response.text)
content = await response.aread()
return Response(
content=content,
status_code=response.status_code,
headers=dict(response.headers),
)
except Exception as e:
verbose_proxy_logger.error(
"litellm.proxy.proxy_server.pass through endpoint(): Exception occured - {}".format(
str(e)
)
)
verbose_proxy_logger.debug(traceback.format_exc())
if isinstance(e, HTTPException):
raise ProxyException(
message=getattr(e, "message", str(e.detail)),
type=getattr(e, "type", "None"),
param=getattr(e, "param", "None"),
code=getattr(e, "status_code", status.HTTP_400_BAD_REQUEST),
)
else:
error_msg = f"{str(e)}"
raise ProxyException(
message=getattr(e, "message", error_msg),
type=getattr(e, "type", "None"),
param=getattr(e, "param", "None"),
code=getattr(e, "status_code", 500),
)
def create_pass_through_route(endpoint, target, custom_headers=None):
async def endpoint_func(request: Request):
return await pass_through_request(request, target, custom_headers)
return endpoint_func
async def initialize_pass_through_endpoints(pass_through_endpoints: list):
verbose_proxy_logger.debug("initializing pass through endpoints")
from litellm.proxy.proxy_server import app
for endpoint in pass_through_endpoints:
_target = endpoint.get("target", None)
_path = endpoint.get("path", None)
_custom_headers = endpoint.get("headers", None)
_custom_headers = await set_env_variables_in_header(
custom_headers=_custom_headers
)
if _target is None:
continue
verbose_proxy_logger.debug("adding pass through endpoint: %s", _path)
app.add_api_route(
path=_path,
endpoint=create_pass_through_route(_path, _target, _custom_headers),
methods=["GET", "POST", "PUT", "DELETE", "PATCH"],
)
verbose_proxy_logger.debug("Added new pass through endpoint: %s", _path)