mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 02:34:29 +00:00
* feat(bedrock/rerank): infer model region if model given as arn * test: add unit testing to ensure bedrock region name inferred from arn on rerank * feat(bedrock/rerank/transformation.py): include search units for bedrock rerank result Resolves https://github.com/BerriAI/litellm/issues/7258#issuecomment-2671557137 * test(test_bedrock_completion.py): add testing for bedrock cohere rerank * feat(cost_calculator.py): refactor rerank cost tracking to support bedrock cost tracking * build(model_prices_and_context_window.json): add amazon.rerank model to model cost map * fix(cost_calculator.py): bedrock/common_utils.py get base model from model w/ arn -> handles rerank model * build(model_prices_and_context_window.json): add bedrock cohere rerank pricing * feat(bedrock/rerank): migrate bedrock config to basererank config * Revert "feat(bedrock/rerank): migrate bedrock config to basererank config" This reverts commit84fae1f167
. * test: add testing to ensure large doc / queries are correctly counted * Revert "test: add testing to ensure large doc / queries are correctly counted" This reverts commit4337f1657e
. * fix(migrate-jina-ai-to-rerank-config): enables cost tracking * refactor(jina_ai/): finish migrating jina ai to base rerank config enables cost tracking * fix(jina_ai/rerank): e2e jina ai rerank cost tracking * fix: cleanup dead code * fix: fix python3.8 compatibility error * test: fix test * test: add e2e testing for azure ai rerank * fix: fix linting error * test: mark cohere as flaky
51 lines
1.5 KiB
Python
51 lines
1.5 KiB
Python
import json
|
|
import os
|
|
import sys
|
|
|
|
import pytest
|
|
from fastapi.testclient import TestClient
|
|
|
|
sys.path.insert(
|
|
0, os.path.abspath("../../..")
|
|
) # Adds the parent directory to the system path
|
|
from unittest.mock import MagicMock, patch
|
|
|
|
from litellm import rerank
|
|
from litellm.llms.custom_httpx.http_handler import HTTPHandler
|
|
|
|
|
|
def test_rerank_infer_region_from_model_arn(monkeypatch):
|
|
mock_response = MagicMock()
|
|
|
|
monkeypatch.setenv("AWS_REGION_NAME", "us-east-1")
|
|
args = {
|
|
"model": "bedrock/arn:aws:bedrock:us-west-2::foundation-model/amazon.rerank-v1:0",
|
|
"query": "hello",
|
|
"documents": ["hello", "world"],
|
|
}
|
|
|
|
def return_val():
|
|
return {
|
|
"results": [
|
|
{"index": 0, "relevanceScore": 0.6716859340667725},
|
|
{"index": 1, "relevanceScore": 0.0004994205664843321},
|
|
]
|
|
}
|
|
|
|
mock_response.json = return_val
|
|
mock_response.headers = {"key": "value"}
|
|
mock_response.status_code = 200
|
|
|
|
client = HTTPHandler()
|
|
|
|
with patch.object(client, "post", return_value=mock_response) as mock_post:
|
|
rerank(
|
|
model=args["model"],
|
|
query=args["query"],
|
|
documents=args["documents"],
|
|
client=client,
|
|
)
|
|
mock_post.assert_called_once()
|
|
print(f"mock_post.call_args: {mock_post.call_args.kwargs}")
|
|
assert "us-west-2" in mock_post.call_args.kwargs["url"]
|
|
assert "us-east-1" not in mock_post.call_args.kwargs["url"]
|