(code refactor) - Add BaseRerankConfig. Use BaseRerankConfig for cohere/rerank and azure_ai/rerank (#7319)

* add base rerank config

* working sync cohere rerank

* update rerank types

* update base rerank config

* remove old rerank

* add new cohere handler.py

* add cohere rerank transform

* add get_provider_rerank_config

* add rerank to base llm http handler

* add rerank utils

* add arerank to llm http handler.py

* add AzureAIRerankConfig

* updates rerank config

* update test rerank

* fix unused imports

* update get_provider_rerank_config

* test_basic_rerank_caching

* fix unused import

* test rerank
This commit is contained in:
Ishaan Jaff 2024-12-19 17:03:34 -08:00 committed by GitHub
parent a790d43116
commit 5f15b0aa20
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
19 changed files with 645 additions and 425 deletions

View file

@ -171,6 +171,7 @@ from openai import OpenAIError as OriginalError
from litellm.llms.base_llm.chat.transformation import BaseConfig
from litellm.llms.base_llm.embedding.transformation import BaseEmbeddingConfig
from litellm.llms.base_llm.rerank.transformation import BaseRerankConfig
from ._logging import verbose_logger
from .caching.caching import (
@ -6204,6 +6205,17 @@ class ProviderConfigManager:
return litellm.VoyageEmbeddingConfig()
raise ValueError(f"Provider {provider} does not support embedding config")
@staticmethod
def get_provider_rerank_config(
model: str,
provider: LlmProviders,
) -> BaseRerankConfig:
if litellm.LlmProviders.COHERE == provider:
return litellm.CohereRerankConfig()
elif litellm.LlmProviders.AZURE_AI == provider:
return litellm.AzureAIRerankConfig()
return litellm.CohereRerankConfig()
def get_end_user_id_for_cost_tracking(
litellm_params: dict,