diff --git a/litellm/__init__.py b/litellm/__init__.py index 59a88abfd1..0bd192d84f 100644 --- a/litellm/__init__.py +++ b/litellm/__init__.py @@ -1023,6 +1023,8 @@ from .llms.databricks.embed.transformation import DatabricksEmbeddingConfig from .llms.predibase.chat.transformation import PredibaseConfig from .llms.replicate.chat.transformation import ReplicateConfig from .llms.cohere.completion.transformation import CohereTextConfig as CohereConfig +from .llms.cohere.rerank.transformation import CohereRerankConfig +from .llms.azure_ai.rerank.transformation import AzureAIRerankConfig from .llms.clarifai.chat.transformation import ClarifaiConfig from .llms.ai21.chat.transformation import AI21ChatConfig, AI21ChatConfig as AI21Config from .llms.together_ai.chat import TogetherAIConfig diff --git a/litellm/llms/azure_ai/rerank/__init__.py b/litellm/llms/azure_ai/rerank/__init__.py deleted file mode 100644 index a25d34b1c4..0000000000 --- a/litellm/llms/azure_ai/rerank/__init__.py +++ /dev/null @@ -1 +0,0 @@ -from .handler import AzureAIRerank diff --git a/litellm/llms/azure_ai/rerank/handler.py b/litellm/llms/azure_ai/rerank/handler.py index 60edfd296f..57e7cefd23 100644 --- a/litellm/llms/azure_ai/rerank/handler.py +++ b/litellm/llms/azure_ai/rerank/handler.py @@ -1,127 +1,5 @@ -from typing import Any, Dict, List, Optional, Union +""" +Azure AI Rerank - uses `llm_http_handler.py` to make httpx requests -import httpx - -from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj -from litellm.llms.cohere.rerank import CohereRerank -from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler -from litellm.types.rerank import RerankResponse - - -class AzureAIRerank(CohereRerank): - - def get_base_model(self, azure_model_group: Optional[str]) -> Optional[str]: - if azure_model_group is None: - return None - if azure_model_group == "offer-cohere-rerank-mul-paygo": - return "azure_ai/cohere-rerank-v3-multilingual" - if azure_model_group == "offer-cohere-rerank-eng-paygo": - return "azure_ai/cohere-rerank-v3-english" - return azure_model_group - - async def async_azure_rerank( - self, - model: str, - api_key: str, - api_base: str, - query: str, - documents: List[Union[str, Dict[str, Any]]], - headers: Optional[dict], - litellm_logging_obj: LiteLLMLoggingObj, - top_n: Optional[int] = None, - rank_fields: Optional[List[str]] = None, - return_documents: Optional[bool] = True, - max_chunks_per_doc: Optional[int] = None, - ): - returned_response: RerankResponse = await super().rerank( # type: ignore - model=model, - api_key=api_key, - api_base=api_base, - query=query, - documents=documents, - top_n=top_n, - rank_fields=rank_fields, - return_documents=return_documents, - max_chunks_per_doc=max_chunks_per_doc, - _is_async=True, - headers=headers, - litellm_logging_obj=litellm_logging_obj, - ) - - # get base model - additional_headers = ( - returned_response._hidden_params.get("additional_headers") or {} - ) - - base_model = self.get_base_model( - additional_headers.get("llm_provider-azureml-model-group") - ) - returned_response._hidden_params["model"] = base_model - - return returned_response - - def rerank( - self, - model: str, - api_key: str, - api_base: str, - query: str, - documents: List[Union[str, Dict[str, Any]]], - headers: Optional[dict], - litellm_logging_obj: LiteLLMLoggingObj, - top_n: Optional[int] = None, - rank_fields: Optional[List[str]] = None, - return_documents: Optional[bool] = True, - max_chunks_per_doc: Optional[int] = None, - _is_async: Optional[bool] = False, - client: Optional[Union[HTTPHandler, AsyncHTTPHandler]] = None, - ) -> RerankResponse: - - if headers is None: - headers = {"Authorization": "Bearer {}".format(api_key)} - else: - headers = {**headers, "Authorization": "Bearer {}".format(api_key)} - - # Assuming api_base is a string representing the base URL - api_base_url = httpx.URL(api_base) - - # Replace the path with '/v1/rerank' if it doesn't already end with it - if not api_base_url.path.endswith("/v1/rerank"): - api_base = str(api_base_url.copy_with(path="/v1/rerank")) - - if _is_async: - return self.async_azure_rerank( # type: ignore - model=model, - api_key=api_key, - api_base=api_base, - query=query, - documents=documents, - top_n=top_n, - rank_fields=rank_fields, - return_documents=return_documents, - max_chunks_per_doc=max_chunks_per_doc, - headers=headers, - litellm_logging_obj=litellm_logging_obj, - ) - else: - returned_response = super().rerank( - model=model, - api_key=api_key, - api_base=api_base, - query=query, - documents=documents, - top_n=top_n, - rank_fields=rank_fields, - return_documents=return_documents, - max_chunks_per_doc=max_chunks_per_doc, - _is_async=_is_async, - headers=headers, - litellm_logging_obj=litellm_logging_obj, - ) - - # get base model - base_model = self.get_base_model( - returned_response._hidden_params.get("llm_provider-azureml-model-group") - ) - returned_response._hidden_params["model"] = base_model - return returned_response +Request/Response transformation is handled in `transformation.py` +""" diff --git a/litellm/llms/azure_ai/rerank/transformation.py b/litellm/llms/azure_ai/rerank/transformation.py index b5aad0ca21..4465e0d70a 100644 --- a/litellm/llms/azure_ai/rerank/transformation.py +++ b/litellm/llms/azure_ai/rerank/transformation.py @@ -1,3 +1,91 @@ """ Translate between Cohere's `/rerank` format and Azure AI's `/rerank` format. """ + +from typing import Optional + +import httpx + +import litellm +from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj +from litellm.llms.cohere.rerank.transformation import CohereRerankConfig +from litellm.secret_managers.main import get_secret_str +from litellm.types.utils import RerankResponse + + +class AzureAIRerankConfig(CohereRerankConfig): + """ + Azure AI Rerank - Follows the same Spec as Cohere Rerank + """ + + def get_complete_url(self, api_base: Optional[str], model: str) -> str: + if api_base is None: + raise ValueError( + "Azure AI API Base is required. api_base=None. Set in call or via `AZURE_AI_API_BASE` env var." + ) + if not api_base.endswith("/v1/rerank"): + api_base = f"{api_base}/v1/rerank" + return api_base + + def validate_environment( + self, + headers: dict, + model: str, + api_key: Optional[str] = None, + ) -> dict: + if api_key is None: + api_key = get_secret_str("AZURE_AI_API_KEY") or litellm.azure_key + + if api_key is None: + raise ValueError( + "Azure AI API key is required. Please set 'AZURE_AI_API_KEY' or 'litellm.azure_key'" + ) + + default_headers = { + "Authorization": f"Bearer {api_key}", + "accept": "application/json", + "content-type": "application/json", + } + + # If 'Authorization' is provided in headers, it overrides the default. + if "Authorization" in headers: + default_headers["Authorization"] = headers["Authorization"] + + # Merge other headers, overriding any default ones except Authorization + return {**default_headers, **headers} + + def transform_rerank_response( + self, + model: str, + raw_response: httpx.Response, + model_response: RerankResponse, + logging_obj: LiteLLMLoggingObj, + api_key: Optional[str] = None, + request_data: dict = {}, + optional_params: dict = {}, + litellm_params: dict = {}, + ) -> RerankResponse: + rerank_response = super().transform_rerank_response( + model=model, + raw_response=raw_response, + model_response=model_response, + logging_obj=logging_obj, + api_key=api_key, + request_data=request_data, + optional_params=optional_params, + litellm_params=litellm_params, + ) + base_model = self._get_base_model( + rerank_response._hidden_params.get("llm_provider-azureml-model-group") + ) + rerank_response._hidden_params["model"] = base_model + return rerank_response + + def _get_base_model(self, azure_model_group: Optional[str]) -> Optional[str]: + if azure_model_group is None: + return None + if azure_model_group == "offer-cohere-rerank-mul-paygo": + return "azure_ai/cohere-rerank-v3-multilingual" + if azure_model_group == "offer-cohere-rerank-eng-paygo": + return "azure_ai/cohere-rerank-v3-english" + return azure_model_group diff --git a/litellm/llms/base_llm/rerank/transformation.py b/litellm/llms/base_llm/rerank/transformation.py new file mode 100644 index 0000000000..d956c9a555 --- /dev/null +++ b/litellm/llms/base_llm/rerank/transformation.py @@ -0,0 +1,86 @@ +from abc import ABC, abstractmethod +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union + +import httpx + +from litellm.types.rerank import OptionalRerankParams, RerankResponse + +from ..chat.transformation import BaseLLMException + +if TYPE_CHECKING: + from litellm.litellm_core_utils.litellm_logging import Logging as _LiteLLMLoggingObj + + LiteLLMLoggingObj = _LiteLLMLoggingObj +else: + LiteLLMLoggingObj = Any + + +class BaseRerankConfig(ABC): + @abstractmethod + def validate_environment( + self, + headers: dict, + model: str, + api_key: Optional[str] = None, + ) -> dict: + pass + + @abstractmethod + def transform_rerank_request( + self, + model: str, + optional_rerank_params: OptionalRerankParams, + headers: dict, + ) -> dict: + return {} + + @abstractmethod + def transform_rerank_response( + self, + model: str, + raw_response: httpx.Response, + model_response: RerankResponse, + logging_obj: LiteLLMLoggingObj, + api_key: Optional[str] = None, + request_data: dict = {}, + optional_params: dict = {}, + litellm_params: dict = {}, + ) -> RerankResponse: + return model_response + + @abstractmethod + def get_complete_url(self, api_base: Optional[str], model: str) -> str: + """ + OPTIONAL + + Get the complete url for the request + + Some providers need `model` in `api_base` + """ + return api_base or "" + + @abstractmethod + def get_supported_cohere_rerank_params(self, model: str) -> list: + pass + + @abstractmethod + def map_cohere_rerank_params( + self, + non_default_params: Optional[dict], + model: str, + drop_params: bool, + query: str, + documents: List[Union[str, Dict[str, Any]]], + custom_llm_provider: Optional[str] = None, + top_n: Optional[int] = None, + rank_fields: Optional[List[str]] = None, + return_documents: Optional[bool] = True, + max_chunks_per_doc: Optional[int] = None, + ) -> OptionalRerankParams: + pass + + @abstractmethod + def get_error_class( + self, error_message: str, status_code: int, headers: Union[dict, httpx.Headers] + ) -> BaseLLMException: + pass diff --git a/litellm/llms/cohere/rerank.py b/litellm/llms/cohere/rerank.py deleted file mode 100644 index 8de2dfbb41..0000000000 --- a/litellm/llms/cohere/rerank.py +++ /dev/null @@ -1,153 +0,0 @@ -""" -Re rank api - -LiteLLM supports the re rank API format, no paramter transformation occurs -""" - -from typing import Any, Dict, List, Optional, Union - -import httpx - -import litellm -from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj -from litellm.llms.base import BaseLLM -from litellm.llms.custom_httpx.http_handler import ( - AsyncHTTPHandler, - HTTPHandler, - _get_httpx_client, - get_async_httpx_client, -) -from litellm.types.rerank import RerankRequest, RerankResponse - - -class CohereRerank(BaseLLM): - def validate_environment(self, api_key: str, headers: Optional[dict]) -> dict: - default_headers = { - "accept": "application/json", - "content-type": "application/json", - "Authorization": f"bearer {api_key}", - } - - if headers is None: - return default_headers - - # If 'Authorization' is provided in headers, it overrides the default. - if "Authorization" in headers: - default_headers["Authorization"] = headers["Authorization"] - - # Merge other headers, overriding any default ones except Authorization - return {**default_headers, **headers} - - def ensure_rerank_endpoint(self, api_base: str) -> str: - """ - Ensures the `/v1/rerank` endpoint is appended to the given `api_base`. - If `/v1/rerank` is already present, the original URL is returned. - - :param api_base: The base API URL. - :return: A URL with `/v1/rerank` appended if missing. - """ - # Parse the base URL to ensure proper structure - url = httpx.URL(api_base) - - # Check if the URL already ends with `/v1/rerank` - if not url.path.endswith("/v1/rerank"): - url = url.copy_with(path=f"{url.path.rstrip('/')}/v1/rerank") - - return str(url) - - def rerank( - self, - model: str, - api_key: str, - api_base: str, - query: str, - documents: List[Union[str, Dict[str, Any]]], - headers: Optional[dict], - litellm_logging_obj: LiteLLMLoggingObj, - top_n: Optional[int] = None, - rank_fields: Optional[List[str]] = None, - return_documents: Optional[bool] = True, - max_chunks_per_doc: Optional[int] = None, - _is_async: Optional[bool] = False, # New parameter - client: Optional[Union[HTTPHandler, AsyncHTTPHandler]] = None, - ) -> RerankResponse: - headers = self.validate_environment(api_key=api_key, headers=headers) - api_base = self.ensure_rerank_endpoint(api_base) - request_data = RerankRequest( - model=model, - query=query, - top_n=top_n, - documents=documents, - rank_fields=rank_fields, - return_documents=return_documents, - max_chunks_per_doc=max_chunks_per_doc, - ) - - request_data_dict = request_data.dict(exclude_none=True) - ## LOGGING - litellm_logging_obj.pre_call( - input=request_data_dict, - api_key=api_key, - additional_args={ - "complete_input_dict": request_data_dict, - "api_base": api_base, - "headers": headers, - }, - ) - - if _is_async: - return self.async_rerank(request_data=request_data, api_key=api_key, api_base=api_base, headers=headers) # type: ignore # Call async method - - if client is not None and isinstance(client, HTTPHandler): - client = client - else: - client = _get_httpx_client() - - response = client.post( - url=api_base, - headers=headers, - json=request_data_dict, - ) - - returned_response = RerankResponse(**response.json()) - - _response_headers = response.headers - - llm_response_headers = { - "{}-{}".format("llm_provider", k): v for k, v in _response_headers.items() - } - returned_response._hidden_params["additional_headers"] = llm_response_headers - - return returned_response - - async def async_rerank( - self, - request_data: RerankRequest, - api_key: str, - api_base: str, - headers: dict, - client: Optional[AsyncHTTPHandler] = None, - ) -> RerankResponse: - request_data_dict = request_data.dict(exclude_none=True) - - client = client or get_async_httpx_client( - llm_provider=litellm.LlmProviders.COHERE - ) - - response = await client.post( - api_base, - headers=headers, - json=request_data_dict, - ) - - returned_response = RerankResponse(**response.json()) - - _response_headers = dict(response.headers) - - llm_response_headers = { - "{}-{}".format("llm_provider", k): v for k, v in _response_headers.items() - } - returned_response._hidden_params["additional_headers"] = llm_response_headers - returned_response._hidden_params["model"] = request_data.model - - return returned_response diff --git a/litellm/llms/cohere/rerank/handler.py b/litellm/llms/cohere/rerank/handler.py new file mode 100644 index 0000000000..e94f1859a7 --- /dev/null +++ b/litellm/llms/cohere/rerank/handler.py @@ -0,0 +1,5 @@ +""" +Cohere Rerank - uses `llm_http_handler.py` to make httpx requests + +Request/Response transformation is handled in `transformation.py` +""" diff --git a/litellm/llms/cohere/rerank/transformation.py b/litellm/llms/cohere/rerank/transformation.py new file mode 100644 index 0000000000..e0836a71f7 --- /dev/null +++ b/litellm/llms/cohere/rerank/transformation.py @@ -0,0 +1,150 @@ +from typing import Any, Dict, List, Optional, Union + +import httpx + +import litellm +from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj +from litellm.llms.base_llm.chat.transformation import BaseLLMException +from litellm.llms.base_llm.rerank.transformation import BaseRerankConfig +from litellm.secret_managers.main import get_secret_str +from litellm.types.rerank import OptionalRerankParams, RerankRequest +from litellm.types.utils import RerankResponse + +from ..common_utils import CohereError + + +class CohereRerankConfig(BaseRerankConfig): + """ + Reference: https://docs.cohere.com/v2/reference/rerank + """ + + def __init__(self) -> None: + pass + + def get_complete_url(self, api_base: Optional[str], model: str) -> str: + if api_base: + # Remove trailing slashes and ensure clean base URL + api_base = api_base.rstrip("/") + if not api_base.endswith("/v1/rerank"): + api_base = f"{api_base}/v1/rerank" + return api_base + return "https://api.cohere.ai/v1/rerank" + + def get_supported_cohere_rerank_params(self, model: str) -> list: + return [ + "query", + "documents", + "top_n", + "max_chunks_per_doc", + "rank_fields", + "return_documents", + ] + + def map_cohere_rerank_params( + self, + non_default_params: Optional[dict], + model: str, + drop_params: bool, + query: str, + documents: List[Union[str, Dict[str, Any]]], + custom_llm_provider: Optional[str] = None, + top_n: Optional[int] = None, + rank_fields: Optional[List[str]] = None, + return_documents: Optional[bool] = True, + max_chunks_per_doc: Optional[int] = None, + ) -> OptionalRerankParams: + """ + Map Cohere rerank params + + No mapping required - returns all supported params + """ + return OptionalRerankParams( + query=query, + documents=documents, + top_n=top_n, + rank_fields=rank_fields, + return_documents=return_documents, + max_chunks_per_doc=max_chunks_per_doc, + ) + + def validate_environment( + self, + headers: dict, + model: str, + api_key: Optional[str] = None, + ) -> dict: + if api_key is None: + api_key = ( + get_secret_str("COHERE_API_KEY") + or get_secret_str("CO_API_KEY") + or litellm.cohere_key + ) + + if api_key is None: + raise ValueError( + "Cohere API key is required. Please set 'COHERE_API_KEY' or 'CO_API_KEY' or 'litellm.cohere_key'" + ) + + default_headers = { + "Authorization": f"bearer {api_key}", + "accept": "application/json", + "content-type": "application/json", + } + + # If 'Authorization' is provided in headers, it overrides the default. + if "Authorization" in headers: + default_headers["Authorization"] = headers["Authorization"] + + # Merge other headers, overriding any default ones except Authorization + return {**default_headers, **headers} + + def transform_rerank_request( + self, + model: str, + optional_rerank_params: OptionalRerankParams, + headers: dict, + ) -> dict: + if "query" not in optional_rerank_params: + raise ValueError("query is required for Cohere rerank") + if "documents" not in optional_rerank_params: + raise ValueError("documents is required for Cohere rerank") + rerank_request = RerankRequest( + model=model, + query=optional_rerank_params["query"], + documents=optional_rerank_params["documents"], + top_n=optional_rerank_params.get("top_n", None), + rank_fields=optional_rerank_params.get("rank_fields", None), + return_documents=optional_rerank_params.get("return_documents", None), + max_chunks_per_doc=optional_rerank_params.get("max_chunks_per_doc", None), + ) + return rerank_request.model_dump(exclude_none=True) + + def transform_rerank_response( + self, + model: str, + raw_response: httpx.Response, + model_response: RerankResponse, + logging_obj: LiteLLMLoggingObj, + api_key: Optional[str] = None, + request_data: dict = {}, + optional_params: dict = {}, + litellm_params: dict = {}, + ) -> RerankResponse: + """ + Transform Cohere rerank response + + No transformation required, litellm follows cohere API response format + """ + try: + raw_response_json = raw_response.json() + except Exception: + raise CohereError( + message=raw_response.text, status_code=raw_response.status_code + ) + + return RerankResponse(**raw_response_json) + + def get_error_class( + self, error_message: str, status_code: int, headers: Union[dict, httpx.Headers] + ) -> BaseLLMException: + return CohereError(message=error_message, status_code=status_code) diff --git a/litellm/llms/custom_httpx/llm_http_handler.py b/litellm/llms/custom_httpx/llm_http_handler.py index 277c698b91..e3ccccd60d 100644 --- a/litellm/llms/custom_httpx/llm_http_handler.py +++ b/litellm/llms/custom_httpx/llm_http_handler.py @@ -9,12 +9,14 @@ import litellm.types import litellm.types.utils from litellm.llms.base_llm.chat.transformation import BaseConfig, BaseLLMException from litellm.llms.base_llm.embedding.transformation import BaseEmbeddingConfig +from litellm.llms.base_llm.rerank.transformation import BaseRerankConfig from litellm.llms.custom_httpx.http_handler import ( AsyncHTTPHandler, HTTPHandler, _get_httpx_client, get_async_httpx_client, ) +from litellm.types.rerank import OptionalRerankParams, RerankResponse from litellm.types.utils import EmbeddingResponse from litellm.utils import CustomStreamWrapper, ModelResponse, ProviderConfigManager @@ -524,7 +526,138 @@ class BaseLLMHTTPHandler: request_data=request_data, ) - def _handle_error(self, e: Exception, provider_config: BaseConfig): + def rerank( + self, + model: str, + custom_llm_provider: str, + logging_obj: LiteLLMLoggingObj, + optional_rerank_params: OptionalRerankParams, + timeout: Optional[Union[float, httpx.Timeout]], + model_response: RerankResponse, + _is_async: bool = False, + headers: dict = {}, + api_key: Optional[str] = None, + api_base: Optional[str] = None, + client: Optional[Union[HTTPHandler, AsyncHTTPHandler]] = None, + ) -> RerankResponse: + + provider_config = ProviderConfigManager.get_provider_rerank_config( + model=model, provider=litellm.LlmProviders(custom_llm_provider) + ) + # get config from model, custom llm provider + headers = provider_config.validate_environment( + api_key=api_key, + headers=headers, + model=model, + ) + + api_base = provider_config.get_complete_url( + api_base=api_base, + model=model, + ) + + data = provider_config.transform_rerank_request( + model=model, + optional_rerank_params=optional_rerank_params, + headers=headers, + ) + + ## LOGGING + logging_obj.pre_call( + input=optional_rerank_params.get("query", ""), + api_key=api_key, + additional_args={ + "complete_input_dict": data, + "api_base": api_base, + "headers": headers, + }, + ) + + if _is_async is True: + return self.arerank( # type: ignore + model=model, + request_data=data, + custom_llm_provider=custom_llm_provider, + provider_config=provider_config, + logging_obj=logging_obj, + model_response=model_response, + api_base=api_base, + headers=headers, + api_key=api_key, + timeout=timeout, + client=client, + ) + + if client is None or not isinstance(client, HTTPHandler): + sync_httpx_client = _get_httpx_client() + else: + sync_httpx_client = client + + try: + response = sync_httpx_client.post( + url=api_base, + headers=headers, + data=json.dumps(data), + timeout=timeout, + ) + except Exception as e: + raise self._handle_error( + e=e, + provider_config=provider_config, + ) + + return provider_config.transform_rerank_response( + model=model, + raw_response=response, + model_response=model_response, + logging_obj=logging_obj, + api_key=api_key, + request_data=data, + ) + + async def arerank( + self, + model: str, + request_data: dict, + custom_llm_provider: str, + provider_config: BaseRerankConfig, + logging_obj: LiteLLMLoggingObj, + model_response: RerankResponse, + api_base: str, + headers: dict, + api_key: Optional[str] = None, + timeout: Optional[Union[float, httpx.Timeout]] = None, + client: Optional[Union[HTTPHandler, AsyncHTTPHandler]] = None, + ) -> RerankResponse: + + if client is None or not isinstance(client, AsyncHTTPHandler): + async_httpx_client = get_async_httpx_client( + llm_provider=litellm.LlmProviders(custom_llm_provider) + ) + else: + async_httpx_client = client + try: + response = await async_httpx_client.post( + url=api_base, + headers=headers, + data=json.dumps(request_data), + timeout=timeout, + ) + except Exception as e: + raise self._handle_error(e=e, provider_config=provider_config) + + return provider_config.transform_rerank_response( + model=model, + raw_response=response, + model_response=model_response, + logging_obj=logging_obj, + api_key=api_key, + request_data=request_data, + ) + + def _handle_error( + self, e: Exception, provider_config: Union[BaseConfig, BaseRerankConfig] + ): status_code = getattr(e, "status_code", 500) error_headers = getattr(e, "headers", None) error_text = getattr(e, "text", str(e)) diff --git a/litellm/rerank_api/main.py b/litellm/rerank_api/main.py index 0acdfb0da3..72de2ca8ed 100644 --- a/litellm/rerank_api/main.py +++ b/litellm/rerank_api/main.py @@ -6,23 +6,23 @@ from typing import Any, Coroutine, Dict, List, Literal, Optional, Union import litellm from litellm._logging import verbose_logger from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj -from litellm.llms.azure_ai.rerank import AzureAIRerank +from litellm.llms.base_llm.rerank.transformation import BaseRerankConfig from litellm.llms.bedrock.rerank.handler import BedrockRerankHandler -from litellm.llms.cohere.rerank import CohereRerank +from litellm.llms.custom_httpx.llm_http_handler import BaseLLMHTTPHandler from litellm.llms.jina_ai.rerank.handler import JinaAIRerank from litellm.llms.together_ai.rerank.handler import TogetherAIRerank +from litellm.rerank_api.rerank_utils import get_optional_rerank_params from litellm.secret_managers.main import get_secret -from litellm.types.rerank import RerankResponse +from litellm.types.rerank import OptionalRerankParams, RerankResponse from litellm.types.router import * -from litellm.utils import client, exception_type +from litellm.utils import ProviderConfigManager, client, exception_type ####### ENVIRONMENT VARIABLES ################### # Initialize any necessary instances or variables here -cohere_rerank = CohereRerank() together_rerank = TogetherAIRerank() -azure_ai_rerank = AzureAIRerank() jina_ai_rerank = JinaAIRerank() bedrock_rerank = BedrockRerankHandler() +base_llm_http_handler = BaseLLMHTTPHandler() ################################################# @@ -107,18 +107,36 @@ def rerank( # noqa: PLR0915 ) ) - model_params_dict = { - "top_n": top_n, - "rank_fields": rank_fields, - "return_documents": return_documents, - "max_chunks_per_doc": max_chunks_per_doc, - "documents": documents, - } + rerank_provider_config: BaseRerankConfig = ( + ProviderConfigManager.get_provider_rerank_config( + model=model, + provider=litellm.LlmProviders(_custom_llm_provider), + ) + ) + + optional_rerank_params: OptionalRerankParams = get_optional_rerank_params( + rerank_provider_config=rerank_provider_config, + model=model, + drop_params=kwargs.get("drop_params") or litellm.drop_params or False, + query=query, + documents=documents, + custom_llm_provider=_custom_llm_provider, + top_n=top_n, + rank_fields=rank_fields, + return_documents=return_documents, + max_chunks_per_doc=max_chunks_per_doc, + non_default_params=kwargs, + ) + + if isinstance(optional_params.timeout, str): + optional_params.timeout = float(optional_params.timeout) + + model_response = RerankResponse() litellm_logging_obj.update_environment_variables( model=model, user=user, - optional_params=model_params_dict, + optional_params=optional_rerank_params, litellm_params={ "litellm_call_id": litellm_call_id, "proxy_server_request": proxy_server_request, @@ -135,19 +153,9 @@ def rerank( # noqa: PLR0915 if _custom_llm_provider == "cohere": # Implement Cohere rerank logic api_key: Optional[str] = ( - dynamic_api_key - or optional_params.api_key - or litellm.cohere_key - or get_secret("COHERE_API_KEY") # type: ignore - or get_secret("CO_API_KEY") # type: ignore - or litellm.api_key + dynamic_api_key or optional_params.api_key or litellm.api_key ) - if api_key is None: - raise ValueError( - "Cohere API key is required, please set 'COHERE_API_KEY' in your environment" - ) - api_base: Optional[str] = ( dynamic_api_base or optional_params.api_base @@ -160,23 +168,18 @@ def rerank( # noqa: PLR0915 raise Exception( "Invalid api base. api_base=None. Set in call or via `COHERE_API_BASE` env var." ) - - headers = headers or litellm.headers or {} - - response = cohere_rerank.rerank( + response = base_llm_http_handler.rerank( model=model, - query=query, - documents=documents, - top_n=top_n, - rank_fields=rank_fields, - return_documents=return_documents, - max_chunks_per_doc=max_chunks_per_doc, - api_key=api_key, + custom_llm_provider=_custom_llm_provider, + optional_rerank_params=optional_rerank_params, + logging_obj=litellm_logging_obj, + timeout=optional_params.timeout, + api_key=dynamic_api_key or optional_params.api_key, api_base=api_base, _is_async=_is_async, - headers=headers, - litellm_logging_obj=litellm_logging_obj, + headers=headers or litellm.headers or {}, client=client, + model_response=model_response, ) elif _custom_llm_provider == "azure_ai": api_base = ( @@ -185,47 +188,18 @@ def rerank( # noqa: PLR0915 or litellm.api_base or get_secret("AZURE_AI_API_BASE") # type: ignore ) - # set API KEY - api_key = ( - dynamic_api_key - or litellm.api_key # for deepinfra/perplexity/anyscale/friendliai we check in get_llm_provider and pass in the api key from there - or litellm.openai_key - or get_secret("AZURE_AI_API_KEY") # type: ignore - ) - - headers = headers or litellm.headers or {} - - if api_key is None: - raise ValueError( - "Azure AI API key is required, please set 'AZURE_AI_API_KEY' in your environment" - ) - - if api_base is None: - raise Exception( - "Azure AI API Base is required. api_base=None. Set in call or via `AZURE_AI_API_BASE` env var." - ) - - ## LOAD CONFIG - if set - config = litellm.OpenAIConfig.get_config() - for k, v in config.items(): - if ( - k not in optional_params - ): # completion(top_k=3) > openai_config(top_k=3) <- allows for dynamic variables to be passed in - optional_params[k] = v - - response = azure_ai_rerank.rerank( + response = base_llm_http_handler.rerank( model=model, - query=query, - documents=documents, - top_n=top_n, - rank_fields=rank_fields, - return_documents=return_documents, - max_chunks_per_doc=max_chunks_per_doc, - api_key=api_key, + custom_llm_provider=_custom_llm_provider, + optional_rerank_params=optional_rerank_params, + logging_obj=litellm_logging_obj, + timeout=optional_params.timeout, + api_key=dynamic_api_key or optional_params.api_key, api_base=api_base, _is_async=_is_async, - headers=headers, - litellm_logging_obj=litellm_logging_obj, + headers=headers or litellm.headers or {}, + client=client, + model_response=model_response, ) elif _custom_llm_provider == "together_ai": # Implement Together AI rerank logic diff --git a/litellm/rerank_api/rerank_utils.py b/litellm/rerank_api/rerank_utils.py new file mode 100644 index 0000000000..c3e5fda56e --- /dev/null +++ b/litellm/rerank_api/rerank_utils.py @@ -0,0 +1,31 @@ +from typing import Any, Dict, List, Optional, Union + +from litellm.llms.base_llm.rerank.transformation import BaseRerankConfig +from litellm.types.rerank import OptionalRerankParams + + +def get_optional_rerank_params( + rerank_provider_config: BaseRerankConfig, + model: str, + drop_params: bool, + query: str, + documents: List[Union[str, Dict[str, Any]]], + custom_llm_provider: Optional[str] = None, + top_n: Optional[int] = None, + rank_fields: Optional[List[str]] = None, + return_documents: Optional[bool] = True, + max_chunks_per_doc: Optional[int] = None, + non_default_params: Optional[dict] = None, +) -> OptionalRerankParams: + return rerank_provider_config.map_cohere_rerank_params( + model=model, + drop_params=drop_params, + query=query, + documents=documents, + custom_llm_provider=custom_llm_provider, + top_n=top_n, + rank_fields=rank_fields, + return_documents=return_documents, + max_chunks_per_doc=max_chunks_per_doc, + non_default_params=non_default_params, + ) diff --git a/litellm/types/rerank.py b/litellm/types/rerank.py index 8a2332fe36..019d90a0ae 100644 --- a/litellm/types/rerank.py +++ b/litellm/types/rerank.py @@ -20,6 +20,15 @@ class RerankRequest(BaseModel): max_chunks_per_doc: Optional[int] = None +class OptionalRerankParams(TypedDict, total=False): + query: str + top_n: Optional[int] + documents: List[Union[str, dict]] + rank_fields: Optional[List[str]] + return_documents: Optional[bool] + max_chunks_per_doc: Optional[int] + + class RerankBilledUnits(TypedDict, total=False): search_units: int total_tokens: int @@ -42,8 +51,10 @@ class RerankResponseResult(TypedDict): class RerankResponse(BaseModel): - id: str - results: List[RerankResponseResult] # Contains index and relevance_score + id: Optional[str] = None + results: Optional[List[RerankResponseResult]] = ( + None # Contains index and relevance_score + ) meta: Optional[RerankResponseMeta] = None # Contains api_version and billed_units # Define private attributes using PrivateAttr diff --git a/litellm/utils.py b/litellm/utils.py index 360d310936..a16075ebc2 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -171,6 +171,7 @@ from openai import OpenAIError as OriginalError from litellm.llms.base_llm.chat.transformation import BaseConfig from litellm.llms.base_llm.embedding.transformation import BaseEmbeddingConfig +from litellm.llms.base_llm.rerank.transformation import BaseRerankConfig from ._logging import verbose_logger from .caching.caching import ( @@ -6204,6 +6205,17 @@ class ProviderConfigManager: return litellm.VoyageEmbeddingConfig() raise ValueError(f"Provider {provider} does not support embedding config") + @staticmethod + def get_provider_rerank_config( + model: str, + provider: LlmProviders, + ) -> BaseRerankConfig: + if litellm.LlmProviders.COHERE == provider: + return litellm.CohereRerankConfig() + elif litellm.LlmProviders.AZURE_AI == provider: + return litellm.AzureAIRerankConfig() + return litellm.CohereRerankConfig() + def get_end_user_id_for_cost_tracking( litellm_params: dict, diff --git a/tests/local_testing/test_rerank.py b/tests/llm_translation/test_rerank.py similarity index 79% rename from tests/local_testing/test_rerank.py rename to tests/llm_translation/test_rerank.py index 5fca6f1354..48a6fe0ca2 100644 --- a/tests/local_testing/test_rerank.py +++ b/tests/llm_translation/test_rerank.py @@ -66,6 +66,7 @@ def assert_response_shape(response, custom_llm_provider): @pytest.mark.asyncio() @pytest.mark.parametrize("sync_mode", [True, False]) async def test_basic_rerank(sync_mode): + litellm.set_verbose = True if sync_mode is True: response = litellm.rerank( model="cohere/rerank-english-v3.0", @@ -95,6 +96,8 @@ async def test_basic_rerank(sync_mode): assert_response_shape(response, custom_llm_provider="cohere") + print("response", response.model_dump_json(indent=4)) + @pytest.mark.asyncio() @pytest.mark.parametrize("sync_mode", [True, False]) @@ -191,8 +194,8 @@ async def test_rerank_custom_api_base(): expected_payload = { "model": "Salesforce/Llama-Rank-V1", "query": "hello", - "documents": ["hello", "world"], "top_n": 3, + "documents": ["hello", "world"], } with patch( @@ -211,15 +214,21 @@ async def test_rerank_custom_api_base(): # Assert mock_post.assert_called_once() - _url, kwargs = mock_post.call_args - args_to_api = kwargs["json"] + print("call args", mock_post.call_args) + args_to_api = mock_post.call_args.kwargs["data"] + _url = mock_post.call_args.kwargs["url"] print("Arguments passed to API=", args_to_api) print("url = ", _url) assert ( - _url[0] - == "https://exampleopenaiendpoint-production.up.railway.app/v1/rerank" + _url == "https://exampleopenaiendpoint-production.up.railway.app/v1/rerank" ) - assert args_to_api == expected_payload + + request_data = json.loads(args_to_api) + assert request_data["query"] == expected_payload["query"] + assert request_data["documents"] == expected_payload["documents"] + assert request_data["top_n"] == expected_payload["top_n"] + assert request_data["model"] == expected_payload["model"] + assert response.id is not None assert response.results is not None @@ -290,3 +299,58 @@ def test_complete_base_url_cohere(): print("mock_post.call_args", mock_post.call_args) mock_post.assert_called_once() assert "http://localhost:4000/v1/rerank" in mock_post.call_args.kwargs["url"] + + +@pytest.mark.asyncio() +@pytest.mark.parametrize("sync_mode", [True, False]) +@pytest.mark.parametrize( + "top_n_1, top_n_2, expect_cache_hit", + [ + (3, 3, True), + (3, None, False), + ], +) +async def test_basic_rerank_caching(sync_mode, top_n_1, top_n_2, expect_cache_hit): + from litellm.caching.caching import Cache + + litellm.set_verbose = True + litellm.cache = Cache(type="local") + + if sync_mode is True: + for idx in range(2): + if idx == 0: + top_n = top_n_1 + else: + top_n = top_n_2 + response = litellm.rerank( + model="cohere/rerank-english-v3.0", + query="hello", + documents=["hello", "world"], + top_n=top_n, + ) + else: + for idx in range(2): + if idx == 0: + top_n = top_n_1 + else: + top_n = top_n_2 + response = await litellm.arerank( + model="cohere/rerank-english-v3.0", + query="hello", + documents=["hello", "world"], + top_n=top_n, + ) + + await asyncio.sleep(1) + + if expect_cache_hit is True: + assert "cache_key" in response._hidden_params + else: + assert "cache_key" not in response._hidden_params + + print("re rank response: ", response) + + assert response.id is not None + assert response.results is not None + + assert_response_shape(response, custom_llm_provider="cohere") diff --git a/tests/local_testing/cache_unit_tests.py b/tests/local_testing/cache_unit_tests.py index da56c773f3..8b82bbdfe6 100644 --- a/tests/local_testing/cache_unit_tests.py +++ b/tests/local_testing/cache_unit_tests.py @@ -7,7 +7,6 @@ import traceback import uuid from dotenv import load_dotenv -from test_rerank import assert_response_shape load_dotenv() import os diff --git a/tests/local_testing/test_caching.py b/tests/local_testing/test_caching.py index 18f7700c75..a8452249e9 100644 --- a/tests/local_testing/test_caching.py +++ b/tests/local_testing/test_caching.py @@ -5,7 +5,6 @@ import traceback import uuid from dotenv import load_dotenv -from test_rerank import assert_response_shape load_dotenv() import os @@ -2132,59 +2131,6 @@ def test_logging_turn_off_message_logging_streaming(): assert mock_client.call_args.args[0].choices[0].message.content == "hello" -@pytest.mark.asyncio() -@pytest.mark.parametrize("sync_mode", [True, False]) -@pytest.mark.parametrize( - "top_n_1, top_n_2, expect_cache_hit", - [ - (3, 3, True), - (3, None, False), - ], -) -async def test_basic_rerank_caching(sync_mode, top_n_1, top_n_2, expect_cache_hit): - litellm.set_verbose = True - litellm.cache = Cache(type="local") - - if sync_mode is True: - for idx in range(2): - if idx == 0: - top_n = top_n_1 - else: - top_n = top_n_2 - response = litellm.rerank( - model="cohere/rerank-english-v3.0", - query="hello", - documents=["hello", "world"], - top_n=top_n, - ) - else: - for idx in range(2): - if idx == 0: - top_n = top_n_1 - else: - top_n = top_n_2 - response = await litellm.arerank( - model="cohere/rerank-english-v3.0", - query="hello", - documents=["hello", "world"], - top_n=top_n, - ) - - await asyncio.sleep(1) - - if expect_cache_hit is True: - assert "cache_key" in response._hidden_params - else: - assert "cache_key" not in response._hidden_params - - print("re rank response: ", response) - - assert response.id is not None - assert response.results is not None - - assert_response_shape(response, custom_llm_provider="cohere") - - def test_basic_caching_import(): from litellm.caching import Cache diff --git a/tests/local_testing/test_caching_handler.py b/tests/local_testing/test_caching_handler.py index b2c8022649..c2466000c4 100644 --- a/tests/local_testing/test_caching_handler.py +++ b/tests/local_testing/test_caching_handler.py @@ -5,8 +5,6 @@ import traceback import uuid from dotenv import load_dotenv -from test_rerank import assert_response_shape - load_dotenv() sys.path.insert( diff --git a/tests/local_testing/test_dual_cache.py b/tests/local_testing/test_dual_cache.py index e81424a9ff..657e00dae1 100644 --- a/tests/local_testing/test_dual_cache.py +++ b/tests/local_testing/test_dual_cache.py @@ -5,7 +5,6 @@ import traceback import uuid from dotenv import load_dotenv -from test_rerank import assert_response_shape load_dotenv() import os diff --git a/tests/local_testing/test_unit_test_caching.py b/tests/local_testing/test_unit_test_caching.py index 5f8f41ba54..52007698ee 100644 --- a/tests/local_testing/test_unit_test_caching.py +++ b/tests/local_testing/test_unit_test_caching.py @@ -5,8 +5,6 @@ import traceback import uuid from dotenv import load_dotenv -from test_rerank import assert_response_shape - load_dotenv() sys.path.insert(