mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 02:34:29 +00:00
This commit implements the Hosted VLLM rerank provider integration for LiteLLM. The integration includes: Adding Hosted VLLM as a supported rerank provider in the main rerank function Implementing the HostedVLLMRerank handler class for making API requests Creating a transformation class to convert Hosted VLLM responses to LiteLLM's standardized format The integration supports both synchronous and asynchronous rerank operations. API credentials can be provided directly or through environment variables (HOSTED_VLLM_API_KEY and HOSTED_VLLM_API_BASE). Notable features: Proper error handling for missing credentials Standard response transformation Support for common rerank parameters (top_n, return_documents, etc.) Proper token usage tracking This expands LiteLLM's rerank provider ecosystem to include Hosted VLLM alongside existing providers like Cohere, Together AI, Azure AI, and Bedrock.
366 lines
14 KiB
Python
366 lines
14 KiB
Python
import asyncio
|
|
import contextvars
|
|
from functools import partial
|
|
from typing import Any, Coroutine, Dict, List, Literal, Optional, Union
|
|
|
|
import litellm
|
|
from litellm._logging import verbose_logger
|
|
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
|
|
from litellm.llms.base_llm.rerank.transformation import BaseRerankConfig
|
|
from litellm.llms.bedrock.rerank.handler import BedrockRerankHandler
|
|
from litellm.llms.custom_httpx.llm_http_handler import BaseLLMHTTPHandler
|
|
from litellm.llms.together_ai.rerank.handler import TogetherAIRerank
|
|
from litellm.llms.hosted_vllm.rerank.handler import HostedVLLMRerank
|
|
from litellm.rerank_api.rerank_utils import get_optional_rerank_params
|
|
from litellm.secret_managers.main import get_secret, get_secret_str
|
|
from litellm.types.rerank import OptionalRerankParams, RerankResponse
|
|
from litellm.types.router import *
|
|
from litellm.utils import ProviderConfigManager, client, exception_type
|
|
|
|
####### ENVIRONMENT VARIABLES ###################
|
|
# Initialize any necessary instances or variables here
|
|
together_rerank = TogetherAIRerank()
|
|
bedrock_rerank = BedrockRerankHandler()
|
|
base_llm_http_handler = BaseLLMHTTPHandler()
|
|
hosted_vllm_rerank = HostedVLLMRerank()
|
|
#################################################
|
|
|
|
|
|
@client
|
|
async def arerank(
|
|
model: str,
|
|
query: str,
|
|
documents: List[Union[str, Dict[str, Any]]],
|
|
custom_llm_provider: Optional[Literal["cohere", "together_ai"]] = None,
|
|
top_n: Optional[int] = None,
|
|
rank_fields: Optional[List[str]] = None,
|
|
return_documents: Optional[bool] = None,
|
|
max_chunks_per_doc: Optional[int] = None,
|
|
**kwargs,
|
|
) -> Union[RerankResponse, Coroutine[Any, Any, RerankResponse]]:
|
|
"""
|
|
Async: Reranks a list of documents based on their relevance to the query
|
|
"""
|
|
try:
|
|
loop = asyncio.get_event_loop()
|
|
kwargs["arerank"] = True
|
|
|
|
func = partial(
|
|
rerank,
|
|
model,
|
|
query,
|
|
documents,
|
|
custom_llm_provider,
|
|
top_n,
|
|
rank_fields,
|
|
return_documents,
|
|
max_chunks_per_doc,
|
|
**kwargs,
|
|
)
|
|
|
|
ctx = contextvars.copy_context()
|
|
func_with_context = partial(ctx.run, func)
|
|
init_response = await loop.run_in_executor(None, func_with_context)
|
|
|
|
if asyncio.iscoroutine(init_response):
|
|
response = await init_response
|
|
else:
|
|
response = init_response
|
|
return response
|
|
except Exception as e:
|
|
raise e
|
|
|
|
|
|
@client
|
|
def rerank( # noqa: PLR0915
|
|
model: str,
|
|
query: str,
|
|
documents: List[Union[str, Dict[str, Any]]],
|
|
custom_llm_provider: Optional[
|
|
Literal["cohere", "together_ai", "azure_ai", "infinity", "litellm_proxy", "hosted_vllm"]
|
|
] = None,
|
|
top_n: Optional[int] = None,
|
|
rank_fields: Optional[List[str]] = None,
|
|
return_documents: Optional[bool] = True,
|
|
max_chunks_per_doc: Optional[int] = None,
|
|
max_tokens_per_doc: Optional[int] = None,
|
|
**kwargs,
|
|
) -> Union[RerankResponse, Coroutine[Any, Any, RerankResponse]]:
|
|
"""
|
|
Reranks a list of documents based on their relevance to the query
|
|
"""
|
|
headers: Optional[dict] = kwargs.get("headers") # type: ignore
|
|
litellm_logging_obj: LiteLLMLoggingObj = kwargs.get("litellm_logging_obj") # type: ignore
|
|
litellm_call_id: Optional[str] = kwargs.get("litellm_call_id", None)
|
|
proxy_server_request = kwargs.get("proxy_server_request", None)
|
|
model_info = kwargs.get("model_info", None)
|
|
metadata = kwargs.get("metadata", {})
|
|
user = kwargs.get("user", None)
|
|
client = kwargs.get("client", None)
|
|
try:
|
|
_is_async = kwargs.pop("arerank", False) is True
|
|
optional_params = GenericLiteLLMParams(**kwargs)
|
|
# Params that are unique to specific versions of the client for the rerank call
|
|
unique_version_params = {
|
|
"max_chunks_per_doc": max_chunks_per_doc,
|
|
"max_tokens_per_doc": max_tokens_per_doc,
|
|
}
|
|
present_version_params = [
|
|
k for k, v in unique_version_params.items() if v is not None
|
|
]
|
|
|
|
model, _custom_llm_provider, dynamic_api_key, dynamic_api_base = (
|
|
litellm.get_llm_provider(
|
|
model=model,
|
|
custom_llm_provider=custom_llm_provider,
|
|
api_base=optional_params.api_base,
|
|
api_key=optional_params.api_key,
|
|
)
|
|
)
|
|
|
|
rerank_provider_config: BaseRerankConfig = (
|
|
ProviderConfigManager.get_provider_rerank_config(
|
|
model=model,
|
|
provider=litellm.LlmProviders(_custom_llm_provider),
|
|
api_base=optional_params.api_base,
|
|
present_version_params=present_version_params,
|
|
)
|
|
)
|
|
|
|
optional_rerank_params: OptionalRerankParams = get_optional_rerank_params(
|
|
rerank_provider_config=rerank_provider_config,
|
|
model=model,
|
|
drop_params=kwargs.get("drop_params") or litellm.drop_params or False,
|
|
query=query,
|
|
documents=documents,
|
|
custom_llm_provider=_custom_llm_provider,
|
|
top_n=top_n,
|
|
rank_fields=rank_fields,
|
|
return_documents=return_documents,
|
|
max_chunks_per_doc=max_chunks_per_doc,
|
|
max_tokens_per_doc=max_tokens_per_doc,
|
|
non_default_params=kwargs,
|
|
)
|
|
|
|
if isinstance(optional_params.timeout, str):
|
|
optional_params.timeout = float(optional_params.timeout)
|
|
|
|
model_response = RerankResponse()
|
|
|
|
litellm_logging_obj.update_environment_variables(
|
|
model=model,
|
|
user=user,
|
|
optional_params=dict(optional_rerank_params),
|
|
litellm_params={
|
|
"litellm_call_id": litellm_call_id,
|
|
"proxy_server_request": proxy_server_request,
|
|
"model_info": model_info,
|
|
"metadata": metadata,
|
|
"preset_cache_key": None,
|
|
"stream_response": {},
|
|
**optional_params.model_dump(exclude_unset=True),
|
|
},
|
|
custom_llm_provider=_custom_llm_provider,
|
|
)
|
|
|
|
# Implement rerank logic here based on the custom_llm_provider
|
|
if _custom_llm_provider == "cohere" or _custom_llm_provider == "litellm_proxy":
|
|
# Implement Cohere rerank logic
|
|
api_key: Optional[str] = (
|
|
dynamic_api_key or optional_params.api_key or litellm.api_key
|
|
)
|
|
|
|
api_base: Optional[str] = (
|
|
dynamic_api_base
|
|
or optional_params.api_base
|
|
or litellm.api_base
|
|
or get_secret("COHERE_API_BASE") # type: ignore
|
|
or "https://api.cohere.com"
|
|
)
|
|
|
|
if api_base is None:
|
|
raise Exception(
|
|
"Invalid api base. api_base=None. Set in call or via `COHERE_API_BASE` env var."
|
|
)
|
|
response = base_llm_http_handler.rerank(
|
|
model=model,
|
|
custom_llm_provider=_custom_llm_provider,
|
|
provider_config=rerank_provider_config,
|
|
optional_rerank_params=optional_rerank_params,
|
|
logging_obj=litellm_logging_obj,
|
|
timeout=optional_params.timeout,
|
|
api_key=api_key,
|
|
api_base=api_base,
|
|
_is_async=_is_async,
|
|
headers=headers or litellm.headers or {},
|
|
client=client,
|
|
model_response=model_response,
|
|
)
|
|
elif _custom_llm_provider == "azure_ai":
|
|
api_base = (
|
|
dynamic_api_base # for deepinfra/perplexity/anyscale/groq/friendliai we check in get_llm_provider and pass in the api base from there
|
|
or optional_params.api_base
|
|
or litellm.api_base
|
|
or get_secret("AZURE_AI_API_BASE") # type: ignore
|
|
)
|
|
response = base_llm_http_handler.rerank(
|
|
model=model,
|
|
custom_llm_provider=_custom_llm_provider,
|
|
optional_rerank_params=optional_rerank_params,
|
|
provider_config=rerank_provider_config,
|
|
logging_obj=litellm_logging_obj,
|
|
timeout=optional_params.timeout,
|
|
api_key=dynamic_api_key or optional_params.api_key,
|
|
api_base=api_base,
|
|
_is_async=_is_async,
|
|
headers=headers or litellm.headers or {},
|
|
client=client,
|
|
model_response=model_response,
|
|
)
|
|
elif _custom_llm_provider == "infinity":
|
|
# Implement Infinity rerank logic
|
|
api_key = dynamic_api_key or optional_params.api_key or litellm.api_key
|
|
|
|
api_base = (
|
|
dynamic_api_base
|
|
or optional_params.api_base
|
|
or litellm.api_base
|
|
or get_secret_str("INFINITY_API_BASE")
|
|
)
|
|
|
|
if api_base is None:
|
|
raise Exception(
|
|
"Invalid api base. api_base=None. Set in call or via `INFINITY_API_BASE` env var."
|
|
)
|
|
|
|
response = base_llm_http_handler.rerank(
|
|
model=model,
|
|
custom_llm_provider=_custom_llm_provider,
|
|
provider_config=rerank_provider_config,
|
|
optional_rerank_params=optional_rerank_params,
|
|
logging_obj=litellm_logging_obj,
|
|
timeout=optional_params.timeout,
|
|
api_key=dynamic_api_key or optional_params.api_key,
|
|
api_base=api_base,
|
|
_is_async=_is_async,
|
|
headers=headers or litellm.headers or {},
|
|
client=client,
|
|
model_response=model_response,
|
|
)
|
|
elif _custom_llm_provider == "together_ai":
|
|
# Implement Together AI rerank logic
|
|
api_key = (
|
|
dynamic_api_key
|
|
or optional_params.api_key
|
|
or litellm.togetherai_api_key
|
|
or get_secret("TOGETHERAI_API_KEY") # type: ignore
|
|
or litellm.api_key
|
|
)
|
|
|
|
if api_key is None:
|
|
raise ValueError(
|
|
"TogetherAI API key is required, please set 'TOGETHERAI_API_KEY' in your environment"
|
|
)
|
|
|
|
response = together_rerank.rerank(
|
|
model=model,
|
|
query=query,
|
|
documents=documents,
|
|
top_n=top_n,
|
|
rank_fields=rank_fields,
|
|
return_documents=return_documents,
|
|
max_chunks_per_doc=max_chunks_per_doc,
|
|
api_key=api_key,
|
|
_is_async=_is_async,
|
|
)
|
|
elif _custom_llm_provider == "jina_ai":
|
|
|
|
if dynamic_api_key is None:
|
|
raise ValueError(
|
|
"Jina AI API key is required, please set 'JINA_AI_API_KEY' in your environment"
|
|
)
|
|
|
|
api_base = (
|
|
dynamic_api_base
|
|
or optional_params.api_base
|
|
or litellm.api_base
|
|
or get_secret("BEDROCK_API_BASE") # type: ignore
|
|
)
|
|
|
|
response = base_llm_http_handler.rerank(
|
|
model=model,
|
|
custom_llm_provider=_custom_llm_provider,
|
|
optional_rerank_params=optional_rerank_params,
|
|
logging_obj=litellm_logging_obj,
|
|
provider_config=rerank_provider_config,
|
|
timeout=optional_params.timeout,
|
|
api_key=dynamic_api_key or optional_params.api_key,
|
|
api_base=api_base,
|
|
_is_async=_is_async,
|
|
headers=headers or litellm.headers or {},
|
|
client=client,
|
|
model_response=model_response,
|
|
)
|
|
elif _custom_llm_provider == "bedrock":
|
|
api_base = (
|
|
dynamic_api_base
|
|
or optional_params.api_base
|
|
or litellm.api_base
|
|
or get_secret("BEDROCK_API_BASE") # type: ignore
|
|
)
|
|
|
|
response = bedrock_rerank.rerank(
|
|
model=model,
|
|
query=query,
|
|
documents=documents,
|
|
top_n=top_n,
|
|
rank_fields=rank_fields,
|
|
return_documents=return_documents,
|
|
max_chunks_per_doc=max_chunks_per_doc,
|
|
_is_async=_is_async,
|
|
optional_params=optional_params.model_dump(exclude_unset=True),
|
|
api_base=api_base,
|
|
logging_obj=litellm_logging_obj,
|
|
client=client,
|
|
)
|
|
elif _custom_llm_provider == "hosted_vllm":
|
|
# Implement Hosted VLLM rerank logic
|
|
api_key = (
|
|
dynamic_api_key
|
|
or optional_params.api_key
|
|
or get_secret_str("HOSTED_VLLM_API_KEY")
|
|
)
|
|
|
|
api_base = (
|
|
dynamic_api_base
|
|
or optional_params.api_base
|
|
or get_secret_str("HOSTED_VLLM_API_BASE")
|
|
)
|
|
|
|
if api_base is None:
|
|
raise ValueError(
|
|
"api_base must be provided for Hosted VLLM rerank. Set in call or via HOSTED_VLLM_API_BASE env var."
|
|
)
|
|
|
|
response = hosted_vllm_rerank.rerank(
|
|
model=model,
|
|
query=query,
|
|
documents=documents,
|
|
top_n=top_n,
|
|
rank_fields=rank_fields,
|
|
return_documents=return_documents,
|
|
max_chunks_per_doc=max_chunks_per_doc,
|
|
api_key=api_key,
|
|
api_base=api_base,
|
|
_is_async=_is_async,
|
|
)
|
|
else:
|
|
raise ValueError(f"Unsupported provider: {_custom_llm_provider}")
|
|
|
|
# Placeholder return
|
|
return response
|
|
except Exception as e:
|
|
verbose_logger.error(f"Error in rerank: {str(e)}")
|
|
raise exception_type(
|
|
model=model, custom_llm_provider=custom_llm_provider, original_exception=e
|
|
)
|