litellm-mirror/litellm/llms/bedrock/rerank/transformation.py
Krish Dholakia 9b7ebb6a7d
build(pyproject.toml): add new dev dependencies - for type checking (#9631)
* build(pyproject.toml): add new dev dependencies - for type checking

* build: reformat files to fit black

* ci: reformat to fit black

* ci(test-litellm.yml): make tests run clear

* build(pyproject.toml): add ruff

* fix: fix ruff checks

* build(mypy/): fix mypy linting errors

* fix(hashicorp_secret_manager.py): fix passing cert for tls auth

* build(mypy/): resolve all mypy errors

* test: update test

* fix: fix black formatting

* build(pre-commit-config.yaml): use poetry run black

* fix(proxy_server.py): fix linting error

* fix: fix ruff safe representation error
2025-03-29 11:02:13 -07:00

118 lines
4 KiB
Python

"""
Translates from Cohere's `/v1/rerank` input format to Bedrock's `/rerank` input format.
Why separate file? Make it easy to see how transformation works
"""
import uuid
from typing import List, Optional, Union
from litellm.types.llms.bedrock import (
BedrockRerankBedrockRerankingConfiguration,
BedrockRerankConfiguration,
BedrockRerankInlineDocumentSource,
BedrockRerankModelConfiguration,
BedrockRerankQuery,
BedrockRerankRequest,
BedrockRerankSource,
BedrockRerankTextDocument,
BedrockRerankTextQuery,
)
from litellm.types.rerank import (
RerankBilledUnits,
RerankRequest,
RerankResponse,
RerankResponseMeta,
RerankResponseResult,
RerankTokens,
)
class BedrockRerankConfig:
def _transform_sources(
self, documents: List[Union[str, dict]]
) -> List[BedrockRerankSource]:
"""
Transform the sources from RerankRequest format to Bedrock format.
"""
_sources = []
for document in documents:
if isinstance(document, str):
_sources.append(
BedrockRerankSource(
inlineDocumentSource=BedrockRerankInlineDocumentSource(
textDocument=BedrockRerankTextDocument(text=document),
type="TEXT",
),
type="INLINE",
)
)
else:
_sources.append(
BedrockRerankSource(
inlineDocumentSource=BedrockRerankInlineDocumentSource(
jsonDocument=document, type="JSON"
),
type="INLINE",
)
)
return _sources
def _transform_request(self, request_data: RerankRequest) -> BedrockRerankRequest:
"""
Transform the request from RerankRequest format to Bedrock format.
"""
_sources = self._transform_sources(request_data.documents)
return BedrockRerankRequest(
queries=[
BedrockRerankQuery(
textQuery=BedrockRerankTextQuery(text=request_data.query),
type="TEXT",
)
],
rerankingConfiguration=BedrockRerankConfiguration(
bedrockRerankingConfiguration=BedrockRerankBedrockRerankingConfiguration(
modelConfiguration=BedrockRerankModelConfiguration(
modelArn=request_data.model
),
numberOfResults=request_data.top_n or len(request_data.documents),
),
type="BEDROCK_RERANKING_MODEL",
),
sources=_sources,
)
def _transform_response(self, response: dict) -> RerankResponse:
"""
Transform the response from Bedrock into the RerankResponse format.
example input:
{"results":[{"index":0,"relevanceScore":0.6847912669181824},{"index":1,"relevanceScore":0.5980774760246277}]}
"""
_billed_units = RerankBilledUnits(
**response.get("usage", {"search_units": 1})
) # by default 1 search unit
_tokens = RerankTokens(**response.get("usage", {}))
rerank_meta = RerankResponseMeta(billed_units=_billed_units, tokens=_tokens)
_results: Optional[List[RerankResponseResult]] = None
bedrock_results = response.get("results")
if bedrock_results:
_results = [
RerankResponseResult(
index=result.get("index"),
relevance_score=result.get("relevanceScore"),
)
for result in bedrock_results
]
if _results is None:
raise ValueError(f"No results found in the response={response}")
return RerankResponse(
id=response.get("id") or str(uuid.uuid4()),
results=_results,
meta=rerank_meta,
) # Return response