diff --git a/litellm/llms/cohere/rerank.py b/litellm/llms/cohere/rerank.py index 0c00ea03c9..4ef523e3a1 100644 --- a/litellm/llms/cohere/rerank.py +++ b/litellm/llms/cohere/rerank.py @@ -28,9 +28,8 @@ class CohereRerank(BaseLLM): rank_fields: Optional[List[str]] = None, return_documents: Optional[bool] = True, max_chunks_per_doc: Optional[int] = None, + _is_async: Optional[bool] = False, # New parameter ) -> RerankResponse: - client = _get_httpx_client() - request_data = RerankRequest( model=model, query=query, @@ -43,6 +42,10 @@ class CohereRerank(BaseLLM): request_data_dict = request_data.dict(exclude_none=True) + if _is_async: + return self.async_rerank(request_data_dict, api_key) # type: ignore # Call async method + + client = _get_httpx_client() response = client.post( "https://api.cohere.com/v1/rerank", headers={ @@ -55,4 +58,21 @@ class CohereRerank(BaseLLM): return RerankResponse(**response.json()) - pass + async def async_rerank( + self, + request_data_dict: Dict[str, Any], + api_key: str, + ) -> RerankResponse: + client = _get_async_httpx_client() + + response = await client.post( + "https://api.cohere.com/v1/rerank", + headers={ + "accept": "application/json", + "content-type": "application/json", + "Authorization": f"bearer {api_key}", + }, + json=request_data_dict, + ) + + return RerankResponse(**response.json()) diff --git a/litellm/llms/togetherai/rerank.py b/litellm/llms/togetherai/rerank.py index 8a5a466852..32a8cdcfdc 100644 --- a/litellm/llms/togetherai/rerank.py +++ b/litellm/llms/togetherai/rerank.py @@ -28,6 +28,7 @@ class TogetherAIRerank(BaseLLM): rank_fields: Optional[List[str]] = None, return_documents: Optional[bool] = True, max_chunks_per_doc: Optional[int] = None, + _is_async: Optional[bool] = False, ) -> RerankResponse: client = _get_httpx_client() @@ -45,6 +46,9 @@ class TogetherAIRerank(BaseLLM): if max_chunks_per_doc is not None: raise ValueError("TogetherAI does not support max_chunks_per_doc") + if _is_async: + return self.async_rerank(request_data_dict, api_key) # type: ignore # Call async method + response = client.post( "https://api.together.xyz/v1/rerank", headers={ @@ -68,4 +72,32 @@ class TogetherAIRerank(BaseLLM): return response + async def async_rerank( # New async method + self, + request_data_dict: Dict[str, Any], + api_key: str, + ) -> RerankResponse: + client = _get_async_httpx_client() # Use async client + + response = await client.post( + "https://api.together.xyz/v1/rerank", + headers={ + "accept": "application/json", + "content-type": "application/json", + "authorization": f"Bearer {api_key}", + }, + json=request_data_dict, + ) + + if response.status_code != 200: + raise Exception(response.text) + + _json_response = response.json() + + return RerankResponse( + id=_json_response.get("id"), + results=_json_response.get("results"), + meta=_json_response.get("meta") or {}, + ) # Return response + pass diff --git a/litellm/rerank_api/main.py b/litellm/rerank_api/main.py index 6d3a27f549..968b9b562c 100644 --- a/litellm/rerank_api/main.py +++ b/litellm/rerank_api/main.py @@ -23,11 +23,14 @@ together_rerank = TogetherAIRerank() async def arerank( model: str, query: str, - documents: List[str], - custom_llm_provider: Literal["cohere", "together_ai"] = "cohere", - top_n: int = 3, + documents: List[Union[str, Dict[str, Any]]], + custom_llm_provider: Optional[Literal["cohere", "together_ai"]] = None, + top_n: Optional[int] = None, + rank_fields: Optional[List[str]] = None, + return_documents: Optional[bool] = True, + max_chunks_per_doc: Optional[int] = None, **kwargs, -) -> Dict[str, Any]: +) -> Union[RerankResponse, Coroutine[Any, Any, RerankResponse]]: """ Async: Reranks a list of documents based on their relevance to the query """ @@ -36,7 +39,16 @@ async def arerank( kwargs["arerank"] = True func = partial( - rerank, model, query, documents, custom_llm_provider, top_n, **kwargs + rerank, + model, + query, + documents, + custom_llm_provider, + top_n, + rank_fields, + return_documents, + max_chunks_per_doc, + **kwargs, ) ctx = contextvars.copy_context() @@ -114,6 +126,7 @@ def rerank( return_documents=return_documents, max_chunks_per_doc=max_chunks_per_doc, api_key=cohere_key, + _is_async=_is_async, ) pass elif _custom_llm_provider == "together_ai": @@ -140,6 +153,7 @@ def rerank( return_documents=return_documents, max_chunks_per_doc=max_chunks_per_doc, api_key=together_key, + _is_async=_is_async, ) else: diff --git a/litellm/tests/test_rerank.py b/litellm/tests/test_rerank.py index 946bfbb970..a0127063f9 100644 --- a/litellm/tests/test_rerank.py +++ b/litellm/tests/test_rerank.py @@ -61,33 +61,67 @@ def assert_response_shape(response, custom_llm_provider): ) -def test_basic_rerank(): - response = litellm.rerank( - model="cohere/rerank-english-v3.0", - query="hello", - documents=["hello", "world"], - top_n=3, - ) +@pytest.mark.asyncio() +@pytest.mark.parametrize("sync_mode", [True, False]) +async def test_basic_rerank(sync_mode): + if sync_mode is True: + response = litellm.rerank( + model="cohere/rerank-english-v3.0", + query="hello", + documents=["hello", "world"], + top_n=3, + ) - print("re rank response: ", response) + print("re rank response: ", response) - assert response.id is not None - assert response.results is not None + assert response.id is not None + assert response.results is not None - assert_response_shape(response, custom_llm_provider="cohere") + assert_response_shape(response, custom_llm_provider="cohere") + else: + response = await litellm.arerank( + model="cohere/rerank-english-v3.0", + query="hello", + documents=["hello", "world"], + top_n=3, + ) + + print("async re rank response: ", response) + + assert response.id is not None + assert response.results is not None + + assert_response_shape(response, custom_llm_provider="cohere") -def test_basic_rerank_together_ai(): - response = litellm.rerank( - model="together_ai/Salesforce/Llama-Rank-V1", - query="hello", - documents=["hello", "world"], - top_n=3, - ) +@pytest.mark.asyncio() +@pytest.mark.parametrize("sync_mode", [True, False]) +async def test_basic_rerank_together_ai(sync_mode): + if sync_mode is True: + response = litellm.rerank( + model="together_ai/Salesforce/Llama-Rank-V1", + query="hello", + documents=["hello", "world"], + top_n=3, + ) - print("re rank response: ", response) + print("re rank response: ", response) - assert response.id is not None - assert response.results is not None + assert response.id is not None + assert response.results is not None - assert_response_shape(response, custom_llm_provider="together_ai") + assert_response_shape(response, custom_llm_provider="together_ai") + else: + response = await litellm.arerank( + model="together_ai/Salesforce/Llama-Rank-V1", + query="hello", + documents=["hello", "world"], + top_n=3, + ) + + print("async re rank response: ", response) + + assert response.id is not None + assert response.results is not None + + assert_response_shape(response, custom_llm_provider="together_ai")