Add /openai pass through route on litellm proxy (#7412)

* add pt oai route - proxy

* pass through use safe read request body
This commit is contained in:
Ishaan Jaff 2024-12-25 20:15:59 -08:00 committed by GitHub
parent c9f61b3d23
commit ced059e371
2 changed files with 63 additions and 14 deletions

View file

@ -313,6 +313,65 @@ async def azure_proxy_route(
raise Exception( raise Exception(
"Required 'AZURE_API_BASE' in environment to make pass-through calls to Azure." "Required 'AZURE_API_BASE' in environment to make pass-through calls to Azure."
) )
# Add or update query parameters
azure_api_key = get_secret_str(secret_name="AZURE_API_KEY")
if azure_api_key is None:
raise Exception(
"Required 'AZURE_API_KEY' in environment to make pass-through calls to Azure."
)
return await _base_openai_pass_through_handler(
endpoint=endpoint,
request=request,
fastapi_response=fastapi_response,
user_api_key_dict=user_api_key_dict,
base_target_url=base_target_url,
api_key=azure_api_key,
)
@router.api_route(
"/openai/{endpoint:path}",
methods=["GET", "POST", "PUT", "DELETE", "PATCH"],
tags=["OpenAI Pass-through", "pass-through"],
)
async def openai_proxy_route(
endpoint: str,
request: Request,
fastapi_response: Response,
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
):
"""
Simple pass-through for OpenAI. Use this if you want to directly send a request to OpenAI.
"""
base_target_url = "https://api.openai.com"
# Add or update query parameters
openai_api_key = get_secret_str(secret_name="OPENAI_API_KEY")
if openai_api_key is None:
raise Exception(
"Required 'OPENAI_API_KEY' in environment to make pass-through calls to OpenAI."
)
return await _base_openai_pass_through_handler(
endpoint=endpoint,
request=request,
fastapi_response=fastapi_response,
user_api_key_dict=user_api_key_dict,
base_target_url=base_target_url,
api_key=openai_api_key,
)
async def _base_openai_pass_through_handler(
endpoint: str,
request: Request,
fastapi_response: Response,
user_api_key_dict: UserAPIKeyAuth,
base_target_url: str,
api_key: str,
):
encoded_endpoint = httpx.URL(endpoint).path encoded_endpoint = httpx.URL(endpoint).path
# Ensure endpoint starts with '/' for proper URL construction # Ensure endpoint starts with '/' for proper URL construction
@ -323,9 +382,6 @@ async def azure_proxy_route(
base_url = httpx.URL(base_target_url) base_url = httpx.URL(base_target_url)
updated_url = base_url.copy_with(path=encoded_endpoint) updated_url = base_url.copy_with(path=encoded_endpoint)
# Add or update query parameters
azure_api_key = get_secret_str(secret_name="AZURE_API_KEY")
## check for streaming ## check for streaming
is_streaming_request = False is_streaming_request = False
if "stream" in str(updated_url): if "stream" in str(updated_url):
@ -336,8 +392,8 @@ async def azure_proxy_route(
endpoint=endpoint, endpoint=endpoint,
target=str(updated_url), target=str(updated_url),
custom_headers={ custom_headers={
"authorization": "Bearer {}".format(azure_api_key), "authorization": "Bearer {}".format(api_key),
"api-key": "{}".format(azure_api_key), "api-key": "{}".format(api_key),
}, },
) # dynamically construct pass-through endpoint based on incoming path ) # dynamically construct pass-through endpoint based on incoming path
received_value = await endpoint_func( received_value = await endpoint_func(

View file

@ -22,6 +22,7 @@ from litellm.proxy._types import (
UserAPIKeyAuth, UserAPIKeyAuth,
) )
from litellm.proxy.auth.user_api_key_auth import user_api_key_auth from litellm.proxy.auth.user_api_key_auth import user_api_key_auth
from litellm.proxy.common_utils.http_parsing_utils import _read_request_body
from litellm.secret_managers.main import get_secret_str from litellm.secret_managers.main import get_secret_str
from litellm.types.llms.custom_http import httpxSpecialProvider from litellm.types.llms.custom_http import httpxSpecialProvider
@ -330,15 +331,7 @@ async def pass_through_request( # noqa: PLR0915
if custom_body: if custom_body:
_parsed_body = custom_body _parsed_body = custom_body
else: else:
request_body = await request.body() _parsed_body = await _read_request_body(request)
if request_body == b"" or request_body is None:
_parsed_body = None
else:
body_str = request_body.decode()
try:
_parsed_body = ast.literal_eval(body_str)
except Exception:
_parsed_body = json.loads(body_str)
verbose_proxy_logger.debug( verbose_proxy_logger.debug(
"Pass through endpoint sending request to \nURL {}\nheaders: {}\nbody: {}\n".format( "Pass through endpoint sending request to \nURL {}\nheaders: {}\nbody: {}\n".format(
url, headers, _parsed_body url, headers, _parsed_body