(code refactor) - Add BaseRerankConfig. Use BaseRerankConfig for cohere/rerank and azure_ai/rerank (#7319)

* add base rerank config

* working sync cohere rerank

* update rerank types

* update base rerank config

* remove old rerank

* add new cohere handler.py

* add cohere rerank transform

* add get_provider_rerank_config

* add rerank to base llm http handler

* add rerank utils

* add arerank to llm http handler.py

* add AzureAIRerankConfig

* updates rerank config

* update test rerank

* fix unused imports

* update get_provider_rerank_config

* test_basic_rerank_caching

* fix unused import

* test rerank
This commit is contained in:
Ishaan Jaff 2024-12-19 17:03:34 -08:00 committed by GitHub
parent a790d43116
commit 5f15b0aa20
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
19 changed files with 645 additions and 425 deletions

View file

@ -1023,6 +1023,8 @@ from .llms.databricks.embed.transformation import DatabricksEmbeddingConfig
from .llms.predibase.chat.transformation import PredibaseConfig from .llms.predibase.chat.transformation import PredibaseConfig
from .llms.replicate.chat.transformation import ReplicateConfig from .llms.replicate.chat.transformation import ReplicateConfig
from .llms.cohere.completion.transformation import CohereTextConfig as CohereConfig from .llms.cohere.completion.transformation import CohereTextConfig as CohereConfig
from .llms.cohere.rerank.transformation import CohereRerankConfig
from .llms.azure_ai.rerank.transformation import AzureAIRerankConfig
from .llms.clarifai.chat.transformation import ClarifaiConfig from .llms.clarifai.chat.transformation import ClarifaiConfig
from .llms.ai21.chat.transformation import AI21ChatConfig, AI21ChatConfig as AI21Config from .llms.ai21.chat.transformation import AI21ChatConfig, AI21ChatConfig as AI21Config
from .llms.together_ai.chat import TogetherAIConfig from .llms.together_ai.chat import TogetherAIConfig

View file

@ -1 +0,0 @@
from .handler import AzureAIRerank

View file

@ -1,127 +1,5 @@
from typing import Any, Dict, List, Optional, Union """
Azure AI Rerank - uses `llm_http_handler.py` to make httpx requests
import httpx Request/Response transformation is handled in `transformation.py`
"""
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
from litellm.llms.cohere.rerank import CohereRerank
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
from litellm.types.rerank import RerankResponse
class AzureAIRerank(CohereRerank):
def get_base_model(self, azure_model_group: Optional[str]) -> Optional[str]:
if azure_model_group is None:
return None
if azure_model_group == "offer-cohere-rerank-mul-paygo":
return "azure_ai/cohere-rerank-v3-multilingual"
if azure_model_group == "offer-cohere-rerank-eng-paygo":
return "azure_ai/cohere-rerank-v3-english"
return azure_model_group
async def async_azure_rerank(
self,
model: str,
api_key: str,
api_base: str,
query: str,
documents: List[Union[str, Dict[str, Any]]],
headers: Optional[dict],
litellm_logging_obj: LiteLLMLoggingObj,
top_n: Optional[int] = None,
rank_fields: Optional[List[str]] = None,
return_documents: Optional[bool] = True,
max_chunks_per_doc: Optional[int] = None,
):
returned_response: RerankResponse = await super().rerank( # type: ignore
model=model,
api_key=api_key,
api_base=api_base,
query=query,
documents=documents,
top_n=top_n,
rank_fields=rank_fields,
return_documents=return_documents,
max_chunks_per_doc=max_chunks_per_doc,
_is_async=True,
headers=headers,
litellm_logging_obj=litellm_logging_obj,
)
# get base model
additional_headers = (
returned_response._hidden_params.get("additional_headers") or {}
)
base_model = self.get_base_model(
additional_headers.get("llm_provider-azureml-model-group")
)
returned_response._hidden_params["model"] = base_model
return returned_response
def rerank(
self,
model: str,
api_key: str,
api_base: str,
query: str,
documents: List[Union[str, Dict[str, Any]]],
headers: Optional[dict],
litellm_logging_obj: LiteLLMLoggingObj,
top_n: Optional[int] = None,
rank_fields: Optional[List[str]] = None,
return_documents: Optional[bool] = True,
max_chunks_per_doc: Optional[int] = None,
_is_async: Optional[bool] = False,
client: Optional[Union[HTTPHandler, AsyncHTTPHandler]] = None,
) -> RerankResponse:
if headers is None:
headers = {"Authorization": "Bearer {}".format(api_key)}
else:
headers = {**headers, "Authorization": "Bearer {}".format(api_key)}
# Assuming api_base is a string representing the base URL
api_base_url = httpx.URL(api_base)
# Replace the path with '/v1/rerank' if it doesn't already end with it
if not api_base_url.path.endswith("/v1/rerank"):
api_base = str(api_base_url.copy_with(path="/v1/rerank"))
if _is_async:
return self.async_azure_rerank( # type: ignore
model=model,
api_key=api_key,
api_base=api_base,
query=query,
documents=documents,
top_n=top_n,
rank_fields=rank_fields,
return_documents=return_documents,
max_chunks_per_doc=max_chunks_per_doc,
headers=headers,
litellm_logging_obj=litellm_logging_obj,
)
else:
returned_response = super().rerank(
model=model,
api_key=api_key,
api_base=api_base,
query=query,
documents=documents,
top_n=top_n,
rank_fields=rank_fields,
return_documents=return_documents,
max_chunks_per_doc=max_chunks_per_doc,
_is_async=_is_async,
headers=headers,
litellm_logging_obj=litellm_logging_obj,
)
# get base model
base_model = self.get_base_model(
returned_response._hidden_params.get("llm_provider-azureml-model-group")
)
returned_response._hidden_params["model"] = base_model
return returned_response

View file

@ -1,3 +1,91 @@
""" """
Translate between Cohere's `/rerank` format and Azure AI's `/rerank` format. Translate between Cohere's `/rerank` format and Azure AI's `/rerank` format.
""" """
from typing import Optional
import httpx
import litellm
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
from litellm.llms.cohere.rerank.transformation import CohereRerankConfig
from litellm.secret_managers.main import get_secret_str
from litellm.types.utils import RerankResponse
class AzureAIRerankConfig(CohereRerankConfig):
"""
Azure AI Rerank - Follows the same Spec as Cohere Rerank
"""
def get_complete_url(self, api_base: Optional[str], model: str) -> str:
if api_base is None:
raise ValueError(
"Azure AI API Base is required. api_base=None. Set in call or via `AZURE_AI_API_BASE` env var."
)
if not api_base.endswith("/v1/rerank"):
api_base = f"{api_base}/v1/rerank"
return api_base
def validate_environment(
self,
headers: dict,
model: str,
api_key: Optional[str] = None,
) -> dict:
if api_key is None:
api_key = get_secret_str("AZURE_AI_API_KEY") or litellm.azure_key
if api_key is None:
raise ValueError(
"Azure AI API key is required. Please set 'AZURE_AI_API_KEY' or 'litellm.azure_key'"
)
default_headers = {
"Authorization": f"Bearer {api_key}",
"accept": "application/json",
"content-type": "application/json",
}
# If 'Authorization' is provided in headers, it overrides the default.
if "Authorization" in headers:
default_headers["Authorization"] = headers["Authorization"]
# Merge other headers, overriding any default ones except Authorization
return {**default_headers, **headers}
def transform_rerank_response(
self,
model: str,
raw_response: httpx.Response,
model_response: RerankResponse,
logging_obj: LiteLLMLoggingObj,
api_key: Optional[str] = None,
request_data: dict = {},
optional_params: dict = {},
litellm_params: dict = {},
) -> RerankResponse:
rerank_response = super().transform_rerank_response(
model=model,
raw_response=raw_response,
model_response=model_response,
logging_obj=logging_obj,
api_key=api_key,
request_data=request_data,
optional_params=optional_params,
litellm_params=litellm_params,
)
base_model = self._get_base_model(
rerank_response._hidden_params.get("llm_provider-azureml-model-group")
)
rerank_response._hidden_params["model"] = base_model
return rerank_response
def _get_base_model(self, azure_model_group: Optional[str]) -> Optional[str]:
if azure_model_group is None:
return None
if azure_model_group == "offer-cohere-rerank-mul-paygo":
return "azure_ai/cohere-rerank-v3-multilingual"
if azure_model_group == "offer-cohere-rerank-eng-paygo":
return "azure_ai/cohere-rerank-v3-english"
return azure_model_group

View file

@ -0,0 +1,86 @@
from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
import httpx
from litellm.types.rerank import OptionalRerankParams, RerankResponse
from ..chat.transformation import BaseLLMException
if TYPE_CHECKING:
from litellm.litellm_core_utils.litellm_logging import Logging as _LiteLLMLoggingObj
LiteLLMLoggingObj = _LiteLLMLoggingObj
else:
LiteLLMLoggingObj = Any
class BaseRerankConfig(ABC):
@abstractmethod
def validate_environment(
self,
headers: dict,
model: str,
api_key: Optional[str] = None,
) -> dict:
pass
@abstractmethod
def transform_rerank_request(
self,
model: str,
optional_rerank_params: OptionalRerankParams,
headers: dict,
) -> dict:
return {}
@abstractmethod
def transform_rerank_response(
self,
model: str,
raw_response: httpx.Response,
model_response: RerankResponse,
logging_obj: LiteLLMLoggingObj,
api_key: Optional[str] = None,
request_data: dict = {},
optional_params: dict = {},
litellm_params: dict = {},
) -> RerankResponse:
return model_response
@abstractmethod
def get_complete_url(self, api_base: Optional[str], model: str) -> str:
"""
OPTIONAL
Get the complete url for the request
Some providers need `model` in `api_base`
"""
return api_base or ""
@abstractmethod
def get_supported_cohere_rerank_params(self, model: str) -> list:
pass
@abstractmethod
def map_cohere_rerank_params(
self,
non_default_params: Optional[dict],
model: str,
drop_params: bool,
query: str,
documents: List[Union[str, Dict[str, Any]]],
custom_llm_provider: Optional[str] = None,
top_n: Optional[int] = None,
rank_fields: Optional[List[str]] = None,
return_documents: Optional[bool] = True,
max_chunks_per_doc: Optional[int] = None,
) -> OptionalRerankParams:
pass
@abstractmethod
def get_error_class(
self, error_message: str, status_code: int, headers: Union[dict, httpx.Headers]
) -> BaseLLMException:
pass

View file

@ -1,153 +0,0 @@
"""
Re rank api
LiteLLM supports the re rank API format, no paramter transformation occurs
"""
from typing import Any, Dict, List, Optional, Union
import httpx
import litellm
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
from litellm.llms.base import BaseLLM
from litellm.llms.custom_httpx.http_handler import (
AsyncHTTPHandler,
HTTPHandler,
_get_httpx_client,
get_async_httpx_client,
)
from litellm.types.rerank import RerankRequest, RerankResponse
class CohereRerank(BaseLLM):
def validate_environment(self, api_key: str, headers: Optional[dict]) -> dict:
default_headers = {
"accept": "application/json",
"content-type": "application/json",
"Authorization": f"bearer {api_key}",
}
if headers is None:
return default_headers
# If 'Authorization' is provided in headers, it overrides the default.
if "Authorization" in headers:
default_headers["Authorization"] = headers["Authorization"]
# Merge other headers, overriding any default ones except Authorization
return {**default_headers, **headers}
def ensure_rerank_endpoint(self, api_base: str) -> str:
"""
Ensures the `/v1/rerank` endpoint is appended to the given `api_base`.
If `/v1/rerank` is already present, the original URL is returned.
:param api_base: The base API URL.
:return: A URL with `/v1/rerank` appended if missing.
"""
# Parse the base URL to ensure proper structure
url = httpx.URL(api_base)
# Check if the URL already ends with `/v1/rerank`
if not url.path.endswith("/v1/rerank"):
url = url.copy_with(path=f"{url.path.rstrip('/')}/v1/rerank")
return str(url)
def rerank(
self,
model: str,
api_key: str,
api_base: str,
query: str,
documents: List[Union[str, Dict[str, Any]]],
headers: Optional[dict],
litellm_logging_obj: LiteLLMLoggingObj,
top_n: Optional[int] = None,
rank_fields: Optional[List[str]] = None,
return_documents: Optional[bool] = True,
max_chunks_per_doc: Optional[int] = None,
_is_async: Optional[bool] = False, # New parameter
client: Optional[Union[HTTPHandler, AsyncHTTPHandler]] = None,
) -> RerankResponse:
headers = self.validate_environment(api_key=api_key, headers=headers)
api_base = self.ensure_rerank_endpoint(api_base)
request_data = RerankRequest(
model=model,
query=query,
top_n=top_n,
documents=documents,
rank_fields=rank_fields,
return_documents=return_documents,
max_chunks_per_doc=max_chunks_per_doc,
)
request_data_dict = request_data.dict(exclude_none=True)
## LOGGING
litellm_logging_obj.pre_call(
input=request_data_dict,
api_key=api_key,
additional_args={
"complete_input_dict": request_data_dict,
"api_base": api_base,
"headers": headers,
},
)
if _is_async:
return self.async_rerank(request_data=request_data, api_key=api_key, api_base=api_base, headers=headers) # type: ignore # Call async method
if client is not None and isinstance(client, HTTPHandler):
client = client
else:
client = _get_httpx_client()
response = client.post(
url=api_base,
headers=headers,
json=request_data_dict,
)
returned_response = RerankResponse(**response.json())
_response_headers = response.headers
llm_response_headers = {
"{}-{}".format("llm_provider", k): v for k, v in _response_headers.items()
}
returned_response._hidden_params["additional_headers"] = llm_response_headers
return returned_response
async def async_rerank(
self,
request_data: RerankRequest,
api_key: str,
api_base: str,
headers: dict,
client: Optional[AsyncHTTPHandler] = None,
) -> RerankResponse:
request_data_dict = request_data.dict(exclude_none=True)
client = client or get_async_httpx_client(
llm_provider=litellm.LlmProviders.COHERE
)
response = await client.post(
api_base,
headers=headers,
json=request_data_dict,
)
returned_response = RerankResponse(**response.json())
_response_headers = dict(response.headers)
llm_response_headers = {
"{}-{}".format("llm_provider", k): v for k, v in _response_headers.items()
}
returned_response._hidden_params["additional_headers"] = llm_response_headers
returned_response._hidden_params["model"] = request_data.model
return returned_response

View file

@ -0,0 +1,5 @@
"""
Cohere Rerank - uses `llm_http_handler.py` to make httpx requests
Request/Response transformation is handled in `transformation.py`
"""

View file

@ -0,0 +1,150 @@
from typing import Any, Dict, List, Optional, Union
import httpx
import litellm
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
from litellm.llms.base_llm.chat.transformation import BaseLLMException
from litellm.llms.base_llm.rerank.transformation import BaseRerankConfig
from litellm.secret_managers.main import get_secret_str
from litellm.types.rerank import OptionalRerankParams, RerankRequest
from litellm.types.utils import RerankResponse
from ..common_utils import CohereError
class CohereRerankConfig(BaseRerankConfig):
"""
Reference: https://docs.cohere.com/v2/reference/rerank
"""
def __init__(self) -> None:
pass
def get_complete_url(self, api_base: Optional[str], model: str) -> str:
if api_base:
# Remove trailing slashes and ensure clean base URL
api_base = api_base.rstrip("/")
if not api_base.endswith("/v1/rerank"):
api_base = f"{api_base}/v1/rerank"
return api_base
return "https://api.cohere.ai/v1/rerank"
def get_supported_cohere_rerank_params(self, model: str) -> list:
return [
"query",
"documents",
"top_n",
"max_chunks_per_doc",
"rank_fields",
"return_documents",
]
def map_cohere_rerank_params(
self,
non_default_params: Optional[dict],
model: str,
drop_params: bool,
query: str,
documents: List[Union[str, Dict[str, Any]]],
custom_llm_provider: Optional[str] = None,
top_n: Optional[int] = None,
rank_fields: Optional[List[str]] = None,
return_documents: Optional[bool] = True,
max_chunks_per_doc: Optional[int] = None,
) -> OptionalRerankParams:
"""
Map Cohere rerank params
No mapping required - returns all supported params
"""
return OptionalRerankParams(
query=query,
documents=documents,
top_n=top_n,
rank_fields=rank_fields,
return_documents=return_documents,
max_chunks_per_doc=max_chunks_per_doc,
)
def validate_environment(
self,
headers: dict,
model: str,
api_key: Optional[str] = None,
) -> dict:
if api_key is None:
api_key = (
get_secret_str("COHERE_API_KEY")
or get_secret_str("CO_API_KEY")
or litellm.cohere_key
)
if api_key is None:
raise ValueError(
"Cohere API key is required. Please set 'COHERE_API_KEY' or 'CO_API_KEY' or 'litellm.cohere_key'"
)
default_headers = {
"Authorization": f"bearer {api_key}",
"accept": "application/json",
"content-type": "application/json",
}
# If 'Authorization' is provided in headers, it overrides the default.
if "Authorization" in headers:
default_headers["Authorization"] = headers["Authorization"]
# Merge other headers, overriding any default ones except Authorization
return {**default_headers, **headers}
def transform_rerank_request(
self,
model: str,
optional_rerank_params: OptionalRerankParams,
headers: dict,
) -> dict:
if "query" not in optional_rerank_params:
raise ValueError("query is required for Cohere rerank")
if "documents" not in optional_rerank_params:
raise ValueError("documents is required for Cohere rerank")
rerank_request = RerankRequest(
model=model,
query=optional_rerank_params["query"],
documents=optional_rerank_params["documents"],
top_n=optional_rerank_params.get("top_n", None),
rank_fields=optional_rerank_params.get("rank_fields", None),
return_documents=optional_rerank_params.get("return_documents", None),
max_chunks_per_doc=optional_rerank_params.get("max_chunks_per_doc", None),
)
return rerank_request.model_dump(exclude_none=True)
def transform_rerank_response(
self,
model: str,
raw_response: httpx.Response,
model_response: RerankResponse,
logging_obj: LiteLLMLoggingObj,
api_key: Optional[str] = None,
request_data: dict = {},
optional_params: dict = {},
litellm_params: dict = {},
) -> RerankResponse:
"""
Transform Cohere rerank response
No transformation required, litellm follows cohere API response format
"""
try:
raw_response_json = raw_response.json()
except Exception:
raise CohereError(
message=raw_response.text, status_code=raw_response.status_code
)
return RerankResponse(**raw_response_json)
def get_error_class(
self, error_message: str, status_code: int, headers: Union[dict, httpx.Headers]
) -> BaseLLMException:
return CohereError(message=error_message, status_code=status_code)

View file

@ -9,12 +9,14 @@ import litellm.types
import litellm.types.utils import litellm.types.utils
from litellm.llms.base_llm.chat.transformation import BaseConfig, BaseLLMException from litellm.llms.base_llm.chat.transformation import BaseConfig, BaseLLMException
from litellm.llms.base_llm.embedding.transformation import BaseEmbeddingConfig from litellm.llms.base_llm.embedding.transformation import BaseEmbeddingConfig
from litellm.llms.base_llm.rerank.transformation import BaseRerankConfig
from litellm.llms.custom_httpx.http_handler import ( from litellm.llms.custom_httpx.http_handler import (
AsyncHTTPHandler, AsyncHTTPHandler,
HTTPHandler, HTTPHandler,
_get_httpx_client, _get_httpx_client,
get_async_httpx_client, get_async_httpx_client,
) )
from litellm.types.rerank import OptionalRerankParams, RerankResponse
from litellm.types.utils import EmbeddingResponse from litellm.types.utils import EmbeddingResponse
from litellm.utils import CustomStreamWrapper, ModelResponse, ProviderConfigManager from litellm.utils import CustomStreamWrapper, ModelResponse, ProviderConfigManager
@ -524,7 +526,138 @@ class BaseLLMHTTPHandler:
request_data=request_data, request_data=request_data,
) )
def _handle_error(self, e: Exception, provider_config: BaseConfig): def rerank(
self,
model: str,
custom_llm_provider: str,
logging_obj: LiteLLMLoggingObj,
optional_rerank_params: OptionalRerankParams,
timeout: Optional[Union[float, httpx.Timeout]],
model_response: RerankResponse,
_is_async: bool = False,
headers: dict = {},
api_key: Optional[str] = None,
api_base: Optional[str] = None,
client: Optional[Union[HTTPHandler, AsyncHTTPHandler]] = None,
) -> RerankResponse:
provider_config = ProviderConfigManager.get_provider_rerank_config(
model=model, provider=litellm.LlmProviders(custom_llm_provider)
)
# get config from model, custom llm provider
headers = provider_config.validate_environment(
api_key=api_key,
headers=headers,
model=model,
)
api_base = provider_config.get_complete_url(
api_base=api_base,
model=model,
)
data = provider_config.transform_rerank_request(
model=model,
optional_rerank_params=optional_rerank_params,
headers=headers,
)
## LOGGING
logging_obj.pre_call(
input=optional_rerank_params.get("query", ""),
api_key=api_key,
additional_args={
"complete_input_dict": data,
"api_base": api_base,
"headers": headers,
},
)
if _is_async is True:
return self.arerank( # type: ignore
model=model,
request_data=data,
custom_llm_provider=custom_llm_provider,
provider_config=provider_config,
logging_obj=logging_obj,
model_response=model_response,
api_base=api_base,
headers=headers,
api_key=api_key,
timeout=timeout,
client=client,
)
if client is None or not isinstance(client, HTTPHandler):
sync_httpx_client = _get_httpx_client()
else:
sync_httpx_client = client
try:
response = sync_httpx_client.post(
url=api_base,
headers=headers,
data=json.dumps(data),
timeout=timeout,
)
except Exception as e:
raise self._handle_error(
e=e,
provider_config=provider_config,
)
return provider_config.transform_rerank_response(
model=model,
raw_response=response,
model_response=model_response,
logging_obj=logging_obj,
api_key=api_key,
request_data=data,
)
async def arerank(
self,
model: str,
request_data: dict,
custom_llm_provider: str,
provider_config: BaseRerankConfig,
logging_obj: LiteLLMLoggingObj,
model_response: RerankResponse,
api_base: str,
headers: dict,
api_key: Optional[str] = None,
timeout: Optional[Union[float, httpx.Timeout]] = None,
client: Optional[Union[HTTPHandler, AsyncHTTPHandler]] = None,
) -> RerankResponse:
if client is None or not isinstance(client, AsyncHTTPHandler):
async_httpx_client = get_async_httpx_client(
llm_provider=litellm.LlmProviders(custom_llm_provider)
)
else:
async_httpx_client = client
try:
response = await async_httpx_client.post(
url=api_base,
headers=headers,
data=json.dumps(request_data),
timeout=timeout,
)
except Exception as e:
raise self._handle_error(e=e, provider_config=provider_config)
return provider_config.transform_rerank_response(
model=model,
raw_response=response,
model_response=model_response,
logging_obj=logging_obj,
api_key=api_key,
request_data=request_data,
)
def _handle_error(
self, e: Exception, provider_config: Union[BaseConfig, BaseRerankConfig]
):
status_code = getattr(e, "status_code", 500) status_code = getattr(e, "status_code", 500)
error_headers = getattr(e, "headers", None) error_headers = getattr(e, "headers", None)
error_text = getattr(e, "text", str(e)) error_text = getattr(e, "text", str(e))

View file

@ -6,23 +6,23 @@ from typing import Any, Coroutine, Dict, List, Literal, Optional, Union
import litellm import litellm
from litellm._logging import verbose_logger from litellm._logging import verbose_logger
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
from litellm.llms.azure_ai.rerank import AzureAIRerank from litellm.llms.base_llm.rerank.transformation import BaseRerankConfig
from litellm.llms.bedrock.rerank.handler import BedrockRerankHandler from litellm.llms.bedrock.rerank.handler import BedrockRerankHandler
from litellm.llms.cohere.rerank import CohereRerank from litellm.llms.custom_httpx.llm_http_handler import BaseLLMHTTPHandler
from litellm.llms.jina_ai.rerank.handler import JinaAIRerank from litellm.llms.jina_ai.rerank.handler import JinaAIRerank
from litellm.llms.together_ai.rerank.handler import TogetherAIRerank from litellm.llms.together_ai.rerank.handler import TogetherAIRerank
from litellm.rerank_api.rerank_utils import get_optional_rerank_params
from litellm.secret_managers.main import get_secret from litellm.secret_managers.main import get_secret
from litellm.types.rerank import RerankResponse from litellm.types.rerank import OptionalRerankParams, RerankResponse
from litellm.types.router import * from litellm.types.router import *
from litellm.utils import client, exception_type from litellm.utils import ProviderConfigManager, client, exception_type
####### ENVIRONMENT VARIABLES ################### ####### ENVIRONMENT VARIABLES ###################
# Initialize any necessary instances or variables here # Initialize any necessary instances or variables here
cohere_rerank = CohereRerank()
together_rerank = TogetherAIRerank() together_rerank = TogetherAIRerank()
azure_ai_rerank = AzureAIRerank()
jina_ai_rerank = JinaAIRerank() jina_ai_rerank = JinaAIRerank()
bedrock_rerank = BedrockRerankHandler() bedrock_rerank = BedrockRerankHandler()
base_llm_http_handler = BaseLLMHTTPHandler()
################################################# #################################################
@ -107,18 +107,36 @@ def rerank( # noqa: PLR0915
) )
) )
model_params_dict = { rerank_provider_config: BaseRerankConfig = (
"top_n": top_n, ProviderConfigManager.get_provider_rerank_config(
"rank_fields": rank_fields, model=model,
"return_documents": return_documents, provider=litellm.LlmProviders(_custom_llm_provider),
"max_chunks_per_doc": max_chunks_per_doc, )
"documents": documents, )
}
optional_rerank_params: OptionalRerankParams = get_optional_rerank_params(
rerank_provider_config=rerank_provider_config,
model=model,
drop_params=kwargs.get("drop_params") or litellm.drop_params or False,
query=query,
documents=documents,
custom_llm_provider=_custom_llm_provider,
top_n=top_n,
rank_fields=rank_fields,
return_documents=return_documents,
max_chunks_per_doc=max_chunks_per_doc,
non_default_params=kwargs,
)
if isinstance(optional_params.timeout, str):
optional_params.timeout = float(optional_params.timeout)
model_response = RerankResponse()
litellm_logging_obj.update_environment_variables( litellm_logging_obj.update_environment_variables(
model=model, model=model,
user=user, user=user,
optional_params=model_params_dict, optional_params=optional_rerank_params,
litellm_params={ litellm_params={
"litellm_call_id": litellm_call_id, "litellm_call_id": litellm_call_id,
"proxy_server_request": proxy_server_request, "proxy_server_request": proxy_server_request,
@ -135,19 +153,9 @@ def rerank( # noqa: PLR0915
if _custom_llm_provider == "cohere": if _custom_llm_provider == "cohere":
# Implement Cohere rerank logic # Implement Cohere rerank logic
api_key: Optional[str] = ( api_key: Optional[str] = (
dynamic_api_key dynamic_api_key or optional_params.api_key or litellm.api_key
or optional_params.api_key
or litellm.cohere_key
or get_secret("COHERE_API_KEY") # type: ignore
or get_secret("CO_API_KEY") # type: ignore
or litellm.api_key
) )
if api_key is None:
raise ValueError(
"Cohere API key is required, please set 'COHERE_API_KEY' in your environment"
)
api_base: Optional[str] = ( api_base: Optional[str] = (
dynamic_api_base dynamic_api_base
or optional_params.api_base or optional_params.api_base
@ -160,23 +168,18 @@ def rerank( # noqa: PLR0915
raise Exception( raise Exception(
"Invalid api base. api_base=None. Set in call or via `COHERE_API_BASE` env var." "Invalid api base. api_base=None. Set in call or via `COHERE_API_BASE` env var."
) )
response = base_llm_http_handler.rerank(
headers = headers or litellm.headers or {}
response = cohere_rerank.rerank(
model=model, model=model,
query=query, custom_llm_provider=_custom_llm_provider,
documents=documents, optional_rerank_params=optional_rerank_params,
top_n=top_n, logging_obj=litellm_logging_obj,
rank_fields=rank_fields, timeout=optional_params.timeout,
return_documents=return_documents, api_key=dynamic_api_key or optional_params.api_key,
max_chunks_per_doc=max_chunks_per_doc,
api_key=api_key,
api_base=api_base, api_base=api_base,
_is_async=_is_async, _is_async=_is_async,
headers=headers, headers=headers or litellm.headers or {},
litellm_logging_obj=litellm_logging_obj,
client=client, client=client,
model_response=model_response,
) )
elif _custom_llm_provider == "azure_ai": elif _custom_llm_provider == "azure_ai":
api_base = ( api_base = (
@ -185,47 +188,18 @@ def rerank( # noqa: PLR0915
or litellm.api_base or litellm.api_base
or get_secret("AZURE_AI_API_BASE") # type: ignore or get_secret("AZURE_AI_API_BASE") # type: ignore
) )
# set API KEY response = base_llm_http_handler.rerank(
api_key = (
dynamic_api_key
or litellm.api_key # for deepinfra/perplexity/anyscale/friendliai we check in get_llm_provider and pass in the api key from there
or litellm.openai_key
or get_secret("AZURE_AI_API_KEY") # type: ignore
)
headers = headers or litellm.headers or {}
if api_key is None:
raise ValueError(
"Azure AI API key is required, please set 'AZURE_AI_API_KEY' in your environment"
)
if api_base is None:
raise Exception(
"Azure AI API Base is required. api_base=None. Set in call or via `AZURE_AI_API_BASE` env var."
)
## LOAD CONFIG - if set
config = litellm.OpenAIConfig.get_config()
for k, v in config.items():
if (
k not in optional_params
): # completion(top_k=3) > openai_config(top_k=3) <- allows for dynamic variables to be passed in
optional_params[k] = v
response = azure_ai_rerank.rerank(
model=model, model=model,
query=query, custom_llm_provider=_custom_llm_provider,
documents=documents, optional_rerank_params=optional_rerank_params,
top_n=top_n, logging_obj=litellm_logging_obj,
rank_fields=rank_fields, timeout=optional_params.timeout,
return_documents=return_documents, api_key=dynamic_api_key or optional_params.api_key,
max_chunks_per_doc=max_chunks_per_doc,
api_key=api_key,
api_base=api_base, api_base=api_base,
_is_async=_is_async, _is_async=_is_async,
headers=headers, headers=headers or litellm.headers or {},
litellm_logging_obj=litellm_logging_obj, client=client,
model_response=model_response,
) )
elif _custom_llm_provider == "together_ai": elif _custom_llm_provider == "together_ai":
# Implement Together AI rerank logic # Implement Together AI rerank logic

View file

@ -0,0 +1,31 @@
from typing import Any, Dict, List, Optional, Union
from litellm.llms.base_llm.rerank.transformation import BaseRerankConfig
from litellm.types.rerank import OptionalRerankParams
def get_optional_rerank_params(
rerank_provider_config: BaseRerankConfig,
model: str,
drop_params: bool,
query: str,
documents: List[Union[str, Dict[str, Any]]],
custom_llm_provider: Optional[str] = None,
top_n: Optional[int] = None,
rank_fields: Optional[List[str]] = None,
return_documents: Optional[bool] = True,
max_chunks_per_doc: Optional[int] = None,
non_default_params: Optional[dict] = None,
) -> OptionalRerankParams:
return rerank_provider_config.map_cohere_rerank_params(
model=model,
drop_params=drop_params,
query=query,
documents=documents,
custom_llm_provider=custom_llm_provider,
top_n=top_n,
rank_fields=rank_fields,
return_documents=return_documents,
max_chunks_per_doc=max_chunks_per_doc,
non_default_params=non_default_params,
)

View file

@ -20,6 +20,15 @@ class RerankRequest(BaseModel):
max_chunks_per_doc: Optional[int] = None max_chunks_per_doc: Optional[int] = None
class OptionalRerankParams(TypedDict, total=False):
query: str
top_n: Optional[int]
documents: List[Union[str, dict]]
rank_fields: Optional[List[str]]
return_documents: Optional[bool]
max_chunks_per_doc: Optional[int]
class RerankBilledUnits(TypedDict, total=False): class RerankBilledUnits(TypedDict, total=False):
search_units: int search_units: int
total_tokens: int total_tokens: int
@ -42,8 +51,10 @@ class RerankResponseResult(TypedDict):
class RerankResponse(BaseModel): class RerankResponse(BaseModel):
id: str id: Optional[str] = None
results: List[RerankResponseResult] # Contains index and relevance_score results: Optional[List[RerankResponseResult]] = (
None # Contains index and relevance_score
)
meta: Optional[RerankResponseMeta] = None # Contains api_version and billed_units meta: Optional[RerankResponseMeta] = None # Contains api_version and billed_units
# Define private attributes using PrivateAttr # Define private attributes using PrivateAttr

View file

@ -171,6 +171,7 @@ from openai import OpenAIError as OriginalError
from litellm.llms.base_llm.chat.transformation import BaseConfig from litellm.llms.base_llm.chat.transformation import BaseConfig
from litellm.llms.base_llm.embedding.transformation import BaseEmbeddingConfig from litellm.llms.base_llm.embedding.transformation import BaseEmbeddingConfig
from litellm.llms.base_llm.rerank.transformation import BaseRerankConfig
from ._logging import verbose_logger from ._logging import verbose_logger
from .caching.caching import ( from .caching.caching import (
@ -6204,6 +6205,17 @@ class ProviderConfigManager:
return litellm.VoyageEmbeddingConfig() return litellm.VoyageEmbeddingConfig()
raise ValueError(f"Provider {provider} does not support embedding config") raise ValueError(f"Provider {provider} does not support embedding config")
@staticmethod
def get_provider_rerank_config(
model: str,
provider: LlmProviders,
) -> BaseRerankConfig:
if litellm.LlmProviders.COHERE == provider:
return litellm.CohereRerankConfig()
elif litellm.LlmProviders.AZURE_AI == provider:
return litellm.AzureAIRerankConfig()
return litellm.CohereRerankConfig()
def get_end_user_id_for_cost_tracking( def get_end_user_id_for_cost_tracking(
litellm_params: dict, litellm_params: dict,

View file

@ -66,6 +66,7 @@ def assert_response_shape(response, custom_llm_provider):
@pytest.mark.asyncio() @pytest.mark.asyncio()
@pytest.mark.parametrize("sync_mode", [True, False]) @pytest.mark.parametrize("sync_mode", [True, False])
async def test_basic_rerank(sync_mode): async def test_basic_rerank(sync_mode):
litellm.set_verbose = True
if sync_mode is True: if sync_mode is True:
response = litellm.rerank( response = litellm.rerank(
model="cohere/rerank-english-v3.0", model="cohere/rerank-english-v3.0",
@ -95,6 +96,8 @@ async def test_basic_rerank(sync_mode):
assert_response_shape(response, custom_llm_provider="cohere") assert_response_shape(response, custom_llm_provider="cohere")
print("response", response.model_dump_json(indent=4))
@pytest.mark.asyncio() @pytest.mark.asyncio()
@pytest.mark.parametrize("sync_mode", [True, False]) @pytest.mark.parametrize("sync_mode", [True, False])
@ -191,8 +194,8 @@ async def test_rerank_custom_api_base():
expected_payload = { expected_payload = {
"model": "Salesforce/Llama-Rank-V1", "model": "Salesforce/Llama-Rank-V1",
"query": "hello", "query": "hello",
"documents": ["hello", "world"],
"top_n": 3, "top_n": 3,
"documents": ["hello", "world"],
} }
with patch( with patch(
@ -211,15 +214,21 @@ async def test_rerank_custom_api_base():
# Assert # Assert
mock_post.assert_called_once() mock_post.assert_called_once()
_url, kwargs = mock_post.call_args print("call args", mock_post.call_args)
args_to_api = kwargs["json"] args_to_api = mock_post.call_args.kwargs["data"]
_url = mock_post.call_args.kwargs["url"]
print("Arguments passed to API=", args_to_api) print("Arguments passed to API=", args_to_api)
print("url = ", _url) print("url = ", _url)
assert ( assert (
_url[0] _url == "https://exampleopenaiendpoint-production.up.railway.app/v1/rerank"
== "https://exampleopenaiendpoint-production.up.railway.app/v1/rerank"
) )
assert args_to_api == expected_payload
request_data = json.loads(args_to_api)
assert request_data["query"] == expected_payload["query"]
assert request_data["documents"] == expected_payload["documents"]
assert request_data["top_n"] == expected_payload["top_n"]
assert request_data["model"] == expected_payload["model"]
assert response.id is not None assert response.id is not None
assert response.results is not None assert response.results is not None
@ -290,3 +299,58 @@ def test_complete_base_url_cohere():
print("mock_post.call_args", mock_post.call_args) print("mock_post.call_args", mock_post.call_args)
mock_post.assert_called_once() mock_post.assert_called_once()
assert "http://localhost:4000/v1/rerank" in mock_post.call_args.kwargs["url"] assert "http://localhost:4000/v1/rerank" in mock_post.call_args.kwargs["url"]
@pytest.mark.asyncio()
@pytest.mark.parametrize("sync_mode", [True, False])
@pytest.mark.parametrize(
"top_n_1, top_n_2, expect_cache_hit",
[
(3, 3, True),
(3, None, False),
],
)
async def test_basic_rerank_caching(sync_mode, top_n_1, top_n_2, expect_cache_hit):
from litellm.caching.caching import Cache
litellm.set_verbose = True
litellm.cache = Cache(type="local")
if sync_mode is True:
for idx in range(2):
if idx == 0:
top_n = top_n_1
else:
top_n = top_n_2
response = litellm.rerank(
model="cohere/rerank-english-v3.0",
query="hello",
documents=["hello", "world"],
top_n=top_n,
)
else:
for idx in range(2):
if idx == 0:
top_n = top_n_1
else:
top_n = top_n_2
response = await litellm.arerank(
model="cohere/rerank-english-v3.0",
query="hello",
documents=["hello", "world"],
top_n=top_n,
)
await asyncio.sleep(1)
if expect_cache_hit is True:
assert "cache_key" in response._hidden_params
else:
assert "cache_key" not in response._hidden_params
print("re rank response: ", response)
assert response.id is not None
assert response.results is not None
assert_response_shape(response, custom_llm_provider="cohere")

View file

@ -7,7 +7,6 @@ import traceback
import uuid import uuid
from dotenv import load_dotenv from dotenv import load_dotenv
from test_rerank import assert_response_shape
load_dotenv() load_dotenv()
import os import os

View file

@ -5,7 +5,6 @@ import traceback
import uuid import uuid
from dotenv import load_dotenv from dotenv import load_dotenv
from test_rerank import assert_response_shape
load_dotenv() load_dotenv()
import os import os
@ -2132,59 +2131,6 @@ def test_logging_turn_off_message_logging_streaming():
assert mock_client.call_args.args[0].choices[0].message.content == "hello" assert mock_client.call_args.args[0].choices[0].message.content == "hello"
@pytest.mark.asyncio()
@pytest.mark.parametrize("sync_mode", [True, False])
@pytest.mark.parametrize(
"top_n_1, top_n_2, expect_cache_hit",
[
(3, 3, True),
(3, None, False),
],
)
async def test_basic_rerank_caching(sync_mode, top_n_1, top_n_2, expect_cache_hit):
litellm.set_verbose = True
litellm.cache = Cache(type="local")
if sync_mode is True:
for idx in range(2):
if idx == 0:
top_n = top_n_1
else:
top_n = top_n_2
response = litellm.rerank(
model="cohere/rerank-english-v3.0",
query="hello",
documents=["hello", "world"],
top_n=top_n,
)
else:
for idx in range(2):
if idx == 0:
top_n = top_n_1
else:
top_n = top_n_2
response = await litellm.arerank(
model="cohere/rerank-english-v3.0",
query="hello",
documents=["hello", "world"],
top_n=top_n,
)
await asyncio.sleep(1)
if expect_cache_hit is True:
assert "cache_key" in response._hidden_params
else:
assert "cache_key" not in response._hidden_params
print("re rank response: ", response)
assert response.id is not None
assert response.results is not None
assert_response_shape(response, custom_llm_provider="cohere")
def test_basic_caching_import(): def test_basic_caching_import():
from litellm.caching import Cache from litellm.caching import Cache

View file

@ -5,8 +5,6 @@ import traceback
import uuid import uuid
from dotenv import load_dotenv from dotenv import load_dotenv
from test_rerank import assert_response_shape
load_dotenv() load_dotenv()
sys.path.insert( sys.path.insert(

View file

@ -5,7 +5,6 @@ import traceback
import uuid import uuid
from dotenv import load_dotenv from dotenv import load_dotenv
from test_rerank import assert_response_shape
load_dotenv() load_dotenv()
import os import os

View file

@ -5,8 +5,6 @@ import traceback
import uuid import uuid
from dotenv import load_dotenv from dotenv import load_dotenv
from test_rerank import assert_response_shape
load_dotenv() load_dotenv()
sys.path.insert( sys.path.insert(