litellm-mirror/litellm/llms/bedrock/rerank/handler.py
Krish Dholakia 30a4f2abc2 Add cost tracking for rerank via bedrock (#8691)
* feat(bedrock/rerank): infer model region if model given as arn

* test: add unit testing to ensure bedrock region name inferred from arn on rerank

* feat(bedrock/rerank/transformation.py): include search units for bedrock rerank result

Resolves https://github.com/BerriAI/litellm/issues/7258#issuecomment-2671557137

* test(test_bedrock_completion.py): add testing for bedrock cohere rerank

* feat(cost_calculator.py): refactor rerank cost tracking to support bedrock cost tracking

* build(model_prices_and_context_window.json): add amazon.rerank model to model cost map

* fix(cost_calculator.py): bedrock/common_utils.py

get base model from model w/ arn -> handles rerank model

* build(model_prices_and_context_window.json): add bedrock cohere rerank pricing

* feat(bedrock/rerank): migrate bedrock config to basererank config

* Revert "feat(bedrock/rerank): migrate bedrock config to basererank config"

This reverts commit 84fae1f167.

* test: add testing to ensure large doc / queries are correctly counted

* Revert "test: add testing to ensure large doc / queries are correctly counted"

This reverts commit 4337f1657e.

* fix(migrate-jina-ai-to-rerank-config): enables cost tracking

* refactor(jina_ai/): finish migrating jina ai to base rerank config

enables cost tracking

* fix(jina_ai/rerank): e2e jina ai rerank cost tracking

* fix: cleanup dead code

* fix: fix python3.8 compatibility error

* test: fix test

* test: add e2e testing for azure ai rerank

* fix: fix linting error

* test: mark cohere as flaky
2025-02-20 21:00:18 -08:00

168 lines
6 KiB
Python

import json
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union, cast
import httpx
import litellm
from litellm.litellm_core_utils.litellm_logging import Logging as LitellmLogging
from litellm.llms.custom_httpx.http_handler import (
AsyncHTTPHandler,
HTTPHandler,
_get_httpx_client,
get_async_httpx_client,
)
from litellm.types.llms.bedrock import BedrockPreparedRequest
from litellm.types.rerank import RerankRequest
from litellm.types.utils import RerankResponse
from ..base_aws_llm import BaseAWSLLM
from ..common_utils import BedrockError
from .transformation import BedrockRerankConfig
if TYPE_CHECKING:
from botocore.awsrequest import AWSPreparedRequest
else:
AWSPreparedRequest = Any
class BedrockRerankHandler(BaseAWSLLM):
async def arerank(
self,
prepared_request: BedrockPreparedRequest,
client: Optional[AsyncHTTPHandler] = None,
):
if client is None:
client = get_async_httpx_client(llm_provider=litellm.LlmProviders.BEDROCK)
try:
response = await client.post(url=prepared_request["endpoint_url"], headers=prepared_request["prepped"].headers, data=prepared_request["body"]) # type: ignore
response.raise_for_status()
except httpx.HTTPStatusError as err:
error_code = err.response.status_code
raise BedrockError(status_code=error_code, message=err.response.text)
except httpx.TimeoutException:
raise BedrockError(status_code=408, message="Timeout error occurred.")
return BedrockRerankConfig()._transform_response(response.json())
def rerank(
self,
model: str,
query: str,
documents: List[Union[str, Dict[str, Any]]],
optional_params: dict,
logging_obj: LitellmLogging,
top_n: Optional[int] = None,
rank_fields: Optional[List[str]] = None,
return_documents: Optional[bool] = True,
max_chunks_per_doc: Optional[int] = None,
_is_async: Optional[bool] = False,
api_base: Optional[str] = None,
extra_headers: Optional[dict] = None,
client: Optional[Union[HTTPHandler, AsyncHTTPHandler]] = None,
) -> RerankResponse:
request_data = RerankRequest(
model=model,
query=query,
documents=documents,
top_n=top_n,
rank_fields=rank_fields,
return_documents=return_documents,
)
data = BedrockRerankConfig()._transform_request(request_data)
prepared_request = self._prepare_request(
model=model,
optional_params=optional_params,
api_base=api_base,
extra_headers=extra_headers,
data=cast(dict, data),
)
logging_obj.pre_call(
input=data,
api_key="",
additional_args={
"complete_input_dict": data,
"api_base": prepared_request["endpoint_url"],
"headers": prepared_request["prepped"].headers,
},
)
if _is_async:
return self.arerank(prepared_request, client=client if client is not None and isinstance(client, AsyncHTTPHandler) else None) # type: ignore
if client is None or not isinstance(client, HTTPHandler):
client = _get_httpx_client()
try:
response = client.post(url=prepared_request["endpoint_url"], headers=prepared_request["prepped"].headers, data=prepared_request["body"]) # type: ignore
response.raise_for_status()
except httpx.HTTPStatusError as err:
error_code = err.response.status_code
raise BedrockError(status_code=error_code, message=err.response.text)
except httpx.TimeoutException:
raise BedrockError(status_code=408, message="Timeout error occurred.")
logging_obj.post_call(
original_response=response.text,
api_key="",
)
response_json = response.json()
return BedrockRerankConfig()._transform_response(response_json)
def _prepare_request(
self,
model: str,
api_base: Optional[str],
extra_headers: Optional[dict],
data: dict,
optional_params: dict,
) -> BedrockPreparedRequest:
try:
from botocore.auth import SigV4Auth
from botocore.awsrequest import AWSRequest
except ImportError:
raise ImportError("Missing boto3 to call bedrock. Run 'pip install boto3'.")
boto3_credentials_info = self._get_boto_credentials_from_optional_params(
optional_params, model
)
### SET RUNTIME ENDPOINT ###
_, proxy_endpoint_url = self.get_runtime_endpoint(
api_base=api_base,
aws_bedrock_runtime_endpoint=boto3_credentials_info.aws_bedrock_runtime_endpoint,
aws_region_name=boto3_credentials_info.aws_region_name,
)
proxy_endpoint_url = proxy_endpoint_url.replace(
"bedrock-runtime", "bedrock-agent-runtime"
)
proxy_endpoint_url = f"{proxy_endpoint_url}/rerank"
sigv4 = SigV4Auth(
boto3_credentials_info.credentials,
"bedrock",
boto3_credentials_info.aws_region_name,
)
# Make POST Request
body = json.dumps(data).encode("utf-8")
headers = {"Content-Type": "application/json"}
if extra_headers is not None:
headers = {"Content-Type": "application/json", **extra_headers}
request = AWSRequest(
method="POST", url=proxy_endpoint_url, data=body, headers=headers
)
sigv4.add_auth(request)
if (
extra_headers is not None and "Authorization" in extra_headers
): # prevent sigv4 from overwriting the auth header
request.headers["Authorization"] = extra_headers["Authorization"]
prepped = request.prepare()
return BedrockPreparedRequest(
endpoint_url=proxy_endpoint_url,
prepped=prepped,
body=body,
data=data,
)