mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-27 11:43:54 +00:00
93 lines
2.6 KiB
Python
93 lines
2.6 KiB
Python
import json
|
|
import os
|
|
import sys
|
|
import traceback
|
|
|
|
from dotenv import load_dotenv
|
|
|
|
load_dotenv()
|
|
import io
|
|
import os
|
|
|
|
sys.path.insert(
|
|
0, os.path.abspath("../..")
|
|
) # Adds the parent directory to the system path
|
|
|
|
import os
|
|
from unittest.mock import AsyncMock, MagicMock, patch
|
|
|
|
import pytest
|
|
|
|
import litellm
|
|
from litellm import RateLimitError, Timeout, completion, completion_cost, embedding
|
|
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
|
|
|
|
|
|
def assert_response_shape(response, custom_llm_provider):
|
|
expected_response_shape = {"id": str, "results": list, "meta": dict}
|
|
|
|
expected_results_shape = {"index": int, "relevance_score": float}
|
|
|
|
expected_meta_shape = {"api_version": dict, "billed_units": dict}
|
|
|
|
expected_api_version_shape = {"version": str}
|
|
|
|
expected_billed_units_shape = {"search_units": int}
|
|
|
|
assert isinstance(response.id, expected_response_shape["id"])
|
|
assert isinstance(response.results, expected_response_shape["results"])
|
|
for result in response.results:
|
|
assert isinstance(result["index"], expected_results_shape["index"])
|
|
assert isinstance(
|
|
result["relevance_score"], expected_results_shape["relevance_score"]
|
|
)
|
|
assert isinstance(response.meta, expected_response_shape["meta"])
|
|
|
|
if custom_llm_provider == "cohere":
|
|
|
|
assert isinstance(
|
|
response.meta["api_version"], expected_meta_shape["api_version"]
|
|
)
|
|
assert isinstance(
|
|
response.meta["api_version"]["version"],
|
|
expected_api_version_shape["version"],
|
|
)
|
|
assert isinstance(
|
|
response.meta["billed_units"], expected_meta_shape["billed_units"]
|
|
)
|
|
assert isinstance(
|
|
response.meta["billed_units"]["search_units"],
|
|
expected_billed_units_shape["search_units"],
|
|
)
|
|
|
|
|
|
def test_basic_rerank():
|
|
response = litellm.rerank(
|
|
model="cohere/rerank-english-v3.0",
|
|
query="hello",
|
|
documents=["hello", "world"],
|
|
top_n=3,
|
|
)
|
|
|
|
print("re rank response: ", response)
|
|
|
|
assert response.id is not None
|
|
assert response.results is not None
|
|
|
|
assert_response_shape(response, custom_llm_provider="cohere")
|
|
|
|
|
|
def test_basic_rerank_together_ai():
|
|
response = litellm.rerank(
|
|
model="together_ai/Salesforce/Llama-Rank-V1",
|
|
query="hello",
|
|
documents=["hello", "world"],
|
|
top_n=3,
|
|
)
|
|
|
|
print("re rank response: ", response)
|
|
|
|
assert response.id is not None
|
|
assert response.results is not None
|
|
|
|
assert_response_shape(response, custom_llm_provider="together_ai")
|