mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-27 11:43:54 +00:00
add tg ai rerank support
This commit is contained in:
parent
15ac10af40
commit
70db82a236
3 changed files with 134 additions and 6 deletions
44
litellm/llms/cohere/rerank.py
Normal file
44
litellm/llms/cohere/rerank.py
Normal file
|
@ -0,0 +1,44 @@
|
|||
"""
|
||||
Re rank api
|
||||
|
||||
LiteLLM supports the re rank API format, no paramter transformation occurs
|
||||
"""
|
||||
|
||||
import httpx
|
||||
from pydantic import BaseModel
|
||||
|
||||
from litellm.llms.base import BaseLLM
|
||||
from litellm.llms.custom_httpx.http_handler import (
|
||||
_get_async_httpx_client,
|
||||
_get_httpx_client,
|
||||
)
|
||||
from litellm.rerank_api.types import RerankRequest, RerankResponse
|
||||
|
||||
|
||||
class CohereRerank(BaseLLM):
|
||||
def rerank(
|
||||
self,
|
||||
model: str,
|
||||
api_key: str,
|
||||
query: str,
|
||||
documents: list[str],
|
||||
top_n: int = 3,
|
||||
) -> RerankResponse:
|
||||
client = _get_httpx_client()
|
||||
request_data = RerankRequest(
|
||||
model=model, query=query, top_n=top_n, documents=documents
|
||||
)
|
||||
|
||||
response = client.post(
|
||||
"https://api.cohere.com/v1/rerank",
|
||||
headers={
|
||||
"accept": "application/json",
|
||||
"content-type": "application/json",
|
||||
"Authorization": f"bearer {api_key}",
|
||||
},
|
||||
json=request_data.dict(),
|
||||
)
|
||||
|
||||
return RerankResponse(**response.json())
|
||||
|
||||
pass
|
52
litellm/llms/togetherai/rerank.py
Normal file
52
litellm/llms/togetherai/rerank.py
Normal file
|
@ -0,0 +1,52 @@
|
|||
"""
|
||||
Re rank api
|
||||
|
||||
LiteLLM supports the re rank API format, no paramter transformation occurs
|
||||
"""
|
||||
|
||||
import httpx
|
||||
from pydantic import BaseModel
|
||||
|
||||
from litellm.llms.base import BaseLLM
|
||||
from litellm.llms.custom_httpx.http_handler import (
|
||||
_get_async_httpx_client,
|
||||
_get_httpx_client,
|
||||
)
|
||||
from litellm.rerank_api.types import RerankRequest, RerankResponse
|
||||
|
||||
|
||||
class TogetherAIRerank(BaseLLM):
|
||||
def rerank(
|
||||
self,
|
||||
model: str,
|
||||
api_key: str,
|
||||
query: str,
|
||||
documents: list[str],
|
||||
top_n: int = 3,
|
||||
) -> RerankResponse:
|
||||
client = _get_httpx_client()
|
||||
|
||||
request_data = RerankRequest(
|
||||
model=model, query=query, top_n=top_n, documents=documents
|
||||
)
|
||||
|
||||
response = client.post(
|
||||
"https://api.together.xyz/v1/rerank",
|
||||
headers={
|
||||
"accept": "application/json",
|
||||
"content-type": "application/json",
|
||||
"authorization": f"Bearer {api_key}",
|
||||
},
|
||||
json=request_data.dict(),
|
||||
)
|
||||
|
||||
_json_response = response.json()
|
||||
response = RerankResponse(
|
||||
id=_json_response.get("id"),
|
||||
results=_json_response.get("results"),
|
||||
meta=_json_response.get("meta") or {},
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
pass
|
|
@ -7,6 +7,7 @@ import litellm
|
|||
from litellm import get_secret
|
||||
from litellm._logging import verbose_logger
|
||||
from litellm.llms.cohere.rerank import CohereRerank
|
||||
from litellm.llms.togetherai.rerank import TogetherAIRerank
|
||||
from litellm.types.router import *
|
||||
from litellm.utils import supports_httpx_timeout
|
||||
|
||||
|
@ -15,6 +16,7 @@ from .types import RerankRequest, RerankResponse
|
|||
####### ENVIRONMENT VARIABLES ###################
|
||||
# Initialize any necessary instances or variables here
|
||||
cohere_rerank = CohereRerank()
|
||||
together_rerank = TogetherAIRerank()
|
||||
#################################################
|
||||
|
||||
|
||||
|
@ -54,7 +56,7 @@ def rerank(
|
|||
model: str,
|
||||
query: str,
|
||||
documents: List[str],
|
||||
custom_llm_provider: Literal["cohere", "together_ai"] = "cohere",
|
||||
custom_llm_provider: Optional[Literal["cohere", "together_ai"]] = None,
|
||||
top_n: int = 3,
|
||||
**kwargs,
|
||||
) -> Union[RerankResponse, Coroutine[Any, Any, RerankResponse]]:
|
||||
|
@ -65,11 +67,21 @@ def rerank(
|
|||
_is_async = kwargs.pop("arerank", False) is True
|
||||
optional_params = GenericLiteLLMParams(**kwargs)
|
||||
|
||||
model, _custom_llm_provider, dynamic_api_key, api_base = (
|
||||
litellm.get_llm_provider(
|
||||
model=model,
|
||||
custom_llm_provider=custom_llm_provider,
|
||||
api_base=optional_params.api_base,
|
||||
api_key=optional_params.api_key,
|
||||
)
|
||||
)
|
||||
|
||||
# Implement rerank logic here based on the custom_llm_provider
|
||||
if custom_llm_provider == "cohere":
|
||||
if _custom_llm_provider == "cohere":
|
||||
# Implement Cohere rerank logic
|
||||
cohere_key = (
|
||||
optional_params.api_key
|
||||
dynamic_api_key
|
||||
or optional_params.api_key
|
||||
or litellm.cohere_key
|
||||
or get_secret("COHERE_API_KEY")
|
||||
or get_secret("CO_API_KEY")
|
||||
|
@ -98,11 +110,31 @@ def rerank(
|
|||
api_key=cohere_key,
|
||||
)
|
||||
pass
|
||||
elif custom_llm_provider == "together_ai":
|
||||
elif _custom_llm_provider == "together_ai":
|
||||
# Implement Together AI rerank logic
|
||||
pass
|
||||
together_key = (
|
||||
dynamic_api_key
|
||||
or optional_params.api_key
|
||||
or litellm.togetherai_api_key
|
||||
or get_secret("TOGETHERAI_API_KEY")
|
||||
or litellm.api_key
|
||||
)
|
||||
|
||||
if together_key is None:
|
||||
raise ValueError(
|
||||
"TogetherAI API key is required, please set 'TOGETHERAI_API_KEY' in your environment"
|
||||
)
|
||||
|
||||
response = together_rerank.rerank(
|
||||
model=model,
|
||||
query=query,
|
||||
documents=documents,
|
||||
top_n=top_n,
|
||||
api_key=together_key,
|
||||
)
|
||||
|
||||
else:
|
||||
raise ValueError(f"Unsupported provider: {custom_llm_provider}")
|
||||
raise ValueError(f"Unsupported provider: {_custom_llm_provider}")
|
||||
|
||||
# Placeholder return
|
||||
return response
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue