[Feat] Add infinity embedding support (contributor pr) (#10196)

* Feature - infinity support for #8764 (#10009)

* Added support for infinity embeddings

* Added test cases

* Fixed tests and api base

* Updated docs and tests

* Removed unused import

* Updated signature

* Added support for infinity embeddings

* Added test cases

* Fixed tests and api base

* Updated docs and tests

* Removed unused import

* Updated signature

* Updated validate params

---------

Co-authored-by: Ishaan Jaff <ishaanjaffer0324@gmail.com>

* fix InfinityEmbeddingConfig

---------

Co-authored-by: Prathamesh Saraf <pratamesh1867@gmail.com>
This commit is contained in:
Ishaan Jaff 2025-04-21 20:01:29 -07:00 committed by GitHub
parent 0c2f705417
commit 104e4cb1bc
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
12 changed files with 529 additions and 22 deletions

View file

@ -20,6 +20,8 @@ REPLICATE_API_TOKEN = ""
ANTHROPIC_API_KEY = ""
# Infisical
INFISICAL_TOKEN = ""
# INFINITY
INFINITY_API_KEY = ""
# Development Configs
LITELLM_MASTER_KEY = "sk-1234"

1
.gitignore vendored
View file

@ -86,4 +86,5 @@ litellm/proxy/db/migrations/0_init/migration.sql
litellm/proxy/db/migrations/*
litellm/proxy/migrations/*config.yaml
litellm/proxy/migrations/*
config.yaml
tests/litellm/litellm_core_utils/llm_cost_calc/log.txt

View file

@ -3,18 +3,17 @@ import TabItem from '@theme/TabItem';
# Infinity
| Property | Details |
|-------|-------|
| Description | Infinity is a high-throughput, low-latency REST API for serving text-embeddings, reranking models and clip|
| Provider Route on LiteLLM | `infinity/` |
| Supported Operations | `/rerank` |
| Link to Provider Doc | [Infinity ↗](https://github.com/michaelfeil/infinity) |
| Property | Details |
| ------------------------- | ---------------------------------------------------------------------------------------------------------- |
| Description | Infinity is a high-throughput, low-latency REST API for serving text-embeddings, reranking models and clip |
| Provider Route on LiteLLM | `infinity/` |
| Supported Operations | `/rerank`, `/embeddings` |
| Link to Provider Doc | [Infinity ↗](https://github.com/michaelfeil/infinity) |
## **Usage - LiteLLM Python SDK**
```python
from litellm import rerank
from litellm import rerank, embedding
import os
os.environ["INFINITY_API_BASE"] = "http://localhost:8080"
@ -39,8 +38,8 @@ model_list:
- model_name: custom-infinity-rerank
litellm_params:
model: infinity/rerank
api_key: os.environ/INFINITY_API_KEY
api_base: https://localhost:8080
api_key: os.environ/INFINITY_API_KEY
```
Start litellm
@ -51,7 +50,9 @@ litellm --config /path/to/config.yaml
# RUNNING on http://0.0.0.0:4000
```
Test request
## Test request:
### Rerank
```bash
curl http://0.0.0.0:4000/rerank \
@ -70,15 +71,14 @@ curl http://0.0.0.0:4000/rerank \
}'
```
#### Supported Cohere Rerank API Params
## Supported Cohere Rerank API Params
| Param | Type | Description |
|-------|-------|-------|
| `query` | `str` | The query to rerank the documents against |
| `documents` | `list[str]` | The documents to rerank |
| `top_n` | `int` | The number of documents to return |
| `return_documents` | `bool` | Whether to return the documents in the response |
| Param | Type | Description |
| ------------------ | ----------- | ----------------------------------------------- |
| `query` | `str` | The query to rerank the documents against |
| `documents` | `list[str]` | The documents to rerank |
| `top_n` | `int` | The number of documents to return |
| `return_documents` | `bool` | Whether to return the documents in the response |
### Usage - Return Documents
@ -138,6 +138,7 @@ response = rerank(
raw_scores=True, # 👈 PROVIDER-SPECIFIC PARAM
)
```
</TabItem>
<TabItem value="proxy" label="PROXY">
@ -161,7 +162,7 @@ litellm --config /path/to/config.yaml
# RUNNING on http://0.0.0.0:4000
```
3. Test it!
3. Test it!
```bash
curl http://0.0.0.0:4000/rerank \
@ -179,6 +180,121 @@ curl http://0.0.0.0:4000/rerank \
"raw_scores": True # 👈 PROVIDER-SPECIFIC PARAM
}'
```
</TabItem>
</Tabs>
## Embeddings
LiteLLM provides an OpenAI api compatible `/embeddings` endpoint for embedding calls.
**Setup**
Add this to your litellm proxy config.yaml
```yaml
model_list:
- model_name: custom-infinity-embedding
litellm_params:
model: infinity/provider/custom-embedding-v1
api_base: http://localhost:8080
api_key: os.environ/INFINITY_API_KEY
```
### Test request:
```bash
curl http://0.0.0.0:4000/embeddings \
-H "Authorization: Bearer sk-1234" \
-H "Content-Type: application/json" \
-d '{
"model": "custom-infinity-embedding",
"input": ["hello"]
}'
```
#### Supported Embedding API Params
| Param | Type | Description |
| ----------------- | ----------- | ----------------------------------------------------------- |
| `model` | `str` | The embedding model to use |
| `input` | `list[str]` | The text inputs to generate embeddings for |
| `encoding_format` | `str` | The format to return embeddings in (e.g. "float", "base64") |
| `modality` | `str` | The type of input (e.g. "text", "image", "audio") |
### Usage - Basic Examples
<Tabs>
<TabItem value="sdk" label="SDK">
```python
from litellm import embedding
import os
os.environ["INFINITY_API_BASE"] = "http://localhost:8080"
response = embedding(
model="infinity/bge-small",
input=["good morning from litellm"]
)
print(response.data[0]['embedding'])
```
</TabItem>
<TabItem value="proxy" label="PROXY">
```bash
curl http://0.0.0.0:4000/embeddings \
-H "Authorization: Bearer sk-1234" \
-H "Content-Type: application/json" \
-d '{
"model": "custom-infinity-embedding",
"input": ["hello"]
}'
```
</TabItem>
</Tabs>
### Usage - OpenAI Client
<Tabs>
<TabItem value="sdk" label="SDK">
```python
from openai import OpenAI
client = OpenAI(
api_key="<LITELLM_MASTER_KEY>",
base_url="<LITELLM_URL>"
)
response = client.embeddings.create(
model="bge-small",
input=["The food was delicious and the waiter..."],
encoding_format="float"
)
print(response.data[0].embedding)
```
</TabItem>
<TabItem value="proxy" label="PROXY">
```bash
curl http://0.0.0.0:4000/embeddings \
-H "Authorization: Bearer sk-1234" \
-H "Content-Type: application/json" \
-d '{
"model": "bge-small",
"input": ["The food was delicious and the waiter..."],
"encoding_format": "float"
}'
```
</TabItem>
</Tabs>

View file

@ -415,6 +415,7 @@ deepseek_models: List = []
azure_ai_models: List = []
jina_ai_models: List = []
voyage_models: List = []
infinity_models: List = []
databricks_models: List = []
cloudflare_models: List = []
codestral_models: List = []
@ -556,6 +557,8 @@ def add_known_models():
azure_ai_models.append(key)
elif value.get("litellm_provider") == "voyage":
voyage_models.append(key)
elif value.get("litellm_provider") == "infinity":
infinity_models.append(key)
elif value.get("litellm_provider") == "databricks":
databricks_models.append(key)
elif value.get("litellm_provider") == "cloudflare":
@ -644,6 +647,7 @@ model_list = (
+ deepseek_models
+ azure_ai_models
+ voyage_models
+ infinity_models
+ databricks_models
+ cloudflare_models
+ codestral_models
@ -699,6 +703,7 @@ models_by_provider: dict = {
"mistral": mistral_chat_models,
"azure_ai": azure_ai_models,
"voyage": voyage_models,
"infinity": infinity_models,
"databricks": databricks_models,
"cloudflare": cloudflare_models,
"codestral": codestral_models,
@ -946,6 +951,7 @@ from .llms.topaz.image_variations.transformation import TopazImageVariationConfi
from litellm.llms.openai.completion.transformation import OpenAITextCompletionConfig
from .llms.groq.chat.transformation import GroqChatConfig
from .llms.voyage.embedding.transformation import VoyageEmbeddingConfig
from .llms.infinity.embedding.transformation import InfinityEmbeddingConfig
from .llms.azure_ai.chat.transformation import AzureAIStudioConfig
from .llms.mistral.mistral_chat_transformation import MistralConfig
from .llms.openai.responses.transformation import OpenAIResponsesAPIConfig

View file

@ -221,6 +221,8 @@ def get_supported_openai_params( # noqa: PLR0915
return litellm.PredibaseConfig().get_supported_openai_params(model=model)
elif custom_llm_provider == "voyage":
return litellm.VoyageEmbeddingConfig().get_supported_openai_params(model=model)
elif custom_llm_provider == "infinity":
return litellm.InfinityEmbeddingConfig().get_supported_openai_params(model=model)
elif custom_llm_provider == "triton":
if request_type == "embeddings":
return litellm.TritonEmbeddingConfig().get_supported_openai_params(

View file

@ -1,10 +1,16 @@
from typing import Union
import httpx
from litellm.llms.base_llm.chat.transformation import BaseLLMException
class InfinityError(BaseLLMException):
def __init__(self, status_code, message):
def __init__(
self,
status_code: int,
message: str,
headers: Union[dict, httpx.Headers] = {}
):
self.status_code = status_code
self.message = message
self.request = httpx.Request(
@ -16,4 +22,5 @@ class InfinityError(BaseLLMException):
message=message,
request=self.request,
response=self.response,
headers=headers,
) # Call the base class constructor with the parameters it needs

View file

@ -0,0 +1,5 @@
"""
Infinity Embedding - uses `llm_http_handler.py` to make httpx requests
Request/Response transformation is handled in `transformation.py`
"""

View file

@ -0,0 +1,141 @@
from typing import List, Optional, Union
import httpx
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
from litellm.llms.base_llm.chat.transformation import BaseLLMException
from litellm.llms.base_llm.embedding.transformation import BaseEmbeddingConfig
from litellm.secret_managers.main import get_secret_str
from litellm.types.llms.openai import AllEmbeddingInputValues, AllMessageValues
from litellm.types.utils import EmbeddingResponse, Usage
from ..common_utils import InfinityError
class InfinityEmbeddingConfig(BaseEmbeddingConfig):
"""
Reference: https://infinity.modal.michaelfeil.eu/docs
"""
def __init__(self) -> None:
pass
def get_complete_url(
self,
api_base: Optional[str],
api_key: Optional[str],
model: str,
optional_params: dict,
litellm_params: dict,
stream: Optional[bool] = None,
) -> str:
if api_base is None:
raise ValueError("api_base is required for Infinity embeddings")
# Remove trailing slashes and ensure clean base URL
api_base = api_base.rstrip("/")
if not api_base.endswith("/embeddings"):
api_base = f"{api_base}/embeddings"
return api_base
def validate_environment(
self,
headers: dict,
model: str,
messages: List[AllMessageValues],
optional_params: dict,
litellm_params: dict,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
) -> dict:
if api_key is None:
api_key = get_secret_str("INFINITY_API_KEY")
default_headers = {
"Authorization": f"Bearer {api_key}",
"accept": "application/json",
"Content-Type": "application/json",
}
# If 'Authorization' is provided in headers, it overrides the default.
if "Authorization" in headers:
default_headers["Authorization"] = headers["Authorization"]
# Merge other headers, overriding any default ones except Authorization
return {**default_headers, **headers}
def get_supported_openai_params(self, model: str) -> list:
return [
"encoding_format",
"modality",
"dimensions",
]
def map_openai_params(
self,
non_default_params: dict,
optional_params: dict,
model: str,
drop_params: bool,
) -> dict:
"""
Map OpenAI params to Infinity params
Reference: https://infinity.modal.michaelfeil.eu/docs
"""
if "encoding_format" in non_default_params:
optional_params["encoding_format"] = non_default_params["encoding_format"]
if "modality" in non_default_params:
optional_params["modality"] = non_default_params["modality"]
if "dimensions" in non_default_params:
optional_params["output_dimension"] = non_default_params["dimensions"]
return optional_params
def transform_embedding_request(
self,
model: str,
input: AllEmbeddingInputValues,
optional_params: dict,
headers: dict,
) -> dict:
return {
"input": input,
"model": model,
**optional_params,
}
def transform_embedding_response(
self,
model: str,
raw_response: httpx.Response,
model_response: EmbeddingResponse,
logging_obj: LiteLLMLoggingObj,
api_key: Optional[str] = None,
request_data: dict = {},
optional_params: dict = {},
litellm_params: dict = {},
) -> EmbeddingResponse:
try:
raw_response_json = raw_response.json()
except Exception:
raise InfinityError(
message=raw_response.text, status_code=raw_response.status_code
)
# model_response.usage
model_response.model = raw_response_json.get("model")
model_response.data = raw_response_json.get("data")
model_response.object = raw_response_json.get("object")
usage = Usage(
prompt_tokens=raw_response_json.get("usage", {}).get("prompt_tokens", 0),
total_tokens=raw_response_json.get("usage", {}).get("total_tokens", 0),
)
model_response.usage = usage
return model_response
def get_error_class(
self, error_message: str, status_code: int, headers: Union[dict, httpx.Headers]
) -> BaseLLMException:
return InfinityError(
message=error_message, status_code=status_code, headers=headers
)

View file

@ -22,7 +22,7 @@ from litellm.types.rerank import (
RerankTokens,
)
from .common_utils import InfinityError
from ..common_utils import InfinityError
class InfinityRerankConfig(CohereRerankConfig):

View file

@ -3884,6 +3884,21 @@ def embedding( # noqa: PLR0915
aembedding=aembedding,
litellm_params={},
)
elif custom_llm_provider == "infinity":
response = base_llm_http_handler.embedding(
model=model,
input=input,
custom_llm_provider=custom_llm_provider,
api_base=api_base,
api_key=api_key,
logging_obj=logging,
timeout=timeout,
model_response=EmbeddingResponse(),
optional_params=optional_params,
client=client,
aembedding=aembedding,
litellm_params={},
)
elif custom_llm_provider == "watsonx":
credentials = IBMWatsonXMixin.get_watsonx_credentials(
optional_params=optional_params, api_key=api_key, api_base=api_base

View file

@ -2735,6 +2735,21 @@ def get_optional_params_embeddings( # noqa: PLR0915
)
final_params = {**optional_params, **kwargs}
return final_params
elif custom_llm_provider == "infinity":
supported_params = get_supported_openai_params(
model=model,
custom_llm_provider="infinity",
request_type="embeddings",
)
_check_valid_arg(supported_params=supported_params)
optional_params = litellm.InfinityEmbeddingConfig().map_openai_params(
non_default_params=non_default_params,
optional_params={},
model=model,
drop_params=drop_params if drop_params is not None else False,
)
final_params = {**optional_params, **kwargs}
return final_params
elif custom_llm_provider == "fireworks_ai":
supported_params = get_supported_openai_params(
model=model,
@ -5120,6 +5135,11 @@ def validate_environment( # noqa: PLR0915
keys_in_environment = True
else:
missing_keys.append("VOYAGE_API_KEY")
elif custom_llm_provider == "infinity":
if "INFINITY_API_KEY" in os.environ:
keys_in_environment = True
else:
missing_keys.append("INFINITY_API_KEY")
elif custom_llm_provider == "fireworks_ai":
if (
"FIREWORKS_AI_API_KEY" in os.environ
@ -6554,6 +6574,8 @@ class ProviderConfigManager:
return litellm.TritonEmbeddingConfig()
elif litellm.LlmProviders.WATSONX == provider:
return litellm.IBMWatsonXEmbeddingConfig()
elif litellm.LlmProviders.INFINITY == provider:
return litellm.InfinityEmbeddingConfig()
raise ValueError(f"Provider {provider.value} does not support embedding config")
@staticmethod

View file

@ -15,7 +15,7 @@ import json
import os
import sys
from datetime import datetime
from unittest.mock import AsyncMock, patch
from unittest.mock import patch, MagicMock, AsyncMock
import pytest
@ -25,6 +25,10 @@ sys.path.insert(
from test_rerank import assert_response_shape
import litellm
from base_embedding_unit_tests import BaseLLMEmbeddingTest
from litellm.llms.custom_httpx.http_handler import HTTPHandler, AsyncHTTPHandler
from litellm.types.utils import EmbeddingResponse, Usage
@pytest.mark.asyncio()
async def test_infinity_rerank():
@ -182,3 +186,189 @@ async def test_infinity_rerank_with_env(monkeypatch):
) # total_tokens - prompt_tokens
assert_response_shape(response, custom_llm_provider="infinity")
#### Embedding Tests
@pytest.mark.asyncio()
async def test_infinity_embedding():
mock_response = AsyncMock()
def return_val():
return {
"data": [{"embedding": [0.1, 0.2, 0.3], "index": 0}],
"usage": {"prompt_tokens": 100, "total_tokens": 150},
"model": "custom-model/embedding-v1",
"object": "list"
}
mock_response.json = return_val
mock_response.headers = {"key": "value"}
mock_response.status_code = 200
expected_payload = {
"model": "custom-model/embedding-v1",
"input": ["hello world"],
"encoding_format": "float",
"output_dimension": 512
}
with patch(
"litellm.llms.custom_httpx.http_handler.AsyncHTTPHandler.post",
return_value=mock_response,
) as mock_post:
response = await litellm.aembedding(
model="infinity/custom-model/embedding-v1",
input=["hello world"],
dimensions=512,
encoding_format="float",
api_base="https://api.infinity.ai/embeddings",
)
# Assert
mock_post.assert_called_once()
print("call args", mock_post.call_args)
args_to_api = mock_post.call_args.kwargs["data"]
_url = mock_post.call_args.kwargs["url"]
assert _url == "https://api.infinity.ai/embeddings"
request_data = json.loads(args_to_api)
assert request_data["input"] == expected_payload["input"]
assert request_data["model"] == expected_payload["model"]
assert request_data["output_dimension"] == expected_payload["output_dimension"]
assert request_data["encoding_format"] == expected_payload["encoding_format"]
assert response.data is not None
assert response.usage.prompt_tokens == 100
assert response.usage.total_tokens == 150
assert response.model == "custom-model/embedding-v1"
assert response.object == "list"
@pytest.mark.asyncio()
async def test_infinity_embedding_with_env(monkeypatch):
# Set up mock response
mock_response = AsyncMock()
def return_val():
return {
"data": [{"embedding": [0.1, 0.2, 0.3], "index": 0}],
"usage": {"prompt_tokens": 100, "total_tokens": 150},
"model": "custom-model/embedding-v1",
"object": "list"
}
mock_response.json = return_val
mock_response.headers = {"key": "value"}
mock_response.status_code = 200
expected_payload = {
"model": "custom-model/embedding-v1",
"input": ["hello world"],
"encoding_format": "float",
"output_dimension": 512
}
with patch(
"litellm.llms.custom_httpx.http_handler.AsyncHTTPHandler.post",
return_value=mock_response,
) as mock_post:
response = await litellm.aembedding(
model="infinity/custom-model/embedding-v1",
input=["hello world"],
dimensions=512,
encoding_format="float",
api_base="https://api.infinity.ai/embeddings",
)
# Assert
mock_post.assert_called_once()
print("call args", mock_post.call_args)
args_to_api = mock_post.call_args.kwargs["data"]
_url = mock_post.call_args.kwargs["url"]
assert _url == "https://api.infinity.ai/embeddings"
request_data = json.loads(args_to_api)
assert request_data["input"] == expected_payload["input"]
assert request_data["model"] == expected_payload["model"]
assert request_data["output_dimension"] == expected_payload["output_dimension"]
assert request_data["encoding_format"] == expected_payload["encoding_format"]
assert response.data is not None
assert response.usage.prompt_tokens == 100
assert response.usage.total_tokens == 150
assert response.model == "custom-model/embedding-v1"
assert response.object == "list"
@pytest.mark.asyncio()
async def test_infinity_embedding_extra_params():
mock_response = AsyncMock()
def return_val():
return {
"data": [{"embedding": [0.1, 0.2, 0.3], "index": 0}],
"usage": {"prompt_tokens": 100, "total_tokens": 150},
"model": "custom-model/embedding-v1",
"object": "list"
}
mock_response.json = return_val
mock_response.headers = {"key": "value"}
mock_response.status_code = 200
with patch(
"litellm.llms.custom_httpx.http_handler.AsyncHTTPHandler.post",
return_value=mock_response,
) as mock_post:
response = await litellm.aembedding(
model="infinity/custom-model/embedding-v1",
input=["test input"],
dimensions=512,
encoding_format="float",
modality="text",
api_base="https://api.infinity.ai/embeddings",
)
mock_post.assert_called_once()
json_data = json.loads(mock_post.call_args.kwargs["data"])
# Assert the request parameters
assert json_data["input"] == ["test input"]
assert json_data["model"] == "custom-model/embedding-v1"
assert json_data["output_dimension"] == 512
assert json_data["encoding_format"] == "float"
assert json_data["modality"] == "text"
@pytest.mark.asyncio()
async def test_infinity_embedding_prompt_token_mapping():
mock_response = AsyncMock()
def return_val():
return {
"data": [{"embedding": [0.1, 0.2, 0.3], "index": 0}],
"usage": {"total_tokens": 1, "prompt_tokens": 1},
"model": "custom-model/embedding-v1",
"object": "list"
}
mock_response.json = return_val
mock_response.headers = {"key": "value"}
mock_response.status_code = 200
with patch(
"litellm.llms.custom_httpx.http_handler.AsyncHTTPHandler.post",
return_value=mock_response,
) as mock_post:
response = await litellm.aembedding(
model="infinity/custom-model/embedding-v1",
input=["a"],
dimensions=512,
encoding_format="float",
api_base="https://api.infinity.ai/embeddings",
)
mock_post.assert_called_once()
# Assert the response
assert response.usage.prompt_tokens == 1
assert response.usage.total_tokens == 1