diff --git a/.env.example b/.env.example
index 82b09ca25e..54986a97cd 100644
--- a/.env.example
+++ b/.env.example
@@ -20,6 +20,8 @@ REPLICATE_API_TOKEN = ""
ANTHROPIC_API_KEY = ""
# Infisical
INFISICAL_TOKEN = ""
+# INFINITY
+INFINITY_API_KEY = ""
# Development Configs
LITELLM_MASTER_KEY = "sk-1234"
diff --git a/.gitignore b/.gitignore
index 4259b80f55..e8c18bed4c 100644
--- a/.gitignore
+++ b/.gitignore
@@ -86,4 +86,5 @@ litellm/proxy/db/migrations/0_init/migration.sql
litellm/proxy/db/migrations/*
litellm/proxy/migrations/*config.yaml
litellm/proxy/migrations/*
+config.yaml
tests/litellm/litellm_core_utils/llm_cost_calc/log.txt
diff --git a/docs/my-website/docs/providers/infinity.md b/docs/my-website/docs/providers/infinity.md
index 091503bf18..7900d5adb4 100644
--- a/docs/my-website/docs/providers/infinity.md
+++ b/docs/my-website/docs/providers/infinity.md
@@ -3,18 +3,17 @@ import TabItem from '@theme/TabItem';
# Infinity
-| Property | Details |
-|-------|-------|
-| Description | Infinity is a high-throughput, low-latency REST API for serving text-embeddings, reranking models and clip|
-| Provider Route on LiteLLM | `infinity/` |
-| Supported Operations | `/rerank` |
-| Link to Provider Doc | [Infinity ↗](https://github.com/michaelfeil/infinity) |
-
+| Property | Details |
+| ------------------------- | ---------------------------------------------------------------------------------------------------------- |
+| Description | Infinity is a high-throughput, low-latency REST API for serving text-embeddings, reranking models and clip |
+| Provider Route on LiteLLM | `infinity/` |
+| Supported Operations | `/rerank`, `/embeddings` |
+| Link to Provider Doc | [Infinity ↗](https://github.com/michaelfeil/infinity) |
## **Usage - LiteLLM Python SDK**
```python
-from litellm import rerank
+from litellm import rerank, embedding
import os
os.environ["INFINITY_API_BASE"] = "http://localhost:8080"
@@ -39,8 +38,8 @@ model_list:
- model_name: custom-infinity-rerank
litellm_params:
model: infinity/rerank
- api_key: os.environ/INFINITY_API_KEY
api_base: https://localhost:8080
+ api_key: os.environ/INFINITY_API_KEY
```
Start litellm
@@ -51,7 +50,9 @@ litellm --config /path/to/config.yaml
# RUNNING on http://0.0.0.0:4000
```
-Test request
+## Test request:
+
+### Rerank
```bash
curl http://0.0.0.0:4000/rerank \
@@ -70,15 +71,14 @@ curl http://0.0.0.0:4000/rerank \
}'
```
+#### Supported Cohere Rerank API Params
-## Supported Cohere Rerank API Params
-
-| Param | Type | Description |
-|-------|-------|-------|
-| `query` | `str` | The query to rerank the documents against |
-| `documents` | `list[str]` | The documents to rerank |
-| `top_n` | `int` | The number of documents to return |
-| `return_documents` | `bool` | Whether to return the documents in the response |
+| Param | Type | Description |
+| ------------------ | ----------- | ----------------------------------------------- |
+| `query` | `str` | The query to rerank the documents against |
+| `documents` | `list[str]` | The documents to rerank |
+| `top_n` | `int` | The number of documents to return |
+| `return_documents` | `bool` | Whether to return the documents in the response |
### Usage - Return Documents
@@ -138,6 +138,7 @@ response = rerank(
raw_scores=True, # 👈 PROVIDER-SPECIFIC PARAM
)
```
+
@@ -161,7 +162,7 @@ litellm --config /path/to/config.yaml
# RUNNING on http://0.0.0.0:4000
```
-3. Test it!
+3. Test it!
```bash
curl http://0.0.0.0:4000/rerank \
@@ -179,6 +180,121 @@ curl http://0.0.0.0:4000/rerank \
"raw_scores": True # 👈 PROVIDER-SPECIFIC PARAM
}'
```
+
+
+## Embeddings
+
+LiteLLM provides an OpenAI api compatible `/embeddings` endpoint for embedding calls.
+
+**Setup**
+
+Add this to your litellm proxy config.yaml
+
+```yaml
+model_list:
+ - model_name: custom-infinity-embedding
+ litellm_params:
+ model: infinity/provider/custom-embedding-v1
+ api_base: http://localhost:8080
+ api_key: os.environ/INFINITY_API_KEY
+```
+
+### Test request:
+
+```bash
+curl http://0.0.0.0:4000/embeddings \
+ -H "Authorization: Bearer sk-1234" \
+ -H "Content-Type: application/json" \
+ -d '{
+ "model": "custom-infinity-embedding",
+ "input": ["hello"]
+ }'
+```
+
+#### Supported Embedding API Params
+
+| Param | Type | Description |
+| ----------------- | ----------- | ----------------------------------------------------------- |
+| `model` | `str` | The embedding model to use |
+| `input` | `list[str]` | The text inputs to generate embeddings for |
+| `encoding_format` | `str` | The format to return embeddings in (e.g. "float", "base64") |
+| `modality` | `str` | The type of input (e.g. "text", "image", "audio") |
+
+### Usage - Basic Examples
+
+
+
+
+```python
+from litellm import embedding
+import os
+
+os.environ["INFINITY_API_BASE"] = "http://localhost:8080"
+
+response = embedding(
+ model="infinity/bge-small",
+ input=["good morning from litellm"]
+)
+
+print(response.data[0]['embedding'])
+```
+
+
+
+
+
+```bash
+curl http://0.0.0.0:4000/embeddings \
+ -H "Authorization: Bearer sk-1234" \
+ -H "Content-Type: application/json" \
+ -d '{
+ "model": "custom-infinity-embedding",
+ "input": ["hello"]
+ }'
+```
+
+
+
+
+### Usage - OpenAI Client
+
+
+
+
+```python
+from openai import OpenAI
+
+client = OpenAI(
+ api_key="",
+ base_url=""
+)
+
+response = client.embeddings.create(
+ model="bge-small",
+ input=["The food was delicious and the waiter..."],
+ encoding_format="float"
+)
+
+print(response.data[0].embedding)
+```
+
+
+
+
+
+```bash
+curl http://0.0.0.0:4000/embeddings \
+ -H "Authorization: Bearer sk-1234" \
+ -H "Content-Type: application/json" \
+ -d '{
+ "model": "bge-small",
+ "input": ["The food was delicious and the waiter..."],
+ "encoding_format": "float"
+ }'
+```
+
+
+
diff --git a/litellm/__init__.py b/litellm/__init__.py
index e9dadbfaf6..a7b747541a 100644
--- a/litellm/__init__.py
+++ b/litellm/__init__.py
@@ -415,6 +415,7 @@ deepseek_models: List = []
azure_ai_models: List = []
jina_ai_models: List = []
voyage_models: List = []
+infinity_models: List = []
databricks_models: List = []
cloudflare_models: List = []
codestral_models: List = []
@@ -556,6 +557,8 @@ def add_known_models():
azure_ai_models.append(key)
elif value.get("litellm_provider") == "voyage":
voyage_models.append(key)
+ elif value.get("litellm_provider") == "infinity":
+ infinity_models.append(key)
elif value.get("litellm_provider") == "databricks":
databricks_models.append(key)
elif value.get("litellm_provider") == "cloudflare":
@@ -644,6 +647,7 @@ model_list = (
+ deepseek_models
+ azure_ai_models
+ voyage_models
+ + infinity_models
+ databricks_models
+ cloudflare_models
+ codestral_models
@@ -699,6 +703,7 @@ models_by_provider: dict = {
"mistral": mistral_chat_models,
"azure_ai": azure_ai_models,
"voyage": voyage_models,
+ "infinity": infinity_models,
"databricks": databricks_models,
"cloudflare": cloudflare_models,
"codestral": codestral_models,
@@ -946,6 +951,7 @@ from .llms.topaz.image_variations.transformation import TopazImageVariationConfi
from litellm.llms.openai.completion.transformation import OpenAITextCompletionConfig
from .llms.groq.chat.transformation import GroqChatConfig
from .llms.voyage.embedding.transformation import VoyageEmbeddingConfig
+from .llms.infinity.embedding.transformation import InfinityEmbeddingConfig
from .llms.azure_ai.chat.transformation import AzureAIStudioConfig
from .llms.mistral.mistral_chat_transformation import MistralConfig
from .llms.openai.responses.transformation import OpenAIResponsesAPIConfig
diff --git a/litellm/litellm_core_utils/get_supported_openai_params.py b/litellm/litellm_core_utils/get_supported_openai_params.py
index bcf9fdb961..c0f638ddc2 100644
--- a/litellm/litellm_core_utils/get_supported_openai_params.py
+++ b/litellm/litellm_core_utils/get_supported_openai_params.py
@@ -221,6 +221,8 @@ def get_supported_openai_params( # noqa: PLR0915
return litellm.PredibaseConfig().get_supported_openai_params(model=model)
elif custom_llm_provider == "voyage":
return litellm.VoyageEmbeddingConfig().get_supported_openai_params(model=model)
+ elif custom_llm_provider == "infinity":
+ return litellm.InfinityEmbeddingConfig().get_supported_openai_params(model=model)
elif custom_llm_provider == "triton":
if request_type == "embeddings":
return litellm.TritonEmbeddingConfig().get_supported_openai_params(
diff --git a/litellm/llms/infinity/rerank/common_utils.py b/litellm/llms/infinity/common_utils.py
similarity index 76%
rename from litellm/llms/infinity/rerank/common_utils.py
rename to litellm/llms/infinity/common_utils.py
index 99477d1a33..089818c829 100644
--- a/litellm/llms/infinity/rerank/common_utils.py
+++ b/litellm/llms/infinity/common_utils.py
@@ -1,10 +1,16 @@
+from typing import Union
import httpx
from litellm.llms.base_llm.chat.transformation import BaseLLMException
class InfinityError(BaseLLMException):
- def __init__(self, status_code, message):
+ def __init__(
+ self,
+ status_code: int,
+ message: str,
+ headers: Union[dict, httpx.Headers] = {}
+ ):
self.status_code = status_code
self.message = message
self.request = httpx.Request(
@@ -16,4 +22,5 @@ class InfinityError(BaseLLMException):
message=message,
request=self.request,
response=self.response,
+ headers=headers,
) # Call the base class constructor with the parameters it needs
diff --git a/litellm/llms/infinity/embedding/handler.py b/litellm/llms/infinity/embedding/handler.py
new file mode 100644
index 0000000000..cdcb99c433
--- /dev/null
+++ b/litellm/llms/infinity/embedding/handler.py
@@ -0,0 +1,5 @@
+"""
+Infinity Embedding - uses `llm_http_handler.py` to make httpx requests
+
+Request/Response transformation is handled in `transformation.py`
+"""
diff --git a/litellm/llms/infinity/embedding/transformation.py b/litellm/llms/infinity/embedding/transformation.py
new file mode 100644
index 0000000000..824dcd38da
--- /dev/null
+++ b/litellm/llms/infinity/embedding/transformation.py
@@ -0,0 +1,141 @@
+from typing import List, Optional, Union
+
+import httpx
+
+from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
+from litellm.llms.base_llm.chat.transformation import BaseLLMException
+from litellm.llms.base_llm.embedding.transformation import BaseEmbeddingConfig
+from litellm.secret_managers.main import get_secret_str
+from litellm.types.llms.openai import AllEmbeddingInputValues, AllMessageValues
+from litellm.types.utils import EmbeddingResponse, Usage
+
+from ..common_utils import InfinityError
+
+
+class InfinityEmbeddingConfig(BaseEmbeddingConfig):
+ """
+ Reference: https://infinity.modal.michaelfeil.eu/docs
+ """
+
+ def __init__(self) -> None:
+ pass
+
+ def get_complete_url(
+ self,
+ api_base: Optional[str],
+ api_key: Optional[str],
+ model: str,
+ optional_params: dict,
+ litellm_params: dict,
+ stream: Optional[bool] = None,
+ ) -> str:
+ if api_base is None:
+ raise ValueError("api_base is required for Infinity embeddings")
+ # Remove trailing slashes and ensure clean base URL
+ api_base = api_base.rstrip("/")
+ if not api_base.endswith("/embeddings"):
+ api_base = f"{api_base}/embeddings"
+ return api_base
+
+ def validate_environment(
+ self,
+ headers: dict,
+ model: str,
+ messages: List[AllMessageValues],
+ optional_params: dict,
+ litellm_params: dict,
+ api_key: Optional[str] = None,
+ api_base: Optional[str] = None,
+ ) -> dict:
+ if api_key is None:
+ api_key = get_secret_str("INFINITY_API_KEY")
+
+ default_headers = {
+ "Authorization": f"Bearer {api_key}",
+ "accept": "application/json",
+ "Content-Type": "application/json",
+ }
+
+ # If 'Authorization' is provided in headers, it overrides the default.
+ if "Authorization" in headers:
+ default_headers["Authorization"] = headers["Authorization"]
+
+ # Merge other headers, overriding any default ones except Authorization
+ return {**default_headers, **headers}
+
+ def get_supported_openai_params(self, model: str) -> list:
+ return [
+ "encoding_format",
+ "modality",
+ "dimensions",
+ ]
+
+ def map_openai_params(
+ self,
+ non_default_params: dict,
+ optional_params: dict,
+ model: str,
+ drop_params: bool,
+ ) -> dict:
+ """
+ Map OpenAI params to Infinity params
+
+ Reference: https://infinity.modal.michaelfeil.eu/docs
+ """
+ if "encoding_format" in non_default_params:
+ optional_params["encoding_format"] = non_default_params["encoding_format"]
+ if "modality" in non_default_params:
+ optional_params["modality"] = non_default_params["modality"]
+ if "dimensions" in non_default_params:
+ optional_params["output_dimension"] = non_default_params["dimensions"]
+ return optional_params
+
+ def transform_embedding_request(
+ self,
+ model: str,
+ input: AllEmbeddingInputValues,
+ optional_params: dict,
+ headers: dict,
+ ) -> dict:
+ return {
+ "input": input,
+ "model": model,
+ **optional_params,
+ }
+
+ def transform_embedding_response(
+ self,
+ model: str,
+ raw_response: httpx.Response,
+ model_response: EmbeddingResponse,
+ logging_obj: LiteLLMLoggingObj,
+ api_key: Optional[str] = None,
+ request_data: dict = {},
+ optional_params: dict = {},
+ litellm_params: dict = {},
+ ) -> EmbeddingResponse:
+ try:
+ raw_response_json = raw_response.json()
+ except Exception:
+ raise InfinityError(
+ message=raw_response.text, status_code=raw_response.status_code
+ )
+
+ # model_response.usage
+ model_response.model = raw_response_json.get("model")
+ model_response.data = raw_response_json.get("data")
+ model_response.object = raw_response_json.get("object")
+
+ usage = Usage(
+ prompt_tokens=raw_response_json.get("usage", {}).get("prompt_tokens", 0),
+ total_tokens=raw_response_json.get("usage", {}).get("total_tokens", 0),
+ )
+ model_response.usage = usage
+ return model_response
+
+ def get_error_class(
+ self, error_message: str, status_code: int, headers: Union[dict, httpx.Headers]
+ ) -> BaseLLMException:
+ return InfinityError(
+ message=error_message, status_code=status_code, headers=headers
+ )
diff --git a/litellm/llms/infinity/rerank/transformation.py b/litellm/llms/infinity/rerank/transformation.py
index 1e7234ab17..4b75fa121b 100644
--- a/litellm/llms/infinity/rerank/transformation.py
+++ b/litellm/llms/infinity/rerank/transformation.py
@@ -22,7 +22,7 @@ from litellm.types.rerank import (
RerankTokens,
)
-from .common_utils import InfinityError
+from ..common_utils import InfinityError
class InfinityRerankConfig(CohereRerankConfig):
diff --git a/litellm/main.py b/litellm/main.py
index 9bb1cf0c15..80486fbe02 100644
--- a/litellm/main.py
+++ b/litellm/main.py
@@ -3884,6 +3884,21 @@ def embedding( # noqa: PLR0915
aembedding=aembedding,
litellm_params={},
)
+ elif custom_llm_provider == "infinity":
+ response = base_llm_http_handler.embedding(
+ model=model,
+ input=input,
+ custom_llm_provider=custom_llm_provider,
+ api_base=api_base,
+ api_key=api_key,
+ logging_obj=logging,
+ timeout=timeout,
+ model_response=EmbeddingResponse(),
+ optional_params=optional_params,
+ client=client,
+ aembedding=aembedding,
+ litellm_params={},
+ )
elif custom_llm_provider == "watsonx":
credentials = IBMWatsonXMixin.get_watsonx_credentials(
optional_params=optional_params, api_key=api_key, api_base=api_base
diff --git a/litellm/utils.py b/litellm/utils.py
index 3efd188717..38e604943a 100644
--- a/litellm/utils.py
+++ b/litellm/utils.py
@@ -2735,6 +2735,21 @@ def get_optional_params_embeddings( # noqa: PLR0915
)
final_params = {**optional_params, **kwargs}
return final_params
+ elif custom_llm_provider == "infinity":
+ supported_params = get_supported_openai_params(
+ model=model,
+ custom_llm_provider="infinity",
+ request_type="embeddings",
+ )
+ _check_valid_arg(supported_params=supported_params)
+ optional_params = litellm.InfinityEmbeddingConfig().map_openai_params(
+ non_default_params=non_default_params,
+ optional_params={},
+ model=model,
+ drop_params=drop_params if drop_params is not None else False,
+ )
+ final_params = {**optional_params, **kwargs}
+ return final_params
elif custom_llm_provider == "fireworks_ai":
supported_params = get_supported_openai_params(
model=model,
@@ -5120,6 +5135,11 @@ def validate_environment( # noqa: PLR0915
keys_in_environment = True
else:
missing_keys.append("VOYAGE_API_KEY")
+ elif custom_llm_provider == "infinity":
+ if "INFINITY_API_KEY" in os.environ:
+ keys_in_environment = True
+ else:
+ missing_keys.append("INFINITY_API_KEY")
elif custom_llm_provider == "fireworks_ai":
if (
"FIREWORKS_AI_API_KEY" in os.environ
@@ -6554,6 +6574,8 @@ class ProviderConfigManager:
return litellm.TritonEmbeddingConfig()
elif litellm.LlmProviders.WATSONX == provider:
return litellm.IBMWatsonXEmbeddingConfig()
+ elif litellm.LlmProviders.INFINITY == provider:
+ return litellm.InfinityEmbeddingConfig()
raise ValueError(f"Provider {provider.value} does not support embedding config")
@staticmethod
diff --git a/tests/llm_translation/test_infinity.py b/tests/llm_translation/test_infinity.py
index eb986b8ab5..802789377c 100644
--- a/tests/llm_translation/test_infinity.py
+++ b/tests/llm_translation/test_infinity.py
@@ -15,7 +15,7 @@ import json
import os
import sys
from datetime import datetime
-from unittest.mock import AsyncMock, patch
+from unittest.mock import patch, MagicMock, AsyncMock
import pytest
@@ -25,6 +25,10 @@ sys.path.insert(
from test_rerank import assert_response_shape
import litellm
+from base_embedding_unit_tests import BaseLLMEmbeddingTest
+from litellm.llms.custom_httpx.http_handler import HTTPHandler, AsyncHTTPHandler
+from litellm.types.utils import EmbeddingResponse, Usage
+
@pytest.mark.asyncio()
async def test_infinity_rerank():
@@ -182,3 +186,189 @@ async def test_infinity_rerank_with_env(monkeypatch):
) # total_tokens - prompt_tokens
assert_response_shape(response, custom_llm_provider="infinity")
+
+#### Embedding Tests
+@pytest.mark.asyncio()
+async def test_infinity_embedding():
+ mock_response = AsyncMock()
+
+ def return_val():
+ return {
+ "data": [{"embedding": [0.1, 0.2, 0.3], "index": 0}],
+ "usage": {"prompt_tokens": 100, "total_tokens": 150},
+ "model": "custom-model/embedding-v1",
+ "object": "list"
+ }
+
+ mock_response.json = return_val
+ mock_response.headers = {"key": "value"}
+ mock_response.status_code = 200
+
+ expected_payload = {
+ "model": "custom-model/embedding-v1",
+ "input": ["hello world"],
+ "encoding_format": "float",
+ "output_dimension": 512
+ }
+
+ with patch(
+ "litellm.llms.custom_httpx.http_handler.AsyncHTTPHandler.post",
+ return_value=mock_response,
+ ) as mock_post:
+ response = await litellm.aembedding(
+ model="infinity/custom-model/embedding-v1",
+ input=["hello world"],
+ dimensions=512,
+ encoding_format="float",
+ api_base="https://api.infinity.ai/embeddings",
+
+ )
+
+ # Assert
+ mock_post.assert_called_once()
+ print("call args", mock_post.call_args)
+ args_to_api = mock_post.call_args.kwargs["data"]
+ _url = mock_post.call_args.kwargs["url"]
+ assert _url == "https://api.infinity.ai/embeddings"
+
+ request_data = json.loads(args_to_api)
+ assert request_data["input"] == expected_payload["input"]
+ assert request_data["model"] == expected_payload["model"]
+ assert request_data["output_dimension"] == expected_payload["output_dimension"]
+ assert request_data["encoding_format"] == expected_payload["encoding_format"]
+
+ assert response.data is not None
+ assert response.usage.prompt_tokens == 100
+ assert response.usage.total_tokens == 150
+ assert response.model == "custom-model/embedding-v1"
+ assert response.object == "list"
+
+
+@pytest.mark.asyncio()
+async def test_infinity_embedding_with_env(monkeypatch):
+ # Set up mock response
+ mock_response = AsyncMock()
+
+ def return_val():
+ return {
+ "data": [{"embedding": [0.1, 0.2, 0.3], "index": 0}],
+ "usage": {"prompt_tokens": 100, "total_tokens": 150},
+ "model": "custom-model/embedding-v1",
+ "object": "list"
+ }
+
+ mock_response.json = return_val
+ mock_response.headers = {"key": "value"}
+ mock_response.status_code = 200
+
+ expected_payload = {
+ "model": "custom-model/embedding-v1",
+ "input": ["hello world"],
+ "encoding_format": "float",
+ "output_dimension": 512
+ }
+
+ with patch(
+ "litellm.llms.custom_httpx.http_handler.AsyncHTTPHandler.post",
+ return_value=mock_response,
+ ) as mock_post:
+ response = await litellm.aembedding(
+ model="infinity/custom-model/embedding-v1",
+ input=["hello world"],
+ dimensions=512,
+ encoding_format="float",
+ api_base="https://api.infinity.ai/embeddings",
+ )
+
+ # Assert
+ mock_post.assert_called_once()
+ print("call args", mock_post.call_args)
+ args_to_api = mock_post.call_args.kwargs["data"]
+ _url = mock_post.call_args.kwargs["url"]
+ assert _url == "https://api.infinity.ai/embeddings"
+
+ request_data = json.loads(args_to_api)
+ assert request_data["input"] == expected_payload["input"]
+ assert request_data["model"] == expected_payload["model"]
+ assert request_data["output_dimension"] == expected_payload["output_dimension"]
+ assert request_data["encoding_format"] == expected_payload["encoding_format"]
+
+ assert response.data is not None
+ assert response.usage.prompt_tokens == 100
+ assert response.usage.total_tokens == 150
+ assert response.model == "custom-model/embedding-v1"
+ assert response.object == "list"
+
+
+@pytest.mark.asyncio()
+async def test_infinity_embedding_extra_params():
+ mock_response = AsyncMock()
+
+ def return_val():
+ return {
+ "data": [{"embedding": [0.1, 0.2, 0.3], "index": 0}],
+ "usage": {"prompt_tokens": 100, "total_tokens": 150},
+ "model": "custom-model/embedding-v1",
+ "object": "list"
+ }
+
+ mock_response.json = return_val
+ mock_response.headers = {"key": "value"}
+ mock_response.status_code = 200
+
+ with patch(
+ "litellm.llms.custom_httpx.http_handler.AsyncHTTPHandler.post",
+ return_value=mock_response,
+ ) as mock_post:
+ response = await litellm.aembedding(
+ model="infinity/custom-model/embedding-v1",
+ input=["test input"],
+ dimensions=512,
+ encoding_format="float",
+ modality="text",
+ api_base="https://api.infinity.ai/embeddings",
+ )
+
+ mock_post.assert_called_once()
+ json_data = json.loads(mock_post.call_args.kwargs["data"])
+
+ # Assert the request parameters
+ assert json_data["input"] == ["test input"]
+ assert json_data["model"] == "custom-model/embedding-v1"
+ assert json_data["output_dimension"] == 512
+ assert json_data["encoding_format"] == "float"
+ assert json_data["modality"] == "text"
+
+
+@pytest.mark.asyncio()
+async def test_infinity_embedding_prompt_token_mapping():
+ mock_response = AsyncMock()
+
+ def return_val():
+ return {
+ "data": [{"embedding": [0.1, 0.2, 0.3], "index": 0}],
+ "usage": {"total_tokens": 1, "prompt_tokens": 1},
+ "model": "custom-model/embedding-v1",
+ "object": "list"
+ }
+
+ mock_response.json = return_val
+ mock_response.headers = {"key": "value"}
+ mock_response.status_code = 200
+
+ with patch(
+ "litellm.llms.custom_httpx.http_handler.AsyncHTTPHandler.post",
+ return_value=mock_response,
+ ) as mock_post:
+ response = await litellm.aembedding(
+ model="infinity/custom-model/embedding-v1",
+ input=["a"],
+ dimensions=512,
+ encoding_format="float",
+ api_base="https://api.infinity.ai/embeddings",
+ )
+
+ mock_post.assert_called_once()
+ # Assert the response
+ assert response.usage.prompt_tokens == 1
+ assert response.usage.total_tokens == 1