mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 18:54:30 +00:00
* build(model_prices_and_context_window.json): mark bedrock llama as supporting vision based on docs * Add price for Cerebras llama3.3-70b (#8676) * docs(readme.md): fix contributing docs point people to new mock directory testing structure s/o @vibhavbhat * build: update contributing readme * docs(readme.md): improve docs * docs(readme.md): cleanup readme on tests/ * docs(README.md): cleanup doc * feat(infinity/): support returning documents when return_documents=True * test(test_rerank.py): add e2e testing for cohere rerank * fix: fix linting errors * fix(together_ai/): fix together ai transformation * fix: fix linting error * fix: fix linting errors * fix: fix linting errors * test: mark cohere as flaky * build: fix model supports check * test: fix test * test: mark flaky test * fix: fix test * test: fix test --------- Co-authored-by: Yury Koleda <fut.wrk@gmail.com>
184 lines
5.5 KiB
Python
184 lines
5.5 KiB
Python
import json
|
|
import os
|
|
import sys
|
|
from datetime import datetime
|
|
from unittest.mock import AsyncMock
|
|
|
|
sys.path.insert(
|
|
0, os.path.abspath("../..")
|
|
) # Adds the parent directory to the system-path
|
|
|
|
|
|
import litellm
|
|
|
|
import json
|
|
import os
|
|
import sys
|
|
from datetime import datetime
|
|
from unittest.mock import AsyncMock, patch
|
|
|
|
import pytest
|
|
|
|
sys.path.insert(
|
|
0, os.path.abspath("../..")
|
|
) # Adds the parent directory to the system-path
|
|
from test_rerank import assert_response_shape
|
|
import litellm
|
|
|
|
|
|
@pytest.mark.asyncio()
|
|
async def test_infinity_rerank():
|
|
mock_response = AsyncMock()
|
|
|
|
def return_val():
|
|
return {
|
|
"id": "cmpl-mockid",
|
|
"results": [{"index": 0, "relevance_score": 0.95}],
|
|
"usage": {"prompt_tokens": 100, "total_tokens": 150},
|
|
}
|
|
|
|
mock_response.json = return_val
|
|
mock_response.headers = {"key": "value"}
|
|
mock_response.status_code = 200
|
|
|
|
expected_payload = {
|
|
"model": "rerank-model",
|
|
"query": "hello",
|
|
"top_n": 3,
|
|
"documents": ["hello", "world"],
|
|
}
|
|
|
|
with patch(
|
|
"litellm.llms.custom_httpx.http_handler.AsyncHTTPHandler.post",
|
|
return_value=mock_response,
|
|
) as mock_post:
|
|
response = await litellm.arerank(
|
|
model="infinity/rerank-model",
|
|
query="hello",
|
|
documents=["hello", "world"],
|
|
top_n=3,
|
|
api_base="https://api.infinity.ai",
|
|
)
|
|
|
|
print("async re rank response: ", response)
|
|
|
|
# Assert
|
|
mock_post.assert_called_once()
|
|
print("call args", mock_post.call_args)
|
|
args_to_api = mock_post.call_args.kwargs["data"]
|
|
_url = mock_post.call_args.kwargs["url"]
|
|
print("Arguments passed to API=", args_to_api)
|
|
print("url = ", _url)
|
|
assert _url == "https://api.infinity.ai/rerank"
|
|
|
|
request_data = json.loads(args_to_api)
|
|
assert request_data["query"] == expected_payload["query"]
|
|
assert request_data["documents"] == expected_payload["documents"]
|
|
assert request_data["top_n"] == expected_payload["top_n"]
|
|
assert request_data["model"] == expected_payload["model"]
|
|
|
|
assert response.id is not None
|
|
assert response.results is not None
|
|
assert response.meta["tokens"]["input_tokens"] == 100
|
|
assert (
|
|
response.meta["tokens"]["output_tokens"] == 50
|
|
) # total_tokens - prompt_tokens
|
|
|
|
assert_response_shape(response, custom_llm_provider="infinity")
|
|
|
|
|
|
@pytest.mark.asyncio()
|
|
async def test_infinity_rerank_with_return_documents():
|
|
mock_response = AsyncMock()
|
|
|
|
mock_response = AsyncMock()
|
|
|
|
def return_val():
|
|
return {
|
|
"id": "cmpl-mockid",
|
|
"results": [{"index": 0, "relevance_score": 0.95, "document": "hello"}],
|
|
"usage": {"prompt_tokens": 100, "total_tokens": 150},
|
|
}
|
|
|
|
mock_response.json = return_val
|
|
mock_response.headers = {"key": "value"}
|
|
mock_response.status_code = 200
|
|
|
|
with patch(
|
|
"litellm.llms.custom_httpx.http_handler.AsyncHTTPHandler.post",
|
|
return_value=mock_response,
|
|
) as mock_post:
|
|
response = await litellm.arerank(
|
|
model="infinity/rerank-model",
|
|
query="hello",
|
|
documents=["hello", "world"],
|
|
top_n=3,
|
|
return_documents=True,
|
|
api_base="https://api.infinity.ai",
|
|
)
|
|
assert response.results[0]["document"] == {"text": "hello"}
|
|
assert_response_shape(response, custom_llm_provider="infinity")
|
|
|
|
|
|
@pytest.mark.asyncio()
|
|
async def test_infinity_rerank_with_env(monkeypatch):
|
|
# Set up mock response
|
|
mock_response = AsyncMock()
|
|
|
|
def return_val():
|
|
return {
|
|
"id": "cmpl-mockid",
|
|
"results": [{"index": 0, "relevance_score": 0.95}],
|
|
"usage": {"prompt_tokens": 100, "total_tokens": 150},
|
|
}
|
|
|
|
mock_response.json = return_val
|
|
mock_response.headers = {"key": "value"}
|
|
mock_response.status_code = 200
|
|
|
|
# Set environment variable
|
|
monkeypatch.setenv("INFINITY_API_BASE", "https://env.infinity.ai")
|
|
|
|
expected_payload = {
|
|
"model": "rerank-model",
|
|
"query": "hello",
|
|
"top_n": 3,
|
|
"documents": ["hello", "world"],
|
|
}
|
|
|
|
with patch(
|
|
"litellm.llms.custom_httpx.http_handler.AsyncHTTPHandler.post",
|
|
return_value=mock_response,
|
|
) as mock_post:
|
|
response = await litellm.arerank(
|
|
model="infinity/rerank-model",
|
|
query="hello",
|
|
documents=["hello", "world"],
|
|
top_n=3,
|
|
)
|
|
|
|
print("async re rank response: ", response)
|
|
|
|
# Assert
|
|
mock_post.assert_called_once()
|
|
print("call args", mock_post.call_args)
|
|
args_to_api = mock_post.call_args.kwargs["data"]
|
|
_url = mock_post.call_args.kwargs["url"]
|
|
print("Arguments passed to API=", args_to_api)
|
|
print("url = ", _url)
|
|
assert _url == "https://env.infinity.ai/rerank"
|
|
|
|
request_data = json.loads(args_to_api)
|
|
assert request_data["query"] == expected_payload["query"]
|
|
assert request_data["documents"] == expected_payload["documents"]
|
|
assert request_data["top_n"] == expected_payload["top_n"]
|
|
assert request_data["model"] == expected_payload["model"]
|
|
|
|
assert response.id is not None
|
|
assert response.results is not None
|
|
assert response.meta["tokens"]["input_tokens"] == 100
|
|
assert (
|
|
response.meta["tokens"]["output_tokens"] == 50
|
|
) # total_tokens - prompt_tokens
|
|
|
|
assert_response_shape(response, custom_llm_provider="infinity")
|