mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-27 19:54:13 +00:00
add test for rerank on custom api base
This commit is contained in:
parent
09e9e4aebf
commit
a80b2aebbb
3 changed files with 60 additions and 5 deletions
|
@ -22,6 +22,7 @@ class CohereRerank(BaseLLM):
|
||||||
self,
|
self,
|
||||||
model: str,
|
model: str,
|
||||||
api_key: str,
|
api_key: str,
|
||||||
|
api_base: str,
|
||||||
query: str,
|
query: str,
|
||||||
documents: List[Union[str, Dict[str, Any]]],
|
documents: List[Union[str, Dict[str, Any]]],
|
||||||
top_n: Optional[int] = None,
|
top_n: Optional[int] = None,
|
||||||
|
@ -43,11 +44,11 @@ class CohereRerank(BaseLLM):
|
||||||
request_data_dict = request_data.dict(exclude_none=True)
|
request_data_dict = request_data.dict(exclude_none=True)
|
||||||
|
|
||||||
if _is_async:
|
if _is_async:
|
||||||
return self.async_rerank(request_data_dict, api_key) # type: ignore # Call async method
|
return self.async_rerank(request_data_dict, api_key, api_base) # type: ignore # Call async method
|
||||||
|
|
||||||
client = _get_httpx_client()
|
client = _get_httpx_client()
|
||||||
response = client.post(
|
response = client.post(
|
||||||
"https://api.cohere.com/v1/rerank",
|
api_base,
|
||||||
headers={
|
headers={
|
||||||
"accept": "application/json",
|
"accept": "application/json",
|
||||||
"content-type": "application/json",
|
"content-type": "application/json",
|
||||||
|
@ -62,11 +63,12 @@ class CohereRerank(BaseLLM):
|
||||||
self,
|
self,
|
||||||
request_data_dict: Dict[str, Any],
|
request_data_dict: Dict[str, Any],
|
||||||
api_key: str,
|
api_key: str,
|
||||||
|
api_base: str,
|
||||||
) -> RerankResponse:
|
) -> RerankResponse:
|
||||||
client = _get_async_httpx_client()
|
client = _get_async_httpx_client()
|
||||||
|
|
||||||
response = await client.post(
|
response = await client.post(
|
||||||
"https://api.cohere.com/v1/rerank",
|
api_base,
|
||||||
headers={
|
headers={
|
||||||
"accept": "application/json",
|
"accept": "application/json",
|
||||||
"content-type": "application/json",
|
"content-type": "application/json",
|
||||||
|
|
|
@ -27,7 +27,7 @@ async def arerank(
|
||||||
custom_llm_provider: Optional[Literal["cohere", "together_ai"]] = None,
|
custom_llm_provider: Optional[Literal["cohere", "together_ai"]] = None,
|
||||||
top_n: Optional[int] = None,
|
top_n: Optional[int] = None,
|
||||||
rank_fields: Optional[List[str]] = None,
|
rank_fields: Optional[List[str]] = None,
|
||||||
return_documents: Optional[bool] = True,
|
return_documents: Optional[bool] = None,
|
||||||
max_chunks_per_doc: Optional[int] = None,
|
max_chunks_per_doc: Optional[int] = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> Union[RerankResponse, Coroutine[Any, Any, RerankResponse]]:
|
) -> Union[RerankResponse, Coroutine[Any, Any, RerankResponse]]:
|
||||||
|
@ -112,7 +112,7 @@ def rerank(
|
||||||
optional_params.api_base
|
optional_params.api_base
|
||||||
or litellm.api_base
|
or litellm.api_base
|
||||||
or get_secret("COHERE_API_BASE")
|
or get_secret("COHERE_API_BASE")
|
||||||
or "https://api.cohere.ai/v1/generate"
|
or "https://api.cohere.com/v1/rerank"
|
||||||
)
|
)
|
||||||
|
|
||||||
headers: Dict = litellm.headers or {}
|
headers: Dict = litellm.headers or {}
|
||||||
|
@ -126,6 +126,7 @@ def rerank(
|
||||||
return_documents=return_documents,
|
return_documents=return_documents,
|
||||||
max_chunks_per_doc=max_chunks_per_doc,
|
max_chunks_per_doc=max_chunks_per_doc,
|
||||||
api_key=cohere_key,
|
api_key=cohere_key,
|
||||||
|
api_base=api_base,
|
||||||
_is_async=_is_async,
|
_is_async=_is_async,
|
||||||
)
|
)
|
||||||
pass
|
pass
|
||||||
|
|
|
@ -125,3 +125,55 @@ async def test_basic_rerank_together_ai(sync_mode):
|
||||||
assert response.results is not None
|
assert response.results is not None
|
||||||
|
|
||||||
assert_response_shape(response, custom_llm_provider="together_ai")
|
assert_response_shape(response, custom_llm_provider="together_ai")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio()
|
||||||
|
async def test_rerank_custom_api_base():
|
||||||
|
mock_response = AsyncMock()
|
||||||
|
|
||||||
|
def return_val():
|
||||||
|
return {
|
||||||
|
"id": "cmpl-mockid",
|
||||||
|
"results": [{"index": 0, "relevance_score": 0.95}],
|
||||||
|
"meta": {
|
||||||
|
"api_version": {"version": "1.0"},
|
||||||
|
"billed_units": {"search_units": 1},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
mock_response.json = return_val
|
||||||
|
mock_response.status_code = 200
|
||||||
|
|
||||||
|
expected_payload = {
|
||||||
|
"model": "Salesforce/Llama-Rank-V1",
|
||||||
|
"query": "hello",
|
||||||
|
"documents": ["hello", "world"],
|
||||||
|
"top_n": 3,
|
||||||
|
}
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"litellm.llms.custom_httpx.http_handler.AsyncHTTPHandler.post",
|
||||||
|
return_value=mock_response,
|
||||||
|
) as mock_post:
|
||||||
|
response = await litellm.arerank(
|
||||||
|
model="cohere/Salesforce/Llama-Rank-V1",
|
||||||
|
query="hello",
|
||||||
|
documents=["hello", "world"],
|
||||||
|
top_n=3,
|
||||||
|
api_base="https://exampleopenaiendpoint-production.up.railway.app/",
|
||||||
|
)
|
||||||
|
|
||||||
|
print("async re rank response: ", response)
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
mock_post.assert_called_once()
|
||||||
|
_url, kwargs = mock_post.call_args
|
||||||
|
args_to_api = kwargs["json"]
|
||||||
|
print("Arguments passed to API=", args_to_api)
|
||||||
|
print("url = ", _url)
|
||||||
|
assert _url[0] == "https://exampleopenaiendpoint-production.up.railway.app/"
|
||||||
|
assert args_to_api == expected_payload
|
||||||
|
assert response.id is not None
|
||||||
|
assert response.results is not None
|
||||||
|
|
||||||
|
assert_response_shape(response, custom_llm_provider="cohere")
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue