mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 02:34:29 +00:00
* Add cohere v2/rerank support (#8421) * Support v2 endpoint cohere rerank * Add tests and docs * Make v1 default if old params used * Update docs * Update docs pt 2 * Update tests * Add e2e test * Clean up code * Use inheritence for new config * Fix linting issues (#8608) * Fix cohere v2 failing test + linting (#8672) * Fix test and unused imports * Fix tests * fix: fix linting errors * test: handle tgai instability * fix: skip service unavailable err * test: print logs for unstable test * test: skip unreliable tests --------- Co-authored-by: vibhavbhat <vibhavb00@gmail.com>
This commit is contained in:
parent
c2aec21b4d
commit
09462ba80c
19 changed files with 257 additions and 40 deletions
|
@ -108,7 +108,7 @@ response = embedding(
|
|||
|
||||
### Usage
|
||||
|
||||
|
||||
LiteLLM supports the v1 and v2 clients for Cohere rerank. By default, the `rerank` endpoint uses the v2 client, but you can specify the v1 client by explicitly calling `v1/rerank`
|
||||
|
||||
<Tabs>
|
||||
<TabItem value="sdk" label="LiteLLM SDK Usage">
|
||||
|
|
|
@ -111,7 +111,7 @@ curl http://0.0.0.0:4000/rerank \
|
|||
|
||||
| Provider | Link to Usage |
|
||||
|-------------|--------------------|
|
||||
| Cohere | [Usage](#quick-start) |
|
||||
| Cohere (v1 + v2 clients) | [Usage](#quick-start) |
|
||||
| Together AI| [Usage](../docs/providers/togetherai) |
|
||||
| Azure AI| [Usage](../docs/providers/azure_ai) |
|
||||
| Jina AI| [Usage](../docs/providers/jina_ai) |
|
||||
|
|
|
@ -824,6 +824,7 @@ from .llms.predibase.chat.transformation import PredibaseConfig
|
|||
from .llms.replicate.chat.transformation import ReplicateConfig
|
||||
from .llms.cohere.completion.transformation import CohereTextConfig as CohereConfig
|
||||
from .llms.cohere.rerank.transformation import CohereRerankConfig
|
||||
from .llms.cohere.rerank_v2.transformation import CohereRerankV2Config
|
||||
from .llms.azure_ai.rerank.transformation import AzureAIRerankConfig
|
||||
from .llms.infinity.rerank.transformation import InfinityRerankConfig
|
||||
from .llms.jina_ai.rerank.transformation import JinaAIRerankConfig
|
||||
|
|
|
@ -855,7 +855,10 @@ def rerank_cost(
|
|||
|
||||
try:
|
||||
config = ProviderConfigManager.get_provider_rerank_config(
|
||||
model=model, provider=LlmProviders(custom_llm_provider)
|
||||
model=model,
|
||||
api_base=None,
|
||||
present_version_params=[],
|
||||
provider=LlmProviders(custom_llm_provider),
|
||||
)
|
||||
|
||||
try:
|
||||
|
|
|
@ -17,7 +17,6 @@ class AzureAIRerankConfig(CohereRerankConfig):
|
|||
"""
|
||||
Azure AI Rerank - Follows the same Spec as Cohere Rerank
|
||||
"""
|
||||
|
||||
def get_complete_url(self, api_base: Optional[str], model: str) -> str:
|
||||
if api_base is None:
|
||||
raise ValueError(
|
||||
|
|
|
@ -77,6 +77,7 @@ class BaseRerankConfig(ABC):
|
|||
rank_fields: Optional[List[str]] = None,
|
||||
return_documents: Optional[bool] = True,
|
||||
max_chunks_per_doc: Optional[int] = None,
|
||||
max_tokens_per_doc: Optional[int] = None,
|
||||
) -> OptionalRerankParams:
|
||||
pass
|
||||
|
||||
|
|
|
@ -52,6 +52,7 @@ class CohereRerankConfig(BaseRerankConfig):
|
|||
rank_fields: Optional[List[str]] = None,
|
||||
return_documents: Optional[bool] = True,
|
||||
max_chunks_per_doc: Optional[int] = None,
|
||||
max_tokens_per_doc: Optional[int] = None,
|
||||
) -> OptionalRerankParams:
|
||||
"""
|
||||
Map Cohere rerank params
|
||||
|
@ -147,4 +148,4 @@ class CohereRerankConfig(BaseRerankConfig):
|
|||
def get_error_class(
|
||||
self, error_message: str, status_code: int, headers: Union[dict, httpx.Headers]
|
||||
) -> BaseLLMException:
|
||||
return CohereError(message=error_message, status_code=status_code)
|
||||
return CohereError(message=error_message, status_code=status_code)
|
80
litellm/llms/cohere/rerank_v2/transformation.py
Normal file
80
litellm/llms/cohere/rerank_v2/transformation.py
Normal file
|
@ -0,0 +1,80 @@
|
|||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
from litellm.llms.cohere.rerank.transformation import CohereRerankConfig
|
||||
from litellm.types.rerank import OptionalRerankParams, RerankRequest
|
||||
|
||||
class CohereRerankV2Config(CohereRerankConfig):
|
||||
"""
|
||||
Reference: https://docs.cohere.com/v2/reference/rerank
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
pass
|
||||
|
||||
def get_complete_url(self, api_base: Optional[str], model: str) -> str:
|
||||
if api_base:
|
||||
# Remove trailing slashes and ensure clean base URL
|
||||
api_base = api_base.rstrip("/")
|
||||
if not api_base.endswith("/v2/rerank"):
|
||||
api_base = f"{api_base}/v2/rerank"
|
||||
return api_base
|
||||
return "https://api.cohere.ai/v2/rerank"
|
||||
|
||||
def get_supported_cohere_rerank_params(self, model: str) -> list:
|
||||
return [
|
||||
"query",
|
||||
"documents",
|
||||
"top_n",
|
||||
"max_tokens_per_doc",
|
||||
"rank_fields",
|
||||
"return_documents",
|
||||
]
|
||||
|
||||
def map_cohere_rerank_params(
|
||||
self,
|
||||
non_default_params: Optional[dict],
|
||||
model: str,
|
||||
drop_params: bool,
|
||||
query: str,
|
||||
documents: List[Union[str, Dict[str, Any]]],
|
||||
custom_llm_provider: Optional[str] = None,
|
||||
top_n: Optional[int] = None,
|
||||
rank_fields: Optional[List[str]] = None,
|
||||
return_documents: Optional[bool] = True,
|
||||
max_chunks_per_doc: Optional[int] = None,
|
||||
max_tokens_per_doc: Optional[int] = None,
|
||||
) -> OptionalRerankParams:
|
||||
"""
|
||||
Map Cohere rerank params
|
||||
|
||||
No mapping required - returns all supported params
|
||||
"""
|
||||
return OptionalRerankParams(
|
||||
query=query,
|
||||
documents=documents,
|
||||
top_n=top_n,
|
||||
rank_fields=rank_fields,
|
||||
return_documents=return_documents,
|
||||
max_tokens_per_doc=max_tokens_per_doc,
|
||||
)
|
||||
|
||||
def transform_rerank_request(
|
||||
self,
|
||||
model: str,
|
||||
optional_rerank_params: OptionalRerankParams,
|
||||
headers: dict,
|
||||
) -> dict:
|
||||
if "query" not in optional_rerank_params:
|
||||
raise ValueError("query is required for Cohere rerank")
|
||||
if "documents" not in optional_rerank_params:
|
||||
raise ValueError("documents is required for Cohere rerank")
|
||||
rerank_request = RerankRequest(
|
||||
model=model,
|
||||
query=optional_rerank_params["query"],
|
||||
documents=optional_rerank_params["documents"],
|
||||
top_n=optional_rerank_params.get("top_n", None),
|
||||
rank_fields=optional_rerank_params.get("rank_fields", None),
|
||||
return_documents=optional_rerank_params.get("return_documents", None),
|
||||
max_tokens_per_doc=optional_rerank_params.get("max_tokens_per_doc", None),
|
||||
)
|
||||
return rerank_request.model_dump(exclude_none=True)
|
|
@ -710,6 +710,7 @@ class BaseLLMHTTPHandler:
|
|||
model: str,
|
||||
custom_llm_provider: str,
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
provider_config: BaseRerankConfig,
|
||||
optional_rerank_params: OptionalRerankParams,
|
||||
timeout: Optional[Union[float, httpx.Timeout]],
|
||||
model_response: RerankResponse,
|
||||
|
@ -719,10 +720,7 @@ class BaseLLMHTTPHandler:
|
|||
api_base: Optional[str] = None,
|
||||
client: Optional[Union[HTTPHandler, AsyncHTTPHandler]] = None,
|
||||
) -> RerankResponse:
|
||||
|
||||
provider_config = ProviderConfigManager.get_provider_rerank_config(
|
||||
model=model, provider=litellm.LlmProviders(custom_llm_provider)
|
||||
)
|
||||
|
||||
# get config from model, custom llm provider
|
||||
headers = provider_config.validate_environment(
|
||||
api_key=api_key,
|
||||
|
|
|
@ -44,6 +44,7 @@ class JinaAIRerankConfig(BaseRerankConfig):
|
|||
rank_fields: Optional[List[str]] = None,
|
||||
return_documents: Optional[bool] = True,
|
||||
max_chunks_per_doc: Optional[int] = None,
|
||||
max_tokens_per_doc: Optional[int] = None,
|
||||
) -> OptionalRerankParams:
|
||||
optional_params = {}
|
||||
supported_params = self.get_supported_cohere_rerank_params(model)
|
||||
|
|
|
@ -239,6 +239,7 @@ class LiteLLMRoutes(enum.Enum):
|
|||
# rerank
|
||||
"/rerank",
|
||||
"/v1/rerank",
|
||||
"/v2/rerank"
|
||||
# realtime
|
||||
"/realtime",
|
||||
"/v1/realtime",
|
||||
|
|
|
@ -11,7 +11,12 @@ from litellm.proxy.auth.user_api_key_auth import user_api_key_auth
|
|||
router = APIRouter()
|
||||
import asyncio
|
||||
|
||||
|
||||
@router.post(
|
||||
"/v2/rerank",
|
||||
dependencies=[Depends(user_api_key_auth)],
|
||||
response_class=ORJSONResponse,
|
||||
tags=["rerank"],
|
||||
)
|
||||
@router.post(
|
||||
"/v1/rerank",
|
||||
dependencies=[Depends(user_api_key_auth)],
|
||||
|
|
|
@ -81,6 +81,7 @@ def rerank( # noqa: PLR0915
|
|||
rank_fields: Optional[List[str]] = None,
|
||||
return_documents: Optional[bool] = True,
|
||||
max_chunks_per_doc: Optional[int] = None,
|
||||
max_tokens_per_doc: Optional[int] = None,
|
||||
**kwargs,
|
||||
) -> Union[RerankResponse, Coroutine[Any, Any, RerankResponse]]:
|
||||
"""
|
||||
|
@ -97,6 +98,14 @@ def rerank( # noqa: PLR0915
|
|||
try:
|
||||
_is_async = kwargs.pop("arerank", False) is True
|
||||
optional_params = GenericLiteLLMParams(**kwargs)
|
||||
# Params that are unique to specific versions of the client for the rerank call
|
||||
unique_version_params = {
|
||||
"max_chunks_per_doc": max_chunks_per_doc,
|
||||
"max_tokens_per_doc": max_tokens_per_doc,
|
||||
}
|
||||
present_version_params = [
|
||||
k for k, v in unique_version_params.items() if v is not None
|
||||
]
|
||||
|
||||
model, _custom_llm_provider, dynamic_api_key, dynamic_api_base = (
|
||||
litellm.get_llm_provider(
|
||||
|
@ -111,6 +120,8 @@ def rerank( # noqa: PLR0915
|
|||
ProviderConfigManager.get_provider_rerank_config(
|
||||
model=model,
|
||||
provider=litellm.LlmProviders(_custom_llm_provider),
|
||||
api_base=optional_params.api_base,
|
||||
present_version_params=present_version_params,
|
||||
)
|
||||
)
|
||||
|
||||
|
@ -125,6 +136,7 @@ def rerank( # noqa: PLR0915
|
|||
rank_fields=rank_fields,
|
||||
return_documents=return_documents,
|
||||
max_chunks_per_doc=max_chunks_per_doc,
|
||||
max_tokens_per_doc=max_tokens_per_doc,
|
||||
non_default_params=kwargs,
|
||||
)
|
||||
|
||||
|
@ -171,6 +183,7 @@ def rerank( # noqa: PLR0915
|
|||
response = base_llm_http_handler.rerank(
|
||||
model=model,
|
||||
custom_llm_provider=_custom_llm_provider,
|
||||
provider_config=rerank_provider_config,
|
||||
optional_rerank_params=optional_rerank_params,
|
||||
logging_obj=litellm_logging_obj,
|
||||
timeout=optional_params.timeout,
|
||||
|
@ -192,6 +205,7 @@ def rerank( # noqa: PLR0915
|
|||
model=model,
|
||||
custom_llm_provider=_custom_llm_provider,
|
||||
optional_rerank_params=optional_rerank_params,
|
||||
provider_config=rerank_provider_config,
|
||||
logging_obj=litellm_logging_obj,
|
||||
timeout=optional_params.timeout,
|
||||
api_key=dynamic_api_key or optional_params.api_key,
|
||||
|
@ -220,6 +234,7 @@ def rerank( # noqa: PLR0915
|
|||
response = base_llm_http_handler.rerank(
|
||||
model=model,
|
||||
custom_llm_provider=_custom_llm_provider,
|
||||
provider_config=rerank_provider_config,
|
||||
optional_rerank_params=optional_rerank_params,
|
||||
logging_obj=litellm_logging_obj,
|
||||
timeout=optional_params.timeout,
|
||||
|
@ -275,6 +290,7 @@ def rerank( # noqa: PLR0915
|
|||
custom_llm_provider=_custom_llm_provider,
|
||||
optional_rerank_params=optional_rerank_params,
|
||||
logging_obj=litellm_logging_obj,
|
||||
provider_config=rerank_provider_config,
|
||||
timeout=optional_params.timeout,
|
||||
api_key=dynamic_api_key or optional_params.api_key,
|
||||
api_base=api_base,
|
||||
|
|
|
@ -15,6 +15,7 @@ def get_optional_rerank_params(
|
|||
rank_fields: Optional[List[str]] = None,
|
||||
return_documents: Optional[bool] = True,
|
||||
max_chunks_per_doc: Optional[int] = None,
|
||||
max_tokens_per_doc: Optional[int] = None,
|
||||
non_default_params: Optional[dict] = None,
|
||||
) -> OptionalRerankParams:
|
||||
all_non_default_params = non_default_params or {}
|
||||
|
@ -28,6 +29,8 @@ def get_optional_rerank_params(
|
|||
all_non_default_params["return_documents"] = return_documents
|
||||
if max_chunks_per_doc is not None:
|
||||
all_non_default_params["max_chunks_per_doc"] = max_chunks_per_doc
|
||||
if max_tokens_per_doc is not None:
|
||||
all_non_default_params["max_tokens_per_doc"] = max_tokens_per_doc
|
||||
return rerank_provider_config.map_cohere_rerank_params(
|
||||
model=model,
|
||||
drop_params=drop_params,
|
||||
|
@ -38,5 +41,6 @@ def get_optional_rerank_params(
|
|||
rank_fields=rank_fields,
|
||||
return_documents=return_documents,
|
||||
max_chunks_per_doc=max_chunks_per_doc,
|
||||
max_tokens_per_doc=max_tokens_per_doc,
|
||||
non_default_params=all_non_default_params,
|
||||
)
|
||||
|
|
|
@ -18,6 +18,8 @@ class RerankRequest(BaseModel):
|
|||
rank_fields: Optional[List[str]] = None
|
||||
return_documents: Optional[bool] = None
|
||||
max_chunks_per_doc: Optional[int] = None
|
||||
max_tokens_per_doc: Optional[int] = None
|
||||
|
||||
|
||||
|
||||
class OptionalRerankParams(TypedDict, total=False):
|
||||
|
@ -27,6 +29,7 @@ class OptionalRerankParams(TypedDict, total=False):
|
|||
rank_fields: Optional[List[str]]
|
||||
return_documents: Optional[bool]
|
||||
max_chunks_per_doc: Optional[int]
|
||||
max_tokens_per_doc: Optional[int]
|
||||
|
||||
|
||||
class RerankBilledUnits(TypedDict, total=False):
|
||||
|
|
|
@ -6191,9 +6191,14 @@ class ProviderConfigManager:
|
|||
def get_provider_rerank_config(
|
||||
model: str,
|
||||
provider: LlmProviders,
|
||||
api_base: Optional[str],
|
||||
present_version_params: List[str],
|
||||
) -> BaseRerankConfig:
|
||||
if litellm.LlmProviders.COHERE == provider:
|
||||
return litellm.CohereRerankConfig()
|
||||
if should_use_cohere_v1_client(api_base, present_version_params):
|
||||
return litellm.CohereRerankConfig()
|
||||
else:
|
||||
return litellm.CohereRerankV2Config()
|
||||
elif litellm.LlmProviders.AZURE_AI == provider:
|
||||
return litellm.AzureAIRerankConfig()
|
||||
elif litellm.LlmProviders.INFINITY == provider:
|
||||
|
@ -6277,6 +6282,12 @@ def get_end_user_id_for_cost_tracking(
|
|||
return None
|
||||
return end_user_id
|
||||
|
||||
def should_use_cohere_v1_client(api_base: Optional[str], present_version_params: List[str]):
|
||||
if not api_base:
|
||||
return False
|
||||
uses_v1_params = ("max_chunks_per_doc" in present_version_params) and ('max_tokens_per_doc' not in present_version_params)
|
||||
return api_base.endswith("/v1/rerank") or (uses_v1_params and not api_base.endswith("/v2/rerank"))
|
||||
|
||||
|
||||
def is_prompt_caching_valid_prompt(
|
||||
model: str,
|
||||
|
|
|
@ -1970,6 +1970,26 @@ def test_get_applied_guardrails(test_case):
|
|||
# Assert
|
||||
assert sorted(result) == sorted(test_case["expected"])
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"endpoint, params, expected_bool",
|
||||
[
|
||||
("localhost:4000/v1/rerank", ["max_chunks_per_doc"], True),
|
||||
("localhost:4000/v2/rerank", ["max_chunks_per_doc"], False),
|
||||
("localhost:4000", ["max_chunks_per_doc"], True),
|
||||
|
||||
("localhost:4000/v1/rerank", ["max_tokens_per_doc"], True),
|
||||
("localhost:4000/v2/rerank", ["max_tokens_per_doc"], False),
|
||||
("localhost:4000", ["max_tokens_per_doc"], False),
|
||||
|
||||
("localhost:4000/v1/rerank", ["max_chunks_per_doc", "max_tokens_per_doc"], True),
|
||||
("localhost:4000/v2/rerank", ["max_chunks_per_doc", "max_tokens_per_doc"], False),
|
||||
("localhost:4000", ["max_chunks_per_doc", "max_tokens_per_doc"], False),
|
||||
|
||||
],
|
||||
)
|
||||
def test_should_use_cohere_v1_client(endpoint, params, expected_bool):
|
||||
assert(litellm.utils.should_use_cohere_v1_client(endpoint, params) == expected_bool)
|
||||
|
||||
|
||||
def test_add_openai_metadata():
|
||||
from litellm.utils import add_openai_metadata
|
||||
|
|
|
@ -111,35 +111,41 @@ async def test_basic_rerank(sync_mode):
|
|||
|
||||
@pytest.mark.asyncio()
|
||||
@pytest.mark.parametrize("sync_mode", [True, False])
|
||||
@pytest.mark.skip(reason="Skipping test due to 503 Service Temporarily Unavailable")
|
||||
async def test_basic_rerank_together_ai(sync_mode):
|
||||
if sync_mode is True:
|
||||
response = litellm.rerank(
|
||||
model="together_ai/Salesforce/Llama-Rank-V1",
|
||||
query="hello",
|
||||
documents=["hello", "world"],
|
||||
top_n=3,
|
||||
)
|
||||
try:
|
||||
if sync_mode is True:
|
||||
response = litellm.rerank(
|
||||
model="together_ai/Salesforce/Llama-Rank-V1",
|
||||
query="hello",
|
||||
documents=["hello", "world"],
|
||||
top_n=3,
|
||||
)
|
||||
|
||||
print("re rank response: ", response)
|
||||
print("re rank response: ", response)
|
||||
|
||||
assert response.id is not None
|
||||
assert response.results is not None
|
||||
assert response.id is not None
|
||||
assert response.results is not None
|
||||
|
||||
assert_response_shape(response, custom_llm_provider="together_ai")
|
||||
else:
|
||||
response = await litellm.arerank(
|
||||
model="together_ai/Salesforce/Llama-Rank-V1",
|
||||
query="hello",
|
||||
documents=["hello", "world"],
|
||||
top_n=3,
|
||||
)
|
||||
assert_response_shape(response, custom_llm_provider="together_ai")
|
||||
else:
|
||||
response = await litellm.arerank(
|
||||
model="together_ai/Salesforce/Llama-Rank-V1",
|
||||
query="hello",
|
||||
documents=["hello", "world"],
|
||||
top_n=3,
|
||||
)
|
||||
|
||||
print("async re rank response: ", response)
|
||||
print("async re rank response: ", response)
|
||||
|
||||
assert response.id is not None
|
||||
assert response.results is not None
|
||||
assert response.id is not None
|
||||
assert response.results is not None
|
||||
|
||||
assert_response_shape(response, custom_llm_provider="together_ai")
|
||||
assert_response_shape(response, custom_llm_provider="together_ai")
|
||||
except Exception as e:
|
||||
if "Service unavailable" in str(e):
|
||||
pytest.skip("Skipping test due to 503 Service Temporarily Unavailable")
|
||||
raise e
|
||||
|
||||
|
||||
@pytest.mark.asyncio()
|
||||
|
@ -184,8 +190,10 @@ async def test_basic_rerank_azure_ai(sync_mode):
|
|||
|
||||
|
||||
@pytest.mark.asyncio()
|
||||
async def test_rerank_custom_api_base():
|
||||
@pytest.mark.parametrize("version", ["v1", "v2"])
|
||||
async def test_rerank_custom_api_base(version):
|
||||
mock_response = AsyncMock()
|
||||
litellm.cohere_key = "test_api_key"
|
||||
|
||||
def return_val():
|
||||
return {
|
||||
|
@ -208,6 +216,10 @@ async def test_rerank_custom_api_base():
|
|||
"documents": ["hello", "world"],
|
||||
}
|
||||
|
||||
api_base = "https://exampleopenaiendpoint-production.up.railway.app/"
|
||||
if version == "v1":
|
||||
api_base += "v1/rerank"
|
||||
|
||||
with patch(
|
||||
"litellm.llms.custom_httpx.http_handler.AsyncHTTPHandler.post",
|
||||
return_value=mock_response,
|
||||
|
@ -217,7 +229,7 @@ async def test_rerank_custom_api_base():
|
|||
query="hello",
|
||||
documents=["hello", "world"],
|
||||
top_n=3,
|
||||
api_base="https://exampleopenaiendpoint-production.up.railway.app/",
|
||||
api_base=api_base,
|
||||
)
|
||||
|
||||
print("async re rank response: ", response)
|
||||
|
@ -230,7 +242,8 @@ async def test_rerank_custom_api_base():
|
|||
print("Arguments passed to API=", args_to_api)
|
||||
print("url = ", _url)
|
||||
assert (
|
||||
_url == "https://exampleopenaiendpoint-production.up.railway.app/v1/rerank"
|
||||
_url
|
||||
== f"https://exampleopenaiendpoint-production.up.railway.app/{version}/rerank"
|
||||
)
|
||||
|
||||
request_data = json.loads(args_to_api)
|
||||
|
@ -287,6 +300,7 @@ def test_complete_base_url_cohere():
|
|||
|
||||
client = HTTPHandler()
|
||||
litellm.api_base = "http://localhost:4000"
|
||||
litellm.cohere_key = "test_api_key"
|
||||
litellm.set_verbose = True
|
||||
|
||||
text = "Hello there!"
|
||||
|
@ -308,7 +322,8 @@ def test_complete_base_url_cohere():
|
|||
|
||||
print("mock_post.call_args", mock_post.call_args)
|
||||
mock_post.assert_called_once()
|
||||
assert "http://localhost:4000/v1/rerank" in mock_post.call_args.kwargs["url"]
|
||||
# Default to the v2 client when calling the base /rerank
|
||||
assert "http://localhost:4000/v2/rerank" in mock_post.call_args.kwargs["url"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio()
|
||||
|
@ -395,6 +410,63 @@ def test_rerank_response_assertions():
|
|||
assert_response_shape(r, custom_llm_provider="custom")
|
||||
|
||||
|
||||
def test_cohere_rerank_v2_client():
|
||||
from litellm.llms.custom_httpx.http_handler import HTTPHandler
|
||||
|
||||
client = HTTPHandler()
|
||||
litellm.api_base = "http://localhost:4000"
|
||||
litellm.set_verbose = True
|
||||
|
||||
text = "Hello there!"
|
||||
list_texts = ["Hello there!", "How are you?", "How do you do?"]
|
||||
|
||||
rerank_model = "rerank-multilingual-v3.0"
|
||||
|
||||
with patch.object(client, "post") as mock_post:
|
||||
mock_response = MagicMock()
|
||||
mock_response.text = json.dumps(
|
||||
{
|
||||
"id": "cmpl-mockid",
|
||||
"results": [
|
||||
{"index": 0, "relevance_score": 0.95},
|
||||
{"index": 1, "relevance_score": 0.75},
|
||||
{"index": 2, "relevance_score": 0.65},
|
||||
],
|
||||
"usage": {"prompt_tokens": 100, "total_tokens": 150},
|
||||
}
|
||||
)
|
||||
mock_response.status_code = 200
|
||||
mock_response.headers = {"Content-Type": "application/json"}
|
||||
mock_response.json = lambda: json.loads(mock_response.text)
|
||||
|
||||
mock_post.return_value = mock_response
|
||||
|
||||
response = litellm.rerank(
|
||||
model=rerank_model,
|
||||
query=text,
|
||||
documents=list_texts,
|
||||
custom_llm_provider="cohere",
|
||||
max_tokens_per_doc=3,
|
||||
top_n=2,
|
||||
api_key="fake-api-key",
|
||||
client=client,
|
||||
)
|
||||
|
||||
# Ensure Cohere API is called with the expected params
|
||||
mock_post.assert_called_once()
|
||||
assert mock_post.call_args.kwargs["url"] == "http://localhost:4000/v2/rerank"
|
||||
|
||||
request_data = json.loads(mock_post.call_args.kwargs["data"])
|
||||
assert request_data["model"] == rerank_model
|
||||
assert request_data["query"] == text
|
||||
assert request_data["documents"] == list_texts
|
||||
assert request_data["max_tokens_per_doc"] == 3
|
||||
assert request_data["top_n"] == 2
|
||||
|
||||
# Ensure litellm response is what we expect
|
||||
assert response["results"] == mock_response.json()["results"]
|
||||
|
||||
|
||||
@pytest.mark.flaky(retries=3, delay=1)
|
||||
def test_rerank_cohere_api():
|
||||
response = litellm.rerank(
|
||||
|
|
|
@ -961,7 +961,8 @@ async def test_gemini_embeddings(sync_mode, input):
|
|||
|
||||
@pytest.mark.parametrize("sync_mode", [True, False])
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.flaky(retries=3, delay=1)
|
||||
@pytest.mark.flaky(retries=6, delay=1)
|
||||
@pytest.mark.skip(reason="Skipping test due to flakyness")
|
||||
async def test_hf_embedddings_with_optional_params(sync_mode):
|
||||
litellm.set_verbose = True
|
||||
|
||||
|
@ -992,8 +993,8 @@ async def test_hf_embedddings_with_optional_params(sync_mode):
|
|||
wait_for_model=True,
|
||||
client=client,
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
except Exception as e:
|
||||
print(e)
|
||||
|
||||
mock_client.assert_called_once()
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue