""" Transformation logic from Cohere's /v1/rerank format to Infinity's `/v1/rerank` format. Why separate file? Make it easy to see how transformation works """ import uuid from typing import List, Optional import httpx import litellm from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj from litellm.llms.cohere.rerank.transformation import CohereRerankConfig from litellm.secret_managers.main import get_secret_str from litellm.types.rerank import ( RerankBilledUnits, RerankResponse, RerankResponseDocument, RerankResponseMeta, RerankResponseResult, RerankTokens, ) from .common_utils import InfinityError class InfinityRerankConfig(CohereRerankConfig): def get_complete_url(self, api_base: Optional[str], model: str) -> str: if api_base is None: raise ValueError("api_base is required for Infinity rerank") # Remove trailing slashes and ensure clean base URL api_base = api_base.rstrip("/") if not api_base.endswith("/rerank"): api_base = f"{api_base}/rerank" return api_base def validate_environment( self, headers: dict, model: str, api_key: Optional[str] = None, ) -> dict: if api_key is None: api_key = ( get_secret_str("INFINITY_API_KEY") or get_secret_str("INFINITY_API_KEY") or litellm.infinity_key ) default_headers = { "Authorization": f"bearer {api_key}", "accept": "application/json", "content-type": "application/json", } # If 'Authorization' is provided in headers, it overrides the default. if "Authorization" in headers: default_headers["Authorization"] = headers["Authorization"] # Merge other headers, overriding any default ones except Authorization return {**default_headers, **headers} def transform_rerank_response( self, model: str, raw_response: httpx.Response, model_response: RerankResponse, logging_obj: LiteLLMLoggingObj, api_key: Optional[str] = None, request_data: dict = {}, optional_params: dict = {}, litellm_params: dict = {}, ) -> RerankResponse: """ Transform Infinity rerank response No transformation required, Infinity follows Cohere API response format """ try: raw_response_json = raw_response.json() except Exception: raise InfinityError( message=raw_response.text, status_code=raw_response.status_code ) _billed_units = RerankBilledUnits(**raw_response_json.get("usage", {})) _tokens = RerankTokens( input_tokens=raw_response_json.get("usage", {}).get("prompt_tokens", 0), output_tokens=( raw_response_json.get("usage", {}).get("total_tokens", 0) - raw_response_json.get("usage", {}).get("prompt_tokens", 0) ), ) rerank_meta = RerankResponseMeta(billed_units=_billed_units, tokens=_tokens) cohere_results: List[RerankResponseResult] = [] if raw_response_json.get("results"): for result in raw_response_json.get("results"): _rerank_response = RerankResponseResult( index=result.get("index"), relevance_score=result.get("relevance_score"), ) if result.get("document"): _rerank_response["document"] = RerankResponseDocument( text=result.get("document") ) cohere_results.append(_rerank_response) if cohere_results is None: raise ValueError(f"No results found in the response={raw_response_json}") return RerankResponse( id=raw_response_json.get("id") or str(uuid.uuid4()), results=cohere_results, meta=rerank_meta, ) # Return response