mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-27 11:43:54 +00:00
(feat) add infinity rerank models (#7321)
* Support Infinity Reranker (custom reranking models) (#7247) * Support Infinity Reranker * Clean code * Included transformation.py * Clean code * Added Infinity reranker test * Clean code --------- Co-authored-by: Ishaan Jaff <ishaanjaffer0324@gmail.com> * transform_rerank_response * update handler.py * infinity rerank updates * ci/cd run again * add infinity unit tests * docs add instruction on how to add a new provider for rerank --------- Co-authored-by: Hao Shan <53949959+haoshan98@users.noreply.github.com>
This commit is contained in:
parent
50204d9a6d
commit
6641e75e0c
11 changed files with 414 additions and 1 deletions
21
docs/my-website/docs/adding_provider/directory_structure.md
Normal file
21
docs/my-website/docs/adding_provider/directory_structure.md
Normal file
|
@ -0,0 +1,21 @@
|
|||
# Directory Structure
|
||||
|
||||
When adding a new provider, you need to create a directory for the provider that follows the following structure:
|
||||
|
||||
```
|
||||
litellm/llms/
|
||||
└── provider_name/
|
||||
├── completion/
|
||||
│ ├── handler.py
|
||||
│ └── transformation.py
|
||||
├── chat/
|
||||
│ ├── handler.py
|
||||
│ └── transformation.py
|
||||
├── embed/
|
||||
│ ├── handler.py
|
||||
│ └── transformation.py
|
||||
└── rerank/
|
||||
├── handler.py
|
||||
└── transformation.py
|
||||
```
|
||||
|
81
docs/my-website/docs/adding_provider/new_rerank_provider.md
Normal file
81
docs/my-website/docs/adding_provider/new_rerank_provider.md
Normal file
|
@ -0,0 +1,81 @@
|
|||
# Add Rerank Provider
|
||||
|
||||
LiteLLM **follows the Cohere Rerank API format** for all rerank providers. Here's how to add a new rerank provider:
|
||||
|
||||
## 1. Create a transformation.py file
|
||||
|
||||
Create a config class named `<Provider><Endpoint>Config` that inherits from `BaseRerankConfig`:
|
||||
|
||||
```python
|
||||
from litellm.types.rerank import OptionalRerankParams, RerankRequest, RerankResponse
|
||||
class YourProviderRerankConfig(BaseRerankConfig):
|
||||
def get_supported_cohere_rerank_params(self, model: str) -> list:
|
||||
return [
|
||||
"query",
|
||||
"documents",
|
||||
"top_n",
|
||||
# ... other supported params
|
||||
]
|
||||
|
||||
def transform_rerank_request(self, model: str, optional_rerank_params: OptionalRerankParams, headers: dict) -> dict:
|
||||
# Transform request to RerankRequest spec
|
||||
return rerank_request.model_dump(exclude_none=True)
|
||||
|
||||
def transform_rerank_response(self, model: str, raw_response: httpx.Response, ...) -> RerankResponse:
|
||||
# Transform provider response to RerankResponse
|
||||
return RerankResponse(**raw_response_json)
|
||||
```
|
||||
|
||||
|
||||
## 2. Register Your Provider
|
||||
Add your provider to `litellm.utils.get_provider_rerank_config()`:
|
||||
|
||||
```python
|
||||
elif litellm.LlmProviders.YOUR_PROVIDER == provider:
|
||||
return litellm.YourProviderRerankConfig()
|
||||
```
|
||||
|
||||
|
||||
## 3. Add Provider to `rerank_api/main.py`
|
||||
|
||||
Add a code block to handle when your provider is called. Your provider should use the `base_llm_http_handler.rerank` method
|
||||
|
||||
|
||||
```python
|
||||
elif _custom_llm_provider == "your_provider":
|
||||
...
|
||||
response = base_llm_http_handler.rerank(
|
||||
model=model,
|
||||
custom_llm_provider=_custom_llm_provider,
|
||||
optional_rerank_params=optional_rerank_params,
|
||||
logging_obj=litellm_logging_obj,
|
||||
timeout=optional_params.timeout,
|
||||
api_key=dynamic_api_key or optional_params.api_key,
|
||||
api_base=api_base,
|
||||
_is_async=_is_async,
|
||||
headers=headers or litellm.headers or {},
|
||||
client=client,
|
||||
mod el_response=model_response,
|
||||
)
|
||||
...
|
||||
```
|
||||
|
||||
## 4. Add Tests
|
||||
|
||||
Add a test file to [`tests/llm_translation`](https://github.com/BerriAI/litellm/tree/main/tests/llm_translation)
|
||||
|
||||
```python
|
||||
def test_basic_rerank_cohere():
|
||||
response = litellm.rerank(
|
||||
model="cohere/rerank-english-v3.0",
|
||||
query="hello",
|
||||
documents=["hello", "world"],
|
||||
top_n=3,
|
||||
)
|
||||
|
||||
print("re rank response: ", response)
|
||||
|
||||
assert response.id is not None
|
||||
assert response.results is not None
|
||||
```
|
||||
|
|
@ -320,6 +320,13 @@ const sidebars = {
|
|||
"load_test_rpm",
|
||||
]
|
||||
},
|
||||
{
|
||||
type: "category",
|
||||
label: "Adding Providers",
|
||||
items: [
|
||||
"adding_provider/directory_structure",
|
||||
"adding_provider/new_rerank_provider"],
|
||||
},
|
||||
{
|
||||
type: "category",
|
||||
label: "Logging & Observability",
|
||||
|
|
|
@ -126,6 +126,7 @@ azure_key: Optional[str] = None
|
|||
anthropic_key: Optional[str] = None
|
||||
replicate_key: Optional[str] = None
|
||||
cohere_key: Optional[str] = None
|
||||
infinity_key: Optional[str] = None
|
||||
clarifai_key: Optional[str] = None
|
||||
maritalk_key: Optional[str] = None
|
||||
ai21_key: Optional[str] = None
|
||||
|
@ -1025,6 +1026,7 @@ from .llms.replicate.chat.transformation import ReplicateConfig
|
|||
from .llms.cohere.completion.transformation import CohereTextConfig as CohereConfig
|
||||
from .llms.cohere.rerank.transformation import CohereRerankConfig
|
||||
from .llms.azure_ai.rerank.transformation import AzureAIRerankConfig
|
||||
from .llms.infinity.rerank.transformation import InfinityRerankConfig
|
||||
from .llms.clarifai.chat.transformation import ClarifaiConfig
|
||||
from .llms.ai21.chat.transformation import AI21ChatConfig, AI21ChatConfig as AI21Config
|
||||
from .llms.together_ai.chat import TogetherAIConfig
|
||||
|
|
19
litellm/llms/infinity/rerank/common_utils.py
Normal file
19
litellm/llms/infinity/rerank/common_utils.py
Normal file
|
@ -0,0 +1,19 @@
|
|||
import httpx
|
||||
|
||||
from litellm.llms.base_llm.chat.transformation import BaseLLMException
|
||||
|
||||
|
||||
class InfinityError(BaseLLMException):
|
||||
def __init__(self, status_code, message):
|
||||
self.status_code = status_code
|
||||
self.message = message
|
||||
self.request = httpx.Request(
|
||||
method="POST", url="https://github.com/michaelfeil/infinity"
|
||||
)
|
||||
self.response = httpx.Response(status_code=status_code, request=self.request)
|
||||
super().__init__(
|
||||
status_code=status_code,
|
||||
message=message,
|
||||
request=self.request,
|
||||
response=self.response,
|
||||
) # Call the base class constructor with the parameters it needs
|
5
litellm/llms/infinity/rerank/handler.py
Normal file
5
litellm/llms/infinity/rerank/handler.py
Normal file
|
@ -0,0 +1,5 @@
|
|||
"""
|
||||
Infinity Rerank - uses `llm_http_handler.py` to make httpx requests
|
||||
|
||||
Request/Response transformation is handled in `transformation.py`
|
||||
"""
|
91
litellm/llms/infinity/rerank/transformation.py
Normal file
91
litellm/llms/infinity/rerank/transformation.py
Normal file
|
@ -0,0 +1,91 @@
|
|||
"""
|
||||
Transformation logic from Cohere's /v1/rerank format to Infinity's `/v1/rerank` format.
|
||||
|
||||
Why separate file? Make it easy to see how transformation works
|
||||
"""
|
||||
|
||||
import uuid
|
||||
from typing import List, Optional
|
||||
|
||||
import httpx
|
||||
|
||||
import litellm
|
||||
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
|
||||
from litellm.llms.cohere.rerank.transformation import CohereRerankConfig
|
||||
from litellm.secret_managers.main import get_secret_str
|
||||
from litellm.types.rerank import RerankBilledUnits, RerankResponseMeta, RerankTokens
|
||||
from litellm.types.utils import RerankResponse
|
||||
|
||||
from .common_utils import InfinityError
|
||||
|
||||
|
||||
class InfinityRerankConfig(CohereRerankConfig):
|
||||
def validate_environment(
|
||||
self,
|
||||
headers: dict,
|
||||
model: str,
|
||||
api_key: Optional[str] = None,
|
||||
) -> dict:
|
||||
if api_key is None:
|
||||
api_key = (
|
||||
get_secret_str("INFINITY_API_KEY")
|
||||
or get_secret_str("INFINITY_API_KEY")
|
||||
or litellm.infinity_key
|
||||
)
|
||||
|
||||
default_headers = {
|
||||
"Authorization": f"bearer {api_key}",
|
||||
"accept": "application/json",
|
||||
"content-type": "application/json",
|
||||
}
|
||||
|
||||
# If 'Authorization' is provided in headers, it overrides the default.
|
||||
if "Authorization" in headers:
|
||||
default_headers["Authorization"] = headers["Authorization"]
|
||||
|
||||
# Merge other headers, overriding any default ones except Authorization
|
||||
return {**default_headers, **headers}
|
||||
|
||||
def transform_rerank_response(
|
||||
self,
|
||||
model: str,
|
||||
raw_response: httpx.Response,
|
||||
model_response: RerankResponse,
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
api_key: Optional[str] = None,
|
||||
request_data: dict = {},
|
||||
optional_params: dict = {},
|
||||
litellm_params: dict = {},
|
||||
) -> RerankResponse:
|
||||
"""
|
||||
Transform Infinity rerank response
|
||||
|
||||
No transformation required, Infinity follows Cohere API response format
|
||||
"""
|
||||
try:
|
||||
raw_response_json = raw_response.json()
|
||||
except Exception:
|
||||
raise InfinityError(
|
||||
message=raw_response.text, status_code=raw_response.status_code
|
||||
)
|
||||
|
||||
_billed_units = RerankBilledUnits(**raw_response_json.get("usage", {}))
|
||||
_tokens = RerankTokens(
|
||||
input_tokens=raw_response_json.get("usage", {}).get("prompt_tokens", 0),
|
||||
output_tokens=(
|
||||
raw_response_json.get("usage", {}).get("total_tokens", 0)
|
||||
- raw_response_json.get("usage", {}).get("prompt_tokens", 0)
|
||||
),
|
||||
)
|
||||
rerank_meta = RerankResponseMeta(billed_units=_billed_units, tokens=_tokens)
|
||||
|
||||
_results: Optional[List[dict]] = raw_response_json.get("results")
|
||||
|
||||
if _results is None:
|
||||
raise ValueError(f"No results found in the response={raw_response_json}")
|
||||
|
||||
return RerankResponse(
|
||||
id=raw_response_json.get("id") or str(uuid.uuid4()),
|
||||
results=_results, # type: ignore
|
||||
meta=rerank_meta,
|
||||
) # Return response
|
|
@ -76,7 +76,9 @@ def rerank( # noqa: PLR0915
|
|||
model: str,
|
||||
query: str,
|
||||
documents: List[Union[str, Dict[str, Any]]],
|
||||
custom_llm_provider: Optional[Literal["cohere", "together_ai", "azure_ai"]] = None,
|
||||
custom_llm_provider: Optional[
|
||||
Literal["cohere", "together_ai", "azure_ai", "infinity"]
|
||||
] = None,
|
||||
top_n: Optional[int] = None,
|
||||
rank_fields: Optional[List[str]] = None,
|
||||
return_documents: Optional[bool] = True,
|
||||
|
@ -188,6 +190,37 @@ def rerank( # noqa: PLR0915
|
|||
or litellm.api_base
|
||||
or get_secret("AZURE_AI_API_BASE") # type: ignore
|
||||
)
|
||||
response = base_llm_http_handler.rerank(
|
||||
model=model,
|
||||
custom_llm_provider=_custom_llm_provider,
|
||||
optional_rerank_params=optional_rerank_params,
|
||||
logging_obj=litellm_logging_obj,
|
||||
timeout=optional_params.timeout,
|
||||
api_key=dynamic_api_key or optional_params.api_key,
|
||||
api_base=api_base,
|
||||
_is_async=_is_async,
|
||||
headers=headers or litellm.headers or {},
|
||||
client=client,
|
||||
model_response=model_response,
|
||||
)
|
||||
elif _custom_llm_provider == "infinity":
|
||||
# Implement Infinity rerank logic
|
||||
api_key: Optional[str] = (
|
||||
dynamic_api_key or optional_params.api_key or litellm.api_key
|
||||
)
|
||||
|
||||
api_base: Optional[str] = (
|
||||
dynamic_api_base
|
||||
or optional_params.api_base
|
||||
or litellm.api_base
|
||||
or get_secret("INFINITY_API_BASE") # type: ignore
|
||||
)
|
||||
|
||||
if api_base is None:
|
||||
raise Exception(
|
||||
"Invalid api base. api_base=None. Set in call or via `INFINITY_API_BASE` env var."
|
||||
)
|
||||
|
||||
response = base_llm_http_handler.rerank(
|
||||
model=model,
|
||||
custom_llm_provider=_custom_llm_provider,
|
||||
|
|
|
@ -1741,6 +1741,7 @@ class LlmProviders(str, Enum):
|
|||
HOSTED_VLLM = "hosted_vllm"
|
||||
LM_STUDIO = "lm_studio"
|
||||
GALADRIEL = "galadriel"
|
||||
INFINITY = "infinity"
|
||||
|
||||
|
||||
class LiteLLMLoggingBaseClass:
|
||||
|
|
|
@ -6214,6 +6214,8 @@ class ProviderConfigManager:
|
|||
return litellm.CohereRerankConfig()
|
||||
elif litellm.LlmProviders.AZURE_AI == provider:
|
||||
return litellm.AzureAIRerankConfig()
|
||||
elif litellm.LlmProviders.INFINITY == provider:
|
||||
return litellm.InfinityRerankConfig()
|
||||
return litellm.CohereRerankConfig()
|
||||
|
||||
|
||||
|
|
151
tests/llm_translation/test_infinity.py
Normal file
151
tests/llm_translation/test_infinity.py
Normal file
|
@ -0,0 +1,151 @@
|
|||
import json
|
||||
import os
|
||||
import sys
|
||||
from datetime import datetime
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
sys.path.insert(
|
||||
0, os.path.abspath("../..")
|
||||
) # Adds the parent directory to the system-path
|
||||
|
||||
|
||||
import litellm
|
||||
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
from datetime import datetime
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
sys.path.insert(
|
||||
0, os.path.abspath("../..")
|
||||
) # Adds the parent directory to the system-path
|
||||
from test_rerank import assert_response_shape
|
||||
import litellm
|
||||
|
||||
|
||||
@pytest.mark.asyncio()
|
||||
async def test_infinity_rerank():
|
||||
mock_response = AsyncMock()
|
||||
|
||||
def return_val():
|
||||
return {
|
||||
"id": "cmpl-mockid",
|
||||
"results": [{"index": 0, "relevance_score": 0.95}],
|
||||
"usage": {"prompt_tokens": 100, "total_tokens": 150},
|
||||
}
|
||||
|
||||
mock_response.json = return_val
|
||||
mock_response.headers = {"key": "value"}
|
||||
mock_response.status_code = 200
|
||||
|
||||
expected_payload = {
|
||||
"model": "rerank-model",
|
||||
"query": "hello",
|
||||
"top_n": 3,
|
||||
"documents": ["hello", "world"],
|
||||
}
|
||||
|
||||
with patch(
|
||||
"litellm.llms.custom_httpx.http_handler.AsyncHTTPHandler.post",
|
||||
return_value=mock_response,
|
||||
) as mock_post:
|
||||
response = await litellm.arerank(
|
||||
model="infinity/rerank-model",
|
||||
query="hello",
|
||||
documents=["hello", "world"],
|
||||
top_n=3,
|
||||
api_base="https://api.infinity.ai",
|
||||
)
|
||||
|
||||
print("async re rank response: ", response)
|
||||
|
||||
# Assert
|
||||
mock_post.assert_called_once()
|
||||
print("call args", mock_post.call_args)
|
||||
args_to_api = mock_post.call_args.kwargs["data"]
|
||||
_url = mock_post.call_args.kwargs["url"]
|
||||
print("Arguments passed to API=", args_to_api)
|
||||
print("url = ", _url)
|
||||
assert _url == "https://api.infinity.ai/v1/rerank"
|
||||
|
||||
request_data = json.loads(args_to_api)
|
||||
assert request_data["query"] == expected_payload["query"]
|
||||
assert request_data["documents"] == expected_payload["documents"]
|
||||
assert request_data["top_n"] == expected_payload["top_n"]
|
||||
assert request_data["model"] == expected_payload["model"]
|
||||
|
||||
assert response.id is not None
|
||||
assert response.results is not None
|
||||
assert response.meta["tokens"]["input_tokens"] == 100
|
||||
assert (
|
||||
response.meta["tokens"]["output_tokens"] == 50
|
||||
) # total_tokens - prompt_tokens
|
||||
|
||||
assert_response_shape(response, custom_llm_provider="infinity")
|
||||
|
||||
|
||||
@pytest.mark.asyncio()
|
||||
async def test_infinity_rerank_with_env(monkeypatch):
|
||||
# Set up mock response
|
||||
mock_response = AsyncMock()
|
||||
|
||||
def return_val():
|
||||
return {
|
||||
"id": "cmpl-mockid",
|
||||
"results": [{"index": 0, "relevance_score": 0.95}],
|
||||
"usage": {"prompt_tokens": 100, "total_tokens": 150},
|
||||
}
|
||||
|
||||
mock_response.json = return_val
|
||||
mock_response.headers = {"key": "value"}
|
||||
mock_response.status_code = 200
|
||||
|
||||
# Set environment variable
|
||||
monkeypatch.setenv("INFINITY_API_BASE", "https://env.infinity.ai")
|
||||
|
||||
expected_payload = {
|
||||
"model": "rerank-model",
|
||||
"query": "hello",
|
||||
"top_n": 3,
|
||||
"documents": ["hello", "world"],
|
||||
}
|
||||
|
||||
with patch(
|
||||
"litellm.llms.custom_httpx.http_handler.AsyncHTTPHandler.post",
|
||||
return_value=mock_response,
|
||||
) as mock_post:
|
||||
response = await litellm.arerank(
|
||||
model="infinity/rerank-model",
|
||||
query="hello",
|
||||
documents=["hello", "world"],
|
||||
top_n=3,
|
||||
)
|
||||
|
||||
print("async re rank response: ", response)
|
||||
|
||||
# Assert
|
||||
mock_post.assert_called_once()
|
||||
print("call args", mock_post.call_args)
|
||||
args_to_api = mock_post.call_args.kwargs["data"]
|
||||
_url = mock_post.call_args.kwargs["url"]
|
||||
print("Arguments passed to API=", args_to_api)
|
||||
print("url = ", _url)
|
||||
assert _url == "https://env.infinity.ai/v1/rerank"
|
||||
|
||||
request_data = json.loads(args_to_api)
|
||||
assert request_data["query"] == expected_payload["query"]
|
||||
assert request_data["documents"] == expected_payload["documents"]
|
||||
assert request_data["top_n"] == expected_payload["top_n"]
|
||||
assert request_data["model"] == expected_payload["model"]
|
||||
|
||||
assert response.id is not None
|
||||
assert response.results is not None
|
||||
assert response.meta["tokens"]["input_tokens"] == 100
|
||||
assert (
|
||||
response.meta["tokens"]["output_tokens"] == 50
|
||||
) # total_tokens - prompt_tokens
|
||||
|
||||
assert_response_shape(response, custom_llm_provider="infinity")
|
Loading…
Add table
Add a link
Reference in a new issue