mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 11:14:04 +00:00
Add cost tracking for rerank via bedrock (#8691)
* feat(bedrock/rerank): infer model region if model given as arn * test: add unit testing to ensure bedrock region name inferred from arn on rerank * feat(bedrock/rerank/transformation.py): include search units for bedrock rerank result Resolves https://github.com/BerriAI/litellm/issues/7258#issuecomment-2671557137 * test(test_bedrock_completion.py): add testing for bedrock cohere rerank * feat(cost_calculator.py): refactor rerank cost tracking to support bedrock cost tracking * build(model_prices_and_context_window.json): add amazon.rerank model to model cost map * fix(cost_calculator.py): bedrock/common_utils.py get base model from model w/ arn -> handles rerank model * build(model_prices_and_context_window.json): add bedrock cohere rerank pricing * feat(bedrock/rerank): migrate bedrock config to basererank config * Revert "feat(bedrock/rerank): migrate bedrock config to basererank config" This reverts commit84fae1f167
. * test: add testing to ensure large doc / queries are correctly counted * Revert "test: add testing to ensure large doc / queries are correctly counted" This reverts commit4337f1657e
. * fix(migrate-jina-ai-to-rerank-config): enables cost tracking * refactor(jina_ai/): finish migrating jina ai to base rerank config enables cost tracking * fix(jina_ai/rerank): e2e jina ai rerank cost tracking * fix: cleanup dead code * fix: fix python3.8 compatibility error * test: fix test * test: add e2e testing for azure ai rerank * fix: fix linting error * test: mark cohere as flaky
This commit is contained in:
parent
e5f7bde268
commit
30a4f2abc2
26 changed files with 524 additions and 296 deletions
|
@ -6,6 +6,8 @@ import httpx
|
|||
import litellm
|
||||
from litellm.litellm_core_utils.litellm_logging import Logging as LitellmLogging
|
||||
from litellm.llms.custom_httpx.http_handler import (
|
||||
AsyncHTTPHandler,
|
||||
HTTPHandler,
|
||||
_get_httpx_client,
|
||||
get_async_httpx_client,
|
||||
)
|
||||
|
@ -27,8 +29,10 @@ class BedrockRerankHandler(BaseAWSLLM):
|
|||
async def arerank(
|
||||
self,
|
||||
prepared_request: BedrockPreparedRequest,
|
||||
client: Optional[AsyncHTTPHandler] = None,
|
||||
):
|
||||
client = get_async_httpx_client(llm_provider=litellm.LlmProviders.BEDROCK)
|
||||
if client is None:
|
||||
client = get_async_httpx_client(llm_provider=litellm.LlmProviders.BEDROCK)
|
||||
try:
|
||||
response = await client.post(url=prepared_request["endpoint_url"], headers=prepared_request["prepped"].headers, data=prepared_request["body"]) # type: ignore
|
||||
response.raise_for_status()
|
||||
|
@ -54,7 +58,9 @@ class BedrockRerankHandler(BaseAWSLLM):
|
|||
_is_async: Optional[bool] = False,
|
||||
api_base: Optional[str] = None,
|
||||
extra_headers: Optional[dict] = None,
|
||||
client: Optional[Union[HTTPHandler, AsyncHTTPHandler]] = None,
|
||||
) -> RerankResponse:
|
||||
|
||||
request_data = RerankRequest(
|
||||
model=model,
|
||||
query=query,
|
||||
|
@ -66,6 +72,7 @@ class BedrockRerankHandler(BaseAWSLLM):
|
|||
data = BedrockRerankConfig()._transform_request(request_data)
|
||||
|
||||
prepared_request = self._prepare_request(
|
||||
model=model,
|
||||
optional_params=optional_params,
|
||||
api_base=api_base,
|
||||
extra_headers=extra_headers,
|
||||
|
@ -83,9 +90,10 @@ class BedrockRerankHandler(BaseAWSLLM):
|
|||
)
|
||||
|
||||
if _is_async:
|
||||
return self.arerank(prepared_request) # type: ignore
|
||||
return self.arerank(prepared_request, client=client if client is not None and isinstance(client, AsyncHTTPHandler) else None) # type: ignore
|
||||
|
||||
client = _get_httpx_client()
|
||||
if client is None or not isinstance(client, HTTPHandler):
|
||||
client = _get_httpx_client()
|
||||
try:
|
||||
response = client.post(url=prepared_request["endpoint_url"], headers=prepared_request["prepped"].headers, data=prepared_request["body"]) # type: ignore
|
||||
response.raise_for_status()
|
||||
|
@ -95,10 +103,18 @@ class BedrockRerankHandler(BaseAWSLLM):
|
|||
except httpx.TimeoutException:
|
||||
raise BedrockError(status_code=408, message="Timeout error occurred.")
|
||||
|
||||
return BedrockRerankConfig()._transform_response(response.json())
|
||||
logging_obj.post_call(
|
||||
original_response=response.text,
|
||||
api_key="",
|
||||
)
|
||||
|
||||
response_json = response.json()
|
||||
|
||||
return BedrockRerankConfig()._transform_response(response_json)
|
||||
|
||||
def _prepare_request(
|
||||
self,
|
||||
model: str,
|
||||
api_base: Optional[str],
|
||||
extra_headers: Optional[dict],
|
||||
data: dict,
|
||||
|
@ -110,7 +126,7 @@ class BedrockRerankHandler(BaseAWSLLM):
|
|||
except ImportError:
|
||||
raise ImportError("Missing boto3 to call bedrock. Run 'pip install boto3'.")
|
||||
boto3_credentials_info = self._get_boto_credentials_from_optional_params(
|
||||
optional_params
|
||||
optional_params, model
|
||||
)
|
||||
|
||||
### SET RUNTIME ENDPOINT ###
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue