mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-27 11:43:54 +00:00
(code refactor) - Add BaseRerankConfig
. Use BaseRerankConfig
for cohere/rerank
and azure_ai/rerank
(#7319)
* add base rerank config * working sync cohere rerank * update rerank types * update base rerank config * remove old rerank * add new cohere handler.py * add cohere rerank transform * add get_provider_rerank_config * add rerank to base llm http handler * add rerank utils * add arerank to llm http handler.py * add AzureAIRerankConfig * updates rerank config * update test rerank * fix unused imports * update get_provider_rerank_config * test_basic_rerank_caching * fix unused import * test rerank
This commit is contained in:
parent
51f3fc65f7
commit
b0738fd439
19 changed files with 645 additions and 425 deletions
|
@ -1023,6 +1023,8 @@ from .llms.databricks.embed.transformation import DatabricksEmbeddingConfig
|
|||
from .llms.predibase.chat.transformation import PredibaseConfig
|
||||
from .llms.replicate.chat.transformation import ReplicateConfig
|
||||
from .llms.cohere.completion.transformation import CohereTextConfig as CohereConfig
|
||||
from .llms.cohere.rerank.transformation import CohereRerankConfig
|
||||
from .llms.azure_ai.rerank.transformation import AzureAIRerankConfig
|
||||
from .llms.clarifai.chat.transformation import ClarifaiConfig
|
||||
from .llms.ai21.chat.transformation import AI21ChatConfig, AI21ChatConfig as AI21Config
|
||||
from .llms.together_ai.chat import TogetherAIConfig
|
||||
|
|
|
@ -1 +0,0 @@
|
|||
from .handler import AzureAIRerank
|
|
@ -1,127 +1,5 @@
|
|||
from typing import Any, Dict, List, Optional, Union
|
||||
"""
|
||||
Azure AI Rerank - uses `llm_http_handler.py` to make httpx requests
|
||||
|
||||
import httpx
|
||||
|
||||
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
|
||||
from litellm.llms.cohere.rerank import CohereRerank
|
||||
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
|
||||
from litellm.types.rerank import RerankResponse
|
||||
|
||||
|
||||
class AzureAIRerank(CohereRerank):
|
||||
|
||||
def get_base_model(self, azure_model_group: Optional[str]) -> Optional[str]:
|
||||
if azure_model_group is None:
|
||||
return None
|
||||
if azure_model_group == "offer-cohere-rerank-mul-paygo":
|
||||
return "azure_ai/cohere-rerank-v3-multilingual"
|
||||
if azure_model_group == "offer-cohere-rerank-eng-paygo":
|
||||
return "azure_ai/cohere-rerank-v3-english"
|
||||
return azure_model_group
|
||||
|
||||
async def async_azure_rerank(
|
||||
self,
|
||||
model: str,
|
||||
api_key: str,
|
||||
api_base: str,
|
||||
query: str,
|
||||
documents: List[Union[str, Dict[str, Any]]],
|
||||
headers: Optional[dict],
|
||||
litellm_logging_obj: LiteLLMLoggingObj,
|
||||
top_n: Optional[int] = None,
|
||||
rank_fields: Optional[List[str]] = None,
|
||||
return_documents: Optional[bool] = True,
|
||||
max_chunks_per_doc: Optional[int] = None,
|
||||
):
|
||||
returned_response: RerankResponse = await super().rerank( # type: ignore
|
||||
model=model,
|
||||
api_key=api_key,
|
||||
api_base=api_base,
|
||||
query=query,
|
||||
documents=documents,
|
||||
top_n=top_n,
|
||||
rank_fields=rank_fields,
|
||||
return_documents=return_documents,
|
||||
max_chunks_per_doc=max_chunks_per_doc,
|
||||
_is_async=True,
|
||||
headers=headers,
|
||||
litellm_logging_obj=litellm_logging_obj,
|
||||
)
|
||||
|
||||
# get base model
|
||||
additional_headers = (
|
||||
returned_response._hidden_params.get("additional_headers") or {}
|
||||
)
|
||||
|
||||
base_model = self.get_base_model(
|
||||
additional_headers.get("llm_provider-azureml-model-group")
|
||||
)
|
||||
returned_response._hidden_params["model"] = base_model
|
||||
|
||||
return returned_response
|
||||
|
||||
def rerank(
|
||||
self,
|
||||
model: str,
|
||||
api_key: str,
|
||||
api_base: str,
|
||||
query: str,
|
||||
documents: List[Union[str, Dict[str, Any]]],
|
||||
headers: Optional[dict],
|
||||
litellm_logging_obj: LiteLLMLoggingObj,
|
||||
top_n: Optional[int] = None,
|
||||
rank_fields: Optional[List[str]] = None,
|
||||
return_documents: Optional[bool] = True,
|
||||
max_chunks_per_doc: Optional[int] = None,
|
||||
_is_async: Optional[bool] = False,
|
||||
client: Optional[Union[HTTPHandler, AsyncHTTPHandler]] = None,
|
||||
) -> RerankResponse:
|
||||
|
||||
if headers is None:
|
||||
headers = {"Authorization": "Bearer {}".format(api_key)}
|
||||
else:
|
||||
headers = {**headers, "Authorization": "Bearer {}".format(api_key)}
|
||||
|
||||
# Assuming api_base is a string representing the base URL
|
||||
api_base_url = httpx.URL(api_base)
|
||||
|
||||
# Replace the path with '/v1/rerank' if it doesn't already end with it
|
||||
if not api_base_url.path.endswith("/v1/rerank"):
|
||||
api_base = str(api_base_url.copy_with(path="/v1/rerank"))
|
||||
|
||||
if _is_async:
|
||||
return self.async_azure_rerank( # type: ignore
|
||||
model=model,
|
||||
api_key=api_key,
|
||||
api_base=api_base,
|
||||
query=query,
|
||||
documents=documents,
|
||||
top_n=top_n,
|
||||
rank_fields=rank_fields,
|
||||
return_documents=return_documents,
|
||||
max_chunks_per_doc=max_chunks_per_doc,
|
||||
headers=headers,
|
||||
litellm_logging_obj=litellm_logging_obj,
|
||||
)
|
||||
else:
|
||||
returned_response = super().rerank(
|
||||
model=model,
|
||||
api_key=api_key,
|
||||
api_base=api_base,
|
||||
query=query,
|
||||
documents=documents,
|
||||
top_n=top_n,
|
||||
rank_fields=rank_fields,
|
||||
return_documents=return_documents,
|
||||
max_chunks_per_doc=max_chunks_per_doc,
|
||||
_is_async=_is_async,
|
||||
headers=headers,
|
||||
litellm_logging_obj=litellm_logging_obj,
|
||||
)
|
||||
|
||||
# get base model
|
||||
base_model = self.get_base_model(
|
||||
returned_response._hidden_params.get("llm_provider-azureml-model-group")
|
||||
)
|
||||
returned_response._hidden_params["model"] = base_model
|
||||
return returned_response
|
||||
Request/Response transformation is handled in `transformation.py`
|
||||
"""
|
||||
|
|
|
@ -1,3 +1,91 @@
|
|||
"""
|
||||
Translate between Cohere's `/rerank` format and Azure AI's `/rerank` format.
|
||||
"""
|
||||
|
||||
from typing import Optional
|
||||
|
||||
import httpx
|
||||
|
||||
import litellm
|
||||
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
|
||||
from litellm.llms.cohere.rerank.transformation import CohereRerankConfig
|
||||
from litellm.secret_managers.main import get_secret_str
|
||||
from litellm.types.utils import RerankResponse
|
||||
|
||||
|
||||
class AzureAIRerankConfig(CohereRerankConfig):
|
||||
"""
|
||||
Azure AI Rerank - Follows the same Spec as Cohere Rerank
|
||||
"""
|
||||
|
||||
def get_complete_url(self, api_base: Optional[str], model: str) -> str:
|
||||
if api_base is None:
|
||||
raise ValueError(
|
||||
"Azure AI API Base is required. api_base=None. Set in call or via `AZURE_AI_API_BASE` env var."
|
||||
)
|
||||
if not api_base.endswith("/v1/rerank"):
|
||||
api_base = f"{api_base}/v1/rerank"
|
||||
return api_base
|
||||
|
||||
def validate_environment(
|
||||
self,
|
||||
headers: dict,
|
||||
model: str,
|
||||
api_key: Optional[str] = None,
|
||||
) -> dict:
|
||||
if api_key is None:
|
||||
api_key = get_secret_str("AZURE_AI_API_KEY") or litellm.azure_key
|
||||
|
||||
if api_key is None:
|
||||
raise ValueError(
|
||||
"Azure AI API key is required. Please set 'AZURE_AI_API_KEY' or 'litellm.azure_key'"
|
||||
)
|
||||
|
||||
default_headers = {
|
||||
"Authorization": f"Bearer {api_key}",
|
||||
"accept": "application/json",
|
||||
"content-type": "application/json",
|
||||
}
|
||||
|
||||
# If 'Authorization' is provided in headers, it overrides the default.
|
||||
if "Authorization" in headers:
|
||||
default_headers["Authorization"] = headers["Authorization"]
|
||||
|
||||
# Merge other headers, overriding any default ones except Authorization
|
||||
return {**default_headers, **headers}
|
||||
|
||||
def transform_rerank_response(
|
||||
self,
|
||||
model: str,
|
||||
raw_response: httpx.Response,
|
||||
model_response: RerankResponse,
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
api_key: Optional[str] = None,
|
||||
request_data: dict = {},
|
||||
optional_params: dict = {},
|
||||
litellm_params: dict = {},
|
||||
) -> RerankResponse:
|
||||
rerank_response = super().transform_rerank_response(
|
||||
model=model,
|
||||
raw_response=raw_response,
|
||||
model_response=model_response,
|
||||
logging_obj=logging_obj,
|
||||
api_key=api_key,
|
||||
request_data=request_data,
|
||||
optional_params=optional_params,
|
||||
litellm_params=litellm_params,
|
||||
)
|
||||
base_model = self._get_base_model(
|
||||
rerank_response._hidden_params.get("llm_provider-azureml-model-group")
|
||||
)
|
||||
rerank_response._hidden_params["model"] = base_model
|
||||
return rerank_response
|
||||
|
||||
def _get_base_model(self, azure_model_group: Optional[str]) -> Optional[str]:
|
||||
if azure_model_group is None:
|
||||
return None
|
||||
if azure_model_group == "offer-cohere-rerank-mul-paygo":
|
||||
return "azure_ai/cohere-rerank-v3-multilingual"
|
||||
if azure_model_group == "offer-cohere-rerank-eng-paygo":
|
||||
return "azure_ai/cohere-rerank-v3-english"
|
||||
return azure_model_group
|
||||
|
|
86
litellm/llms/base_llm/rerank/transformation.py
Normal file
86
litellm/llms/base_llm/rerank/transformation.py
Normal file
|
@ -0,0 +1,86 @@
|
|||
from abc import ABC, abstractmethod
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
|
||||
|
||||
import httpx
|
||||
|
||||
from litellm.types.rerank import OptionalRerankParams, RerankResponse
|
||||
|
||||
from ..chat.transformation import BaseLLMException
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from litellm.litellm_core_utils.litellm_logging import Logging as _LiteLLMLoggingObj
|
||||
|
||||
LiteLLMLoggingObj = _LiteLLMLoggingObj
|
||||
else:
|
||||
LiteLLMLoggingObj = Any
|
||||
|
||||
|
||||
class BaseRerankConfig(ABC):
|
||||
@abstractmethod
|
||||
def validate_environment(
|
||||
self,
|
||||
headers: dict,
|
||||
model: str,
|
||||
api_key: Optional[str] = None,
|
||||
) -> dict:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def transform_rerank_request(
|
||||
self,
|
||||
model: str,
|
||||
optional_rerank_params: OptionalRerankParams,
|
||||
headers: dict,
|
||||
) -> dict:
|
||||
return {}
|
||||
|
||||
@abstractmethod
|
||||
def transform_rerank_response(
|
||||
self,
|
||||
model: str,
|
||||
raw_response: httpx.Response,
|
||||
model_response: RerankResponse,
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
api_key: Optional[str] = None,
|
||||
request_data: dict = {},
|
||||
optional_params: dict = {},
|
||||
litellm_params: dict = {},
|
||||
) -> RerankResponse:
|
||||
return model_response
|
||||
|
||||
@abstractmethod
|
||||
def get_complete_url(self, api_base: Optional[str], model: str) -> str:
|
||||
"""
|
||||
OPTIONAL
|
||||
|
||||
Get the complete url for the request
|
||||
|
||||
Some providers need `model` in `api_base`
|
||||
"""
|
||||
return api_base or ""
|
||||
|
||||
@abstractmethod
|
||||
def get_supported_cohere_rerank_params(self, model: str) -> list:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def map_cohere_rerank_params(
|
||||
self,
|
||||
non_default_params: Optional[dict],
|
||||
model: str,
|
||||
drop_params: bool,
|
||||
query: str,
|
||||
documents: List[Union[str, Dict[str, Any]]],
|
||||
custom_llm_provider: Optional[str] = None,
|
||||
top_n: Optional[int] = None,
|
||||
rank_fields: Optional[List[str]] = None,
|
||||
return_documents: Optional[bool] = True,
|
||||
max_chunks_per_doc: Optional[int] = None,
|
||||
) -> OptionalRerankParams:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_error_class(
|
||||
self, error_message: str, status_code: int, headers: Union[dict, httpx.Headers]
|
||||
) -> BaseLLMException:
|
||||
pass
|
|
@ -1,153 +0,0 @@
|
|||
"""
|
||||
Re rank api
|
||||
|
||||
LiteLLM supports the re rank API format, no paramter transformation occurs
|
||||
"""
|
||||
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
import httpx
|
||||
|
||||
import litellm
|
||||
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
|
||||
from litellm.llms.base import BaseLLM
|
||||
from litellm.llms.custom_httpx.http_handler import (
|
||||
AsyncHTTPHandler,
|
||||
HTTPHandler,
|
||||
_get_httpx_client,
|
||||
get_async_httpx_client,
|
||||
)
|
||||
from litellm.types.rerank import RerankRequest, RerankResponse
|
||||
|
||||
|
||||
class CohereRerank(BaseLLM):
|
||||
def validate_environment(self, api_key: str, headers: Optional[dict]) -> dict:
|
||||
default_headers = {
|
||||
"accept": "application/json",
|
||||
"content-type": "application/json",
|
||||
"Authorization": f"bearer {api_key}",
|
||||
}
|
||||
|
||||
if headers is None:
|
||||
return default_headers
|
||||
|
||||
# If 'Authorization' is provided in headers, it overrides the default.
|
||||
if "Authorization" in headers:
|
||||
default_headers["Authorization"] = headers["Authorization"]
|
||||
|
||||
# Merge other headers, overriding any default ones except Authorization
|
||||
return {**default_headers, **headers}
|
||||
|
||||
def ensure_rerank_endpoint(self, api_base: str) -> str:
|
||||
"""
|
||||
Ensures the `/v1/rerank` endpoint is appended to the given `api_base`.
|
||||
If `/v1/rerank` is already present, the original URL is returned.
|
||||
|
||||
:param api_base: The base API URL.
|
||||
:return: A URL with `/v1/rerank` appended if missing.
|
||||
"""
|
||||
# Parse the base URL to ensure proper structure
|
||||
url = httpx.URL(api_base)
|
||||
|
||||
# Check if the URL already ends with `/v1/rerank`
|
||||
if not url.path.endswith("/v1/rerank"):
|
||||
url = url.copy_with(path=f"{url.path.rstrip('/')}/v1/rerank")
|
||||
|
||||
return str(url)
|
||||
|
||||
def rerank(
|
||||
self,
|
||||
model: str,
|
||||
api_key: str,
|
||||
api_base: str,
|
||||
query: str,
|
||||
documents: List[Union[str, Dict[str, Any]]],
|
||||
headers: Optional[dict],
|
||||
litellm_logging_obj: LiteLLMLoggingObj,
|
||||
top_n: Optional[int] = None,
|
||||
rank_fields: Optional[List[str]] = None,
|
||||
return_documents: Optional[bool] = True,
|
||||
max_chunks_per_doc: Optional[int] = None,
|
||||
_is_async: Optional[bool] = False, # New parameter
|
||||
client: Optional[Union[HTTPHandler, AsyncHTTPHandler]] = None,
|
||||
) -> RerankResponse:
|
||||
headers = self.validate_environment(api_key=api_key, headers=headers)
|
||||
api_base = self.ensure_rerank_endpoint(api_base)
|
||||
request_data = RerankRequest(
|
||||
model=model,
|
||||
query=query,
|
||||
top_n=top_n,
|
||||
documents=documents,
|
||||
rank_fields=rank_fields,
|
||||
return_documents=return_documents,
|
||||
max_chunks_per_doc=max_chunks_per_doc,
|
||||
)
|
||||
|
||||
request_data_dict = request_data.dict(exclude_none=True)
|
||||
## LOGGING
|
||||
litellm_logging_obj.pre_call(
|
||||
input=request_data_dict,
|
||||
api_key=api_key,
|
||||
additional_args={
|
||||
"complete_input_dict": request_data_dict,
|
||||
"api_base": api_base,
|
||||
"headers": headers,
|
||||
},
|
||||
)
|
||||
|
||||
if _is_async:
|
||||
return self.async_rerank(request_data=request_data, api_key=api_key, api_base=api_base, headers=headers) # type: ignore # Call async method
|
||||
|
||||
if client is not None and isinstance(client, HTTPHandler):
|
||||
client = client
|
||||
else:
|
||||
client = _get_httpx_client()
|
||||
|
||||
response = client.post(
|
||||
url=api_base,
|
||||
headers=headers,
|
||||
json=request_data_dict,
|
||||
)
|
||||
|
||||
returned_response = RerankResponse(**response.json())
|
||||
|
||||
_response_headers = response.headers
|
||||
|
||||
llm_response_headers = {
|
||||
"{}-{}".format("llm_provider", k): v for k, v in _response_headers.items()
|
||||
}
|
||||
returned_response._hidden_params["additional_headers"] = llm_response_headers
|
||||
|
||||
return returned_response
|
||||
|
||||
async def async_rerank(
|
||||
self,
|
||||
request_data: RerankRequest,
|
||||
api_key: str,
|
||||
api_base: str,
|
||||
headers: dict,
|
||||
client: Optional[AsyncHTTPHandler] = None,
|
||||
) -> RerankResponse:
|
||||
request_data_dict = request_data.dict(exclude_none=True)
|
||||
|
||||
client = client or get_async_httpx_client(
|
||||
llm_provider=litellm.LlmProviders.COHERE
|
||||
)
|
||||
|
||||
response = await client.post(
|
||||
api_base,
|
||||
headers=headers,
|
||||
json=request_data_dict,
|
||||
)
|
||||
|
||||
returned_response = RerankResponse(**response.json())
|
||||
|
||||
_response_headers = dict(response.headers)
|
||||
|
||||
llm_response_headers = {
|
||||
"{}-{}".format("llm_provider", k): v for k, v in _response_headers.items()
|
||||
}
|
||||
returned_response._hidden_params["additional_headers"] = llm_response_headers
|
||||
returned_response._hidden_params["model"] = request_data.model
|
||||
|
||||
return returned_response
|
5
litellm/llms/cohere/rerank/handler.py
Normal file
5
litellm/llms/cohere/rerank/handler.py
Normal file
|
@ -0,0 +1,5 @@
|
|||
"""
|
||||
Cohere Rerank - uses `llm_http_handler.py` to make httpx requests
|
||||
|
||||
Request/Response transformation is handled in `transformation.py`
|
||||
"""
|
150
litellm/llms/cohere/rerank/transformation.py
Normal file
150
litellm/llms/cohere/rerank/transformation.py
Normal file
|
@ -0,0 +1,150 @@
|
|||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
import httpx
|
||||
|
||||
import litellm
|
||||
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
|
||||
from litellm.llms.base_llm.chat.transformation import BaseLLMException
|
||||
from litellm.llms.base_llm.rerank.transformation import BaseRerankConfig
|
||||
from litellm.secret_managers.main import get_secret_str
|
||||
from litellm.types.rerank import OptionalRerankParams, RerankRequest
|
||||
from litellm.types.utils import RerankResponse
|
||||
|
||||
from ..common_utils import CohereError
|
||||
|
||||
|
||||
class CohereRerankConfig(BaseRerankConfig):
|
||||
"""
|
||||
Reference: https://docs.cohere.com/v2/reference/rerank
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
pass
|
||||
|
||||
def get_complete_url(self, api_base: Optional[str], model: str) -> str:
|
||||
if api_base:
|
||||
# Remove trailing slashes and ensure clean base URL
|
||||
api_base = api_base.rstrip("/")
|
||||
if not api_base.endswith("/v1/rerank"):
|
||||
api_base = f"{api_base}/v1/rerank"
|
||||
return api_base
|
||||
return "https://api.cohere.ai/v1/rerank"
|
||||
|
||||
def get_supported_cohere_rerank_params(self, model: str) -> list:
|
||||
return [
|
||||
"query",
|
||||
"documents",
|
||||
"top_n",
|
||||
"max_chunks_per_doc",
|
||||
"rank_fields",
|
||||
"return_documents",
|
||||
]
|
||||
|
||||
def map_cohere_rerank_params(
|
||||
self,
|
||||
non_default_params: Optional[dict],
|
||||
model: str,
|
||||
drop_params: bool,
|
||||
query: str,
|
||||
documents: List[Union[str, Dict[str, Any]]],
|
||||
custom_llm_provider: Optional[str] = None,
|
||||
top_n: Optional[int] = None,
|
||||
rank_fields: Optional[List[str]] = None,
|
||||
return_documents: Optional[bool] = True,
|
||||
max_chunks_per_doc: Optional[int] = None,
|
||||
) -> OptionalRerankParams:
|
||||
"""
|
||||
Map Cohere rerank params
|
||||
|
||||
No mapping required - returns all supported params
|
||||
"""
|
||||
return OptionalRerankParams(
|
||||
query=query,
|
||||
documents=documents,
|
||||
top_n=top_n,
|
||||
rank_fields=rank_fields,
|
||||
return_documents=return_documents,
|
||||
max_chunks_per_doc=max_chunks_per_doc,
|
||||
)
|
||||
|
||||
def validate_environment(
|
||||
self,
|
||||
headers: dict,
|
||||
model: str,
|
||||
api_key: Optional[str] = None,
|
||||
) -> dict:
|
||||
if api_key is None:
|
||||
api_key = (
|
||||
get_secret_str("COHERE_API_KEY")
|
||||
or get_secret_str("CO_API_KEY")
|
||||
or litellm.cohere_key
|
||||
)
|
||||
|
||||
if api_key is None:
|
||||
raise ValueError(
|
||||
"Cohere API key is required. Please set 'COHERE_API_KEY' or 'CO_API_KEY' or 'litellm.cohere_key'"
|
||||
)
|
||||
|
||||
default_headers = {
|
||||
"Authorization": f"bearer {api_key}",
|
||||
"accept": "application/json",
|
||||
"content-type": "application/json",
|
||||
}
|
||||
|
||||
# If 'Authorization' is provided in headers, it overrides the default.
|
||||
if "Authorization" in headers:
|
||||
default_headers["Authorization"] = headers["Authorization"]
|
||||
|
||||
# Merge other headers, overriding any default ones except Authorization
|
||||
return {**default_headers, **headers}
|
||||
|
||||
def transform_rerank_request(
|
||||
self,
|
||||
model: str,
|
||||
optional_rerank_params: OptionalRerankParams,
|
||||
headers: dict,
|
||||
) -> dict:
|
||||
if "query" not in optional_rerank_params:
|
||||
raise ValueError("query is required for Cohere rerank")
|
||||
if "documents" not in optional_rerank_params:
|
||||
raise ValueError("documents is required for Cohere rerank")
|
||||
rerank_request = RerankRequest(
|
||||
model=model,
|
||||
query=optional_rerank_params["query"],
|
||||
documents=optional_rerank_params["documents"],
|
||||
top_n=optional_rerank_params.get("top_n", None),
|
||||
rank_fields=optional_rerank_params.get("rank_fields", None),
|
||||
return_documents=optional_rerank_params.get("return_documents", None),
|
||||
max_chunks_per_doc=optional_rerank_params.get("max_chunks_per_doc", None),
|
||||
)
|
||||
return rerank_request.model_dump(exclude_none=True)
|
||||
|
||||
def transform_rerank_response(
|
||||
self,
|
||||
model: str,
|
||||
raw_response: httpx.Response,
|
||||
model_response: RerankResponse,
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
api_key: Optional[str] = None,
|
||||
request_data: dict = {},
|
||||
optional_params: dict = {},
|
||||
litellm_params: dict = {},
|
||||
) -> RerankResponse:
|
||||
"""
|
||||
Transform Cohere rerank response
|
||||
|
||||
No transformation required, litellm follows cohere API response format
|
||||
"""
|
||||
try:
|
||||
raw_response_json = raw_response.json()
|
||||
except Exception:
|
||||
raise CohereError(
|
||||
message=raw_response.text, status_code=raw_response.status_code
|
||||
)
|
||||
|
||||
return RerankResponse(**raw_response_json)
|
||||
|
||||
def get_error_class(
|
||||
self, error_message: str, status_code: int, headers: Union[dict, httpx.Headers]
|
||||
) -> BaseLLMException:
|
||||
return CohereError(message=error_message, status_code=status_code)
|
|
@ -9,12 +9,14 @@ import litellm.types
|
|||
import litellm.types.utils
|
||||
from litellm.llms.base_llm.chat.transformation import BaseConfig, BaseLLMException
|
||||
from litellm.llms.base_llm.embedding.transformation import BaseEmbeddingConfig
|
||||
from litellm.llms.base_llm.rerank.transformation import BaseRerankConfig
|
||||
from litellm.llms.custom_httpx.http_handler import (
|
||||
AsyncHTTPHandler,
|
||||
HTTPHandler,
|
||||
_get_httpx_client,
|
||||
get_async_httpx_client,
|
||||
)
|
||||
from litellm.types.rerank import OptionalRerankParams, RerankResponse
|
||||
from litellm.types.utils import EmbeddingResponse
|
||||
from litellm.utils import CustomStreamWrapper, ModelResponse, ProviderConfigManager
|
||||
|
||||
|
@ -524,7 +526,138 @@ class BaseLLMHTTPHandler:
|
|||
request_data=request_data,
|
||||
)
|
||||
|
||||
def _handle_error(self, e: Exception, provider_config: BaseConfig):
|
||||
def rerank(
|
||||
self,
|
||||
model: str,
|
||||
custom_llm_provider: str,
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
optional_rerank_params: OptionalRerankParams,
|
||||
timeout: Optional[Union[float, httpx.Timeout]],
|
||||
model_response: RerankResponse,
|
||||
_is_async: bool = False,
|
||||
headers: dict = {},
|
||||
api_key: Optional[str] = None,
|
||||
api_base: Optional[str] = None,
|
||||
client: Optional[Union[HTTPHandler, AsyncHTTPHandler]] = None,
|
||||
) -> RerankResponse:
|
||||
|
||||
provider_config = ProviderConfigManager.get_provider_rerank_config(
|
||||
model=model, provider=litellm.LlmProviders(custom_llm_provider)
|
||||
)
|
||||
# get config from model, custom llm provider
|
||||
headers = provider_config.validate_environment(
|
||||
api_key=api_key,
|
||||
headers=headers,
|
||||
model=model,
|
||||
)
|
||||
|
||||
api_base = provider_config.get_complete_url(
|
||||
api_base=api_base,
|
||||
model=model,
|
||||
)
|
||||
|
||||
data = provider_config.transform_rerank_request(
|
||||
model=model,
|
||||
optional_rerank_params=optional_rerank_params,
|
||||
headers=headers,
|
||||
)
|
||||
|
||||
## LOGGING
|
||||
logging_obj.pre_call(
|
||||
input=optional_rerank_params.get("query", ""),
|
||||
api_key=api_key,
|
||||
additional_args={
|
||||
"complete_input_dict": data,
|
||||
"api_base": api_base,
|
||||
"headers": headers,
|
||||
},
|
||||
)
|
||||
|
||||
if _is_async is True:
|
||||
return self.arerank( # type: ignore
|
||||
model=model,
|
||||
request_data=data,
|
||||
custom_llm_provider=custom_llm_provider,
|
||||
provider_config=provider_config,
|
||||
logging_obj=logging_obj,
|
||||
model_response=model_response,
|
||||
api_base=api_base,
|
||||
headers=headers,
|
||||
api_key=api_key,
|
||||
timeout=timeout,
|
||||
client=client,
|
||||
)
|
||||
|
||||
if client is None or not isinstance(client, HTTPHandler):
|
||||
sync_httpx_client = _get_httpx_client()
|
||||
else:
|
||||
sync_httpx_client = client
|
||||
|
||||
try:
|
||||
response = sync_httpx_client.post(
|
||||
url=api_base,
|
||||
headers=headers,
|
||||
data=json.dumps(data),
|
||||
timeout=timeout,
|
||||
)
|
||||
except Exception as e:
|
||||
raise self._handle_error(
|
||||
e=e,
|
||||
provider_config=provider_config,
|
||||
)
|
||||
|
||||
return provider_config.transform_rerank_response(
|
||||
model=model,
|
||||
raw_response=response,
|
||||
model_response=model_response,
|
||||
logging_obj=logging_obj,
|
||||
api_key=api_key,
|
||||
request_data=data,
|
||||
)
|
||||
|
||||
async def arerank(
|
||||
self,
|
||||
model: str,
|
||||
request_data: dict,
|
||||
custom_llm_provider: str,
|
||||
provider_config: BaseRerankConfig,
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
model_response: RerankResponse,
|
||||
api_base: str,
|
||||
headers: dict,
|
||||
api_key: Optional[str] = None,
|
||||
timeout: Optional[Union[float, httpx.Timeout]] = None,
|
||||
client: Optional[Union[HTTPHandler, AsyncHTTPHandler]] = None,
|
||||
) -> RerankResponse:
|
||||
|
||||
if client is None or not isinstance(client, AsyncHTTPHandler):
|
||||
async_httpx_client = get_async_httpx_client(
|
||||
llm_provider=litellm.LlmProviders(custom_llm_provider)
|
||||
)
|
||||
else:
|
||||
async_httpx_client = client
|
||||
try:
|
||||
response = await async_httpx_client.post(
|
||||
url=api_base,
|
||||
headers=headers,
|
||||
data=json.dumps(request_data),
|
||||
timeout=timeout,
|
||||
)
|
||||
except Exception as e:
|
||||
raise self._handle_error(e=e, provider_config=provider_config)
|
||||
|
||||
return provider_config.transform_rerank_response(
|
||||
model=model,
|
||||
raw_response=response,
|
||||
model_response=model_response,
|
||||
logging_obj=logging_obj,
|
||||
api_key=api_key,
|
||||
request_data=request_data,
|
||||
)
|
||||
|
||||
def _handle_error(
|
||||
self, e: Exception, provider_config: Union[BaseConfig, BaseRerankConfig]
|
||||
):
|
||||
status_code = getattr(e, "status_code", 500)
|
||||
error_headers = getattr(e, "headers", None)
|
||||
error_text = getattr(e, "text", str(e))
|
||||
|
|
|
@ -6,23 +6,23 @@ from typing import Any, Coroutine, Dict, List, Literal, Optional, Union
|
|||
import litellm
|
||||
from litellm._logging import verbose_logger
|
||||
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
|
||||
from litellm.llms.azure_ai.rerank import AzureAIRerank
|
||||
from litellm.llms.base_llm.rerank.transformation import BaseRerankConfig
|
||||
from litellm.llms.bedrock.rerank.handler import BedrockRerankHandler
|
||||
from litellm.llms.cohere.rerank import CohereRerank
|
||||
from litellm.llms.custom_httpx.llm_http_handler import BaseLLMHTTPHandler
|
||||
from litellm.llms.jina_ai.rerank.handler import JinaAIRerank
|
||||
from litellm.llms.together_ai.rerank.handler import TogetherAIRerank
|
||||
from litellm.rerank_api.rerank_utils import get_optional_rerank_params
|
||||
from litellm.secret_managers.main import get_secret
|
||||
from litellm.types.rerank import RerankResponse
|
||||
from litellm.types.rerank import OptionalRerankParams, RerankResponse
|
||||
from litellm.types.router import *
|
||||
from litellm.utils import client, exception_type
|
||||
from litellm.utils import ProviderConfigManager, client, exception_type
|
||||
|
||||
####### ENVIRONMENT VARIABLES ###################
|
||||
# Initialize any necessary instances or variables here
|
||||
cohere_rerank = CohereRerank()
|
||||
together_rerank = TogetherAIRerank()
|
||||
azure_ai_rerank = AzureAIRerank()
|
||||
jina_ai_rerank = JinaAIRerank()
|
||||
bedrock_rerank = BedrockRerankHandler()
|
||||
base_llm_http_handler = BaseLLMHTTPHandler()
|
||||
#################################################
|
||||
|
||||
|
||||
|
@ -107,18 +107,36 @@ def rerank( # noqa: PLR0915
|
|||
)
|
||||
)
|
||||
|
||||
model_params_dict = {
|
||||
"top_n": top_n,
|
||||
"rank_fields": rank_fields,
|
||||
"return_documents": return_documents,
|
||||
"max_chunks_per_doc": max_chunks_per_doc,
|
||||
"documents": documents,
|
||||
}
|
||||
rerank_provider_config: BaseRerankConfig = (
|
||||
ProviderConfigManager.get_provider_rerank_config(
|
||||
model=model,
|
||||
provider=litellm.LlmProviders(_custom_llm_provider),
|
||||
)
|
||||
)
|
||||
|
||||
optional_rerank_params: OptionalRerankParams = get_optional_rerank_params(
|
||||
rerank_provider_config=rerank_provider_config,
|
||||
model=model,
|
||||
drop_params=kwargs.get("drop_params") or litellm.drop_params or False,
|
||||
query=query,
|
||||
documents=documents,
|
||||
custom_llm_provider=_custom_llm_provider,
|
||||
top_n=top_n,
|
||||
rank_fields=rank_fields,
|
||||
return_documents=return_documents,
|
||||
max_chunks_per_doc=max_chunks_per_doc,
|
||||
non_default_params=kwargs,
|
||||
)
|
||||
|
||||
if isinstance(optional_params.timeout, str):
|
||||
optional_params.timeout = float(optional_params.timeout)
|
||||
|
||||
model_response = RerankResponse()
|
||||
|
||||
litellm_logging_obj.update_environment_variables(
|
||||
model=model,
|
||||
user=user,
|
||||
optional_params=model_params_dict,
|
||||
optional_params=optional_rerank_params,
|
||||
litellm_params={
|
||||
"litellm_call_id": litellm_call_id,
|
||||
"proxy_server_request": proxy_server_request,
|
||||
|
@ -135,19 +153,9 @@ def rerank( # noqa: PLR0915
|
|||
if _custom_llm_provider == "cohere":
|
||||
# Implement Cohere rerank logic
|
||||
api_key: Optional[str] = (
|
||||
dynamic_api_key
|
||||
or optional_params.api_key
|
||||
or litellm.cohere_key
|
||||
or get_secret("COHERE_API_KEY") # type: ignore
|
||||
or get_secret("CO_API_KEY") # type: ignore
|
||||
or litellm.api_key
|
||||
dynamic_api_key or optional_params.api_key or litellm.api_key
|
||||
)
|
||||
|
||||
if api_key is None:
|
||||
raise ValueError(
|
||||
"Cohere API key is required, please set 'COHERE_API_KEY' in your environment"
|
||||
)
|
||||
|
||||
api_base: Optional[str] = (
|
||||
dynamic_api_base
|
||||
or optional_params.api_base
|
||||
|
@ -160,23 +168,18 @@ def rerank( # noqa: PLR0915
|
|||
raise Exception(
|
||||
"Invalid api base. api_base=None. Set in call or via `COHERE_API_BASE` env var."
|
||||
)
|
||||
|
||||
headers = headers or litellm.headers or {}
|
||||
|
||||
response = cohere_rerank.rerank(
|
||||
response = base_llm_http_handler.rerank(
|
||||
model=model,
|
||||
query=query,
|
||||
documents=documents,
|
||||
top_n=top_n,
|
||||
rank_fields=rank_fields,
|
||||
return_documents=return_documents,
|
||||
max_chunks_per_doc=max_chunks_per_doc,
|
||||
api_key=api_key,
|
||||
custom_llm_provider=_custom_llm_provider,
|
||||
optional_rerank_params=optional_rerank_params,
|
||||
logging_obj=litellm_logging_obj,
|
||||
timeout=optional_params.timeout,
|
||||
api_key=dynamic_api_key or optional_params.api_key,
|
||||
api_base=api_base,
|
||||
_is_async=_is_async,
|
||||
headers=headers,
|
||||
litellm_logging_obj=litellm_logging_obj,
|
||||
headers=headers or litellm.headers or {},
|
||||
client=client,
|
||||
model_response=model_response,
|
||||
)
|
||||
elif _custom_llm_provider == "azure_ai":
|
||||
api_base = (
|
||||
|
@ -185,47 +188,18 @@ def rerank( # noqa: PLR0915
|
|||
or litellm.api_base
|
||||
or get_secret("AZURE_AI_API_BASE") # type: ignore
|
||||
)
|
||||
# set API KEY
|
||||
api_key = (
|
||||
dynamic_api_key
|
||||
or litellm.api_key # for deepinfra/perplexity/anyscale/friendliai we check in get_llm_provider and pass in the api key from there
|
||||
or litellm.openai_key
|
||||
or get_secret("AZURE_AI_API_KEY") # type: ignore
|
||||
)
|
||||
|
||||
headers = headers or litellm.headers or {}
|
||||
|
||||
if api_key is None:
|
||||
raise ValueError(
|
||||
"Azure AI API key is required, please set 'AZURE_AI_API_KEY' in your environment"
|
||||
)
|
||||
|
||||
if api_base is None:
|
||||
raise Exception(
|
||||
"Azure AI API Base is required. api_base=None. Set in call or via `AZURE_AI_API_BASE` env var."
|
||||
)
|
||||
|
||||
## LOAD CONFIG - if set
|
||||
config = litellm.OpenAIConfig.get_config()
|
||||
for k, v in config.items():
|
||||
if (
|
||||
k not in optional_params
|
||||
): # completion(top_k=3) > openai_config(top_k=3) <- allows for dynamic variables to be passed in
|
||||
optional_params[k] = v
|
||||
|
||||
response = azure_ai_rerank.rerank(
|
||||
response = base_llm_http_handler.rerank(
|
||||
model=model,
|
||||
query=query,
|
||||
documents=documents,
|
||||
top_n=top_n,
|
||||
rank_fields=rank_fields,
|
||||
return_documents=return_documents,
|
||||
max_chunks_per_doc=max_chunks_per_doc,
|
||||
api_key=api_key,
|
||||
custom_llm_provider=_custom_llm_provider,
|
||||
optional_rerank_params=optional_rerank_params,
|
||||
logging_obj=litellm_logging_obj,
|
||||
timeout=optional_params.timeout,
|
||||
api_key=dynamic_api_key or optional_params.api_key,
|
||||
api_base=api_base,
|
||||
_is_async=_is_async,
|
||||
headers=headers,
|
||||
litellm_logging_obj=litellm_logging_obj,
|
||||
headers=headers or litellm.headers or {},
|
||||
client=client,
|
||||
model_response=model_response,
|
||||
)
|
||||
elif _custom_llm_provider == "together_ai":
|
||||
# Implement Together AI rerank logic
|
||||
|
|
31
litellm/rerank_api/rerank_utils.py
Normal file
31
litellm/rerank_api/rerank_utils.py
Normal file
|
@ -0,0 +1,31 @@
|
|||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
from litellm.llms.base_llm.rerank.transformation import BaseRerankConfig
|
||||
from litellm.types.rerank import OptionalRerankParams
|
||||
|
||||
|
||||
def get_optional_rerank_params(
|
||||
rerank_provider_config: BaseRerankConfig,
|
||||
model: str,
|
||||
drop_params: bool,
|
||||
query: str,
|
||||
documents: List[Union[str, Dict[str, Any]]],
|
||||
custom_llm_provider: Optional[str] = None,
|
||||
top_n: Optional[int] = None,
|
||||
rank_fields: Optional[List[str]] = None,
|
||||
return_documents: Optional[bool] = True,
|
||||
max_chunks_per_doc: Optional[int] = None,
|
||||
non_default_params: Optional[dict] = None,
|
||||
) -> OptionalRerankParams:
|
||||
return rerank_provider_config.map_cohere_rerank_params(
|
||||
model=model,
|
||||
drop_params=drop_params,
|
||||
query=query,
|
||||
documents=documents,
|
||||
custom_llm_provider=custom_llm_provider,
|
||||
top_n=top_n,
|
||||
rank_fields=rank_fields,
|
||||
return_documents=return_documents,
|
||||
max_chunks_per_doc=max_chunks_per_doc,
|
||||
non_default_params=non_default_params,
|
||||
)
|
|
@ -20,6 +20,15 @@ class RerankRequest(BaseModel):
|
|||
max_chunks_per_doc: Optional[int] = None
|
||||
|
||||
|
||||
class OptionalRerankParams(TypedDict, total=False):
|
||||
query: str
|
||||
top_n: Optional[int]
|
||||
documents: List[Union[str, dict]]
|
||||
rank_fields: Optional[List[str]]
|
||||
return_documents: Optional[bool]
|
||||
max_chunks_per_doc: Optional[int]
|
||||
|
||||
|
||||
class RerankBilledUnits(TypedDict, total=False):
|
||||
search_units: int
|
||||
total_tokens: int
|
||||
|
@ -42,8 +51,10 @@ class RerankResponseResult(TypedDict):
|
|||
|
||||
|
||||
class RerankResponse(BaseModel):
|
||||
id: str
|
||||
results: List[RerankResponseResult] # Contains index and relevance_score
|
||||
id: Optional[str] = None
|
||||
results: Optional[List[RerankResponseResult]] = (
|
||||
None # Contains index and relevance_score
|
||||
)
|
||||
meta: Optional[RerankResponseMeta] = None # Contains api_version and billed_units
|
||||
|
||||
# Define private attributes using PrivateAttr
|
||||
|
|
|
@ -171,6 +171,7 @@ from openai import OpenAIError as OriginalError
|
|||
|
||||
from litellm.llms.base_llm.chat.transformation import BaseConfig
|
||||
from litellm.llms.base_llm.embedding.transformation import BaseEmbeddingConfig
|
||||
from litellm.llms.base_llm.rerank.transformation import BaseRerankConfig
|
||||
|
||||
from ._logging import verbose_logger
|
||||
from .caching.caching import (
|
||||
|
@ -6204,6 +6205,17 @@ class ProviderConfigManager:
|
|||
return litellm.VoyageEmbeddingConfig()
|
||||
raise ValueError(f"Provider {provider} does not support embedding config")
|
||||
|
||||
@staticmethod
|
||||
def get_provider_rerank_config(
|
||||
model: str,
|
||||
provider: LlmProviders,
|
||||
) -> BaseRerankConfig:
|
||||
if litellm.LlmProviders.COHERE == provider:
|
||||
return litellm.CohereRerankConfig()
|
||||
elif litellm.LlmProviders.AZURE_AI == provider:
|
||||
return litellm.AzureAIRerankConfig()
|
||||
return litellm.CohereRerankConfig()
|
||||
|
||||
|
||||
def get_end_user_id_for_cost_tracking(
|
||||
litellm_params: dict,
|
||||
|
|
|
@ -66,6 +66,7 @@ def assert_response_shape(response, custom_llm_provider):
|
|||
@pytest.mark.asyncio()
|
||||
@pytest.mark.parametrize("sync_mode", [True, False])
|
||||
async def test_basic_rerank(sync_mode):
|
||||
litellm.set_verbose = True
|
||||
if sync_mode is True:
|
||||
response = litellm.rerank(
|
||||
model="cohere/rerank-english-v3.0",
|
||||
|
@ -95,6 +96,8 @@ async def test_basic_rerank(sync_mode):
|
|||
|
||||
assert_response_shape(response, custom_llm_provider="cohere")
|
||||
|
||||
print("response", response.model_dump_json(indent=4))
|
||||
|
||||
|
||||
@pytest.mark.asyncio()
|
||||
@pytest.mark.parametrize("sync_mode", [True, False])
|
||||
|
@ -191,8 +194,8 @@ async def test_rerank_custom_api_base():
|
|||
expected_payload = {
|
||||
"model": "Salesforce/Llama-Rank-V1",
|
||||
"query": "hello",
|
||||
"documents": ["hello", "world"],
|
||||
"top_n": 3,
|
||||
"documents": ["hello", "world"],
|
||||
}
|
||||
|
||||
with patch(
|
||||
|
@ -211,15 +214,21 @@ async def test_rerank_custom_api_base():
|
|||
|
||||
# Assert
|
||||
mock_post.assert_called_once()
|
||||
_url, kwargs = mock_post.call_args
|
||||
args_to_api = kwargs["json"]
|
||||
print("call args", mock_post.call_args)
|
||||
args_to_api = mock_post.call_args.kwargs["data"]
|
||||
_url = mock_post.call_args.kwargs["url"]
|
||||
print("Arguments passed to API=", args_to_api)
|
||||
print("url = ", _url)
|
||||
assert (
|
||||
_url[0]
|
||||
== "https://exampleopenaiendpoint-production.up.railway.app/v1/rerank"
|
||||
_url == "https://exampleopenaiendpoint-production.up.railway.app/v1/rerank"
|
||||
)
|
||||
assert args_to_api == expected_payload
|
||||
|
||||
request_data = json.loads(args_to_api)
|
||||
assert request_data["query"] == expected_payload["query"]
|
||||
assert request_data["documents"] == expected_payload["documents"]
|
||||
assert request_data["top_n"] == expected_payload["top_n"]
|
||||
assert request_data["model"] == expected_payload["model"]
|
||||
|
||||
assert response.id is not None
|
||||
assert response.results is not None
|
||||
|
||||
|
@ -290,3 +299,58 @@ def test_complete_base_url_cohere():
|
|||
print("mock_post.call_args", mock_post.call_args)
|
||||
mock_post.assert_called_once()
|
||||
assert "http://localhost:4000/v1/rerank" in mock_post.call_args.kwargs["url"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio()
|
||||
@pytest.mark.parametrize("sync_mode", [True, False])
|
||||
@pytest.mark.parametrize(
|
||||
"top_n_1, top_n_2, expect_cache_hit",
|
||||
[
|
||||
(3, 3, True),
|
||||
(3, None, False),
|
||||
],
|
||||
)
|
||||
async def test_basic_rerank_caching(sync_mode, top_n_1, top_n_2, expect_cache_hit):
|
||||
from litellm.caching.caching import Cache
|
||||
|
||||
litellm.set_verbose = True
|
||||
litellm.cache = Cache(type="local")
|
||||
|
||||
if sync_mode is True:
|
||||
for idx in range(2):
|
||||
if idx == 0:
|
||||
top_n = top_n_1
|
||||
else:
|
||||
top_n = top_n_2
|
||||
response = litellm.rerank(
|
||||
model="cohere/rerank-english-v3.0",
|
||||
query="hello",
|
||||
documents=["hello", "world"],
|
||||
top_n=top_n,
|
||||
)
|
||||
else:
|
||||
for idx in range(2):
|
||||
if idx == 0:
|
||||
top_n = top_n_1
|
||||
else:
|
||||
top_n = top_n_2
|
||||
response = await litellm.arerank(
|
||||
model="cohere/rerank-english-v3.0",
|
||||
query="hello",
|
||||
documents=["hello", "world"],
|
||||
top_n=top_n,
|
||||
)
|
||||
|
||||
await asyncio.sleep(1)
|
||||
|
||||
if expect_cache_hit is True:
|
||||
assert "cache_key" in response._hidden_params
|
||||
else:
|
||||
assert "cache_key" not in response._hidden_params
|
||||
|
||||
print("re rank response: ", response)
|
||||
|
||||
assert response.id is not None
|
||||
assert response.results is not None
|
||||
|
||||
assert_response_shape(response, custom_llm_provider="cohere")
|
|
@ -7,7 +7,6 @@ import traceback
|
|||
import uuid
|
||||
|
||||
from dotenv import load_dotenv
|
||||
from test_rerank import assert_response_shape
|
||||
|
||||
load_dotenv()
|
||||
import os
|
||||
|
|
|
@ -5,7 +5,6 @@ import traceback
|
|||
import uuid
|
||||
|
||||
from dotenv import load_dotenv
|
||||
from test_rerank import assert_response_shape
|
||||
|
||||
load_dotenv()
|
||||
import os
|
||||
|
@ -2132,59 +2131,6 @@ def test_logging_turn_off_message_logging_streaming():
|
|||
assert mock_client.call_args.args[0].choices[0].message.content == "hello"
|
||||
|
||||
|
||||
@pytest.mark.asyncio()
|
||||
@pytest.mark.parametrize("sync_mode", [True, False])
|
||||
@pytest.mark.parametrize(
|
||||
"top_n_1, top_n_2, expect_cache_hit",
|
||||
[
|
||||
(3, 3, True),
|
||||
(3, None, False),
|
||||
],
|
||||
)
|
||||
async def test_basic_rerank_caching(sync_mode, top_n_1, top_n_2, expect_cache_hit):
|
||||
litellm.set_verbose = True
|
||||
litellm.cache = Cache(type="local")
|
||||
|
||||
if sync_mode is True:
|
||||
for idx in range(2):
|
||||
if idx == 0:
|
||||
top_n = top_n_1
|
||||
else:
|
||||
top_n = top_n_2
|
||||
response = litellm.rerank(
|
||||
model="cohere/rerank-english-v3.0",
|
||||
query="hello",
|
||||
documents=["hello", "world"],
|
||||
top_n=top_n,
|
||||
)
|
||||
else:
|
||||
for idx in range(2):
|
||||
if idx == 0:
|
||||
top_n = top_n_1
|
||||
else:
|
||||
top_n = top_n_2
|
||||
response = await litellm.arerank(
|
||||
model="cohere/rerank-english-v3.0",
|
||||
query="hello",
|
||||
documents=["hello", "world"],
|
||||
top_n=top_n,
|
||||
)
|
||||
|
||||
await asyncio.sleep(1)
|
||||
|
||||
if expect_cache_hit is True:
|
||||
assert "cache_key" in response._hidden_params
|
||||
else:
|
||||
assert "cache_key" not in response._hidden_params
|
||||
|
||||
print("re rank response: ", response)
|
||||
|
||||
assert response.id is not None
|
||||
assert response.results is not None
|
||||
|
||||
assert_response_shape(response, custom_llm_provider="cohere")
|
||||
|
||||
|
||||
def test_basic_caching_import():
|
||||
from litellm.caching import Cache
|
||||
|
||||
|
|
|
@ -5,8 +5,6 @@ import traceback
|
|||
import uuid
|
||||
|
||||
from dotenv import load_dotenv
|
||||
from test_rerank import assert_response_shape
|
||||
|
||||
|
||||
load_dotenv()
|
||||
sys.path.insert(
|
||||
|
|
|
@ -5,7 +5,6 @@ import traceback
|
|||
import uuid
|
||||
|
||||
from dotenv import load_dotenv
|
||||
from test_rerank import assert_response_shape
|
||||
|
||||
load_dotenv()
|
||||
import os
|
||||
|
|
|
@ -5,8 +5,6 @@ import traceback
|
|||
import uuid
|
||||
|
||||
from dotenv import load_dotenv
|
||||
from test_rerank import assert_response_shape
|
||||
|
||||
|
||||
load_dotenv()
|
||||
sys.path.insert(
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue