(feat) add infinity rerank models (#7321)

* Support Infinity Reranker (custom reranking models) (#7247)

* Support Infinity Reranker

* Clean code

* Included transformation.py

* Clean code

* Added Infinity reranker test

* Clean code

---------

Co-authored-by: Ishaan Jaff <ishaanjaffer0324@gmail.com>

* transform_rerank_response

* update handler.py

* infinity rerank updates

* ci/cd run again

* add infinity unit tests

* docs add instruction on how to add a new provider for rerank

---------

Co-authored-by: Hao Shan <53949959+haoshan98@users.noreply.github.com>
This commit is contained in:
Ishaan Jaff 2024-12-19 18:30:28 -08:00 committed by GitHub
parent 50204d9a6d
commit 6641e75e0c
11 changed files with 414 additions and 1 deletions

View file

@ -0,0 +1,21 @@
# Directory Structure
When adding a new provider, you need to create a directory for the provider that follows the following structure:
```
litellm/llms/
└── provider_name/
├── completion/
│ ├── handler.py
│ └── transformation.py
├── chat/
│ ├── handler.py
│ └── transformation.py
├── embed/
│ ├── handler.py
│ └── transformation.py
└── rerank/
├── handler.py
└── transformation.py
```

View file

@ -0,0 +1,81 @@
# Add Rerank Provider
LiteLLM **follows the Cohere Rerank API format** for all rerank providers. Here's how to add a new rerank provider:
## 1. Create a transformation.py file
Create a config class named `<Provider><Endpoint>Config` that inherits from `BaseRerankConfig`:
```python
from litellm.types.rerank import OptionalRerankParams, RerankRequest, RerankResponse
class YourProviderRerankConfig(BaseRerankConfig):
def get_supported_cohere_rerank_params(self, model: str) -> list:
return [
"query",
"documents",
"top_n",
# ... other supported params
]
def transform_rerank_request(self, model: str, optional_rerank_params: OptionalRerankParams, headers: dict) -> dict:
# Transform request to RerankRequest spec
return rerank_request.model_dump(exclude_none=True)
def transform_rerank_response(self, model: str, raw_response: httpx.Response, ...) -> RerankResponse:
# Transform provider response to RerankResponse
return RerankResponse(**raw_response_json)
```
## 2. Register Your Provider
Add your provider to `litellm.utils.get_provider_rerank_config()`:
```python
elif litellm.LlmProviders.YOUR_PROVIDER == provider:
return litellm.YourProviderRerankConfig()
```
## 3. Add Provider to `rerank_api/main.py`
Add a code block to handle when your provider is called. Your provider should use the `base_llm_http_handler.rerank` method
```python
elif _custom_llm_provider == "your_provider":
...
response = base_llm_http_handler.rerank(
model=model,
custom_llm_provider=_custom_llm_provider,
optional_rerank_params=optional_rerank_params,
logging_obj=litellm_logging_obj,
timeout=optional_params.timeout,
api_key=dynamic_api_key or optional_params.api_key,
api_base=api_base,
_is_async=_is_async,
headers=headers or litellm.headers or {},
client=client,
mod el_response=model_response,
)
...
```
## 4. Add Tests
Add a test file to [`tests/llm_translation`](https://github.com/BerriAI/litellm/tree/main/tests/llm_translation)
```python
def test_basic_rerank_cohere():
response = litellm.rerank(
model="cohere/rerank-english-v3.0",
query="hello",
documents=["hello", "world"],
top_n=3,
)
print("re rank response: ", response)
assert response.id is not None
assert response.results is not None
```

View file

@ -320,6 +320,13 @@ const sidebars = {
"load_test_rpm",
]
},
{
type: "category",
label: "Adding Providers",
items: [
"adding_provider/directory_structure",
"adding_provider/new_rerank_provider"],
},
{
type: "category",
label: "Logging & Observability",

View file

@ -126,6 +126,7 @@ azure_key: Optional[str] = None
anthropic_key: Optional[str] = None
replicate_key: Optional[str] = None
cohere_key: Optional[str] = None
infinity_key: Optional[str] = None
clarifai_key: Optional[str] = None
maritalk_key: Optional[str] = None
ai21_key: Optional[str] = None
@ -1025,6 +1026,7 @@ from .llms.replicate.chat.transformation import ReplicateConfig
from .llms.cohere.completion.transformation import CohereTextConfig as CohereConfig
from .llms.cohere.rerank.transformation import CohereRerankConfig
from .llms.azure_ai.rerank.transformation import AzureAIRerankConfig
from .llms.infinity.rerank.transformation import InfinityRerankConfig
from .llms.clarifai.chat.transformation import ClarifaiConfig
from .llms.ai21.chat.transformation import AI21ChatConfig, AI21ChatConfig as AI21Config
from .llms.together_ai.chat import TogetherAIConfig

View file

@ -0,0 +1,19 @@
import httpx
from litellm.llms.base_llm.chat.transformation import BaseLLMException
class InfinityError(BaseLLMException):
def __init__(self, status_code, message):
self.status_code = status_code
self.message = message
self.request = httpx.Request(
method="POST", url="https://github.com/michaelfeil/infinity"
)
self.response = httpx.Response(status_code=status_code, request=self.request)
super().__init__(
status_code=status_code,
message=message,
request=self.request,
response=self.response,
) # Call the base class constructor with the parameters it needs

View file

@ -0,0 +1,5 @@
"""
Infinity Rerank - uses `llm_http_handler.py` to make httpx requests
Request/Response transformation is handled in `transformation.py`
"""

View file

@ -0,0 +1,91 @@
"""
Transformation logic from Cohere's /v1/rerank format to Infinity's `/v1/rerank` format.
Why separate file? Make it easy to see how transformation works
"""
import uuid
from typing import List, Optional
import httpx
import litellm
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
from litellm.llms.cohere.rerank.transformation import CohereRerankConfig
from litellm.secret_managers.main import get_secret_str
from litellm.types.rerank import RerankBilledUnits, RerankResponseMeta, RerankTokens
from litellm.types.utils import RerankResponse
from .common_utils import InfinityError
class InfinityRerankConfig(CohereRerankConfig):
def validate_environment(
self,
headers: dict,
model: str,
api_key: Optional[str] = None,
) -> dict:
if api_key is None:
api_key = (
get_secret_str("INFINITY_API_KEY")
or get_secret_str("INFINITY_API_KEY")
or litellm.infinity_key
)
default_headers = {
"Authorization": f"bearer {api_key}",
"accept": "application/json",
"content-type": "application/json",
}
# If 'Authorization' is provided in headers, it overrides the default.
if "Authorization" in headers:
default_headers["Authorization"] = headers["Authorization"]
# Merge other headers, overriding any default ones except Authorization
return {**default_headers, **headers}
def transform_rerank_response(
self,
model: str,
raw_response: httpx.Response,
model_response: RerankResponse,
logging_obj: LiteLLMLoggingObj,
api_key: Optional[str] = None,
request_data: dict = {},
optional_params: dict = {},
litellm_params: dict = {},
) -> RerankResponse:
"""
Transform Infinity rerank response
No transformation required, Infinity follows Cohere API response format
"""
try:
raw_response_json = raw_response.json()
except Exception:
raise InfinityError(
message=raw_response.text, status_code=raw_response.status_code
)
_billed_units = RerankBilledUnits(**raw_response_json.get("usage", {}))
_tokens = RerankTokens(
input_tokens=raw_response_json.get("usage", {}).get("prompt_tokens", 0),
output_tokens=(
raw_response_json.get("usage", {}).get("total_tokens", 0)
- raw_response_json.get("usage", {}).get("prompt_tokens", 0)
),
)
rerank_meta = RerankResponseMeta(billed_units=_billed_units, tokens=_tokens)
_results: Optional[List[dict]] = raw_response_json.get("results")
if _results is None:
raise ValueError(f"No results found in the response={raw_response_json}")
return RerankResponse(
id=raw_response_json.get("id") or str(uuid.uuid4()),
results=_results, # type: ignore
meta=rerank_meta,
) # Return response

View file

@ -76,7 +76,9 @@ def rerank( # noqa: PLR0915
model: str,
query: str,
documents: List[Union[str, Dict[str, Any]]],
custom_llm_provider: Optional[Literal["cohere", "together_ai", "azure_ai"]] = None,
custom_llm_provider: Optional[
Literal["cohere", "together_ai", "azure_ai", "infinity"]
] = None,
top_n: Optional[int] = None,
rank_fields: Optional[List[str]] = None,
return_documents: Optional[bool] = True,
@ -188,6 +190,37 @@ def rerank( # noqa: PLR0915
or litellm.api_base
or get_secret("AZURE_AI_API_BASE") # type: ignore
)
response = base_llm_http_handler.rerank(
model=model,
custom_llm_provider=_custom_llm_provider,
optional_rerank_params=optional_rerank_params,
logging_obj=litellm_logging_obj,
timeout=optional_params.timeout,
api_key=dynamic_api_key or optional_params.api_key,
api_base=api_base,
_is_async=_is_async,
headers=headers or litellm.headers or {},
client=client,
model_response=model_response,
)
elif _custom_llm_provider == "infinity":
# Implement Infinity rerank logic
api_key: Optional[str] = (
dynamic_api_key or optional_params.api_key or litellm.api_key
)
api_base: Optional[str] = (
dynamic_api_base
or optional_params.api_base
or litellm.api_base
or get_secret("INFINITY_API_BASE") # type: ignore
)
if api_base is None:
raise Exception(
"Invalid api base. api_base=None. Set in call or via `INFINITY_API_BASE` env var."
)
response = base_llm_http_handler.rerank(
model=model,
custom_llm_provider=_custom_llm_provider,

View file

@ -1741,6 +1741,7 @@ class LlmProviders(str, Enum):
HOSTED_VLLM = "hosted_vllm"
LM_STUDIO = "lm_studio"
GALADRIEL = "galadriel"
INFINITY = "infinity"
class LiteLLMLoggingBaseClass:

View file

@ -6214,6 +6214,8 @@ class ProviderConfigManager:
return litellm.CohereRerankConfig()
elif litellm.LlmProviders.AZURE_AI == provider:
return litellm.AzureAIRerankConfig()
elif litellm.LlmProviders.INFINITY == provider:
return litellm.InfinityRerankConfig()
return litellm.CohereRerankConfig()

View file

@ -0,0 +1,151 @@
import json
import os
import sys
from datetime import datetime
from unittest.mock import AsyncMock
sys.path.insert(
0, os.path.abspath("../..")
) # Adds the parent directory to the system-path
import litellm
import json
import os
import sys
from datetime import datetime
from unittest.mock import AsyncMock, patch
import pytest
sys.path.insert(
0, os.path.abspath("../..")
) # Adds the parent directory to the system-path
from test_rerank import assert_response_shape
import litellm
@pytest.mark.asyncio()
async def test_infinity_rerank():
mock_response = AsyncMock()
def return_val():
return {
"id": "cmpl-mockid",
"results": [{"index": 0, "relevance_score": 0.95}],
"usage": {"prompt_tokens": 100, "total_tokens": 150},
}
mock_response.json = return_val
mock_response.headers = {"key": "value"}
mock_response.status_code = 200
expected_payload = {
"model": "rerank-model",
"query": "hello",
"top_n": 3,
"documents": ["hello", "world"],
}
with patch(
"litellm.llms.custom_httpx.http_handler.AsyncHTTPHandler.post",
return_value=mock_response,
) as mock_post:
response = await litellm.arerank(
model="infinity/rerank-model",
query="hello",
documents=["hello", "world"],
top_n=3,
api_base="https://api.infinity.ai",
)
print("async re rank response: ", response)
# Assert
mock_post.assert_called_once()
print("call args", mock_post.call_args)
args_to_api = mock_post.call_args.kwargs["data"]
_url = mock_post.call_args.kwargs["url"]
print("Arguments passed to API=", args_to_api)
print("url = ", _url)
assert _url == "https://api.infinity.ai/v1/rerank"
request_data = json.loads(args_to_api)
assert request_data["query"] == expected_payload["query"]
assert request_data["documents"] == expected_payload["documents"]
assert request_data["top_n"] == expected_payload["top_n"]
assert request_data["model"] == expected_payload["model"]
assert response.id is not None
assert response.results is not None
assert response.meta["tokens"]["input_tokens"] == 100
assert (
response.meta["tokens"]["output_tokens"] == 50
) # total_tokens - prompt_tokens
assert_response_shape(response, custom_llm_provider="infinity")
@pytest.mark.asyncio()
async def test_infinity_rerank_with_env(monkeypatch):
# Set up mock response
mock_response = AsyncMock()
def return_val():
return {
"id": "cmpl-mockid",
"results": [{"index": 0, "relevance_score": 0.95}],
"usage": {"prompt_tokens": 100, "total_tokens": 150},
}
mock_response.json = return_val
mock_response.headers = {"key": "value"}
mock_response.status_code = 200
# Set environment variable
monkeypatch.setenv("INFINITY_API_BASE", "https://env.infinity.ai")
expected_payload = {
"model": "rerank-model",
"query": "hello",
"top_n": 3,
"documents": ["hello", "world"],
}
with patch(
"litellm.llms.custom_httpx.http_handler.AsyncHTTPHandler.post",
return_value=mock_response,
) as mock_post:
response = await litellm.arerank(
model="infinity/rerank-model",
query="hello",
documents=["hello", "world"],
top_n=3,
)
print("async re rank response: ", response)
# Assert
mock_post.assert_called_once()
print("call args", mock_post.call_args)
args_to_api = mock_post.call_args.kwargs["data"]
_url = mock_post.call_args.kwargs["url"]
print("Arguments passed to API=", args_to_api)
print("url = ", _url)
assert _url == "https://env.infinity.ai/v1/rerank"
request_data = json.loads(args_to_api)
assert request_data["query"] == expected_payload["query"]
assert request_data["documents"] == expected_payload["documents"]
assert request_data["top_n"] == expected_payload["top_n"]
assert request_data["model"] == expected_payload["model"]
assert response.id is not None
assert response.results is not None
assert response.meta["tokens"]["input_tokens"] == 100
assert (
response.meta["tokens"]["output_tokens"] == 50
) # total_tokens - prompt_tokens
assert_response_shape(response, custom_llm_provider="infinity")