litellm-mirror/tests/llm_translation/test_rerank.py
2024-12-30 10:12:56 -08:00

387 lines
12 KiB
Python

import asyncio
import json
import os
import sys
import traceback
from dotenv import load_dotenv
load_dotenv()
import io
import os
sys.path.insert(
0, os.path.abspath("../..")
) # Adds the parent directory to the system path
import os
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
import litellm
from litellm.types.rerank import RerankResponse
from litellm import RateLimitError, Timeout, completion, completion_cost, embedding
from litellm.integrations.custom_logger import CustomLogger
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
def assert_response_shape(response, custom_llm_provider):
expected_response_shape = {"id": str, "results": list, "meta": dict}
expected_results_shape = {"index": int, "relevance_score": float}
expected_meta_shape = {"api_version": dict, "billed_units": dict}
expected_api_version_shape = {"version": str}
expected_billed_units_shape = {"search_units": int}
assert isinstance(response.id, expected_response_shape["id"])
assert isinstance(response.results, expected_response_shape["results"])
for result in response.results:
assert isinstance(result["index"], expected_results_shape["index"])
assert isinstance(
result["relevance_score"], expected_results_shape["relevance_score"]
)
assert isinstance(response.meta, expected_response_shape["meta"])
if custom_llm_provider == "cohere":
assert isinstance(
response.meta["api_version"], expected_meta_shape["api_version"]
)
assert isinstance(
response.meta["api_version"]["version"],
expected_api_version_shape["version"],
)
assert isinstance(
response.meta["billed_units"], expected_meta_shape["billed_units"]
)
assert isinstance(
response.meta["billed_units"]["search_units"],
expected_billed_units_shape["search_units"],
)
@pytest.mark.asyncio()
@pytest.mark.parametrize("sync_mode", [True, False])
async def test_basic_rerank(sync_mode):
litellm.set_verbose = True
if sync_mode is True:
response = litellm.rerank(
model="cohere/rerank-english-v3.0",
query="hello",
documents=["hello", "world"],
top_n=3,
)
print("re rank response: ", response)
assert response.id is not None
assert response.results is not None
assert_response_shape(response, custom_llm_provider="cohere")
else:
response = await litellm.arerank(
model="cohere/rerank-english-v3.0",
query="hello",
documents=["hello", "world"],
top_n=3,
)
print("async re rank response: ", response)
assert response.id is not None
assert response.results is not None
assert_response_shape(response, custom_llm_provider="cohere")
print("response", response.model_dump_json(indent=4))
@pytest.mark.asyncio()
@pytest.mark.parametrize("sync_mode", [True, False])
async def test_basic_rerank_together_ai(sync_mode):
if sync_mode is True:
response = litellm.rerank(
model="together_ai/Salesforce/Llama-Rank-V1",
query="hello",
documents=["hello", "world"],
top_n=3,
)
print("re rank response: ", response)
assert response.id is not None
assert response.results is not None
assert_response_shape(response, custom_llm_provider="together_ai")
else:
response = await litellm.arerank(
model="together_ai/Salesforce/Llama-Rank-V1",
query="hello",
documents=["hello", "world"],
top_n=3,
)
print("async re rank response: ", response)
assert response.id is not None
assert response.results is not None
assert_response_shape(response, custom_llm_provider="together_ai")
@pytest.mark.asyncio()
@pytest.mark.parametrize("sync_mode", [True, False])
async def test_basic_rerank_azure_ai(sync_mode):
import os
litellm.set_verbose = True
if sync_mode is True:
response = litellm.rerank(
model="azure_ai/Cohere-rerank-v3-multilingual-ko",
query="hello",
documents=["hello", "world"],
top_n=3,
api_key=os.getenv("AZURE_AI_COHERE_API_KEY"),
api_base=os.getenv("AZURE_AI_COHERE_API_BASE"),
)
print("re rank response: ", response)
assert response.id is not None
assert response.results is not None
assert_response_shape(response, custom_llm_provider="together_ai")
else:
response = await litellm.arerank(
model="azure_ai/Cohere-rerank-v3-multilingual-ko",
query="hello",
documents=["hello", "world"],
top_n=3,
api_key=os.getenv("AZURE_AI_COHERE_API_KEY"),
api_base=os.getenv("AZURE_AI_COHERE_API_BASE"),
)
print("async re rank response: ", response)
assert response.id is not None
assert response.results is not None
assert_response_shape(response, custom_llm_provider="together_ai")
@pytest.mark.asyncio()
async def test_rerank_custom_api_base():
mock_response = AsyncMock()
def return_val():
return {
"id": "cmpl-mockid",
"results": [{"index": 0, "relevance_score": 0.95}],
"meta": {
"api_version": {"version": "1.0"},
"billed_units": {"search_units": 1},
},
}
mock_response.json = return_val
mock_response.headers = {"key": "value"}
mock_response.status_code = 200
expected_payload = {
"model": "Salesforce/Llama-Rank-V1",
"query": "hello",
"top_n": 3,
"documents": ["hello", "world"],
}
with patch(
"litellm.llms.custom_httpx.http_handler.AsyncHTTPHandler.post",
return_value=mock_response,
) as mock_post:
response = await litellm.arerank(
model="cohere/Salesforce/Llama-Rank-V1",
query="hello",
documents=["hello", "world"],
top_n=3,
api_base="https://exampleopenaiendpoint-production.up.railway.app/",
)
print("async re rank response: ", response)
# Assert
mock_post.assert_called_once()
print("call args", mock_post.call_args)
args_to_api = mock_post.call_args.kwargs["data"]
_url = mock_post.call_args.kwargs["url"]
print("Arguments passed to API=", args_to_api)
print("url = ", _url)
assert (
_url == "https://exampleopenaiendpoint-production.up.railway.app/v1/rerank"
)
request_data = json.loads(args_to_api)
assert request_data["query"] == expected_payload["query"]
assert request_data["documents"] == expected_payload["documents"]
assert request_data["top_n"] == expected_payload["top_n"]
assert request_data["model"] == expected_payload["model"]
assert response.id is not None
assert response.results is not None
assert_response_shape(response, custom_llm_provider="cohere")
class TestLogger(CustomLogger):
def __init__(self):
self.kwargs = None
self.response_obj = None
super().__init__()
async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
print("in success event for rerank, kwargs = ", kwargs)
print("in success event for rerank, response_obj = ", response_obj)
self.kwargs = kwargs
self.response_obj = response_obj
@pytest.mark.asyncio()
async def test_rerank_custom_callbacks():
os.environ["LITELLM_LOCAL_MODEL_COST_MAP"] = "True"
litellm.model_cost = litellm.get_model_cost_map(url="")
custom_logger = TestLogger()
litellm.callbacks = [custom_logger]
response = await litellm.arerank(
model="cohere/rerank-english-v3.0",
query="hello",
documents=["hello", "world"],
top_n=3,
)
await asyncio.sleep(5)
print("async re rank response: ", response)
assert custom_logger.kwargs is not None
assert custom_logger.kwargs.get("response_cost") > 0.0
assert custom_logger.response_obj is not None
assert custom_logger.response_obj.results is not None
def test_complete_base_url_cohere():
from litellm.llms.custom_httpx.http_handler import HTTPHandler
client = HTTPHandler()
litellm.api_base = "http://localhost:4000"
litellm.set_verbose = True
text = "Hello there!"
list_texts = ["Hello there!", "How are you?", "How do you do?"]
rerank_model = "rerank-multilingual-v3.0"
with patch.object(client, "post") as mock_post:
try:
litellm.rerank(
model=rerank_model,
query=text,
documents=list_texts,
custom_llm_provider="cohere",
client=client,
)
except Exception as e:
print(e)
print("mock_post.call_args", mock_post.call_args)
mock_post.assert_called_once()
assert "http://localhost:4000/v1/rerank" in mock_post.call_args.kwargs["url"]
@pytest.mark.asyncio()
@pytest.mark.parametrize("sync_mode", [True, False])
@pytest.mark.parametrize(
"top_n_1, top_n_2, expect_cache_hit",
[
(3, 3, True),
(3, None, False),
],
)
async def test_basic_rerank_caching(sync_mode, top_n_1, top_n_2, expect_cache_hit):
from litellm.caching.caching import Cache
litellm.set_verbose = True
litellm.cache = Cache(type="local")
if sync_mode is True:
for idx in range(2):
if idx == 0:
top_n = top_n_1
else:
top_n = top_n_2
response = litellm.rerank(
model="cohere/rerank-english-v3.0",
query="hello",
documents=["hello", "world"],
top_n=top_n,
)
else:
for idx in range(2):
if idx == 0:
top_n = top_n_1
else:
top_n = top_n_2
response = await litellm.arerank(
model="cohere/rerank-english-v3.0",
query="hello",
documents=["hello", "world"],
top_n=top_n,
)
await asyncio.sleep(1)
if expect_cache_hit is True:
assert "cache_key" in response._hidden_params
else:
assert "cache_key" not in response._hidden_params
print("re rank response: ", response)
assert response.id is not None
assert response.results is not None
assert_response_shape(response, custom_llm_provider="cohere")
def test_rerank_response_assertions():
r = RerankResponse(
**{
"id": "ab0fcca0-b617-11ef-b292-0242ac110002",
"results": [
{"index": 2, "relevance_score": 0.9958819150924683, "document": None},
{"index": 0, "relevance_score": 0.001293411129154265, "document": None},
{
"index": 1,
"relevance_score": 7.641685078851879e-05,
"document": None,
},
{
"index": 3,
"relevance_score": 7.621097756782547e-05,
"document": None,
},
],
"meta": {
"api_version": None,
"billed_units": None,
"tokens": None,
"warnings": None,
},
}
)
assert_response_shape(r, custom_llm_provider="custom")