Adding reranking for voyage

This commit is contained in:
Prathamesh 2025-04-16 08:49:48 +05:30
parent ebfff975d4
commit 93fde85437
5 changed files with 90 additions and 21 deletions

View file

@ -951,6 +951,7 @@ from .llms.topaz.image_variations.transformation import TopazImageVariationConfi
from litellm.llms.openai.completion.transformation import OpenAITextCompletionConfig
from .llms.groq.chat.transformation import GroqChatConfig
from .llms.voyage.embedding.transformation import VoyageEmbeddingConfig
from .llms.voyage.rerank.transformation import VoyageRerankConfig
from .llms.infinity.embedding.transformation import InfinityEmbeddingConfig
from .llms.azure_ai.chat.transformation import AzureAIStudioConfig
from .llms.mistral.mistral_chat_transformation import MistralConfig

View file

@ -9,25 +9,7 @@ from litellm.secret_managers.main import get_secret_str
from litellm.types.llms.openai import AllEmbeddingInputValues, AllMessageValues
from litellm.types.utils import EmbeddingResponse, Usage
class VoyageError(BaseLLMException):
def __init__(
self,
status_code: int,
message: str,
headers: Union[dict, httpx.Headers] = {},
):
self.status_code = status_code
self.message = message
self.request = httpx.Request(
method="POST", url="https://api.voyageai.com/v1/embeddings"
)
self.response = httpx.Response(status_code=status_code, request=self.request)
super().__init__(
status_code=status_code,
message=message,
headers=headers,
)
from ..common_utils import VoyageError
class VoyageEmbeddingConfig(BaseEmbeddingConfig):

View file

@ -75,7 +75,7 @@ def rerank( # noqa: PLR0915
query: str,
documents: List[Union[str, Dict[str, Any]]],
custom_llm_provider: Optional[
Literal["cohere", "together_ai", "azure_ai", "infinity", "litellm_proxy"]
Literal["cohere", "together_ai", "azure_ai", "infinity", "voyage", "litellm_proxy"]
] = None,
top_n: Optional[int] = None,
rank_fields: Optional[List[str]] = None,
@ -323,6 +323,33 @@ def rerank( # noqa: PLR0915
logging_obj=litellm_logging_obj,
client=client,
)
elif _custom_llm_provider == "voyage":
api_key = (
dynamic_api_key or optional_params.api_key or litellm.api_key
)
api_base = (
dynamic_api_base
or optional_params.api_base
or litellm.api_base
or get_secret("VOYAGE_API_BASE") # type: ignore
)
# optional_rerank_params["top_k"] = top_n if top_n is not None else 3
response = base_llm_http_handler.rerank(
model=model,
custom_llm_provider=_custom_llm_provider,
provider_config=rerank_provider_config,
optional_rerank_params=optional_rerank_params,
logging_obj=litellm_logging_obj,
timeout=optional_params.timeout,
api_key=dynamic_api_key or optional_params.api_key,
api_base=api_base,
_is_async=_is_async,
headers=headers or litellm.headers or {},
client=client,
model_response=model_response,
)
else:
raise ValueError(f"Unsupported provider: {_custom_llm_provider}")

View file

@ -6596,6 +6596,8 @@ class ProviderConfigManager:
return litellm.InfinityRerankConfig()
elif litellm.LlmProviders.JINA_AI == provider:
return litellm.JinaAIRerankConfig()
elif litellm.LlmProviders.VOYAGE == provider:
return litellm.VoyageRerankConfig()
return litellm.CohereRerankConfig()
@staticmethod

View file

@ -11,6 +11,7 @@ sys.path.insert(
from base_embedding_unit_tests import BaseLLMEmbeddingTest
from test_rerank import assert_response_shape
import litellm
from litellm.llms.custom_httpx.http_handler import HTTPHandler
from unittest.mock import patch, MagicMock
@ -77,4 +78,60 @@ def test_voyage_ai_embedding_prompt_token_mapping():
assert response.usage.total_tokens == 120
except Exception as e:
pytest.fail(f"Error occurred: {e}")
pytest.fail(f"Error occurred: {e}")
### Rerank Tests
@pytest.mark.asyncio()
async def test_voyage_ai_rerank():
mock_response = AsyncMock()
def return_val():
return {
"id": "cmpl-mockid",
"results": [{"index": 0, "relevance_score": 0.95}],
"usage": {"total_tokens": 150},
}
mock_response.json = return_val
mock_response.headers = {"key": "value"}
mock_response.status_code = 200
expected_payload = {
"model": "rerank-model",
"query": "hello",
"top_k": 1,
"documents": ["hello", "world", "foo", "bar"],
}
with patch(
"litellm.llms.custom_httpx.http_handler.AsyncHTTPHandler.post",
return_value=mock_response,
) as mock_post:
response = await litellm.arerank(
model="voyage/rerank-model",
query="hello",
documents=["hello", "world", "foo", "bar"],
top_k=1,
)
print("async re rank response: ", response)
# Assert
mock_post.assert_called_once()
args_to_api = mock_post.call_args.kwargs["data"]
_url = mock_post.call_args.kwargs["url"]
assert _url == "https://api.voyageai.com/v1/rerank"
request_data = json.loads(args_to_api)
assert request_data["query"] == expected_payload["query"]
assert request_data["documents"] == expected_payload["documents"]
assert request_data["top_k"] == expected_payload["top_k"]
assert request_data["model"] == expected_payload["model"]
assert response.id is not None
assert response.results is not None
assert response.meta["tokens"]["total_tokens"] == 150
assert_response_shape(response, custom_llm_provider="voyage")