mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-24 18:24:20 +00:00
Merge 6ad3d6e57c
into b82af5b826
This commit is contained in:
commit
72a4af86e8
8 changed files with 284 additions and 22 deletions
|
@ -952,6 +952,7 @@ from .llms.topaz.image_variations.transformation import TopazImageVariationConfi
|
|||
from litellm.llms.openai.completion.transformation import OpenAITextCompletionConfig
|
||||
from .llms.groq.chat.transformation import GroqChatConfig
|
||||
from .llms.voyage.embedding.transformation import VoyageEmbeddingConfig
|
||||
from .llms.voyage.rerank.transformation import VoyageRerankConfig
|
||||
from .llms.infinity.embedding.transformation import InfinityEmbeddingConfig
|
||||
from .llms.azure_ai.chat.transformation import AzureAIStudioConfig
|
||||
from .llms.mistral.mistral_chat_transformation import MistralConfig
|
||||
|
|
26
litellm/llms/voyage/common_utils.py
Normal file
26
litellm/llms/voyage/common_utils.py
Normal file
|
@ -0,0 +1,26 @@
|
|||
from litellm.llms.base_llm.chat.transformation import BaseLLMException
|
||||
from typing import Union
|
||||
|
||||
import httpx
|
||||
|
||||
|
||||
class VoyageError(BaseLLMException):
|
||||
def __init__(
|
||||
self,
|
||||
status_code: int,
|
||||
message: str,
|
||||
headers: Union[dict, httpx.Headers] = {},
|
||||
):
|
||||
self.status_code = status_code
|
||||
self.message = message
|
||||
self.request = httpx.Request(
|
||||
method="POST", url="https://api.voyageai.com/v1/embeddings"
|
||||
)
|
||||
self.response = httpx.Response(status_code=status_code, request=self.request)
|
||||
super().__init__(
|
||||
status_code=status_code,
|
||||
message=message,
|
||||
headers=headers,
|
||||
request=self.request,
|
||||
response=self.response,
|
||||
)
|
|
@ -9,25 +9,7 @@ from litellm.secret_managers.main import get_secret_str
|
|||
from litellm.types.llms.openai import AllEmbeddingInputValues, AllMessageValues
|
||||
from litellm.types.utils import EmbeddingResponse, Usage
|
||||
|
||||
|
||||
class VoyageError(BaseLLMException):
|
||||
def __init__(
|
||||
self,
|
||||
status_code: int,
|
||||
message: str,
|
||||
headers: Union[dict, httpx.Headers] = {},
|
||||
):
|
||||
self.status_code = status_code
|
||||
self.message = message
|
||||
self.request = httpx.Request(
|
||||
method="POST", url="https://api.voyageai.com/v1/embeddings"
|
||||
)
|
||||
self.response = httpx.Response(status_code=status_code, request=self.request)
|
||||
super().__init__(
|
||||
status_code=status_code,
|
||||
message=message,
|
||||
headers=headers,
|
||||
)
|
||||
from ..common_utils import VoyageError
|
||||
|
||||
|
||||
class VoyageEmbeddingConfig(BaseEmbeddingConfig):
|
||||
|
|
156
litellm/llms/voyage/rerank/transformation.py
Normal file
156
litellm/llms/voyage/rerank/transformation.py
Normal file
|
@ -0,0 +1,156 @@
|
|||
import uuid
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
import httpx
|
||||
|
||||
from litellm.secret_managers.main import get_secret_str
|
||||
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
|
||||
from litellm.llms.base_llm.rerank.transformation import BaseRerankConfig
|
||||
from litellm.types.rerank import OptionalRerankParams, RerankRequest, RerankResponse
|
||||
from litellm.llms.voyage.common_utils import VoyageError
|
||||
from litellm.types.rerank import (
|
||||
RerankBilledUnits,
|
||||
RerankResponseDocument,
|
||||
RerankResponseMeta,
|
||||
RerankResponseResult,
|
||||
RerankTokens,
|
||||
)
|
||||
|
||||
class VoyageRerankConfig(BaseRerankConfig):
|
||||
def __init__(self) -> None:
|
||||
pass
|
||||
|
||||
def get_complete_url(self, api_base: Optional[str], model: str) -> str:
|
||||
if api_base:
|
||||
# Remove trailing slashes and ensure clean base URL
|
||||
api_base = api_base.rstrip("/")
|
||||
if not api_base.endswith("/v1/rerank"):
|
||||
api_base = f"{api_base}/v1/rerank"
|
||||
return api_base
|
||||
return "https://api.voyageai.com/v1/rerank"
|
||||
|
||||
def validate_environment(
|
||||
self,
|
||||
headers: dict,
|
||||
model: str,
|
||||
api_key: Optional[str] = None,
|
||||
) -> dict:
|
||||
if api_key is None:
|
||||
api_key = (
|
||||
get_secret_str("VOYAGE_API_KEY")
|
||||
or get_secret_str("VOYAGE_AI_API_KEY")
|
||||
or get_secret_str("VOYAGE_AI_TOKEN")
|
||||
)
|
||||
return {
|
||||
"Authorization": f"Bearer {api_key}",
|
||||
"content-type": "application/json",
|
||||
}
|
||||
|
||||
def get_supported_cohere_rerank_params(self, model: str) -> list:
|
||||
return [
|
||||
"query",
|
||||
"documents",
|
||||
"top_k",
|
||||
"return_documents",
|
||||
]
|
||||
|
||||
def map_cohere_rerank_params(
|
||||
self,
|
||||
non_default_params: dict,
|
||||
model: str,
|
||||
drop_params: bool,
|
||||
query: str,
|
||||
documents: List[Union[str, Dict[str, Any]]],
|
||||
custom_llm_provider: Optional[str] = None,
|
||||
top_n: Optional[int] = None,
|
||||
rank_fields: Optional[List[str]] = None,
|
||||
return_documents: Optional[bool] = True,
|
||||
max_chunks_per_doc: Optional[int] = None,
|
||||
max_tokens_per_doc: Optional[int] = None,
|
||||
) -> OptionalRerankParams:
|
||||
"""
|
||||
Map Voyage rerank params
|
||||
"""
|
||||
optional_params = {}
|
||||
supported_params = self.get_supported_cohere_rerank_params(model)
|
||||
for k, v in non_default_params.items():
|
||||
if k in supported_params:
|
||||
optional_params[k] = v
|
||||
|
||||
# Voyage API uses top_k instead of top_n
|
||||
# Assign top_k to top_n if top_n is not None
|
||||
if top_n is not None:
|
||||
optional_params["top_k"] = top_n
|
||||
optional_params["top_n"] = None
|
||||
|
||||
return OptionalRerankParams(
|
||||
**optional_params,
|
||||
)
|
||||
def transform_rerank_request(self, model: str, optional_rerank_params: OptionalRerankParams, headers: dict) -> dict:
|
||||
# Transform request to RerankRequest spec
|
||||
if "query" not in optional_rerank_params:
|
||||
raise ValueError("query is required for Cohere rerank")
|
||||
if "documents" not in optional_rerank_params:
|
||||
raise ValueError("documents is required for Voyage rerank")
|
||||
rerank_request = RerankRequest(
|
||||
model=model,
|
||||
query=optional_rerank_params["query"],
|
||||
documents=optional_rerank_params["documents"],
|
||||
# Voyage API uses top_k instead of top_n
|
||||
top_k=optional_rerank_params.get("top_k", None),
|
||||
return_documents=optional_rerank_params.get("return_documents", None),
|
||||
)
|
||||
return rerank_request.model_dump(exclude_none=True)
|
||||
|
||||
def transform_rerank_response(
|
||||
self,
|
||||
model: str,
|
||||
raw_response: httpx.Response,
|
||||
model_response: RerankResponse,
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
api_key: Optional[str] = None,
|
||||
request_data: dict = {},
|
||||
optional_params: dict = {},
|
||||
litellm_params: dict = {},
|
||||
) -> RerankResponse:
|
||||
"""
|
||||
Transform Voyage rerank response
|
||||
|
||||
No transformation required, litellm follows Voyage API response format
|
||||
"""
|
||||
try:
|
||||
raw_response_json = raw_response.json()
|
||||
except Exception:
|
||||
raise VoyageError(
|
||||
message=raw_response.text, status_code=raw_response.status_code
|
||||
)
|
||||
_billed_units = RerankBilledUnits(**raw_response_json.get("usage", {}))
|
||||
_tokens = RerankTokens(
|
||||
input_tokens=raw_response_json.get("usage", {}).get("prompt_tokens", 0),
|
||||
output_tokens=(
|
||||
raw_response_json.get("usage", {}).get("total_tokens", 0)
|
||||
- raw_response_json.get("usage", {}).get("prompt_tokens", 0)
|
||||
),
|
||||
)
|
||||
rerank_meta = RerankResponseMeta(billed_units=_billed_units, tokens=_tokens)
|
||||
|
||||
voyage_results: List[RerankResponseResult] = []
|
||||
if raw_response_json.get("data"):
|
||||
for result in raw_response_json.get("data"):
|
||||
_rerank_response = RerankResponseResult(
|
||||
index=result.get("index"),
|
||||
relevance_score=result.get("relevance_score"),
|
||||
)
|
||||
if result.get("document"):
|
||||
_rerank_response["document"] = RerankResponseDocument(
|
||||
text=result.get("document")
|
||||
)
|
||||
voyage_results.append(_rerank_response)
|
||||
if voyage_results is None:
|
||||
raise ValueError(f"No results found in the response={raw_response_json}")
|
||||
|
||||
return RerankResponse(
|
||||
id=raw_response_json.get("id") or str(uuid.uuid4()),
|
||||
results=voyage_results,
|
||||
meta=rerank_meta,
|
||||
) # Return response
|
|
@ -75,7 +75,7 @@ def rerank( # noqa: PLR0915
|
|||
query: str,
|
||||
documents: List[Union[str, Dict[str, Any]]],
|
||||
custom_llm_provider: Optional[
|
||||
Literal["cohere", "together_ai", "azure_ai", "infinity", "litellm_proxy"]
|
||||
Literal["cohere", "together_ai", "azure_ai", "infinity", "voyage", "litellm_proxy"]
|
||||
] = None,
|
||||
top_n: Optional[int] = None,
|
||||
rank_fields: Optional[List[str]] = None,
|
||||
|
@ -323,6 +323,33 @@ def rerank( # noqa: PLR0915
|
|||
logging_obj=litellm_logging_obj,
|
||||
client=client,
|
||||
)
|
||||
elif _custom_llm_provider == "voyage":
|
||||
api_key = (
|
||||
dynamic_api_key or optional_params.api_key or litellm.api_key
|
||||
)
|
||||
api_base = (
|
||||
dynamic_api_base
|
||||
or optional_params.api_base
|
||||
or litellm.api_base
|
||||
or get_secret("VOYAGE_API_BASE") # type: ignore
|
||||
)
|
||||
|
||||
# optional_rerank_params["top_k"] = top_n if top_n is not None else 3
|
||||
|
||||
response = base_llm_http_handler.rerank(
|
||||
model=model,
|
||||
custom_llm_provider=_custom_llm_provider,
|
||||
provider_config=rerank_provider_config,
|
||||
optional_rerank_params=optional_rerank_params,
|
||||
logging_obj=litellm_logging_obj,
|
||||
timeout=optional_params.timeout,
|
||||
api_key=dynamic_api_key or optional_params.api_key,
|
||||
api_base=api_base,
|
||||
_is_async=_is_async,
|
||||
headers=headers or litellm.headers or {},
|
||||
client=client,
|
||||
model_response=model_response,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unsupported provider: {_custom_llm_provider}")
|
||||
|
||||
|
|
|
@ -19,6 +19,7 @@ class RerankRequest(BaseModel):
|
|||
return_documents: Optional[bool] = None
|
||||
max_chunks_per_doc: Optional[int] = None
|
||||
max_tokens_per_doc: Optional[int] = None
|
||||
top_k: Optional[int] = None
|
||||
|
||||
|
||||
class OptionalRerankParams(TypedDict, total=False):
|
||||
|
@ -29,7 +30,7 @@ class OptionalRerankParams(TypedDict, total=False):
|
|||
return_documents: Optional[bool]
|
||||
max_chunks_per_doc: Optional[int]
|
||||
max_tokens_per_doc: Optional[int]
|
||||
|
||||
top_k: Optional[int]
|
||||
|
||||
class RerankBilledUnits(TypedDict, total=False):
|
||||
search_units: Optional[int]
|
||||
|
|
|
@ -6604,6 +6604,8 @@ class ProviderConfigManager:
|
|||
return litellm.InfinityRerankConfig()
|
||||
elif litellm.LlmProviders.JINA_AI == provider:
|
||||
return litellm.JinaAIRerankConfig()
|
||||
elif litellm.LlmProviders.VOYAGE == provider:
|
||||
return litellm.VoyageRerankConfig()
|
||||
return litellm.CohereRerankConfig()
|
||||
|
||||
@staticmethod
|
||||
|
|
|
@ -11,6 +11,7 @@ sys.path.insert(
|
|||
|
||||
|
||||
from base_embedding_unit_tests import BaseLLMEmbeddingTest
|
||||
from test_rerank import assert_response_shape
|
||||
import litellm
|
||||
from litellm.llms.custom_httpx.http_handler import HTTPHandler
|
||||
from unittest.mock import patch, MagicMock
|
||||
|
@ -77,4 +78,70 @@ def test_voyage_ai_embedding_prompt_token_mapping():
|
|||
assert response.usage.total_tokens == 120
|
||||
|
||||
except Exception as e:
|
||||
pytest.fail(f"Error occurred: {e}")
|
||||
pytest.fail(f"Error occurred: {e}")
|
||||
|
||||
|
||||
### Rerank Tests
|
||||
@pytest.mark.asyncio()
|
||||
async def test_voyage_ai_rerank():
|
||||
mock_response = AsyncMock()
|
||||
|
||||
def return_val():
|
||||
return {
|
||||
"id": "cmpl-mockid",
|
||||
"results": [{"index": 2, "relevance_score": 0.84375}],
|
||||
"usage": {"total_tokens": 150},
|
||||
}
|
||||
|
||||
mock_response.json = return_val
|
||||
mock_response.headers = {"key": "value"}
|
||||
mock_response.status_code = 200
|
||||
|
||||
expected_payload = {
|
||||
"model": "rerank-model",
|
||||
"query": "What is the capital of the United States?",
|
||||
# Voyage API uses top_k instead of top_n
|
||||
"top_k": 1,
|
||||
"documents": [
|
||||
"Carson City is the capital city of the American state of Nevada.",
|
||||
"The Commonwealth of the Northern Mariana Islands is a group of islands in the Pacific Ocean. Its capital is Saipan.",
|
||||
"Washington, D.C. is the capital of the United States.",
|
||||
"Capital punishment has existed in the United States since before it was a country."
|
||||
],
|
||||
}
|
||||
|
||||
with patch(
|
||||
"litellm.llms.custom_httpx.http_handler.AsyncHTTPHandler.post",
|
||||
return_value=mock_response,
|
||||
) as mock_post:
|
||||
response = await litellm.arerank(
|
||||
model="voyage/rerank-model",
|
||||
query="What is the capital of the United States?",
|
||||
documents=["Carson City is the capital city of the American state of Nevada.", "The Commonwealth of the Northern Mariana Islands is a group of islands in the Pacific Ocean. Its capital is Saipan.", "Washington, D.C. is the capital of the United States.", "Capital punishment has existed in the United States since before it was a country."],
|
||||
top_n=1, # This will be converted to top_k internally
|
||||
api_base="https://api.voyageai.ai"
|
||||
)
|
||||
|
||||
print("async re rank response: ", response)
|
||||
|
||||
# Assert
|
||||
mock_post.assert_called_once()
|
||||
print("call args", mock_post.call_args)
|
||||
args_to_api = mock_post.call_args.kwargs["data"]
|
||||
_url = mock_post.call_args.kwargs["url"]
|
||||
print("Arguments passed to API=", args_to_api)
|
||||
print("url = ", _url)
|
||||
assert _url == "https://api.voyageai.ai/v1/rerank"
|
||||
|
||||
request_data = json.loads(args_to_api)
|
||||
print("request data to voyage ai", json.dumps(request_data, indent=4))
|
||||
assert request_data["query"] == expected_payload["query"]
|
||||
assert request_data["documents"] == expected_payload["documents"]
|
||||
assert request_data["top_k"] == expected_payload["top_k"]
|
||||
assert request_data["model"] == expected_payload["model"]
|
||||
|
||||
assert response.id is not None
|
||||
assert response.results is not None
|
||||
assert response.meta["tokens"]["output_tokens"] == 150
|
||||
assert_response_shape(response, custom_llm_provider="voyage")
|
||||
|
Loading…
Add table
Add a link
Reference in a new issue