mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 11:14:04 +00:00
refactor location of proxy
This commit is contained in:
parent
baa5564f95
commit
ce58c53ff1
413 changed files with 2087 additions and 2088 deletions
|
@ -0,0 +1,965 @@
|
|||
"""
|
||||
What is this?
|
||||
|
||||
Provider-specific Pass-Through Endpoints
|
||||
|
||||
Use litellm with Anthropic SDK, Vertex AI SDK, Cohere SDK, etc.
|
||||
"""
|
||||
|
||||
import os
|
||||
from typing import Optional
|
||||
|
||||
import httpx
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request, Response
|
||||
|
||||
import litellm
|
||||
from litellm._logging import verbose_proxy_logger
|
||||
from litellm.constants import BEDROCK_AGENT_RUNTIME_PASS_THROUGH_ROUTES
|
||||
from litellm.llms.vertex_ai.vertex_llm_base import VertexBase
|
||||
from litellm_proxy._types import *
|
||||
from litellm_proxy.auth.route_checks import RouteChecks
|
||||
from litellm_proxy.auth.user_api_key_auth import user_api_key_auth
|
||||
from litellm_proxy.pass_through_endpoints.common_utils import get_litellm_virtual_key
|
||||
from litellm_proxy.pass_through_endpoints.pass_through_endpoints import (
|
||||
create_pass_through_route,
|
||||
)
|
||||
from litellm.secret_managers.main import get_secret_str
|
||||
|
||||
from .passthrough_endpoint_router import PassthroughEndpointRouter
|
||||
|
||||
vertex_llm_base = VertexBase()
|
||||
router = APIRouter()
|
||||
default_vertex_config = None
|
||||
|
||||
passthrough_endpoint_router = PassthroughEndpointRouter()
|
||||
|
||||
|
||||
def create_request_copy(request: Request):
|
||||
return {
|
||||
"method": request.method,
|
||||
"url": str(request.url),
|
||||
"headers": dict(request.headers),
|
||||
"cookies": request.cookies,
|
||||
"query_params": dict(request.query_params),
|
||||
}
|
||||
|
||||
|
||||
async def llm_passthrough_factory_proxy_route(
|
||||
custom_llm_provider: str,
|
||||
endpoint: str,
|
||||
request: Request,
|
||||
fastapi_response: Response,
|
||||
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
||||
):
|
||||
"""
|
||||
Factory function for creating pass-through endpoints for LLM providers.
|
||||
"""
|
||||
from litellm.types.utils import LlmProviders
|
||||
from litellm.utils import ProviderConfigManager
|
||||
|
||||
provider_config = ProviderConfigManager.get_provider_model_info(
|
||||
provider=LlmProviders(custom_llm_provider),
|
||||
model=None,
|
||||
)
|
||||
if provider_config is None:
|
||||
raise HTTPException(
|
||||
status_code=404, detail=f"Provider {custom_llm_provider} not found"
|
||||
)
|
||||
base_target_url = provider_config.get_api_base()
|
||||
|
||||
if base_target_url is None:
|
||||
raise HTTPException(
|
||||
status_code=404, detail=f"Provider {custom_llm_provider} api base not found"
|
||||
)
|
||||
|
||||
encoded_endpoint = httpx.URL(endpoint).path
|
||||
|
||||
# Ensure endpoint starts with '/' for proper URL construction
|
||||
if not encoded_endpoint.startswith("/"):
|
||||
encoded_endpoint = "/" + encoded_endpoint
|
||||
|
||||
# Construct the full target URL using httpx
|
||||
base_url = httpx.URL(base_target_url)
|
||||
updated_url = base_url.copy_with(path=encoded_endpoint)
|
||||
|
||||
# Add or update query parameters
|
||||
provider_api_key = passthrough_endpoint_router.get_credentials(
|
||||
custom_llm_provider=custom_llm_provider,
|
||||
region_name=None,
|
||||
)
|
||||
|
||||
auth_headers = provider_config.validate_environment(
|
||||
headers={},
|
||||
model="",
|
||||
messages=[],
|
||||
optional_params={},
|
||||
litellm_params={},
|
||||
api_key=provider_api_key,
|
||||
api_base=base_target_url,
|
||||
)
|
||||
|
||||
## check for streaming
|
||||
is_streaming_request = False
|
||||
# anthropic is streaming when 'stream' = True is in the body
|
||||
if request.method == "POST":
|
||||
_request_body = await request.json()
|
||||
if _request_body.get("stream"):
|
||||
is_streaming_request = True
|
||||
|
||||
## CREATE PASS-THROUGH
|
||||
endpoint_func = create_pass_through_route(
|
||||
endpoint=endpoint,
|
||||
target=str(updated_url),
|
||||
custom_headers=auth_headers,
|
||||
) # dynamically construct pass-through endpoint based on incoming path
|
||||
received_value = await endpoint_func(
|
||||
request,
|
||||
fastapi_response,
|
||||
user_api_key_dict,
|
||||
stream=is_streaming_request, # type: ignore
|
||||
)
|
||||
|
||||
return received_value
|
||||
|
||||
|
||||
@router.api_route(
|
||||
"/gemini/{endpoint:path}",
|
||||
methods=["GET", "POST", "PUT", "DELETE", "PATCH"],
|
||||
tags=["Google AI Studio Pass-through", "pass-through"],
|
||||
)
|
||||
async def gemini_proxy_route(
|
||||
endpoint: str,
|
||||
request: Request,
|
||||
fastapi_response: Response,
|
||||
):
|
||||
"""
|
||||
[Docs](https://docs.litellm.ai/docs/pass_through/google_ai_studio)
|
||||
"""
|
||||
## CHECK FOR LITELLM API KEY IN THE QUERY PARAMS - ?..key=LITELLM_API_KEY
|
||||
google_ai_studio_api_key = request.query_params.get("key") or request.headers.get(
|
||||
"x-goog-api-key"
|
||||
)
|
||||
|
||||
user_api_key_dict = await user_api_key_auth(
|
||||
request=request, api_key=f"Bearer {google_ai_studio_api_key}"
|
||||
)
|
||||
|
||||
base_target_url = "https://generativelanguage.googleapis.com"
|
||||
encoded_endpoint = httpx.URL(endpoint).path
|
||||
|
||||
# Ensure endpoint starts with '/' for proper URL construction
|
||||
if not encoded_endpoint.startswith("/"):
|
||||
encoded_endpoint = "/" + encoded_endpoint
|
||||
|
||||
# Construct the full target URL using httpx
|
||||
base_url = httpx.URL(base_target_url)
|
||||
updated_url = base_url.copy_with(path=encoded_endpoint)
|
||||
|
||||
# Add or update query parameters
|
||||
gemini_api_key: Optional[str] = passthrough_endpoint_router.get_credentials(
|
||||
custom_llm_provider="gemini",
|
||||
region_name=None,
|
||||
)
|
||||
if gemini_api_key is None:
|
||||
raise Exception(
|
||||
"Required 'GEMINI_API_KEY' in environment to make pass-through calls to Google AI Studio."
|
||||
)
|
||||
# Merge query parameters, giving precedence to those in updated_url
|
||||
merged_params = dict(request.query_params)
|
||||
merged_params.update({"key": gemini_api_key})
|
||||
|
||||
## check for streaming
|
||||
is_streaming_request = False
|
||||
if "stream" in str(updated_url):
|
||||
is_streaming_request = True
|
||||
|
||||
## CREATE PASS-THROUGH
|
||||
endpoint_func = create_pass_through_route(
|
||||
endpoint=endpoint,
|
||||
target=str(updated_url),
|
||||
) # dynamically construct pass-through endpoint based on incoming path
|
||||
received_value = await endpoint_func(
|
||||
request,
|
||||
fastapi_response,
|
||||
user_api_key_dict,
|
||||
query_params=merged_params, # type: ignore
|
||||
stream=is_streaming_request, # type: ignore
|
||||
)
|
||||
|
||||
return received_value
|
||||
|
||||
|
||||
@router.api_route(
|
||||
"/cohere/{endpoint:path}",
|
||||
methods=["GET", "POST", "PUT", "DELETE", "PATCH"],
|
||||
tags=["Cohere Pass-through", "pass-through"],
|
||||
)
|
||||
async def cohere_proxy_route(
|
||||
endpoint: str,
|
||||
request: Request,
|
||||
fastapi_response: Response,
|
||||
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
||||
):
|
||||
"""
|
||||
[Docs](https://docs.litellm.ai/docs/pass_through/cohere)
|
||||
"""
|
||||
base_target_url = "https://api.cohere.com"
|
||||
encoded_endpoint = httpx.URL(endpoint).path
|
||||
|
||||
# Ensure endpoint starts with '/' for proper URL construction
|
||||
if not encoded_endpoint.startswith("/"):
|
||||
encoded_endpoint = "/" + encoded_endpoint
|
||||
|
||||
# Construct the full target URL using httpx
|
||||
base_url = httpx.URL(base_target_url)
|
||||
updated_url = base_url.copy_with(path=encoded_endpoint)
|
||||
|
||||
# Add or update query parameters
|
||||
cohere_api_key = passthrough_endpoint_router.get_credentials(
|
||||
custom_llm_provider="cohere",
|
||||
region_name=None,
|
||||
)
|
||||
|
||||
## check for streaming
|
||||
is_streaming_request = False
|
||||
if "stream" in str(updated_url):
|
||||
is_streaming_request = True
|
||||
|
||||
## CREATE PASS-THROUGH
|
||||
endpoint_func = create_pass_through_route(
|
||||
endpoint=endpoint,
|
||||
target=str(updated_url),
|
||||
custom_headers={"Authorization": "Bearer {}".format(cohere_api_key)},
|
||||
) # dynamically construct pass-through endpoint based on incoming path
|
||||
received_value = await endpoint_func(
|
||||
request,
|
||||
fastapi_response,
|
||||
user_api_key_dict,
|
||||
stream=is_streaming_request, # type: ignore
|
||||
)
|
||||
|
||||
return received_value
|
||||
|
||||
|
||||
@router.api_route(
|
||||
"/vllm/{endpoint:path}",
|
||||
methods=["GET", "POST", "PUT", "DELETE", "PATCH"],
|
||||
tags=["VLLM Pass-through", "pass-through"],
|
||||
)
|
||||
async def vllm_proxy_route(
|
||||
endpoint: str,
|
||||
request: Request,
|
||||
fastapi_response: Response,
|
||||
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
||||
):
|
||||
"""
|
||||
[Docs](https://docs.litellm.ai/docs/pass_through/vllm)
|
||||
"""
|
||||
return await llm_passthrough_factory_proxy_route(
|
||||
endpoint=endpoint,
|
||||
request=request,
|
||||
fastapi_response=fastapi_response,
|
||||
user_api_key_dict=user_api_key_dict,
|
||||
custom_llm_provider="vllm",
|
||||
)
|
||||
|
||||
|
||||
@router.api_route(
|
||||
"/mistral/{endpoint:path}",
|
||||
methods=["GET", "POST", "PUT", "DELETE", "PATCH"],
|
||||
tags=["Mistral Pass-through", "pass-through"],
|
||||
)
|
||||
async def mistral_proxy_route(
|
||||
endpoint: str,
|
||||
request: Request,
|
||||
fastapi_response: Response,
|
||||
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
||||
):
|
||||
"""
|
||||
[Docs](https://docs.litellm.ai/docs/anthropic_completion)
|
||||
"""
|
||||
base_target_url = os.getenv("MISTRAL_API_BASE") or "https://api.mistral.ai"
|
||||
encoded_endpoint = httpx.URL(endpoint).path
|
||||
|
||||
# Ensure endpoint starts with '/' for proper URL construction
|
||||
if not encoded_endpoint.startswith("/"):
|
||||
encoded_endpoint = "/" + encoded_endpoint
|
||||
|
||||
# Construct the full target URL using httpx
|
||||
base_url = httpx.URL(base_target_url)
|
||||
updated_url = base_url.copy_with(path=encoded_endpoint)
|
||||
|
||||
# Add or update query parameters
|
||||
mistral_api_key = passthrough_endpoint_router.get_credentials(
|
||||
custom_llm_provider="mistral",
|
||||
region_name=None,
|
||||
)
|
||||
|
||||
## check for streaming
|
||||
is_streaming_request = False
|
||||
# anthropic is streaming when 'stream' = True is in the body
|
||||
if request.method == "POST":
|
||||
_request_body = await request.json()
|
||||
if _request_body.get("stream"):
|
||||
is_streaming_request = True
|
||||
|
||||
## CREATE PASS-THROUGH
|
||||
endpoint_func = create_pass_through_route(
|
||||
endpoint=endpoint,
|
||||
target=str(updated_url),
|
||||
custom_headers={"Authorization": "Bearer {}".format(mistral_api_key)},
|
||||
) # dynamically construct pass-through endpoint based on incoming path
|
||||
received_value = await endpoint_func(
|
||||
request,
|
||||
fastapi_response,
|
||||
user_api_key_dict,
|
||||
stream=is_streaming_request, # type: ignore
|
||||
)
|
||||
|
||||
return received_value
|
||||
|
||||
|
||||
@router.api_route(
|
||||
"/anthropic/{endpoint:path}",
|
||||
methods=["GET", "POST", "PUT", "DELETE", "PATCH"],
|
||||
tags=["Anthropic Pass-through", "pass-through"],
|
||||
)
|
||||
async def anthropic_proxy_route(
|
||||
endpoint: str,
|
||||
request: Request,
|
||||
fastapi_response: Response,
|
||||
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
||||
):
|
||||
"""
|
||||
[Docs](https://docs.litellm.ai/docs/anthropic_completion)
|
||||
"""
|
||||
base_target_url = "https://api.anthropic.com"
|
||||
encoded_endpoint = httpx.URL(endpoint).path
|
||||
|
||||
# Ensure endpoint starts with '/' for proper URL construction
|
||||
if not encoded_endpoint.startswith("/"):
|
||||
encoded_endpoint = "/" + encoded_endpoint
|
||||
|
||||
# Construct the full target URL using httpx
|
||||
base_url = httpx.URL(base_target_url)
|
||||
updated_url = base_url.copy_with(path=encoded_endpoint)
|
||||
|
||||
# Add or update query parameters
|
||||
anthropic_api_key = passthrough_endpoint_router.get_credentials(
|
||||
custom_llm_provider="anthropic",
|
||||
region_name=None,
|
||||
)
|
||||
|
||||
## check for streaming
|
||||
is_streaming_request = False
|
||||
# anthropic is streaming when 'stream' = True is in the body
|
||||
if request.method == "POST":
|
||||
_request_body = await request.json()
|
||||
if _request_body.get("stream"):
|
||||
is_streaming_request = True
|
||||
|
||||
## CREATE PASS-THROUGH
|
||||
endpoint_func = create_pass_through_route(
|
||||
endpoint=endpoint,
|
||||
target=str(updated_url),
|
||||
custom_headers={"x-api-key": "{}".format(anthropic_api_key)},
|
||||
_forward_headers=True,
|
||||
) # dynamically construct pass-through endpoint based on incoming path
|
||||
received_value = await endpoint_func(
|
||||
request,
|
||||
fastapi_response,
|
||||
user_api_key_dict,
|
||||
stream=is_streaming_request, # type: ignore
|
||||
)
|
||||
|
||||
return received_value
|
||||
|
||||
|
||||
@router.api_route(
|
||||
"/bedrock/{endpoint:path}",
|
||||
methods=["GET", "POST", "PUT", "DELETE", "PATCH"],
|
||||
tags=["Bedrock Pass-through", "pass-through"],
|
||||
)
|
||||
async def bedrock_proxy_route(
|
||||
endpoint: str,
|
||||
request: Request,
|
||||
fastapi_response: Response,
|
||||
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
||||
):
|
||||
"""
|
||||
[Docs](https://docs.litellm.ai/docs/pass_through/bedrock)
|
||||
"""
|
||||
create_request_copy(request)
|
||||
|
||||
try:
|
||||
from botocore.auth import SigV4Auth
|
||||
from botocore.awsrequest import AWSRequest
|
||||
from botocore.credentials import Credentials
|
||||
except ImportError:
|
||||
raise ImportError("Missing boto3 to call bedrock. Run 'pip install boto3'.")
|
||||
|
||||
aws_region_name = litellm.utils.get_secret(secret_name="AWS_REGION_NAME")
|
||||
if _is_bedrock_agent_runtime_route(endpoint=endpoint): # handle bedrock agents
|
||||
base_target_url = (
|
||||
f"https://bedrock-agent-runtime.{aws_region_name}.amazonaws.com"
|
||||
)
|
||||
else:
|
||||
base_target_url = f"https://bedrock-runtime.{aws_region_name}.amazonaws.com"
|
||||
encoded_endpoint = httpx.URL(endpoint).path
|
||||
|
||||
# Ensure endpoint starts with '/' for proper URL construction
|
||||
if not encoded_endpoint.startswith("/"):
|
||||
encoded_endpoint = "/" + encoded_endpoint
|
||||
|
||||
# Construct the full target URL using httpx
|
||||
base_url = httpx.URL(base_target_url)
|
||||
updated_url = base_url.copy_with(path=encoded_endpoint)
|
||||
|
||||
# Add or update query parameters
|
||||
from litellm.llms.bedrock.chat import BedrockConverseLLM
|
||||
|
||||
credentials: Credentials = BedrockConverseLLM().get_credentials()
|
||||
sigv4 = SigV4Auth(credentials, "bedrock", aws_region_name)
|
||||
headers = {"Content-Type": "application/json"}
|
||||
# Assuming the body contains JSON data, parse it
|
||||
try:
|
||||
data = await request.json()
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=400, detail={"error": e})
|
||||
_request = AWSRequest(
|
||||
method="POST", url=str(updated_url), data=json.dumps(data), headers=headers
|
||||
)
|
||||
sigv4.add_auth(_request)
|
||||
prepped = _request.prepare()
|
||||
|
||||
## check for streaming
|
||||
is_streaming_request = False
|
||||
if "stream" in str(updated_url):
|
||||
is_streaming_request = True
|
||||
|
||||
## CREATE PASS-THROUGH
|
||||
endpoint_func = create_pass_through_route(
|
||||
endpoint=endpoint,
|
||||
target=str(prepped.url),
|
||||
custom_headers=prepped.headers, # type: ignore
|
||||
) # dynamically construct pass-through endpoint based on incoming path
|
||||
received_value = await endpoint_func(
|
||||
request,
|
||||
fastapi_response,
|
||||
user_api_key_dict,
|
||||
stream=is_streaming_request, # type: ignore
|
||||
custom_body=data, # type: ignore
|
||||
query_params={}, # type: ignore
|
||||
)
|
||||
|
||||
return received_value
|
||||
|
||||
|
||||
def _is_bedrock_agent_runtime_route(endpoint: str) -> bool:
|
||||
"""
|
||||
Return True, if the endpoint should be routed to the `bedrock-agent-runtime` endpoint.
|
||||
"""
|
||||
for _route in BEDROCK_AGENT_RUNTIME_PASS_THROUGH_ROUTES:
|
||||
if _route in endpoint:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
@router.api_route(
|
||||
"/assemblyai/{endpoint:path}",
|
||||
methods=["GET", "POST", "PUT", "DELETE", "PATCH"],
|
||||
tags=["AssemblyAI Pass-through", "pass-through"],
|
||||
)
|
||||
@router.api_route(
|
||||
"/eu.assemblyai/{endpoint:path}",
|
||||
methods=["GET", "POST", "PUT", "DELETE", "PATCH"],
|
||||
tags=["AssemblyAI EU Pass-through", "pass-through"],
|
||||
)
|
||||
async def assemblyai_proxy_route(
|
||||
endpoint: str,
|
||||
request: Request,
|
||||
fastapi_response: Response,
|
||||
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
||||
):
|
||||
from litellm_proxy.pass_through_endpoints.llm_provider_handlers.assembly_passthrough_logging_handler import (
|
||||
AssemblyAIPassthroughLoggingHandler,
|
||||
)
|
||||
|
||||
"""
|
||||
[Docs](https://api.assemblyai.com)
|
||||
"""
|
||||
# Set base URL based on the route
|
||||
assembly_region = AssemblyAIPassthroughLoggingHandler._get_assembly_region_from_url(
|
||||
url=str(request.url)
|
||||
)
|
||||
base_target_url = (
|
||||
AssemblyAIPassthroughLoggingHandler._get_assembly_base_url_from_region(
|
||||
region=assembly_region
|
||||
)
|
||||
)
|
||||
encoded_endpoint = httpx.URL(endpoint).path
|
||||
# Ensure endpoint starts with '/' for proper URL construction
|
||||
if not encoded_endpoint.startswith("/"):
|
||||
encoded_endpoint = "/" + encoded_endpoint
|
||||
|
||||
# Construct the full target URL using httpx
|
||||
base_url = httpx.URL(base_target_url)
|
||||
updated_url = base_url.copy_with(path=encoded_endpoint)
|
||||
|
||||
# Add or update query parameters
|
||||
assemblyai_api_key = passthrough_endpoint_router.get_credentials(
|
||||
custom_llm_provider="assemblyai",
|
||||
region_name=assembly_region,
|
||||
)
|
||||
|
||||
## check for streaming
|
||||
is_streaming_request = False
|
||||
# assemblyai is streaming when 'stream' = True is in the body
|
||||
if request.method == "POST":
|
||||
_request_body = await request.json()
|
||||
if _request_body.get("stream"):
|
||||
is_streaming_request = True
|
||||
|
||||
## CREATE PASS-THROUGH
|
||||
endpoint_func = create_pass_through_route(
|
||||
endpoint=endpoint,
|
||||
target=str(updated_url),
|
||||
custom_headers={"Authorization": "{}".format(assemblyai_api_key)},
|
||||
) # dynamically construct pass-through endpoint based on incoming path
|
||||
received_value = await endpoint_func(
|
||||
request=request,
|
||||
fastapi_response=fastapi_response,
|
||||
user_api_key_dict=user_api_key_dict,
|
||||
stream=is_streaming_request, # type: ignore
|
||||
)
|
||||
|
||||
return received_value
|
||||
|
||||
|
||||
@router.api_route(
|
||||
"/azure/{endpoint:path}",
|
||||
methods=["GET", "POST", "PUT", "DELETE", "PATCH"],
|
||||
tags=["Azure Pass-through", "pass-through"],
|
||||
)
|
||||
async def azure_proxy_route(
|
||||
endpoint: str,
|
||||
request: Request,
|
||||
fastapi_response: Response,
|
||||
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
||||
):
|
||||
"""
|
||||
Call any azure endpoint using the proxy.
|
||||
|
||||
Just use `{PROXY_BASE_URL}/azure/{endpoint:path}`
|
||||
"""
|
||||
base_target_url = get_secret_str(secret_name="AZURE_API_BASE")
|
||||
if base_target_url is None:
|
||||
raise Exception(
|
||||
"Required 'AZURE_API_BASE' in environment to make pass-through calls to Azure."
|
||||
)
|
||||
# Add or update query parameters
|
||||
azure_api_key = passthrough_endpoint_router.get_credentials(
|
||||
custom_llm_provider=litellm.LlmProviders.AZURE.value,
|
||||
region_name=None,
|
||||
)
|
||||
if azure_api_key is None:
|
||||
raise Exception(
|
||||
"Required 'AZURE_API_KEY' in environment to make pass-through calls to Azure."
|
||||
)
|
||||
|
||||
return await BaseOpenAIPassThroughHandler._base_openai_pass_through_handler(
|
||||
endpoint=endpoint,
|
||||
request=request,
|
||||
fastapi_response=fastapi_response,
|
||||
user_api_key_dict=user_api_key_dict,
|
||||
base_target_url=base_target_url,
|
||||
api_key=azure_api_key,
|
||||
custom_llm_provider=litellm.LlmProviders.AZURE,
|
||||
)
|
||||
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
|
||||
class BaseVertexAIPassThroughHandler(ABC):
|
||||
@staticmethod
|
||||
@abstractmethod
|
||||
def get_default_base_target_url(vertex_location: Optional[str]) -> str:
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
@abstractmethod
|
||||
def update_base_target_url_with_credential_location(
|
||||
base_target_url: str, vertex_location: Optional[str]
|
||||
) -> str:
|
||||
pass
|
||||
|
||||
|
||||
class VertexAIDiscoveryPassThroughHandler(BaseVertexAIPassThroughHandler):
|
||||
@staticmethod
|
||||
def get_default_base_target_url(vertex_location: Optional[str]) -> str:
|
||||
return "https://discoveryengine.googleapis.com/"
|
||||
|
||||
@staticmethod
|
||||
def update_base_target_url_with_credential_location(
|
||||
base_target_url: str, vertex_location: Optional[str]
|
||||
) -> str:
|
||||
return base_target_url
|
||||
|
||||
|
||||
class VertexAIPassThroughHandler(BaseVertexAIPassThroughHandler):
|
||||
@staticmethod
|
||||
def get_default_base_target_url(vertex_location: Optional[str]) -> str:
|
||||
return f"https://{vertex_location}-aiplatform.googleapis.com/"
|
||||
|
||||
@staticmethod
|
||||
def update_base_target_url_with_credential_location(
|
||||
base_target_url: str, vertex_location: Optional[str]
|
||||
) -> str:
|
||||
return f"https://{vertex_location}-aiplatform.googleapis.com/"
|
||||
|
||||
|
||||
def get_vertex_pass_through_handler(
|
||||
call_type: Literal["discovery", "aiplatform"]
|
||||
) -> BaseVertexAIPassThroughHandler:
|
||||
if call_type == "discovery":
|
||||
return VertexAIDiscoveryPassThroughHandler()
|
||||
elif call_type == "aiplatform":
|
||||
return VertexAIPassThroughHandler()
|
||||
else:
|
||||
raise ValueError(f"Invalid call type: {call_type}")
|
||||
|
||||
|
||||
async def _base_vertex_proxy_route(
|
||||
endpoint: str,
|
||||
request: Request,
|
||||
fastapi_response: Response,
|
||||
get_vertex_pass_through_handler: BaseVertexAIPassThroughHandler,
|
||||
user_api_key_dict: Optional[UserAPIKeyAuth] = None,
|
||||
):
|
||||
"""
|
||||
Base function for Vertex AI passthrough routes.
|
||||
Handles common logic for all Vertex AI services.
|
||||
|
||||
Default base_target_url is `https://{vertex_location}-aiplatform.googleapis.com/`
|
||||
"""
|
||||
from litellm.llms.vertex_ai.common_utils import (
|
||||
construct_target_url,
|
||||
get_vertex_location_from_url,
|
||||
get_vertex_project_id_from_url,
|
||||
)
|
||||
|
||||
encoded_endpoint = httpx.URL(endpoint).path
|
||||
verbose_proxy_logger.debug("requested endpoint %s", endpoint)
|
||||
headers: dict = {}
|
||||
api_key_to_use = get_litellm_virtual_key(request=request)
|
||||
user_api_key_dict = await user_api_key_auth(
|
||||
request=request,
|
||||
api_key=api_key_to_use,
|
||||
)
|
||||
|
||||
if user_api_key_dict is None:
|
||||
api_key_to_use = get_litellm_virtual_key(request=request)
|
||||
user_api_key_dict = await user_api_key_auth(
|
||||
request=request,
|
||||
api_key=api_key_to_use,
|
||||
)
|
||||
|
||||
vertex_project: Optional[str] = get_vertex_project_id_from_url(endpoint)
|
||||
vertex_location: Optional[str] = get_vertex_location_from_url(endpoint)
|
||||
vertex_credentials = passthrough_endpoint_router.get_vertex_credentials(
|
||||
project_id=vertex_project,
|
||||
location=vertex_location,
|
||||
)
|
||||
|
||||
base_target_url = get_vertex_pass_through_handler.get_default_base_target_url(
|
||||
vertex_location
|
||||
)
|
||||
|
||||
headers_passed_through = False
|
||||
# Use headers from the incoming request if no vertex credentials are found
|
||||
if vertex_credentials is None or vertex_credentials.vertex_project is None:
|
||||
headers = dict(request.headers) or {}
|
||||
headers_passed_through = True
|
||||
verbose_proxy_logger.debug(
|
||||
"default_vertex_config not set, incoming request headers %s", headers
|
||||
)
|
||||
headers.pop("content-length", None)
|
||||
headers.pop("host", None)
|
||||
else:
|
||||
vertex_project = vertex_credentials.vertex_project
|
||||
vertex_location = vertex_credentials.vertex_location
|
||||
vertex_credentials_str = vertex_credentials.vertex_credentials
|
||||
|
||||
_auth_header, vertex_project = await vertex_llm_base._ensure_access_token_async(
|
||||
credentials=vertex_credentials_str,
|
||||
project_id=vertex_project,
|
||||
custom_llm_provider="vertex_ai_beta",
|
||||
)
|
||||
|
||||
auth_header, _ = vertex_llm_base._get_token_and_url(
|
||||
model="",
|
||||
auth_header=_auth_header,
|
||||
gemini_api_key=None,
|
||||
vertex_credentials=vertex_credentials_str,
|
||||
vertex_project=vertex_project,
|
||||
vertex_location=vertex_location,
|
||||
stream=False,
|
||||
custom_llm_provider="vertex_ai_beta",
|
||||
api_base="",
|
||||
)
|
||||
|
||||
headers = {
|
||||
"Authorization": f"Bearer {auth_header}",
|
||||
}
|
||||
|
||||
base_target_url = get_vertex_pass_through_handler.update_base_target_url_with_credential_location(
|
||||
base_target_url, vertex_location
|
||||
)
|
||||
|
||||
if base_target_url is None:
|
||||
base_target_url = f"https://{vertex_location}-aiplatform.googleapis.com/"
|
||||
|
||||
request_route = encoded_endpoint
|
||||
verbose_proxy_logger.debug("request_route %s", request_route)
|
||||
|
||||
# Ensure endpoint starts with '/' for proper URL construction
|
||||
if not encoded_endpoint.startswith("/"):
|
||||
encoded_endpoint = "/" + encoded_endpoint
|
||||
|
||||
# Construct the full target URL using httpx
|
||||
updated_url = construct_target_url(
|
||||
base_url=base_target_url,
|
||||
requested_route=encoded_endpoint,
|
||||
vertex_location=vertex_location,
|
||||
vertex_project=vertex_project,
|
||||
)
|
||||
|
||||
verbose_proxy_logger.debug("updated url %s", updated_url)
|
||||
|
||||
## check for streaming
|
||||
target = str(updated_url)
|
||||
is_streaming_request = False
|
||||
if "stream" in str(updated_url):
|
||||
is_streaming_request = True
|
||||
target += "?alt=sse"
|
||||
|
||||
## CREATE PASS-THROUGH
|
||||
endpoint_func = create_pass_through_route(
|
||||
endpoint=endpoint,
|
||||
target=target,
|
||||
custom_headers=headers,
|
||||
) # dynamically construct pass-through endpoint based on incoming path
|
||||
|
||||
try:
|
||||
received_value = await endpoint_func(
|
||||
request,
|
||||
fastapi_response,
|
||||
user_api_key_dict,
|
||||
stream=is_streaming_request, # type: ignore
|
||||
)
|
||||
except ProxyException as e:
|
||||
if headers_passed_through:
|
||||
e.message = f"No credentials found on proxy for project_name={vertex_project} + location={vertex_location}, check `/model/info` for allowed project + region combinations with `use_in_pass_through: true`. Headers were passed through directly but request failed with error: {e.message}"
|
||||
raise e
|
||||
|
||||
return received_value
|
||||
|
||||
|
||||
@router.api_route(
|
||||
"/vertex_ai/discovery/{endpoint:path}",
|
||||
methods=["GET", "POST", "PUT", "DELETE", "PATCH"],
|
||||
tags=["Vertex AI Pass-through", "pass-through"],
|
||||
)
|
||||
async def vertex_discovery_proxy_route(
|
||||
endpoint: str,
|
||||
request: Request,
|
||||
fastapi_response: Response,
|
||||
):
|
||||
"""
|
||||
Call any vertex discovery endpoint using the proxy.
|
||||
|
||||
Just use `{PROXY_BASE_URL}/vertex_ai/discovery/{endpoint:path}`
|
||||
|
||||
Target url: `https://discoveryengine.googleapis.com`
|
||||
"""
|
||||
|
||||
discovery_handler = get_vertex_pass_through_handler(call_type="discovery")
|
||||
return await _base_vertex_proxy_route(
|
||||
endpoint=endpoint,
|
||||
request=request,
|
||||
fastapi_response=fastapi_response,
|
||||
get_vertex_pass_through_handler=discovery_handler,
|
||||
)
|
||||
|
||||
|
||||
@router.api_route(
|
||||
"/vertex-ai/{endpoint:path}",
|
||||
methods=["GET", "POST", "PUT", "DELETE", "PATCH"],
|
||||
tags=["Vertex AI Pass-through", "pass-through"],
|
||||
include_in_schema=False,
|
||||
)
|
||||
@router.api_route(
|
||||
"/vertex_ai/{endpoint:path}",
|
||||
methods=["GET", "POST", "PUT", "DELETE", "PATCH"],
|
||||
tags=["Vertex AI Pass-through", "pass-through"],
|
||||
)
|
||||
async def vertex_proxy_route(
|
||||
endpoint: str,
|
||||
request: Request,
|
||||
fastapi_response: Response,
|
||||
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
||||
):
|
||||
"""
|
||||
Call LiteLLM proxy via Vertex AI SDK.
|
||||
|
||||
[Docs](https://docs.litellm.ai/docs/pass_through/vertex_ai)
|
||||
"""
|
||||
ai_platform_handler = get_vertex_pass_through_handler(call_type="aiplatform")
|
||||
|
||||
return await _base_vertex_proxy_route(
|
||||
endpoint=endpoint,
|
||||
request=request,
|
||||
fastapi_response=fastapi_response,
|
||||
get_vertex_pass_through_handler=ai_platform_handler,
|
||||
user_api_key_dict=user_api_key_dict,
|
||||
)
|
||||
|
||||
|
||||
@router.api_route(
|
||||
"/openai/{endpoint:path}",
|
||||
methods=["GET", "POST", "PUT", "DELETE", "PATCH"],
|
||||
tags=["OpenAI Pass-through", "pass-through"],
|
||||
)
|
||||
async def openai_proxy_route(
|
||||
endpoint: str,
|
||||
request: Request,
|
||||
fastapi_response: Response,
|
||||
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
||||
):
|
||||
"""
|
||||
Simple pass-through for OpenAI. Use this if you want to directly send a request to OpenAI.
|
||||
|
||||
|
||||
"""
|
||||
base_target_url = "https://api.openai.com/"
|
||||
# Add or update query parameters
|
||||
openai_api_key = passthrough_endpoint_router.get_credentials(
|
||||
custom_llm_provider=litellm.LlmProviders.OPENAI.value,
|
||||
region_name=None,
|
||||
)
|
||||
if openai_api_key is None:
|
||||
raise Exception(
|
||||
"Required 'OPENAI_API_KEY' in environment to make pass-through calls to OpenAI."
|
||||
)
|
||||
|
||||
return await BaseOpenAIPassThroughHandler._base_openai_pass_through_handler(
|
||||
endpoint=endpoint,
|
||||
request=request,
|
||||
fastapi_response=fastapi_response,
|
||||
user_api_key_dict=user_api_key_dict,
|
||||
base_target_url=base_target_url,
|
||||
api_key=openai_api_key,
|
||||
custom_llm_provider=litellm.LlmProviders.OPENAI,
|
||||
)
|
||||
|
||||
|
||||
class BaseOpenAIPassThroughHandler:
|
||||
@staticmethod
|
||||
async def _base_openai_pass_through_handler(
|
||||
endpoint: str,
|
||||
request: Request,
|
||||
fastapi_response: Response,
|
||||
user_api_key_dict: UserAPIKeyAuth,
|
||||
base_target_url: str,
|
||||
api_key: str,
|
||||
custom_llm_provider: litellm.LlmProviders,
|
||||
):
|
||||
encoded_endpoint = httpx.URL(endpoint).path
|
||||
# Ensure endpoint starts with '/' for proper URL construction
|
||||
if not encoded_endpoint.startswith("/"):
|
||||
encoded_endpoint = "/" + encoded_endpoint
|
||||
|
||||
# Construct the full target URL by properly joining the base URL and endpoint path
|
||||
base_url = httpx.URL(base_target_url)
|
||||
updated_url = BaseOpenAIPassThroughHandler._join_url_paths(
|
||||
base_url=base_url,
|
||||
path=encoded_endpoint,
|
||||
custom_llm_provider=custom_llm_provider,
|
||||
)
|
||||
|
||||
## check for streaming
|
||||
is_streaming_request = False
|
||||
if "stream" in str(updated_url):
|
||||
is_streaming_request = True
|
||||
|
||||
## CREATE PASS-THROUGH
|
||||
endpoint_func = create_pass_through_route(
|
||||
endpoint=endpoint,
|
||||
target=str(updated_url),
|
||||
custom_headers=BaseOpenAIPassThroughHandler._assemble_headers(
|
||||
api_key=api_key, request=request
|
||||
),
|
||||
) # dynamically construct pass-through endpoint based on incoming path
|
||||
received_value = await endpoint_func(
|
||||
request,
|
||||
fastapi_response,
|
||||
user_api_key_dict,
|
||||
stream=is_streaming_request, # type: ignore
|
||||
query_params=dict(request.query_params), # type: ignore
|
||||
)
|
||||
|
||||
return received_value
|
||||
|
||||
@staticmethod
|
||||
def _append_openai_beta_header(headers: dict, request: Request) -> dict:
|
||||
"""
|
||||
Appends the OpenAI-Beta header to the headers if the request is an OpenAI Assistants API request
|
||||
"""
|
||||
if (
|
||||
RouteChecks._is_assistants_api_request(request) is True
|
||||
and "OpenAI-Beta" not in headers
|
||||
):
|
||||
headers["OpenAI-Beta"] = "assistants=v2"
|
||||
return headers
|
||||
|
||||
@staticmethod
|
||||
def _assemble_headers(api_key: str, request: Request) -> dict:
|
||||
base_headers = {
|
||||
"authorization": "Bearer {}".format(api_key),
|
||||
"api-key": "{}".format(api_key),
|
||||
}
|
||||
return BaseOpenAIPassThroughHandler._append_openai_beta_header(
|
||||
headers=base_headers,
|
||||
request=request,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _join_url_paths(
|
||||
base_url: httpx.URL, path: str, custom_llm_provider: litellm.LlmProviders
|
||||
) -> str:
|
||||
"""
|
||||
Properly joins a base URL with a path, preserving any existing path in the base URL.
|
||||
"""
|
||||
# Join paths correctly by removing trailing/leading slashes as needed
|
||||
if not base_url.path or base_url.path == "/":
|
||||
# If base URL has no path, just use the new path
|
||||
joined_path_str = str(base_url.copy_with(path=path))
|
||||
else:
|
||||
# Otherwise, combine the paths
|
||||
base_path = base_url.path.rstrip("/")
|
||||
clean_path = path.lstrip("/")
|
||||
full_path = f"{base_path}/{clean_path}"
|
||||
joined_path_str = str(base_url.copy_with(path=full_path))
|
||||
|
||||
# Apply OpenAI-specific path handling for both branches
|
||||
if (
|
||||
custom_llm_provider == litellm.LlmProviders.OPENAI
|
||||
and "/v1/" not in joined_path_str
|
||||
):
|
||||
# Insert v1 after api.openai.com for OpenAI requests
|
||||
joined_path_str = joined_path_str.replace(
|
||||
"api.openai.com/", "api.openai.com/v1/"
|
||||
)
|
||||
|
||||
return joined_path_str
|
Loading…
Add table
Add a link
Reference in a new issue