(feat) add infinity rerank models (#7321)

* Support Infinity Reranker (custom reranking models) (#7247)

* Support Infinity Reranker

* Clean code

* Included transformation.py

* Clean code

* Added Infinity reranker test

* Clean code

---------

Co-authored-by: Ishaan Jaff <ishaanjaffer0324@gmail.com>

* transform_rerank_response

* update handler.py

* infinity rerank updates

* ci/cd run again

* add infinity unit tests

* docs add instruction on how to add a new provider for rerank

---------

Co-authored-by: Hao Shan <53949959+haoshan98@users.noreply.github.com>
This commit is contained in:
Ishaan Jaff 2024-12-19 18:30:28 -08:00 committed by GitHub
parent 741500e089
commit 617ac63d14
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
11 changed files with 414 additions and 1 deletions

View file

@ -76,7 +76,9 @@ def rerank( # noqa: PLR0915
model: str,
query: str,
documents: List[Union[str, Dict[str, Any]]],
custom_llm_provider: Optional[Literal["cohere", "together_ai", "azure_ai"]] = None,
custom_llm_provider: Optional[
Literal["cohere", "together_ai", "azure_ai", "infinity"]
] = None,
top_n: Optional[int] = None,
rank_fields: Optional[List[str]] = None,
return_documents: Optional[bool] = True,
@ -188,6 +190,37 @@ def rerank( # noqa: PLR0915
or litellm.api_base
or get_secret("AZURE_AI_API_BASE") # type: ignore
)
response = base_llm_http_handler.rerank(
model=model,
custom_llm_provider=_custom_llm_provider,
optional_rerank_params=optional_rerank_params,
logging_obj=litellm_logging_obj,
timeout=optional_params.timeout,
api_key=dynamic_api_key or optional_params.api_key,
api_base=api_base,
_is_async=_is_async,
headers=headers or litellm.headers or {},
client=client,
model_response=model_response,
)
elif _custom_llm_provider == "infinity":
# Implement Infinity rerank logic
api_key: Optional[str] = (
dynamic_api_key or optional_params.api_key or litellm.api_key
)
api_base: Optional[str] = (
dynamic_api_base
or optional_params.api_base
or litellm.api_base
or get_secret("INFINITY_API_BASE") # type: ignore
)
if api_base is None:
raise Exception(
"Invalid api base. api_base=None. Set in call or via `INFINITY_API_BASE` env var."
)
response = base_llm_http_handler.rerank(
model=model,
custom_llm_provider=_custom_llm_provider,