mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 03:04:13 +00:00
Adding reranking for voyage
This commit is contained in:
parent
ebfff975d4
commit
93fde85437
5 changed files with 90 additions and 21 deletions
|
@ -951,6 +951,7 @@ from .llms.topaz.image_variations.transformation import TopazImageVariationConfi
|
||||||
from litellm.llms.openai.completion.transformation import OpenAITextCompletionConfig
|
from litellm.llms.openai.completion.transformation import OpenAITextCompletionConfig
|
||||||
from .llms.groq.chat.transformation import GroqChatConfig
|
from .llms.groq.chat.transformation import GroqChatConfig
|
||||||
from .llms.voyage.embedding.transformation import VoyageEmbeddingConfig
|
from .llms.voyage.embedding.transformation import VoyageEmbeddingConfig
|
||||||
|
from .llms.voyage.rerank.transformation import VoyageRerankConfig
|
||||||
from .llms.infinity.embedding.transformation import InfinityEmbeddingConfig
|
from .llms.infinity.embedding.transformation import InfinityEmbeddingConfig
|
||||||
from .llms.azure_ai.chat.transformation import AzureAIStudioConfig
|
from .llms.azure_ai.chat.transformation import AzureAIStudioConfig
|
||||||
from .llms.mistral.mistral_chat_transformation import MistralConfig
|
from .llms.mistral.mistral_chat_transformation import MistralConfig
|
||||||
|
|
|
@ -9,25 +9,7 @@ from litellm.secret_managers.main import get_secret_str
|
||||||
from litellm.types.llms.openai import AllEmbeddingInputValues, AllMessageValues
|
from litellm.types.llms.openai import AllEmbeddingInputValues, AllMessageValues
|
||||||
from litellm.types.utils import EmbeddingResponse, Usage
|
from litellm.types.utils import EmbeddingResponse, Usage
|
||||||
|
|
||||||
|
from ..common_utils import VoyageError
|
||||||
class VoyageError(BaseLLMException):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
status_code: int,
|
|
||||||
message: str,
|
|
||||||
headers: Union[dict, httpx.Headers] = {},
|
|
||||||
):
|
|
||||||
self.status_code = status_code
|
|
||||||
self.message = message
|
|
||||||
self.request = httpx.Request(
|
|
||||||
method="POST", url="https://api.voyageai.com/v1/embeddings"
|
|
||||||
)
|
|
||||||
self.response = httpx.Response(status_code=status_code, request=self.request)
|
|
||||||
super().__init__(
|
|
||||||
status_code=status_code,
|
|
||||||
message=message,
|
|
||||||
headers=headers,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class VoyageEmbeddingConfig(BaseEmbeddingConfig):
|
class VoyageEmbeddingConfig(BaseEmbeddingConfig):
|
||||||
|
|
|
@ -75,7 +75,7 @@ def rerank( # noqa: PLR0915
|
||||||
query: str,
|
query: str,
|
||||||
documents: List[Union[str, Dict[str, Any]]],
|
documents: List[Union[str, Dict[str, Any]]],
|
||||||
custom_llm_provider: Optional[
|
custom_llm_provider: Optional[
|
||||||
Literal["cohere", "together_ai", "azure_ai", "infinity", "litellm_proxy"]
|
Literal["cohere", "together_ai", "azure_ai", "infinity", "voyage", "litellm_proxy"]
|
||||||
] = None,
|
] = None,
|
||||||
top_n: Optional[int] = None,
|
top_n: Optional[int] = None,
|
||||||
rank_fields: Optional[List[str]] = None,
|
rank_fields: Optional[List[str]] = None,
|
||||||
|
@ -323,6 +323,33 @@ def rerank( # noqa: PLR0915
|
||||||
logging_obj=litellm_logging_obj,
|
logging_obj=litellm_logging_obj,
|
||||||
client=client,
|
client=client,
|
||||||
)
|
)
|
||||||
|
elif _custom_llm_provider == "voyage":
|
||||||
|
api_key = (
|
||||||
|
dynamic_api_key or optional_params.api_key or litellm.api_key
|
||||||
|
)
|
||||||
|
api_base = (
|
||||||
|
dynamic_api_base
|
||||||
|
or optional_params.api_base
|
||||||
|
or litellm.api_base
|
||||||
|
or get_secret("VOYAGE_API_BASE") # type: ignore
|
||||||
|
)
|
||||||
|
|
||||||
|
# optional_rerank_params["top_k"] = top_n if top_n is not None else 3
|
||||||
|
|
||||||
|
response = base_llm_http_handler.rerank(
|
||||||
|
model=model,
|
||||||
|
custom_llm_provider=_custom_llm_provider,
|
||||||
|
provider_config=rerank_provider_config,
|
||||||
|
optional_rerank_params=optional_rerank_params,
|
||||||
|
logging_obj=litellm_logging_obj,
|
||||||
|
timeout=optional_params.timeout,
|
||||||
|
api_key=dynamic_api_key or optional_params.api_key,
|
||||||
|
api_base=api_base,
|
||||||
|
_is_async=_is_async,
|
||||||
|
headers=headers or litellm.headers or {},
|
||||||
|
client=client,
|
||||||
|
model_response=model_response,
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unsupported provider: {_custom_llm_provider}")
|
raise ValueError(f"Unsupported provider: {_custom_llm_provider}")
|
||||||
|
|
||||||
|
|
|
@ -6596,6 +6596,8 @@ class ProviderConfigManager:
|
||||||
return litellm.InfinityRerankConfig()
|
return litellm.InfinityRerankConfig()
|
||||||
elif litellm.LlmProviders.JINA_AI == provider:
|
elif litellm.LlmProviders.JINA_AI == provider:
|
||||||
return litellm.JinaAIRerankConfig()
|
return litellm.JinaAIRerankConfig()
|
||||||
|
elif litellm.LlmProviders.VOYAGE == provider:
|
||||||
|
return litellm.VoyageRerankConfig()
|
||||||
return litellm.CohereRerankConfig()
|
return litellm.CohereRerankConfig()
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
|
|
@ -11,6 +11,7 @@ sys.path.insert(
|
||||||
|
|
||||||
|
|
||||||
from base_embedding_unit_tests import BaseLLMEmbeddingTest
|
from base_embedding_unit_tests import BaseLLMEmbeddingTest
|
||||||
|
from test_rerank import assert_response_shape
|
||||||
import litellm
|
import litellm
|
||||||
from litellm.llms.custom_httpx.http_handler import HTTPHandler
|
from litellm.llms.custom_httpx.http_handler import HTTPHandler
|
||||||
from unittest.mock import patch, MagicMock
|
from unittest.mock import patch, MagicMock
|
||||||
|
@ -77,4 +78,60 @@ def test_voyage_ai_embedding_prompt_token_mapping():
|
||||||
assert response.usage.total_tokens == 120
|
assert response.usage.total_tokens == 120
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
pytest.fail(f"Error occurred: {e}")
|
pytest.fail(f"Error occurred: {e}")
|
||||||
|
|
||||||
|
|
||||||
|
### Rerank Tests
|
||||||
|
@pytest.mark.asyncio()
|
||||||
|
async def test_voyage_ai_rerank():
|
||||||
|
mock_response = AsyncMock()
|
||||||
|
|
||||||
|
def return_val():
|
||||||
|
return {
|
||||||
|
"id": "cmpl-mockid",
|
||||||
|
"results": [{"index": 0, "relevance_score": 0.95}],
|
||||||
|
"usage": {"total_tokens": 150},
|
||||||
|
}
|
||||||
|
|
||||||
|
mock_response.json = return_val
|
||||||
|
mock_response.headers = {"key": "value"}
|
||||||
|
mock_response.status_code = 200
|
||||||
|
|
||||||
|
expected_payload = {
|
||||||
|
"model": "rerank-model",
|
||||||
|
"query": "hello",
|
||||||
|
"top_k": 1,
|
||||||
|
"documents": ["hello", "world", "foo", "bar"],
|
||||||
|
}
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"litellm.llms.custom_httpx.http_handler.AsyncHTTPHandler.post",
|
||||||
|
return_value=mock_response,
|
||||||
|
) as mock_post:
|
||||||
|
response = await litellm.arerank(
|
||||||
|
model="voyage/rerank-model",
|
||||||
|
query="hello",
|
||||||
|
documents=["hello", "world", "foo", "bar"],
|
||||||
|
top_k=1,
|
||||||
|
)
|
||||||
|
|
||||||
|
print("async re rank response: ", response)
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
mock_post.assert_called_once()
|
||||||
|
args_to_api = mock_post.call_args.kwargs["data"]
|
||||||
|
_url = mock_post.call_args.kwargs["url"]
|
||||||
|
assert _url == "https://api.voyageai.com/v1/rerank"
|
||||||
|
|
||||||
|
request_data = json.loads(args_to_api)
|
||||||
|
assert request_data["query"] == expected_payload["query"]
|
||||||
|
assert request_data["documents"] == expected_payload["documents"]
|
||||||
|
assert request_data["top_k"] == expected_payload["top_k"]
|
||||||
|
assert request_data["model"] == expected_payload["model"]
|
||||||
|
|
||||||
|
assert response.id is not None
|
||||||
|
assert response.results is not None
|
||||||
|
assert response.meta["tokens"]["total_tokens"] == 150
|
||||||
|
|
||||||
|
assert_response_shape(response, custom_llm_provider="voyage")
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue