Add cohere v2/rerank support (#8421) (#8605)

* Add cohere v2/rerank support (#8421)

* Support v2 endpoint cohere rerank

* Add tests and docs

* Make v1 default if old params used

* Update docs

* Update docs pt 2

* Update tests

* Add e2e test

* Clean up code

* Use inheritence for new config

* Fix linting issues (#8608)

* Fix cohere v2 failing test + linting (#8672)

* Fix test and unused imports

* Fix tests

* fix: fix linting errors

* test: handle tgai instability

* fix: skip service unavailable err

* test: print logs for unstable test

* test: skip unreliable tests

---------

Co-authored-by: vibhavbhat <vibhavb00@gmail.com>
This commit is contained in:
Krish Dholakia 2025-02-22 22:25:29 -08:00 committed by GitHub
parent c2aec21b4d
commit 09462ba80c
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
19 changed files with 257 additions and 40 deletions

View file

@ -108,7 +108,7 @@ response = embedding(
### Usage ### Usage
LiteLLM supports the v1 and v2 clients for Cohere rerank. By default, the `rerank` endpoint uses the v2 client, but you can specify the v1 client by explicitly calling `v1/rerank`
<Tabs> <Tabs>
<TabItem value="sdk" label="LiteLLM SDK Usage"> <TabItem value="sdk" label="LiteLLM SDK Usage">

View file

@ -111,7 +111,7 @@ curl http://0.0.0.0:4000/rerank \
| Provider | Link to Usage | | Provider | Link to Usage |
|-------------|--------------------| |-------------|--------------------|
| Cohere | [Usage](#quick-start) | | Cohere (v1 + v2 clients) | [Usage](#quick-start) |
| Together AI| [Usage](../docs/providers/togetherai) | | Together AI| [Usage](../docs/providers/togetherai) |
| Azure AI| [Usage](../docs/providers/azure_ai) | | Azure AI| [Usage](../docs/providers/azure_ai) |
| Jina AI| [Usage](../docs/providers/jina_ai) | | Jina AI| [Usage](../docs/providers/jina_ai) |

View file

@ -824,6 +824,7 @@ from .llms.predibase.chat.transformation import PredibaseConfig
from .llms.replicate.chat.transformation import ReplicateConfig from .llms.replicate.chat.transformation import ReplicateConfig
from .llms.cohere.completion.transformation import CohereTextConfig as CohereConfig from .llms.cohere.completion.transformation import CohereTextConfig as CohereConfig
from .llms.cohere.rerank.transformation import CohereRerankConfig from .llms.cohere.rerank.transformation import CohereRerankConfig
from .llms.cohere.rerank_v2.transformation import CohereRerankV2Config
from .llms.azure_ai.rerank.transformation import AzureAIRerankConfig from .llms.azure_ai.rerank.transformation import AzureAIRerankConfig
from .llms.infinity.rerank.transformation import InfinityRerankConfig from .llms.infinity.rerank.transformation import InfinityRerankConfig
from .llms.jina_ai.rerank.transformation import JinaAIRerankConfig from .llms.jina_ai.rerank.transformation import JinaAIRerankConfig

View file

@ -855,7 +855,10 @@ def rerank_cost(
try: try:
config = ProviderConfigManager.get_provider_rerank_config( config = ProviderConfigManager.get_provider_rerank_config(
model=model, provider=LlmProviders(custom_llm_provider) model=model,
api_base=None,
present_version_params=[],
provider=LlmProviders(custom_llm_provider),
) )
try: try:

View file

@ -17,7 +17,6 @@ class AzureAIRerankConfig(CohereRerankConfig):
""" """
Azure AI Rerank - Follows the same Spec as Cohere Rerank Azure AI Rerank - Follows the same Spec as Cohere Rerank
""" """
def get_complete_url(self, api_base: Optional[str], model: str) -> str: def get_complete_url(self, api_base: Optional[str], model: str) -> str:
if api_base is None: if api_base is None:
raise ValueError( raise ValueError(

View file

@ -77,6 +77,7 @@ class BaseRerankConfig(ABC):
rank_fields: Optional[List[str]] = None, rank_fields: Optional[List[str]] = None,
return_documents: Optional[bool] = True, return_documents: Optional[bool] = True,
max_chunks_per_doc: Optional[int] = None, max_chunks_per_doc: Optional[int] = None,
max_tokens_per_doc: Optional[int] = None,
) -> OptionalRerankParams: ) -> OptionalRerankParams:
pass pass

View file

@ -52,6 +52,7 @@ class CohereRerankConfig(BaseRerankConfig):
rank_fields: Optional[List[str]] = None, rank_fields: Optional[List[str]] = None,
return_documents: Optional[bool] = True, return_documents: Optional[bool] = True,
max_chunks_per_doc: Optional[int] = None, max_chunks_per_doc: Optional[int] = None,
max_tokens_per_doc: Optional[int] = None,
) -> OptionalRerankParams: ) -> OptionalRerankParams:
""" """
Map Cohere rerank params Map Cohere rerank params

View file

@ -0,0 +1,80 @@
from typing import Any, Dict, List, Optional, Union
from litellm.llms.cohere.rerank.transformation import CohereRerankConfig
from litellm.types.rerank import OptionalRerankParams, RerankRequest
class CohereRerankV2Config(CohereRerankConfig):
"""
Reference: https://docs.cohere.com/v2/reference/rerank
"""
def __init__(self) -> None:
pass
def get_complete_url(self, api_base: Optional[str], model: str) -> str:
if api_base:
# Remove trailing slashes and ensure clean base URL
api_base = api_base.rstrip("/")
if not api_base.endswith("/v2/rerank"):
api_base = f"{api_base}/v2/rerank"
return api_base
return "https://api.cohere.ai/v2/rerank"
def get_supported_cohere_rerank_params(self, model: str) -> list:
return [
"query",
"documents",
"top_n",
"max_tokens_per_doc",
"rank_fields",
"return_documents",
]
def map_cohere_rerank_params(
self,
non_default_params: Optional[dict],
model: str,
drop_params: bool,
query: str,
documents: List[Union[str, Dict[str, Any]]],
custom_llm_provider: Optional[str] = None,
top_n: Optional[int] = None,
rank_fields: Optional[List[str]] = None,
return_documents: Optional[bool] = True,
max_chunks_per_doc: Optional[int] = None,
max_tokens_per_doc: Optional[int] = None,
) -> OptionalRerankParams:
"""
Map Cohere rerank params
No mapping required - returns all supported params
"""
return OptionalRerankParams(
query=query,
documents=documents,
top_n=top_n,
rank_fields=rank_fields,
return_documents=return_documents,
max_tokens_per_doc=max_tokens_per_doc,
)
def transform_rerank_request(
self,
model: str,
optional_rerank_params: OptionalRerankParams,
headers: dict,
) -> dict:
if "query" not in optional_rerank_params:
raise ValueError("query is required for Cohere rerank")
if "documents" not in optional_rerank_params:
raise ValueError("documents is required for Cohere rerank")
rerank_request = RerankRequest(
model=model,
query=optional_rerank_params["query"],
documents=optional_rerank_params["documents"],
top_n=optional_rerank_params.get("top_n", None),
rank_fields=optional_rerank_params.get("rank_fields", None),
return_documents=optional_rerank_params.get("return_documents", None),
max_tokens_per_doc=optional_rerank_params.get("max_tokens_per_doc", None),
)
return rerank_request.model_dump(exclude_none=True)

View file

@ -710,6 +710,7 @@ class BaseLLMHTTPHandler:
model: str, model: str,
custom_llm_provider: str, custom_llm_provider: str,
logging_obj: LiteLLMLoggingObj, logging_obj: LiteLLMLoggingObj,
provider_config: BaseRerankConfig,
optional_rerank_params: OptionalRerankParams, optional_rerank_params: OptionalRerankParams,
timeout: Optional[Union[float, httpx.Timeout]], timeout: Optional[Union[float, httpx.Timeout]],
model_response: RerankResponse, model_response: RerankResponse,
@ -720,9 +721,6 @@ class BaseLLMHTTPHandler:
client: Optional[Union[HTTPHandler, AsyncHTTPHandler]] = None, client: Optional[Union[HTTPHandler, AsyncHTTPHandler]] = None,
) -> RerankResponse: ) -> RerankResponse:
provider_config = ProviderConfigManager.get_provider_rerank_config(
model=model, provider=litellm.LlmProviders(custom_llm_provider)
)
# get config from model, custom llm provider # get config from model, custom llm provider
headers = provider_config.validate_environment( headers = provider_config.validate_environment(
api_key=api_key, api_key=api_key,

View file

@ -44,6 +44,7 @@ class JinaAIRerankConfig(BaseRerankConfig):
rank_fields: Optional[List[str]] = None, rank_fields: Optional[List[str]] = None,
return_documents: Optional[bool] = True, return_documents: Optional[bool] = True,
max_chunks_per_doc: Optional[int] = None, max_chunks_per_doc: Optional[int] = None,
max_tokens_per_doc: Optional[int] = None,
) -> OptionalRerankParams: ) -> OptionalRerankParams:
optional_params = {} optional_params = {}
supported_params = self.get_supported_cohere_rerank_params(model) supported_params = self.get_supported_cohere_rerank_params(model)

View file

@ -239,6 +239,7 @@ class LiteLLMRoutes(enum.Enum):
# rerank # rerank
"/rerank", "/rerank",
"/v1/rerank", "/v1/rerank",
"/v2/rerank"
# realtime # realtime
"/realtime", "/realtime",
"/v1/realtime", "/v1/realtime",

View file

@ -11,7 +11,12 @@ from litellm.proxy.auth.user_api_key_auth import user_api_key_auth
router = APIRouter() router = APIRouter()
import asyncio import asyncio
@router.post(
"/v2/rerank",
dependencies=[Depends(user_api_key_auth)],
response_class=ORJSONResponse,
tags=["rerank"],
)
@router.post( @router.post(
"/v1/rerank", "/v1/rerank",
dependencies=[Depends(user_api_key_auth)], dependencies=[Depends(user_api_key_auth)],

View file

@ -81,6 +81,7 @@ def rerank( # noqa: PLR0915
rank_fields: Optional[List[str]] = None, rank_fields: Optional[List[str]] = None,
return_documents: Optional[bool] = True, return_documents: Optional[bool] = True,
max_chunks_per_doc: Optional[int] = None, max_chunks_per_doc: Optional[int] = None,
max_tokens_per_doc: Optional[int] = None,
**kwargs, **kwargs,
) -> Union[RerankResponse, Coroutine[Any, Any, RerankResponse]]: ) -> Union[RerankResponse, Coroutine[Any, Any, RerankResponse]]:
""" """
@ -97,6 +98,14 @@ def rerank( # noqa: PLR0915
try: try:
_is_async = kwargs.pop("arerank", False) is True _is_async = kwargs.pop("arerank", False) is True
optional_params = GenericLiteLLMParams(**kwargs) optional_params = GenericLiteLLMParams(**kwargs)
# Params that are unique to specific versions of the client for the rerank call
unique_version_params = {
"max_chunks_per_doc": max_chunks_per_doc,
"max_tokens_per_doc": max_tokens_per_doc,
}
present_version_params = [
k for k, v in unique_version_params.items() if v is not None
]
model, _custom_llm_provider, dynamic_api_key, dynamic_api_base = ( model, _custom_llm_provider, dynamic_api_key, dynamic_api_base = (
litellm.get_llm_provider( litellm.get_llm_provider(
@ -111,6 +120,8 @@ def rerank( # noqa: PLR0915
ProviderConfigManager.get_provider_rerank_config( ProviderConfigManager.get_provider_rerank_config(
model=model, model=model,
provider=litellm.LlmProviders(_custom_llm_provider), provider=litellm.LlmProviders(_custom_llm_provider),
api_base=optional_params.api_base,
present_version_params=present_version_params,
) )
) )
@ -125,6 +136,7 @@ def rerank( # noqa: PLR0915
rank_fields=rank_fields, rank_fields=rank_fields,
return_documents=return_documents, return_documents=return_documents,
max_chunks_per_doc=max_chunks_per_doc, max_chunks_per_doc=max_chunks_per_doc,
max_tokens_per_doc=max_tokens_per_doc,
non_default_params=kwargs, non_default_params=kwargs,
) )
@ -171,6 +183,7 @@ def rerank( # noqa: PLR0915
response = base_llm_http_handler.rerank( response = base_llm_http_handler.rerank(
model=model, model=model,
custom_llm_provider=_custom_llm_provider, custom_llm_provider=_custom_llm_provider,
provider_config=rerank_provider_config,
optional_rerank_params=optional_rerank_params, optional_rerank_params=optional_rerank_params,
logging_obj=litellm_logging_obj, logging_obj=litellm_logging_obj,
timeout=optional_params.timeout, timeout=optional_params.timeout,
@ -192,6 +205,7 @@ def rerank( # noqa: PLR0915
model=model, model=model,
custom_llm_provider=_custom_llm_provider, custom_llm_provider=_custom_llm_provider,
optional_rerank_params=optional_rerank_params, optional_rerank_params=optional_rerank_params,
provider_config=rerank_provider_config,
logging_obj=litellm_logging_obj, logging_obj=litellm_logging_obj,
timeout=optional_params.timeout, timeout=optional_params.timeout,
api_key=dynamic_api_key or optional_params.api_key, api_key=dynamic_api_key or optional_params.api_key,
@ -220,6 +234,7 @@ def rerank( # noqa: PLR0915
response = base_llm_http_handler.rerank( response = base_llm_http_handler.rerank(
model=model, model=model,
custom_llm_provider=_custom_llm_provider, custom_llm_provider=_custom_llm_provider,
provider_config=rerank_provider_config,
optional_rerank_params=optional_rerank_params, optional_rerank_params=optional_rerank_params,
logging_obj=litellm_logging_obj, logging_obj=litellm_logging_obj,
timeout=optional_params.timeout, timeout=optional_params.timeout,
@ -275,6 +290,7 @@ def rerank( # noqa: PLR0915
custom_llm_provider=_custom_llm_provider, custom_llm_provider=_custom_llm_provider,
optional_rerank_params=optional_rerank_params, optional_rerank_params=optional_rerank_params,
logging_obj=litellm_logging_obj, logging_obj=litellm_logging_obj,
provider_config=rerank_provider_config,
timeout=optional_params.timeout, timeout=optional_params.timeout,
api_key=dynamic_api_key or optional_params.api_key, api_key=dynamic_api_key or optional_params.api_key,
api_base=api_base, api_base=api_base,

View file

@ -15,6 +15,7 @@ def get_optional_rerank_params(
rank_fields: Optional[List[str]] = None, rank_fields: Optional[List[str]] = None,
return_documents: Optional[bool] = True, return_documents: Optional[bool] = True,
max_chunks_per_doc: Optional[int] = None, max_chunks_per_doc: Optional[int] = None,
max_tokens_per_doc: Optional[int] = None,
non_default_params: Optional[dict] = None, non_default_params: Optional[dict] = None,
) -> OptionalRerankParams: ) -> OptionalRerankParams:
all_non_default_params = non_default_params or {} all_non_default_params = non_default_params or {}
@ -28,6 +29,8 @@ def get_optional_rerank_params(
all_non_default_params["return_documents"] = return_documents all_non_default_params["return_documents"] = return_documents
if max_chunks_per_doc is not None: if max_chunks_per_doc is not None:
all_non_default_params["max_chunks_per_doc"] = max_chunks_per_doc all_non_default_params["max_chunks_per_doc"] = max_chunks_per_doc
if max_tokens_per_doc is not None:
all_non_default_params["max_tokens_per_doc"] = max_tokens_per_doc
return rerank_provider_config.map_cohere_rerank_params( return rerank_provider_config.map_cohere_rerank_params(
model=model, model=model,
drop_params=drop_params, drop_params=drop_params,
@ -38,5 +41,6 @@ def get_optional_rerank_params(
rank_fields=rank_fields, rank_fields=rank_fields,
return_documents=return_documents, return_documents=return_documents,
max_chunks_per_doc=max_chunks_per_doc, max_chunks_per_doc=max_chunks_per_doc,
max_tokens_per_doc=max_tokens_per_doc,
non_default_params=all_non_default_params, non_default_params=all_non_default_params,
) )

View file

@ -18,6 +18,8 @@ class RerankRequest(BaseModel):
rank_fields: Optional[List[str]] = None rank_fields: Optional[List[str]] = None
return_documents: Optional[bool] = None return_documents: Optional[bool] = None
max_chunks_per_doc: Optional[int] = None max_chunks_per_doc: Optional[int] = None
max_tokens_per_doc: Optional[int] = None
class OptionalRerankParams(TypedDict, total=False): class OptionalRerankParams(TypedDict, total=False):
@ -27,6 +29,7 @@ class OptionalRerankParams(TypedDict, total=False):
rank_fields: Optional[List[str]] rank_fields: Optional[List[str]]
return_documents: Optional[bool] return_documents: Optional[bool]
max_chunks_per_doc: Optional[int] max_chunks_per_doc: Optional[int]
max_tokens_per_doc: Optional[int]
class RerankBilledUnits(TypedDict, total=False): class RerankBilledUnits(TypedDict, total=False):

View file

@ -6191,9 +6191,14 @@ class ProviderConfigManager:
def get_provider_rerank_config( def get_provider_rerank_config(
model: str, model: str,
provider: LlmProviders, provider: LlmProviders,
api_base: Optional[str],
present_version_params: List[str],
) -> BaseRerankConfig: ) -> BaseRerankConfig:
if litellm.LlmProviders.COHERE == provider: if litellm.LlmProviders.COHERE == provider:
if should_use_cohere_v1_client(api_base, present_version_params):
return litellm.CohereRerankConfig() return litellm.CohereRerankConfig()
else:
return litellm.CohereRerankV2Config()
elif litellm.LlmProviders.AZURE_AI == provider: elif litellm.LlmProviders.AZURE_AI == provider:
return litellm.AzureAIRerankConfig() return litellm.AzureAIRerankConfig()
elif litellm.LlmProviders.INFINITY == provider: elif litellm.LlmProviders.INFINITY == provider:
@ -6277,6 +6282,12 @@ def get_end_user_id_for_cost_tracking(
return None return None
return end_user_id return end_user_id
def should_use_cohere_v1_client(api_base: Optional[str], present_version_params: List[str]):
if not api_base:
return False
uses_v1_params = ("max_chunks_per_doc" in present_version_params) and ('max_tokens_per_doc' not in present_version_params)
return api_base.endswith("/v1/rerank") or (uses_v1_params and not api_base.endswith("/v2/rerank"))
def is_prompt_caching_valid_prompt( def is_prompt_caching_valid_prompt(
model: str, model: str,

View file

@ -1970,6 +1970,26 @@ def test_get_applied_guardrails(test_case):
# Assert # Assert
assert sorted(result) == sorted(test_case["expected"]) assert sorted(result) == sorted(test_case["expected"])
@pytest.mark.parametrize(
"endpoint, params, expected_bool",
[
("localhost:4000/v1/rerank", ["max_chunks_per_doc"], True),
("localhost:4000/v2/rerank", ["max_chunks_per_doc"], False),
("localhost:4000", ["max_chunks_per_doc"], True),
("localhost:4000/v1/rerank", ["max_tokens_per_doc"], True),
("localhost:4000/v2/rerank", ["max_tokens_per_doc"], False),
("localhost:4000", ["max_tokens_per_doc"], False),
("localhost:4000/v1/rerank", ["max_chunks_per_doc", "max_tokens_per_doc"], True),
("localhost:4000/v2/rerank", ["max_chunks_per_doc", "max_tokens_per_doc"], False),
("localhost:4000", ["max_chunks_per_doc", "max_tokens_per_doc"], False),
],
)
def test_should_use_cohere_v1_client(endpoint, params, expected_bool):
assert(litellm.utils.should_use_cohere_v1_client(endpoint, params) == expected_bool)
def test_add_openai_metadata(): def test_add_openai_metadata():
from litellm.utils import add_openai_metadata from litellm.utils import add_openai_metadata

View file

@ -111,7 +111,9 @@ async def test_basic_rerank(sync_mode):
@pytest.mark.asyncio() @pytest.mark.asyncio()
@pytest.mark.parametrize("sync_mode", [True, False]) @pytest.mark.parametrize("sync_mode", [True, False])
@pytest.mark.skip(reason="Skipping test due to 503 Service Temporarily Unavailable")
async def test_basic_rerank_together_ai(sync_mode): async def test_basic_rerank_together_ai(sync_mode):
try:
if sync_mode is True: if sync_mode is True:
response = litellm.rerank( response = litellm.rerank(
model="together_ai/Salesforce/Llama-Rank-V1", model="together_ai/Salesforce/Llama-Rank-V1",
@ -140,6 +142,10 @@ async def test_basic_rerank_together_ai(sync_mode):
assert response.results is not None assert response.results is not None
assert_response_shape(response, custom_llm_provider="together_ai") assert_response_shape(response, custom_llm_provider="together_ai")
except Exception as e:
if "Service unavailable" in str(e):
pytest.skip("Skipping test due to 503 Service Temporarily Unavailable")
raise e
@pytest.mark.asyncio() @pytest.mark.asyncio()
@ -184,8 +190,10 @@ async def test_basic_rerank_azure_ai(sync_mode):
@pytest.mark.asyncio() @pytest.mark.asyncio()
async def test_rerank_custom_api_base(): @pytest.mark.parametrize("version", ["v1", "v2"])
async def test_rerank_custom_api_base(version):
mock_response = AsyncMock() mock_response = AsyncMock()
litellm.cohere_key = "test_api_key"
def return_val(): def return_val():
return { return {
@ -208,6 +216,10 @@ async def test_rerank_custom_api_base():
"documents": ["hello", "world"], "documents": ["hello", "world"],
} }
api_base = "https://exampleopenaiendpoint-production.up.railway.app/"
if version == "v1":
api_base += "v1/rerank"
with patch( with patch(
"litellm.llms.custom_httpx.http_handler.AsyncHTTPHandler.post", "litellm.llms.custom_httpx.http_handler.AsyncHTTPHandler.post",
return_value=mock_response, return_value=mock_response,
@ -217,7 +229,7 @@ async def test_rerank_custom_api_base():
query="hello", query="hello",
documents=["hello", "world"], documents=["hello", "world"],
top_n=3, top_n=3,
api_base="https://exampleopenaiendpoint-production.up.railway.app/", api_base=api_base,
) )
print("async re rank response: ", response) print("async re rank response: ", response)
@ -230,7 +242,8 @@ async def test_rerank_custom_api_base():
print("Arguments passed to API=", args_to_api) print("Arguments passed to API=", args_to_api)
print("url = ", _url) print("url = ", _url)
assert ( assert (
_url == "https://exampleopenaiendpoint-production.up.railway.app/v1/rerank" _url
== f"https://exampleopenaiendpoint-production.up.railway.app/{version}/rerank"
) )
request_data = json.loads(args_to_api) request_data = json.loads(args_to_api)
@ -287,6 +300,7 @@ def test_complete_base_url_cohere():
client = HTTPHandler() client = HTTPHandler()
litellm.api_base = "http://localhost:4000" litellm.api_base = "http://localhost:4000"
litellm.cohere_key = "test_api_key"
litellm.set_verbose = True litellm.set_verbose = True
text = "Hello there!" text = "Hello there!"
@ -308,7 +322,8 @@ def test_complete_base_url_cohere():
print("mock_post.call_args", mock_post.call_args) print("mock_post.call_args", mock_post.call_args)
mock_post.assert_called_once() mock_post.assert_called_once()
assert "http://localhost:4000/v1/rerank" in mock_post.call_args.kwargs["url"] # Default to the v2 client when calling the base /rerank
assert "http://localhost:4000/v2/rerank" in mock_post.call_args.kwargs["url"]
@pytest.mark.asyncio() @pytest.mark.asyncio()
@ -395,6 +410,63 @@ def test_rerank_response_assertions():
assert_response_shape(r, custom_llm_provider="custom") assert_response_shape(r, custom_llm_provider="custom")
def test_cohere_rerank_v2_client():
from litellm.llms.custom_httpx.http_handler import HTTPHandler
client = HTTPHandler()
litellm.api_base = "http://localhost:4000"
litellm.set_verbose = True
text = "Hello there!"
list_texts = ["Hello there!", "How are you?", "How do you do?"]
rerank_model = "rerank-multilingual-v3.0"
with patch.object(client, "post") as mock_post:
mock_response = MagicMock()
mock_response.text = json.dumps(
{
"id": "cmpl-mockid",
"results": [
{"index": 0, "relevance_score": 0.95},
{"index": 1, "relevance_score": 0.75},
{"index": 2, "relevance_score": 0.65},
],
"usage": {"prompt_tokens": 100, "total_tokens": 150},
}
)
mock_response.status_code = 200
mock_response.headers = {"Content-Type": "application/json"}
mock_response.json = lambda: json.loads(mock_response.text)
mock_post.return_value = mock_response
response = litellm.rerank(
model=rerank_model,
query=text,
documents=list_texts,
custom_llm_provider="cohere",
max_tokens_per_doc=3,
top_n=2,
api_key="fake-api-key",
client=client,
)
# Ensure Cohere API is called with the expected params
mock_post.assert_called_once()
assert mock_post.call_args.kwargs["url"] == "http://localhost:4000/v2/rerank"
request_data = json.loads(mock_post.call_args.kwargs["data"])
assert request_data["model"] == rerank_model
assert request_data["query"] == text
assert request_data["documents"] == list_texts
assert request_data["max_tokens_per_doc"] == 3
assert request_data["top_n"] == 2
# Ensure litellm response is what we expect
assert response["results"] == mock_response.json()["results"]
@pytest.mark.flaky(retries=3, delay=1) @pytest.mark.flaky(retries=3, delay=1)
def test_rerank_cohere_api(): def test_rerank_cohere_api():
response = litellm.rerank( response = litellm.rerank(

View file

@ -961,7 +961,8 @@ async def test_gemini_embeddings(sync_mode, input):
@pytest.mark.parametrize("sync_mode", [True, False]) @pytest.mark.parametrize("sync_mode", [True, False])
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.flaky(retries=3, delay=1) @pytest.mark.flaky(retries=6, delay=1)
@pytest.mark.skip(reason="Skipping test due to flakyness")
async def test_hf_embedddings_with_optional_params(sync_mode): async def test_hf_embedddings_with_optional_params(sync_mode):
litellm.set_verbose = True litellm.set_verbose = True
@ -992,8 +993,8 @@ async def test_hf_embedddings_with_optional_params(sync_mode):
wait_for_model=True, wait_for_model=True,
client=client, client=client,
) )
except Exception: except Exception as e:
pass print(e)
mock_client.assert_called_once() mock_client.assert_called_once()