mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-24 18:24:20 +00:00
Adding reranking for voyage
This commit is contained in:
parent
ebfff975d4
commit
93fde85437
5 changed files with 90 additions and 21 deletions
|
@ -951,6 +951,7 @@ from .llms.topaz.image_variations.transformation import TopazImageVariationConfi
|
|||
from litellm.llms.openai.completion.transformation import OpenAITextCompletionConfig
|
||||
from .llms.groq.chat.transformation import GroqChatConfig
|
||||
from .llms.voyage.embedding.transformation import VoyageEmbeddingConfig
|
||||
from .llms.voyage.rerank.transformation import VoyageRerankConfig
|
||||
from .llms.infinity.embedding.transformation import InfinityEmbeddingConfig
|
||||
from .llms.azure_ai.chat.transformation import AzureAIStudioConfig
|
||||
from .llms.mistral.mistral_chat_transformation import MistralConfig
|
||||
|
|
|
@ -9,25 +9,7 @@ from litellm.secret_managers.main import get_secret_str
|
|||
from litellm.types.llms.openai import AllEmbeddingInputValues, AllMessageValues
|
||||
from litellm.types.utils import EmbeddingResponse, Usage
|
||||
|
||||
|
||||
class VoyageError(BaseLLMException):
|
||||
def __init__(
|
||||
self,
|
||||
status_code: int,
|
||||
message: str,
|
||||
headers: Union[dict, httpx.Headers] = {},
|
||||
):
|
||||
self.status_code = status_code
|
||||
self.message = message
|
||||
self.request = httpx.Request(
|
||||
method="POST", url="https://api.voyageai.com/v1/embeddings"
|
||||
)
|
||||
self.response = httpx.Response(status_code=status_code, request=self.request)
|
||||
super().__init__(
|
||||
status_code=status_code,
|
||||
message=message,
|
||||
headers=headers,
|
||||
)
|
||||
from ..common_utils import VoyageError
|
||||
|
||||
|
||||
class VoyageEmbeddingConfig(BaseEmbeddingConfig):
|
||||
|
|
|
@ -75,7 +75,7 @@ def rerank( # noqa: PLR0915
|
|||
query: str,
|
||||
documents: List[Union[str, Dict[str, Any]]],
|
||||
custom_llm_provider: Optional[
|
||||
Literal["cohere", "together_ai", "azure_ai", "infinity", "litellm_proxy"]
|
||||
Literal["cohere", "together_ai", "azure_ai", "infinity", "voyage", "litellm_proxy"]
|
||||
] = None,
|
||||
top_n: Optional[int] = None,
|
||||
rank_fields: Optional[List[str]] = None,
|
||||
|
@ -323,6 +323,33 @@ def rerank( # noqa: PLR0915
|
|||
logging_obj=litellm_logging_obj,
|
||||
client=client,
|
||||
)
|
||||
elif _custom_llm_provider == "voyage":
|
||||
api_key = (
|
||||
dynamic_api_key or optional_params.api_key or litellm.api_key
|
||||
)
|
||||
api_base = (
|
||||
dynamic_api_base
|
||||
or optional_params.api_base
|
||||
or litellm.api_base
|
||||
or get_secret("VOYAGE_API_BASE") # type: ignore
|
||||
)
|
||||
|
||||
# optional_rerank_params["top_k"] = top_n if top_n is not None else 3
|
||||
|
||||
response = base_llm_http_handler.rerank(
|
||||
model=model,
|
||||
custom_llm_provider=_custom_llm_provider,
|
||||
provider_config=rerank_provider_config,
|
||||
optional_rerank_params=optional_rerank_params,
|
||||
logging_obj=litellm_logging_obj,
|
||||
timeout=optional_params.timeout,
|
||||
api_key=dynamic_api_key or optional_params.api_key,
|
||||
api_base=api_base,
|
||||
_is_async=_is_async,
|
||||
headers=headers or litellm.headers or {},
|
||||
client=client,
|
||||
model_response=model_response,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unsupported provider: {_custom_llm_provider}")
|
||||
|
||||
|
|
|
@ -6596,6 +6596,8 @@ class ProviderConfigManager:
|
|||
return litellm.InfinityRerankConfig()
|
||||
elif litellm.LlmProviders.JINA_AI == provider:
|
||||
return litellm.JinaAIRerankConfig()
|
||||
elif litellm.LlmProviders.VOYAGE == provider:
|
||||
return litellm.VoyageRerankConfig()
|
||||
return litellm.CohereRerankConfig()
|
||||
|
||||
@staticmethod
|
||||
|
|
|
@ -11,6 +11,7 @@ sys.path.insert(
|
|||
|
||||
|
||||
from base_embedding_unit_tests import BaseLLMEmbeddingTest
|
||||
from test_rerank import assert_response_shape
|
||||
import litellm
|
||||
from litellm.llms.custom_httpx.http_handler import HTTPHandler
|
||||
from unittest.mock import patch, MagicMock
|
||||
|
@ -77,4 +78,60 @@ def test_voyage_ai_embedding_prompt_token_mapping():
|
|||
assert response.usage.total_tokens == 120
|
||||
|
||||
except Exception as e:
|
||||
pytest.fail(f"Error occurred: {e}")
|
||||
pytest.fail(f"Error occurred: {e}")
|
||||
|
||||
|
||||
### Rerank Tests
|
||||
@pytest.mark.asyncio()
|
||||
async def test_voyage_ai_rerank():
|
||||
mock_response = AsyncMock()
|
||||
|
||||
def return_val():
|
||||
return {
|
||||
"id": "cmpl-mockid",
|
||||
"results": [{"index": 0, "relevance_score": 0.95}],
|
||||
"usage": {"total_tokens": 150},
|
||||
}
|
||||
|
||||
mock_response.json = return_val
|
||||
mock_response.headers = {"key": "value"}
|
||||
mock_response.status_code = 200
|
||||
|
||||
expected_payload = {
|
||||
"model": "rerank-model",
|
||||
"query": "hello",
|
||||
"top_k": 1,
|
||||
"documents": ["hello", "world", "foo", "bar"],
|
||||
}
|
||||
|
||||
with patch(
|
||||
"litellm.llms.custom_httpx.http_handler.AsyncHTTPHandler.post",
|
||||
return_value=mock_response,
|
||||
) as mock_post:
|
||||
response = await litellm.arerank(
|
||||
model="voyage/rerank-model",
|
||||
query="hello",
|
||||
documents=["hello", "world", "foo", "bar"],
|
||||
top_k=1,
|
||||
)
|
||||
|
||||
print("async re rank response: ", response)
|
||||
|
||||
# Assert
|
||||
mock_post.assert_called_once()
|
||||
args_to_api = mock_post.call_args.kwargs["data"]
|
||||
_url = mock_post.call_args.kwargs["url"]
|
||||
assert _url == "https://api.voyageai.com/v1/rerank"
|
||||
|
||||
request_data = json.loads(args_to_api)
|
||||
assert request_data["query"] == expected_payload["query"]
|
||||
assert request_data["documents"] == expected_payload["documents"]
|
||||
assert request_data["top_k"] == expected_payload["top_k"]
|
||||
assert request_data["model"] == expected_payload["model"]
|
||||
|
||||
assert response.id is not None
|
||||
assert response.results is not None
|
||||
assert response.meta["tokens"]["total_tokens"] == 150
|
||||
|
||||
assert_response_shape(response, custom_llm_provider="voyage")
|
||||
|
Loading…
Add table
Add a link
Reference in a new issue