Add cohere v2/rerank support (#8421) (#8605)

* Add cohere v2/rerank support (#8421)

* Support v2 endpoint cohere rerank

* Add tests and docs

* Make v1 default if old params used

* Update docs

* Update docs pt 2

* Update tests

* Add e2e test

* Clean up code

* Use inheritence for new config

* Fix linting issues (#8608)

* Fix cohere v2 failing test + linting (#8672)

* Fix test and unused imports

* Fix tests

* fix: fix linting errors

* test: handle tgai instability

* fix: skip service unavailable err

* test: print logs for unstable test

* test: skip unreliable tests

---------

Co-authored-by: vibhavbhat <vibhavb00@gmail.com>
This commit is contained in:
Krish Dholakia 2025-02-22 22:25:29 -08:00 committed by GitHub
parent c2aec21b4d
commit 09462ba80c
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
19 changed files with 257 additions and 40 deletions

View file

@ -81,6 +81,7 @@ def rerank( # noqa: PLR0915
rank_fields: Optional[List[str]] = None,
return_documents: Optional[bool] = True,
max_chunks_per_doc: Optional[int] = None,
max_tokens_per_doc: Optional[int] = None,
**kwargs,
) -> Union[RerankResponse, Coroutine[Any, Any, RerankResponse]]:
"""
@ -97,6 +98,14 @@ def rerank( # noqa: PLR0915
try:
_is_async = kwargs.pop("arerank", False) is True
optional_params = GenericLiteLLMParams(**kwargs)
# Params that are unique to specific versions of the client for the rerank call
unique_version_params = {
"max_chunks_per_doc": max_chunks_per_doc,
"max_tokens_per_doc": max_tokens_per_doc,
}
present_version_params = [
k for k, v in unique_version_params.items() if v is not None
]
model, _custom_llm_provider, dynamic_api_key, dynamic_api_base = (
litellm.get_llm_provider(
@ -111,6 +120,8 @@ def rerank( # noqa: PLR0915
ProviderConfigManager.get_provider_rerank_config(
model=model,
provider=litellm.LlmProviders(_custom_llm_provider),
api_base=optional_params.api_base,
present_version_params=present_version_params,
)
)
@ -125,6 +136,7 @@ def rerank( # noqa: PLR0915
rank_fields=rank_fields,
return_documents=return_documents,
max_chunks_per_doc=max_chunks_per_doc,
max_tokens_per_doc=max_tokens_per_doc,
non_default_params=kwargs,
)
@ -171,6 +183,7 @@ def rerank( # noqa: PLR0915
response = base_llm_http_handler.rerank(
model=model,
custom_llm_provider=_custom_llm_provider,
provider_config=rerank_provider_config,
optional_rerank_params=optional_rerank_params,
logging_obj=litellm_logging_obj,
timeout=optional_params.timeout,
@ -192,6 +205,7 @@ def rerank( # noqa: PLR0915
model=model,
custom_llm_provider=_custom_llm_provider,
optional_rerank_params=optional_rerank_params,
provider_config=rerank_provider_config,
logging_obj=litellm_logging_obj,
timeout=optional_params.timeout,
api_key=dynamic_api_key or optional_params.api_key,
@ -220,6 +234,7 @@ def rerank( # noqa: PLR0915
response = base_llm_http_handler.rerank(
model=model,
custom_llm_provider=_custom_llm_provider,
provider_config=rerank_provider_config,
optional_rerank_params=optional_rerank_params,
logging_obj=litellm_logging_obj,
timeout=optional_params.timeout,
@ -275,6 +290,7 @@ def rerank( # noqa: PLR0915
custom_llm_provider=_custom_llm_provider,
optional_rerank_params=optional_rerank_params,
logging_obj=litellm_logging_obj,
provider_config=rerank_provider_config,
timeout=optional_params.timeout,
api_key=dynamic_api_key or optional_params.api_key,
api_base=api_base,