forked from phoenix/litellm-mirror
260 lines
7.8 KiB
Python
260 lines
7.8 KiB
Python
import asyncio
|
|
import json
|
|
import os
|
|
import sys
|
|
import traceback
|
|
|
|
from dotenv import load_dotenv
|
|
|
|
load_dotenv()
|
|
import io
|
|
import os
|
|
|
|
sys.path.insert(
|
|
0, os.path.abspath("../..")
|
|
) # Adds the parent directory to the system path
|
|
|
|
import os
|
|
from unittest.mock import AsyncMock, MagicMock, patch
|
|
|
|
import pytest
|
|
|
|
import litellm
|
|
from litellm import RateLimitError, Timeout, completion, completion_cost, embedding
|
|
from litellm.integrations.custom_logger import CustomLogger
|
|
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
|
|
|
|
|
|
def assert_response_shape(response, custom_llm_provider):
|
|
expected_response_shape = {"id": str, "results": list, "meta": dict}
|
|
|
|
expected_results_shape = {"index": int, "relevance_score": float}
|
|
|
|
expected_meta_shape = {"api_version": dict, "billed_units": dict}
|
|
|
|
expected_api_version_shape = {"version": str}
|
|
|
|
expected_billed_units_shape = {"search_units": int}
|
|
|
|
assert isinstance(response.id, expected_response_shape["id"])
|
|
assert isinstance(response.results, expected_response_shape["results"])
|
|
for result in response.results:
|
|
assert isinstance(result["index"], expected_results_shape["index"])
|
|
assert isinstance(
|
|
result["relevance_score"], expected_results_shape["relevance_score"]
|
|
)
|
|
assert isinstance(response.meta, expected_response_shape["meta"])
|
|
|
|
if custom_llm_provider == "cohere":
|
|
|
|
assert isinstance(
|
|
response.meta["api_version"], expected_meta_shape["api_version"]
|
|
)
|
|
assert isinstance(
|
|
response.meta["api_version"]["version"],
|
|
expected_api_version_shape["version"],
|
|
)
|
|
assert isinstance(
|
|
response.meta["billed_units"], expected_meta_shape["billed_units"]
|
|
)
|
|
assert isinstance(
|
|
response.meta["billed_units"]["search_units"],
|
|
expected_billed_units_shape["search_units"],
|
|
)
|
|
|
|
|
|
@pytest.mark.asyncio()
|
|
@pytest.mark.parametrize("sync_mode", [True, False])
|
|
async def test_basic_rerank(sync_mode):
|
|
if sync_mode is True:
|
|
response = litellm.rerank(
|
|
model="cohere/rerank-english-v3.0",
|
|
query="hello",
|
|
documents=["hello", "world"],
|
|
top_n=3,
|
|
)
|
|
|
|
print("re rank response: ", response)
|
|
|
|
assert response.id is not None
|
|
assert response.results is not None
|
|
|
|
assert_response_shape(response, custom_llm_provider="cohere")
|
|
else:
|
|
response = await litellm.arerank(
|
|
model="cohere/rerank-english-v3.0",
|
|
query="hello",
|
|
documents=["hello", "world"],
|
|
top_n=3,
|
|
)
|
|
|
|
print("async re rank response: ", response)
|
|
|
|
assert response.id is not None
|
|
assert response.results is not None
|
|
|
|
assert_response_shape(response, custom_llm_provider="cohere")
|
|
|
|
|
|
@pytest.mark.asyncio()
|
|
@pytest.mark.parametrize("sync_mode", [True, False])
|
|
async def test_basic_rerank_together_ai(sync_mode):
|
|
if sync_mode is True:
|
|
response = litellm.rerank(
|
|
model="together_ai/Salesforce/Llama-Rank-V1",
|
|
query="hello",
|
|
documents=["hello", "world"],
|
|
top_n=3,
|
|
)
|
|
|
|
print("re rank response: ", response)
|
|
|
|
assert response.id is not None
|
|
assert response.results is not None
|
|
|
|
assert_response_shape(response, custom_llm_provider="together_ai")
|
|
else:
|
|
response = await litellm.arerank(
|
|
model="together_ai/Salesforce/Llama-Rank-V1",
|
|
query="hello",
|
|
documents=["hello", "world"],
|
|
top_n=3,
|
|
)
|
|
|
|
print("async re rank response: ", response)
|
|
|
|
assert response.id is not None
|
|
assert response.results is not None
|
|
|
|
assert_response_shape(response, custom_llm_provider="together_ai")
|
|
|
|
|
|
@pytest.mark.asyncio()
|
|
@pytest.mark.parametrize("sync_mode", [True, False])
|
|
async def test_basic_rerank_azure_ai(sync_mode):
|
|
import os
|
|
|
|
litellm.set_verbose = True
|
|
|
|
if sync_mode is True:
|
|
response = litellm.rerank(
|
|
model="azure_ai/Cohere-rerank-v3-multilingual-ko",
|
|
query="hello",
|
|
documents=["hello", "world"],
|
|
top_n=3,
|
|
api_key=os.getenv("AZURE_AI_COHERE_API_KEY"),
|
|
api_base=os.getenv("AZURE_AI_COHERE_API_BASE"),
|
|
)
|
|
|
|
print("re rank response: ", response)
|
|
|
|
assert response.id is not None
|
|
assert response.results is not None
|
|
|
|
assert_response_shape(response, custom_llm_provider="together_ai")
|
|
else:
|
|
response = await litellm.arerank(
|
|
model="azure_ai/Cohere-rerank-v3-multilingual-ko",
|
|
query="hello",
|
|
documents=["hello", "world"],
|
|
top_n=3,
|
|
api_key=os.getenv("AZURE_AI_COHERE_API_KEY"),
|
|
api_base=os.getenv("AZURE_AI_COHERE_API_BASE"),
|
|
)
|
|
|
|
print("async re rank response: ", response)
|
|
|
|
assert response.id is not None
|
|
assert response.results is not None
|
|
|
|
assert_response_shape(response, custom_llm_provider="together_ai")
|
|
|
|
|
|
@pytest.mark.asyncio()
|
|
async def test_rerank_custom_api_base():
|
|
mock_response = AsyncMock()
|
|
|
|
def return_val():
|
|
return {
|
|
"id": "cmpl-mockid",
|
|
"results": [{"index": 0, "relevance_score": 0.95}],
|
|
"meta": {
|
|
"api_version": {"version": "1.0"},
|
|
"billed_units": {"search_units": 1},
|
|
},
|
|
}
|
|
|
|
mock_response.json = return_val
|
|
mock_response.headers = {"key": "value"}
|
|
mock_response.status_code = 200
|
|
|
|
expected_payload = {
|
|
"model": "Salesforce/Llama-Rank-V1",
|
|
"query": "hello",
|
|
"documents": ["hello", "world"],
|
|
"top_n": 3,
|
|
}
|
|
|
|
with patch(
|
|
"litellm.llms.custom_httpx.http_handler.AsyncHTTPHandler.post",
|
|
return_value=mock_response,
|
|
) as mock_post:
|
|
response = await litellm.arerank(
|
|
model="cohere/Salesforce/Llama-Rank-V1",
|
|
query="hello",
|
|
documents=["hello", "world"],
|
|
top_n=3,
|
|
api_base="https://exampleopenaiendpoint-production.up.railway.app/",
|
|
)
|
|
|
|
print("async re rank response: ", response)
|
|
|
|
# Assert
|
|
mock_post.assert_called_once()
|
|
_url, kwargs = mock_post.call_args
|
|
args_to_api = kwargs["json"]
|
|
print("Arguments passed to API=", args_to_api)
|
|
print("url = ", _url)
|
|
assert _url[0] == "https://exampleopenaiendpoint-production.up.railway.app/"
|
|
assert args_to_api == expected_payload
|
|
assert response.id is not None
|
|
assert response.results is not None
|
|
|
|
assert_response_shape(response, custom_llm_provider="cohere")
|
|
|
|
|
|
class TestLogger(CustomLogger):
|
|
|
|
def __init__(self):
|
|
self.kwargs = None
|
|
self.response_obj = None
|
|
super().__init__()
|
|
|
|
async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
|
|
print("in success event for rerank, kwargs = ", kwargs)
|
|
print("in success event for rerank, response_obj = ", response_obj)
|
|
self.kwargs = kwargs
|
|
self.response_obj = response_obj
|
|
|
|
|
|
@pytest.mark.asyncio()
|
|
async def test_rerank_custom_callbacks():
|
|
os.environ["LITELLM_LOCAL_MODEL_COST_MAP"] = "True"
|
|
litellm.model_cost = litellm.get_model_cost_map(url="")
|
|
|
|
custom_logger = TestLogger()
|
|
litellm.callbacks = [custom_logger]
|
|
response = await litellm.arerank(
|
|
model="cohere/rerank-english-v3.0",
|
|
query="hello",
|
|
documents=["hello", "world"],
|
|
top_n=3,
|
|
)
|
|
|
|
await asyncio.sleep(5)
|
|
|
|
print("async re rank response: ", response)
|
|
assert custom_logger.kwargs is not None
|
|
assert custom_logger.kwargs.get("response_cost") > 0.0
|
|
assert custom_logger.response_obj is not None
|
|
assert custom_logger.response_obj.results is not None
|