(code refactor) - Add BaseRerankConfig. Use BaseRerankConfig for cohere/rerank and azure_ai/rerank (#7319)

* add base rerank config

* working sync cohere rerank

* update rerank types

* update base rerank config

* remove old rerank

* add new cohere handler.py

* add cohere rerank transform

* add get_provider_rerank_config

* add rerank to base llm http handler

* add rerank utils

* add arerank to llm http handler.py

* add AzureAIRerankConfig

* updates rerank config

* update test rerank

* fix unused imports

* update get_provider_rerank_config

* test_basic_rerank_caching

* fix unused import

* test rerank
This commit is contained in:
Ishaan Jaff 2024-12-19 17:03:34 -08:00 committed by GitHub
parent a790d43116
commit 5f15b0aa20
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
19 changed files with 645 additions and 425 deletions

View file

@ -1023,6 +1023,8 @@ from .llms.databricks.embed.transformation import DatabricksEmbeddingConfig
from .llms.predibase.chat.transformation import PredibaseConfig
from .llms.replicate.chat.transformation import ReplicateConfig
from .llms.cohere.completion.transformation import CohereTextConfig as CohereConfig
from .llms.cohere.rerank.transformation import CohereRerankConfig
from .llms.azure_ai.rerank.transformation import AzureAIRerankConfig
from .llms.clarifai.chat.transformation import ClarifaiConfig
from .llms.ai21.chat.transformation import AI21ChatConfig, AI21ChatConfig as AI21Config
from .llms.together_ai.chat import TogetherAIConfig

View file

@ -1 +0,0 @@
from .handler import AzureAIRerank

View file

@ -1,127 +1,5 @@
from typing import Any, Dict, List, Optional, Union
"""
Azure AI Rerank - uses `llm_http_handler.py` to make httpx requests
import httpx
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
from litellm.llms.cohere.rerank import CohereRerank
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
from litellm.types.rerank import RerankResponse
class AzureAIRerank(CohereRerank):
def get_base_model(self, azure_model_group: Optional[str]) -> Optional[str]:
if azure_model_group is None:
return None
if azure_model_group == "offer-cohere-rerank-mul-paygo":
return "azure_ai/cohere-rerank-v3-multilingual"
if azure_model_group == "offer-cohere-rerank-eng-paygo":
return "azure_ai/cohere-rerank-v3-english"
return azure_model_group
async def async_azure_rerank(
self,
model: str,
api_key: str,
api_base: str,
query: str,
documents: List[Union[str, Dict[str, Any]]],
headers: Optional[dict],
litellm_logging_obj: LiteLLMLoggingObj,
top_n: Optional[int] = None,
rank_fields: Optional[List[str]] = None,
return_documents: Optional[bool] = True,
max_chunks_per_doc: Optional[int] = None,
):
returned_response: RerankResponse = await super().rerank( # type: ignore
model=model,
api_key=api_key,
api_base=api_base,
query=query,
documents=documents,
top_n=top_n,
rank_fields=rank_fields,
return_documents=return_documents,
max_chunks_per_doc=max_chunks_per_doc,
_is_async=True,
headers=headers,
litellm_logging_obj=litellm_logging_obj,
)
# get base model
additional_headers = (
returned_response._hidden_params.get("additional_headers") or {}
)
base_model = self.get_base_model(
additional_headers.get("llm_provider-azureml-model-group")
)
returned_response._hidden_params["model"] = base_model
return returned_response
def rerank(
self,
model: str,
api_key: str,
api_base: str,
query: str,
documents: List[Union[str, Dict[str, Any]]],
headers: Optional[dict],
litellm_logging_obj: LiteLLMLoggingObj,
top_n: Optional[int] = None,
rank_fields: Optional[List[str]] = None,
return_documents: Optional[bool] = True,
max_chunks_per_doc: Optional[int] = None,
_is_async: Optional[bool] = False,
client: Optional[Union[HTTPHandler, AsyncHTTPHandler]] = None,
) -> RerankResponse:
if headers is None:
headers = {"Authorization": "Bearer {}".format(api_key)}
else:
headers = {**headers, "Authorization": "Bearer {}".format(api_key)}
# Assuming api_base is a string representing the base URL
api_base_url = httpx.URL(api_base)
# Replace the path with '/v1/rerank' if it doesn't already end with it
if not api_base_url.path.endswith("/v1/rerank"):
api_base = str(api_base_url.copy_with(path="/v1/rerank"))
if _is_async:
return self.async_azure_rerank( # type: ignore
model=model,
api_key=api_key,
api_base=api_base,
query=query,
documents=documents,
top_n=top_n,
rank_fields=rank_fields,
return_documents=return_documents,
max_chunks_per_doc=max_chunks_per_doc,
headers=headers,
litellm_logging_obj=litellm_logging_obj,
)
else:
returned_response = super().rerank(
model=model,
api_key=api_key,
api_base=api_base,
query=query,
documents=documents,
top_n=top_n,
rank_fields=rank_fields,
return_documents=return_documents,
max_chunks_per_doc=max_chunks_per_doc,
_is_async=_is_async,
headers=headers,
litellm_logging_obj=litellm_logging_obj,
)
# get base model
base_model = self.get_base_model(
returned_response._hidden_params.get("llm_provider-azureml-model-group")
)
returned_response._hidden_params["model"] = base_model
return returned_response
Request/Response transformation is handled in `transformation.py`
"""

View file

@ -1,3 +1,91 @@
"""
Translate between Cohere's `/rerank` format and Azure AI's `/rerank` format.
"""
from typing import Optional
import httpx
import litellm
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
from litellm.llms.cohere.rerank.transformation import CohereRerankConfig
from litellm.secret_managers.main import get_secret_str
from litellm.types.utils import RerankResponse
class AzureAIRerankConfig(CohereRerankConfig):
"""
Azure AI Rerank - Follows the same Spec as Cohere Rerank
"""
def get_complete_url(self, api_base: Optional[str], model: str) -> str:
if api_base is None:
raise ValueError(
"Azure AI API Base is required. api_base=None. Set in call or via `AZURE_AI_API_BASE` env var."
)
if not api_base.endswith("/v1/rerank"):
api_base = f"{api_base}/v1/rerank"
return api_base
def validate_environment(
self,
headers: dict,
model: str,
api_key: Optional[str] = None,
) -> dict:
if api_key is None:
api_key = get_secret_str("AZURE_AI_API_KEY") or litellm.azure_key
if api_key is None:
raise ValueError(
"Azure AI API key is required. Please set 'AZURE_AI_API_KEY' or 'litellm.azure_key'"
)
default_headers = {
"Authorization": f"Bearer {api_key}",
"accept": "application/json",
"content-type": "application/json",
}
# If 'Authorization' is provided in headers, it overrides the default.
if "Authorization" in headers:
default_headers["Authorization"] = headers["Authorization"]
# Merge other headers, overriding any default ones except Authorization
return {**default_headers, **headers}
def transform_rerank_response(
self,
model: str,
raw_response: httpx.Response,
model_response: RerankResponse,
logging_obj: LiteLLMLoggingObj,
api_key: Optional[str] = None,
request_data: dict = {},
optional_params: dict = {},
litellm_params: dict = {},
) -> RerankResponse:
rerank_response = super().transform_rerank_response(
model=model,
raw_response=raw_response,
model_response=model_response,
logging_obj=logging_obj,
api_key=api_key,
request_data=request_data,
optional_params=optional_params,
litellm_params=litellm_params,
)
base_model = self._get_base_model(
rerank_response._hidden_params.get("llm_provider-azureml-model-group")
)
rerank_response._hidden_params["model"] = base_model
return rerank_response
def _get_base_model(self, azure_model_group: Optional[str]) -> Optional[str]:
if azure_model_group is None:
return None
if azure_model_group == "offer-cohere-rerank-mul-paygo":
return "azure_ai/cohere-rerank-v3-multilingual"
if azure_model_group == "offer-cohere-rerank-eng-paygo":
return "azure_ai/cohere-rerank-v3-english"
return azure_model_group

View file

@ -0,0 +1,86 @@
from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
import httpx
from litellm.types.rerank import OptionalRerankParams, RerankResponse
from ..chat.transformation import BaseLLMException
if TYPE_CHECKING:
from litellm.litellm_core_utils.litellm_logging import Logging as _LiteLLMLoggingObj
LiteLLMLoggingObj = _LiteLLMLoggingObj
else:
LiteLLMLoggingObj = Any
class BaseRerankConfig(ABC):
@abstractmethod
def validate_environment(
self,
headers: dict,
model: str,
api_key: Optional[str] = None,
) -> dict:
pass
@abstractmethod
def transform_rerank_request(
self,
model: str,
optional_rerank_params: OptionalRerankParams,
headers: dict,
) -> dict:
return {}
@abstractmethod
def transform_rerank_response(
self,
model: str,
raw_response: httpx.Response,
model_response: RerankResponse,
logging_obj: LiteLLMLoggingObj,
api_key: Optional[str] = None,
request_data: dict = {},
optional_params: dict = {},
litellm_params: dict = {},
) -> RerankResponse:
return model_response
@abstractmethod
def get_complete_url(self, api_base: Optional[str], model: str) -> str:
"""
OPTIONAL
Get the complete url for the request
Some providers need `model` in `api_base`
"""
return api_base or ""
@abstractmethod
def get_supported_cohere_rerank_params(self, model: str) -> list:
pass
@abstractmethod
def map_cohere_rerank_params(
self,
non_default_params: Optional[dict],
model: str,
drop_params: bool,
query: str,
documents: List[Union[str, Dict[str, Any]]],
custom_llm_provider: Optional[str] = None,
top_n: Optional[int] = None,
rank_fields: Optional[List[str]] = None,
return_documents: Optional[bool] = True,
max_chunks_per_doc: Optional[int] = None,
) -> OptionalRerankParams:
pass
@abstractmethod
def get_error_class(
self, error_message: str, status_code: int, headers: Union[dict, httpx.Headers]
) -> BaseLLMException:
pass

View file

@ -1,153 +0,0 @@
"""
Re rank api
LiteLLM supports the re rank API format, no paramter transformation occurs
"""
from typing import Any, Dict, List, Optional, Union
import httpx
import litellm
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
from litellm.llms.base import BaseLLM
from litellm.llms.custom_httpx.http_handler import (
AsyncHTTPHandler,
HTTPHandler,
_get_httpx_client,
get_async_httpx_client,
)
from litellm.types.rerank import RerankRequest, RerankResponse
class CohereRerank(BaseLLM):
def validate_environment(self, api_key: str, headers: Optional[dict]) -> dict:
default_headers = {
"accept": "application/json",
"content-type": "application/json",
"Authorization": f"bearer {api_key}",
}
if headers is None:
return default_headers
# If 'Authorization' is provided in headers, it overrides the default.
if "Authorization" in headers:
default_headers["Authorization"] = headers["Authorization"]
# Merge other headers, overriding any default ones except Authorization
return {**default_headers, **headers}
def ensure_rerank_endpoint(self, api_base: str) -> str:
"""
Ensures the `/v1/rerank` endpoint is appended to the given `api_base`.
If `/v1/rerank` is already present, the original URL is returned.
:param api_base: The base API URL.
:return: A URL with `/v1/rerank` appended if missing.
"""
# Parse the base URL to ensure proper structure
url = httpx.URL(api_base)
# Check if the URL already ends with `/v1/rerank`
if not url.path.endswith("/v1/rerank"):
url = url.copy_with(path=f"{url.path.rstrip('/')}/v1/rerank")
return str(url)
def rerank(
self,
model: str,
api_key: str,
api_base: str,
query: str,
documents: List[Union[str, Dict[str, Any]]],
headers: Optional[dict],
litellm_logging_obj: LiteLLMLoggingObj,
top_n: Optional[int] = None,
rank_fields: Optional[List[str]] = None,
return_documents: Optional[bool] = True,
max_chunks_per_doc: Optional[int] = None,
_is_async: Optional[bool] = False, # New parameter
client: Optional[Union[HTTPHandler, AsyncHTTPHandler]] = None,
) -> RerankResponse:
headers = self.validate_environment(api_key=api_key, headers=headers)
api_base = self.ensure_rerank_endpoint(api_base)
request_data = RerankRequest(
model=model,
query=query,
top_n=top_n,
documents=documents,
rank_fields=rank_fields,
return_documents=return_documents,
max_chunks_per_doc=max_chunks_per_doc,
)
request_data_dict = request_data.dict(exclude_none=True)
## LOGGING
litellm_logging_obj.pre_call(
input=request_data_dict,
api_key=api_key,
additional_args={
"complete_input_dict": request_data_dict,
"api_base": api_base,
"headers": headers,
},
)
if _is_async:
return self.async_rerank(request_data=request_data, api_key=api_key, api_base=api_base, headers=headers) # type: ignore # Call async method
if client is not None and isinstance(client, HTTPHandler):
client = client
else:
client = _get_httpx_client()
response = client.post(
url=api_base,
headers=headers,
json=request_data_dict,
)
returned_response = RerankResponse(**response.json())
_response_headers = response.headers
llm_response_headers = {
"{}-{}".format("llm_provider", k): v for k, v in _response_headers.items()
}
returned_response._hidden_params["additional_headers"] = llm_response_headers
return returned_response
async def async_rerank(
self,
request_data: RerankRequest,
api_key: str,
api_base: str,
headers: dict,
client: Optional[AsyncHTTPHandler] = None,
) -> RerankResponse:
request_data_dict = request_data.dict(exclude_none=True)
client = client or get_async_httpx_client(
llm_provider=litellm.LlmProviders.COHERE
)
response = await client.post(
api_base,
headers=headers,
json=request_data_dict,
)
returned_response = RerankResponse(**response.json())
_response_headers = dict(response.headers)
llm_response_headers = {
"{}-{}".format("llm_provider", k): v for k, v in _response_headers.items()
}
returned_response._hidden_params["additional_headers"] = llm_response_headers
returned_response._hidden_params["model"] = request_data.model
return returned_response

View file

@ -0,0 +1,5 @@
"""
Cohere Rerank - uses `llm_http_handler.py` to make httpx requests
Request/Response transformation is handled in `transformation.py`
"""

View file

@ -0,0 +1,150 @@
from typing import Any, Dict, List, Optional, Union
import httpx
import litellm
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
from litellm.llms.base_llm.chat.transformation import BaseLLMException
from litellm.llms.base_llm.rerank.transformation import BaseRerankConfig
from litellm.secret_managers.main import get_secret_str
from litellm.types.rerank import OptionalRerankParams, RerankRequest
from litellm.types.utils import RerankResponse
from ..common_utils import CohereError
class CohereRerankConfig(BaseRerankConfig):
"""
Reference: https://docs.cohere.com/v2/reference/rerank
"""
def __init__(self) -> None:
pass
def get_complete_url(self, api_base: Optional[str], model: str) -> str:
if api_base:
# Remove trailing slashes and ensure clean base URL
api_base = api_base.rstrip("/")
if not api_base.endswith("/v1/rerank"):
api_base = f"{api_base}/v1/rerank"
return api_base
return "https://api.cohere.ai/v1/rerank"
def get_supported_cohere_rerank_params(self, model: str) -> list:
return [
"query",
"documents",
"top_n",
"max_chunks_per_doc",
"rank_fields",
"return_documents",
]
def map_cohere_rerank_params(
self,
non_default_params: Optional[dict],
model: str,
drop_params: bool,
query: str,
documents: List[Union[str, Dict[str, Any]]],
custom_llm_provider: Optional[str] = None,
top_n: Optional[int] = None,
rank_fields: Optional[List[str]] = None,
return_documents: Optional[bool] = True,
max_chunks_per_doc: Optional[int] = None,
) -> OptionalRerankParams:
"""
Map Cohere rerank params
No mapping required - returns all supported params
"""
return OptionalRerankParams(
query=query,
documents=documents,
top_n=top_n,
rank_fields=rank_fields,
return_documents=return_documents,
max_chunks_per_doc=max_chunks_per_doc,
)
def validate_environment(
self,
headers: dict,
model: str,
api_key: Optional[str] = None,
) -> dict:
if api_key is None:
api_key = (
get_secret_str("COHERE_API_KEY")
or get_secret_str("CO_API_KEY")
or litellm.cohere_key
)
if api_key is None:
raise ValueError(
"Cohere API key is required. Please set 'COHERE_API_KEY' or 'CO_API_KEY' or 'litellm.cohere_key'"
)
default_headers = {
"Authorization": f"bearer {api_key}",
"accept": "application/json",
"content-type": "application/json",
}
# If 'Authorization' is provided in headers, it overrides the default.
if "Authorization" in headers:
default_headers["Authorization"] = headers["Authorization"]
# Merge other headers, overriding any default ones except Authorization
return {**default_headers, **headers}
def transform_rerank_request(
self,
model: str,
optional_rerank_params: OptionalRerankParams,
headers: dict,
) -> dict:
if "query" not in optional_rerank_params:
raise ValueError("query is required for Cohere rerank")
if "documents" not in optional_rerank_params:
raise ValueError("documents is required for Cohere rerank")
rerank_request = RerankRequest(
model=model,
query=optional_rerank_params["query"],
documents=optional_rerank_params["documents"],
top_n=optional_rerank_params.get("top_n", None),
rank_fields=optional_rerank_params.get("rank_fields", None),
return_documents=optional_rerank_params.get("return_documents", None),
max_chunks_per_doc=optional_rerank_params.get("max_chunks_per_doc", None),
)
return rerank_request.model_dump(exclude_none=True)
def transform_rerank_response(
self,
model: str,
raw_response: httpx.Response,
model_response: RerankResponse,
logging_obj: LiteLLMLoggingObj,
api_key: Optional[str] = None,
request_data: dict = {},
optional_params: dict = {},
litellm_params: dict = {},
) -> RerankResponse:
"""
Transform Cohere rerank response
No transformation required, litellm follows cohere API response format
"""
try:
raw_response_json = raw_response.json()
except Exception:
raise CohereError(
message=raw_response.text, status_code=raw_response.status_code
)
return RerankResponse(**raw_response_json)
def get_error_class(
self, error_message: str, status_code: int, headers: Union[dict, httpx.Headers]
) -> BaseLLMException:
return CohereError(message=error_message, status_code=status_code)

View file

@ -9,12 +9,14 @@ import litellm.types
import litellm.types.utils
from litellm.llms.base_llm.chat.transformation import BaseConfig, BaseLLMException
from litellm.llms.base_llm.embedding.transformation import BaseEmbeddingConfig
from litellm.llms.base_llm.rerank.transformation import BaseRerankConfig
from litellm.llms.custom_httpx.http_handler import (
AsyncHTTPHandler,
HTTPHandler,
_get_httpx_client,
get_async_httpx_client,
)
from litellm.types.rerank import OptionalRerankParams, RerankResponse
from litellm.types.utils import EmbeddingResponse
from litellm.utils import CustomStreamWrapper, ModelResponse, ProviderConfigManager
@ -524,7 +526,138 @@ class BaseLLMHTTPHandler:
request_data=request_data,
)
def _handle_error(self, e: Exception, provider_config: BaseConfig):
def rerank(
self,
model: str,
custom_llm_provider: str,
logging_obj: LiteLLMLoggingObj,
optional_rerank_params: OptionalRerankParams,
timeout: Optional[Union[float, httpx.Timeout]],
model_response: RerankResponse,
_is_async: bool = False,
headers: dict = {},
api_key: Optional[str] = None,
api_base: Optional[str] = None,
client: Optional[Union[HTTPHandler, AsyncHTTPHandler]] = None,
) -> RerankResponse:
provider_config = ProviderConfigManager.get_provider_rerank_config(
model=model, provider=litellm.LlmProviders(custom_llm_provider)
)
# get config from model, custom llm provider
headers = provider_config.validate_environment(
api_key=api_key,
headers=headers,
model=model,
)
api_base = provider_config.get_complete_url(
api_base=api_base,
model=model,
)
data = provider_config.transform_rerank_request(
model=model,
optional_rerank_params=optional_rerank_params,
headers=headers,
)
## LOGGING
logging_obj.pre_call(
input=optional_rerank_params.get("query", ""),
api_key=api_key,
additional_args={
"complete_input_dict": data,
"api_base": api_base,
"headers": headers,
},
)
if _is_async is True:
return self.arerank( # type: ignore
model=model,
request_data=data,
custom_llm_provider=custom_llm_provider,
provider_config=provider_config,
logging_obj=logging_obj,
model_response=model_response,
api_base=api_base,
headers=headers,
api_key=api_key,
timeout=timeout,
client=client,
)
if client is None or not isinstance(client, HTTPHandler):
sync_httpx_client = _get_httpx_client()
else:
sync_httpx_client = client
try:
response = sync_httpx_client.post(
url=api_base,
headers=headers,
data=json.dumps(data),
timeout=timeout,
)
except Exception as e:
raise self._handle_error(
e=e,
provider_config=provider_config,
)
return provider_config.transform_rerank_response(
model=model,
raw_response=response,
model_response=model_response,
logging_obj=logging_obj,
api_key=api_key,
request_data=data,
)
async def arerank(
self,
model: str,
request_data: dict,
custom_llm_provider: str,
provider_config: BaseRerankConfig,
logging_obj: LiteLLMLoggingObj,
model_response: RerankResponse,
api_base: str,
headers: dict,
api_key: Optional[str] = None,
timeout: Optional[Union[float, httpx.Timeout]] = None,
client: Optional[Union[HTTPHandler, AsyncHTTPHandler]] = None,
) -> RerankResponse:
if client is None or not isinstance(client, AsyncHTTPHandler):
async_httpx_client = get_async_httpx_client(
llm_provider=litellm.LlmProviders(custom_llm_provider)
)
else:
async_httpx_client = client
try:
response = await async_httpx_client.post(
url=api_base,
headers=headers,
data=json.dumps(request_data),
timeout=timeout,
)
except Exception as e:
raise self._handle_error(e=e, provider_config=provider_config)
return provider_config.transform_rerank_response(
model=model,
raw_response=response,
model_response=model_response,
logging_obj=logging_obj,
api_key=api_key,
request_data=request_data,
)
def _handle_error(
self, e: Exception, provider_config: Union[BaseConfig, BaseRerankConfig]
):
status_code = getattr(e, "status_code", 500)
error_headers = getattr(e, "headers", None)
error_text = getattr(e, "text", str(e))

View file

@ -6,23 +6,23 @@ from typing import Any, Coroutine, Dict, List, Literal, Optional, Union
import litellm
from litellm._logging import verbose_logger
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
from litellm.llms.azure_ai.rerank import AzureAIRerank
from litellm.llms.base_llm.rerank.transformation import BaseRerankConfig
from litellm.llms.bedrock.rerank.handler import BedrockRerankHandler
from litellm.llms.cohere.rerank import CohereRerank
from litellm.llms.custom_httpx.llm_http_handler import BaseLLMHTTPHandler
from litellm.llms.jina_ai.rerank.handler import JinaAIRerank
from litellm.llms.together_ai.rerank.handler import TogetherAIRerank
from litellm.rerank_api.rerank_utils import get_optional_rerank_params
from litellm.secret_managers.main import get_secret
from litellm.types.rerank import RerankResponse
from litellm.types.rerank import OptionalRerankParams, RerankResponse
from litellm.types.router import *
from litellm.utils import client, exception_type
from litellm.utils import ProviderConfigManager, client, exception_type
####### ENVIRONMENT VARIABLES ###################
# Initialize any necessary instances or variables here
cohere_rerank = CohereRerank()
together_rerank = TogetherAIRerank()
azure_ai_rerank = AzureAIRerank()
jina_ai_rerank = JinaAIRerank()
bedrock_rerank = BedrockRerankHandler()
base_llm_http_handler = BaseLLMHTTPHandler()
#################################################
@ -107,18 +107,36 @@ def rerank( # noqa: PLR0915
)
)
model_params_dict = {
"top_n": top_n,
"rank_fields": rank_fields,
"return_documents": return_documents,
"max_chunks_per_doc": max_chunks_per_doc,
"documents": documents,
}
rerank_provider_config: BaseRerankConfig = (
ProviderConfigManager.get_provider_rerank_config(
model=model,
provider=litellm.LlmProviders(_custom_llm_provider),
)
)
optional_rerank_params: OptionalRerankParams = get_optional_rerank_params(
rerank_provider_config=rerank_provider_config,
model=model,
drop_params=kwargs.get("drop_params") or litellm.drop_params or False,
query=query,
documents=documents,
custom_llm_provider=_custom_llm_provider,
top_n=top_n,
rank_fields=rank_fields,
return_documents=return_documents,
max_chunks_per_doc=max_chunks_per_doc,
non_default_params=kwargs,
)
if isinstance(optional_params.timeout, str):
optional_params.timeout = float(optional_params.timeout)
model_response = RerankResponse()
litellm_logging_obj.update_environment_variables(
model=model,
user=user,
optional_params=model_params_dict,
optional_params=optional_rerank_params,
litellm_params={
"litellm_call_id": litellm_call_id,
"proxy_server_request": proxy_server_request,
@ -135,17 +153,7 @@ def rerank( # noqa: PLR0915
if _custom_llm_provider == "cohere":
# Implement Cohere rerank logic
api_key: Optional[str] = (
dynamic_api_key
or optional_params.api_key
or litellm.cohere_key
or get_secret("COHERE_API_KEY") # type: ignore
or get_secret("CO_API_KEY") # type: ignore
or litellm.api_key
)
if api_key is None:
raise ValueError(
"Cohere API key is required, please set 'COHERE_API_KEY' in your environment"
dynamic_api_key or optional_params.api_key or litellm.api_key
)
api_base: Optional[str] = (
@ -160,23 +168,18 @@ def rerank( # noqa: PLR0915
raise Exception(
"Invalid api base. api_base=None. Set in call or via `COHERE_API_BASE` env var."
)
headers = headers or litellm.headers or {}
response = cohere_rerank.rerank(
response = base_llm_http_handler.rerank(
model=model,
query=query,
documents=documents,
top_n=top_n,
rank_fields=rank_fields,
return_documents=return_documents,
max_chunks_per_doc=max_chunks_per_doc,
api_key=api_key,
custom_llm_provider=_custom_llm_provider,
optional_rerank_params=optional_rerank_params,
logging_obj=litellm_logging_obj,
timeout=optional_params.timeout,
api_key=dynamic_api_key or optional_params.api_key,
api_base=api_base,
_is_async=_is_async,
headers=headers,
litellm_logging_obj=litellm_logging_obj,
headers=headers or litellm.headers or {},
client=client,
model_response=model_response,
)
elif _custom_llm_provider == "azure_ai":
api_base = (
@ -185,47 +188,18 @@ def rerank( # noqa: PLR0915
or litellm.api_base
or get_secret("AZURE_AI_API_BASE") # type: ignore
)
# set API KEY
api_key = (
dynamic_api_key
or litellm.api_key # for deepinfra/perplexity/anyscale/friendliai we check in get_llm_provider and pass in the api key from there
or litellm.openai_key
or get_secret("AZURE_AI_API_KEY") # type: ignore
)
headers = headers or litellm.headers or {}
if api_key is None:
raise ValueError(
"Azure AI API key is required, please set 'AZURE_AI_API_KEY' in your environment"
)
if api_base is None:
raise Exception(
"Azure AI API Base is required. api_base=None. Set in call or via `AZURE_AI_API_BASE` env var."
)
## LOAD CONFIG - if set
config = litellm.OpenAIConfig.get_config()
for k, v in config.items():
if (
k not in optional_params
): # completion(top_k=3) > openai_config(top_k=3) <- allows for dynamic variables to be passed in
optional_params[k] = v
response = azure_ai_rerank.rerank(
response = base_llm_http_handler.rerank(
model=model,
query=query,
documents=documents,
top_n=top_n,
rank_fields=rank_fields,
return_documents=return_documents,
max_chunks_per_doc=max_chunks_per_doc,
api_key=api_key,
custom_llm_provider=_custom_llm_provider,
optional_rerank_params=optional_rerank_params,
logging_obj=litellm_logging_obj,
timeout=optional_params.timeout,
api_key=dynamic_api_key or optional_params.api_key,
api_base=api_base,
_is_async=_is_async,
headers=headers,
litellm_logging_obj=litellm_logging_obj,
headers=headers or litellm.headers or {},
client=client,
model_response=model_response,
)
elif _custom_llm_provider == "together_ai":
# Implement Together AI rerank logic

View file

@ -0,0 +1,31 @@
from typing import Any, Dict, List, Optional, Union
from litellm.llms.base_llm.rerank.transformation import BaseRerankConfig
from litellm.types.rerank import OptionalRerankParams
def get_optional_rerank_params(
rerank_provider_config: BaseRerankConfig,
model: str,
drop_params: bool,
query: str,
documents: List[Union[str, Dict[str, Any]]],
custom_llm_provider: Optional[str] = None,
top_n: Optional[int] = None,
rank_fields: Optional[List[str]] = None,
return_documents: Optional[bool] = True,
max_chunks_per_doc: Optional[int] = None,
non_default_params: Optional[dict] = None,
) -> OptionalRerankParams:
return rerank_provider_config.map_cohere_rerank_params(
model=model,
drop_params=drop_params,
query=query,
documents=documents,
custom_llm_provider=custom_llm_provider,
top_n=top_n,
rank_fields=rank_fields,
return_documents=return_documents,
max_chunks_per_doc=max_chunks_per_doc,
non_default_params=non_default_params,
)

View file

@ -20,6 +20,15 @@ class RerankRequest(BaseModel):
max_chunks_per_doc: Optional[int] = None
class OptionalRerankParams(TypedDict, total=False):
query: str
top_n: Optional[int]
documents: List[Union[str, dict]]
rank_fields: Optional[List[str]]
return_documents: Optional[bool]
max_chunks_per_doc: Optional[int]
class RerankBilledUnits(TypedDict, total=False):
search_units: int
total_tokens: int
@ -42,8 +51,10 @@ class RerankResponseResult(TypedDict):
class RerankResponse(BaseModel):
id: str
results: List[RerankResponseResult] # Contains index and relevance_score
id: Optional[str] = None
results: Optional[List[RerankResponseResult]] = (
None # Contains index and relevance_score
)
meta: Optional[RerankResponseMeta] = None # Contains api_version and billed_units
# Define private attributes using PrivateAttr

View file

@ -171,6 +171,7 @@ from openai import OpenAIError as OriginalError
from litellm.llms.base_llm.chat.transformation import BaseConfig
from litellm.llms.base_llm.embedding.transformation import BaseEmbeddingConfig
from litellm.llms.base_llm.rerank.transformation import BaseRerankConfig
from ._logging import verbose_logger
from .caching.caching import (
@ -6204,6 +6205,17 @@ class ProviderConfigManager:
return litellm.VoyageEmbeddingConfig()
raise ValueError(f"Provider {provider} does not support embedding config")
@staticmethod
def get_provider_rerank_config(
model: str,
provider: LlmProviders,
) -> BaseRerankConfig:
if litellm.LlmProviders.COHERE == provider:
return litellm.CohereRerankConfig()
elif litellm.LlmProviders.AZURE_AI == provider:
return litellm.AzureAIRerankConfig()
return litellm.CohereRerankConfig()
def get_end_user_id_for_cost_tracking(
litellm_params: dict,

View file

@ -66,6 +66,7 @@ def assert_response_shape(response, custom_llm_provider):
@pytest.mark.asyncio()
@pytest.mark.parametrize("sync_mode", [True, False])
async def test_basic_rerank(sync_mode):
litellm.set_verbose = True
if sync_mode is True:
response = litellm.rerank(
model="cohere/rerank-english-v3.0",
@ -95,6 +96,8 @@ async def test_basic_rerank(sync_mode):
assert_response_shape(response, custom_llm_provider="cohere")
print("response", response.model_dump_json(indent=4))
@pytest.mark.asyncio()
@pytest.mark.parametrize("sync_mode", [True, False])
@ -191,8 +194,8 @@ async def test_rerank_custom_api_base():
expected_payload = {
"model": "Salesforce/Llama-Rank-V1",
"query": "hello",
"documents": ["hello", "world"],
"top_n": 3,
"documents": ["hello", "world"],
}
with patch(
@ -211,15 +214,21 @@ async def test_rerank_custom_api_base():
# Assert
mock_post.assert_called_once()
_url, kwargs = mock_post.call_args
args_to_api = kwargs["json"]
print("call args", mock_post.call_args)
args_to_api = mock_post.call_args.kwargs["data"]
_url = mock_post.call_args.kwargs["url"]
print("Arguments passed to API=", args_to_api)
print("url = ", _url)
assert (
_url[0]
== "https://exampleopenaiendpoint-production.up.railway.app/v1/rerank"
_url == "https://exampleopenaiendpoint-production.up.railway.app/v1/rerank"
)
assert args_to_api == expected_payload
request_data = json.loads(args_to_api)
assert request_data["query"] == expected_payload["query"]
assert request_data["documents"] == expected_payload["documents"]
assert request_data["top_n"] == expected_payload["top_n"]
assert request_data["model"] == expected_payload["model"]
assert response.id is not None
assert response.results is not None
@ -290,3 +299,58 @@ def test_complete_base_url_cohere():
print("mock_post.call_args", mock_post.call_args)
mock_post.assert_called_once()
assert "http://localhost:4000/v1/rerank" in mock_post.call_args.kwargs["url"]
@pytest.mark.asyncio()
@pytest.mark.parametrize("sync_mode", [True, False])
@pytest.mark.parametrize(
"top_n_1, top_n_2, expect_cache_hit",
[
(3, 3, True),
(3, None, False),
],
)
async def test_basic_rerank_caching(sync_mode, top_n_1, top_n_2, expect_cache_hit):
from litellm.caching.caching import Cache
litellm.set_verbose = True
litellm.cache = Cache(type="local")
if sync_mode is True:
for idx in range(2):
if idx == 0:
top_n = top_n_1
else:
top_n = top_n_2
response = litellm.rerank(
model="cohere/rerank-english-v3.0",
query="hello",
documents=["hello", "world"],
top_n=top_n,
)
else:
for idx in range(2):
if idx == 0:
top_n = top_n_1
else:
top_n = top_n_2
response = await litellm.arerank(
model="cohere/rerank-english-v3.0",
query="hello",
documents=["hello", "world"],
top_n=top_n,
)
await asyncio.sleep(1)
if expect_cache_hit is True:
assert "cache_key" in response._hidden_params
else:
assert "cache_key" not in response._hidden_params
print("re rank response: ", response)
assert response.id is not None
assert response.results is not None
assert_response_shape(response, custom_llm_provider="cohere")

View file

@ -7,7 +7,6 @@ import traceback
import uuid
from dotenv import load_dotenv
from test_rerank import assert_response_shape
load_dotenv()
import os

View file

@ -5,7 +5,6 @@ import traceback
import uuid
from dotenv import load_dotenv
from test_rerank import assert_response_shape
load_dotenv()
import os
@ -2132,59 +2131,6 @@ def test_logging_turn_off_message_logging_streaming():
assert mock_client.call_args.args[0].choices[0].message.content == "hello"
@pytest.mark.asyncio()
@pytest.mark.parametrize("sync_mode", [True, False])
@pytest.mark.parametrize(
"top_n_1, top_n_2, expect_cache_hit",
[
(3, 3, True),
(3, None, False),
],
)
async def test_basic_rerank_caching(sync_mode, top_n_1, top_n_2, expect_cache_hit):
litellm.set_verbose = True
litellm.cache = Cache(type="local")
if sync_mode is True:
for idx in range(2):
if idx == 0:
top_n = top_n_1
else:
top_n = top_n_2
response = litellm.rerank(
model="cohere/rerank-english-v3.0",
query="hello",
documents=["hello", "world"],
top_n=top_n,
)
else:
for idx in range(2):
if idx == 0:
top_n = top_n_1
else:
top_n = top_n_2
response = await litellm.arerank(
model="cohere/rerank-english-v3.0",
query="hello",
documents=["hello", "world"],
top_n=top_n,
)
await asyncio.sleep(1)
if expect_cache_hit is True:
assert "cache_key" in response._hidden_params
else:
assert "cache_key" not in response._hidden_params
print("re rank response: ", response)
assert response.id is not None
assert response.results is not None
assert_response_shape(response, custom_llm_provider="cohere")
def test_basic_caching_import():
from litellm.caching import Cache

View file

@ -5,8 +5,6 @@ import traceback
import uuid
from dotenv import load_dotenv
from test_rerank import assert_response_shape
load_dotenv()
sys.path.insert(

View file

@ -5,7 +5,6 @@ import traceback
import uuid
from dotenv import load_dotenv
from test_rerank import assert_response_shape
load_dotenv()
import os

View file

@ -5,8 +5,6 @@ import traceback
import uuid
from dotenv import load_dotenv
from test_rerank import assert_response_shape
load_dotenv()
sys.path.insert(