mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-27 03:34:10 +00:00
add tg ai rerank support
This commit is contained in:
parent
b8bc185bd5
commit
dc42ad0021
3 changed files with 134 additions and 6 deletions
44
litellm/llms/cohere/rerank.py
Normal file
44
litellm/llms/cohere/rerank.py
Normal file
|
@ -0,0 +1,44 @@
|
||||||
|
"""
|
||||||
|
Re rank api
|
||||||
|
|
||||||
|
LiteLLM supports the re rank API format, no paramter transformation occurs
|
||||||
|
"""
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
from litellm.llms.base import BaseLLM
|
||||||
|
from litellm.llms.custom_httpx.http_handler import (
|
||||||
|
_get_async_httpx_client,
|
||||||
|
_get_httpx_client,
|
||||||
|
)
|
||||||
|
from litellm.rerank_api.types import RerankRequest, RerankResponse
|
||||||
|
|
||||||
|
|
||||||
|
class CohereRerank(BaseLLM):
|
||||||
|
def rerank(
|
||||||
|
self,
|
||||||
|
model: str,
|
||||||
|
api_key: str,
|
||||||
|
query: str,
|
||||||
|
documents: list[str],
|
||||||
|
top_n: int = 3,
|
||||||
|
) -> RerankResponse:
|
||||||
|
client = _get_httpx_client()
|
||||||
|
request_data = RerankRequest(
|
||||||
|
model=model, query=query, top_n=top_n, documents=documents
|
||||||
|
)
|
||||||
|
|
||||||
|
response = client.post(
|
||||||
|
"https://api.cohere.com/v1/rerank",
|
||||||
|
headers={
|
||||||
|
"accept": "application/json",
|
||||||
|
"content-type": "application/json",
|
||||||
|
"Authorization": f"bearer {api_key}",
|
||||||
|
},
|
||||||
|
json=request_data.dict(),
|
||||||
|
)
|
||||||
|
|
||||||
|
return RerankResponse(**response.json())
|
||||||
|
|
||||||
|
pass
|
52
litellm/llms/togetherai/rerank.py
Normal file
52
litellm/llms/togetherai/rerank.py
Normal file
|
@ -0,0 +1,52 @@
|
||||||
|
"""
|
||||||
|
Re rank api
|
||||||
|
|
||||||
|
LiteLLM supports the re rank API format, no paramter transformation occurs
|
||||||
|
"""
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
from litellm.llms.base import BaseLLM
|
||||||
|
from litellm.llms.custom_httpx.http_handler import (
|
||||||
|
_get_async_httpx_client,
|
||||||
|
_get_httpx_client,
|
||||||
|
)
|
||||||
|
from litellm.rerank_api.types import RerankRequest, RerankResponse
|
||||||
|
|
||||||
|
|
||||||
|
class TogetherAIRerank(BaseLLM):
|
||||||
|
def rerank(
|
||||||
|
self,
|
||||||
|
model: str,
|
||||||
|
api_key: str,
|
||||||
|
query: str,
|
||||||
|
documents: list[str],
|
||||||
|
top_n: int = 3,
|
||||||
|
) -> RerankResponse:
|
||||||
|
client = _get_httpx_client()
|
||||||
|
|
||||||
|
request_data = RerankRequest(
|
||||||
|
model=model, query=query, top_n=top_n, documents=documents
|
||||||
|
)
|
||||||
|
|
||||||
|
response = client.post(
|
||||||
|
"https://api.together.xyz/v1/rerank",
|
||||||
|
headers={
|
||||||
|
"accept": "application/json",
|
||||||
|
"content-type": "application/json",
|
||||||
|
"authorization": f"Bearer {api_key}",
|
||||||
|
},
|
||||||
|
json=request_data.dict(),
|
||||||
|
)
|
||||||
|
|
||||||
|
_json_response = response.json()
|
||||||
|
response = RerankResponse(
|
||||||
|
id=_json_response.get("id"),
|
||||||
|
results=_json_response.get("results"),
|
||||||
|
meta=_json_response.get("meta") or {},
|
||||||
|
)
|
||||||
|
|
||||||
|
return response
|
||||||
|
|
||||||
|
pass
|
|
@ -7,6 +7,7 @@ import litellm
|
||||||
from litellm import get_secret
|
from litellm import get_secret
|
||||||
from litellm._logging import verbose_logger
|
from litellm._logging import verbose_logger
|
||||||
from litellm.llms.cohere.rerank import CohereRerank
|
from litellm.llms.cohere.rerank import CohereRerank
|
||||||
|
from litellm.llms.togetherai.rerank import TogetherAIRerank
|
||||||
from litellm.types.router import *
|
from litellm.types.router import *
|
||||||
from litellm.utils import supports_httpx_timeout
|
from litellm.utils import supports_httpx_timeout
|
||||||
|
|
||||||
|
@ -15,6 +16,7 @@ from .types import RerankRequest, RerankResponse
|
||||||
####### ENVIRONMENT VARIABLES ###################
|
####### ENVIRONMENT VARIABLES ###################
|
||||||
# Initialize any necessary instances or variables here
|
# Initialize any necessary instances or variables here
|
||||||
cohere_rerank = CohereRerank()
|
cohere_rerank = CohereRerank()
|
||||||
|
together_rerank = TogetherAIRerank()
|
||||||
#################################################
|
#################################################
|
||||||
|
|
||||||
|
|
||||||
|
@ -54,7 +56,7 @@ def rerank(
|
||||||
model: str,
|
model: str,
|
||||||
query: str,
|
query: str,
|
||||||
documents: List[str],
|
documents: List[str],
|
||||||
custom_llm_provider: Literal["cohere", "together_ai"] = "cohere",
|
custom_llm_provider: Optional[Literal["cohere", "together_ai"]] = None,
|
||||||
top_n: int = 3,
|
top_n: int = 3,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> Union[RerankResponse, Coroutine[Any, Any, RerankResponse]]:
|
) -> Union[RerankResponse, Coroutine[Any, Any, RerankResponse]]:
|
||||||
|
@ -65,11 +67,21 @@ def rerank(
|
||||||
_is_async = kwargs.pop("arerank", False) is True
|
_is_async = kwargs.pop("arerank", False) is True
|
||||||
optional_params = GenericLiteLLMParams(**kwargs)
|
optional_params = GenericLiteLLMParams(**kwargs)
|
||||||
|
|
||||||
|
model, _custom_llm_provider, dynamic_api_key, api_base = (
|
||||||
|
litellm.get_llm_provider(
|
||||||
|
model=model,
|
||||||
|
custom_llm_provider=custom_llm_provider,
|
||||||
|
api_base=optional_params.api_base,
|
||||||
|
api_key=optional_params.api_key,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
# Implement rerank logic here based on the custom_llm_provider
|
# Implement rerank logic here based on the custom_llm_provider
|
||||||
if custom_llm_provider == "cohere":
|
if _custom_llm_provider == "cohere":
|
||||||
# Implement Cohere rerank logic
|
# Implement Cohere rerank logic
|
||||||
cohere_key = (
|
cohere_key = (
|
||||||
optional_params.api_key
|
dynamic_api_key
|
||||||
|
or optional_params.api_key
|
||||||
or litellm.cohere_key
|
or litellm.cohere_key
|
||||||
or get_secret("COHERE_API_KEY")
|
or get_secret("COHERE_API_KEY")
|
||||||
or get_secret("CO_API_KEY")
|
or get_secret("CO_API_KEY")
|
||||||
|
@ -98,11 +110,31 @@ def rerank(
|
||||||
api_key=cohere_key,
|
api_key=cohere_key,
|
||||||
)
|
)
|
||||||
pass
|
pass
|
||||||
elif custom_llm_provider == "together_ai":
|
elif _custom_llm_provider == "together_ai":
|
||||||
# Implement Together AI rerank logic
|
# Implement Together AI rerank logic
|
||||||
pass
|
together_key = (
|
||||||
|
dynamic_api_key
|
||||||
|
or optional_params.api_key
|
||||||
|
or litellm.togetherai_api_key
|
||||||
|
or get_secret("TOGETHERAI_API_KEY")
|
||||||
|
or litellm.api_key
|
||||||
|
)
|
||||||
|
|
||||||
|
if together_key is None:
|
||||||
|
raise ValueError(
|
||||||
|
"TogetherAI API key is required, please set 'TOGETHERAI_API_KEY' in your environment"
|
||||||
|
)
|
||||||
|
|
||||||
|
response = together_rerank.rerank(
|
||||||
|
model=model,
|
||||||
|
query=query,
|
||||||
|
documents=documents,
|
||||||
|
top_n=top_n,
|
||||||
|
api_key=together_key,
|
||||||
|
)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unsupported provider: {custom_llm_provider}")
|
raise ValueError(f"Unsupported provider: {_custom_llm_provider}")
|
||||||
|
|
||||||
# Placeholder return
|
# Placeholder return
|
||||||
return response
|
return response
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue