litellm-mirror/tests/litellm/rerank_api/test_main.py
Krish Dholakia b682dc4ec8
Add cost tracking for rerank via bedrock (#8691)
* feat(bedrock/rerank): infer model region if model given as arn

* test: add unit testing to ensure bedrock region name inferred from arn on rerank

* feat(bedrock/rerank/transformation.py): include search units for bedrock rerank result

Resolves https://github.com/BerriAI/litellm/issues/7258#issuecomment-2671557137

* test(test_bedrock_completion.py): add testing for bedrock cohere rerank

* feat(cost_calculator.py): refactor rerank cost tracking to support bedrock cost tracking

* build(model_prices_and_context_window.json): add amazon.rerank model to model cost map

* fix(cost_calculator.py): bedrock/common_utils.py

get base model from model w/ arn -> handles rerank model

* build(model_prices_and_context_window.json): add bedrock cohere rerank pricing

* feat(bedrock/rerank): migrate bedrock config to basererank config

* Revert "feat(bedrock/rerank): migrate bedrock config to basererank config"

This reverts commit 84fae1f167.

* test: add testing to ensure large doc / queries are correctly counted

* Revert "test: add testing to ensure large doc / queries are correctly counted"

This reverts commit 4337f1657e.

* fix(migrate-jina-ai-to-rerank-config): enables cost tracking

* refactor(jina_ai/): finish migrating jina ai to base rerank config

enables cost tracking

* fix(jina_ai/rerank): e2e jina ai rerank cost tracking

* fix: cleanup dead code

* fix: fix python3.8 compatibility error

* test: fix test

* test: add e2e testing for azure ai rerank

* fix: fix linting error

* test: mark cohere as flaky
2025-02-20 21:00:18 -08:00

51 lines
1.5 KiB
Python

import json
import os
import sys
import pytest
from fastapi.testclient import TestClient
sys.path.insert(
0, os.path.abspath("../../..")
) # Adds the parent directory to the system path
from unittest.mock import MagicMock, patch
from litellm import rerank
from litellm.llms.custom_httpx.http_handler import HTTPHandler
def test_rerank_infer_region_from_model_arn(monkeypatch):
mock_response = MagicMock()
monkeypatch.setenv("AWS_REGION_NAME", "us-east-1")
args = {
"model": "bedrock/arn:aws:bedrock:us-west-2::foundation-model/amazon.rerank-v1:0",
"query": "hello",
"documents": ["hello", "world"],
}
def return_val():
return {
"results": [
{"index": 0, "relevanceScore": 0.6716859340667725},
{"index": 1, "relevanceScore": 0.0004994205664843321},
]
}
mock_response.json = return_val
mock_response.headers = {"key": "value"}
mock_response.status_code = 200
client = HTTPHandler()
with patch.object(client, "post", return_value=mock_response) as mock_post:
rerank(
model=args["model"],
query=args["query"],
documents=args["documents"],
client=client,
)
mock_post.assert_called_once()
print(f"mock_post.call_args: {mock_post.call_args.kwargs}")
assert "us-west-2" in mock_post.call_args.kwargs["url"]
assert "us-east-1" not in mock_post.call_args.kwargs["url"]