fix(pass_through_endpoints): support bedrock agents via pass through (#5527)

This commit is contained in:
Krish Dholakia 2024-09-04 22:22:22 -07:00 committed by GitHub
parent 1e7e538261
commit ca37bb9de5
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 24 additions and 16 deletions

View file

@ -1,16 +1,5 @@
model_list:
- model_name: gpt-4o-mini-2024-07-18
- model_name: "*"
litellm_params:
api_key: API_KEY
model: openai/gpt-4o-mini-2024-07-18
rpm: 0
tpm: 100
router_settings:
num_retries: 0
routing_strategy: usage-based-routing-v2
timeout: 10
litellm_settings:
callbacks: custom_callbacks.proxy_handler_instance
model: openai/*

View file

@ -244,6 +244,13 @@ class LiteLLMRoutes(enum.Enum):
"/utils/token_counter",
]
mapped_pass_through_routes: List = [
"/bedrock",
"/vertex-ai",
"/gemini",
"/langfuse",
]
anthropic_routes: List = [
"/v1/messages",
]

View file

@ -387,9 +387,16 @@ async def user_api_key_auth(
)
#### ELSE ####
## CHECK PASS-THROUGH ENDPOINTS ##
is_mapped_pass_through_route: bool = False
for mapped_route in LiteLLMRoutes.mapped_pass_through_routes.value:
if route.startswith(mapped_route):
is_mapped_pass_through_route = True
if is_mapped_pass_through_route:
if request.headers.get("litellm_user_api_key") is not None:
api_key = request.headers.get("litellm_user_api_key") or ""
if pass_through_endpoints is not None:
for endpoint in pass_through_endpoints:
if endpoint.get("path", "") == route:
if isinstance(endpoint, dict) and endpoint.get("path", "") == route:
## IF AUTH DISABLED
if endpoint.get("auth") is not True:
return UserAPIKeyAuth()

View file

@ -166,6 +166,11 @@ async def bedrock_proxy_route(
raise ImportError("Missing boto3 to call bedrock. Run 'pip install boto3'.")
aws_region_name = litellm.utils.get_secret(secret_name="AWS_REGION_NAME")
if endpoint.startswith("agents/"): # handle bedrock agents
base_target_url = (
f"https://bedrock-agent-runtime.{aws_region_name}.amazonaws.com"
)
else:
base_target_url = f"https://bedrock-runtime.{aws_region_name}.amazonaws.com"
encoded_endpoint = httpx.URL(endpoint).path