mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 10:44:24 +00:00
(feat) Add Bedrock knowledge base pass through endpoints (#7267)
All checks were successful
Read Version from pyproject.toml / read-version (push) Successful in 56s
All checks were successful
Read Version from pyproject.toml / read-version (push) Successful in 56s
* bugfix: Proxy Routing for Bedrock Knowledgebase URLs are incorrect (#7097) * Fixing routing bug where bedrock knowledgebase urls were being generated incorrectly * Preparing for PR * Preparing for PR * Preparing for PR --------- Co-authored-by: Luke Birk <lb0737@att.com> * fix _is_bedrock_agent_runtime_route * docs - Query Knowledge Base * test_is_bedrock_agent_runtime_route * fix bedrock_proxy_route --------- Co-authored-by: LBirk <2731718+LBirk@users.noreply.github.com> Co-authored-by: Luke Birk <lb0737@att.com>
This commit is contained in:
parent
3c984ed60e
commit
f3b13a9af3
4 changed files with 51 additions and 2 deletions
|
@ -164,7 +164,7 @@ curl -X POST "http://0.0.0.0:4000/bedrock/knowledgebases/{knowledgeBaseId}/retri
|
||||||
#### Direct Bedrock API Call
|
#### Direct Bedrock API Call
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
curl -X POST "https://bedrock-runtime.us-west-2.amazonaws.com/knowledgebases/{knowledgeBaseId}/retrieve" \
|
curl -X POST "https://bedrock-agent-runtime.us-west-2.amazonaws.com/knowledgebases/{knowledgeBaseId}/retrieve" \
|
||||||
-H 'Authorization: AWS4-HMAC-SHA256..' \
|
-H 'Authorization: AWS4-HMAC-SHA256..' \
|
||||||
-H 'Content-Type: application/json' \
|
-H 'Content-Type: application/json' \
|
||||||
-d '{
|
-d '{
|
||||||
|
|
|
@ -75,8 +75,20 @@ RESPONSE_FORMAT_TOOL_NAME = "json_tool_call" # default tool name used when conv
|
||||||
AZURE_STORAGE_MSFT_VERSION = "2019-07-07"
|
AZURE_STORAGE_MSFT_VERSION = "2019-07-07"
|
||||||
|
|
||||||
########################### LiteLLM Proxy Specific Constants ###########################
|
########################### LiteLLM Proxy Specific Constants ###########################
|
||||||
|
########################################################################################
|
||||||
MAX_SPENDLOG_ROWS_TO_QUERY = (
|
MAX_SPENDLOG_ROWS_TO_QUERY = (
|
||||||
1_000_000 # if spendLogs has more than 1M rows, do not query the DB
|
1_000_000 # if spendLogs has more than 1M rows, do not query the DB
|
||||||
)
|
)
|
||||||
# makes it clear this is a rate limit error for a litellm virtual key
|
# makes it clear this is a rate limit error for a litellm virtual key
|
||||||
RATE_LIMIT_ERROR_MESSAGE_FOR_VIRTUAL_KEY = "LiteLLM Virtual Key user_api_key_hash"
|
RATE_LIMIT_ERROR_MESSAGE_FOR_VIRTUAL_KEY = "LiteLLM Virtual Key user_api_key_hash"
|
||||||
|
|
||||||
|
# pass through route constansts
|
||||||
|
BEDROCK_AGENT_RUNTIME_PASS_THROUGH_ROUTES = [
|
||||||
|
"agents/",
|
||||||
|
"knowledgebases/",
|
||||||
|
"flows/",
|
||||||
|
"retrieveAndGenerate/",
|
||||||
|
"rerank/",
|
||||||
|
"generateQuery/",
|
||||||
|
"optimize-prompt/",
|
||||||
|
]
|
||||||
|
|
|
@ -32,6 +32,7 @@ from starlette.datastructures import QueryParams
|
||||||
import litellm
|
import litellm
|
||||||
from litellm._logging import verbose_proxy_logger
|
from litellm._logging import verbose_proxy_logger
|
||||||
from litellm.batches.main import FileObject
|
from litellm.batches.main import FileObject
|
||||||
|
from litellm.constants import BEDROCK_AGENT_RUNTIME_PASS_THROUGH_ROUTES
|
||||||
from litellm.fine_tuning.main import vertex_fine_tuning_apis_instance
|
from litellm.fine_tuning.main import vertex_fine_tuning_apis_instance
|
||||||
from litellm.proxy._types import *
|
from litellm.proxy._types import *
|
||||||
from litellm.proxy.auth.user_api_key_auth import user_api_key_auth
|
from litellm.proxy.auth.user_api_key_auth import user_api_key_auth
|
||||||
|
@ -247,7 +248,7 @@ async def bedrock_proxy_route(
|
||||||
raise ImportError("Missing boto3 to call bedrock. Run 'pip install boto3'.")
|
raise ImportError("Missing boto3 to call bedrock. Run 'pip install boto3'.")
|
||||||
|
|
||||||
aws_region_name = litellm.utils.get_secret(secret_name="AWS_REGION_NAME")
|
aws_region_name = litellm.utils.get_secret(secret_name="AWS_REGION_NAME")
|
||||||
if endpoint.startswith("agents/"): # handle bedrock agents
|
if _is_bedrock_agent_runtime_route(endpoint=endpoint): # handle bedrock agents
|
||||||
base_target_url = (
|
base_target_url = (
|
||||||
f"https://bedrock-agent-runtime.{aws_region_name}.amazonaws.com"
|
f"https://bedrock-agent-runtime.{aws_region_name}.amazonaws.com"
|
||||||
)
|
)
|
||||||
|
@ -303,6 +304,16 @@ async def bedrock_proxy_route(
|
||||||
return received_value
|
return received_value
|
||||||
|
|
||||||
|
|
||||||
|
def _is_bedrock_agent_runtime_route(endpoint: str) -> bool:
|
||||||
|
"""
|
||||||
|
Return True, if the endpoint should be routed to the `bedrock-agent-runtime` endpoint.
|
||||||
|
"""
|
||||||
|
for _route in BEDROCK_AGENT_RUNTIME_PASS_THROUGH_ROUTES:
|
||||||
|
if _route in endpoint:
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
@router.api_route(
|
@router.api_route(
|
||||||
"/azure/{endpoint:path}",
|
"/azure/{endpoint:path}",
|
||||||
methods=["GET", "POST", "PUT", "DELETE", "PATCH"],
|
methods=["GET", "POST", "PUT", "DELETE", "PATCH"],
|
||||||
|
|
|
@ -355,3 +355,29 @@ def test_pass_through_routes_support_all_methods():
|
||||||
# Check both routers
|
# Check both routers
|
||||||
check_router_methods(llm_router)
|
check_router_methods(llm_router)
|
||||||
check_router_methods(vertex_router)
|
check_router_methods(vertex_router)
|
||||||
|
|
||||||
|
|
||||||
|
def test_is_bedrock_agent_runtime_route():
|
||||||
|
"""
|
||||||
|
Test that _is_bedrock_agent_runtime_route correctly identifies bedrock agent runtime endpoints
|
||||||
|
"""
|
||||||
|
from litellm.proxy.pass_through_endpoints.llm_passthrough_endpoints import (
|
||||||
|
_is_bedrock_agent_runtime_route,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Test agent runtime endpoints (should return True)
|
||||||
|
assert _is_bedrock_agent_runtime_route("/knowledgebases/kb-123/retrieve") is True
|
||||||
|
assert (
|
||||||
|
_is_bedrock_agent_runtime_route("/agents/knowledgebases/kb-123/retrieve")
|
||||||
|
is True
|
||||||
|
)
|
||||||
|
|
||||||
|
# Test regular bedrock runtime endpoints (should return False)
|
||||||
|
assert (
|
||||||
|
_is_bedrock_agent_runtime_route("/guardrail/test-id/version/1/apply") is False
|
||||||
|
)
|
||||||
|
assert (
|
||||||
|
_is_bedrock_agent_runtime_route("/model/cohere.command-r-v1:0/converse")
|
||||||
|
is False
|
||||||
|
)
|
||||||
|
assert _is_bedrock_agent_runtime_route("/some/random/endpoint") is False
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue