diff --git a/litellm/llms/cohere/rerank.py b/litellm/llms/cohere/rerank.py index a547ea2189..0c00ea03c9 100644 --- a/litellm/llms/cohere/rerank.py +++ b/litellm/llms/cohere/rerank.py @@ -4,6 +4,8 @@ Re rank api LiteLLM supports the re rank API format, no paramter transformation occurs """ +from typing import Any, Dict, List, Optional, Union + import httpx from pydantic import BaseModel @@ -21,14 +23,26 @@ class CohereRerank(BaseLLM): model: str, api_key: str, query: str, - documents: list[str], - top_n: int = 3, + documents: list[Union[str, Dict[str, Any]]], + top_n: Optional[int] = None, + rank_fields: Optional[List[str]] = None, + return_documents: Optional[bool] = True, + max_chunks_per_doc: Optional[int] = None, ) -> RerankResponse: client = _get_httpx_client() + request_data = RerankRequest( - model=model, query=query, top_n=top_n, documents=documents + model=model, + query=query, + top_n=top_n, + documents=documents, + rank_fields=rank_fields, + return_documents=return_documents, + max_chunks_per_doc=max_chunks_per_doc, ) + request_data_dict = request_data.dict(exclude_none=True) + response = client.post( "https://api.cohere.com/v1/rerank", headers={ @@ -36,7 +50,7 @@ class CohereRerank(BaseLLM): "content-type": "application/json", "Authorization": f"bearer {api_key}", }, - json=request_data.dict(), + json=request_data_dict, ) return RerankResponse(**response.json()) diff --git a/litellm/llms/togetherai/rerank.py b/litellm/llms/togetherai/rerank.py index b4020fc651..8a5a466852 100644 --- a/litellm/llms/togetherai/rerank.py +++ b/litellm/llms/togetherai/rerank.py @@ -4,6 +4,8 @@ Re rank api LiteLLM supports the re rank API format, no paramter transformation occurs """ +from typing import Any, Dict, List, Optional, Union + import httpx from pydantic import BaseModel @@ -21,15 +23,28 @@ class TogetherAIRerank(BaseLLM): model: str, api_key: str, query: str, - documents: list[str], - top_n: int = 3, + documents: list[Union[str, Dict[str, Any]]], + top_n: Optional[int] = None, + rank_fields: Optional[List[str]] = None, + return_documents: Optional[bool] = True, + max_chunks_per_doc: Optional[int] = None, ) -> RerankResponse: client = _get_httpx_client() request_data = RerankRequest( - model=model, query=query, top_n=top_n, documents=documents + model=model, + query=query, + top_n=top_n, + documents=documents, + rank_fields=rank_fields, + return_documents=return_documents, ) + # exclude None values from request_data + request_data_dict = request_data.dict(exclude_none=True) + if max_chunks_per_doc is not None: + raise ValueError("TogetherAI does not support max_chunks_per_doc") + response = client.post( "https://api.together.xyz/v1/rerank", headers={ @@ -37,10 +52,14 @@ class TogetherAIRerank(BaseLLM): "content-type": "application/json", "authorization": f"Bearer {api_key}", }, - json=request_data.dict(), + json=request_data_dict, ) + if response.status_code != 200: + raise Exception(response.text) + _json_response = response.json() + response = RerankResponse( id=_json_response.get("id"), results=_json_response.get("results"), diff --git a/litellm/rerank_api/main.py b/litellm/rerank_api/main.py index bb0094d001..6d3a27f549 100644 --- a/litellm/rerank_api/main.py +++ b/litellm/rerank_api/main.py @@ -55,9 +55,12 @@ async def arerank( def rerank( model: str, query: str, - documents: List[str], + documents: List[Union[str, Dict[str, Any]]], custom_llm_provider: Optional[Literal["cohere", "together_ai"]] = None, - top_n: int = 3, + top_n: Optional[int] = None, + rank_fields: Optional[List[str]] = None, + return_documents: Optional[bool] = True, + max_chunks_per_doc: Optional[int] = None, **kwargs, ) -> Union[RerankResponse, Coroutine[Any, Any, RerankResponse]]: """ @@ -107,6 +110,9 @@ def rerank( query=query, documents=documents, top_n=top_n, + rank_fields=rank_fields, + return_documents=return_documents, + max_chunks_per_doc=max_chunks_per_doc, api_key=cohere_key, ) pass @@ -130,6 +136,9 @@ def rerank( query=query, documents=documents, top_n=top_n, + rank_fields=rank_fields, + return_documents=return_documents, + max_chunks_per_doc=max_chunks_per_doc, api_key=together_key, ) diff --git a/litellm/rerank_api/types.py b/litellm/rerank_api/types.py index 9d53cf278c..605e25a2ec 100644 --- a/litellm/rerank_api/types.py +++ b/litellm/rerank_api/types.py @@ -4,17 +4,22 @@ https://docs.cohere.com/reference/rerank """ +from typing import List, Optional, Union + from pydantic import BaseModel class RerankRequest(BaseModel): model: str query: str - top_n: int - documents: list[str] + top_n: Optional[int] = None + documents: List[Union[str, dict]] + rank_fields: Optional[List[str]] = None + return_documents: Optional[bool] = None + max_chunks_per_doc: Optional[int] = None class RerankResponse(BaseModel): id: str - results: list[dict] # Contains index and relevance_score + results: List[dict] # Contains index and relevance_score meta: dict # Contains api_version and billed_units