fix(handler.py): refactor together ai rerank call

This commit is contained in:
Krrish Dholakia 2024-11-14 18:54:08 +05:30
parent 8bf0f57b59
commit fef31a1c01
3 changed files with 45 additions and 16 deletions

View file

@ -15,7 +15,14 @@ from litellm.llms.custom_httpx.http_handler import (
_get_httpx_client,
get_async_httpx_client,
)
from litellm.types.rerank import RerankRequest, RerankResponse
from litellm.llms.together_ai.rerank.transformation import TogetherAIRerankConfig
from litellm.types.rerank import (
RerankBilledUnits,
RerankRequest,
RerankResponse,
RerankResponseMeta,
RerankTokens,
)
class TogetherAIRerank(BaseLLM):
@ -65,13 +72,7 @@ class TogetherAIRerank(BaseLLM):
_json_response = response.json()
response = RerankResponse(
id=_json_response.get("id"),
results=_json_response.get("results"),
meta=_json_response.get("meta") or {},
)
return response
return TogetherAIRerankConfig()._transform_response(_json_response)
async def async_rerank( # New async method
self,
@ -97,10 +98,4 @@ class TogetherAIRerank(BaseLLM):
_json_response = response.json()
return RerankResponse(
id=_json_response.get("id"),
results=_json_response.get("results"),
meta=_json_response.get("meta") or {},
) # Return response
pass
return TogetherAIRerankConfig()._transform_response(_json_response)

View file

@ -0,0 +1,34 @@
"""
Transformation logic from Cohere's /v1/rerank format to Together AI's `/v1/rerank` format.
Why separate file? Make it easy to see how transformation works
"""
import uuid
from typing import List, Optional
from litellm.types.rerank import (
RerankBilledUnits,
RerankResponse,
RerankResponseMeta,
RerankTokens,
)
class TogetherAIRerankConfig:
def _transform_response(self, response: dict) -> RerankResponse:
_billed_units = RerankBilledUnits(**response.get("usage", {}))
_tokens = RerankTokens(**response.get("usage", {}))
rerank_meta = RerankResponseMeta(billed_units=_billed_units, tokens=_tokens)
_results: Optional[List[dict]] = response.get("results")
if _results is None:
raise ValueError(f"No results found in the response={response}")
return RerankResponse(
id=response.get("id") or str(uuid.uuid4()),
results=_results,
meta=rerank_meta,
) # Return response

View file

@ -9,7 +9,7 @@ from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLogging
from litellm.llms.azure_ai.rerank import AzureAIRerank
from litellm.llms.cohere.rerank import CohereRerank
from litellm.llms.jina_ai.rerank.handler import JinaAIRerank
from litellm.llms.together_ai.rerank import TogetherAIRerank
from litellm.llms.together_ai.rerank.handler import TogetherAIRerank
from litellm.secret_managers.main import get_secret
from litellm.types.rerank import RerankRequest, RerankResponse
from litellm.types.router import *