add rerank params

This commit is contained in:
Ishaan Jaff 2024-08-27 16:45:39 -07:00
parent 1ed4e91b9b
commit 2aa119864a
4 changed files with 60 additions and 13 deletions

View file

@ -4,6 +4,8 @@ Re rank api
LiteLLM supports the re rank API format, no paramter transformation occurs LiteLLM supports the re rank API format, no paramter transformation occurs
""" """
from typing import Any, Dict, List, Optional, Union
import httpx import httpx
from pydantic import BaseModel from pydantic import BaseModel
@ -21,14 +23,26 @@ class CohereRerank(BaseLLM):
model: str, model: str,
api_key: str, api_key: str,
query: str, query: str,
documents: list[str], documents: list[Union[str, Dict[str, Any]]],
top_n: int = 3, top_n: Optional[int] = None,
rank_fields: Optional[List[str]] = None,
return_documents: Optional[bool] = True,
max_chunks_per_doc: Optional[int] = None,
) -> RerankResponse: ) -> RerankResponse:
client = _get_httpx_client() client = _get_httpx_client()
request_data = RerankRequest( request_data = RerankRequest(
model=model, query=query, top_n=top_n, documents=documents model=model,
query=query,
top_n=top_n,
documents=documents,
rank_fields=rank_fields,
return_documents=return_documents,
max_chunks_per_doc=max_chunks_per_doc,
) )
request_data_dict = request_data.dict(exclude_none=True)
response = client.post( response = client.post(
"https://api.cohere.com/v1/rerank", "https://api.cohere.com/v1/rerank",
headers={ headers={
@ -36,7 +50,7 @@ class CohereRerank(BaseLLM):
"content-type": "application/json", "content-type": "application/json",
"Authorization": f"bearer {api_key}", "Authorization": f"bearer {api_key}",
}, },
json=request_data.dict(), json=request_data_dict,
) )
return RerankResponse(**response.json()) return RerankResponse(**response.json())

View file

@ -4,6 +4,8 @@ Re rank api
LiteLLM supports the re rank API format, no paramter transformation occurs LiteLLM supports the re rank API format, no paramter transformation occurs
""" """
from typing import Any, Dict, List, Optional, Union
import httpx import httpx
from pydantic import BaseModel from pydantic import BaseModel
@ -21,15 +23,28 @@ class TogetherAIRerank(BaseLLM):
model: str, model: str,
api_key: str, api_key: str,
query: str, query: str,
documents: list[str], documents: list[Union[str, Dict[str, Any]]],
top_n: int = 3, top_n: Optional[int] = None,
rank_fields: Optional[List[str]] = None,
return_documents: Optional[bool] = True,
max_chunks_per_doc: Optional[int] = None,
) -> RerankResponse: ) -> RerankResponse:
client = _get_httpx_client() client = _get_httpx_client()
request_data = RerankRequest( request_data = RerankRequest(
model=model, query=query, top_n=top_n, documents=documents model=model,
query=query,
top_n=top_n,
documents=documents,
rank_fields=rank_fields,
return_documents=return_documents,
) )
# exclude None values from request_data
request_data_dict = request_data.dict(exclude_none=True)
if max_chunks_per_doc is not None:
raise ValueError("TogetherAI does not support max_chunks_per_doc")
response = client.post( response = client.post(
"https://api.together.xyz/v1/rerank", "https://api.together.xyz/v1/rerank",
headers={ headers={
@ -37,10 +52,14 @@ class TogetherAIRerank(BaseLLM):
"content-type": "application/json", "content-type": "application/json",
"authorization": f"Bearer {api_key}", "authorization": f"Bearer {api_key}",
}, },
json=request_data.dict(), json=request_data_dict,
) )
if response.status_code != 200:
raise Exception(response.text)
_json_response = response.json() _json_response = response.json()
response = RerankResponse( response = RerankResponse(
id=_json_response.get("id"), id=_json_response.get("id"),
results=_json_response.get("results"), results=_json_response.get("results"),

View file

@ -55,9 +55,12 @@ async def arerank(
def rerank( def rerank(
model: str, model: str,
query: str, query: str,
documents: List[str], documents: List[Union[str, Dict[str, Any]]],
custom_llm_provider: Optional[Literal["cohere", "together_ai"]] = None, custom_llm_provider: Optional[Literal["cohere", "together_ai"]] = None,
top_n: int = 3, top_n: Optional[int] = None,
rank_fields: Optional[List[str]] = None,
return_documents: Optional[bool] = True,
max_chunks_per_doc: Optional[int] = None,
**kwargs, **kwargs,
) -> Union[RerankResponse, Coroutine[Any, Any, RerankResponse]]: ) -> Union[RerankResponse, Coroutine[Any, Any, RerankResponse]]:
""" """
@ -107,6 +110,9 @@ def rerank(
query=query, query=query,
documents=documents, documents=documents,
top_n=top_n, top_n=top_n,
rank_fields=rank_fields,
return_documents=return_documents,
max_chunks_per_doc=max_chunks_per_doc,
api_key=cohere_key, api_key=cohere_key,
) )
pass pass
@ -130,6 +136,9 @@ def rerank(
query=query, query=query,
documents=documents, documents=documents,
top_n=top_n, top_n=top_n,
rank_fields=rank_fields,
return_documents=return_documents,
max_chunks_per_doc=max_chunks_per_doc,
api_key=together_key, api_key=together_key,
) )

View file

@ -4,17 +4,22 @@ https://docs.cohere.com/reference/rerank
""" """
from typing import List, Optional, Union
from pydantic import BaseModel from pydantic import BaseModel
class RerankRequest(BaseModel): class RerankRequest(BaseModel):
model: str model: str
query: str query: str
top_n: int top_n: Optional[int] = None
documents: list[str] documents: List[Union[str, dict]]
rank_fields: Optional[List[str]] = None
return_documents: Optional[bool] = None
max_chunks_per_doc: Optional[int] = None
class RerankResponse(BaseModel): class RerankResponse(BaseModel):
id: str id: str
results: list[dict] # Contains index and relevance_score results: List[dict] # Contains index and relevance_score
meta: dict # Contains api_version and billed_units meta: dict # Contains api_version and billed_units