This commit is contained in:
Prathamesh Saraf 2025-04-24 00:54:03 -07:00 committed by GitHub
commit 72a4af86e8
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
8 changed files with 284 additions and 22 deletions

View file

@ -952,6 +952,7 @@ from .llms.topaz.image_variations.transformation import TopazImageVariationConfi
from litellm.llms.openai.completion.transformation import OpenAITextCompletionConfig
from .llms.groq.chat.transformation import GroqChatConfig
from .llms.voyage.embedding.transformation import VoyageEmbeddingConfig
from .llms.voyage.rerank.transformation import VoyageRerankConfig
from .llms.infinity.embedding.transformation import InfinityEmbeddingConfig
from .llms.azure_ai.chat.transformation import AzureAIStudioConfig
from .llms.mistral.mistral_chat_transformation import MistralConfig

View file

@ -0,0 +1,26 @@
from litellm.llms.base_llm.chat.transformation import BaseLLMException
from typing import Union
import httpx
class VoyageError(BaseLLMException):
def __init__(
self,
status_code: int,
message: str,
headers: Union[dict, httpx.Headers] = {},
):
self.status_code = status_code
self.message = message
self.request = httpx.Request(
method="POST", url="https://api.voyageai.com/v1/embeddings"
)
self.response = httpx.Response(status_code=status_code, request=self.request)
super().__init__(
status_code=status_code,
message=message,
headers=headers,
request=self.request,
response=self.response,
)

View file

@ -9,25 +9,7 @@ from litellm.secret_managers.main import get_secret_str
from litellm.types.llms.openai import AllEmbeddingInputValues, AllMessageValues
from litellm.types.utils import EmbeddingResponse, Usage
class VoyageError(BaseLLMException):
def __init__(
self,
status_code: int,
message: str,
headers: Union[dict, httpx.Headers] = {},
):
self.status_code = status_code
self.message = message
self.request = httpx.Request(
method="POST", url="https://api.voyageai.com/v1/embeddings"
)
self.response = httpx.Response(status_code=status_code, request=self.request)
super().__init__(
status_code=status_code,
message=message,
headers=headers,
)
from ..common_utils import VoyageError
class VoyageEmbeddingConfig(BaseEmbeddingConfig):

View file

@ -0,0 +1,156 @@
import uuid
from typing import Any, Dict, List, Optional, Union
import httpx
from litellm.secret_managers.main import get_secret_str
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
from litellm.llms.base_llm.rerank.transformation import BaseRerankConfig
from litellm.types.rerank import OptionalRerankParams, RerankRequest, RerankResponse
from litellm.llms.voyage.common_utils import VoyageError
from litellm.types.rerank import (
RerankBilledUnits,
RerankResponseDocument,
RerankResponseMeta,
RerankResponseResult,
RerankTokens,
)
class VoyageRerankConfig(BaseRerankConfig):
def __init__(self) -> None:
pass
def get_complete_url(self, api_base: Optional[str], model: str) -> str:
if api_base:
# Remove trailing slashes and ensure clean base URL
api_base = api_base.rstrip("/")
if not api_base.endswith("/v1/rerank"):
api_base = f"{api_base}/v1/rerank"
return api_base
return "https://api.voyageai.com/v1/rerank"
def validate_environment(
self,
headers: dict,
model: str,
api_key: Optional[str] = None,
) -> dict:
if api_key is None:
api_key = (
get_secret_str("VOYAGE_API_KEY")
or get_secret_str("VOYAGE_AI_API_KEY")
or get_secret_str("VOYAGE_AI_TOKEN")
)
return {
"Authorization": f"Bearer {api_key}",
"content-type": "application/json",
}
def get_supported_cohere_rerank_params(self, model: str) -> list:
return [
"query",
"documents",
"top_k",
"return_documents",
]
def map_cohere_rerank_params(
self,
non_default_params: dict,
model: str,
drop_params: bool,
query: str,
documents: List[Union[str, Dict[str, Any]]],
custom_llm_provider: Optional[str] = None,
top_n: Optional[int] = None,
rank_fields: Optional[List[str]] = None,
return_documents: Optional[bool] = True,
max_chunks_per_doc: Optional[int] = None,
max_tokens_per_doc: Optional[int] = None,
) -> OptionalRerankParams:
"""
Map Voyage rerank params
"""
optional_params = {}
supported_params = self.get_supported_cohere_rerank_params(model)
for k, v in non_default_params.items():
if k in supported_params:
optional_params[k] = v
# Voyage API uses top_k instead of top_n
# Assign top_k to top_n if top_n is not None
if top_n is not None:
optional_params["top_k"] = top_n
optional_params["top_n"] = None
return OptionalRerankParams(
**optional_params,
)
def transform_rerank_request(self, model: str, optional_rerank_params: OptionalRerankParams, headers: dict) -> dict:
# Transform request to RerankRequest spec
if "query" not in optional_rerank_params:
raise ValueError("query is required for Cohere rerank")
if "documents" not in optional_rerank_params:
raise ValueError("documents is required for Voyage rerank")
rerank_request = RerankRequest(
model=model,
query=optional_rerank_params["query"],
documents=optional_rerank_params["documents"],
# Voyage API uses top_k instead of top_n
top_k=optional_rerank_params.get("top_k", None),
return_documents=optional_rerank_params.get("return_documents", None),
)
return rerank_request.model_dump(exclude_none=True)
def transform_rerank_response(
self,
model: str,
raw_response: httpx.Response,
model_response: RerankResponse,
logging_obj: LiteLLMLoggingObj,
api_key: Optional[str] = None,
request_data: dict = {},
optional_params: dict = {},
litellm_params: dict = {},
) -> RerankResponse:
"""
Transform Voyage rerank response
No transformation required, litellm follows Voyage API response format
"""
try:
raw_response_json = raw_response.json()
except Exception:
raise VoyageError(
message=raw_response.text, status_code=raw_response.status_code
)
_billed_units = RerankBilledUnits(**raw_response_json.get("usage", {}))
_tokens = RerankTokens(
input_tokens=raw_response_json.get("usage", {}).get("prompt_tokens", 0),
output_tokens=(
raw_response_json.get("usage", {}).get("total_tokens", 0)
- raw_response_json.get("usage", {}).get("prompt_tokens", 0)
),
)
rerank_meta = RerankResponseMeta(billed_units=_billed_units, tokens=_tokens)
voyage_results: List[RerankResponseResult] = []
if raw_response_json.get("data"):
for result in raw_response_json.get("data"):
_rerank_response = RerankResponseResult(
index=result.get("index"),
relevance_score=result.get("relevance_score"),
)
if result.get("document"):
_rerank_response["document"] = RerankResponseDocument(
text=result.get("document")
)
voyage_results.append(_rerank_response)
if voyage_results is None:
raise ValueError(f"No results found in the response={raw_response_json}")
return RerankResponse(
id=raw_response_json.get("id") or str(uuid.uuid4()),
results=voyage_results,
meta=rerank_meta,
) # Return response

View file

@ -75,7 +75,7 @@ def rerank( # noqa: PLR0915
query: str,
documents: List[Union[str, Dict[str, Any]]],
custom_llm_provider: Optional[
Literal["cohere", "together_ai", "azure_ai", "infinity", "litellm_proxy"]
Literal["cohere", "together_ai", "azure_ai", "infinity", "voyage", "litellm_proxy"]
] = None,
top_n: Optional[int] = None,
rank_fields: Optional[List[str]] = None,
@ -323,6 +323,33 @@ def rerank( # noqa: PLR0915
logging_obj=litellm_logging_obj,
client=client,
)
elif _custom_llm_provider == "voyage":
api_key = (
dynamic_api_key or optional_params.api_key or litellm.api_key
)
api_base = (
dynamic_api_base
or optional_params.api_base
or litellm.api_base
or get_secret("VOYAGE_API_BASE") # type: ignore
)
# optional_rerank_params["top_k"] = top_n if top_n is not None else 3
response = base_llm_http_handler.rerank(
model=model,
custom_llm_provider=_custom_llm_provider,
provider_config=rerank_provider_config,
optional_rerank_params=optional_rerank_params,
logging_obj=litellm_logging_obj,
timeout=optional_params.timeout,
api_key=dynamic_api_key or optional_params.api_key,
api_base=api_base,
_is_async=_is_async,
headers=headers or litellm.headers or {},
client=client,
model_response=model_response,
)
else:
raise ValueError(f"Unsupported provider: {_custom_llm_provider}")

View file

@ -19,6 +19,7 @@ class RerankRequest(BaseModel):
return_documents: Optional[bool] = None
max_chunks_per_doc: Optional[int] = None
max_tokens_per_doc: Optional[int] = None
top_k: Optional[int] = None
class OptionalRerankParams(TypedDict, total=False):
@ -29,7 +30,7 @@ class OptionalRerankParams(TypedDict, total=False):
return_documents: Optional[bool]
max_chunks_per_doc: Optional[int]
max_tokens_per_doc: Optional[int]
top_k: Optional[int]
class RerankBilledUnits(TypedDict, total=False):
search_units: Optional[int]

View file

@ -6604,6 +6604,8 @@ class ProviderConfigManager:
return litellm.InfinityRerankConfig()
elif litellm.LlmProviders.JINA_AI == provider:
return litellm.JinaAIRerankConfig()
elif litellm.LlmProviders.VOYAGE == provider:
return litellm.VoyageRerankConfig()
return litellm.CohereRerankConfig()
@staticmethod

View file

@ -11,6 +11,7 @@ sys.path.insert(
from base_embedding_unit_tests import BaseLLMEmbeddingTest
from test_rerank import assert_response_shape
import litellm
from litellm.llms.custom_httpx.http_handler import HTTPHandler
from unittest.mock import patch, MagicMock
@ -77,4 +78,70 @@ def test_voyage_ai_embedding_prompt_token_mapping():
assert response.usage.total_tokens == 120
except Exception as e:
pytest.fail(f"Error occurred: {e}")
pytest.fail(f"Error occurred: {e}")
### Rerank Tests
@pytest.mark.asyncio()
async def test_voyage_ai_rerank():
mock_response = AsyncMock()
def return_val():
return {
"id": "cmpl-mockid",
"results": [{"index": 2, "relevance_score": 0.84375}],
"usage": {"total_tokens": 150},
}
mock_response.json = return_val
mock_response.headers = {"key": "value"}
mock_response.status_code = 200
expected_payload = {
"model": "rerank-model",
"query": "What is the capital of the United States?",
# Voyage API uses top_k instead of top_n
"top_k": 1,
"documents": [
"Carson City is the capital city of the American state of Nevada.",
"The Commonwealth of the Northern Mariana Islands is a group of islands in the Pacific Ocean. Its capital is Saipan.",
"Washington, D.C. is the capital of the United States.",
"Capital punishment has existed in the United States since before it was a country."
],
}
with patch(
"litellm.llms.custom_httpx.http_handler.AsyncHTTPHandler.post",
return_value=mock_response,
) as mock_post:
response = await litellm.arerank(
model="voyage/rerank-model",
query="What is the capital of the United States?",
documents=["Carson City is the capital city of the American state of Nevada.", "The Commonwealth of the Northern Mariana Islands is a group of islands in the Pacific Ocean. Its capital is Saipan.", "Washington, D.C. is the capital of the United States.", "Capital punishment has existed in the United States since before it was a country."],
top_n=1, # This will be converted to top_k internally
api_base="https://api.voyageai.ai"
)
print("async re rank response: ", response)
# Assert
mock_post.assert_called_once()
print("call args", mock_post.call_args)
args_to_api = mock_post.call_args.kwargs["data"]
_url = mock_post.call_args.kwargs["url"]
print("Arguments passed to API=", args_to_api)
print("url = ", _url)
assert _url == "https://api.voyageai.ai/v1/rerank"
request_data = json.loads(args_to_api)
print("request data to voyage ai", json.dumps(request_data, indent=4))
assert request_data["query"] == expected_payload["query"]
assert request_data["documents"] == expected_payload["documents"]
assert request_data["top_k"] == expected_payload["top_k"]
assert request_data["model"] == expected_payload["model"]
assert response.id is not None
assert response.results is not None
assert response.meta["tokens"]["output_tokens"] == 150
assert_response_shape(response, custom_llm_provider="voyage")