Adding reranking for voyage

This commit is contained in:
Prathamesh 2025-04-16 08:49:48 +05:30
parent ebfff975d4
commit 93fde85437
5 changed files with 90 additions and 21 deletions

View file

@ -951,6 +951,7 @@ from .llms.topaz.image_variations.transformation import TopazImageVariationConfi
from litellm.llms.openai.completion.transformation import OpenAITextCompletionConfig from litellm.llms.openai.completion.transformation import OpenAITextCompletionConfig
from .llms.groq.chat.transformation import GroqChatConfig from .llms.groq.chat.transformation import GroqChatConfig
from .llms.voyage.embedding.transformation import VoyageEmbeddingConfig from .llms.voyage.embedding.transformation import VoyageEmbeddingConfig
from .llms.voyage.rerank.transformation import VoyageRerankConfig
from .llms.infinity.embedding.transformation import InfinityEmbeddingConfig from .llms.infinity.embedding.transformation import InfinityEmbeddingConfig
from .llms.azure_ai.chat.transformation import AzureAIStudioConfig from .llms.azure_ai.chat.transformation import AzureAIStudioConfig
from .llms.mistral.mistral_chat_transformation import MistralConfig from .llms.mistral.mistral_chat_transformation import MistralConfig

View file

@ -9,25 +9,7 @@ from litellm.secret_managers.main import get_secret_str
from litellm.types.llms.openai import AllEmbeddingInputValues, AllMessageValues from litellm.types.llms.openai import AllEmbeddingInputValues, AllMessageValues
from litellm.types.utils import EmbeddingResponse, Usage from litellm.types.utils import EmbeddingResponse, Usage
from ..common_utils import VoyageError
class VoyageError(BaseLLMException):
def __init__(
self,
status_code: int,
message: str,
headers: Union[dict, httpx.Headers] = {},
):
self.status_code = status_code
self.message = message
self.request = httpx.Request(
method="POST", url="https://api.voyageai.com/v1/embeddings"
)
self.response = httpx.Response(status_code=status_code, request=self.request)
super().__init__(
status_code=status_code,
message=message,
headers=headers,
)
class VoyageEmbeddingConfig(BaseEmbeddingConfig): class VoyageEmbeddingConfig(BaseEmbeddingConfig):

View file

@ -75,7 +75,7 @@ def rerank( # noqa: PLR0915
query: str, query: str,
documents: List[Union[str, Dict[str, Any]]], documents: List[Union[str, Dict[str, Any]]],
custom_llm_provider: Optional[ custom_llm_provider: Optional[
Literal["cohere", "together_ai", "azure_ai", "infinity", "litellm_proxy"] Literal["cohere", "together_ai", "azure_ai", "infinity", "voyage", "litellm_proxy"]
] = None, ] = None,
top_n: Optional[int] = None, top_n: Optional[int] = None,
rank_fields: Optional[List[str]] = None, rank_fields: Optional[List[str]] = None,
@ -323,6 +323,33 @@ def rerank( # noqa: PLR0915
logging_obj=litellm_logging_obj, logging_obj=litellm_logging_obj,
client=client, client=client,
) )
elif _custom_llm_provider == "voyage":
api_key = (
dynamic_api_key or optional_params.api_key or litellm.api_key
)
api_base = (
dynamic_api_base
or optional_params.api_base
or litellm.api_base
or get_secret("VOYAGE_API_BASE") # type: ignore
)
# optional_rerank_params["top_k"] = top_n if top_n is not None else 3
response = base_llm_http_handler.rerank(
model=model,
custom_llm_provider=_custom_llm_provider,
provider_config=rerank_provider_config,
optional_rerank_params=optional_rerank_params,
logging_obj=litellm_logging_obj,
timeout=optional_params.timeout,
api_key=dynamic_api_key or optional_params.api_key,
api_base=api_base,
_is_async=_is_async,
headers=headers or litellm.headers or {},
client=client,
model_response=model_response,
)
else: else:
raise ValueError(f"Unsupported provider: {_custom_llm_provider}") raise ValueError(f"Unsupported provider: {_custom_llm_provider}")

View file

@ -6596,6 +6596,8 @@ class ProviderConfigManager:
return litellm.InfinityRerankConfig() return litellm.InfinityRerankConfig()
elif litellm.LlmProviders.JINA_AI == provider: elif litellm.LlmProviders.JINA_AI == provider:
return litellm.JinaAIRerankConfig() return litellm.JinaAIRerankConfig()
elif litellm.LlmProviders.VOYAGE == provider:
return litellm.VoyageRerankConfig()
return litellm.CohereRerankConfig() return litellm.CohereRerankConfig()
@staticmethod @staticmethod

View file

@ -11,6 +11,7 @@ sys.path.insert(
from base_embedding_unit_tests import BaseLLMEmbeddingTest from base_embedding_unit_tests import BaseLLMEmbeddingTest
from test_rerank import assert_response_shape
import litellm import litellm
from litellm.llms.custom_httpx.http_handler import HTTPHandler from litellm.llms.custom_httpx.http_handler import HTTPHandler
from unittest.mock import patch, MagicMock from unittest.mock import patch, MagicMock
@ -77,4 +78,60 @@ def test_voyage_ai_embedding_prompt_token_mapping():
assert response.usage.total_tokens == 120 assert response.usage.total_tokens == 120
except Exception as e: except Exception as e:
pytest.fail(f"Error occurred: {e}") pytest.fail(f"Error occurred: {e}")
### Rerank Tests
@pytest.mark.asyncio()
async def test_voyage_ai_rerank():
mock_response = AsyncMock()
def return_val():
return {
"id": "cmpl-mockid",
"results": [{"index": 0, "relevance_score": 0.95}],
"usage": {"total_tokens": 150},
}
mock_response.json = return_val
mock_response.headers = {"key": "value"}
mock_response.status_code = 200
expected_payload = {
"model": "rerank-model",
"query": "hello",
"top_k": 1,
"documents": ["hello", "world", "foo", "bar"],
}
with patch(
"litellm.llms.custom_httpx.http_handler.AsyncHTTPHandler.post",
return_value=mock_response,
) as mock_post:
response = await litellm.arerank(
model="voyage/rerank-model",
query="hello",
documents=["hello", "world", "foo", "bar"],
top_k=1,
)
print("async re rank response: ", response)
# Assert
mock_post.assert_called_once()
args_to_api = mock_post.call_args.kwargs["data"]
_url = mock_post.call_args.kwargs["url"]
assert _url == "https://api.voyageai.com/v1/rerank"
request_data = json.loads(args_to_api)
assert request_data["query"] == expected_payload["query"]
assert request_data["documents"] == expected_payload["documents"]
assert request_data["top_k"] == expected_payload["top_k"]
assert request_data["model"] == expected_payload["model"]
assert response.id is not None
assert response.results is not None
assert response.meta["tokens"]["total_tokens"] == 150
assert_response_shape(response, custom_llm_provider="voyage")