mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-24 18:24:20 +00:00
Add Hosted VLLM rerank provider integration
This commit implements the Hosted VLLM rerank provider integration for LiteLLM. The integration includes: Adding Hosted VLLM as a supported rerank provider in the main rerank function Implementing the HostedVLLMRerank handler class for making API requests Creating a transformation class to convert Hosted VLLM responses to LiteLLM's standardized format The integration supports both synchronous and asynchronous rerank operations. API credentials can be provided directly or through environment variables (HOSTED_VLLM_API_KEY and HOSTED_VLLM_API_BASE). Notable features: Proper error handling for missing credentials Standard response transformation Support for common rerank parameters (top_n, return_documents, etc.) Proper token usage tracking This expands LiteLLM's rerank provider ecosystem to include Hosted VLLM alongside existing providers like Cohere, Together AI, Azure AI, and Bedrock.
This commit is contained in:
parent
3875df666b
commit
0e9b4453e0
3 changed files with 202 additions and 1 deletions
103
litellm/llms/hosted_vllm/rerank/handler.py
Normal file
103
litellm/llms/hosted_vllm/rerank/handler.py
Normal file
|
@ -0,0 +1,103 @@
|
|||
"""
|
||||
Re rank api for Hosted VLLM
|
||||
|
||||
LiteLLM supports the re rank API format, no parameter transformation occurs
|
||||
"""
|
||||
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
import litellm
|
||||
from litellm.llms.base import BaseLLM
|
||||
from litellm.llms.custom_httpx.http_handler import (
|
||||
_get_httpx_client,
|
||||
get_async_httpx_client,
|
||||
)
|
||||
from litellm.llms.hosted_vllm.rerank.transformation import HostedVLLMRerankConfig
|
||||
from litellm.secret_managers.main import get_secret_str
|
||||
from litellm.types.rerank import RerankRequest, RerankResponse
|
||||
|
||||
|
||||
class HostedVLLMRerank(BaseLLM):
|
||||
def rerank(
|
||||
self,
|
||||
model: str,
|
||||
api_key: str,
|
||||
query: str,
|
||||
documents: List[Union[str, Dict[str, Any]]],
|
||||
top_n: Optional[int] = None,
|
||||
rank_fields: Optional[List[str]] = None,
|
||||
return_documents: Optional[bool] = True,
|
||||
max_chunks_per_doc: Optional[int] = None,
|
||||
api_base: Optional[str] = None,
|
||||
_is_async: Optional[bool] = False,
|
||||
) -> RerankResponse:
|
||||
client = _get_httpx_client()
|
||||
|
||||
request_data = RerankRequest(
|
||||
model=model,
|
||||
query=query,
|
||||
top_n=top_n,
|
||||
documents=documents,
|
||||
rank_fields=rank_fields,
|
||||
return_documents=return_documents,
|
||||
)
|
||||
|
||||
# exclude None values from request_data
|
||||
request_data_dict = request_data.dict(exclude_none=True)
|
||||
if max_chunks_per_doc is not None:
|
||||
raise ValueError("Hosted VLLM does not support max_chunks_per_doc")
|
||||
|
||||
# Get API base URL
|
||||
api_base = api_base or get_secret_str("HOSTED_VLLM_API_BASE")
|
||||
if api_base is None:
|
||||
raise ValueError("api_base must be provided for Hosted VLLM rerank")
|
||||
|
||||
# Get API key
|
||||
api_key = api_key or get_secret_str("HOSTED_VLLM_API_KEY") or "fake-api-key"
|
||||
|
||||
if _is_async:
|
||||
return self.async_rerank(request_data_dict, api_key, api_base) # type: ignore # Call async method
|
||||
|
||||
response = client.post(
|
||||
f"{api_base}/v1/rerank",
|
||||
headers={
|
||||
"accept": "application/json",
|
||||
"content-type": "application/json",
|
||||
"authorization": f"Bearer {api_key}",
|
||||
},
|
||||
json=request_data_dict,
|
||||
)
|
||||
|
||||
if response.status_code != 200:
|
||||
raise Exception(response.text)
|
||||
|
||||
_json_response = response.json()
|
||||
|
||||
return HostedVLLMRerankConfig()._transform_response(_json_response)
|
||||
|
||||
async def async_rerank( # New async method
|
||||
self,
|
||||
request_data_dict: Dict[str, Any],
|
||||
api_key: str,
|
||||
api_base: str,
|
||||
) -> RerankResponse:
|
||||
client = get_async_httpx_client(
|
||||
llm_provider=litellm.LlmProviders.HOSTED_VLLM
|
||||
) # Use async client
|
||||
|
||||
response = await client.post(
|
||||
f"{api_base}/v1/rerank",
|
||||
headers={
|
||||
"accept": "application/json",
|
||||
"content-type": "application/json",
|
||||
"authorization": f"Bearer {api_key}",
|
||||
},
|
||||
json=request_data_dict,
|
||||
)
|
||||
|
||||
if response.status_code != 200:
|
||||
raise Exception(response.text)
|
||||
|
||||
_json_response = response.json()
|
||||
|
||||
return HostedVLLMRerankConfig()._transform_response(_json_response)
|
65
litellm/llms/hosted_vllm/rerank/transformation.py
Normal file
65
litellm/llms/hosted_vllm/rerank/transformation.py
Normal file
|
@ -0,0 +1,65 @@
|
|||
"""
|
||||
Transformation logic for Hosted VLLM's `/v1/rerank` format.
|
||||
|
||||
Why separate file? Make it easy to see how transformation works
|
||||
"""
|
||||
|
||||
import uuid
|
||||
from typing import List, Optional
|
||||
|
||||
from litellm.types.rerank import (
|
||||
RerankBilledUnits,
|
||||
RerankResponse,
|
||||
RerankResponseDocument,
|
||||
RerankResponseMeta,
|
||||
RerankResponseResult,
|
||||
RerankTokens,
|
||||
)
|
||||
|
||||
|
||||
class HostedVLLMRerankConfig:
|
||||
def _transform_response(self, response: dict) -> RerankResponse:
|
||||
# Extract usage information
|
||||
usage_data = response.get("usage", {})
|
||||
_billed_units = RerankBilledUnits(total_tokens=usage_data.get("total_tokens", 0))
|
||||
_tokens = RerankTokens(total_tokens=usage_data.get("total_tokens", 0))
|
||||
rerank_meta = RerankResponseMeta(billed_units=_billed_units, tokens=_tokens)
|
||||
|
||||
# Extract results
|
||||
_results: Optional[List[dict]] = response.get("results")
|
||||
|
||||
if _results is None:
|
||||
raise ValueError(f"No results found in the response={response}")
|
||||
|
||||
rerank_results: List[RerankResponseResult] = []
|
||||
|
||||
for result in _results:
|
||||
# Validate required fields exist
|
||||
if not all(key in result for key in ["index", "relevance_score"]):
|
||||
raise ValueError(f"Missing required fields in the result={result}")
|
||||
|
||||
# Get document data if it exists
|
||||
document_data = result.get("document", {})
|
||||
document = (
|
||||
RerankResponseDocument(text=str(document_data.get("text", "")))
|
||||
if document_data
|
||||
else None
|
||||
)
|
||||
|
||||
# Create typed result
|
||||
rerank_result = RerankResponseResult(
|
||||
index=int(result["index"]),
|
||||
relevance_score=float(result["relevance_score"]),
|
||||
)
|
||||
|
||||
# Only add document if it exists
|
||||
if document:
|
||||
rerank_result["document"] = document
|
||||
|
||||
rerank_results.append(rerank_result)
|
||||
|
||||
return RerankResponse(
|
||||
id=response.get("id") or str(uuid.uuid4()),
|
||||
results=rerank_results,
|
||||
meta=rerank_meta,
|
||||
)
|
|
@ -10,6 +10,7 @@ from litellm.llms.base_llm.rerank.transformation import BaseRerankConfig
|
|||
from litellm.llms.bedrock.rerank.handler import BedrockRerankHandler
|
||||
from litellm.llms.custom_httpx.llm_http_handler import BaseLLMHTTPHandler
|
||||
from litellm.llms.together_ai.rerank.handler import TogetherAIRerank
|
||||
from litellm.llms.hosted_vllm.rerank.handler import HostedVLLMRerank
|
||||
from litellm.rerank_api.rerank_utils import get_optional_rerank_params
|
||||
from litellm.secret_managers.main import get_secret, get_secret_str
|
||||
from litellm.types.rerank import OptionalRerankParams, RerankResponse
|
||||
|
@ -21,6 +22,7 @@ from litellm.utils import ProviderConfigManager, client, exception_type
|
|||
together_rerank = TogetherAIRerank()
|
||||
bedrock_rerank = BedrockRerankHandler()
|
||||
base_llm_http_handler = BaseLLMHTTPHandler()
|
||||
hosted_vllm_rerank = HostedVLLMRerank()
|
||||
#################################################
|
||||
|
||||
|
||||
|
@ -75,7 +77,7 @@ def rerank( # noqa: PLR0915
|
|||
query: str,
|
||||
documents: List[Union[str, Dict[str, Any]]],
|
||||
custom_llm_provider: Optional[
|
||||
Literal["cohere", "together_ai", "azure_ai", "infinity", "litellm_proxy"]
|
||||
Literal["cohere", "together_ai", "azure_ai", "infinity", "litellm_proxy", "hosted_vllm"]
|
||||
] = None,
|
||||
top_n: Optional[int] = None,
|
||||
rank_fields: Optional[List[str]] = None,
|
||||
|
@ -321,6 +323,37 @@ def rerank( # noqa: PLR0915
|
|||
logging_obj=litellm_logging_obj,
|
||||
client=client,
|
||||
)
|
||||
elif _custom_llm_provider == "hosted_vllm":
|
||||
# Implement Hosted VLLM rerank logic
|
||||
api_key = (
|
||||
dynamic_api_key
|
||||
or optional_params.api_key
|
||||
or get_secret_str("HOSTED_VLLM_API_KEY")
|
||||
)
|
||||
|
||||
api_base = (
|
||||
dynamic_api_base
|
||||
or optional_params.api_base
|
||||
or get_secret_str("HOSTED_VLLM_API_BASE")
|
||||
)
|
||||
|
||||
if api_base is None:
|
||||
raise ValueError(
|
||||
"api_base must be provided for Hosted VLLM rerank. Set in call or via HOSTED_VLLM_API_BASE env var."
|
||||
)
|
||||
|
||||
response = hosted_vllm_rerank.rerank(
|
||||
model=model,
|
||||
query=query,
|
||||
documents=documents,
|
||||
top_n=top_n,
|
||||
rank_fields=rank_fields,
|
||||
return_documents=return_documents,
|
||||
max_chunks_per_doc=max_chunks_per_doc,
|
||||
api_key=api_key,
|
||||
api_base=api_base,
|
||||
_is_async=_is_async,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unsupported provider: {_custom_llm_provider}")
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue