mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 10:44:24 +00:00
* Add cohere v2/rerank support (#8421) * Support v2 endpoint cohere rerank * Add tests and docs * Make v1 default if old params used * Update docs * Update docs pt 2 * Update tests * Add e2e test * Clean up code * Use inheritence for new config * Fix linting issues (#8608) * Fix cohere v2 failing test + linting (#8672) * Fix test and unused imports * Fix tests * fix: fix linting errors * test: handle tgai instability * fix: skip service unavailable err * test: print logs for unstable test * test: skip unreliable tests --------- Co-authored-by: vibhavbhat <vibhavb00@gmail.com>
This commit is contained in:
parent
c2aec21b4d
commit
09462ba80c
19 changed files with 257 additions and 40 deletions
|
@ -81,6 +81,7 @@ def rerank( # noqa: PLR0915
|
|||
rank_fields: Optional[List[str]] = None,
|
||||
return_documents: Optional[bool] = True,
|
||||
max_chunks_per_doc: Optional[int] = None,
|
||||
max_tokens_per_doc: Optional[int] = None,
|
||||
**kwargs,
|
||||
) -> Union[RerankResponse, Coroutine[Any, Any, RerankResponse]]:
|
||||
"""
|
||||
|
@ -97,6 +98,14 @@ def rerank( # noqa: PLR0915
|
|||
try:
|
||||
_is_async = kwargs.pop("arerank", False) is True
|
||||
optional_params = GenericLiteLLMParams(**kwargs)
|
||||
# Params that are unique to specific versions of the client for the rerank call
|
||||
unique_version_params = {
|
||||
"max_chunks_per_doc": max_chunks_per_doc,
|
||||
"max_tokens_per_doc": max_tokens_per_doc,
|
||||
}
|
||||
present_version_params = [
|
||||
k for k, v in unique_version_params.items() if v is not None
|
||||
]
|
||||
|
||||
model, _custom_llm_provider, dynamic_api_key, dynamic_api_base = (
|
||||
litellm.get_llm_provider(
|
||||
|
@ -111,6 +120,8 @@ def rerank( # noqa: PLR0915
|
|||
ProviderConfigManager.get_provider_rerank_config(
|
||||
model=model,
|
||||
provider=litellm.LlmProviders(_custom_llm_provider),
|
||||
api_base=optional_params.api_base,
|
||||
present_version_params=present_version_params,
|
||||
)
|
||||
)
|
||||
|
||||
|
@ -125,6 +136,7 @@ def rerank( # noqa: PLR0915
|
|||
rank_fields=rank_fields,
|
||||
return_documents=return_documents,
|
||||
max_chunks_per_doc=max_chunks_per_doc,
|
||||
max_tokens_per_doc=max_tokens_per_doc,
|
||||
non_default_params=kwargs,
|
||||
)
|
||||
|
||||
|
@ -171,6 +183,7 @@ def rerank( # noqa: PLR0915
|
|||
response = base_llm_http_handler.rerank(
|
||||
model=model,
|
||||
custom_llm_provider=_custom_llm_provider,
|
||||
provider_config=rerank_provider_config,
|
||||
optional_rerank_params=optional_rerank_params,
|
||||
logging_obj=litellm_logging_obj,
|
||||
timeout=optional_params.timeout,
|
||||
|
@ -192,6 +205,7 @@ def rerank( # noqa: PLR0915
|
|||
model=model,
|
||||
custom_llm_provider=_custom_llm_provider,
|
||||
optional_rerank_params=optional_rerank_params,
|
||||
provider_config=rerank_provider_config,
|
||||
logging_obj=litellm_logging_obj,
|
||||
timeout=optional_params.timeout,
|
||||
api_key=dynamic_api_key or optional_params.api_key,
|
||||
|
@ -220,6 +234,7 @@ def rerank( # noqa: PLR0915
|
|||
response = base_llm_http_handler.rerank(
|
||||
model=model,
|
||||
custom_llm_provider=_custom_llm_provider,
|
||||
provider_config=rerank_provider_config,
|
||||
optional_rerank_params=optional_rerank_params,
|
||||
logging_obj=litellm_logging_obj,
|
||||
timeout=optional_params.timeout,
|
||||
|
@ -275,6 +290,7 @@ def rerank( # noqa: PLR0915
|
|||
custom_llm_provider=_custom_llm_provider,
|
||||
optional_rerank_params=optional_rerank_params,
|
||||
logging_obj=litellm_logging_obj,
|
||||
provider_config=rerank_provider_config,
|
||||
timeout=optional_params.timeout,
|
||||
api_key=dynamic_api_key or optional_params.api_key,
|
||||
api_base=api_base,
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue