From 104e4cb1bcaea2d420a5685492b51404beee4033 Mon Sep 17 00:00:00 2001 From: Ishaan Jaff Date: Mon, 21 Apr 2025 20:01:29 -0700 Subject: [PATCH] [Feat] Add infinity embedding support (contributor pr) (#10196) * Feature - infinity support for #8764 (#10009) * Added support for infinity embeddings * Added test cases * Fixed tests and api base * Updated docs and tests * Removed unused import * Updated signature * Added support for infinity embeddings * Added test cases * Fixed tests and api base * Updated docs and tests * Removed unused import * Updated signature * Updated validate params --------- Co-authored-by: Ishaan Jaff * fix InfinityEmbeddingConfig --------- Co-authored-by: Prathamesh Saraf --- .env.example | 2 + .gitignore | 1 + docs/my-website/docs/providers/infinity.md | 154 ++++++++++++-- litellm/__init__.py | 6 + .../get_supported_openai_params.py | 2 + .../infinity/{rerank => }/common_utils.py | 9 +- litellm/llms/infinity/embedding/handler.py | 5 + .../llms/infinity/embedding/transformation.py | 141 +++++++++++++ .../llms/infinity/rerank/transformation.py | 2 +- litellm/main.py | 15 ++ litellm/utils.py | 22 ++ tests/llm_translation/test_infinity.py | 192 +++++++++++++++++- 12 files changed, 529 insertions(+), 22 deletions(-) rename litellm/llms/infinity/{rerank => }/common_utils.py (76%) create mode 100644 litellm/llms/infinity/embedding/handler.py create mode 100644 litellm/llms/infinity/embedding/transformation.py diff --git a/.env.example b/.env.example index 82b09ca25e..54986a97cd 100644 --- a/.env.example +++ b/.env.example @@ -20,6 +20,8 @@ REPLICATE_API_TOKEN = "" ANTHROPIC_API_KEY = "" # Infisical INFISICAL_TOKEN = "" +# INFINITY +INFINITY_API_KEY = "" # Development Configs LITELLM_MASTER_KEY = "sk-1234" diff --git a/.gitignore b/.gitignore index 4259b80f55..e8c18bed4c 100644 --- a/.gitignore +++ b/.gitignore @@ -86,4 +86,5 @@ litellm/proxy/db/migrations/0_init/migration.sql litellm/proxy/db/migrations/* litellm/proxy/migrations/*config.yaml litellm/proxy/migrations/* +config.yaml tests/litellm/litellm_core_utils/llm_cost_calc/log.txt diff --git a/docs/my-website/docs/providers/infinity.md b/docs/my-website/docs/providers/infinity.md index 091503bf18..7900d5adb4 100644 --- a/docs/my-website/docs/providers/infinity.md +++ b/docs/my-website/docs/providers/infinity.md @@ -3,18 +3,17 @@ import TabItem from '@theme/TabItem'; # Infinity -| Property | Details | -|-------|-------| -| Description | Infinity is a high-throughput, low-latency REST API for serving text-embeddings, reranking models and clip| -| Provider Route on LiteLLM | `infinity/` | -| Supported Operations | `/rerank` | -| Link to Provider Doc | [Infinity ↗](https://github.com/michaelfeil/infinity) | - +| Property | Details | +| ------------------------- | ---------------------------------------------------------------------------------------------------------- | +| Description | Infinity is a high-throughput, low-latency REST API for serving text-embeddings, reranking models and clip | +| Provider Route on LiteLLM | `infinity/` | +| Supported Operations | `/rerank`, `/embeddings` | +| Link to Provider Doc | [Infinity ↗](https://github.com/michaelfeil/infinity) | ## **Usage - LiteLLM Python SDK** ```python -from litellm import rerank +from litellm import rerank, embedding import os os.environ["INFINITY_API_BASE"] = "http://localhost:8080" @@ -39,8 +38,8 @@ model_list: - model_name: custom-infinity-rerank litellm_params: model: infinity/rerank - api_key: os.environ/INFINITY_API_KEY api_base: https://localhost:8080 + api_key: os.environ/INFINITY_API_KEY ``` Start litellm @@ -51,7 +50,9 @@ litellm --config /path/to/config.yaml # RUNNING on http://0.0.0.0:4000 ``` -Test request +## Test request: + +### Rerank ```bash curl http://0.0.0.0:4000/rerank \ @@ -70,15 +71,14 @@ curl http://0.0.0.0:4000/rerank \ }' ``` +#### Supported Cohere Rerank API Params -## Supported Cohere Rerank API Params - -| Param | Type | Description | -|-------|-------|-------| -| `query` | `str` | The query to rerank the documents against | -| `documents` | `list[str]` | The documents to rerank | -| `top_n` | `int` | The number of documents to return | -| `return_documents` | `bool` | Whether to return the documents in the response | +| Param | Type | Description | +| ------------------ | ----------- | ----------------------------------------------- | +| `query` | `str` | The query to rerank the documents against | +| `documents` | `list[str]` | The documents to rerank | +| `top_n` | `int` | The number of documents to return | +| `return_documents` | `bool` | Whether to return the documents in the response | ### Usage - Return Documents @@ -138,6 +138,7 @@ response = rerank( raw_scores=True, # 👈 PROVIDER-SPECIFIC PARAM ) ``` + @@ -161,7 +162,7 @@ litellm --config /path/to/config.yaml # RUNNING on http://0.0.0.0:4000 ``` -3. Test it! +3. Test it! ```bash curl http://0.0.0.0:4000/rerank \ @@ -179,6 +180,121 @@ curl http://0.0.0.0:4000/rerank \ "raw_scores": True # 👈 PROVIDER-SPECIFIC PARAM }' ``` + + +## Embeddings + +LiteLLM provides an OpenAI api compatible `/embeddings` endpoint for embedding calls. + +**Setup** + +Add this to your litellm proxy config.yaml + +```yaml +model_list: + - model_name: custom-infinity-embedding + litellm_params: + model: infinity/provider/custom-embedding-v1 + api_base: http://localhost:8080 + api_key: os.environ/INFINITY_API_KEY +``` + +### Test request: + +```bash +curl http://0.0.0.0:4000/embeddings \ + -H "Authorization: Bearer sk-1234" \ + -H "Content-Type: application/json" \ + -d '{ + "model": "custom-infinity-embedding", + "input": ["hello"] + }' +``` + +#### Supported Embedding API Params + +| Param | Type | Description | +| ----------------- | ----------- | ----------------------------------------------------------- | +| `model` | `str` | The embedding model to use | +| `input` | `list[str]` | The text inputs to generate embeddings for | +| `encoding_format` | `str` | The format to return embeddings in (e.g. "float", "base64") | +| `modality` | `str` | The type of input (e.g. "text", "image", "audio") | + +### Usage - Basic Examples + + + + +```python +from litellm import embedding +import os + +os.environ["INFINITY_API_BASE"] = "http://localhost:8080" + +response = embedding( + model="infinity/bge-small", + input=["good morning from litellm"] +) + +print(response.data[0]['embedding']) +``` + + + + + +```bash +curl http://0.0.0.0:4000/embeddings \ + -H "Authorization: Bearer sk-1234" \ + -H "Content-Type: application/json" \ + -d '{ + "model": "custom-infinity-embedding", + "input": ["hello"] + }' +``` + + + + +### Usage - OpenAI Client + + + + +```python +from openai import OpenAI + +client = OpenAI( + api_key="", + base_url="" +) + +response = client.embeddings.create( + model="bge-small", + input=["The food was delicious and the waiter..."], + encoding_format="float" +) + +print(response.data[0].embedding) +``` + + + + + +```bash +curl http://0.0.0.0:4000/embeddings \ + -H "Authorization: Bearer sk-1234" \ + -H "Content-Type: application/json" \ + -d '{ + "model": "bge-small", + "input": ["The food was delicious and the waiter..."], + "encoding_format": "float" + }' +``` + + + diff --git a/litellm/__init__.py b/litellm/__init__.py index e9dadbfaf6..a7b747541a 100644 --- a/litellm/__init__.py +++ b/litellm/__init__.py @@ -415,6 +415,7 @@ deepseek_models: List = [] azure_ai_models: List = [] jina_ai_models: List = [] voyage_models: List = [] +infinity_models: List = [] databricks_models: List = [] cloudflare_models: List = [] codestral_models: List = [] @@ -556,6 +557,8 @@ def add_known_models(): azure_ai_models.append(key) elif value.get("litellm_provider") == "voyage": voyage_models.append(key) + elif value.get("litellm_provider") == "infinity": + infinity_models.append(key) elif value.get("litellm_provider") == "databricks": databricks_models.append(key) elif value.get("litellm_provider") == "cloudflare": @@ -644,6 +647,7 @@ model_list = ( + deepseek_models + azure_ai_models + voyage_models + + infinity_models + databricks_models + cloudflare_models + codestral_models @@ -699,6 +703,7 @@ models_by_provider: dict = { "mistral": mistral_chat_models, "azure_ai": azure_ai_models, "voyage": voyage_models, + "infinity": infinity_models, "databricks": databricks_models, "cloudflare": cloudflare_models, "codestral": codestral_models, @@ -946,6 +951,7 @@ from .llms.topaz.image_variations.transformation import TopazImageVariationConfi from litellm.llms.openai.completion.transformation import OpenAITextCompletionConfig from .llms.groq.chat.transformation import GroqChatConfig from .llms.voyage.embedding.transformation import VoyageEmbeddingConfig +from .llms.infinity.embedding.transformation import InfinityEmbeddingConfig from .llms.azure_ai.chat.transformation import AzureAIStudioConfig from .llms.mistral.mistral_chat_transformation import MistralConfig from .llms.openai.responses.transformation import OpenAIResponsesAPIConfig diff --git a/litellm/litellm_core_utils/get_supported_openai_params.py b/litellm/litellm_core_utils/get_supported_openai_params.py index bcf9fdb961..c0f638ddc2 100644 --- a/litellm/litellm_core_utils/get_supported_openai_params.py +++ b/litellm/litellm_core_utils/get_supported_openai_params.py @@ -221,6 +221,8 @@ def get_supported_openai_params( # noqa: PLR0915 return litellm.PredibaseConfig().get_supported_openai_params(model=model) elif custom_llm_provider == "voyage": return litellm.VoyageEmbeddingConfig().get_supported_openai_params(model=model) + elif custom_llm_provider == "infinity": + return litellm.InfinityEmbeddingConfig().get_supported_openai_params(model=model) elif custom_llm_provider == "triton": if request_type == "embeddings": return litellm.TritonEmbeddingConfig().get_supported_openai_params( diff --git a/litellm/llms/infinity/rerank/common_utils.py b/litellm/llms/infinity/common_utils.py similarity index 76% rename from litellm/llms/infinity/rerank/common_utils.py rename to litellm/llms/infinity/common_utils.py index 99477d1a33..089818c829 100644 --- a/litellm/llms/infinity/rerank/common_utils.py +++ b/litellm/llms/infinity/common_utils.py @@ -1,10 +1,16 @@ +from typing import Union import httpx from litellm.llms.base_llm.chat.transformation import BaseLLMException class InfinityError(BaseLLMException): - def __init__(self, status_code, message): + def __init__( + self, + status_code: int, + message: str, + headers: Union[dict, httpx.Headers] = {} + ): self.status_code = status_code self.message = message self.request = httpx.Request( @@ -16,4 +22,5 @@ class InfinityError(BaseLLMException): message=message, request=self.request, response=self.response, + headers=headers, ) # Call the base class constructor with the parameters it needs diff --git a/litellm/llms/infinity/embedding/handler.py b/litellm/llms/infinity/embedding/handler.py new file mode 100644 index 0000000000..cdcb99c433 --- /dev/null +++ b/litellm/llms/infinity/embedding/handler.py @@ -0,0 +1,5 @@ +""" +Infinity Embedding - uses `llm_http_handler.py` to make httpx requests + +Request/Response transformation is handled in `transformation.py` +""" diff --git a/litellm/llms/infinity/embedding/transformation.py b/litellm/llms/infinity/embedding/transformation.py new file mode 100644 index 0000000000..824dcd38da --- /dev/null +++ b/litellm/llms/infinity/embedding/transformation.py @@ -0,0 +1,141 @@ +from typing import List, Optional, Union + +import httpx + +from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj +from litellm.llms.base_llm.chat.transformation import BaseLLMException +from litellm.llms.base_llm.embedding.transformation import BaseEmbeddingConfig +from litellm.secret_managers.main import get_secret_str +from litellm.types.llms.openai import AllEmbeddingInputValues, AllMessageValues +from litellm.types.utils import EmbeddingResponse, Usage + +from ..common_utils import InfinityError + + +class InfinityEmbeddingConfig(BaseEmbeddingConfig): + """ + Reference: https://infinity.modal.michaelfeil.eu/docs + """ + + def __init__(self) -> None: + pass + + def get_complete_url( + self, + api_base: Optional[str], + api_key: Optional[str], + model: str, + optional_params: dict, + litellm_params: dict, + stream: Optional[bool] = None, + ) -> str: + if api_base is None: + raise ValueError("api_base is required for Infinity embeddings") + # Remove trailing slashes and ensure clean base URL + api_base = api_base.rstrip("/") + if not api_base.endswith("/embeddings"): + api_base = f"{api_base}/embeddings" + return api_base + + def validate_environment( + self, + headers: dict, + model: str, + messages: List[AllMessageValues], + optional_params: dict, + litellm_params: dict, + api_key: Optional[str] = None, + api_base: Optional[str] = None, + ) -> dict: + if api_key is None: + api_key = get_secret_str("INFINITY_API_KEY") + + default_headers = { + "Authorization": f"Bearer {api_key}", + "accept": "application/json", + "Content-Type": "application/json", + } + + # If 'Authorization' is provided in headers, it overrides the default. + if "Authorization" in headers: + default_headers["Authorization"] = headers["Authorization"] + + # Merge other headers, overriding any default ones except Authorization + return {**default_headers, **headers} + + def get_supported_openai_params(self, model: str) -> list: + return [ + "encoding_format", + "modality", + "dimensions", + ] + + def map_openai_params( + self, + non_default_params: dict, + optional_params: dict, + model: str, + drop_params: bool, + ) -> dict: + """ + Map OpenAI params to Infinity params + + Reference: https://infinity.modal.michaelfeil.eu/docs + """ + if "encoding_format" in non_default_params: + optional_params["encoding_format"] = non_default_params["encoding_format"] + if "modality" in non_default_params: + optional_params["modality"] = non_default_params["modality"] + if "dimensions" in non_default_params: + optional_params["output_dimension"] = non_default_params["dimensions"] + return optional_params + + def transform_embedding_request( + self, + model: str, + input: AllEmbeddingInputValues, + optional_params: dict, + headers: dict, + ) -> dict: + return { + "input": input, + "model": model, + **optional_params, + } + + def transform_embedding_response( + self, + model: str, + raw_response: httpx.Response, + model_response: EmbeddingResponse, + logging_obj: LiteLLMLoggingObj, + api_key: Optional[str] = None, + request_data: dict = {}, + optional_params: dict = {}, + litellm_params: dict = {}, + ) -> EmbeddingResponse: + try: + raw_response_json = raw_response.json() + except Exception: + raise InfinityError( + message=raw_response.text, status_code=raw_response.status_code + ) + + # model_response.usage + model_response.model = raw_response_json.get("model") + model_response.data = raw_response_json.get("data") + model_response.object = raw_response_json.get("object") + + usage = Usage( + prompt_tokens=raw_response_json.get("usage", {}).get("prompt_tokens", 0), + total_tokens=raw_response_json.get("usage", {}).get("total_tokens", 0), + ) + model_response.usage = usage + return model_response + + def get_error_class( + self, error_message: str, status_code: int, headers: Union[dict, httpx.Headers] + ) -> BaseLLMException: + return InfinityError( + message=error_message, status_code=status_code, headers=headers + ) diff --git a/litellm/llms/infinity/rerank/transformation.py b/litellm/llms/infinity/rerank/transformation.py index 1e7234ab17..4b75fa121b 100644 --- a/litellm/llms/infinity/rerank/transformation.py +++ b/litellm/llms/infinity/rerank/transformation.py @@ -22,7 +22,7 @@ from litellm.types.rerank import ( RerankTokens, ) -from .common_utils import InfinityError +from ..common_utils import InfinityError class InfinityRerankConfig(CohereRerankConfig): diff --git a/litellm/main.py b/litellm/main.py index 9bb1cf0c15..80486fbe02 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -3884,6 +3884,21 @@ def embedding( # noqa: PLR0915 aembedding=aembedding, litellm_params={}, ) + elif custom_llm_provider == "infinity": + response = base_llm_http_handler.embedding( + model=model, + input=input, + custom_llm_provider=custom_llm_provider, + api_base=api_base, + api_key=api_key, + logging_obj=logging, + timeout=timeout, + model_response=EmbeddingResponse(), + optional_params=optional_params, + client=client, + aembedding=aembedding, + litellm_params={}, + ) elif custom_llm_provider == "watsonx": credentials = IBMWatsonXMixin.get_watsonx_credentials( optional_params=optional_params, api_key=api_key, api_base=api_base diff --git a/litellm/utils.py b/litellm/utils.py index 3efd188717..38e604943a 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -2735,6 +2735,21 @@ def get_optional_params_embeddings( # noqa: PLR0915 ) final_params = {**optional_params, **kwargs} return final_params + elif custom_llm_provider == "infinity": + supported_params = get_supported_openai_params( + model=model, + custom_llm_provider="infinity", + request_type="embeddings", + ) + _check_valid_arg(supported_params=supported_params) + optional_params = litellm.InfinityEmbeddingConfig().map_openai_params( + non_default_params=non_default_params, + optional_params={}, + model=model, + drop_params=drop_params if drop_params is not None else False, + ) + final_params = {**optional_params, **kwargs} + return final_params elif custom_llm_provider == "fireworks_ai": supported_params = get_supported_openai_params( model=model, @@ -5120,6 +5135,11 @@ def validate_environment( # noqa: PLR0915 keys_in_environment = True else: missing_keys.append("VOYAGE_API_KEY") + elif custom_llm_provider == "infinity": + if "INFINITY_API_KEY" in os.environ: + keys_in_environment = True + else: + missing_keys.append("INFINITY_API_KEY") elif custom_llm_provider == "fireworks_ai": if ( "FIREWORKS_AI_API_KEY" in os.environ @@ -6554,6 +6574,8 @@ class ProviderConfigManager: return litellm.TritonEmbeddingConfig() elif litellm.LlmProviders.WATSONX == provider: return litellm.IBMWatsonXEmbeddingConfig() + elif litellm.LlmProviders.INFINITY == provider: + return litellm.InfinityEmbeddingConfig() raise ValueError(f"Provider {provider.value} does not support embedding config") @staticmethod diff --git a/tests/llm_translation/test_infinity.py b/tests/llm_translation/test_infinity.py index eb986b8ab5..802789377c 100644 --- a/tests/llm_translation/test_infinity.py +++ b/tests/llm_translation/test_infinity.py @@ -15,7 +15,7 @@ import json import os import sys from datetime import datetime -from unittest.mock import AsyncMock, patch +from unittest.mock import patch, MagicMock, AsyncMock import pytest @@ -25,6 +25,10 @@ sys.path.insert( from test_rerank import assert_response_shape import litellm +from base_embedding_unit_tests import BaseLLMEmbeddingTest +from litellm.llms.custom_httpx.http_handler import HTTPHandler, AsyncHTTPHandler +from litellm.types.utils import EmbeddingResponse, Usage + @pytest.mark.asyncio() async def test_infinity_rerank(): @@ -182,3 +186,189 @@ async def test_infinity_rerank_with_env(monkeypatch): ) # total_tokens - prompt_tokens assert_response_shape(response, custom_llm_provider="infinity") + +#### Embedding Tests +@pytest.mark.asyncio() +async def test_infinity_embedding(): + mock_response = AsyncMock() + + def return_val(): + return { + "data": [{"embedding": [0.1, 0.2, 0.3], "index": 0}], + "usage": {"prompt_tokens": 100, "total_tokens": 150}, + "model": "custom-model/embedding-v1", + "object": "list" + } + + mock_response.json = return_val + mock_response.headers = {"key": "value"} + mock_response.status_code = 200 + + expected_payload = { + "model": "custom-model/embedding-v1", + "input": ["hello world"], + "encoding_format": "float", + "output_dimension": 512 + } + + with patch( + "litellm.llms.custom_httpx.http_handler.AsyncHTTPHandler.post", + return_value=mock_response, + ) as mock_post: + response = await litellm.aembedding( + model="infinity/custom-model/embedding-v1", + input=["hello world"], + dimensions=512, + encoding_format="float", + api_base="https://api.infinity.ai/embeddings", + + ) + + # Assert + mock_post.assert_called_once() + print("call args", mock_post.call_args) + args_to_api = mock_post.call_args.kwargs["data"] + _url = mock_post.call_args.kwargs["url"] + assert _url == "https://api.infinity.ai/embeddings" + + request_data = json.loads(args_to_api) + assert request_data["input"] == expected_payload["input"] + assert request_data["model"] == expected_payload["model"] + assert request_data["output_dimension"] == expected_payload["output_dimension"] + assert request_data["encoding_format"] == expected_payload["encoding_format"] + + assert response.data is not None + assert response.usage.prompt_tokens == 100 + assert response.usage.total_tokens == 150 + assert response.model == "custom-model/embedding-v1" + assert response.object == "list" + + +@pytest.mark.asyncio() +async def test_infinity_embedding_with_env(monkeypatch): + # Set up mock response + mock_response = AsyncMock() + + def return_val(): + return { + "data": [{"embedding": [0.1, 0.2, 0.3], "index": 0}], + "usage": {"prompt_tokens": 100, "total_tokens": 150}, + "model": "custom-model/embedding-v1", + "object": "list" + } + + mock_response.json = return_val + mock_response.headers = {"key": "value"} + mock_response.status_code = 200 + + expected_payload = { + "model": "custom-model/embedding-v1", + "input": ["hello world"], + "encoding_format": "float", + "output_dimension": 512 + } + + with patch( + "litellm.llms.custom_httpx.http_handler.AsyncHTTPHandler.post", + return_value=mock_response, + ) as mock_post: + response = await litellm.aembedding( + model="infinity/custom-model/embedding-v1", + input=["hello world"], + dimensions=512, + encoding_format="float", + api_base="https://api.infinity.ai/embeddings", + ) + + # Assert + mock_post.assert_called_once() + print("call args", mock_post.call_args) + args_to_api = mock_post.call_args.kwargs["data"] + _url = mock_post.call_args.kwargs["url"] + assert _url == "https://api.infinity.ai/embeddings" + + request_data = json.loads(args_to_api) + assert request_data["input"] == expected_payload["input"] + assert request_data["model"] == expected_payload["model"] + assert request_data["output_dimension"] == expected_payload["output_dimension"] + assert request_data["encoding_format"] == expected_payload["encoding_format"] + + assert response.data is not None + assert response.usage.prompt_tokens == 100 + assert response.usage.total_tokens == 150 + assert response.model == "custom-model/embedding-v1" + assert response.object == "list" + + +@pytest.mark.asyncio() +async def test_infinity_embedding_extra_params(): + mock_response = AsyncMock() + + def return_val(): + return { + "data": [{"embedding": [0.1, 0.2, 0.3], "index": 0}], + "usage": {"prompt_tokens": 100, "total_tokens": 150}, + "model": "custom-model/embedding-v1", + "object": "list" + } + + mock_response.json = return_val + mock_response.headers = {"key": "value"} + mock_response.status_code = 200 + + with patch( + "litellm.llms.custom_httpx.http_handler.AsyncHTTPHandler.post", + return_value=mock_response, + ) as mock_post: + response = await litellm.aembedding( + model="infinity/custom-model/embedding-v1", + input=["test input"], + dimensions=512, + encoding_format="float", + modality="text", + api_base="https://api.infinity.ai/embeddings", + ) + + mock_post.assert_called_once() + json_data = json.loads(mock_post.call_args.kwargs["data"]) + + # Assert the request parameters + assert json_data["input"] == ["test input"] + assert json_data["model"] == "custom-model/embedding-v1" + assert json_data["output_dimension"] == 512 + assert json_data["encoding_format"] == "float" + assert json_data["modality"] == "text" + + +@pytest.mark.asyncio() +async def test_infinity_embedding_prompt_token_mapping(): + mock_response = AsyncMock() + + def return_val(): + return { + "data": [{"embedding": [0.1, 0.2, 0.3], "index": 0}], + "usage": {"total_tokens": 1, "prompt_tokens": 1}, + "model": "custom-model/embedding-v1", + "object": "list" + } + + mock_response.json = return_val + mock_response.headers = {"key": "value"} + mock_response.status_code = 200 + + with patch( + "litellm.llms.custom_httpx.http_handler.AsyncHTTPHandler.post", + return_value=mock_response, + ) as mock_post: + response = await litellm.aembedding( + model="infinity/custom-model/embedding-v1", + input=["a"], + dimensions=512, + encoding_format="float", + api_base="https://api.infinity.ai/embeddings", + ) + + mock_post.assert_called_once() + # Assert the response + assert response.usage.prompt_tokens == 1 + assert response.usage.total_tokens == 1