""" What is this? Google AI Studio Pass-Through Endpoints """ """ 1. Create pass-through endpoints for any LITELLM_BASE_URL/gemini/ map to https://generativelanguage.googleapis.com/ """ import ast import asyncio import traceback from datetime import datetime, timedelta, timezone from typing import List, Optional from urllib.parse import urlencode import fastapi import httpx from fastapi import ( APIRouter, Depends, File, Form, Header, HTTPException, Request, Response, UploadFile, status, ) from starlette.datastructures import QueryParams import litellm from litellm._logging import verbose_proxy_logger from litellm.batches.main import FileObject from litellm.fine_tuning.main import vertex_fine_tuning_apis_instance from litellm.proxy._types import * from litellm.proxy.auth.user_api_key_auth import user_api_key_auth from litellm.proxy.pass_through_endpoints.pass_through_endpoints import ( create_pass_through_route, ) router = APIRouter() default_vertex_config = None def create_request_copy(request: Request): return { "method": request.method, "url": str(request.url), "headers": dict(request.headers), "cookies": request.cookies, "query_params": dict(request.query_params), } @router.api_route("/gemini/{endpoint:path}", methods=["GET", "POST", "PUT", "DELETE"]) async def gemini_proxy_route( endpoint: str, request: Request, fastapi_response: Response, ): ## CHECK FOR LITELLM API KEY IN THE QUERY PARAMS - ?..key=LITELLM_API_KEY api_key = request.query_params.get("key") user_api_key_dict = await user_api_key_auth( request=request, api_key="Bearer {}".format(api_key) ) base_target_url = "https://generativelanguage.googleapis.com" encoded_endpoint = httpx.URL(endpoint).path # Ensure endpoint starts with '/' for proper URL construction if not encoded_endpoint.startswith("/"): encoded_endpoint = "/" + encoded_endpoint # Construct the full target URL using httpx base_url = httpx.URL(base_target_url) updated_url = base_url.copy_with(path=encoded_endpoint) # Add or update query parameters gemini_api_key = litellm.utils.get_secret(secret_name="GEMINI_API_KEY") # Merge query parameters, giving precedence to those in updated_url merged_params = dict(request.query_params) merged_params.update({"key": gemini_api_key}) ## check for streaming is_streaming_request = False if "stream" in str(updated_url): is_streaming_request = True ## CREATE PASS-THROUGH endpoint_func = create_pass_through_route( endpoint=endpoint, target=str(updated_url), ) # dynamically construct pass-through endpoint based on incoming path received_value = await endpoint_func( request, fastapi_response, user_api_key_dict, query_params=merged_params, stream=is_streaming_request, ) return received_value @router.api_route("/cohere/{endpoint:path}", methods=["GET", "POST", "PUT", "DELETE"]) async def cohere_proxy_route( endpoint: str, request: Request, fastapi_response: Response, user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), ): base_target_url = "https://api.cohere.com" encoded_endpoint = httpx.URL(endpoint).path # Ensure endpoint starts with '/' for proper URL construction if not encoded_endpoint.startswith("/"): encoded_endpoint = "/" + encoded_endpoint # Construct the full target URL using httpx base_url = httpx.URL(base_target_url) updated_url = base_url.copy_with(path=encoded_endpoint) # Add or update query parameters cohere_api_key = litellm.utils.get_secret(secret_name="COHERE_API_KEY") ## check for streaming is_streaming_request = False if "stream" in str(updated_url): is_streaming_request = True ## CREATE PASS-THROUGH endpoint_func = create_pass_through_route( endpoint=endpoint, target=str(updated_url), custom_headers={"Authorization": "Bearer {}".format(cohere_api_key)}, ) # dynamically construct pass-through endpoint based on incoming path received_value = await endpoint_func( request, fastapi_response, user_api_key_dict, stream=is_streaming_request, ) return received_value @router.api_route("/bedrock/{endpoint:path}", methods=["GET", "POST", "PUT", "DELETE"]) async def bedrock_proxy_route( endpoint: str, request: Request, fastapi_response: Response, user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), ): request_copy = create_request_copy(request) try: import boto3 from botocore.auth import SigV4Auth from botocore.awsrequest import AWSRequest from botocore.credentials import Credentials except ImportError as e: raise ImportError("Missing boto3 to call bedrock. Run 'pip install boto3'.") aws_region_name = litellm.utils.get_secret(secret_name="AWS_REGION_NAME") base_target_url = f"https://bedrock-runtime.{aws_region_name}.amazonaws.com" encoded_endpoint = httpx.URL(endpoint).path # Ensure endpoint starts with '/' for proper URL construction if not encoded_endpoint.startswith("/"): encoded_endpoint = "/" + encoded_endpoint # Construct the full target URL using httpx base_url = httpx.URL(base_target_url) updated_url = base_url.copy_with(path=encoded_endpoint) # Add or update query parameters from litellm.llms.bedrock_httpx import BedrockConverseLLM credentials: Credentials = BedrockConverseLLM().get_credentials() sigv4 = SigV4Auth(credentials, "bedrock", aws_region_name) headers = {"Content-Type": "application/json"} # Assuming the body contains JSON data, parse it try: data = await request.json() except Exception as e: raise HTTPException(status_code=400, detail={"error": e}) _request = AWSRequest( method="POST", url=str(updated_url), data=json.dumps(data), headers=headers ) sigv4.add_auth(_request) prepped = _request.prepare() ## check for streaming is_streaming_request = False if "stream" in str(updated_url): is_streaming_request = True ## CREATE PASS-THROUGH endpoint_func = create_pass_through_route( endpoint=endpoint, target=str(prepped.url), custom_headers=prepped.headers, ) # dynamically construct pass-through endpoint based on incoming path received_value = await endpoint_func( request, fastapi_response, user_api_key_dict, stream=is_streaming_request, custom_body=data, query_params={}, ) return received_value