Merge pull request #5264 from BerriAI/litellm_bedrock_pass_through

feat: Bedrock pass-through endpoint support (All endpoints)
This commit is contained in:
Krish Dholakia 2024-08-18 09:55:22 -07:00 committed by GitHub
commit f42ac2c9d8
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 117 additions and 22 deletions

View file

@ -283,6 +283,7 @@ async def pass_through_request(
target: str,
custom_headers: dict,
user_api_key_dict: UserAPIKeyAuth,
custom_body: Optional[dict] = None,
forward_headers: Optional[bool] = False,
query_params: Optional[dict] = None,
stream: Optional[bool] = None,
@ -300,12 +301,15 @@ async def pass_through_request(
request=request, headers=headers, forward_headers=forward_headers
)
request_body = await request.body()
body_str = request_body.decode()
try:
_parsed_body = ast.literal_eval(body_str)
except Exception:
_parsed_body = json.loads(body_str)
if custom_body:
_parsed_body = custom_body
else:
request_body = await request.body()
body_str = request_body.decode()
try:
_parsed_body = ast.literal_eval(body_str)
except Exception:
_parsed_body = json.loads(body_str)
verbose_proxy_logger.debug(
"Pass through endpoint sending request to \nURL {}\nheaders: {}\nbody: {}\n".format(
@ -356,22 +360,24 @@ async def pass_through_request(
# combine url with query params for logging
requested_query_params = query_params or request.query_params.__dict__
requested_query_params_str = "&".join(
f"{k}={v}" for k, v in requested_query_params.items()
)
# requested_query_params = query_params or request.query_params.__dict__
# requested_query_params_str = "&".join(
# f"{k}={v}" for k, v in requested_query_params.items()
# )
if "?" in str(url):
logging_url = str(url) + "&" + requested_query_params_str
else:
logging_url = str(url) + "?" + requested_query_params_str
requested_query_params = None
# if "?" in str(url):
# logging_url = str(url) + "&" + requested_query_params_str
# else:
# logging_url = str(url) + "?" + requested_query_params_str
logging_obj.pre_call(
input=[{"role": "user", "content": "no-message-pass-through-endpoint"}],
api_key="",
additional_args={
"complete_input_dict": _parsed_body,
"api_base": logging_url,
"api_base": str(url),
"headers": headers,
},
)
@ -526,6 +532,7 @@ def create_pass_through_route(
fastapi_response: Response,
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
query_params: Optional[dict] = None,
custom_body: Optional[dict] = None,
stream: Optional[
bool
] = None, # if pass-through endpoint is a streaming request
@ -538,6 +545,7 @@ def create_pass_through_route(
forward_headers=_forward_headers,
query_params=query_params,
stream=stream,
custom_body=custom_body,
)
return endpoint_func

View file

@ -45,6 +45,16 @@ router = APIRouter()
default_vertex_config = None
def create_request_copy(request: Request):
return {
"method": request.method,
"url": str(request.url),
"headers": dict(request.headers),
"cookies": request.cookies,
"query_params": dict(request.query_params),
}
@router.api_route("/gemini/{endpoint:path}", methods=["GET", "POST", "PUT", "DELETE"])
async def gemini_proxy_route(
endpoint: str,
@ -136,3 +146,72 @@ async def cohere_proxy_route(
)
return received_value
@router.api_route("/bedrock/{endpoint:path}", methods=["GET", "POST", "PUT", "DELETE"])
async def bedrock_proxy_route(
endpoint: str,
request: Request,
fastapi_response: Response,
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
):
request_copy = create_request_copy(request)
try:
import boto3
from botocore.auth import SigV4Auth
from botocore.awsrequest import AWSRequest
from botocore.credentials import Credentials
except ImportError as e:
raise ImportError("Missing boto3 to call bedrock. Run 'pip install boto3'.")
aws_region_name = litellm.utils.get_secret(secret_name="AWS_REGION_NAME")
base_target_url = f"https://bedrock-runtime.{aws_region_name}.amazonaws.com"
encoded_endpoint = httpx.URL(endpoint).path
# Ensure endpoint starts with '/' for proper URL construction
if not encoded_endpoint.startswith("/"):
encoded_endpoint = "/" + encoded_endpoint
# Construct the full target URL using httpx
base_url = httpx.URL(base_target_url)
updated_url = base_url.copy_with(path=encoded_endpoint)
# Add or update query parameters
from litellm.llms.bedrock_httpx import BedrockConverseLLM
credentials: Credentials = BedrockConverseLLM().get_credentials()
sigv4 = SigV4Auth(credentials, "bedrock", aws_region_name)
headers = {"Content-Type": "application/json"}
# Assuming the body contains JSON data, parse it
try:
data = await request.json()
except Exception as e:
raise HTTPException(status_code=400, detail={"error": e})
_request = AWSRequest(
method="POST", url=str(updated_url), data=json.dumps(data), headers=headers
)
sigv4.add_auth(_request)
prepped = _request.prepare()
## check for streaming
is_streaming_request = False
if "stream" in str(updated_url):
is_streaming_request = True
## CREATE PASS-THROUGH
endpoint_func = create_pass_through_route(
endpoint=endpoint,
target=str(prepped.url),
custom_headers=prepped.headers,
) # dynamically construct pass-through endpoint based on incoming path
received_value = await endpoint_func(
request,
fastapi_response,
user_api_key_dict,
stream=is_streaming_request,
custom_body=data,
query_params={},
)
return received_value

View file

@ -3375,14 +3375,14 @@ def response_format_tests(response: litellm.ModelResponse):
@pytest.mark.parametrize(
"model",
[
"bedrock/cohere.command-r-plus-v1:0",
# "bedrock/cohere.command-r-plus-v1:0",
"anthropic.claude-3-sonnet-20240229-v1:0",
"anthropic.claude-instant-v1",
"bedrock/ai21.j2-mid",
"mistral.mistral-7b-instruct-v0:2",
"bedrock/amazon.titan-tg1-large",
"meta.llama3-8b-instruct-v1:0",
"cohere.command-text-v14",
# "anthropic.claude-instant-v1",
# "bedrock/ai21.j2-mid",
# "mistral.mistral-7b-instruct-v0:2",
# "bedrock/amazon.titan-tg1-large",
# "meta.llama3-8b-instruct-v1:0",
# "cohere.command-text-v14",
],
)
@pytest.mark.parametrize("sync_mode", [True, False])

View file

@ -9053,6 +9053,9 @@ class CustomStreamWrapper:
text = ""
is_finished = False
finish_reason = ""
index: Optional[int] = None
if "index" in data_json:
index = data_json.get("index")
if "text" in data_json:
text = data_json["text"]
elif "is_finished" in data_json:
@ -9061,6 +9064,7 @@ class CustomStreamWrapper:
else:
raise Exception(data_json)
return {
"index": index,
"text": text,
"is_finished": is_finished,
"finish_reason": finish_reason,
@ -10246,6 +10250,10 @@ class CustomStreamWrapper:
completion_obj["role"] = "assistant"
self.sent_first_chunk = True
model_response.choices[0].delta = Delta(**completion_obj)
if completion_obj.get("index") is not None:
model_response.choices[0].index = completion_obj.get(
"index"
)
print_verbose(f"returning model_response: {model_response}")
return model_response
else: