mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 10:44:24 +00:00
add tg ai rerank support
This commit is contained in:
parent
b8bc185bd5
commit
dc42ad0021
3 changed files with 134 additions and 6 deletions
|
@ -7,6 +7,7 @@ import litellm
|
|||
from litellm import get_secret
|
||||
from litellm._logging import verbose_logger
|
||||
from litellm.llms.cohere.rerank import CohereRerank
|
||||
from litellm.llms.togetherai.rerank import TogetherAIRerank
|
||||
from litellm.types.router import *
|
||||
from litellm.utils import supports_httpx_timeout
|
||||
|
||||
|
@ -15,6 +16,7 @@ from .types import RerankRequest, RerankResponse
|
|||
####### ENVIRONMENT VARIABLES ###################
|
||||
# Initialize any necessary instances or variables here
|
||||
cohere_rerank = CohereRerank()
|
||||
together_rerank = TogetherAIRerank()
|
||||
#################################################
|
||||
|
||||
|
||||
|
@ -54,7 +56,7 @@ def rerank(
|
|||
model: str,
|
||||
query: str,
|
||||
documents: List[str],
|
||||
custom_llm_provider: Literal["cohere", "together_ai"] = "cohere",
|
||||
custom_llm_provider: Optional[Literal["cohere", "together_ai"]] = None,
|
||||
top_n: int = 3,
|
||||
**kwargs,
|
||||
) -> Union[RerankResponse, Coroutine[Any, Any, RerankResponse]]:
|
||||
|
@ -65,11 +67,21 @@ def rerank(
|
|||
_is_async = kwargs.pop("arerank", False) is True
|
||||
optional_params = GenericLiteLLMParams(**kwargs)
|
||||
|
||||
model, _custom_llm_provider, dynamic_api_key, api_base = (
|
||||
litellm.get_llm_provider(
|
||||
model=model,
|
||||
custom_llm_provider=custom_llm_provider,
|
||||
api_base=optional_params.api_base,
|
||||
api_key=optional_params.api_key,
|
||||
)
|
||||
)
|
||||
|
||||
# Implement rerank logic here based on the custom_llm_provider
|
||||
if custom_llm_provider == "cohere":
|
||||
if _custom_llm_provider == "cohere":
|
||||
# Implement Cohere rerank logic
|
||||
cohere_key = (
|
||||
optional_params.api_key
|
||||
dynamic_api_key
|
||||
or optional_params.api_key
|
||||
or litellm.cohere_key
|
||||
or get_secret("COHERE_API_KEY")
|
||||
or get_secret("CO_API_KEY")
|
||||
|
@ -98,11 +110,31 @@ def rerank(
|
|||
api_key=cohere_key,
|
||||
)
|
||||
pass
|
||||
elif custom_llm_provider == "together_ai":
|
||||
elif _custom_llm_provider == "together_ai":
|
||||
# Implement Together AI rerank logic
|
||||
pass
|
||||
together_key = (
|
||||
dynamic_api_key
|
||||
or optional_params.api_key
|
||||
or litellm.togetherai_api_key
|
||||
or get_secret("TOGETHERAI_API_KEY")
|
||||
or litellm.api_key
|
||||
)
|
||||
|
||||
if together_key is None:
|
||||
raise ValueError(
|
||||
"TogetherAI API key is required, please set 'TOGETHERAI_API_KEY' in your environment"
|
||||
)
|
||||
|
||||
response = together_rerank.rerank(
|
||||
model=model,
|
||||
query=query,
|
||||
documents=documents,
|
||||
top_n=top_n,
|
||||
api_key=together_key,
|
||||
)
|
||||
|
||||
else:
|
||||
raise ValueError(f"Unsupported provider: {custom_llm_provider}")
|
||||
raise ValueError(f"Unsupported provider: {_custom_llm_provider}")
|
||||
|
||||
# Placeholder return
|
||||
return response
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue