add test for rerank on custom api base

This commit is contained in:
Ishaan Jaff 2024-08-27 18:25:51 -07:00
parent 09e9e4aebf
commit a80b2aebbb
3 changed files with 60 additions and 5 deletions

View file

@ -22,6 +22,7 @@ class CohereRerank(BaseLLM):
self, self,
model: str, model: str,
api_key: str, api_key: str,
api_base: str,
query: str, query: str,
documents: List[Union[str, Dict[str, Any]]], documents: List[Union[str, Dict[str, Any]]],
top_n: Optional[int] = None, top_n: Optional[int] = None,
@ -43,11 +44,11 @@ class CohereRerank(BaseLLM):
request_data_dict = request_data.dict(exclude_none=True) request_data_dict = request_data.dict(exclude_none=True)
if _is_async: if _is_async:
return self.async_rerank(request_data_dict, api_key) # type: ignore # Call async method return self.async_rerank(request_data_dict, api_key, api_base) # type: ignore # Call async method
client = _get_httpx_client() client = _get_httpx_client()
response = client.post( response = client.post(
"https://api.cohere.com/v1/rerank", api_base,
headers={ headers={
"accept": "application/json", "accept": "application/json",
"content-type": "application/json", "content-type": "application/json",
@ -62,11 +63,12 @@ class CohereRerank(BaseLLM):
self, self,
request_data_dict: Dict[str, Any], request_data_dict: Dict[str, Any],
api_key: str, api_key: str,
api_base: str,
) -> RerankResponse: ) -> RerankResponse:
client = _get_async_httpx_client() client = _get_async_httpx_client()
response = await client.post( response = await client.post(
"https://api.cohere.com/v1/rerank", api_base,
headers={ headers={
"accept": "application/json", "accept": "application/json",
"content-type": "application/json", "content-type": "application/json",

View file

@ -27,7 +27,7 @@ async def arerank(
custom_llm_provider: Optional[Literal["cohere", "together_ai"]] = None, custom_llm_provider: Optional[Literal["cohere", "together_ai"]] = None,
top_n: Optional[int] = None, top_n: Optional[int] = None,
rank_fields: Optional[List[str]] = None, rank_fields: Optional[List[str]] = None,
return_documents: Optional[bool] = True, return_documents: Optional[bool] = None,
max_chunks_per_doc: Optional[int] = None, max_chunks_per_doc: Optional[int] = None,
**kwargs, **kwargs,
) -> Union[RerankResponse, Coroutine[Any, Any, RerankResponse]]: ) -> Union[RerankResponse, Coroutine[Any, Any, RerankResponse]]:
@ -112,7 +112,7 @@ def rerank(
optional_params.api_base optional_params.api_base
or litellm.api_base or litellm.api_base
or get_secret("COHERE_API_BASE") or get_secret("COHERE_API_BASE")
or "https://api.cohere.ai/v1/generate" or "https://api.cohere.com/v1/rerank"
) )
headers: Dict = litellm.headers or {} headers: Dict = litellm.headers or {}
@ -126,6 +126,7 @@ def rerank(
return_documents=return_documents, return_documents=return_documents,
max_chunks_per_doc=max_chunks_per_doc, max_chunks_per_doc=max_chunks_per_doc,
api_key=cohere_key, api_key=cohere_key,
api_base=api_base,
_is_async=_is_async, _is_async=_is_async,
) )
pass pass

View file

@ -125,3 +125,55 @@ async def test_basic_rerank_together_ai(sync_mode):
assert response.results is not None assert response.results is not None
assert_response_shape(response, custom_llm_provider="together_ai") assert_response_shape(response, custom_llm_provider="together_ai")
@pytest.mark.asyncio()
async def test_rerank_custom_api_base():
mock_response = AsyncMock()
def return_val():
return {
"id": "cmpl-mockid",
"results": [{"index": 0, "relevance_score": 0.95}],
"meta": {
"api_version": {"version": "1.0"},
"billed_units": {"search_units": 1},
},
}
mock_response.json = return_val
mock_response.status_code = 200
expected_payload = {
"model": "Salesforce/Llama-Rank-V1",
"query": "hello",
"documents": ["hello", "world"],
"top_n": 3,
}
with patch(
"litellm.llms.custom_httpx.http_handler.AsyncHTTPHandler.post",
return_value=mock_response,
) as mock_post:
response = await litellm.arerank(
model="cohere/Salesforce/Llama-Rank-V1",
query="hello",
documents=["hello", "world"],
top_n=3,
api_base="https://exampleopenaiendpoint-production.up.railway.app/",
)
print("async re rank response: ", response)
# Assert
mock_post.assert_called_once()
_url, kwargs = mock_post.call_args
args_to_api = kwargs["json"]
print("Arguments passed to API=", args_to_api)
print("url = ", _url)
assert _url[0] == "https://exampleopenaiendpoint-production.up.railway.app/"
assert args_to_api == expected_payload
assert response.id is not None
assert response.results is not None
assert_response_shape(response, custom_llm_provider="cohere")