diff --git a/litellm/llms/together_ai/rerank.py b/litellm/llms/together_ai/rerank/handler.py similarity index 84% rename from litellm/llms/together_ai/rerank.py rename to litellm/llms/together_ai/rerank/handler.py index 1be73af2d..3e6d5d667 100644 --- a/litellm/llms/together_ai/rerank.py +++ b/litellm/llms/together_ai/rerank/handler.py @@ -15,7 +15,14 @@ from litellm.llms.custom_httpx.http_handler import ( _get_httpx_client, get_async_httpx_client, ) -from litellm.types.rerank import RerankRequest, RerankResponse +from litellm.llms.together_ai.rerank.transformation import TogetherAIRerankConfig +from litellm.types.rerank import ( + RerankBilledUnits, + RerankRequest, + RerankResponse, + RerankResponseMeta, + RerankTokens, +) class TogetherAIRerank(BaseLLM): @@ -65,13 +72,7 @@ class TogetherAIRerank(BaseLLM): _json_response = response.json() - response = RerankResponse( - id=_json_response.get("id"), - results=_json_response.get("results"), - meta=_json_response.get("meta") or {}, - ) - - return response + return TogetherAIRerankConfig()._transform_response(_json_response) async def async_rerank( # New async method self, @@ -97,10 +98,4 @@ class TogetherAIRerank(BaseLLM): _json_response = response.json() - return RerankResponse( - id=_json_response.get("id"), - results=_json_response.get("results"), - meta=_json_response.get("meta") or {}, - ) # Return response - - pass + return TogetherAIRerankConfig()._transform_response(_json_response) diff --git a/litellm/llms/together_ai/rerank/transformation.py b/litellm/llms/together_ai/rerank/transformation.py new file mode 100644 index 000000000..b2024b5cd --- /dev/null +++ b/litellm/llms/together_ai/rerank/transformation.py @@ -0,0 +1,34 @@ +""" +Transformation logic from Cohere's /v1/rerank format to Together AI's `/v1/rerank` format. + +Why separate file? Make it easy to see how transformation works +""" + +import uuid +from typing import List, Optional + +from litellm.types.rerank import ( + RerankBilledUnits, + RerankResponse, + RerankResponseMeta, + RerankTokens, +) + + +class TogetherAIRerankConfig: + def _transform_response(self, response: dict) -> RerankResponse: + + _billed_units = RerankBilledUnits(**response.get("usage", {})) + _tokens = RerankTokens(**response.get("usage", {})) + rerank_meta = RerankResponseMeta(billed_units=_billed_units, tokens=_tokens) + + _results: Optional[List[dict]] = response.get("results") + + if _results is None: + raise ValueError(f"No results found in the response={response}") + + return RerankResponse( + id=response.get("id") or str(uuid.uuid4()), + results=_results, + meta=rerank_meta, + ) # Return response diff --git a/litellm/rerank_api/main.py b/litellm/rerank_api/main.py index 70353acad..9cc8a8c1d 100644 --- a/litellm/rerank_api/main.py +++ b/litellm/rerank_api/main.py @@ -9,7 +9,7 @@ from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLogging from litellm.llms.azure_ai.rerank import AzureAIRerank from litellm.llms.cohere.rerank import CohereRerank from litellm.llms.jina_ai.rerank.handler import JinaAIRerank -from litellm.llms.together_ai.rerank import TogetherAIRerank +from litellm.llms.together_ai.rerank.handler import TogetherAIRerank from litellm.secret_managers.main import get_secret from litellm.types.rerank import RerankRequest, RerankResponse from litellm.types.router import *