(feat) add infinity rerank models (#7321)

* Support Infinity Reranker (custom reranking models) (#7247)

* Support Infinity Reranker

* Clean code

* Included transformation.py

* Clean code

* Added Infinity reranker test

* Clean code

---------

Co-authored-by: Ishaan Jaff <ishaanjaffer0324@gmail.com>

* transform_rerank_response

* update handler.py

* infinity rerank updates

* ci/cd run again

* add infinity unit tests

* docs add instruction on how to add a new provider for rerank

---------

Co-authored-by: Hao Shan <53949959+haoshan98@users.noreply.github.com>
This commit is contained in:
Ishaan Jaff 2024-12-19 18:30:28 -08:00 committed by GitHub
parent 50204d9a6d
commit 6641e75e0c
11 changed files with 414 additions and 1 deletions

View file

@ -0,0 +1,21 @@
# Directory Structure
When adding a new provider, you need to create a directory for the provider that follows the following structure:
```
litellm/llms/
└── provider_name/
├── completion/
│ ├── handler.py
│ └── transformation.py
├── chat/
│ ├── handler.py
│ └── transformation.py
├── embed/
│ ├── handler.py
│ └── transformation.py
└── rerank/
├── handler.py
└── transformation.py
```

View file

@ -0,0 +1,81 @@
# Add Rerank Provider
LiteLLM **follows the Cohere Rerank API format** for all rerank providers. Here's how to add a new rerank provider:
## 1. Create a transformation.py file
Create a config class named `<Provider><Endpoint>Config` that inherits from `BaseRerankConfig`:
```python
from litellm.types.rerank import OptionalRerankParams, RerankRequest, RerankResponse
class YourProviderRerankConfig(BaseRerankConfig):
def get_supported_cohere_rerank_params(self, model: str) -> list:
return [
"query",
"documents",
"top_n",
# ... other supported params
]
def transform_rerank_request(self, model: str, optional_rerank_params: OptionalRerankParams, headers: dict) -> dict:
# Transform request to RerankRequest spec
return rerank_request.model_dump(exclude_none=True)
def transform_rerank_response(self, model: str, raw_response: httpx.Response, ...) -> RerankResponse:
# Transform provider response to RerankResponse
return RerankResponse(**raw_response_json)
```
## 2. Register Your Provider
Add your provider to `litellm.utils.get_provider_rerank_config()`:
```python
elif litellm.LlmProviders.YOUR_PROVIDER == provider:
return litellm.YourProviderRerankConfig()
```
## 3. Add Provider to `rerank_api/main.py`
Add a code block to handle when your provider is called. Your provider should use the `base_llm_http_handler.rerank` method
```python
elif _custom_llm_provider == "your_provider":
...
response = base_llm_http_handler.rerank(
model=model,
custom_llm_provider=_custom_llm_provider,
optional_rerank_params=optional_rerank_params,
logging_obj=litellm_logging_obj,
timeout=optional_params.timeout,
api_key=dynamic_api_key or optional_params.api_key,
api_base=api_base,
_is_async=_is_async,
headers=headers or litellm.headers or {},
client=client,
mod el_response=model_response,
)
...
```
## 4. Add Tests
Add a test file to [`tests/llm_translation`](https://github.com/BerriAI/litellm/tree/main/tests/llm_translation)
```python
def test_basic_rerank_cohere():
response = litellm.rerank(
model="cohere/rerank-english-v3.0",
query="hello",
documents=["hello", "world"],
top_n=3,
)
print("re rank response: ", response)
assert response.id is not None
assert response.results is not None
```

View file

@ -320,6 +320,13 @@ const sidebars = {
"load_test_rpm", "load_test_rpm",
] ]
}, },
{
type: "category",
label: "Adding Providers",
items: [
"adding_provider/directory_structure",
"adding_provider/new_rerank_provider"],
},
{ {
type: "category", type: "category",
label: "Logging & Observability", label: "Logging & Observability",

View file

@ -126,6 +126,7 @@ azure_key: Optional[str] = None
anthropic_key: Optional[str] = None anthropic_key: Optional[str] = None
replicate_key: Optional[str] = None replicate_key: Optional[str] = None
cohere_key: Optional[str] = None cohere_key: Optional[str] = None
infinity_key: Optional[str] = None
clarifai_key: Optional[str] = None clarifai_key: Optional[str] = None
maritalk_key: Optional[str] = None maritalk_key: Optional[str] = None
ai21_key: Optional[str] = None ai21_key: Optional[str] = None
@ -1025,6 +1026,7 @@ from .llms.replicate.chat.transformation import ReplicateConfig
from .llms.cohere.completion.transformation import CohereTextConfig as CohereConfig from .llms.cohere.completion.transformation import CohereTextConfig as CohereConfig
from .llms.cohere.rerank.transformation import CohereRerankConfig from .llms.cohere.rerank.transformation import CohereRerankConfig
from .llms.azure_ai.rerank.transformation import AzureAIRerankConfig from .llms.azure_ai.rerank.transformation import AzureAIRerankConfig
from .llms.infinity.rerank.transformation import InfinityRerankConfig
from .llms.clarifai.chat.transformation import ClarifaiConfig from .llms.clarifai.chat.transformation import ClarifaiConfig
from .llms.ai21.chat.transformation import AI21ChatConfig, AI21ChatConfig as AI21Config from .llms.ai21.chat.transformation import AI21ChatConfig, AI21ChatConfig as AI21Config
from .llms.together_ai.chat import TogetherAIConfig from .llms.together_ai.chat import TogetherAIConfig

View file

@ -0,0 +1,19 @@
import httpx
from litellm.llms.base_llm.chat.transformation import BaseLLMException
class InfinityError(BaseLLMException):
def __init__(self, status_code, message):
self.status_code = status_code
self.message = message
self.request = httpx.Request(
method="POST", url="https://github.com/michaelfeil/infinity"
)
self.response = httpx.Response(status_code=status_code, request=self.request)
super().__init__(
status_code=status_code,
message=message,
request=self.request,
response=self.response,
) # Call the base class constructor with the parameters it needs

View file

@ -0,0 +1,5 @@
"""
Infinity Rerank - uses `llm_http_handler.py` to make httpx requests
Request/Response transformation is handled in `transformation.py`
"""

View file

@ -0,0 +1,91 @@
"""
Transformation logic from Cohere's /v1/rerank format to Infinity's `/v1/rerank` format.
Why separate file? Make it easy to see how transformation works
"""
import uuid
from typing import List, Optional
import httpx
import litellm
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
from litellm.llms.cohere.rerank.transformation import CohereRerankConfig
from litellm.secret_managers.main import get_secret_str
from litellm.types.rerank import RerankBilledUnits, RerankResponseMeta, RerankTokens
from litellm.types.utils import RerankResponse
from .common_utils import InfinityError
class InfinityRerankConfig(CohereRerankConfig):
def validate_environment(
self,
headers: dict,
model: str,
api_key: Optional[str] = None,
) -> dict:
if api_key is None:
api_key = (
get_secret_str("INFINITY_API_KEY")
or get_secret_str("INFINITY_API_KEY")
or litellm.infinity_key
)
default_headers = {
"Authorization": f"bearer {api_key}",
"accept": "application/json",
"content-type": "application/json",
}
# If 'Authorization' is provided in headers, it overrides the default.
if "Authorization" in headers:
default_headers["Authorization"] = headers["Authorization"]
# Merge other headers, overriding any default ones except Authorization
return {**default_headers, **headers}
def transform_rerank_response(
self,
model: str,
raw_response: httpx.Response,
model_response: RerankResponse,
logging_obj: LiteLLMLoggingObj,
api_key: Optional[str] = None,
request_data: dict = {},
optional_params: dict = {},
litellm_params: dict = {},
) -> RerankResponse:
"""
Transform Infinity rerank response
No transformation required, Infinity follows Cohere API response format
"""
try:
raw_response_json = raw_response.json()
except Exception:
raise InfinityError(
message=raw_response.text, status_code=raw_response.status_code
)
_billed_units = RerankBilledUnits(**raw_response_json.get("usage", {}))
_tokens = RerankTokens(
input_tokens=raw_response_json.get("usage", {}).get("prompt_tokens", 0),
output_tokens=(
raw_response_json.get("usage", {}).get("total_tokens", 0)
- raw_response_json.get("usage", {}).get("prompt_tokens", 0)
),
)
rerank_meta = RerankResponseMeta(billed_units=_billed_units, tokens=_tokens)
_results: Optional[List[dict]] = raw_response_json.get("results")
if _results is None:
raise ValueError(f"No results found in the response={raw_response_json}")
return RerankResponse(
id=raw_response_json.get("id") or str(uuid.uuid4()),
results=_results, # type: ignore
meta=rerank_meta,
) # Return response

View file

@ -76,7 +76,9 @@ def rerank( # noqa: PLR0915
model: str, model: str,
query: str, query: str,
documents: List[Union[str, Dict[str, Any]]], documents: List[Union[str, Dict[str, Any]]],
custom_llm_provider: Optional[Literal["cohere", "together_ai", "azure_ai"]] = None, custom_llm_provider: Optional[
Literal["cohere", "together_ai", "azure_ai", "infinity"]
] = None,
top_n: Optional[int] = None, top_n: Optional[int] = None,
rank_fields: Optional[List[str]] = None, rank_fields: Optional[List[str]] = None,
return_documents: Optional[bool] = True, return_documents: Optional[bool] = True,
@ -188,6 +190,37 @@ def rerank( # noqa: PLR0915
or litellm.api_base or litellm.api_base
or get_secret("AZURE_AI_API_BASE") # type: ignore or get_secret("AZURE_AI_API_BASE") # type: ignore
) )
response = base_llm_http_handler.rerank(
model=model,
custom_llm_provider=_custom_llm_provider,
optional_rerank_params=optional_rerank_params,
logging_obj=litellm_logging_obj,
timeout=optional_params.timeout,
api_key=dynamic_api_key or optional_params.api_key,
api_base=api_base,
_is_async=_is_async,
headers=headers or litellm.headers or {},
client=client,
model_response=model_response,
)
elif _custom_llm_provider == "infinity":
# Implement Infinity rerank logic
api_key: Optional[str] = (
dynamic_api_key or optional_params.api_key or litellm.api_key
)
api_base: Optional[str] = (
dynamic_api_base
or optional_params.api_base
or litellm.api_base
or get_secret("INFINITY_API_BASE") # type: ignore
)
if api_base is None:
raise Exception(
"Invalid api base. api_base=None. Set in call or via `INFINITY_API_BASE` env var."
)
response = base_llm_http_handler.rerank( response = base_llm_http_handler.rerank(
model=model, model=model,
custom_llm_provider=_custom_llm_provider, custom_llm_provider=_custom_llm_provider,

View file

@ -1741,6 +1741,7 @@ class LlmProviders(str, Enum):
HOSTED_VLLM = "hosted_vllm" HOSTED_VLLM = "hosted_vllm"
LM_STUDIO = "lm_studio" LM_STUDIO = "lm_studio"
GALADRIEL = "galadriel" GALADRIEL = "galadriel"
INFINITY = "infinity"
class LiteLLMLoggingBaseClass: class LiteLLMLoggingBaseClass:

View file

@ -6214,6 +6214,8 @@ class ProviderConfigManager:
return litellm.CohereRerankConfig() return litellm.CohereRerankConfig()
elif litellm.LlmProviders.AZURE_AI == provider: elif litellm.LlmProviders.AZURE_AI == provider:
return litellm.AzureAIRerankConfig() return litellm.AzureAIRerankConfig()
elif litellm.LlmProviders.INFINITY == provider:
return litellm.InfinityRerankConfig()
return litellm.CohereRerankConfig() return litellm.CohereRerankConfig()

View file

@ -0,0 +1,151 @@
import json
import os
import sys
from datetime import datetime
from unittest.mock import AsyncMock
sys.path.insert(
0, os.path.abspath("../..")
) # Adds the parent directory to the system-path
import litellm
import json
import os
import sys
from datetime import datetime
from unittest.mock import AsyncMock, patch
import pytest
sys.path.insert(
0, os.path.abspath("../..")
) # Adds the parent directory to the system-path
from test_rerank import assert_response_shape
import litellm
@pytest.mark.asyncio()
async def test_infinity_rerank():
mock_response = AsyncMock()
def return_val():
return {
"id": "cmpl-mockid",
"results": [{"index": 0, "relevance_score": 0.95}],
"usage": {"prompt_tokens": 100, "total_tokens": 150},
}
mock_response.json = return_val
mock_response.headers = {"key": "value"}
mock_response.status_code = 200
expected_payload = {
"model": "rerank-model",
"query": "hello",
"top_n": 3,
"documents": ["hello", "world"],
}
with patch(
"litellm.llms.custom_httpx.http_handler.AsyncHTTPHandler.post",
return_value=mock_response,
) as mock_post:
response = await litellm.arerank(
model="infinity/rerank-model",
query="hello",
documents=["hello", "world"],
top_n=3,
api_base="https://api.infinity.ai",
)
print("async re rank response: ", response)
# Assert
mock_post.assert_called_once()
print("call args", mock_post.call_args)
args_to_api = mock_post.call_args.kwargs["data"]
_url = mock_post.call_args.kwargs["url"]
print("Arguments passed to API=", args_to_api)
print("url = ", _url)
assert _url == "https://api.infinity.ai/v1/rerank"
request_data = json.loads(args_to_api)
assert request_data["query"] == expected_payload["query"]
assert request_data["documents"] == expected_payload["documents"]
assert request_data["top_n"] == expected_payload["top_n"]
assert request_data["model"] == expected_payload["model"]
assert response.id is not None
assert response.results is not None
assert response.meta["tokens"]["input_tokens"] == 100
assert (
response.meta["tokens"]["output_tokens"] == 50
) # total_tokens - prompt_tokens
assert_response_shape(response, custom_llm_provider="infinity")
@pytest.mark.asyncio()
async def test_infinity_rerank_with_env(monkeypatch):
# Set up mock response
mock_response = AsyncMock()
def return_val():
return {
"id": "cmpl-mockid",
"results": [{"index": 0, "relevance_score": 0.95}],
"usage": {"prompt_tokens": 100, "total_tokens": 150},
}
mock_response.json = return_val
mock_response.headers = {"key": "value"}
mock_response.status_code = 200
# Set environment variable
monkeypatch.setenv("INFINITY_API_BASE", "https://env.infinity.ai")
expected_payload = {
"model": "rerank-model",
"query": "hello",
"top_n": 3,
"documents": ["hello", "world"],
}
with patch(
"litellm.llms.custom_httpx.http_handler.AsyncHTTPHandler.post",
return_value=mock_response,
) as mock_post:
response = await litellm.arerank(
model="infinity/rerank-model",
query="hello",
documents=["hello", "world"],
top_n=3,
)
print("async re rank response: ", response)
# Assert
mock_post.assert_called_once()
print("call args", mock_post.call_args)
args_to_api = mock_post.call_args.kwargs["data"]
_url = mock_post.call_args.kwargs["url"]
print("Arguments passed to API=", args_to_api)
print("url = ", _url)
assert _url == "https://env.infinity.ai/v1/rerank"
request_data = json.loads(args_to_api)
assert request_data["query"] == expected_payload["query"]
assert request_data["documents"] == expected_payload["documents"]
assert request_data["top_n"] == expected_payload["top_n"]
assert request_data["model"] == expected_payload["model"]
assert response.id is not None
assert response.results is not None
assert response.meta["tokens"]["input_tokens"] == 100
assert (
response.meta["tokens"]["output_tokens"] == 50
) # total_tokens - prompt_tokens
assert_response_shape(response, custom_llm_provider="infinity")