Add Hosted VLLM rerank provider integration

This commit implements the Hosted VLLM rerank provider integration for LiteLLM. The integration includes:
Adding Hosted VLLM as a supported rerank provider in the main rerank function
Implementing the HostedVLLMRerank handler class for making API requests
Creating a transformation class to convert Hosted VLLM responses to LiteLLM's standardized format
The integration supports both synchronous and asynchronous rerank operations. API credentials can be provided directly or through environment variables (HOSTED_VLLM_API_KEY and HOSTED_VLLM_API_BASE).
Notable features:
Proper error handling for missing credentials
Standard response transformation
Support for common rerank parameters (top_n, return_documents, etc.)
Proper token usage tracking
This expands LiteLLM's rerank provider ecosystem to include Hosted VLLM alongside existing providers like Cohere, Together AI, Azure AI, and Bedrock.
This commit is contained in:
Philip D'Souza 2025-03-14 14:15:45 +00:00
parent 3875df666b
commit 0e9b4453e0
3 changed files with 202 additions and 1 deletions

View file

@ -0,0 +1,103 @@
"""
Re rank api for Hosted VLLM
LiteLLM supports the re rank API format, no parameter transformation occurs
"""
from typing import Any, Dict, List, Optional, Union
import litellm
from litellm.llms.base import BaseLLM
from litellm.llms.custom_httpx.http_handler import (
_get_httpx_client,
get_async_httpx_client,
)
from litellm.llms.hosted_vllm.rerank.transformation import HostedVLLMRerankConfig
from litellm.secret_managers.main import get_secret_str
from litellm.types.rerank import RerankRequest, RerankResponse
class HostedVLLMRerank(BaseLLM):
def rerank(
self,
model: str,
api_key: str,
query: str,
documents: List[Union[str, Dict[str, Any]]],
top_n: Optional[int] = None,
rank_fields: Optional[List[str]] = None,
return_documents: Optional[bool] = True,
max_chunks_per_doc: Optional[int] = None,
api_base: Optional[str] = None,
_is_async: Optional[bool] = False,
) -> RerankResponse:
client = _get_httpx_client()
request_data = RerankRequest(
model=model,
query=query,
top_n=top_n,
documents=documents,
rank_fields=rank_fields,
return_documents=return_documents,
)
# exclude None values from request_data
request_data_dict = request_data.dict(exclude_none=True)
if max_chunks_per_doc is not None:
raise ValueError("Hosted VLLM does not support max_chunks_per_doc")
# Get API base URL
api_base = api_base or get_secret_str("HOSTED_VLLM_API_BASE")
if api_base is None:
raise ValueError("api_base must be provided for Hosted VLLM rerank")
# Get API key
api_key = api_key or get_secret_str("HOSTED_VLLM_API_KEY") or "fake-api-key"
if _is_async:
return self.async_rerank(request_data_dict, api_key, api_base) # type: ignore # Call async method
response = client.post(
f"{api_base}/v1/rerank",
headers={
"accept": "application/json",
"content-type": "application/json",
"authorization": f"Bearer {api_key}",
},
json=request_data_dict,
)
if response.status_code != 200:
raise Exception(response.text)
_json_response = response.json()
return HostedVLLMRerankConfig()._transform_response(_json_response)
async def async_rerank( # New async method
self,
request_data_dict: Dict[str, Any],
api_key: str,
api_base: str,
) -> RerankResponse:
client = get_async_httpx_client(
llm_provider=litellm.LlmProviders.HOSTED_VLLM
) # Use async client
response = await client.post(
f"{api_base}/v1/rerank",
headers={
"accept": "application/json",
"content-type": "application/json",
"authorization": f"Bearer {api_key}",
},
json=request_data_dict,
)
if response.status_code != 200:
raise Exception(response.text)
_json_response = response.json()
return HostedVLLMRerankConfig()._transform_response(_json_response)

View file

@ -0,0 +1,65 @@
"""
Transformation logic for Hosted VLLM's `/v1/rerank` format.
Why separate file? Make it easy to see how transformation works
"""
import uuid
from typing import List, Optional
from litellm.types.rerank import (
RerankBilledUnits,
RerankResponse,
RerankResponseDocument,
RerankResponseMeta,
RerankResponseResult,
RerankTokens,
)
class HostedVLLMRerankConfig:
def _transform_response(self, response: dict) -> RerankResponse:
# Extract usage information
usage_data = response.get("usage", {})
_billed_units = RerankBilledUnits(total_tokens=usage_data.get("total_tokens", 0))
_tokens = RerankTokens(total_tokens=usage_data.get("total_tokens", 0))
rerank_meta = RerankResponseMeta(billed_units=_billed_units, tokens=_tokens)
# Extract results
_results: Optional[List[dict]] = response.get("results")
if _results is None:
raise ValueError(f"No results found in the response={response}")
rerank_results: List[RerankResponseResult] = []
for result in _results:
# Validate required fields exist
if not all(key in result for key in ["index", "relevance_score"]):
raise ValueError(f"Missing required fields in the result={result}")
# Get document data if it exists
document_data = result.get("document", {})
document = (
RerankResponseDocument(text=str(document_data.get("text", "")))
if document_data
else None
)
# Create typed result
rerank_result = RerankResponseResult(
index=int(result["index"]),
relevance_score=float(result["relevance_score"]),
)
# Only add document if it exists
if document:
rerank_result["document"] = document
rerank_results.append(rerank_result)
return RerankResponse(
id=response.get("id") or str(uuid.uuid4()),
results=rerank_results,
meta=rerank_meta,
)

View file

@ -10,6 +10,7 @@ from litellm.llms.base_llm.rerank.transformation import BaseRerankConfig
from litellm.llms.bedrock.rerank.handler import BedrockRerankHandler
from litellm.llms.custom_httpx.llm_http_handler import BaseLLMHTTPHandler
from litellm.llms.together_ai.rerank.handler import TogetherAIRerank
from litellm.llms.hosted_vllm.rerank.handler import HostedVLLMRerank
from litellm.rerank_api.rerank_utils import get_optional_rerank_params
from litellm.secret_managers.main import get_secret, get_secret_str
from litellm.types.rerank import OptionalRerankParams, RerankResponse
@ -21,6 +22,7 @@ from litellm.utils import ProviderConfigManager, client, exception_type
together_rerank = TogetherAIRerank()
bedrock_rerank = BedrockRerankHandler()
base_llm_http_handler = BaseLLMHTTPHandler()
hosted_vllm_rerank = HostedVLLMRerank()
#################################################
@ -75,7 +77,7 @@ def rerank( # noqa: PLR0915
query: str,
documents: List[Union[str, Dict[str, Any]]],
custom_llm_provider: Optional[
Literal["cohere", "together_ai", "azure_ai", "infinity", "litellm_proxy"]
Literal["cohere", "together_ai", "azure_ai", "infinity", "litellm_proxy", "hosted_vllm"]
] = None,
top_n: Optional[int] = None,
rank_fields: Optional[List[str]] = None,
@ -321,6 +323,37 @@ def rerank( # noqa: PLR0915
logging_obj=litellm_logging_obj,
client=client,
)
elif _custom_llm_provider == "hosted_vllm":
# Implement Hosted VLLM rerank logic
api_key = (
dynamic_api_key
or optional_params.api_key
or get_secret_str("HOSTED_VLLM_API_KEY")
)
api_base = (
dynamic_api_base
or optional_params.api_base
or get_secret_str("HOSTED_VLLM_API_BASE")
)
if api_base is None:
raise ValueError(
"api_base must be provided for Hosted VLLM rerank. Set in call or via HOSTED_VLLM_API_BASE env var."
)
response = hosted_vllm_rerank.rerank(
model=model,
query=query,
documents=documents,
top_n=top_n,
rank_fields=rank_fields,
return_documents=return_documents,
max_chunks_per_doc=max_chunks_per_doc,
api_key=api_key,
api_base=api_base,
_is_async=_is_async,
)
else:
raise ValueError(f"Unsupported provider: {_custom_llm_provider}")