(feat) Add Bedrock knowledge base pass through endpoints (#7267)
All checks were successful
Read Version from pyproject.toml / read-version (push) Successful in 56s

* bugfix: Proxy Routing for Bedrock Knowledgebase URLs are incorrect (#7097)

* Fixing routing bug where bedrock knowledgebase urls were being generated incorrectly

* Preparing for PR

* Preparing for PR

* Preparing for PR

---------

Co-authored-by: Luke Birk <lb0737@att.com>

* fix _is_bedrock_agent_runtime_route

* docs - Query Knowledge Base

* test_is_bedrock_agent_runtime_route

* fix bedrock_proxy_route

---------

Co-authored-by: LBirk <2731718+LBirk@users.noreply.github.com>
Co-authored-by: Luke Birk <lb0737@att.com>
This commit is contained in:
Ishaan Jaff 2024-12-16 22:19:34 -08:00 committed by GitHub
parent 3c984ed60e
commit f3b13a9af3
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 51 additions and 2 deletions

View file

@ -164,7 +164,7 @@ curl -X POST "http://0.0.0.0:4000/bedrock/knowledgebases/{knowledgeBaseId}/retri
#### Direct Bedrock API Call
```bash
curl -X POST "https://bedrock-runtime.us-west-2.amazonaws.com/knowledgebases/{knowledgeBaseId}/retrieve" \
curl -X POST "https://bedrock-agent-runtime.us-west-2.amazonaws.com/knowledgebases/{knowledgeBaseId}/retrieve" \
-H 'Authorization: AWS4-HMAC-SHA256..' \
-H 'Content-Type: application/json' \
-d '{

View file

@ -75,8 +75,20 @@ RESPONSE_FORMAT_TOOL_NAME = "json_tool_call" # default tool name used when conv
AZURE_STORAGE_MSFT_VERSION = "2019-07-07"
########################### LiteLLM Proxy Specific Constants ###########################
########################################################################################
MAX_SPENDLOG_ROWS_TO_QUERY = (
1_000_000 # if spendLogs has more than 1M rows, do not query the DB
)
# makes it clear this is a rate limit error for a litellm virtual key
RATE_LIMIT_ERROR_MESSAGE_FOR_VIRTUAL_KEY = "LiteLLM Virtual Key user_api_key_hash"
# pass through route constansts
BEDROCK_AGENT_RUNTIME_PASS_THROUGH_ROUTES = [
"agents/",
"knowledgebases/",
"flows/",
"retrieveAndGenerate/",
"rerank/",
"generateQuery/",
"optimize-prompt/",
]

View file

@ -32,6 +32,7 @@ from starlette.datastructures import QueryParams
import litellm
from litellm._logging import verbose_proxy_logger
from litellm.batches.main import FileObject
from litellm.constants import BEDROCK_AGENT_RUNTIME_PASS_THROUGH_ROUTES
from litellm.fine_tuning.main import vertex_fine_tuning_apis_instance
from litellm.proxy._types import *
from litellm.proxy.auth.user_api_key_auth import user_api_key_auth
@ -247,7 +248,7 @@ async def bedrock_proxy_route(
raise ImportError("Missing boto3 to call bedrock. Run 'pip install boto3'.")
aws_region_name = litellm.utils.get_secret(secret_name="AWS_REGION_NAME")
if endpoint.startswith("agents/"): # handle bedrock agents
if _is_bedrock_agent_runtime_route(endpoint=endpoint): # handle bedrock agents
base_target_url = (
f"https://bedrock-agent-runtime.{aws_region_name}.amazonaws.com"
)
@ -303,6 +304,16 @@ async def bedrock_proxy_route(
return received_value
def _is_bedrock_agent_runtime_route(endpoint: str) -> bool:
"""
Return True, if the endpoint should be routed to the `bedrock-agent-runtime` endpoint.
"""
for _route in BEDROCK_AGENT_RUNTIME_PASS_THROUGH_ROUTES:
if _route in endpoint:
return True
return False
@router.api_route(
"/azure/{endpoint:path}",
methods=["GET", "POST", "PUT", "DELETE", "PATCH"],

View file

@ -355,3 +355,29 @@ def test_pass_through_routes_support_all_methods():
# Check both routers
check_router_methods(llm_router)
check_router_methods(vertex_router)
def test_is_bedrock_agent_runtime_route():
"""
Test that _is_bedrock_agent_runtime_route correctly identifies bedrock agent runtime endpoints
"""
from litellm.proxy.pass_through_endpoints.llm_passthrough_endpoints import (
_is_bedrock_agent_runtime_route,
)
# Test agent runtime endpoints (should return True)
assert _is_bedrock_agent_runtime_route("/knowledgebases/kb-123/retrieve") is True
assert (
_is_bedrock_agent_runtime_route("/agents/knowledgebases/kb-123/retrieve")
is True
)
# Test regular bedrock runtime endpoints (should return False)
assert (
_is_bedrock_agent_runtime_route("/guardrail/test-id/version/1/apply") is False
)
assert (
_is_bedrock_agent_runtime_route("/model/cohere.command-r-v1:0/converse")
is False
)
assert _is_bedrock_agent_runtime_route("/some/random/endpoint") is False