diff --git a/docs/my-website/docs/pass_through/bedrock.md b/docs/my-website/docs/pass_through/bedrock.md index bcbeb09db3..ed184f5706 100644 --- a/docs/my-website/docs/pass_through/bedrock.md +++ b/docs/my-website/docs/pass_through/bedrock.md @@ -164,7 +164,7 @@ curl -X POST "http://0.0.0.0:4000/bedrock/knowledgebases/{knowledgeBaseId}/retri #### Direct Bedrock API Call ```bash -curl -X POST "https://bedrock-runtime.us-west-2.amazonaws.com/knowledgebases/{knowledgeBaseId}/retrieve" \ +curl -X POST "https://bedrock-agent-runtime.us-west-2.amazonaws.com/knowledgebases/{knowledgeBaseId}/retrieve" \ -H 'Authorization: AWS4-HMAC-SHA256..' \ -H 'Content-Type: application/json' \ -d '{ diff --git a/litellm/constants.py b/litellm/constants.py index de745c63b8..0cff9ab5ab 100644 --- a/litellm/constants.py +++ b/litellm/constants.py @@ -75,8 +75,20 @@ RESPONSE_FORMAT_TOOL_NAME = "json_tool_call" # default tool name used when conv AZURE_STORAGE_MSFT_VERSION = "2019-07-07" ########################### LiteLLM Proxy Specific Constants ########################### +######################################################################################## MAX_SPENDLOG_ROWS_TO_QUERY = ( 1_000_000 # if spendLogs has more than 1M rows, do not query the DB ) # makes it clear this is a rate limit error for a litellm virtual key RATE_LIMIT_ERROR_MESSAGE_FOR_VIRTUAL_KEY = "LiteLLM Virtual Key user_api_key_hash" + +# pass through route constansts +BEDROCK_AGENT_RUNTIME_PASS_THROUGH_ROUTES = [ + "agents/", + "knowledgebases/", + "flows/", + "retrieveAndGenerate/", + "rerank/", + "generateQuery/", + "optimize-prompt/", +] diff --git a/litellm/proxy/pass_through_endpoints/llm_passthrough_endpoints.py b/litellm/proxy/pass_through_endpoints/llm_passthrough_endpoints.py index cae211da77..611a74db93 100644 --- a/litellm/proxy/pass_through_endpoints/llm_passthrough_endpoints.py +++ b/litellm/proxy/pass_through_endpoints/llm_passthrough_endpoints.py @@ -32,6 +32,7 @@ from starlette.datastructures import QueryParams import litellm from litellm._logging import verbose_proxy_logger from litellm.batches.main import FileObject +from litellm.constants import BEDROCK_AGENT_RUNTIME_PASS_THROUGH_ROUTES from litellm.fine_tuning.main import vertex_fine_tuning_apis_instance from litellm.proxy._types import * from litellm.proxy.auth.user_api_key_auth import user_api_key_auth @@ -247,7 +248,7 @@ async def bedrock_proxy_route( raise ImportError("Missing boto3 to call bedrock. Run 'pip install boto3'.") aws_region_name = litellm.utils.get_secret(secret_name="AWS_REGION_NAME") - if endpoint.startswith("agents/"): # handle bedrock agents + if _is_bedrock_agent_runtime_route(endpoint=endpoint): # handle bedrock agents base_target_url = ( f"https://bedrock-agent-runtime.{aws_region_name}.amazonaws.com" ) @@ -303,6 +304,16 @@ async def bedrock_proxy_route( return received_value +def _is_bedrock_agent_runtime_route(endpoint: str) -> bool: + """ + Return True, if the endpoint should be routed to the `bedrock-agent-runtime` endpoint. + """ + for _route in BEDROCK_AGENT_RUNTIME_PASS_THROUGH_ROUTES: + if _route in endpoint: + return True + return False + + @router.api_route( "/azure/{endpoint:path}", methods=["GET", "POST", "PUT", "DELETE", "PATCH"], diff --git a/tests/pass_through_unit_tests/test_pass_through_unit_tests.py b/tests/pass_through_unit_tests/test_pass_through_unit_tests.py index d5b6b1c9a9..20a5d8aab6 100644 --- a/tests/pass_through_unit_tests/test_pass_through_unit_tests.py +++ b/tests/pass_through_unit_tests/test_pass_through_unit_tests.py @@ -355,3 +355,29 @@ def test_pass_through_routes_support_all_methods(): # Check both routers check_router_methods(llm_router) check_router_methods(vertex_router) + + +def test_is_bedrock_agent_runtime_route(): + """ + Test that _is_bedrock_agent_runtime_route correctly identifies bedrock agent runtime endpoints + """ + from litellm.proxy.pass_through_endpoints.llm_passthrough_endpoints import ( + _is_bedrock_agent_runtime_route, + ) + + # Test agent runtime endpoints (should return True) + assert _is_bedrock_agent_runtime_route("/knowledgebases/kb-123/retrieve") is True + assert ( + _is_bedrock_agent_runtime_route("/agents/knowledgebases/kb-123/retrieve") + is True + ) + + # Test regular bedrock runtime endpoints (should return False) + assert ( + _is_bedrock_agent_runtime_route("/guardrail/test-id/version/1/apply") is False + ) + assert ( + _is_bedrock_agent_runtime_route("/model/cohere.command-r-v1:0/converse") + is False + ) + assert _is_bedrock_agent_runtime_route("/some/random/endpoint") is False