diff --git a/litellm/proxy/pass_through_endpoints/pass_through_endpoints.py b/litellm/proxy/pass_through_endpoints/pass_through_endpoints.py index 5b9e04d1f..bc16c3555 100644 --- a/litellm/proxy/pass_through_endpoints/pass_through_endpoints.py +++ b/litellm/proxy/pass_through_endpoints/pass_through_endpoints.py @@ -283,6 +283,7 @@ async def pass_through_request( target: str, custom_headers: dict, user_api_key_dict: UserAPIKeyAuth, + custom_body: Optional[dict] = None, forward_headers: Optional[bool] = False, query_params: Optional[dict] = None, stream: Optional[bool] = None, @@ -300,12 +301,15 @@ async def pass_through_request( request=request, headers=headers, forward_headers=forward_headers ) - request_body = await request.body() - body_str = request_body.decode() - try: - _parsed_body = ast.literal_eval(body_str) - except Exception: - _parsed_body = json.loads(body_str) + if custom_body: + _parsed_body = custom_body + else: + request_body = await request.body() + body_str = request_body.decode() + try: + _parsed_body = ast.literal_eval(body_str) + except Exception: + _parsed_body = json.loads(body_str) verbose_proxy_logger.debug( "Pass through endpoint sending request to \nURL {}\nheaders: {}\nbody: {}\n".format( @@ -356,22 +360,24 @@ async def pass_through_request( # combine url with query params for logging - requested_query_params = query_params or request.query_params.__dict__ - requested_query_params_str = "&".join( - f"{k}={v}" for k, v in requested_query_params.items() - ) + # requested_query_params = query_params or request.query_params.__dict__ + # requested_query_params_str = "&".join( + # f"{k}={v}" for k, v in requested_query_params.items() + # ) - if "?" in str(url): - logging_url = str(url) + "&" + requested_query_params_str - else: - logging_url = str(url) + "?" + requested_query_params_str + requested_query_params = None + + # if "?" in str(url): + # logging_url = str(url) + "&" + requested_query_params_str + # else: + # logging_url = str(url) + "?" + requested_query_params_str logging_obj.pre_call( input=[{"role": "user", "content": "no-message-pass-through-endpoint"}], api_key="", additional_args={ "complete_input_dict": _parsed_body, - "api_base": logging_url, + "api_base": str(url), "headers": headers, }, ) @@ -526,6 +532,7 @@ def create_pass_through_route( fastapi_response: Response, user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), query_params: Optional[dict] = None, + custom_body: Optional[dict] = None, stream: Optional[ bool ] = None, # if pass-through endpoint is a streaming request @@ -538,6 +545,7 @@ def create_pass_through_route( forward_headers=_forward_headers, query_params=query_params, stream=stream, + custom_body=custom_body, ) return endpoint_func diff --git a/litellm/proxy/vertex_ai_endpoints/google_ai_studio_endpoints.py b/litellm/proxy/vertex_ai_endpoints/google_ai_studio_endpoints.py index c798e091f..cfb5231f9 100644 --- a/litellm/proxy/vertex_ai_endpoints/google_ai_studio_endpoints.py +++ b/litellm/proxy/vertex_ai_endpoints/google_ai_studio_endpoints.py @@ -45,6 +45,16 @@ router = APIRouter() default_vertex_config = None +def create_request_copy(request: Request): + return { + "method": request.method, + "url": str(request.url), + "headers": dict(request.headers), + "cookies": request.cookies, + "query_params": dict(request.query_params), + } + + @router.api_route("/gemini/{endpoint:path}", methods=["GET", "POST", "PUT", "DELETE"]) async def gemini_proxy_route( endpoint: str, @@ -136,3 +146,72 @@ async def cohere_proxy_route( ) return received_value + + +@router.api_route("/bedrock/{endpoint:path}", methods=["GET", "POST", "PUT", "DELETE"]) +async def bedrock_proxy_route( + endpoint: str, + request: Request, + fastapi_response: Response, + user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), +): + request_copy = create_request_copy(request) + + try: + import boto3 + from botocore.auth import SigV4Auth + from botocore.awsrequest import AWSRequest + from botocore.credentials import Credentials + except ImportError as e: + raise ImportError("Missing boto3 to call bedrock. Run 'pip install boto3'.") + + aws_region_name = litellm.utils.get_secret(secret_name="AWS_REGION_NAME") + base_target_url = f"https://bedrock-runtime.{aws_region_name}.amazonaws.com" + encoded_endpoint = httpx.URL(endpoint).path + + # Ensure endpoint starts with '/' for proper URL construction + if not encoded_endpoint.startswith("/"): + encoded_endpoint = "/" + encoded_endpoint + + # Construct the full target URL using httpx + base_url = httpx.URL(base_target_url) + updated_url = base_url.copy_with(path=encoded_endpoint) + + # Add or update query parameters + from litellm.llms.bedrock_httpx import BedrockConverseLLM + + credentials: Credentials = BedrockConverseLLM().get_credentials() + sigv4 = SigV4Auth(credentials, "bedrock", aws_region_name) + headers = {"Content-Type": "application/json"} + # Assuming the body contains JSON data, parse it + try: + data = await request.json() + except Exception as e: + raise HTTPException(status_code=400, detail={"error": e}) + _request = AWSRequest( + method="POST", url=str(updated_url), data=json.dumps(data), headers=headers + ) + sigv4.add_auth(_request) + prepped = _request.prepare() + + ## check for streaming + is_streaming_request = False + if "stream" in str(updated_url): + is_streaming_request = True + + ## CREATE PASS-THROUGH + endpoint_func = create_pass_through_route( + endpoint=endpoint, + target=str(prepped.url), + custom_headers=prepped.headers, + ) # dynamically construct pass-through endpoint based on incoming path + received_value = await endpoint_func( + request, + fastapi_response, + user_api_key_dict, + stream=is_streaming_request, + custom_body=data, + query_params={}, + ) + + return received_value diff --git a/litellm/tests/test_completion.py b/litellm/tests/test_completion.py index 654b210ff..0a7037428 100644 --- a/litellm/tests/test_completion.py +++ b/litellm/tests/test_completion.py @@ -3375,14 +3375,14 @@ def response_format_tests(response: litellm.ModelResponse): @pytest.mark.parametrize( "model", [ - "bedrock/cohere.command-r-plus-v1:0", + # "bedrock/cohere.command-r-plus-v1:0", "anthropic.claude-3-sonnet-20240229-v1:0", - "anthropic.claude-instant-v1", - "bedrock/ai21.j2-mid", - "mistral.mistral-7b-instruct-v0:2", - "bedrock/amazon.titan-tg1-large", - "meta.llama3-8b-instruct-v1:0", - "cohere.command-text-v14", + # "anthropic.claude-instant-v1", + # "bedrock/ai21.j2-mid", + # "mistral.mistral-7b-instruct-v0:2", + # "bedrock/amazon.titan-tg1-large", + # "meta.llama3-8b-instruct-v1:0", + # "cohere.command-text-v14", ], ) @pytest.mark.parametrize("sync_mode", [True, False]) diff --git a/litellm/utils.py b/litellm/utils.py index 2371a2a43..eff3b4346 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -9053,6 +9053,9 @@ class CustomStreamWrapper: text = "" is_finished = False finish_reason = "" + index: Optional[int] = None + if "index" in data_json: + index = data_json.get("index") if "text" in data_json: text = data_json["text"] elif "is_finished" in data_json: @@ -9061,6 +9064,7 @@ class CustomStreamWrapper: else: raise Exception(data_json) return { + "index": index, "text": text, "is_finished": is_finished, "finish_reason": finish_reason, @@ -10246,6 +10250,10 @@ class CustomStreamWrapper: completion_obj["role"] = "assistant" self.sent_first_chunk = True model_response.choices[0].delta = Delta(**completion_obj) + if completion_obj.get("index") is not None: + model_response.choices[0].index = completion_obj.get( + "index" + ) print_verbose(f"returning model_response: {model_response}") return model_response else: