forked from phoenix/litellm-mirror
Merge pull request #5264 from BerriAI/litellm_bedrock_pass_through
feat: Bedrock pass-through endpoint support (All endpoints)
This commit is contained in:
commit
f42ac2c9d8
4 changed files with 117 additions and 22 deletions
|
@ -283,6 +283,7 @@ async def pass_through_request(
|
|||
target: str,
|
||||
custom_headers: dict,
|
||||
user_api_key_dict: UserAPIKeyAuth,
|
||||
custom_body: Optional[dict] = None,
|
||||
forward_headers: Optional[bool] = False,
|
||||
query_params: Optional[dict] = None,
|
||||
stream: Optional[bool] = None,
|
||||
|
@ -300,12 +301,15 @@ async def pass_through_request(
|
|||
request=request, headers=headers, forward_headers=forward_headers
|
||||
)
|
||||
|
||||
request_body = await request.body()
|
||||
body_str = request_body.decode()
|
||||
try:
|
||||
_parsed_body = ast.literal_eval(body_str)
|
||||
except Exception:
|
||||
_parsed_body = json.loads(body_str)
|
||||
if custom_body:
|
||||
_parsed_body = custom_body
|
||||
else:
|
||||
request_body = await request.body()
|
||||
body_str = request_body.decode()
|
||||
try:
|
||||
_parsed_body = ast.literal_eval(body_str)
|
||||
except Exception:
|
||||
_parsed_body = json.loads(body_str)
|
||||
|
||||
verbose_proxy_logger.debug(
|
||||
"Pass through endpoint sending request to \nURL {}\nheaders: {}\nbody: {}\n".format(
|
||||
|
@ -356,22 +360,24 @@ async def pass_through_request(
|
|||
|
||||
# combine url with query params for logging
|
||||
|
||||
requested_query_params = query_params or request.query_params.__dict__
|
||||
requested_query_params_str = "&".join(
|
||||
f"{k}={v}" for k, v in requested_query_params.items()
|
||||
)
|
||||
# requested_query_params = query_params or request.query_params.__dict__
|
||||
# requested_query_params_str = "&".join(
|
||||
# f"{k}={v}" for k, v in requested_query_params.items()
|
||||
# )
|
||||
|
||||
if "?" in str(url):
|
||||
logging_url = str(url) + "&" + requested_query_params_str
|
||||
else:
|
||||
logging_url = str(url) + "?" + requested_query_params_str
|
||||
requested_query_params = None
|
||||
|
||||
# if "?" in str(url):
|
||||
# logging_url = str(url) + "&" + requested_query_params_str
|
||||
# else:
|
||||
# logging_url = str(url) + "?" + requested_query_params_str
|
||||
|
||||
logging_obj.pre_call(
|
||||
input=[{"role": "user", "content": "no-message-pass-through-endpoint"}],
|
||||
api_key="",
|
||||
additional_args={
|
||||
"complete_input_dict": _parsed_body,
|
||||
"api_base": logging_url,
|
||||
"api_base": str(url),
|
||||
"headers": headers,
|
||||
},
|
||||
)
|
||||
|
@ -526,6 +532,7 @@ def create_pass_through_route(
|
|||
fastapi_response: Response,
|
||||
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
||||
query_params: Optional[dict] = None,
|
||||
custom_body: Optional[dict] = None,
|
||||
stream: Optional[
|
||||
bool
|
||||
] = None, # if pass-through endpoint is a streaming request
|
||||
|
@ -538,6 +545,7 @@ def create_pass_through_route(
|
|||
forward_headers=_forward_headers,
|
||||
query_params=query_params,
|
||||
stream=stream,
|
||||
custom_body=custom_body,
|
||||
)
|
||||
|
||||
return endpoint_func
|
||||
|
|
|
@ -45,6 +45,16 @@ router = APIRouter()
|
|||
default_vertex_config = None
|
||||
|
||||
|
||||
def create_request_copy(request: Request):
|
||||
return {
|
||||
"method": request.method,
|
||||
"url": str(request.url),
|
||||
"headers": dict(request.headers),
|
||||
"cookies": request.cookies,
|
||||
"query_params": dict(request.query_params),
|
||||
}
|
||||
|
||||
|
||||
@router.api_route("/gemini/{endpoint:path}", methods=["GET", "POST", "PUT", "DELETE"])
|
||||
async def gemini_proxy_route(
|
||||
endpoint: str,
|
||||
|
@ -136,3 +146,72 @@ async def cohere_proxy_route(
|
|||
)
|
||||
|
||||
return received_value
|
||||
|
||||
|
||||
@router.api_route("/bedrock/{endpoint:path}", methods=["GET", "POST", "PUT", "DELETE"])
|
||||
async def bedrock_proxy_route(
|
||||
endpoint: str,
|
||||
request: Request,
|
||||
fastapi_response: Response,
|
||||
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
||||
):
|
||||
request_copy = create_request_copy(request)
|
||||
|
||||
try:
|
||||
import boto3
|
||||
from botocore.auth import SigV4Auth
|
||||
from botocore.awsrequest import AWSRequest
|
||||
from botocore.credentials import Credentials
|
||||
except ImportError as e:
|
||||
raise ImportError("Missing boto3 to call bedrock. Run 'pip install boto3'.")
|
||||
|
||||
aws_region_name = litellm.utils.get_secret(secret_name="AWS_REGION_NAME")
|
||||
base_target_url = f"https://bedrock-runtime.{aws_region_name}.amazonaws.com"
|
||||
encoded_endpoint = httpx.URL(endpoint).path
|
||||
|
||||
# Ensure endpoint starts with '/' for proper URL construction
|
||||
if not encoded_endpoint.startswith("/"):
|
||||
encoded_endpoint = "/" + encoded_endpoint
|
||||
|
||||
# Construct the full target URL using httpx
|
||||
base_url = httpx.URL(base_target_url)
|
||||
updated_url = base_url.copy_with(path=encoded_endpoint)
|
||||
|
||||
# Add or update query parameters
|
||||
from litellm.llms.bedrock_httpx import BedrockConverseLLM
|
||||
|
||||
credentials: Credentials = BedrockConverseLLM().get_credentials()
|
||||
sigv4 = SigV4Auth(credentials, "bedrock", aws_region_name)
|
||||
headers = {"Content-Type": "application/json"}
|
||||
# Assuming the body contains JSON data, parse it
|
||||
try:
|
||||
data = await request.json()
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=400, detail={"error": e})
|
||||
_request = AWSRequest(
|
||||
method="POST", url=str(updated_url), data=json.dumps(data), headers=headers
|
||||
)
|
||||
sigv4.add_auth(_request)
|
||||
prepped = _request.prepare()
|
||||
|
||||
## check for streaming
|
||||
is_streaming_request = False
|
||||
if "stream" in str(updated_url):
|
||||
is_streaming_request = True
|
||||
|
||||
## CREATE PASS-THROUGH
|
||||
endpoint_func = create_pass_through_route(
|
||||
endpoint=endpoint,
|
||||
target=str(prepped.url),
|
||||
custom_headers=prepped.headers,
|
||||
) # dynamically construct pass-through endpoint based on incoming path
|
||||
received_value = await endpoint_func(
|
||||
request,
|
||||
fastapi_response,
|
||||
user_api_key_dict,
|
||||
stream=is_streaming_request,
|
||||
custom_body=data,
|
||||
query_params={},
|
||||
)
|
||||
|
||||
return received_value
|
||||
|
|
|
@ -3375,14 +3375,14 @@ def response_format_tests(response: litellm.ModelResponse):
|
|||
@pytest.mark.parametrize(
|
||||
"model",
|
||||
[
|
||||
"bedrock/cohere.command-r-plus-v1:0",
|
||||
# "bedrock/cohere.command-r-plus-v1:0",
|
||||
"anthropic.claude-3-sonnet-20240229-v1:0",
|
||||
"anthropic.claude-instant-v1",
|
||||
"bedrock/ai21.j2-mid",
|
||||
"mistral.mistral-7b-instruct-v0:2",
|
||||
"bedrock/amazon.titan-tg1-large",
|
||||
"meta.llama3-8b-instruct-v1:0",
|
||||
"cohere.command-text-v14",
|
||||
# "anthropic.claude-instant-v1",
|
||||
# "bedrock/ai21.j2-mid",
|
||||
# "mistral.mistral-7b-instruct-v0:2",
|
||||
# "bedrock/amazon.titan-tg1-large",
|
||||
# "meta.llama3-8b-instruct-v1:0",
|
||||
# "cohere.command-text-v14",
|
||||
],
|
||||
)
|
||||
@pytest.mark.parametrize("sync_mode", [True, False])
|
||||
|
|
|
@ -9053,6 +9053,9 @@ class CustomStreamWrapper:
|
|||
text = ""
|
||||
is_finished = False
|
||||
finish_reason = ""
|
||||
index: Optional[int] = None
|
||||
if "index" in data_json:
|
||||
index = data_json.get("index")
|
||||
if "text" in data_json:
|
||||
text = data_json["text"]
|
||||
elif "is_finished" in data_json:
|
||||
|
@ -9061,6 +9064,7 @@ class CustomStreamWrapper:
|
|||
else:
|
||||
raise Exception(data_json)
|
||||
return {
|
||||
"index": index,
|
||||
"text": text,
|
||||
"is_finished": is_finished,
|
||||
"finish_reason": finish_reason,
|
||||
|
@ -10246,6 +10250,10 @@ class CustomStreamWrapper:
|
|||
completion_obj["role"] = "assistant"
|
||||
self.sent_first_chunk = True
|
||||
model_response.choices[0].delta = Delta(**completion_obj)
|
||||
if completion_obj.get("index") is not None:
|
||||
model_response.choices[0].index = completion_obj.get(
|
||||
"index"
|
||||
)
|
||||
print_verbose(f"returning model_response: {model_response}")
|
||||
return model_response
|
||||
else:
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue