mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-24 10:14:26 +00:00
[Feat] Add infinity embedding support (contributor pr) (#10196)
* Feature - infinity support for #8764 (#10009) * Added support for infinity embeddings * Added test cases * Fixed tests and api base * Updated docs and tests * Removed unused import * Updated signature * Added support for infinity embeddings * Added test cases * Fixed tests and api base * Updated docs and tests * Removed unused import * Updated signature * Updated validate params --------- Co-authored-by: Ishaan Jaff <ishaanjaffer0324@gmail.com> * fix InfinityEmbeddingConfig --------- Co-authored-by: Prathamesh Saraf <pratamesh1867@gmail.com>
This commit is contained in:
parent
0c2f705417
commit
104e4cb1bc
12 changed files with 529 additions and 22 deletions
|
@ -20,6 +20,8 @@ REPLICATE_API_TOKEN = ""
|
|||
ANTHROPIC_API_KEY = ""
|
||||
# Infisical
|
||||
INFISICAL_TOKEN = ""
|
||||
# INFINITY
|
||||
INFINITY_API_KEY = ""
|
||||
|
||||
# Development Configs
|
||||
LITELLM_MASTER_KEY = "sk-1234"
|
||||
|
|
1
.gitignore
vendored
1
.gitignore
vendored
|
@ -86,4 +86,5 @@ litellm/proxy/db/migrations/0_init/migration.sql
|
|||
litellm/proxy/db/migrations/*
|
||||
litellm/proxy/migrations/*config.yaml
|
||||
litellm/proxy/migrations/*
|
||||
config.yaml
|
||||
tests/litellm/litellm_core_utils/llm_cost_calc/log.txt
|
||||
|
|
|
@ -3,18 +3,17 @@ import TabItem from '@theme/TabItem';
|
|||
|
||||
# Infinity
|
||||
|
||||
| Property | Details |
|
||||
|-------|-------|
|
||||
| Description | Infinity is a high-throughput, low-latency REST API for serving text-embeddings, reranking models and clip|
|
||||
| Provider Route on LiteLLM | `infinity/` |
|
||||
| Supported Operations | `/rerank` |
|
||||
| Link to Provider Doc | [Infinity ↗](https://github.com/michaelfeil/infinity) |
|
||||
|
||||
| Property | Details |
|
||||
| ------------------------- | ---------------------------------------------------------------------------------------------------------- |
|
||||
| Description | Infinity is a high-throughput, low-latency REST API for serving text-embeddings, reranking models and clip |
|
||||
| Provider Route on LiteLLM | `infinity/` |
|
||||
| Supported Operations | `/rerank`, `/embeddings` |
|
||||
| Link to Provider Doc | [Infinity ↗](https://github.com/michaelfeil/infinity) |
|
||||
|
||||
## **Usage - LiteLLM Python SDK**
|
||||
|
||||
```python
|
||||
from litellm import rerank
|
||||
from litellm import rerank, embedding
|
||||
import os
|
||||
|
||||
os.environ["INFINITY_API_BASE"] = "http://localhost:8080"
|
||||
|
@ -39,8 +38,8 @@ model_list:
|
|||
- model_name: custom-infinity-rerank
|
||||
litellm_params:
|
||||
model: infinity/rerank
|
||||
api_key: os.environ/INFINITY_API_KEY
|
||||
api_base: https://localhost:8080
|
||||
api_key: os.environ/INFINITY_API_KEY
|
||||
```
|
||||
|
||||
Start litellm
|
||||
|
@ -51,7 +50,9 @@ litellm --config /path/to/config.yaml
|
|||
# RUNNING on http://0.0.0.0:4000
|
||||
```
|
||||
|
||||
Test request
|
||||
## Test request:
|
||||
|
||||
### Rerank
|
||||
|
||||
```bash
|
||||
curl http://0.0.0.0:4000/rerank \
|
||||
|
@ -70,15 +71,14 @@ curl http://0.0.0.0:4000/rerank \
|
|||
}'
|
||||
```
|
||||
|
||||
#### Supported Cohere Rerank API Params
|
||||
|
||||
## Supported Cohere Rerank API Params
|
||||
|
||||
| Param | Type | Description |
|
||||
|-------|-------|-------|
|
||||
| `query` | `str` | The query to rerank the documents against |
|
||||
| `documents` | `list[str]` | The documents to rerank |
|
||||
| `top_n` | `int` | The number of documents to return |
|
||||
| `return_documents` | `bool` | Whether to return the documents in the response |
|
||||
| Param | Type | Description |
|
||||
| ------------------ | ----------- | ----------------------------------------------- |
|
||||
| `query` | `str` | The query to rerank the documents against |
|
||||
| `documents` | `list[str]` | The documents to rerank |
|
||||
| `top_n` | `int` | The number of documents to return |
|
||||
| `return_documents` | `bool` | Whether to return the documents in the response |
|
||||
|
||||
### Usage - Return Documents
|
||||
|
||||
|
@ -138,6 +138,7 @@ response = rerank(
|
|||
raw_scores=True, # 👈 PROVIDER-SPECIFIC PARAM
|
||||
)
|
||||
```
|
||||
|
||||
</TabItem>
|
||||
|
||||
<TabItem value="proxy" label="PROXY">
|
||||
|
@ -161,7 +162,7 @@ litellm --config /path/to/config.yaml
|
|||
# RUNNING on http://0.0.0.0:4000
|
||||
```
|
||||
|
||||
3. Test it!
|
||||
3. Test it!
|
||||
|
||||
```bash
|
||||
curl http://0.0.0.0:4000/rerank \
|
||||
|
@ -179,6 +180,121 @@ curl http://0.0.0.0:4000/rerank \
|
|||
"raw_scores": True # 👈 PROVIDER-SPECIFIC PARAM
|
||||
}'
|
||||
```
|
||||
|
||||
</TabItem>
|
||||
|
||||
</Tabs>
|
||||
|
||||
## Embeddings
|
||||
|
||||
LiteLLM provides an OpenAI api compatible `/embeddings` endpoint for embedding calls.
|
||||
|
||||
**Setup**
|
||||
|
||||
Add this to your litellm proxy config.yaml
|
||||
|
||||
```yaml
|
||||
model_list:
|
||||
- model_name: custom-infinity-embedding
|
||||
litellm_params:
|
||||
model: infinity/provider/custom-embedding-v1
|
||||
api_base: http://localhost:8080
|
||||
api_key: os.environ/INFINITY_API_KEY
|
||||
```
|
||||
|
||||
### Test request:
|
||||
|
||||
```bash
|
||||
curl http://0.0.0.0:4000/embeddings \
|
||||
-H "Authorization: Bearer sk-1234" \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{
|
||||
"model": "custom-infinity-embedding",
|
||||
"input": ["hello"]
|
||||
}'
|
||||
```
|
||||
|
||||
#### Supported Embedding API Params
|
||||
|
||||
| Param | Type | Description |
|
||||
| ----------------- | ----------- | ----------------------------------------------------------- |
|
||||
| `model` | `str` | The embedding model to use |
|
||||
| `input` | `list[str]` | The text inputs to generate embeddings for |
|
||||
| `encoding_format` | `str` | The format to return embeddings in (e.g. "float", "base64") |
|
||||
| `modality` | `str` | The type of input (e.g. "text", "image", "audio") |
|
||||
|
||||
### Usage - Basic Examples
|
||||
|
||||
<Tabs>
|
||||
<TabItem value="sdk" label="SDK">
|
||||
|
||||
```python
|
||||
from litellm import embedding
|
||||
import os
|
||||
|
||||
os.environ["INFINITY_API_BASE"] = "http://localhost:8080"
|
||||
|
||||
response = embedding(
|
||||
model="infinity/bge-small",
|
||||
input=["good morning from litellm"]
|
||||
)
|
||||
|
||||
print(response.data[0]['embedding'])
|
||||
```
|
||||
|
||||
</TabItem>
|
||||
|
||||
<TabItem value="proxy" label="PROXY">
|
||||
|
||||
```bash
|
||||
curl http://0.0.0.0:4000/embeddings \
|
||||
-H "Authorization: Bearer sk-1234" \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{
|
||||
"model": "custom-infinity-embedding",
|
||||
"input": ["hello"]
|
||||
}'
|
||||
```
|
||||
|
||||
</TabItem>
|
||||
</Tabs>
|
||||
|
||||
### Usage - OpenAI Client
|
||||
|
||||
<Tabs>
|
||||
<TabItem value="sdk" label="SDK">
|
||||
|
||||
```python
|
||||
from openai import OpenAI
|
||||
|
||||
client = OpenAI(
|
||||
api_key="<LITELLM_MASTER_KEY>",
|
||||
base_url="<LITELLM_URL>"
|
||||
)
|
||||
|
||||
response = client.embeddings.create(
|
||||
model="bge-small",
|
||||
input=["The food was delicious and the waiter..."],
|
||||
encoding_format="float"
|
||||
)
|
||||
|
||||
print(response.data[0].embedding)
|
||||
```
|
||||
|
||||
</TabItem>
|
||||
|
||||
<TabItem value="proxy" label="PROXY">
|
||||
|
||||
```bash
|
||||
curl http://0.0.0.0:4000/embeddings \
|
||||
-H "Authorization: Bearer sk-1234" \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{
|
||||
"model": "bge-small",
|
||||
"input": ["The food was delicious and the waiter..."],
|
||||
"encoding_format": "float"
|
||||
}'
|
||||
```
|
||||
|
||||
</TabItem>
|
||||
</Tabs>
|
||||
|
|
|
@ -415,6 +415,7 @@ deepseek_models: List = []
|
|||
azure_ai_models: List = []
|
||||
jina_ai_models: List = []
|
||||
voyage_models: List = []
|
||||
infinity_models: List = []
|
||||
databricks_models: List = []
|
||||
cloudflare_models: List = []
|
||||
codestral_models: List = []
|
||||
|
@ -556,6 +557,8 @@ def add_known_models():
|
|||
azure_ai_models.append(key)
|
||||
elif value.get("litellm_provider") == "voyage":
|
||||
voyage_models.append(key)
|
||||
elif value.get("litellm_provider") == "infinity":
|
||||
infinity_models.append(key)
|
||||
elif value.get("litellm_provider") == "databricks":
|
||||
databricks_models.append(key)
|
||||
elif value.get("litellm_provider") == "cloudflare":
|
||||
|
@ -644,6 +647,7 @@ model_list = (
|
|||
+ deepseek_models
|
||||
+ azure_ai_models
|
||||
+ voyage_models
|
||||
+ infinity_models
|
||||
+ databricks_models
|
||||
+ cloudflare_models
|
||||
+ codestral_models
|
||||
|
@ -699,6 +703,7 @@ models_by_provider: dict = {
|
|||
"mistral": mistral_chat_models,
|
||||
"azure_ai": azure_ai_models,
|
||||
"voyage": voyage_models,
|
||||
"infinity": infinity_models,
|
||||
"databricks": databricks_models,
|
||||
"cloudflare": cloudflare_models,
|
||||
"codestral": codestral_models,
|
||||
|
@ -946,6 +951,7 @@ from .llms.topaz.image_variations.transformation import TopazImageVariationConfi
|
|||
from litellm.llms.openai.completion.transformation import OpenAITextCompletionConfig
|
||||
from .llms.groq.chat.transformation import GroqChatConfig
|
||||
from .llms.voyage.embedding.transformation import VoyageEmbeddingConfig
|
||||
from .llms.infinity.embedding.transformation import InfinityEmbeddingConfig
|
||||
from .llms.azure_ai.chat.transformation import AzureAIStudioConfig
|
||||
from .llms.mistral.mistral_chat_transformation import MistralConfig
|
||||
from .llms.openai.responses.transformation import OpenAIResponsesAPIConfig
|
||||
|
|
|
@ -221,6 +221,8 @@ def get_supported_openai_params( # noqa: PLR0915
|
|||
return litellm.PredibaseConfig().get_supported_openai_params(model=model)
|
||||
elif custom_llm_provider == "voyage":
|
||||
return litellm.VoyageEmbeddingConfig().get_supported_openai_params(model=model)
|
||||
elif custom_llm_provider == "infinity":
|
||||
return litellm.InfinityEmbeddingConfig().get_supported_openai_params(model=model)
|
||||
elif custom_llm_provider == "triton":
|
||||
if request_type == "embeddings":
|
||||
return litellm.TritonEmbeddingConfig().get_supported_openai_params(
|
||||
|
|
|
@ -1,10 +1,16 @@
|
|||
from typing import Union
|
||||
import httpx
|
||||
|
||||
from litellm.llms.base_llm.chat.transformation import BaseLLMException
|
||||
|
||||
|
||||
class InfinityError(BaseLLMException):
|
||||
def __init__(self, status_code, message):
|
||||
def __init__(
|
||||
self,
|
||||
status_code: int,
|
||||
message: str,
|
||||
headers: Union[dict, httpx.Headers] = {}
|
||||
):
|
||||
self.status_code = status_code
|
||||
self.message = message
|
||||
self.request = httpx.Request(
|
||||
|
@ -16,4 +22,5 @@ class InfinityError(BaseLLMException):
|
|||
message=message,
|
||||
request=self.request,
|
||||
response=self.response,
|
||||
headers=headers,
|
||||
) # Call the base class constructor with the parameters it needs
|
5
litellm/llms/infinity/embedding/handler.py
Normal file
5
litellm/llms/infinity/embedding/handler.py
Normal file
|
@ -0,0 +1,5 @@
|
|||
"""
|
||||
Infinity Embedding - uses `llm_http_handler.py` to make httpx requests
|
||||
|
||||
Request/Response transformation is handled in `transformation.py`
|
||||
"""
|
141
litellm/llms/infinity/embedding/transformation.py
Normal file
141
litellm/llms/infinity/embedding/transformation.py
Normal file
|
@ -0,0 +1,141 @@
|
|||
from typing import List, Optional, Union
|
||||
|
||||
import httpx
|
||||
|
||||
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
|
||||
from litellm.llms.base_llm.chat.transformation import BaseLLMException
|
||||
from litellm.llms.base_llm.embedding.transformation import BaseEmbeddingConfig
|
||||
from litellm.secret_managers.main import get_secret_str
|
||||
from litellm.types.llms.openai import AllEmbeddingInputValues, AllMessageValues
|
||||
from litellm.types.utils import EmbeddingResponse, Usage
|
||||
|
||||
from ..common_utils import InfinityError
|
||||
|
||||
|
||||
class InfinityEmbeddingConfig(BaseEmbeddingConfig):
|
||||
"""
|
||||
Reference: https://infinity.modal.michaelfeil.eu/docs
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
pass
|
||||
|
||||
def get_complete_url(
|
||||
self,
|
||||
api_base: Optional[str],
|
||||
api_key: Optional[str],
|
||||
model: str,
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
stream: Optional[bool] = None,
|
||||
) -> str:
|
||||
if api_base is None:
|
||||
raise ValueError("api_base is required for Infinity embeddings")
|
||||
# Remove trailing slashes and ensure clean base URL
|
||||
api_base = api_base.rstrip("/")
|
||||
if not api_base.endswith("/embeddings"):
|
||||
api_base = f"{api_base}/embeddings"
|
||||
return api_base
|
||||
|
||||
def validate_environment(
|
||||
self,
|
||||
headers: dict,
|
||||
model: str,
|
||||
messages: List[AllMessageValues],
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
api_key: Optional[str] = None,
|
||||
api_base: Optional[str] = None,
|
||||
) -> dict:
|
||||
if api_key is None:
|
||||
api_key = get_secret_str("INFINITY_API_KEY")
|
||||
|
||||
default_headers = {
|
||||
"Authorization": f"Bearer {api_key}",
|
||||
"accept": "application/json",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
|
||||
# If 'Authorization' is provided in headers, it overrides the default.
|
||||
if "Authorization" in headers:
|
||||
default_headers["Authorization"] = headers["Authorization"]
|
||||
|
||||
# Merge other headers, overriding any default ones except Authorization
|
||||
return {**default_headers, **headers}
|
||||
|
||||
def get_supported_openai_params(self, model: str) -> list:
|
||||
return [
|
||||
"encoding_format",
|
||||
"modality",
|
||||
"dimensions",
|
||||
]
|
||||
|
||||
def map_openai_params(
|
||||
self,
|
||||
non_default_params: dict,
|
||||
optional_params: dict,
|
||||
model: str,
|
||||
drop_params: bool,
|
||||
) -> dict:
|
||||
"""
|
||||
Map OpenAI params to Infinity params
|
||||
|
||||
Reference: https://infinity.modal.michaelfeil.eu/docs
|
||||
"""
|
||||
if "encoding_format" in non_default_params:
|
||||
optional_params["encoding_format"] = non_default_params["encoding_format"]
|
||||
if "modality" in non_default_params:
|
||||
optional_params["modality"] = non_default_params["modality"]
|
||||
if "dimensions" in non_default_params:
|
||||
optional_params["output_dimension"] = non_default_params["dimensions"]
|
||||
return optional_params
|
||||
|
||||
def transform_embedding_request(
|
||||
self,
|
||||
model: str,
|
||||
input: AllEmbeddingInputValues,
|
||||
optional_params: dict,
|
||||
headers: dict,
|
||||
) -> dict:
|
||||
return {
|
||||
"input": input,
|
||||
"model": model,
|
||||
**optional_params,
|
||||
}
|
||||
|
||||
def transform_embedding_response(
|
||||
self,
|
||||
model: str,
|
||||
raw_response: httpx.Response,
|
||||
model_response: EmbeddingResponse,
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
api_key: Optional[str] = None,
|
||||
request_data: dict = {},
|
||||
optional_params: dict = {},
|
||||
litellm_params: dict = {},
|
||||
) -> EmbeddingResponse:
|
||||
try:
|
||||
raw_response_json = raw_response.json()
|
||||
except Exception:
|
||||
raise InfinityError(
|
||||
message=raw_response.text, status_code=raw_response.status_code
|
||||
)
|
||||
|
||||
# model_response.usage
|
||||
model_response.model = raw_response_json.get("model")
|
||||
model_response.data = raw_response_json.get("data")
|
||||
model_response.object = raw_response_json.get("object")
|
||||
|
||||
usage = Usage(
|
||||
prompt_tokens=raw_response_json.get("usage", {}).get("prompt_tokens", 0),
|
||||
total_tokens=raw_response_json.get("usage", {}).get("total_tokens", 0),
|
||||
)
|
||||
model_response.usage = usage
|
||||
return model_response
|
||||
|
||||
def get_error_class(
|
||||
self, error_message: str, status_code: int, headers: Union[dict, httpx.Headers]
|
||||
) -> BaseLLMException:
|
||||
return InfinityError(
|
||||
message=error_message, status_code=status_code, headers=headers
|
||||
)
|
|
@ -22,7 +22,7 @@ from litellm.types.rerank import (
|
|||
RerankTokens,
|
||||
)
|
||||
|
||||
from .common_utils import InfinityError
|
||||
from ..common_utils import InfinityError
|
||||
|
||||
|
||||
class InfinityRerankConfig(CohereRerankConfig):
|
||||
|
|
|
@ -3884,6 +3884,21 @@ def embedding( # noqa: PLR0915
|
|||
aembedding=aembedding,
|
||||
litellm_params={},
|
||||
)
|
||||
elif custom_llm_provider == "infinity":
|
||||
response = base_llm_http_handler.embedding(
|
||||
model=model,
|
||||
input=input,
|
||||
custom_llm_provider=custom_llm_provider,
|
||||
api_base=api_base,
|
||||
api_key=api_key,
|
||||
logging_obj=logging,
|
||||
timeout=timeout,
|
||||
model_response=EmbeddingResponse(),
|
||||
optional_params=optional_params,
|
||||
client=client,
|
||||
aembedding=aembedding,
|
||||
litellm_params={},
|
||||
)
|
||||
elif custom_llm_provider == "watsonx":
|
||||
credentials = IBMWatsonXMixin.get_watsonx_credentials(
|
||||
optional_params=optional_params, api_key=api_key, api_base=api_base
|
||||
|
|
|
@ -2735,6 +2735,21 @@ def get_optional_params_embeddings( # noqa: PLR0915
|
|||
)
|
||||
final_params = {**optional_params, **kwargs}
|
||||
return final_params
|
||||
elif custom_llm_provider == "infinity":
|
||||
supported_params = get_supported_openai_params(
|
||||
model=model,
|
||||
custom_llm_provider="infinity",
|
||||
request_type="embeddings",
|
||||
)
|
||||
_check_valid_arg(supported_params=supported_params)
|
||||
optional_params = litellm.InfinityEmbeddingConfig().map_openai_params(
|
||||
non_default_params=non_default_params,
|
||||
optional_params={},
|
||||
model=model,
|
||||
drop_params=drop_params if drop_params is not None else False,
|
||||
)
|
||||
final_params = {**optional_params, **kwargs}
|
||||
return final_params
|
||||
elif custom_llm_provider == "fireworks_ai":
|
||||
supported_params = get_supported_openai_params(
|
||||
model=model,
|
||||
|
@ -5120,6 +5135,11 @@ def validate_environment( # noqa: PLR0915
|
|||
keys_in_environment = True
|
||||
else:
|
||||
missing_keys.append("VOYAGE_API_KEY")
|
||||
elif custom_llm_provider == "infinity":
|
||||
if "INFINITY_API_KEY" in os.environ:
|
||||
keys_in_environment = True
|
||||
else:
|
||||
missing_keys.append("INFINITY_API_KEY")
|
||||
elif custom_llm_provider == "fireworks_ai":
|
||||
if (
|
||||
"FIREWORKS_AI_API_KEY" in os.environ
|
||||
|
@ -6554,6 +6574,8 @@ class ProviderConfigManager:
|
|||
return litellm.TritonEmbeddingConfig()
|
||||
elif litellm.LlmProviders.WATSONX == provider:
|
||||
return litellm.IBMWatsonXEmbeddingConfig()
|
||||
elif litellm.LlmProviders.INFINITY == provider:
|
||||
return litellm.InfinityEmbeddingConfig()
|
||||
raise ValueError(f"Provider {provider.value} does not support embedding config")
|
||||
|
||||
@staticmethod
|
||||
|
|
|
@ -15,7 +15,7 @@ import json
|
|||
import os
|
||||
import sys
|
||||
from datetime import datetime
|
||||
from unittest.mock import AsyncMock, patch
|
||||
from unittest.mock import patch, MagicMock, AsyncMock
|
||||
|
||||
import pytest
|
||||
|
||||
|
@ -25,6 +25,10 @@ sys.path.insert(
|
|||
from test_rerank import assert_response_shape
|
||||
import litellm
|
||||
|
||||
from base_embedding_unit_tests import BaseLLMEmbeddingTest
|
||||
from litellm.llms.custom_httpx.http_handler import HTTPHandler, AsyncHTTPHandler
|
||||
from litellm.types.utils import EmbeddingResponse, Usage
|
||||
|
||||
|
||||
@pytest.mark.asyncio()
|
||||
async def test_infinity_rerank():
|
||||
|
@ -182,3 +186,189 @@ async def test_infinity_rerank_with_env(monkeypatch):
|
|||
) # total_tokens - prompt_tokens
|
||||
|
||||
assert_response_shape(response, custom_llm_provider="infinity")
|
||||
|
||||
#### Embedding Tests
|
||||
@pytest.mark.asyncio()
|
||||
async def test_infinity_embedding():
|
||||
mock_response = AsyncMock()
|
||||
|
||||
def return_val():
|
||||
return {
|
||||
"data": [{"embedding": [0.1, 0.2, 0.3], "index": 0}],
|
||||
"usage": {"prompt_tokens": 100, "total_tokens": 150},
|
||||
"model": "custom-model/embedding-v1",
|
||||
"object": "list"
|
||||
}
|
||||
|
||||
mock_response.json = return_val
|
||||
mock_response.headers = {"key": "value"}
|
||||
mock_response.status_code = 200
|
||||
|
||||
expected_payload = {
|
||||
"model": "custom-model/embedding-v1",
|
||||
"input": ["hello world"],
|
||||
"encoding_format": "float",
|
||||
"output_dimension": 512
|
||||
}
|
||||
|
||||
with patch(
|
||||
"litellm.llms.custom_httpx.http_handler.AsyncHTTPHandler.post",
|
||||
return_value=mock_response,
|
||||
) as mock_post:
|
||||
response = await litellm.aembedding(
|
||||
model="infinity/custom-model/embedding-v1",
|
||||
input=["hello world"],
|
||||
dimensions=512,
|
||||
encoding_format="float",
|
||||
api_base="https://api.infinity.ai/embeddings",
|
||||
|
||||
)
|
||||
|
||||
# Assert
|
||||
mock_post.assert_called_once()
|
||||
print("call args", mock_post.call_args)
|
||||
args_to_api = mock_post.call_args.kwargs["data"]
|
||||
_url = mock_post.call_args.kwargs["url"]
|
||||
assert _url == "https://api.infinity.ai/embeddings"
|
||||
|
||||
request_data = json.loads(args_to_api)
|
||||
assert request_data["input"] == expected_payload["input"]
|
||||
assert request_data["model"] == expected_payload["model"]
|
||||
assert request_data["output_dimension"] == expected_payload["output_dimension"]
|
||||
assert request_data["encoding_format"] == expected_payload["encoding_format"]
|
||||
|
||||
assert response.data is not None
|
||||
assert response.usage.prompt_tokens == 100
|
||||
assert response.usage.total_tokens == 150
|
||||
assert response.model == "custom-model/embedding-v1"
|
||||
assert response.object == "list"
|
||||
|
||||
|
||||
@pytest.mark.asyncio()
|
||||
async def test_infinity_embedding_with_env(monkeypatch):
|
||||
# Set up mock response
|
||||
mock_response = AsyncMock()
|
||||
|
||||
def return_val():
|
||||
return {
|
||||
"data": [{"embedding": [0.1, 0.2, 0.3], "index": 0}],
|
||||
"usage": {"prompt_tokens": 100, "total_tokens": 150},
|
||||
"model": "custom-model/embedding-v1",
|
||||
"object": "list"
|
||||
}
|
||||
|
||||
mock_response.json = return_val
|
||||
mock_response.headers = {"key": "value"}
|
||||
mock_response.status_code = 200
|
||||
|
||||
expected_payload = {
|
||||
"model": "custom-model/embedding-v1",
|
||||
"input": ["hello world"],
|
||||
"encoding_format": "float",
|
||||
"output_dimension": 512
|
||||
}
|
||||
|
||||
with patch(
|
||||
"litellm.llms.custom_httpx.http_handler.AsyncHTTPHandler.post",
|
||||
return_value=mock_response,
|
||||
) as mock_post:
|
||||
response = await litellm.aembedding(
|
||||
model="infinity/custom-model/embedding-v1",
|
||||
input=["hello world"],
|
||||
dimensions=512,
|
||||
encoding_format="float",
|
||||
api_base="https://api.infinity.ai/embeddings",
|
||||
)
|
||||
|
||||
# Assert
|
||||
mock_post.assert_called_once()
|
||||
print("call args", mock_post.call_args)
|
||||
args_to_api = mock_post.call_args.kwargs["data"]
|
||||
_url = mock_post.call_args.kwargs["url"]
|
||||
assert _url == "https://api.infinity.ai/embeddings"
|
||||
|
||||
request_data = json.loads(args_to_api)
|
||||
assert request_data["input"] == expected_payload["input"]
|
||||
assert request_data["model"] == expected_payload["model"]
|
||||
assert request_data["output_dimension"] == expected_payload["output_dimension"]
|
||||
assert request_data["encoding_format"] == expected_payload["encoding_format"]
|
||||
|
||||
assert response.data is not None
|
||||
assert response.usage.prompt_tokens == 100
|
||||
assert response.usage.total_tokens == 150
|
||||
assert response.model == "custom-model/embedding-v1"
|
||||
assert response.object == "list"
|
||||
|
||||
|
||||
@pytest.mark.asyncio()
|
||||
async def test_infinity_embedding_extra_params():
|
||||
mock_response = AsyncMock()
|
||||
|
||||
def return_val():
|
||||
return {
|
||||
"data": [{"embedding": [0.1, 0.2, 0.3], "index": 0}],
|
||||
"usage": {"prompt_tokens": 100, "total_tokens": 150},
|
||||
"model": "custom-model/embedding-v1",
|
||||
"object": "list"
|
||||
}
|
||||
|
||||
mock_response.json = return_val
|
||||
mock_response.headers = {"key": "value"}
|
||||
mock_response.status_code = 200
|
||||
|
||||
with patch(
|
||||
"litellm.llms.custom_httpx.http_handler.AsyncHTTPHandler.post",
|
||||
return_value=mock_response,
|
||||
) as mock_post:
|
||||
response = await litellm.aembedding(
|
||||
model="infinity/custom-model/embedding-v1",
|
||||
input=["test input"],
|
||||
dimensions=512,
|
||||
encoding_format="float",
|
||||
modality="text",
|
||||
api_base="https://api.infinity.ai/embeddings",
|
||||
)
|
||||
|
||||
mock_post.assert_called_once()
|
||||
json_data = json.loads(mock_post.call_args.kwargs["data"])
|
||||
|
||||
# Assert the request parameters
|
||||
assert json_data["input"] == ["test input"]
|
||||
assert json_data["model"] == "custom-model/embedding-v1"
|
||||
assert json_data["output_dimension"] == 512
|
||||
assert json_data["encoding_format"] == "float"
|
||||
assert json_data["modality"] == "text"
|
||||
|
||||
|
||||
@pytest.mark.asyncio()
|
||||
async def test_infinity_embedding_prompt_token_mapping():
|
||||
mock_response = AsyncMock()
|
||||
|
||||
def return_val():
|
||||
return {
|
||||
"data": [{"embedding": [0.1, 0.2, 0.3], "index": 0}],
|
||||
"usage": {"total_tokens": 1, "prompt_tokens": 1},
|
||||
"model": "custom-model/embedding-v1",
|
||||
"object": "list"
|
||||
}
|
||||
|
||||
mock_response.json = return_val
|
||||
mock_response.headers = {"key": "value"}
|
||||
mock_response.status_code = 200
|
||||
|
||||
with patch(
|
||||
"litellm.llms.custom_httpx.http_handler.AsyncHTTPHandler.post",
|
||||
return_value=mock_response,
|
||||
) as mock_post:
|
||||
response = await litellm.aembedding(
|
||||
model="infinity/custom-model/embedding-v1",
|
||||
input=["a"],
|
||||
dimensions=512,
|
||||
encoding_format="float",
|
||||
api_base="https://api.infinity.ai/embeddings",
|
||||
)
|
||||
|
||||
mock_post.assert_called_once()
|
||||
# Assert the response
|
||||
assert response.usage.prompt_tokens == 1
|
||||
assert response.usage.total_tokens == 1
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue