mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 18:54:30 +00:00
(code refactor) - Add BaseRerankConfig
. Use BaseRerankConfig
for cohere/rerank
and azure_ai/rerank
(#7319)
* add base rerank config * working sync cohere rerank * update rerank types * update base rerank config * remove old rerank * add new cohere handler.py * add cohere rerank transform * add get_provider_rerank_config * add rerank to base llm http handler * add rerank utils * add arerank to llm http handler.py * add AzureAIRerankConfig * updates rerank config * update test rerank * fix unused imports * update get_provider_rerank_config * test_basic_rerank_caching * fix unused import * test rerank
This commit is contained in:
parent
a790d43116
commit
5f15b0aa20
19 changed files with 645 additions and 425 deletions
|
@ -6,23 +6,23 @@ from typing import Any, Coroutine, Dict, List, Literal, Optional, Union
|
|||
import litellm
|
||||
from litellm._logging import verbose_logger
|
||||
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
|
||||
from litellm.llms.azure_ai.rerank import AzureAIRerank
|
||||
from litellm.llms.base_llm.rerank.transformation import BaseRerankConfig
|
||||
from litellm.llms.bedrock.rerank.handler import BedrockRerankHandler
|
||||
from litellm.llms.cohere.rerank import CohereRerank
|
||||
from litellm.llms.custom_httpx.llm_http_handler import BaseLLMHTTPHandler
|
||||
from litellm.llms.jina_ai.rerank.handler import JinaAIRerank
|
||||
from litellm.llms.together_ai.rerank.handler import TogetherAIRerank
|
||||
from litellm.rerank_api.rerank_utils import get_optional_rerank_params
|
||||
from litellm.secret_managers.main import get_secret
|
||||
from litellm.types.rerank import RerankResponse
|
||||
from litellm.types.rerank import OptionalRerankParams, RerankResponse
|
||||
from litellm.types.router import *
|
||||
from litellm.utils import client, exception_type
|
||||
from litellm.utils import ProviderConfigManager, client, exception_type
|
||||
|
||||
####### ENVIRONMENT VARIABLES ###################
|
||||
# Initialize any necessary instances or variables here
|
||||
cohere_rerank = CohereRerank()
|
||||
together_rerank = TogetherAIRerank()
|
||||
azure_ai_rerank = AzureAIRerank()
|
||||
jina_ai_rerank = JinaAIRerank()
|
||||
bedrock_rerank = BedrockRerankHandler()
|
||||
base_llm_http_handler = BaseLLMHTTPHandler()
|
||||
#################################################
|
||||
|
||||
|
||||
|
@ -107,18 +107,36 @@ def rerank( # noqa: PLR0915
|
|||
)
|
||||
)
|
||||
|
||||
model_params_dict = {
|
||||
"top_n": top_n,
|
||||
"rank_fields": rank_fields,
|
||||
"return_documents": return_documents,
|
||||
"max_chunks_per_doc": max_chunks_per_doc,
|
||||
"documents": documents,
|
||||
}
|
||||
rerank_provider_config: BaseRerankConfig = (
|
||||
ProviderConfigManager.get_provider_rerank_config(
|
||||
model=model,
|
||||
provider=litellm.LlmProviders(_custom_llm_provider),
|
||||
)
|
||||
)
|
||||
|
||||
optional_rerank_params: OptionalRerankParams = get_optional_rerank_params(
|
||||
rerank_provider_config=rerank_provider_config,
|
||||
model=model,
|
||||
drop_params=kwargs.get("drop_params") or litellm.drop_params or False,
|
||||
query=query,
|
||||
documents=documents,
|
||||
custom_llm_provider=_custom_llm_provider,
|
||||
top_n=top_n,
|
||||
rank_fields=rank_fields,
|
||||
return_documents=return_documents,
|
||||
max_chunks_per_doc=max_chunks_per_doc,
|
||||
non_default_params=kwargs,
|
||||
)
|
||||
|
||||
if isinstance(optional_params.timeout, str):
|
||||
optional_params.timeout = float(optional_params.timeout)
|
||||
|
||||
model_response = RerankResponse()
|
||||
|
||||
litellm_logging_obj.update_environment_variables(
|
||||
model=model,
|
||||
user=user,
|
||||
optional_params=model_params_dict,
|
||||
optional_params=optional_rerank_params,
|
||||
litellm_params={
|
||||
"litellm_call_id": litellm_call_id,
|
||||
"proxy_server_request": proxy_server_request,
|
||||
|
@ -135,19 +153,9 @@ def rerank( # noqa: PLR0915
|
|||
if _custom_llm_provider == "cohere":
|
||||
# Implement Cohere rerank logic
|
||||
api_key: Optional[str] = (
|
||||
dynamic_api_key
|
||||
or optional_params.api_key
|
||||
or litellm.cohere_key
|
||||
or get_secret("COHERE_API_KEY") # type: ignore
|
||||
or get_secret("CO_API_KEY") # type: ignore
|
||||
or litellm.api_key
|
||||
dynamic_api_key or optional_params.api_key or litellm.api_key
|
||||
)
|
||||
|
||||
if api_key is None:
|
||||
raise ValueError(
|
||||
"Cohere API key is required, please set 'COHERE_API_KEY' in your environment"
|
||||
)
|
||||
|
||||
api_base: Optional[str] = (
|
||||
dynamic_api_base
|
||||
or optional_params.api_base
|
||||
|
@ -160,23 +168,18 @@ def rerank( # noqa: PLR0915
|
|||
raise Exception(
|
||||
"Invalid api base. api_base=None. Set in call or via `COHERE_API_BASE` env var."
|
||||
)
|
||||
|
||||
headers = headers or litellm.headers or {}
|
||||
|
||||
response = cohere_rerank.rerank(
|
||||
response = base_llm_http_handler.rerank(
|
||||
model=model,
|
||||
query=query,
|
||||
documents=documents,
|
||||
top_n=top_n,
|
||||
rank_fields=rank_fields,
|
||||
return_documents=return_documents,
|
||||
max_chunks_per_doc=max_chunks_per_doc,
|
||||
api_key=api_key,
|
||||
custom_llm_provider=_custom_llm_provider,
|
||||
optional_rerank_params=optional_rerank_params,
|
||||
logging_obj=litellm_logging_obj,
|
||||
timeout=optional_params.timeout,
|
||||
api_key=dynamic_api_key or optional_params.api_key,
|
||||
api_base=api_base,
|
||||
_is_async=_is_async,
|
||||
headers=headers,
|
||||
litellm_logging_obj=litellm_logging_obj,
|
||||
headers=headers or litellm.headers or {},
|
||||
client=client,
|
||||
model_response=model_response,
|
||||
)
|
||||
elif _custom_llm_provider == "azure_ai":
|
||||
api_base = (
|
||||
|
@ -185,47 +188,18 @@ def rerank( # noqa: PLR0915
|
|||
or litellm.api_base
|
||||
or get_secret("AZURE_AI_API_BASE") # type: ignore
|
||||
)
|
||||
# set API KEY
|
||||
api_key = (
|
||||
dynamic_api_key
|
||||
or litellm.api_key # for deepinfra/perplexity/anyscale/friendliai we check in get_llm_provider and pass in the api key from there
|
||||
or litellm.openai_key
|
||||
or get_secret("AZURE_AI_API_KEY") # type: ignore
|
||||
)
|
||||
|
||||
headers = headers or litellm.headers or {}
|
||||
|
||||
if api_key is None:
|
||||
raise ValueError(
|
||||
"Azure AI API key is required, please set 'AZURE_AI_API_KEY' in your environment"
|
||||
)
|
||||
|
||||
if api_base is None:
|
||||
raise Exception(
|
||||
"Azure AI API Base is required. api_base=None. Set in call or via `AZURE_AI_API_BASE` env var."
|
||||
)
|
||||
|
||||
## LOAD CONFIG - if set
|
||||
config = litellm.OpenAIConfig.get_config()
|
||||
for k, v in config.items():
|
||||
if (
|
||||
k not in optional_params
|
||||
): # completion(top_k=3) > openai_config(top_k=3) <- allows for dynamic variables to be passed in
|
||||
optional_params[k] = v
|
||||
|
||||
response = azure_ai_rerank.rerank(
|
||||
response = base_llm_http_handler.rerank(
|
||||
model=model,
|
||||
query=query,
|
||||
documents=documents,
|
||||
top_n=top_n,
|
||||
rank_fields=rank_fields,
|
||||
return_documents=return_documents,
|
||||
max_chunks_per_doc=max_chunks_per_doc,
|
||||
api_key=api_key,
|
||||
custom_llm_provider=_custom_llm_provider,
|
||||
optional_rerank_params=optional_rerank_params,
|
||||
logging_obj=litellm_logging_obj,
|
||||
timeout=optional_params.timeout,
|
||||
api_key=dynamic_api_key or optional_params.api_key,
|
||||
api_base=api_base,
|
||||
_is_async=_is_async,
|
||||
headers=headers,
|
||||
litellm_logging_obj=litellm_logging_obj,
|
||||
headers=headers or litellm.headers or {},
|
||||
client=client,
|
||||
model_response=model_response,
|
||||
)
|
||||
elif _custom_llm_provider == "together_ai":
|
||||
# Implement Together AI rerank logic
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue