mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 10:44:24 +00:00
Add cost tracking for rerank via bedrock (#8691)
* feat(bedrock/rerank): infer model region if model given as arn * test: add unit testing to ensure bedrock region name inferred from arn on rerank * feat(bedrock/rerank/transformation.py): include search units for bedrock rerank result Resolves https://github.com/BerriAI/litellm/issues/7258#issuecomment-2671557137 * test(test_bedrock_completion.py): add testing for bedrock cohere rerank * feat(cost_calculator.py): refactor rerank cost tracking to support bedrock cost tracking * build(model_prices_and_context_window.json): add amazon.rerank model to model cost map * fix(cost_calculator.py): bedrock/common_utils.py get base model from model w/ arn -> handles rerank model * build(model_prices_and_context_window.json): add bedrock cohere rerank pricing * feat(bedrock/rerank): migrate bedrock config to basererank config * Revert "feat(bedrock/rerank): migrate bedrock config to basererank config" This reverts commit84fae1f167
. * test: add testing to ensure large doc / queries are correctly counted * Revert "test: add testing to ensure large doc / queries are correctly counted" This reverts commit4337f1657e
. * fix(migrate-jina-ai-to-rerank-config): enables cost tracking * refactor(jina_ai/): finish migrating jina ai to base rerank config enables cost tracking * fix(jina_ai/rerank): e2e jina ai rerank cost tracking * fix: cleanup dead code * fix: fix python3.8 compatibility error * test: fix test * test: add e2e testing for azure ai rerank * fix: fix linting error * test: mark cohere as flaky
This commit is contained in:
parent
4c9517fd78
commit
b682dc4ec8
26 changed files with 524 additions and 296 deletions
|
@ -9,7 +9,6 @@ from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLogging
|
|||
from litellm.llms.base_llm.rerank.transformation import BaseRerankConfig
|
||||
from litellm.llms.bedrock.rerank.handler import BedrockRerankHandler
|
||||
from litellm.llms.custom_httpx.llm_http_handler import BaseLLMHTTPHandler
|
||||
from litellm.llms.jina_ai.rerank.handler import JinaAIRerank
|
||||
from litellm.llms.together_ai.rerank.handler import TogetherAIRerank
|
||||
from litellm.rerank_api.rerank_utils import get_optional_rerank_params
|
||||
from litellm.secret_managers.main import get_secret, get_secret_str
|
||||
|
@ -20,7 +19,6 @@ from litellm.utils import ProviderConfigManager, client, exception_type
|
|||
####### ENVIRONMENT VARIABLES ###################
|
||||
# Initialize any necessary instances or variables here
|
||||
together_rerank = TogetherAIRerank()
|
||||
jina_ai_rerank = JinaAIRerank()
|
||||
bedrock_rerank = BedrockRerankHandler()
|
||||
base_llm_http_handler = BaseLLMHTTPHandler()
|
||||
#################################################
|
||||
|
@ -264,16 +262,26 @@ def rerank( # noqa: PLR0915
|
|||
raise ValueError(
|
||||
"Jina AI API key is required, please set 'JINA_AI_API_KEY' in your environment"
|
||||
)
|
||||
response = jina_ai_rerank.rerank(
|
||||
|
||||
api_base = (
|
||||
dynamic_api_base
|
||||
or optional_params.api_base
|
||||
or litellm.api_base
|
||||
or get_secret("BEDROCK_API_BASE") # type: ignore
|
||||
)
|
||||
|
||||
response = base_llm_http_handler.rerank(
|
||||
model=model,
|
||||
api_key=dynamic_api_key,
|
||||
query=query,
|
||||
documents=documents,
|
||||
top_n=top_n,
|
||||
rank_fields=rank_fields,
|
||||
return_documents=return_documents,
|
||||
max_chunks_per_doc=max_chunks_per_doc,
|
||||
custom_llm_provider=_custom_llm_provider,
|
||||
optional_rerank_params=optional_rerank_params,
|
||||
logging_obj=litellm_logging_obj,
|
||||
timeout=optional_params.timeout,
|
||||
api_key=dynamic_api_key or optional_params.api_key,
|
||||
api_base=api_base,
|
||||
_is_async=_is_async,
|
||||
headers=headers or litellm.headers or {},
|
||||
client=client,
|
||||
model_response=model_response,
|
||||
)
|
||||
elif _custom_llm_provider == "bedrock":
|
||||
api_base = (
|
||||
|
@ -295,6 +303,7 @@ def rerank( # noqa: PLR0915
|
|||
optional_params=optional_params.model_dump(exclude_unset=True),
|
||||
api_base=api_base,
|
||||
logging_obj=litellm_logging_obj,
|
||||
client=client,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unsupported provider: {_custom_llm_provider}")
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue