mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-27 19:54:13 +00:00
(feat) add infinity rerank models (#7321)
* Support Infinity Reranker (custom reranking models) (#7247) * Support Infinity Reranker * Clean code * Included transformation.py * Clean code * Added Infinity reranker test * Clean code --------- Co-authored-by: Ishaan Jaff <ishaanjaffer0324@gmail.com> * transform_rerank_response * update handler.py * infinity rerank updates * ci/cd run again * add infinity unit tests * docs add instruction on how to add a new provider for rerank --------- Co-authored-by: Hao Shan <53949959+haoshan98@users.noreply.github.com>
This commit is contained in:
parent
50204d9a6d
commit
6641e75e0c
11 changed files with 414 additions and 1 deletions
21
docs/my-website/docs/adding_provider/directory_structure.md
Normal file
21
docs/my-website/docs/adding_provider/directory_structure.md
Normal file
|
@ -0,0 +1,21 @@
|
||||||
|
# Directory Structure
|
||||||
|
|
||||||
|
When adding a new provider, you need to create a directory for the provider that follows the following structure:
|
||||||
|
|
||||||
|
```
|
||||||
|
litellm/llms/
|
||||||
|
└── provider_name/
|
||||||
|
├── completion/
|
||||||
|
│ ├── handler.py
|
||||||
|
│ └── transformation.py
|
||||||
|
├── chat/
|
||||||
|
│ ├── handler.py
|
||||||
|
│ └── transformation.py
|
||||||
|
├── embed/
|
||||||
|
│ ├── handler.py
|
||||||
|
│ └── transformation.py
|
||||||
|
└── rerank/
|
||||||
|
├── handler.py
|
||||||
|
└── transformation.py
|
||||||
|
```
|
||||||
|
|
81
docs/my-website/docs/adding_provider/new_rerank_provider.md
Normal file
81
docs/my-website/docs/adding_provider/new_rerank_provider.md
Normal file
|
@ -0,0 +1,81 @@
|
||||||
|
# Add Rerank Provider
|
||||||
|
|
||||||
|
LiteLLM **follows the Cohere Rerank API format** for all rerank providers. Here's how to add a new rerank provider:
|
||||||
|
|
||||||
|
## 1. Create a transformation.py file
|
||||||
|
|
||||||
|
Create a config class named `<Provider><Endpoint>Config` that inherits from `BaseRerankConfig`:
|
||||||
|
|
||||||
|
```python
|
||||||
|
from litellm.types.rerank import OptionalRerankParams, RerankRequest, RerankResponse
|
||||||
|
class YourProviderRerankConfig(BaseRerankConfig):
|
||||||
|
def get_supported_cohere_rerank_params(self, model: str) -> list:
|
||||||
|
return [
|
||||||
|
"query",
|
||||||
|
"documents",
|
||||||
|
"top_n",
|
||||||
|
# ... other supported params
|
||||||
|
]
|
||||||
|
|
||||||
|
def transform_rerank_request(self, model: str, optional_rerank_params: OptionalRerankParams, headers: dict) -> dict:
|
||||||
|
# Transform request to RerankRequest spec
|
||||||
|
return rerank_request.model_dump(exclude_none=True)
|
||||||
|
|
||||||
|
def transform_rerank_response(self, model: str, raw_response: httpx.Response, ...) -> RerankResponse:
|
||||||
|
# Transform provider response to RerankResponse
|
||||||
|
return RerankResponse(**raw_response_json)
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
|
## 2. Register Your Provider
|
||||||
|
Add your provider to `litellm.utils.get_provider_rerank_config()`:
|
||||||
|
|
||||||
|
```python
|
||||||
|
elif litellm.LlmProviders.YOUR_PROVIDER == provider:
|
||||||
|
return litellm.YourProviderRerankConfig()
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
|
## 3. Add Provider to `rerank_api/main.py`
|
||||||
|
|
||||||
|
Add a code block to handle when your provider is called. Your provider should use the `base_llm_http_handler.rerank` method
|
||||||
|
|
||||||
|
|
||||||
|
```python
|
||||||
|
elif _custom_llm_provider == "your_provider":
|
||||||
|
...
|
||||||
|
response = base_llm_http_handler.rerank(
|
||||||
|
model=model,
|
||||||
|
custom_llm_provider=_custom_llm_provider,
|
||||||
|
optional_rerank_params=optional_rerank_params,
|
||||||
|
logging_obj=litellm_logging_obj,
|
||||||
|
timeout=optional_params.timeout,
|
||||||
|
api_key=dynamic_api_key or optional_params.api_key,
|
||||||
|
api_base=api_base,
|
||||||
|
_is_async=_is_async,
|
||||||
|
headers=headers or litellm.headers or {},
|
||||||
|
client=client,
|
||||||
|
mod el_response=model_response,
|
||||||
|
)
|
||||||
|
...
|
||||||
|
```
|
||||||
|
|
||||||
|
## 4. Add Tests
|
||||||
|
|
||||||
|
Add a test file to [`tests/llm_translation`](https://github.com/BerriAI/litellm/tree/main/tests/llm_translation)
|
||||||
|
|
||||||
|
```python
|
||||||
|
def test_basic_rerank_cohere():
|
||||||
|
response = litellm.rerank(
|
||||||
|
model="cohere/rerank-english-v3.0",
|
||||||
|
query="hello",
|
||||||
|
documents=["hello", "world"],
|
||||||
|
top_n=3,
|
||||||
|
)
|
||||||
|
|
||||||
|
print("re rank response: ", response)
|
||||||
|
|
||||||
|
assert response.id is not None
|
||||||
|
assert response.results is not None
|
||||||
|
```
|
||||||
|
|
|
@ -320,6 +320,13 @@ const sidebars = {
|
||||||
"load_test_rpm",
|
"load_test_rpm",
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
type: "category",
|
||||||
|
label: "Adding Providers",
|
||||||
|
items: [
|
||||||
|
"adding_provider/directory_structure",
|
||||||
|
"adding_provider/new_rerank_provider"],
|
||||||
|
},
|
||||||
{
|
{
|
||||||
type: "category",
|
type: "category",
|
||||||
label: "Logging & Observability",
|
label: "Logging & Observability",
|
||||||
|
|
|
@ -126,6 +126,7 @@ azure_key: Optional[str] = None
|
||||||
anthropic_key: Optional[str] = None
|
anthropic_key: Optional[str] = None
|
||||||
replicate_key: Optional[str] = None
|
replicate_key: Optional[str] = None
|
||||||
cohere_key: Optional[str] = None
|
cohere_key: Optional[str] = None
|
||||||
|
infinity_key: Optional[str] = None
|
||||||
clarifai_key: Optional[str] = None
|
clarifai_key: Optional[str] = None
|
||||||
maritalk_key: Optional[str] = None
|
maritalk_key: Optional[str] = None
|
||||||
ai21_key: Optional[str] = None
|
ai21_key: Optional[str] = None
|
||||||
|
@ -1025,6 +1026,7 @@ from .llms.replicate.chat.transformation import ReplicateConfig
|
||||||
from .llms.cohere.completion.transformation import CohereTextConfig as CohereConfig
|
from .llms.cohere.completion.transformation import CohereTextConfig as CohereConfig
|
||||||
from .llms.cohere.rerank.transformation import CohereRerankConfig
|
from .llms.cohere.rerank.transformation import CohereRerankConfig
|
||||||
from .llms.azure_ai.rerank.transformation import AzureAIRerankConfig
|
from .llms.azure_ai.rerank.transformation import AzureAIRerankConfig
|
||||||
|
from .llms.infinity.rerank.transformation import InfinityRerankConfig
|
||||||
from .llms.clarifai.chat.transformation import ClarifaiConfig
|
from .llms.clarifai.chat.transformation import ClarifaiConfig
|
||||||
from .llms.ai21.chat.transformation import AI21ChatConfig, AI21ChatConfig as AI21Config
|
from .llms.ai21.chat.transformation import AI21ChatConfig, AI21ChatConfig as AI21Config
|
||||||
from .llms.together_ai.chat import TogetherAIConfig
|
from .llms.together_ai.chat import TogetherAIConfig
|
||||||
|
|
19
litellm/llms/infinity/rerank/common_utils.py
Normal file
19
litellm/llms/infinity/rerank/common_utils.py
Normal file
|
@ -0,0 +1,19 @@
|
||||||
|
import httpx
|
||||||
|
|
||||||
|
from litellm.llms.base_llm.chat.transformation import BaseLLMException
|
||||||
|
|
||||||
|
|
||||||
|
class InfinityError(BaseLLMException):
|
||||||
|
def __init__(self, status_code, message):
|
||||||
|
self.status_code = status_code
|
||||||
|
self.message = message
|
||||||
|
self.request = httpx.Request(
|
||||||
|
method="POST", url="https://github.com/michaelfeil/infinity"
|
||||||
|
)
|
||||||
|
self.response = httpx.Response(status_code=status_code, request=self.request)
|
||||||
|
super().__init__(
|
||||||
|
status_code=status_code,
|
||||||
|
message=message,
|
||||||
|
request=self.request,
|
||||||
|
response=self.response,
|
||||||
|
) # Call the base class constructor with the parameters it needs
|
5
litellm/llms/infinity/rerank/handler.py
Normal file
5
litellm/llms/infinity/rerank/handler.py
Normal file
|
@ -0,0 +1,5 @@
|
||||||
|
"""
|
||||||
|
Infinity Rerank - uses `llm_http_handler.py` to make httpx requests
|
||||||
|
|
||||||
|
Request/Response transformation is handled in `transformation.py`
|
||||||
|
"""
|
91
litellm/llms/infinity/rerank/transformation.py
Normal file
91
litellm/llms/infinity/rerank/transformation.py
Normal file
|
@ -0,0 +1,91 @@
|
||||||
|
"""
|
||||||
|
Transformation logic from Cohere's /v1/rerank format to Infinity's `/v1/rerank` format.
|
||||||
|
|
||||||
|
Why separate file? Make it easy to see how transformation works
|
||||||
|
"""
|
||||||
|
|
||||||
|
import uuid
|
||||||
|
from typing import List, Optional
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
|
||||||
|
import litellm
|
||||||
|
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
|
||||||
|
from litellm.llms.cohere.rerank.transformation import CohereRerankConfig
|
||||||
|
from litellm.secret_managers.main import get_secret_str
|
||||||
|
from litellm.types.rerank import RerankBilledUnits, RerankResponseMeta, RerankTokens
|
||||||
|
from litellm.types.utils import RerankResponse
|
||||||
|
|
||||||
|
from .common_utils import InfinityError
|
||||||
|
|
||||||
|
|
||||||
|
class InfinityRerankConfig(CohereRerankConfig):
|
||||||
|
def validate_environment(
|
||||||
|
self,
|
||||||
|
headers: dict,
|
||||||
|
model: str,
|
||||||
|
api_key: Optional[str] = None,
|
||||||
|
) -> dict:
|
||||||
|
if api_key is None:
|
||||||
|
api_key = (
|
||||||
|
get_secret_str("INFINITY_API_KEY")
|
||||||
|
or get_secret_str("INFINITY_API_KEY")
|
||||||
|
or litellm.infinity_key
|
||||||
|
)
|
||||||
|
|
||||||
|
default_headers = {
|
||||||
|
"Authorization": f"bearer {api_key}",
|
||||||
|
"accept": "application/json",
|
||||||
|
"content-type": "application/json",
|
||||||
|
}
|
||||||
|
|
||||||
|
# If 'Authorization' is provided in headers, it overrides the default.
|
||||||
|
if "Authorization" in headers:
|
||||||
|
default_headers["Authorization"] = headers["Authorization"]
|
||||||
|
|
||||||
|
# Merge other headers, overriding any default ones except Authorization
|
||||||
|
return {**default_headers, **headers}
|
||||||
|
|
||||||
|
def transform_rerank_response(
|
||||||
|
self,
|
||||||
|
model: str,
|
||||||
|
raw_response: httpx.Response,
|
||||||
|
model_response: RerankResponse,
|
||||||
|
logging_obj: LiteLLMLoggingObj,
|
||||||
|
api_key: Optional[str] = None,
|
||||||
|
request_data: dict = {},
|
||||||
|
optional_params: dict = {},
|
||||||
|
litellm_params: dict = {},
|
||||||
|
) -> RerankResponse:
|
||||||
|
"""
|
||||||
|
Transform Infinity rerank response
|
||||||
|
|
||||||
|
No transformation required, Infinity follows Cohere API response format
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
raw_response_json = raw_response.json()
|
||||||
|
except Exception:
|
||||||
|
raise InfinityError(
|
||||||
|
message=raw_response.text, status_code=raw_response.status_code
|
||||||
|
)
|
||||||
|
|
||||||
|
_billed_units = RerankBilledUnits(**raw_response_json.get("usage", {}))
|
||||||
|
_tokens = RerankTokens(
|
||||||
|
input_tokens=raw_response_json.get("usage", {}).get("prompt_tokens", 0),
|
||||||
|
output_tokens=(
|
||||||
|
raw_response_json.get("usage", {}).get("total_tokens", 0)
|
||||||
|
- raw_response_json.get("usage", {}).get("prompt_tokens", 0)
|
||||||
|
),
|
||||||
|
)
|
||||||
|
rerank_meta = RerankResponseMeta(billed_units=_billed_units, tokens=_tokens)
|
||||||
|
|
||||||
|
_results: Optional[List[dict]] = raw_response_json.get("results")
|
||||||
|
|
||||||
|
if _results is None:
|
||||||
|
raise ValueError(f"No results found in the response={raw_response_json}")
|
||||||
|
|
||||||
|
return RerankResponse(
|
||||||
|
id=raw_response_json.get("id") or str(uuid.uuid4()),
|
||||||
|
results=_results, # type: ignore
|
||||||
|
meta=rerank_meta,
|
||||||
|
) # Return response
|
|
@ -76,7 +76,9 @@ def rerank( # noqa: PLR0915
|
||||||
model: str,
|
model: str,
|
||||||
query: str,
|
query: str,
|
||||||
documents: List[Union[str, Dict[str, Any]]],
|
documents: List[Union[str, Dict[str, Any]]],
|
||||||
custom_llm_provider: Optional[Literal["cohere", "together_ai", "azure_ai"]] = None,
|
custom_llm_provider: Optional[
|
||||||
|
Literal["cohere", "together_ai", "azure_ai", "infinity"]
|
||||||
|
] = None,
|
||||||
top_n: Optional[int] = None,
|
top_n: Optional[int] = None,
|
||||||
rank_fields: Optional[List[str]] = None,
|
rank_fields: Optional[List[str]] = None,
|
||||||
return_documents: Optional[bool] = True,
|
return_documents: Optional[bool] = True,
|
||||||
|
@ -188,6 +190,37 @@ def rerank( # noqa: PLR0915
|
||||||
or litellm.api_base
|
or litellm.api_base
|
||||||
or get_secret("AZURE_AI_API_BASE") # type: ignore
|
or get_secret("AZURE_AI_API_BASE") # type: ignore
|
||||||
)
|
)
|
||||||
|
response = base_llm_http_handler.rerank(
|
||||||
|
model=model,
|
||||||
|
custom_llm_provider=_custom_llm_provider,
|
||||||
|
optional_rerank_params=optional_rerank_params,
|
||||||
|
logging_obj=litellm_logging_obj,
|
||||||
|
timeout=optional_params.timeout,
|
||||||
|
api_key=dynamic_api_key or optional_params.api_key,
|
||||||
|
api_base=api_base,
|
||||||
|
_is_async=_is_async,
|
||||||
|
headers=headers or litellm.headers or {},
|
||||||
|
client=client,
|
||||||
|
model_response=model_response,
|
||||||
|
)
|
||||||
|
elif _custom_llm_provider == "infinity":
|
||||||
|
# Implement Infinity rerank logic
|
||||||
|
api_key: Optional[str] = (
|
||||||
|
dynamic_api_key or optional_params.api_key or litellm.api_key
|
||||||
|
)
|
||||||
|
|
||||||
|
api_base: Optional[str] = (
|
||||||
|
dynamic_api_base
|
||||||
|
or optional_params.api_base
|
||||||
|
or litellm.api_base
|
||||||
|
or get_secret("INFINITY_API_BASE") # type: ignore
|
||||||
|
)
|
||||||
|
|
||||||
|
if api_base is None:
|
||||||
|
raise Exception(
|
||||||
|
"Invalid api base. api_base=None. Set in call or via `INFINITY_API_BASE` env var."
|
||||||
|
)
|
||||||
|
|
||||||
response = base_llm_http_handler.rerank(
|
response = base_llm_http_handler.rerank(
|
||||||
model=model,
|
model=model,
|
||||||
custom_llm_provider=_custom_llm_provider,
|
custom_llm_provider=_custom_llm_provider,
|
||||||
|
|
|
@ -1741,6 +1741,7 @@ class LlmProviders(str, Enum):
|
||||||
HOSTED_VLLM = "hosted_vllm"
|
HOSTED_VLLM = "hosted_vllm"
|
||||||
LM_STUDIO = "lm_studio"
|
LM_STUDIO = "lm_studio"
|
||||||
GALADRIEL = "galadriel"
|
GALADRIEL = "galadriel"
|
||||||
|
INFINITY = "infinity"
|
||||||
|
|
||||||
|
|
||||||
class LiteLLMLoggingBaseClass:
|
class LiteLLMLoggingBaseClass:
|
||||||
|
|
|
@ -6214,6 +6214,8 @@ class ProviderConfigManager:
|
||||||
return litellm.CohereRerankConfig()
|
return litellm.CohereRerankConfig()
|
||||||
elif litellm.LlmProviders.AZURE_AI == provider:
|
elif litellm.LlmProviders.AZURE_AI == provider:
|
||||||
return litellm.AzureAIRerankConfig()
|
return litellm.AzureAIRerankConfig()
|
||||||
|
elif litellm.LlmProviders.INFINITY == provider:
|
||||||
|
return litellm.InfinityRerankConfig()
|
||||||
return litellm.CohereRerankConfig()
|
return litellm.CohereRerankConfig()
|
||||||
|
|
||||||
|
|
||||||
|
|
151
tests/llm_translation/test_infinity.py
Normal file
151
tests/llm_translation/test_infinity.py
Normal file
|
@ -0,0 +1,151 @@
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
from datetime import datetime
|
||||||
|
from unittest.mock import AsyncMock
|
||||||
|
|
||||||
|
sys.path.insert(
|
||||||
|
0, os.path.abspath("../..")
|
||||||
|
) # Adds the parent directory to the system-path
|
||||||
|
|
||||||
|
|
||||||
|
import litellm
|
||||||
|
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
from datetime import datetime
|
||||||
|
from unittest.mock import AsyncMock, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
sys.path.insert(
|
||||||
|
0, os.path.abspath("../..")
|
||||||
|
) # Adds the parent directory to the system-path
|
||||||
|
from test_rerank import assert_response_shape
|
||||||
|
import litellm
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio()
|
||||||
|
async def test_infinity_rerank():
|
||||||
|
mock_response = AsyncMock()
|
||||||
|
|
||||||
|
def return_val():
|
||||||
|
return {
|
||||||
|
"id": "cmpl-mockid",
|
||||||
|
"results": [{"index": 0, "relevance_score": 0.95}],
|
||||||
|
"usage": {"prompt_tokens": 100, "total_tokens": 150},
|
||||||
|
}
|
||||||
|
|
||||||
|
mock_response.json = return_val
|
||||||
|
mock_response.headers = {"key": "value"}
|
||||||
|
mock_response.status_code = 200
|
||||||
|
|
||||||
|
expected_payload = {
|
||||||
|
"model": "rerank-model",
|
||||||
|
"query": "hello",
|
||||||
|
"top_n": 3,
|
||||||
|
"documents": ["hello", "world"],
|
||||||
|
}
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"litellm.llms.custom_httpx.http_handler.AsyncHTTPHandler.post",
|
||||||
|
return_value=mock_response,
|
||||||
|
) as mock_post:
|
||||||
|
response = await litellm.arerank(
|
||||||
|
model="infinity/rerank-model",
|
||||||
|
query="hello",
|
||||||
|
documents=["hello", "world"],
|
||||||
|
top_n=3,
|
||||||
|
api_base="https://api.infinity.ai",
|
||||||
|
)
|
||||||
|
|
||||||
|
print("async re rank response: ", response)
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
mock_post.assert_called_once()
|
||||||
|
print("call args", mock_post.call_args)
|
||||||
|
args_to_api = mock_post.call_args.kwargs["data"]
|
||||||
|
_url = mock_post.call_args.kwargs["url"]
|
||||||
|
print("Arguments passed to API=", args_to_api)
|
||||||
|
print("url = ", _url)
|
||||||
|
assert _url == "https://api.infinity.ai/v1/rerank"
|
||||||
|
|
||||||
|
request_data = json.loads(args_to_api)
|
||||||
|
assert request_data["query"] == expected_payload["query"]
|
||||||
|
assert request_data["documents"] == expected_payload["documents"]
|
||||||
|
assert request_data["top_n"] == expected_payload["top_n"]
|
||||||
|
assert request_data["model"] == expected_payload["model"]
|
||||||
|
|
||||||
|
assert response.id is not None
|
||||||
|
assert response.results is not None
|
||||||
|
assert response.meta["tokens"]["input_tokens"] == 100
|
||||||
|
assert (
|
||||||
|
response.meta["tokens"]["output_tokens"] == 50
|
||||||
|
) # total_tokens - prompt_tokens
|
||||||
|
|
||||||
|
assert_response_shape(response, custom_llm_provider="infinity")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio()
|
||||||
|
async def test_infinity_rerank_with_env(monkeypatch):
|
||||||
|
# Set up mock response
|
||||||
|
mock_response = AsyncMock()
|
||||||
|
|
||||||
|
def return_val():
|
||||||
|
return {
|
||||||
|
"id": "cmpl-mockid",
|
||||||
|
"results": [{"index": 0, "relevance_score": 0.95}],
|
||||||
|
"usage": {"prompt_tokens": 100, "total_tokens": 150},
|
||||||
|
}
|
||||||
|
|
||||||
|
mock_response.json = return_val
|
||||||
|
mock_response.headers = {"key": "value"}
|
||||||
|
mock_response.status_code = 200
|
||||||
|
|
||||||
|
# Set environment variable
|
||||||
|
monkeypatch.setenv("INFINITY_API_BASE", "https://env.infinity.ai")
|
||||||
|
|
||||||
|
expected_payload = {
|
||||||
|
"model": "rerank-model",
|
||||||
|
"query": "hello",
|
||||||
|
"top_n": 3,
|
||||||
|
"documents": ["hello", "world"],
|
||||||
|
}
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"litellm.llms.custom_httpx.http_handler.AsyncHTTPHandler.post",
|
||||||
|
return_value=mock_response,
|
||||||
|
) as mock_post:
|
||||||
|
response = await litellm.arerank(
|
||||||
|
model="infinity/rerank-model",
|
||||||
|
query="hello",
|
||||||
|
documents=["hello", "world"],
|
||||||
|
top_n=3,
|
||||||
|
)
|
||||||
|
|
||||||
|
print("async re rank response: ", response)
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
mock_post.assert_called_once()
|
||||||
|
print("call args", mock_post.call_args)
|
||||||
|
args_to_api = mock_post.call_args.kwargs["data"]
|
||||||
|
_url = mock_post.call_args.kwargs["url"]
|
||||||
|
print("Arguments passed to API=", args_to_api)
|
||||||
|
print("url = ", _url)
|
||||||
|
assert _url == "https://env.infinity.ai/v1/rerank"
|
||||||
|
|
||||||
|
request_data = json.loads(args_to_api)
|
||||||
|
assert request_data["query"] == expected_payload["query"]
|
||||||
|
assert request_data["documents"] == expected_payload["documents"]
|
||||||
|
assert request_data["top_n"] == expected_payload["top_n"]
|
||||||
|
assert request_data["model"] == expected_payload["model"]
|
||||||
|
|
||||||
|
assert response.id is not None
|
||||||
|
assert response.results is not None
|
||||||
|
assert response.meta["tokens"]["input_tokens"] == 100
|
||||||
|
assert (
|
||||||
|
response.meta["tokens"]["output_tokens"] == 50
|
||||||
|
) # total_tokens - prompt_tokens
|
||||||
|
|
||||||
|
assert_response_shape(response, custom_llm_provider="infinity")
|
Loading…
Add table
Add a link
Reference in a new issue