add tg ai rerank support

This commit is contained in:
Ishaan Jaff 2024-08-27 16:25:54 -07:00
parent 15ac10af40
commit 70db82a236
3 changed files with 134 additions and 6 deletions

View file

@ -0,0 +1,44 @@
"""
Re rank api
LiteLLM supports the re rank API format, no paramter transformation occurs
"""
import httpx
from pydantic import BaseModel
from litellm.llms.base import BaseLLM
from litellm.llms.custom_httpx.http_handler import (
_get_async_httpx_client,
_get_httpx_client,
)
from litellm.rerank_api.types import RerankRequest, RerankResponse
class CohereRerank(BaseLLM):
def rerank(
self,
model: str,
api_key: str,
query: str,
documents: list[str],
top_n: int = 3,
) -> RerankResponse:
client = _get_httpx_client()
request_data = RerankRequest(
model=model, query=query, top_n=top_n, documents=documents
)
response = client.post(
"https://api.cohere.com/v1/rerank",
headers={
"accept": "application/json",
"content-type": "application/json",
"Authorization": f"bearer {api_key}",
},
json=request_data.dict(),
)
return RerankResponse(**response.json())
pass

View file

@ -0,0 +1,52 @@
"""
Re rank api
LiteLLM supports the re rank API format, no paramter transformation occurs
"""
import httpx
from pydantic import BaseModel
from litellm.llms.base import BaseLLM
from litellm.llms.custom_httpx.http_handler import (
_get_async_httpx_client,
_get_httpx_client,
)
from litellm.rerank_api.types import RerankRequest, RerankResponse
class TogetherAIRerank(BaseLLM):
def rerank(
self,
model: str,
api_key: str,
query: str,
documents: list[str],
top_n: int = 3,
) -> RerankResponse:
client = _get_httpx_client()
request_data = RerankRequest(
model=model, query=query, top_n=top_n, documents=documents
)
response = client.post(
"https://api.together.xyz/v1/rerank",
headers={
"accept": "application/json",
"content-type": "application/json",
"authorization": f"Bearer {api_key}",
},
json=request_data.dict(),
)
_json_response = response.json()
response = RerankResponse(
id=_json_response.get("id"),
results=_json_response.get("results"),
meta=_json_response.get("meta") or {},
)
return response
pass

View file

@ -7,6 +7,7 @@ import litellm
from litellm import get_secret from litellm import get_secret
from litellm._logging import verbose_logger from litellm._logging import verbose_logger
from litellm.llms.cohere.rerank import CohereRerank from litellm.llms.cohere.rerank import CohereRerank
from litellm.llms.togetherai.rerank import TogetherAIRerank
from litellm.types.router import * from litellm.types.router import *
from litellm.utils import supports_httpx_timeout from litellm.utils import supports_httpx_timeout
@ -15,6 +16,7 @@ from .types import RerankRequest, RerankResponse
####### ENVIRONMENT VARIABLES ################### ####### ENVIRONMENT VARIABLES ###################
# Initialize any necessary instances or variables here # Initialize any necessary instances or variables here
cohere_rerank = CohereRerank() cohere_rerank = CohereRerank()
together_rerank = TogetherAIRerank()
################################################# #################################################
@ -54,7 +56,7 @@ def rerank(
model: str, model: str,
query: str, query: str,
documents: List[str], documents: List[str],
custom_llm_provider: Literal["cohere", "together_ai"] = "cohere", custom_llm_provider: Optional[Literal["cohere", "together_ai"]] = None,
top_n: int = 3, top_n: int = 3,
**kwargs, **kwargs,
) -> Union[RerankResponse, Coroutine[Any, Any, RerankResponse]]: ) -> Union[RerankResponse, Coroutine[Any, Any, RerankResponse]]:
@ -65,11 +67,21 @@ def rerank(
_is_async = kwargs.pop("arerank", False) is True _is_async = kwargs.pop("arerank", False) is True
optional_params = GenericLiteLLMParams(**kwargs) optional_params = GenericLiteLLMParams(**kwargs)
model, _custom_llm_provider, dynamic_api_key, api_base = (
litellm.get_llm_provider(
model=model,
custom_llm_provider=custom_llm_provider,
api_base=optional_params.api_base,
api_key=optional_params.api_key,
)
)
# Implement rerank logic here based on the custom_llm_provider # Implement rerank logic here based on the custom_llm_provider
if custom_llm_provider == "cohere": if _custom_llm_provider == "cohere":
# Implement Cohere rerank logic # Implement Cohere rerank logic
cohere_key = ( cohere_key = (
optional_params.api_key dynamic_api_key
or optional_params.api_key
or litellm.cohere_key or litellm.cohere_key
or get_secret("COHERE_API_KEY") or get_secret("COHERE_API_KEY")
or get_secret("CO_API_KEY") or get_secret("CO_API_KEY")
@ -98,11 +110,31 @@ def rerank(
api_key=cohere_key, api_key=cohere_key,
) )
pass pass
elif custom_llm_provider == "together_ai": elif _custom_llm_provider == "together_ai":
# Implement Together AI rerank logic # Implement Together AI rerank logic
pass together_key = (
dynamic_api_key
or optional_params.api_key
or litellm.togetherai_api_key
or get_secret("TOGETHERAI_API_KEY")
or litellm.api_key
)
if together_key is None:
raise ValueError(
"TogetherAI API key is required, please set 'TOGETHERAI_API_KEY' in your environment"
)
response = together_rerank.rerank(
model=model,
query=query,
documents=documents,
top_n=top_n,
api_key=together_key,
)
else: else:
raise ValueError(f"Unsupported provider: {custom_llm_provider}") raise ValueError(f"Unsupported provider: {_custom_llm_provider}")
# Placeholder return # Placeholder return
return response return response