mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 10:44:24 +00:00
(fix) passthrough - allow internal users to access /anthropic (#6843)
* fix /anthropic/ * test llm_passthrough_router * fix test_gemini_pass_through_endpoint
This commit is contained in:
parent
50d2510b60
commit
a7d5536872
5 changed files with 22 additions and 8 deletions
|
@ -0,0 +1,322 @@
|
|||
"""
|
||||
What is this?
|
||||
|
||||
Provider-specific Pass-Through Endpoints
|
||||
|
||||
Use litellm with Anthropic SDK, Vertex AI SDK, Cohere SDK, etc.
|
||||
"""
|
||||
|
||||
import ast
|
||||
import asyncio
|
||||
import traceback
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import List, Optional
|
||||
from urllib.parse import urlencode
|
||||
|
||||
import fastapi
|
||||
import httpx
|
||||
from fastapi import (
|
||||
APIRouter,
|
||||
Depends,
|
||||
File,
|
||||
Form,
|
||||
Header,
|
||||
HTTPException,
|
||||
Request,
|
||||
Response,
|
||||
UploadFile,
|
||||
status,
|
||||
)
|
||||
from starlette.datastructures import QueryParams
|
||||
|
||||
import litellm
|
||||
from litellm._logging import verbose_proxy_logger
|
||||
from litellm.batches.main import FileObject
|
||||
from litellm.fine_tuning.main import vertex_fine_tuning_apis_instance
|
||||
from litellm.proxy._types import *
|
||||
from litellm.proxy.auth.user_api_key_auth import user_api_key_auth
|
||||
from litellm.proxy.pass_through_endpoints.pass_through_endpoints import (
|
||||
create_pass_through_route,
|
||||
)
|
||||
from litellm.secret_managers.main import get_secret_str
|
||||
|
||||
router = APIRouter()
|
||||
default_vertex_config = None
|
||||
|
||||
|
||||
def create_request_copy(request: Request):
|
||||
return {
|
||||
"method": request.method,
|
||||
"url": str(request.url),
|
||||
"headers": dict(request.headers),
|
||||
"cookies": request.cookies,
|
||||
"query_params": dict(request.query_params),
|
||||
}
|
||||
|
||||
|
||||
@router.api_route("/gemini/{endpoint:path}", methods=["GET", "POST", "PUT", "DELETE"])
|
||||
async def gemini_proxy_route(
|
||||
endpoint: str,
|
||||
request: Request,
|
||||
fastapi_response: Response,
|
||||
):
|
||||
## CHECK FOR LITELLM API KEY IN THE QUERY PARAMS - ?..key=LITELLM_API_KEY
|
||||
api_key = request.query_params.get("key")
|
||||
|
||||
user_api_key_dict = await user_api_key_auth(
|
||||
request=request, api_key="Bearer {}".format(api_key)
|
||||
)
|
||||
|
||||
base_target_url = "https://generativelanguage.googleapis.com"
|
||||
encoded_endpoint = httpx.URL(endpoint).path
|
||||
|
||||
# Ensure endpoint starts with '/' for proper URL construction
|
||||
if not encoded_endpoint.startswith("/"):
|
||||
encoded_endpoint = "/" + encoded_endpoint
|
||||
|
||||
# Construct the full target URL using httpx
|
||||
base_url = httpx.URL(base_target_url)
|
||||
updated_url = base_url.copy_with(path=encoded_endpoint)
|
||||
|
||||
# Add or update query parameters
|
||||
gemini_api_key: Optional[str] = litellm.utils.get_secret( # type: ignore
|
||||
secret_name="GEMINI_API_KEY"
|
||||
)
|
||||
if gemini_api_key is None:
|
||||
raise Exception(
|
||||
"Required 'GEMINI_API_KEY' in environment to make pass-through calls to Google AI Studio."
|
||||
)
|
||||
# Merge query parameters, giving precedence to those in updated_url
|
||||
merged_params = dict(request.query_params)
|
||||
merged_params.update({"key": gemini_api_key})
|
||||
|
||||
## check for streaming
|
||||
is_streaming_request = False
|
||||
if "stream" in str(updated_url):
|
||||
is_streaming_request = True
|
||||
|
||||
## CREATE PASS-THROUGH
|
||||
endpoint_func = create_pass_through_route(
|
||||
endpoint=endpoint,
|
||||
target=str(updated_url),
|
||||
) # dynamically construct pass-through endpoint based on incoming path
|
||||
received_value = await endpoint_func(
|
||||
request,
|
||||
fastapi_response,
|
||||
user_api_key_dict,
|
||||
query_params=merged_params, # type: ignore
|
||||
stream=is_streaming_request, # type: ignore
|
||||
)
|
||||
|
||||
return received_value
|
||||
|
||||
|
||||
@router.api_route("/cohere/{endpoint:path}", methods=["GET", "POST", "PUT", "DELETE"])
|
||||
async def cohere_proxy_route(
|
||||
endpoint: str,
|
||||
request: Request,
|
||||
fastapi_response: Response,
|
||||
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
||||
):
|
||||
base_target_url = "https://api.cohere.com"
|
||||
encoded_endpoint = httpx.URL(endpoint).path
|
||||
|
||||
# Ensure endpoint starts with '/' for proper URL construction
|
||||
if not encoded_endpoint.startswith("/"):
|
||||
encoded_endpoint = "/" + encoded_endpoint
|
||||
|
||||
# Construct the full target URL using httpx
|
||||
base_url = httpx.URL(base_target_url)
|
||||
updated_url = base_url.copy_with(path=encoded_endpoint)
|
||||
|
||||
# Add or update query parameters
|
||||
cohere_api_key = litellm.utils.get_secret(secret_name="COHERE_API_KEY")
|
||||
|
||||
## check for streaming
|
||||
is_streaming_request = False
|
||||
if "stream" in str(updated_url):
|
||||
is_streaming_request = True
|
||||
|
||||
## CREATE PASS-THROUGH
|
||||
endpoint_func = create_pass_through_route(
|
||||
endpoint=endpoint,
|
||||
target=str(updated_url),
|
||||
custom_headers={"Authorization": "Bearer {}".format(cohere_api_key)},
|
||||
) # dynamically construct pass-through endpoint based on incoming path
|
||||
received_value = await endpoint_func(
|
||||
request,
|
||||
fastapi_response,
|
||||
user_api_key_dict,
|
||||
stream=is_streaming_request, # type: ignore
|
||||
)
|
||||
|
||||
return received_value
|
||||
|
||||
|
||||
@router.api_route(
|
||||
"/anthropic/{endpoint:path}", methods=["GET", "POST", "PUT", "DELETE"]
|
||||
)
|
||||
async def anthropic_proxy_route(
|
||||
endpoint: str,
|
||||
request: Request,
|
||||
fastapi_response: Response,
|
||||
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
||||
):
|
||||
base_target_url = "https://api.anthropic.com"
|
||||
encoded_endpoint = httpx.URL(endpoint).path
|
||||
|
||||
# Ensure endpoint starts with '/' for proper URL construction
|
||||
if not encoded_endpoint.startswith("/"):
|
||||
encoded_endpoint = "/" + encoded_endpoint
|
||||
|
||||
# Construct the full target URL using httpx
|
||||
base_url = httpx.URL(base_target_url)
|
||||
updated_url = base_url.copy_with(path=encoded_endpoint)
|
||||
|
||||
# Add or update query parameters
|
||||
anthropic_api_key = litellm.utils.get_secret(secret_name="ANTHROPIC_API_KEY")
|
||||
|
||||
## check for streaming
|
||||
is_streaming_request = False
|
||||
if "stream" in str(updated_url):
|
||||
is_streaming_request = True
|
||||
|
||||
## CREATE PASS-THROUGH
|
||||
endpoint_func = create_pass_through_route(
|
||||
endpoint=endpoint,
|
||||
target=str(updated_url),
|
||||
custom_headers={"x-api-key": "{}".format(anthropic_api_key)},
|
||||
_forward_headers=True,
|
||||
) # dynamically construct pass-through endpoint based on incoming path
|
||||
received_value = await endpoint_func(
|
||||
request,
|
||||
fastapi_response,
|
||||
user_api_key_dict,
|
||||
stream=is_streaming_request, # type: ignore
|
||||
)
|
||||
|
||||
return received_value
|
||||
|
||||
|
||||
@router.api_route("/bedrock/{endpoint:path}", methods=["GET", "POST", "PUT", "DELETE"])
|
||||
async def bedrock_proxy_route(
|
||||
endpoint: str,
|
||||
request: Request,
|
||||
fastapi_response: Response,
|
||||
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
||||
):
|
||||
create_request_copy(request)
|
||||
|
||||
try:
|
||||
import boto3
|
||||
from botocore.auth import SigV4Auth
|
||||
from botocore.awsrequest import AWSRequest
|
||||
from botocore.credentials import Credentials
|
||||
except ImportError:
|
||||
raise ImportError("Missing boto3 to call bedrock. Run 'pip install boto3'.")
|
||||
|
||||
aws_region_name = litellm.utils.get_secret(secret_name="AWS_REGION_NAME")
|
||||
if endpoint.startswith("agents/"): # handle bedrock agents
|
||||
base_target_url = (
|
||||
f"https://bedrock-agent-runtime.{aws_region_name}.amazonaws.com"
|
||||
)
|
||||
else:
|
||||
base_target_url = f"https://bedrock-runtime.{aws_region_name}.amazonaws.com"
|
||||
encoded_endpoint = httpx.URL(endpoint).path
|
||||
|
||||
# Ensure endpoint starts with '/' for proper URL construction
|
||||
if not encoded_endpoint.startswith("/"):
|
||||
encoded_endpoint = "/" + encoded_endpoint
|
||||
|
||||
# Construct the full target URL using httpx
|
||||
base_url = httpx.URL(base_target_url)
|
||||
updated_url = base_url.copy_with(path=encoded_endpoint)
|
||||
|
||||
# Add or update query parameters
|
||||
from litellm.llms.bedrock.chat import BedrockConverseLLM
|
||||
|
||||
credentials: Credentials = BedrockConverseLLM().get_credentials()
|
||||
sigv4 = SigV4Auth(credentials, "bedrock", aws_region_name)
|
||||
headers = {"Content-Type": "application/json"}
|
||||
# Assuming the body contains JSON data, parse it
|
||||
try:
|
||||
data = await request.json()
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=400, detail={"error": e})
|
||||
_request = AWSRequest(
|
||||
method="POST", url=str(updated_url), data=json.dumps(data), headers=headers
|
||||
)
|
||||
sigv4.add_auth(_request)
|
||||
prepped = _request.prepare()
|
||||
|
||||
## check for streaming
|
||||
is_streaming_request = False
|
||||
if "stream" in str(updated_url):
|
||||
is_streaming_request = True
|
||||
|
||||
## CREATE PASS-THROUGH
|
||||
endpoint_func = create_pass_through_route(
|
||||
endpoint=endpoint,
|
||||
target=str(prepped.url),
|
||||
custom_headers=prepped.headers, # type: ignore
|
||||
) # dynamically construct pass-through endpoint based on incoming path
|
||||
received_value = await endpoint_func(
|
||||
request,
|
||||
fastapi_response,
|
||||
user_api_key_dict,
|
||||
stream=is_streaming_request, # type: ignore
|
||||
custom_body=data, # type: ignore
|
||||
query_params={}, # type: ignore
|
||||
)
|
||||
|
||||
return received_value
|
||||
|
||||
|
||||
@router.api_route("/azure/{endpoint:path}", methods=["GET", "POST", "PUT", "DELETE"])
|
||||
async def azure_proxy_route(
|
||||
endpoint: str,
|
||||
request: Request,
|
||||
fastapi_response: Response,
|
||||
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
||||
):
|
||||
base_target_url = get_secret_str(secret_name="AZURE_API_BASE")
|
||||
if base_target_url is None:
|
||||
raise Exception(
|
||||
"Required 'AZURE_API_BASE' in environment to make pass-through calls to Azure."
|
||||
)
|
||||
encoded_endpoint = httpx.URL(endpoint).path
|
||||
|
||||
# Ensure endpoint starts with '/' for proper URL construction
|
||||
if not encoded_endpoint.startswith("/"):
|
||||
encoded_endpoint = "/" + encoded_endpoint
|
||||
|
||||
# Construct the full target URL using httpx
|
||||
base_url = httpx.URL(base_target_url)
|
||||
updated_url = base_url.copy_with(path=encoded_endpoint)
|
||||
|
||||
# Add or update query parameters
|
||||
azure_api_key = get_secret_str(secret_name="AZURE_API_KEY")
|
||||
|
||||
## check for streaming
|
||||
is_streaming_request = False
|
||||
if "stream" in str(updated_url):
|
||||
is_streaming_request = True
|
||||
|
||||
## CREATE PASS-THROUGH
|
||||
endpoint_func = create_pass_through_route(
|
||||
endpoint=endpoint,
|
||||
target=str(updated_url),
|
||||
custom_headers={
|
||||
"authorization": "Bearer {}".format(azure_api_key),
|
||||
"api-key": "{}".format(azure_api_key),
|
||||
},
|
||||
) # dynamically construct pass-through endpoint based on incoming path
|
||||
received_value = await endpoint_func(
|
||||
request,
|
||||
fastapi_response,
|
||||
user_api_key_dict,
|
||||
stream=is_streaming_request, # type: ignore
|
||||
query_params=dict(request.query_params), # type: ignore
|
||||
)
|
||||
|
||||
return received_value
|
Loading…
Add table
Add a link
Reference in a new issue