mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-27 11:43:54 +00:00
add test for rerank on custom api base
This commit is contained in:
parent
09e9e4aebf
commit
a80b2aebbb
3 changed files with 60 additions and 5 deletions
|
@ -22,6 +22,7 @@ class CohereRerank(BaseLLM):
|
|||
self,
|
||||
model: str,
|
||||
api_key: str,
|
||||
api_base: str,
|
||||
query: str,
|
||||
documents: List[Union[str, Dict[str, Any]]],
|
||||
top_n: Optional[int] = None,
|
||||
|
@ -43,11 +44,11 @@ class CohereRerank(BaseLLM):
|
|||
request_data_dict = request_data.dict(exclude_none=True)
|
||||
|
||||
if _is_async:
|
||||
return self.async_rerank(request_data_dict, api_key) # type: ignore # Call async method
|
||||
return self.async_rerank(request_data_dict, api_key, api_base) # type: ignore # Call async method
|
||||
|
||||
client = _get_httpx_client()
|
||||
response = client.post(
|
||||
"https://api.cohere.com/v1/rerank",
|
||||
api_base,
|
||||
headers={
|
||||
"accept": "application/json",
|
||||
"content-type": "application/json",
|
||||
|
@ -62,11 +63,12 @@ class CohereRerank(BaseLLM):
|
|||
self,
|
||||
request_data_dict: Dict[str, Any],
|
||||
api_key: str,
|
||||
api_base: str,
|
||||
) -> RerankResponse:
|
||||
client = _get_async_httpx_client()
|
||||
|
||||
response = await client.post(
|
||||
"https://api.cohere.com/v1/rerank",
|
||||
api_base,
|
||||
headers={
|
||||
"accept": "application/json",
|
||||
"content-type": "application/json",
|
||||
|
|
|
@ -27,7 +27,7 @@ async def arerank(
|
|||
custom_llm_provider: Optional[Literal["cohere", "together_ai"]] = None,
|
||||
top_n: Optional[int] = None,
|
||||
rank_fields: Optional[List[str]] = None,
|
||||
return_documents: Optional[bool] = True,
|
||||
return_documents: Optional[bool] = None,
|
||||
max_chunks_per_doc: Optional[int] = None,
|
||||
**kwargs,
|
||||
) -> Union[RerankResponse, Coroutine[Any, Any, RerankResponse]]:
|
||||
|
@ -112,7 +112,7 @@ def rerank(
|
|||
optional_params.api_base
|
||||
or litellm.api_base
|
||||
or get_secret("COHERE_API_BASE")
|
||||
or "https://api.cohere.ai/v1/generate"
|
||||
or "https://api.cohere.com/v1/rerank"
|
||||
)
|
||||
|
||||
headers: Dict = litellm.headers or {}
|
||||
|
@ -126,6 +126,7 @@ def rerank(
|
|||
return_documents=return_documents,
|
||||
max_chunks_per_doc=max_chunks_per_doc,
|
||||
api_key=cohere_key,
|
||||
api_base=api_base,
|
||||
_is_async=_is_async,
|
||||
)
|
||||
pass
|
||||
|
|
|
@ -125,3 +125,55 @@ async def test_basic_rerank_together_ai(sync_mode):
|
|||
assert response.results is not None
|
||||
|
||||
assert_response_shape(response, custom_llm_provider="together_ai")
|
||||
|
||||
|
||||
@pytest.mark.asyncio()
|
||||
async def test_rerank_custom_api_base():
|
||||
mock_response = AsyncMock()
|
||||
|
||||
def return_val():
|
||||
return {
|
||||
"id": "cmpl-mockid",
|
||||
"results": [{"index": 0, "relevance_score": 0.95}],
|
||||
"meta": {
|
||||
"api_version": {"version": "1.0"},
|
||||
"billed_units": {"search_units": 1},
|
||||
},
|
||||
}
|
||||
|
||||
mock_response.json = return_val
|
||||
mock_response.status_code = 200
|
||||
|
||||
expected_payload = {
|
||||
"model": "Salesforce/Llama-Rank-V1",
|
||||
"query": "hello",
|
||||
"documents": ["hello", "world"],
|
||||
"top_n": 3,
|
||||
}
|
||||
|
||||
with patch(
|
||||
"litellm.llms.custom_httpx.http_handler.AsyncHTTPHandler.post",
|
||||
return_value=mock_response,
|
||||
) as mock_post:
|
||||
response = await litellm.arerank(
|
||||
model="cohere/Salesforce/Llama-Rank-V1",
|
||||
query="hello",
|
||||
documents=["hello", "world"],
|
||||
top_n=3,
|
||||
api_base="https://exampleopenaiendpoint-production.up.railway.app/",
|
||||
)
|
||||
|
||||
print("async re rank response: ", response)
|
||||
|
||||
# Assert
|
||||
mock_post.assert_called_once()
|
||||
_url, kwargs = mock_post.call_args
|
||||
args_to_api = kwargs["json"]
|
||||
print("Arguments passed to API=", args_to_api)
|
||||
print("url = ", _url)
|
||||
assert _url[0] == "https://exampleopenaiendpoint-production.up.railway.app/"
|
||||
assert args_to_api == expected_payload
|
||||
assert response.id is not None
|
||||
assert response.results is not None
|
||||
|
||||
assert_response_shape(response, custom_llm_provider="cohere")
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue