diff --git a/litellm/llms/cohere/rerank.py b/litellm/llms/cohere/rerank.py new file mode 100644 index 0000000000..a547ea2189 --- /dev/null +++ b/litellm/llms/cohere/rerank.py @@ -0,0 +1,44 @@ +""" +Re rank api + +LiteLLM supports the re rank API format, no paramter transformation occurs +""" + +import httpx +from pydantic import BaseModel + +from litellm.llms.base import BaseLLM +from litellm.llms.custom_httpx.http_handler import ( + _get_async_httpx_client, + _get_httpx_client, +) +from litellm.rerank_api.types import RerankRequest, RerankResponse + + +class CohereRerank(BaseLLM): + def rerank( + self, + model: str, + api_key: str, + query: str, + documents: list[str], + top_n: int = 3, + ) -> RerankResponse: + client = _get_httpx_client() + request_data = RerankRequest( + model=model, query=query, top_n=top_n, documents=documents + ) + + response = client.post( + "https://api.cohere.com/v1/rerank", + headers={ + "accept": "application/json", + "content-type": "application/json", + "Authorization": f"bearer {api_key}", + }, + json=request_data.dict(), + ) + + return RerankResponse(**response.json()) + + pass diff --git a/litellm/llms/togetherai/rerank.py b/litellm/llms/togetherai/rerank.py new file mode 100644 index 0000000000..b4020fc651 --- /dev/null +++ b/litellm/llms/togetherai/rerank.py @@ -0,0 +1,52 @@ +""" +Re rank api + +LiteLLM supports the re rank API format, no paramter transformation occurs +""" + +import httpx +from pydantic import BaseModel + +from litellm.llms.base import BaseLLM +from litellm.llms.custom_httpx.http_handler import ( + _get_async_httpx_client, + _get_httpx_client, +) +from litellm.rerank_api.types import RerankRequest, RerankResponse + + +class TogetherAIRerank(BaseLLM): + def rerank( + self, + model: str, + api_key: str, + query: str, + documents: list[str], + top_n: int = 3, + ) -> RerankResponse: + client = _get_httpx_client() + + request_data = RerankRequest( + model=model, query=query, top_n=top_n, documents=documents + ) + + response = client.post( + "https://api.together.xyz/v1/rerank", + headers={ + "accept": "application/json", + "content-type": "application/json", + "authorization": f"Bearer {api_key}", + }, + json=request_data.dict(), + ) + + _json_response = response.json() + response = RerankResponse( + id=_json_response.get("id"), + results=_json_response.get("results"), + meta=_json_response.get("meta") or {}, + ) + + return response + + pass diff --git a/litellm/rerank_api/main.py b/litellm/rerank_api/main.py index c65dca503e..bb0094d001 100644 --- a/litellm/rerank_api/main.py +++ b/litellm/rerank_api/main.py @@ -7,6 +7,7 @@ import litellm from litellm import get_secret from litellm._logging import verbose_logger from litellm.llms.cohere.rerank import CohereRerank +from litellm.llms.togetherai.rerank import TogetherAIRerank from litellm.types.router import * from litellm.utils import supports_httpx_timeout @@ -15,6 +16,7 @@ from .types import RerankRequest, RerankResponse ####### ENVIRONMENT VARIABLES ################### # Initialize any necessary instances or variables here cohere_rerank = CohereRerank() +together_rerank = TogetherAIRerank() ################################################# @@ -54,7 +56,7 @@ def rerank( model: str, query: str, documents: List[str], - custom_llm_provider: Literal["cohere", "together_ai"] = "cohere", + custom_llm_provider: Optional[Literal["cohere", "together_ai"]] = None, top_n: int = 3, **kwargs, ) -> Union[RerankResponse, Coroutine[Any, Any, RerankResponse]]: @@ -65,11 +67,21 @@ def rerank( _is_async = kwargs.pop("arerank", False) is True optional_params = GenericLiteLLMParams(**kwargs) + model, _custom_llm_provider, dynamic_api_key, api_base = ( + litellm.get_llm_provider( + model=model, + custom_llm_provider=custom_llm_provider, + api_base=optional_params.api_base, + api_key=optional_params.api_key, + ) + ) + # Implement rerank logic here based on the custom_llm_provider - if custom_llm_provider == "cohere": + if _custom_llm_provider == "cohere": # Implement Cohere rerank logic cohere_key = ( - optional_params.api_key + dynamic_api_key + or optional_params.api_key or litellm.cohere_key or get_secret("COHERE_API_KEY") or get_secret("CO_API_KEY") @@ -98,11 +110,31 @@ def rerank( api_key=cohere_key, ) pass - elif custom_llm_provider == "together_ai": + elif _custom_llm_provider == "together_ai": # Implement Together AI rerank logic - pass + together_key = ( + dynamic_api_key + or optional_params.api_key + or litellm.togetherai_api_key + or get_secret("TOGETHERAI_API_KEY") + or litellm.api_key + ) + + if together_key is None: + raise ValueError( + "TogetherAI API key is required, please set 'TOGETHERAI_API_KEY' in your environment" + ) + + response = together_rerank.rerank( + model=model, + query=query, + documents=documents, + top_n=top_n, + api_key=together_key, + ) + else: - raise ValueError(f"Unsupported provider: {custom_llm_provider}") + raise ValueError(f"Unsupported provider: {_custom_llm_provider}") # Placeholder return return response