From a80b2aebbb7348fcda3d35718a5857797c498de1 Mon Sep 17 00:00:00 2001 From: Ishaan Jaff Date: Tue, 27 Aug 2024 18:25:51 -0700 Subject: [PATCH] add test for rerank on custom api base --- litellm/llms/cohere/rerank.py | 8 ++++-- litellm/rerank_api/main.py | 5 ++-- litellm/tests/test_rerank.py | 52 +++++++++++++++++++++++++++++++++++ 3 files changed, 60 insertions(+), 5 deletions(-) diff --git a/litellm/llms/cohere/rerank.py b/litellm/llms/cohere/rerank.py index a2a7476df8..97cd7e3998 100644 --- a/litellm/llms/cohere/rerank.py +++ b/litellm/llms/cohere/rerank.py @@ -22,6 +22,7 @@ class CohereRerank(BaseLLM): self, model: str, api_key: str, + api_base: str, query: str, documents: List[Union[str, Dict[str, Any]]], top_n: Optional[int] = None, @@ -43,11 +44,11 @@ class CohereRerank(BaseLLM): request_data_dict = request_data.dict(exclude_none=True) if _is_async: - return self.async_rerank(request_data_dict, api_key) # type: ignore # Call async method + return self.async_rerank(request_data_dict, api_key, api_base) # type: ignore # Call async method client = _get_httpx_client() response = client.post( - "https://api.cohere.com/v1/rerank", + api_base, headers={ "accept": "application/json", "content-type": "application/json", @@ -62,11 +63,12 @@ class CohereRerank(BaseLLM): self, request_data_dict: Dict[str, Any], api_key: str, + api_base: str, ) -> RerankResponse: client = _get_async_httpx_client() response = await client.post( - "https://api.cohere.com/v1/rerank", + api_base, headers={ "accept": "application/json", "content-type": "application/json", diff --git a/litellm/rerank_api/main.py b/litellm/rerank_api/main.py index 968b9b562c..2e69f28180 100644 --- a/litellm/rerank_api/main.py +++ b/litellm/rerank_api/main.py @@ -27,7 +27,7 @@ async def arerank( custom_llm_provider: Optional[Literal["cohere", "together_ai"]] = None, top_n: Optional[int] = None, rank_fields: Optional[List[str]] = None, - return_documents: Optional[bool] = True, + return_documents: Optional[bool] = None, max_chunks_per_doc: Optional[int] = None, **kwargs, ) -> Union[RerankResponse, Coroutine[Any, Any, RerankResponse]]: @@ -112,7 +112,7 @@ def rerank( optional_params.api_base or litellm.api_base or get_secret("COHERE_API_BASE") - or "https://api.cohere.ai/v1/generate" + or "https://api.cohere.com/v1/rerank" ) headers: Dict = litellm.headers or {} @@ -126,6 +126,7 @@ def rerank( return_documents=return_documents, max_chunks_per_doc=max_chunks_per_doc, api_key=cohere_key, + api_base=api_base, _is_async=_is_async, ) pass diff --git a/litellm/tests/test_rerank.py b/litellm/tests/test_rerank.py index a0127063f9..4e70424bc3 100644 --- a/litellm/tests/test_rerank.py +++ b/litellm/tests/test_rerank.py @@ -125,3 +125,55 @@ async def test_basic_rerank_together_ai(sync_mode): assert response.results is not None assert_response_shape(response, custom_llm_provider="together_ai") + + +@pytest.mark.asyncio() +async def test_rerank_custom_api_base(): + mock_response = AsyncMock() + + def return_val(): + return { + "id": "cmpl-mockid", + "results": [{"index": 0, "relevance_score": 0.95}], + "meta": { + "api_version": {"version": "1.0"}, + "billed_units": {"search_units": 1}, + }, + } + + mock_response.json = return_val + mock_response.status_code = 200 + + expected_payload = { + "model": "Salesforce/Llama-Rank-V1", + "query": "hello", + "documents": ["hello", "world"], + "top_n": 3, + } + + with patch( + "litellm.llms.custom_httpx.http_handler.AsyncHTTPHandler.post", + return_value=mock_response, + ) as mock_post: + response = await litellm.arerank( + model="cohere/Salesforce/Llama-Rank-V1", + query="hello", + documents=["hello", "world"], + top_n=3, + api_base="https://exampleopenaiendpoint-production.up.railway.app/", + ) + + print("async re rank response: ", response) + + # Assert + mock_post.assert_called_once() + _url, kwargs = mock_post.call_args + args_to_api = kwargs["json"] + print("Arguments passed to API=", args_to_api) + print("url = ", _url) + assert _url[0] == "https://exampleopenaiendpoint-production.up.railway.app/" + assert args_to_api == expected_payload + assert response.id is not None + assert response.results is not None + + assert_response_shape(response, custom_llm_provider="cohere")