Add cohere v2/rerank support (#8421) (#8605)

* Add cohere v2/rerank support (#8421)

* Support v2 endpoint cohere rerank

* Add tests and docs

* Make v1 default if old params used

* Update docs

* Update docs pt 2

* Update tests

* Add e2e test

* Clean up code

* Use inheritence for new config

* Fix linting issues (#8608)

* Fix cohere v2 failing test + linting (#8672)

* Fix test and unused imports

* Fix tests

* fix: fix linting errors

* test: handle tgai instability

* fix: skip service unavailable err

* test: print logs for unstable test

* test: skip unreliable tests

---------

Co-authored-by: vibhavbhat <vibhavb00@gmail.com>
This commit is contained in:
Krish Dholakia 2025-02-22 22:25:29 -08:00 committed by GitHub
parent c2aec21b4d
commit 09462ba80c
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
19 changed files with 257 additions and 40 deletions

View file

@ -6191,9 +6191,14 @@ class ProviderConfigManager:
def get_provider_rerank_config(
model: str,
provider: LlmProviders,
api_base: Optional[str],
present_version_params: List[str],
) -> BaseRerankConfig:
if litellm.LlmProviders.COHERE == provider:
return litellm.CohereRerankConfig()
if should_use_cohere_v1_client(api_base, present_version_params):
return litellm.CohereRerankConfig()
else:
return litellm.CohereRerankV2Config()
elif litellm.LlmProviders.AZURE_AI == provider:
return litellm.AzureAIRerankConfig()
elif litellm.LlmProviders.INFINITY == provider:
@ -6277,6 +6282,12 @@ def get_end_user_id_for_cost_tracking(
return None
return end_user_id
def should_use_cohere_v1_client(api_base: Optional[str], present_version_params: List[str]):
if not api_base:
return False
uses_v1_params = ("max_chunks_per_doc" in present_version_params) and ('max_tokens_per_doc' not in present_version_params)
return api_base.endswith("/v1/rerank") or (uses_v1_params and not api_base.endswith("/v2/rerank"))
def is_prompt_caching_valid_prompt(
model: str,