mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 18:54:30 +00:00
Add cost tracking for rerank via bedrock (#8691)
* feat(bedrock/rerank): infer model region if model given as arn * test: add unit testing to ensure bedrock region name inferred from arn on rerank * feat(bedrock/rerank/transformation.py): include search units for bedrock rerank result Resolves https://github.com/BerriAI/litellm/issues/7258#issuecomment-2671557137 * test(test_bedrock_completion.py): add testing for bedrock cohere rerank * feat(cost_calculator.py): refactor rerank cost tracking to support bedrock cost tracking * build(model_prices_and_context_window.json): add amazon.rerank model to model cost map * fix(cost_calculator.py): bedrock/common_utils.py get base model from model w/ arn -> handles rerank model * build(model_prices_and_context_window.json): add bedrock cohere rerank pricing * feat(bedrock/rerank): migrate bedrock config to basererank config * Revert "feat(bedrock/rerank): migrate bedrock config to basererank config" This reverts commit84fae1f167
. * test: add testing to ensure large doc / queries are correctly counted * Revert "test: add testing to ensure large doc / queries are correctly counted" This reverts commit4337f1657e
. * fix(migrate-jina-ai-to-rerank-config): enables cost tracking * refactor(jina_ai/): finish migrating jina ai to base rerank config enables cost tracking * fix(jina_ai/rerank): e2e jina ai rerank cost tracking * fix: cleanup dead code * fix: fix python3.8 compatibility error * test: fix test * test: add e2e testing for azure ai rerank * fix: fix linting error * test: mark cohere as flaky
This commit is contained in:
parent
4c9517fd78
commit
b682dc4ec8
26 changed files with 524 additions and 296 deletions
|
@ -10,6 +10,7 @@ ALL Bedrock models (Anthropic, Meta, Deepseek, Mistral, Amazon, etc.) are Suppor
|
||||||
| Provider Route on LiteLLM | `bedrock/`, [`bedrock/converse/`](#set-converse--invoke-route), [`bedrock/invoke/`](#set-invoke-route), [`bedrock/converse_like/`](#calling-via-internal-proxy), [`bedrock/llama/`](#deepseek-not-r1), [`bedrock/deepseek_r1/`](#deepseek-r1) |
|
| Provider Route on LiteLLM | `bedrock/`, [`bedrock/converse/`](#set-converse--invoke-route), [`bedrock/invoke/`](#set-invoke-route), [`bedrock/converse_like/`](#calling-via-internal-proxy), [`bedrock/llama/`](#deepseek-not-r1), [`bedrock/deepseek_r1/`](#deepseek-r1) |
|
||||||
| Provider Doc | [Amazon Bedrock ↗](https://docs.aws.amazon.com/bedrock/latest/userguide/what-is-bedrock.html) |
|
| Provider Doc | [Amazon Bedrock ↗](https://docs.aws.amazon.com/bedrock/latest/userguide/what-is-bedrock.html) |
|
||||||
| Supported OpenAI Endpoints | `/chat/completions`, `/completions`, `/embeddings`, `/images/generations` |
|
| Supported OpenAI Endpoints | `/chat/completions`, `/completions`, `/embeddings`, `/images/generations` |
|
||||||
|
| Rerank Endpoint | `/rerank` |
|
||||||
| Pass-through Endpoint | [Supported](../pass_through/bedrock.md) |
|
| Pass-through Endpoint | [Supported](../pass_through/bedrock.md) |
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -820,6 +820,7 @@ from .llms.cohere.completion.transformation import CohereTextConfig as CohereCon
|
||||||
from .llms.cohere.rerank.transformation import CohereRerankConfig
|
from .llms.cohere.rerank.transformation import CohereRerankConfig
|
||||||
from .llms.azure_ai.rerank.transformation import AzureAIRerankConfig
|
from .llms.azure_ai.rerank.transformation import AzureAIRerankConfig
|
||||||
from .llms.infinity.rerank.transformation import InfinityRerankConfig
|
from .llms.infinity.rerank.transformation import InfinityRerankConfig
|
||||||
|
from .llms.jina_ai.rerank.transformation import JinaAIRerankConfig
|
||||||
from .llms.clarifai.chat.transformation import ClarifaiConfig
|
from .llms.clarifai.chat.transformation import ClarifaiConfig
|
||||||
from .llms.ai21.chat.transformation import AI21ChatConfig, AI21ChatConfig as AI21Config
|
from .llms.ai21.chat.transformation import AI21ChatConfig, AI21ChatConfig as AI21Config
|
||||||
from .llms.together_ai.chat import TogetherAIConfig
|
from .llms.together_ai.chat import TogetherAIConfig
|
||||||
|
|
|
@ -16,15 +16,9 @@ from litellm.llms.anthropic.cost_calculation import (
|
||||||
from litellm.llms.azure.cost_calculation import (
|
from litellm.llms.azure.cost_calculation import (
|
||||||
cost_per_token as azure_openai_cost_per_token,
|
cost_per_token as azure_openai_cost_per_token,
|
||||||
)
|
)
|
||||||
from litellm.llms.azure_ai.cost_calculator import (
|
|
||||||
cost_per_query as azure_ai_rerank_cost_per_query,
|
|
||||||
)
|
|
||||||
from litellm.llms.bedrock.image.cost_calculator import (
|
from litellm.llms.bedrock.image.cost_calculator import (
|
||||||
cost_calculator as bedrock_image_cost_calculator,
|
cost_calculator as bedrock_image_cost_calculator,
|
||||||
)
|
)
|
||||||
from litellm.llms.cohere.cost_calculator import (
|
|
||||||
cost_per_query as cohere_rerank_cost_per_query,
|
|
||||||
)
|
|
||||||
from litellm.llms.databricks.cost_calculator import (
|
from litellm.llms.databricks.cost_calculator import (
|
||||||
cost_per_token as databricks_cost_per_token,
|
cost_per_token as databricks_cost_per_token,
|
||||||
)
|
)
|
||||||
|
@ -51,10 +45,12 @@ from litellm.llms.vertex_ai.image_generation.cost_calculator import (
|
||||||
cost_calculator as vertex_ai_image_cost_calculator,
|
cost_calculator as vertex_ai_image_cost_calculator,
|
||||||
)
|
)
|
||||||
from litellm.types.llms.openai import HttpxBinaryResponseContent
|
from litellm.types.llms.openai import HttpxBinaryResponseContent
|
||||||
from litellm.types.rerank import RerankResponse
|
from litellm.types.rerank import RerankBilledUnits, RerankResponse
|
||||||
from litellm.types.utils import (
|
from litellm.types.utils import (
|
||||||
CallTypesLiteral,
|
CallTypesLiteral,
|
||||||
|
LlmProviders,
|
||||||
LlmProvidersSet,
|
LlmProvidersSet,
|
||||||
|
ModelInfo,
|
||||||
PassthroughCallTypes,
|
PassthroughCallTypes,
|
||||||
Usage,
|
Usage,
|
||||||
)
|
)
|
||||||
|
@ -64,6 +60,7 @@ from litellm.utils import (
|
||||||
EmbeddingResponse,
|
EmbeddingResponse,
|
||||||
ImageResponse,
|
ImageResponse,
|
||||||
ModelResponse,
|
ModelResponse,
|
||||||
|
ProviderConfigManager,
|
||||||
TextCompletionResponse,
|
TextCompletionResponse,
|
||||||
TranscriptionResponse,
|
TranscriptionResponse,
|
||||||
_cached_get_model_info_helper,
|
_cached_get_model_info_helper,
|
||||||
|
@ -114,6 +111,8 @@ def cost_per_token( # noqa: PLR0915
|
||||||
number_of_queries: Optional[int] = None,
|
number_of_queries: Optional[int] = None,
|
||||||
### USAGE OBJECT ###
|
### USAGE OBJECT ###
|
||||||
usage_object: Optional[Usage] = None, # just read the usage object if provided
|
usage_object: Optional[Usage] = None, # just read the usage object if provided
|
||||||
|
### BILLED UNITS ###
|
||||||
|
rerank_billed_units: Optional[RerankBilledUnits] = None,
|
||||||
### CALL TYPE ###
|
### CALL TYPE ###
|
||||||
call_type: CallTypesLiteral = "completion",
|
call_type: CallTypesLiteral = "completion",
|
||||||
audio_transcription_file_duration: float = 0.0, # for audio transcription calls - the file time in seconds
|
audio_transcription_file_duration: float = 0.0, # for audio transcription calls - the file time in seconds
|
||||||
|
@ -238,6 +237,7 @@ def cost_per_token( # noqa: PLR0915
|
||||||
return rerank_cost(
|
return rerank_cost(
|
||||||
model=model,
|
model=model,
|
||||||
custom_llm_provider=custom_llm_provider,
|
custom_llm_provider=custom_llm_provider,
|
||||||
|
billed_units=rerank_billed_units,
|
||||||
)
|
)
|
||||||
elif call_type == "atranscription" or call_type == "transcription":
|
elif call_type == "atranscription" or call_type == "transcription":
|
||||||
return openai_cost_per_second(
|
return openai_cost_per_second(
|
||||||
|
@ -552,6 +552,7 @@ def completion_cost( # noqa: PLR0915
|
||||||
cost_per_token_usage_object: Optional[Usage] = _get_usage_object(
|
cost_per_token_usage_object: Optional[Usage] = _get_usage_object(
|
||||||
completion_response=completion_response
|
completion_response=completion_response
|
||||||
)
|
)
|
||||||
|
rerank_billed_units: Optional[RerankBilledUnits] = None
|
||||||
model = _select_model_name_for_cost_calc(
|
model = _select_model_name_for_cost_calc(
|
||||||
model=model,
|
model=model,
|
||||||
completion_response=completion_response,
|
completion_response=completion_response,
|
||||||
|
@ -698,6 +699,11 @@ def completion_cost( # noqa: PLR0915
|
||||||
else:
|
else:
|
||||||
billed_units = {}
|
billed_units = {}
|
||||||
|
|
||||||
|
rerank_billed_units = RerankBilledUnits(
|
||||||
|
search_units=billed_units.get("search_units"),
|
||||||
|
total_tokens=billed_units.get("total_tokens"),
|
||||||
|
)
|
||||||
|
|
||||||
search_units = (
|
search_units = (
|
||||||
billed_units.get("search_units") or 1
|
billed_units.get("search_units") or 1
|
||||||
) # cohere charges per request by default.
|
) # cohere charges per request by default.
|
||||||
|
@ -763,6 +769,7 @@ def completion_cost( # noqa: PLR0915
|
||||||
usage_object=cost_per_token_usage_object,
|
usage_object=cost_per_token_usage_object,
|
||||||
call_type=call_type,
|
call_type=call_type,
|
||||||
audio_transcription_file_duration=audio_transcription_file_duration,
|
audio_transcription_file_duration=audio_transcription_file_duration,
|
||||||
|
rerank_billed_units=rerank_billed_units,
|
||||||
)
|
)
|
||||||
_final_cost = prompt_tokens_cost_usd_dollar + completion_tokens_cost_usd_dollar
|
_final_cost = prompt_tokens_cost_usd_dollar + completion_tokens_cost_usd_dollar
|
||||||
|
|
||||||
|
@ -836,27 +843,33 @@ def response_cost_calculator(
|
||||||
def rerank_cost(
|
def rerank_cost(
|
||||||
model: str,
|
model: str,
|
||||||
custom_llm_provider: Optional[str],
|
custom_llm_provider: Optional[str],
|
||||||
|
billed_units: Optional[RerankBilledUnits] = None,
|
||||||
) -> Tuple[float, float]:
|
) -> Tuple[float, float]:
|
||||||
"""
|
"""
|
||||||
Returns
|
Returns
|
||||||
- float or None: cost of response OR none if error.
|
- float or None: cost of response OR none if error.
|
||||||
"""
|
"""
|
||||||
default_num_queries = 1
|
|
||||||
_, custom_llm_provider, _, _ = litellm.get_llm_provider(
|
_, custom_llm_provider, _, _ = litellm.get_llm_provider(
|
||||||
model=model, custom_llm_provider=custom_llm_provider
|
model=model, custom_llm_provider=custom_llm_provider
|
||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if custom_llm_provider == "cohere":
|
config = ProviderConfigManager.get_provider_rerank_config(
|
||||||
return cohere_rerank_cost_per_query(
|
model=model, provider=LlmProviders(custom_llm_provider)
|
||||||
model=model, num_queries=default_num_queries
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
model_info: Optional[ModelInfo] = litellm.get_model_info(
|
||||||
|
model=model, custom_llm_provider=custom_llm_provider
|
||||||
)
|
)
|
||||||
elif custom_llm_provider == "azure_ai":
|
except Exception:
|
||||||
return azure_ai_rerank_cost_per_query(
|
model_info = None
|
||||||
model=model, num_queries=default_num_queries
|
|
||||||
)
|
return config.calculate_rerank_cost(
|
||||||
raise ValueError(
|
model=model,
|
||||||
f"invalid custom_llm_provider for rerank model: {model}, custom_llm_provider: {custom_llm_provider}"
|
custom_llm_provider=custom_llm_provider,
|
||||||
|
billed_units=billed_units,
|
||||||
|
model_info=model_info,
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise e
|
raise e
|
||||||
|
|
|
@ -1,32 +0,0 @@
|
||||||
"""
|
|
||||||
Handles custom cost calculation for Azure AI models.
|
|
||||||
|
|
||||||
Custom cost calculation for Azure AI models only requied for rerank.
|
|
||||||
"""
|
|
||||||
|
|
||||||
from typing import Tuple
|
|
||||||
|
|
||||||
from litellm.utils import get_model_info
|
|
||||||
|
|
||||||
|
|
||||||
def cost_per_query(model: str, num_queries: int = 1) -> Tuple[float, float]:
|
|
||||||
"""
|
|
||||||
Calculates the cost per query for a given rerank model.
|
|
||||||
|
|
||||||
Input:
|
|
||||||
- model: str, the model name without provider prefix
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Tuple[float, float] - prompt_cost_in_usd, completion_cost_in_usd
|
|
||||||
"""
|
|
||||||
model_info = get_model_info(model=model, custom_llm_provider="azure_ai")
|
|
||||||
|
|
||||||
if (
|
|
||||||
"input_cost_per_query" not in model_info
|
|
||||||
or model_info["input_cost_per_query"] is None
|
|
||||||
):
|
|
||||||
return 0.0, 0.0
|
|
||||||
|
|
||||||
prompt_cost = model_info["input_cost_per_query"] * num_queries
|
|
||||||
|
|
||||||
return prompt_cost, 0.0
|
|
|
@ -1,9 +1,10 @@
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
|
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
|
|
||||||
from litellm.types.rerank import OptionalRerankParams, RerankResponse
|
from litellm.types.rerank import OptionalRerankParams, RerankBilledUnits, RerankResponse
|
||||||
|
from litellm.types.utils import ModelInfo
|
||||||
|
|
||||||
from ..chat.transformation import BaseLLMException
|
from ..chat.transformation import BaseLLMException
|
||||||
|
|
||||||
|
@ -66,7 +67,7 @@ class BaseRerankConfig(ABC):
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def map_cohere_rerank_params(
|
def map_cohere_rerank_params(
|
||||||
self,
|
self,
|
||||||
non_default_params: Optional[dict],
|
non_default_params: dict,
|
||||||
model: str,
|
model: str,
|
||||||
drop_params: bool,
|
drop_params: bool,
|
||||||
query: str,
|
query: str,
|
||||||
|
@ -79,8 +80,48 @@ class BaseRerankConfig(ABC):
|
||||||
) -> OptionalRerankParams:
|
) -> OptionalRerankParams:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def get_error_class(
|
def get_error_class(
|
||||||
self, error_message: str, status_code: int, headers: Union[dict, httpx.Headers]
|
self, error_message: str, status_code: int, headers: Union[dict, httpx.Headers]
|
||||||
) -> BaseLLMException:
|
) -> BaseLLMException:
|
||||||
pass
|
raise BaseLLMException(
|
||||||
|
status_code=status_code,
|
||||||
|
message=error_message,
|
||||||
|
headers=headers,
|
||||||
|
)
|
||||||
|
|
||||||
|
def calculate_rerank_cost(
|
||||||
|
self,
|
||||||
|
model: str,
|
||||||
|
custom_llm_provider: Optional[str] = None,
|
||||||
|
billed_units: Optional[RerankBilledUnits] = None,
|
||||||
|
model_info: Optional[ModelInfo] = None,
|
||||||
|
) -> Tuple[float, float]:
|
||||||
|
"""
|
||||||
|
Calculates the cost per query for a given rerank model.
|
||||||
|
|
||||||
|
Input:
|
||||||
|
- model: str, the model name without provider prefix
|
||||||
|
- custom_llm_provider: str, the provider used for the model. If provided, used to check if the litellm model info is for that provider.
|
||||||
|
- num_queries: int, the number of queries to calculate the cost for
|
||||||
|
- model_info: ModelInfo, the model info for the given model
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple[float, float] - prompt_cost_in_usd, completion_cost_in_usd
|
||||||
|
"""
|
||||||
|
|
||||||
|
if (
|
||||||
|
model_info is None
|
||||||
|
or "input_cost_per_query" not in model_info
|
||||||
|
or model_info["input_cost_per_query"] is None
|
||||||
|
or billed_units is None
|
||||||
|
):
|
||||||
|
return 0.0, 0.0
|
||||||
|
|
||||||
|
search_units = billed_units.get("search_units")
|
||||||
|
|
||||||
|
if search_units is None:
|
||||||
|
return 0.0, 0.0
|
||||||
|
|
||||||
|
prompt_cost = model_info["input_cost_per_query"] * search_units
|
||||||
|
|
||||||
|
return prompt_cost, 0.0
|
||||||
|
|
|
@ -202,6 +202,61 @@ class BaseAWSLLM:
|
||||||
self.iam_cache.set_cache(cache_key, credentials, ttl=_cache_ttl)
|
self.iam_cache.set_cache(cache_key, credentials, ttl=_cache_ttl)
|
||||||
return credentials
|
return credentials
|
||||||
|
|
||||||
|
def _get_aws_region_from_model_arn(self, model: Optional[str]) -> Optional[str]:
|
||||||
|
try:
|
||||||
|
# First check if the string contains the expected prefix
|
||||||
|
if not isinstance(model, str) or "arn:aws:bedrock" not in model:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Split the ARN and check if we have enough parts
|
||||||
|
parts = model.split(":")
|
||||||
|
if len(parts) < 4:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Get the region from the correct position
|
||||||
|
region = parts[3]
|
||||||
|
if not region: # Check if region is empty
|
||||||
|
return None
|
||||||
|
|
||||||
|
return region
|
||||||
|
except Exception:
|
||||||
|
# Catch any unexpected errors and return None
|
||||||
|
return None
|
||||||
|
|
||||||
|
def _get_aws_region_name(
|
||||||
|
self, optional_params: dict, model: Optional[str] = None
|
||||||
|
) -> str:
|
||||||
|
"""
|
||||||
|
Get the AWS region name from the environment variables
|
||||||
|
"""
|
||||||
|
aws_region_name = optional_params.get("aws_region_name", None)
|
||||||
|
### SET REGION NAME ###
|
||||||
|
if aws_region_name is None:
|
||||||
|
# check model arn #
|
||||||
|
aws_region_name = self._get_aws_region_from_model_arn(model)
|
||||||
|
# check env #
|
||||||
|
litellm_aws_region_name = get_secret("AWS_REGION_NAME", None)
|
||||||
|
|
||||||
|
if (
|
||||||
|
aws_region_name is None
|
||||||
|
and litellm_aws_region_name is not None
|
||||||
|
and isinstance(litellm_aws_region_name, str)
|
||||||
|
):
|
||||||
|
aws_region_name = litellm_aws_region_name
|
||||||
|
|
||||||
|
standard_aws_region_name = get_secret("AWS_REGION", None)
|
||||||
|
if (
|
||||||
|
aws_region_name is None
|
||||||
|
and standard_aws_region_name is not None
|
||||||
|
and isinstance(standard_aws_region_name, str)
|
||||||
|
):
|
||||||
|
aws_region_name = standard_aws_region_name
|
||||||
|
|
||||||
|
if aws_region_name is None:
|
||||||
|
aws_region_name = "us-west-2"
|
||||||
|
|
||||||
|
return aws_region_name
|
||||||
|
|
||||||
@tracer.wrap()
|
@tracer.wrap()
|
||||||
def _auth_with_web_identity_token(
|
def _auth_with_web_identity_token(
|
||||||
self,
|
self,
|
||||||
|
@ -423,7 +478,7 @@ class BaseAWSLLM:
|
||||||
return endpoint_url, proxy_endpoint_url
|
return endpoint_url, proxy_endpoint_url
|
||||||
|
|
||||||
def _get_boto_credentials_from_optional_params(
|
def _get_boto_credentials_from_optional_params(
|
||||||
self, optional_params: dict
|
self, optional_params: dict, model: Optional[str] = None
|
||||||
) -> Boto3CredentialsInfo:
|
) -> Boto3CredentialsInfo:
|
||||||
"""
|
"""
|
||||||
Get boto3 credentials from optional params
|
Get boto3 credentials from optional params
|
||||||
|
@ -443,7 +498,7 @@ class BaseAWSLLM:
|
||||||
aws_secret_access_key = optional_params.pop("aws_secret_access_key", None)
|
aws_secret_access_key = optional_params.pop("aws_secret_access_key", None)
|
||||||
aws_access_key_id = optional_params.pop("aws_access_key_id", None)
|
aws_access_key_id = optional_params.pop("aws_access_key_id", None)
|
||||||
aws_session_token = optional_params.pop("aws_session_token", None)
|
aws_session_token = optional_params.pop("aws_session_token", None)
|
||||||
aws_region_name = optional_params.pop("aws_region_name", None)
|
aws_region_name = self._get_aws_region_name(optional_params, model)
|
||||||
aws_role_name = optional_params.pop("aws_role_name", None)
|
aws_role_name = optional_params.pop("aws_role_name", None)
|
||||||
aws_session_name = optional_params.pop("aws_session_name", None)
|
aws_session_name = optional_params.pop("aws_session_name", None)
|
||||||
aws_profile_name = optional_params.pop("aws_profile_name", None)
|
aws_profile_name = optional_params.pop("aws_profile_name", None)
|
||||||
|
@ -453,25 +508,6 @@ class BaseAWSLLM:
|
||||||
"aws_bedrock_runtime_endpoint", None
|
"aws_bedrock_runtime_endpoint", None
|
||||||
) # https://bedrock-runtime.{region_name}.amazonaws.com
|
) # https://bedrock-runtime.{region_name}.amazonaws.com
|
||||||
|
|
||||||
### SET REGION NAME ###
|
|
||||||
if aws_region_name is None:
|
|
||||||
# check env #
|
|
||||||
litellm_aws_region_name = get_secret_str("AWS_REGION_NAME", None)
|
|
||||||
|
|
||||||
if litellm_aws_region_name is not None and isinstance(
|
|
||||||
litellm_aws_region_name, str
|
|
||||||
):
|
|
||||||
aws_region_name = litellm_aws_region_name
|
|
||||||
|
|
||||||
standard_aws_region_name = get_secret_str("AWS_REGION", None)
|
|
||||||
if standard_aws_region_name is not None and isinstance(
|
|
||||||
standard_aws_region_name, str
|
|
||||||
):
|
|
||||||
aws_region_name = standard_aws_region_name
|
|
||||||
|
|
||||||
if aws_region_name is None:
|
|
||||||
aws_region_name = "us-west-2"
|
|
||||||
|
|
||||||
credentials: Credentials = self.get_credentials(
|
credentials: Credentials = self.get_credentials(
|
||||||
aws_access_key_id=aws_access_key_id,
|
aws_access_key_id=aws_access_key_id,
|
||||||
aws_secret_access_key=aws_secret_access_key,
|
aws_secret_access_key=aws_secret_access_key,
|
||||||
|
|
|
@ -27,7 +27,7 @@ from litellm.llms.custom_httpx.http_handler import (
|
||||||
)
|
)
|
||||||
from litellm.types.llms.openai import AllMessageValues
|
from litellm.types.llms.openai import AllMessageValues
|
||||||
from litellm.types.utils import ModelResponse, Usage
|
from litellm.types.utils import ModelResponse, Usage
|
||||||
from litellm.utils import CustomStreamWrapper, get_secret
|
from litellm.utils import CustomStreamWrapper
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from litellm.litellm_core_utils.litellm_logging import Logging as _LiteLLMLoggingObj
|
from litellm.litellm_core_utils.litellm_logging import Logging as _LiteLLMLoggingObj
|
||||||
|
@ -598,61 +598,6 @@ class AmazonInvokeConfig(BaseConfig, BaseAWSLLM):
|
||||||
)
|
)
|
||||||
return modelId
|
return modelId
|
||||||
|
|
||||||
def get_aws_region_from_model_arn(self, model: Optional[str]) -> Optional[str]:
|
|
||||||
try:
|
|
||||||
# First check if the string contains the expected prefix
|
|
||||||
if not isinstance(model, str) or "arn:aws:bedrock" not in model:
|
|
||||||
return None
|
|
||||||
|
|
||||||
# Split the ARN and check if we have enough parts
|
|
||||||
parts = model.split(":")
|
|
||||||
if len(parts) < 4:
|
|
||||||
return None
|
|
||||||
|
|
||||||
# Get the region from the correct position
|
|
||||||
region = parts[3]
|
|
||||||
if not region: # Check if region is empty
|
|
||||||
return None
|
|
||||||
|
|
||||||
return region
|
|
||||||
except Exception:
|
|
||||||
# Catch any unexpected errors and return None
|
|
||||||
return None
|
|
||||||
|
|
||||||
def _get_aws_region_name(
|
|
||||||
self, optional_params: dict, model: Optional[str] = None
|
|
||||||
) -> str:
|
|
||||||
"""
|
|
||||||
Get the AWS region name from the environment variables
|
|
||||||
"""
|
|
||||||
aws_region_name = optional_params.get("aws_region_name", None)
|
|
||||||
### SET REGION NAME ###
|
|
||||||
if aws_region_name is None:
|
|
||||||
# check model arn #
|
|
||||||
aws_region_name = self.get_aws_region_from_model_arn(model)
|
|
||||||
# check env #
|
|
||||||
litellm_aws_region_name = get_secret("AWS_REGION_NAME", None)
|
|
||||||
|
|
||||||
if (
|
|
||||||
aws_region_name is None
|
|
||||||
and litellm_aws_region_name is not None
|
|
||||||
and isinstance(litellm_aws_region_name, str)
|
|
||||||
):
|
|
||||||
aws_region_name = litellm_aws_region_name
|
|
||||||
|
|
||||||
standard_aws_region_name = get_secret("AWS_REGION", None)
|
|
||||||
if (
|
|
||||||
aws_region_name is None
|
|
||||||
and standard_aws_region_name is not None
|
|
||||||
and isinstance(standard_aws_region_name, str)
|
|
||||||
):
|
|
||||||
aws_region_name = standard_aws_region_name
|
|
||||||
|
|
||||||
if aws_region_name is None:
|
|
||||||
aws_region_name = "us-west-2"
|
|
||||||
|
|
||||||
return aws_region_name
|
|
||||||
|
|
||||||
def _get_model_id_from_model_with_spec(
|
def _get_model_id_from_model_with_spec(
|
||||||
self,
|
self,
|
||||||
model: str,
|
model: str,
|
||||||
|
|
|
@ -318,6 +318,23 @@ class BedrockModelInfo(BaseLLMModelInfo):
|
||||||
global_config = AmazonBedrockGlobalConfig()
|
global_config = AmazonBedrockGlobalConfig()
|
||||||
all_global_regions = global_config.get_all_regions()
|
all_global_regions = global_config.get_all_regions()
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def extract_model_name_from_arn(model: str) -> str:
|
||||||
|
"""
|
||||||
|
Extract the model name from an AWS Bedrock ARN.
|
||||||
|
Returns the string after the last '/' if 'arn' is in the input string.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
arn (str): The ARN string to parse
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: The extracted model name if 'arn' is in the string,
|
||||||
|
otherwise returns the original string
|
||||||
|
"""
|
||||||
|
if "arn" in model.lower():
|
||||||
|
return model.split("/")[-1]
|
||||||
|
return model
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_base_model(model: str) -> str:
|
def get_base_model(model: str) -> str:
|
||||||
"""
|
"""
|
||||||
|
@ -335,6 +352,8 @@ class BedrockModelInfo(BaseLLMModelInfo):
|
||||||
if model.startswith("invoke/"):
|
if model.startswith("invoke/"):
|
||||||
model = model.split("/", 1)[1]
|
model = model.split("/", 1)[1]
|
||||||
|
|
||||||
|
model = BedrockModelInfo.extract_model_name_from_arn(model)
|
||||||
|
|
||||||
potential_region = model.split(".", 1)[0]
|
potential_region = model.split(".", 1)[0]
|
||||||
|
|
||||||
alt_potential_region = model.split("/", 1)[
|
alt_potential_region = model.split("/", 1)[
|
||||||
|
|
|
@ -163,7 +163,7 @@ class BedrockImageGeneration(BaseAWSLLM):
|
||||||
except ImportError:
|
except ImportError:
|
||||||
raise ImportError("Missing boto3 to call bedrock. Run 'pip install boto3'.")
|
raise ImportError("Missing boto3 to call bedrock. Run 'pip install boto3'.")
|
||||||
boto3_credentials_info = self._get_boto_credentials_from_optional_params(
|
boto3_credentials_info = self._get_boto_credentials_from_optional_params(
|
||||||
optional_params
|
optional_params, model
|
||||||
)
|
)
|
||||||
|
|
||||||
### SET RUNTIME ENDPOINT ###
|
### SET RUNTIME ENDPOINT ###
|
||||||
|
|
|
@ -6,6 +6,8 @@ import httpx
|
||||||
import litellm
|
import litellm
|
||||||
from litellm.litellm_core_utils.litellm_logging import Logging as LitellmLogging
|
from litellm.litellm_core_utils.litellm_logging import Logging as LitellmLogging
|
||||||
from litellm.llms.custom_httpx.http_handler import (
|
from litellm.llms.custom_httpx.http_handler import (
|
||||||
|
AsyncHTTPHandler,
|
||||||
|
HTTPHandler,
|
||||||
_get_httpx_client,
|
_get_httpx_client,
|
||||||
get_async_httpx_client,
|
get_async_httpx_client,
|
||||||
)
|
)
|
||||||
|
@ -27,8 +29,10 @@ class BedrockRerankHandler(BaseAWSLLM):
|
||||||
async def arerank(
|
async def arerank(
|
||||||
self,
|
self,
|
||||||
prepared_request: BedrockPreparedRequest,
|
prepared_request: BedrockPreparedRequest,
|
||||||
|
client: Optional[AsyncHTTPHandler] = None,
|
||||||
):
|
):
|
||||||
client = get_async_httpx_client(llm_provider=litellm.LlmProviders.BEDROCK)
|
if client is None:
|
||||||
|
client = get_async_httpx_client(llm_provider=litellm.LlmProviders.BEDROCK)
|
||||||
try:
|
try:
|
||||||
response = await client.post(url=prepared_request["endpoint_url"], headers=prepared_request["prepped"].headers, data=prepared_request["body"]) # type: ignore
|
response = await client.post(url=prepared_request["endpoint_url"], headers=prepared_request["prepped"].headers, data=prepared_request["body"]) # type: ignore
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
|
@ -54,7 +58,9 @@ class BedrockRerankHandler(BaseAWSLLM):
|
||||||
_is_async: Optional[bool] = False,
|
_is_async: Optional[bool] = False,
|
||||||
api_base: Optional[str] = None,
|
api_base: Optional[str] = None,
|
||||||
extra_headers: Optional[dict] = None,
|
extra_headers: Optional[dict] = None,
|
||||||
|
client: Optional[Union[HTTPHandler, AsyncHTTPHandler]] = None,
|
||||||
) -> RerankResponse:
|
) -> RerankResponse:
|
||||||
|
|
||||||
request_data = RerankRequest(
|
request_data = RerankRequest(
|
||||||
model=model,
|
model=model,
|
||||||
query=query,
|
query=query,
|
||||||
|
@ -66,6 +72,7 @@ class BedrockRerankHandler(BaseAWSLLM):
|
||||||
data = BedrockRerankConfig()._transform_request(request_data)
|
data = BedrockRerankConfig()._transform_request(request_data)
|
||||||
|
|
||||||
prepared_request = self._prepare_request(
|
prepared_request = self._prepare_request(
|
||||||
|
model=model,
|
||||||
optional_params=optional_params,
|
optional_params=optional_params,
|
||||||
api_base=api_base,
|
api_base=api_base,
|
||||||
extra_headers=extra_headers,
|
extra_headers=extra_headers,
|
||||||
|
@ -83,9 +90,10 @@ class BedrockRerankHandler(BaseAWSLLM):
|
||||||
)
|
)
|
||||||
|
|
||||||
if _is_async:
|
if _is_async:
|
||||||
return self.arerank(prepared_request) # type: ignore
|
return self.arerank(prepared_request, client=client if client is not None and isinstance(client, AsyncHTTPHandler) else None) # type: ignore
|
||||||
|
|
||||||
client = _get_httpx_client()
|
if client is None or not isinstance(client, HTTPHandler):
|
||||||
|
client = _get_httpx_client()
|
||||||
try:
|
try:
|
||||||
response = client.post(url=prepared_request["endpoint_url"], headers=prepared_request["prepped"].headers, data=prepared_request["body"]) # type: ignore
|
response = client.post(url=prepared_request["endpoint_url"], headers=prepared_request["prepped"].headers, data=prepared_request["body"]) # type: ignore
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
|
@ -95,10 +103,18 @@ class BedrockRerankHandler(BaseAWSLLM):
|
||||||
except httpx.TimeoutException:
|
except httpx.TimeoutException:
|
||||||
raise BedrockError(status_code=408, message="Timeout error occurred.")
|
raise BedrockError(status_code=408, message="Timeout error occurred.")
|
||||||
|
|
||||||
return BedrockRerankConfig()._transform_response(response.json())
|
logging_obj.post_call(
|
||||||
|
original_response=response.text,
|
||||||
|
api_key="",
|
||||||
|
)
|
||||||
|
|
||||||
|
response_json = response.json()
|
||||||
|
|
||||||
|
return BedrockRerankConfig()._transform_response(response_json)
|
||||||
|
|
||||||
def _prepare_request(
|
def _prepare_request(
|
||||||
self,
|
self,
|
||||||
|
model: str,
|
||||||
api_base: Optional[str],
|
api_base: Optional[str],
|
||||||
extra_headers: Optional[dict],
|
extra_headers: Optional[dict],
|
||||||
data: dict,
|
data: dict,
|
||||||
|
@ -110,7 +126,7 @@ class BedrockRerankHandler(BaseAWSLLM):
|
||||||
except ImportError:
|
except ImportError:
|
||||||
raise ImportError("Missing boto3 to call bedrock. Run 'pip install boto3'.")
|
raise ImportError("Missing boto3 to call bedrock. Run 'pip install boto3'.")
|
||||||
boto3_credentials_info = self._get_boto_credentials_from_optional_params(
|
boto3_credentials_info = self._get_boto_credentials_from_optional_params(
|
||||||
optional_params
|
optional_params, model
|
||||||
)
|
)
|
||||||
|
|
||||||
### SET RUNTIME ENDPOINT ###
|
### SET RUNTIME ENDPOINT ###
|
||||||
|
|
|
@ -91,7 +91,9 @@ class BedrockRerankConfig:
|
||||||
example input:
|
example input:
|
||||||
{"results":[{"index":0,"relevanceScore":0.6847912669181824},{"index":1,"relevanceScore":0.5980774760246277}]}
|
{"results":[{"index":0,"relevanceScore":0.6847912669181824},{"index":1,"relevanceScore":0.5980774760246277}]}
|
||||||
"""
|
"""
|
||||||
_billed_units = RerankBilledUnits(**response.get("usage", {}))
|
_billed_units = RerankBilledUnits(
|
||||||
|
**response.get("usage", {"search_units": 1})
|
||||||
|
) # by default 1 search unit
|
||||||
_tokens = RerankTokens(**response.get("usage", {}))
|
_tokens = RerankTokens(**response.get("usage", {}))
|
||||||
rerank_meta = RerankResponseMeta(billed_units=_billed_units, tokens=_tokens)
|
rerank_meta = RerankResponseMeta(billed_units=_billed_units, tokens=_tokens)
|
||||||
|
|
||||||
|
|
|
@ -1,31 +0,0 @@
|
||||||
"""
|
|
||||||
Custom cost calculator for Cohere rerank models
|
|
||||||
"""
|
|
||||||
|
|
||||||
from typing import Tuple
|
|
||||||
|
|
||||||
from litellm.utils import get_model_info
|
|
||||||
|
|
||||||
|
|
||||||
def cost_per_query(model: str, num_queries: int = 1) -> Tuple[float, float]:
|
|
||||||
"""
|
|
||||||
Calculates the cost per query for a given rerank model.
|
|
||||||
|
|
||||||
Input:
|
|
||||||
- model: str, the model name without provider prefix
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Tuple[float, float] - prompt_cost_in_usd, completion_cost_in_usd
|
|
||||||
"""
|
|
||||||
|
|
||||||
model_info = get_model_info(model=model, custom_llm_provider="cohere")
|
|
||||||
|
|
||||||
if (
|
|
||||||
"input_cost_per_query" not in model_info
|
|
||||||
or model_info["input_cost_per_query"] is None
|
|
||||||
):
|
|
||||||
return 0.0, 0.0
|
|
||||||
|
|
||||||
prompt_cost = model_info["input_cost_per_query"] * num_queries
|
|
||||||
|
|
||||||
return prompt_cost, 0.0
|
|
|
@ -1,92 +1,3 @@
|
||||||
"""
|
"""
|
||||||
Re rank api
|
HTTP calling migrated to `llm_http_handler.py`
|
||||||
|
|
||||||
LiteLLM supports the re rank API format, no paramter transformation occurs
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from typing import Any, Dict, List, Optional, Union
|
|
||||||
|
|
||||||
import litellm
|
|
||||||
from litellm.llms.base import BaseLLM
|
|
||||||
from litellm.llms.custom_httpx.http_handler import (
|
|
||||||
_get_httpx_client,
|
|
||||||
get_async_httpx_client,
|
|
||||||
)
|
|
||||||
from litellm.llms.jina_ai.rerank.transformation import JinaAIRerankConfig
|
|
||||||
from litellm.types.rerank import RerankRequest, RerankResponse
|
|
||||||
|
|
||||||
|
|
||||||
class JinaAIRerank(BaseLLM):
|
|
||||||
def rerank(
|
|
||||||
self,
|
|
||||||
model: str,
|
|
||||||
api_key: str,
|
|
||||||
query: str,
|
|
||||||
documents: List[Union[str, Dict[str, Any]]],
|
|
||||||
top_n: Optional[int] = None,
|
|
||||||
rank_fields: Optional[List[str]] = None,
|
|
||||||
return_documents: Optional[bool] = True,
|
|
||||||
max_chunks_per_doc: Optional[int] = None,
|
|
||||||
_is_async: Optional[bool] = False,
|
|
||||||
) -> RerankResponse:
|
|
||||||
client = _get_httpx_client()
|
|
||||||
|
|
||||||
request_data = RerankRequest(
|
|
||||||
model=model,
|
|
||||||
query=query,
|
|
||||||
top_n=top_n,
|
|
||||||
documents=documents,
|
|
||||||
rank_fields=rank_fields,
|
|
||||||
return_documents=return_documents,
|
|
||||||
)
|
|
||||||
|
|
||||||
# exclude None values from request_data
|
|
||||||
request_data_dict = request_data.dict(exclude_none=True)
|
|
||||||
|
|
||||||
if _is_async:
|
|
||||||
return self.async_rerank(request_data_dict, api_key) # type: ignore # Call async method
|
|
||||||
|
|
||||||
response = client.post(
|
|
||||||
"https://api.jina.ai/v1/rerank",
|
|
||||||
headers={
|
|
||||||
"accept": "application/json",
|
|
||||||
"content-type": "application/json",
|
|
||||||
"authorization": f"Bearer {api_key}",
|
|
||||||
},
|
|
||||||
json=request_data_dict,
|
|
||||||
)
|
|
||||||
|
|
||||||
if response.status_code != 200:
|
|
||||||
raise Exception(response.text)
|
|
||||||
|
|
||||||
_json_response = response.json()
|
|
||||||
|
|
||||||
return JinaAIRerankConfig()._transform_response(_json_response)
|
|
||||||
|
|
||||||
async def async_rerank( # New async method
|
|
||||||
self,
|
|
||||||
request_data_dict: Dict[str, Any],
|
|
||||||
api_key: str,
|
|
||||||
) -> RerankResponse:
|
|
||||||
client = get_async_httpx_client(
|
|
||||||
llm_provider=litellm.LlmProviders.JINA_AI
|
|
||||||
) # Use async client
|
|
||||||
|
|
||||||
response = await client.post(
|
|
||||||
"https://api.jina.ai/v1/rerank",
|
|
||||||
headers={
|
|
||||||
"accept": "application/json",
|
|
||||||
"content-type": "application/json",
|
|
||||||
"authorization": f"Bearer {api_key}",
|
|
||||||
},
|
|
||||||
json=request_data_dict,
|
|
||||||
)
|
|
||||||
|
|
||||||
if response.status_code != 200:
|
|
||||||
raise Exception(response.text)
|
|
||||||
|
|
||||||
_json_response = response.json()
|
|
||||||
|
|
||||||
return JinaAIRerankConfig()._transform_response(_json_response)
|
|
||||||
|
|
||||||
pass
|
|
||||||
|
|
|
@ -7,30 +7,136 @@ Docs - https://jina.ai/reranker
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import uuid
|
import uuid
|
||||||
from typing import List, Optional
|
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||||
|
|
||||||
|
from httpx import URL, Response
|
||||||
|
|
||||||
|
from litellm.llms.base_llm.chat.transformation import LiteLLMLoggingObj
|
||||||
|
from litellm.llms.base_llm.rerank.transformation import BaseRerankConfig
|
||||||
from litellm.types.rerank import (
|
from litellm.types.rerank import (
|
||||||
|
OptionalRerankParams,
|
||||||
RerankBilledUnits,
|
RerankBilledUnits,
|
||||||
RerankResponse,
|
RerankResponse,
|
||||||
RerankResponseMeta,
|
RerankResponseMeta,
|
||||||
RerankTokens,
|
RerankTokens,
|
||||||
)
|
)
|
||||||
|
from litellm.types.utils import ModelInfo
|
||||||
|
|
||||||
|
|
||||||
class JinaAIRerankConfig:
|
class JinaAIRerankConfig(BaseRerankConfig):
|
||||||
def _transform_response(self, response: dict) -> RerankResponse:
|
def get_supported_cohere_rerank_params(self, model: str) -> list:
|
||||||
|
return [
|
||||||
|
"query",
|
||||||
|
"top_n",
|
||||||
|
"documents",
|
||||||
|
"return_documents",
|
||||||
|
]
|
||||||
|
|
||||||
_billed_units = RerankBilledUnits(**response.get("usage", {}))
|
def map_cohere_rerank_params(
|
||||||
_tokens = RerankTokens(**response.get("usage", {}))
|
self,
|
||||||
|
non_default_params: dict,
|
||||||
|
model: str,
|
||||||
|
drop_params: bool,
|
||||||
|
query: str,
|
||||||
|
documents: List[Union[str, Dict[str, Any]]],
|
||||||
|
custom_llm_provider: Optional[str] = None,
|
||||||
|
top_n: Optional[int] = None,
|
||||||
|
rank_fields: Optional[List[str]] = None,
|
||||||
|
return_documents: Optional[bool] = True,
|
||||||
|
max_chunks_per_doc: Optional[int] = None,
|
||||||
|
) -> OptionalRerankParams:
|
||||||
|
optional_params = {}
|
||||||
|
supported_params = self.get_supported_cohere_rerank_params(model)
|
||||||
|
for k, v in non_default_params.items():
|
||||||
|
if k in supported_params:
|
||||||
|
optional_params[k] = v
|
||||||
|
return OptionalRerankParams(
|
||||||
|
**optional_params,
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_complete_url(self, api_base: Optional[str], model: str) -> str:
|
||||||
|
base_path = "/v1/rerank"
|
||||||
|
|
||||||
|
if api_base is None:
|
||||||
|
return "https://api.jina.ai/v1/rerank"
|
||||||
|
base = URL(api_base)
|
||||||
|
# Reconstruct URL with cleaned path
|
||||||
|
cleaned_base = str(base.copy_with(path=base_path))
|
||||||
|
|
||||||
|
return cleaned_base
|
||||||
|
|
||||||
|
def transform_rerank_request(
|
||||||
|
self, model: str, optional_rerank_params: OptionalRerankParams, headers: Dict
|
||||||
|
) -> Dict:
|
||||||
|
return {"model": model, **optional_rerank_params}
|
||||||
|
|
||||||
|
def transform_rerank_response(
|
||||||
|
self,
|
||||||
|
model: str,
|
||||||
|
raw_response: Response,
|
||||||
|
model_response: RerankResponse,
|
||||||
|
logging_obj: LiteLLMLoggingObj,
|
||||||
|
api_key: Optional[str] = None,
|
||||||
|
request_data: Dict = {},
|
||||||
|
optional_params: Dict = {},
|
||||||
|
litellm_params: Dict = {},
|
||||||
|
) -> RerankResponse:
|
||||||
|
if raw_response.status_code != 200:
|
||||||
|
raise Exception(raw_response.text)
|
||||||
|
|
||||||
|
logging_obj.post_call(original_response=raw_response.text)
|
||||||
|
|
||||||
|
_json_response = raw_response.json()
|
||||||
|
|
||||||
|
_billed_units = RerankBilledUnits(**_json_response.get("usage", {}))
|
||||||
|
_tokens = RerankTokens(**_json_response.get("usage", {}))
|
||||||
rerank_meta = RerankResponseMeta(billed_units=_billed_units, tokens=_tokens)
|
rerank_meta = RerankResponseMeta(billed_units=_billed_units, tokens=_tokens)
|
||||||
|
|
||||||
_results: Optional[List[dict]] = response.get("results")
|
_results: Optional[List[dict]] = _json_response.get("results")
|
||||||
|
|
||||||
if _results is None:
|
if _results is None:
|
||||||
raise ValueError(f"No results found in the response={response}")
|
raise ValueError(f"No results found in the response={_json_response}")
|
||||||
|
|
||||||
return RerankResponse(
|
return RerankResponse(
|
||||||
id=response.get("id") or str(uuid.uuid4()),
|
id=_json_response.get("id") or str(uuid.uuid4()),
|
||||||
results=_results, # type: ignore
|
results=_results, # type: ignore
|
||||||
meta=rerank_meta,
|
meta=rerank_meta,
|
||||||
) # Return response
|
) # Return response
|
||||||
|
|
||||||
|
def validate_environment(
|
||||||
|
self, headers: Dict, model: str, api_key: Optional[str] = None
|
||||||
|
) -> Dict:
|
||||||
|
if api_key is None:
|
||||||
|
raise ValueError(
|
||||||
|
"api_key is required. Set via `api_key` parameter or `JINA_API_KEY` environment variable."
|
||||||
|
)
|
||||||
|
return {
|
||||||
|
"accept": "application/json",
|
||||||
|
"content-type": "application/json",
|
||||||
|
"authorization": f"Bearer {api_key}",
|
||||||
|
}
|
||||||
|
|
||||||
|
def calculate_rerank_cost(
|
||||||
|
self,
|
||||||
|
model: str,
|
||||||
|
custom_llm_provider: Optional[str] = None,
|
||||||
|
billed_units: Optional[RerankBilledUnits] = None,
|
||||||
|
model_info: Optional[ModelInfo] = None,
|
||||||
|
) -> Tuple[float, float]:
|
||||||
|
"""
|
||||||
|
Jina AI reranker is priced at $0.000000018 per token.
|
||||||
|
"""
|
||||||
|
if (
|
||||||
|
model_info is None
|
||||||
|
or "input_cost_per_token" not in model_info
|
||||||
|
or model_info["input_cost_per_token"] is None
|
||||||
|
or billed_units is None
|
||||||
|
):
|
||||||
|
return 0.0, 0.0
|
||||||
|
|
||||||
|
total_tokens = billed_units.get("total_tokens")
|
||||||
|
if total_tokens is None:
|
||||||
|
return 0.0, 0.0
|
||||||
|
|
||||||
|
input_cost = model_info["input_cost_per_token"] * total_tokens
|
||||||
|
return input_cost, 0.0
|
||||||
|
|
|
@ -5982,6 +5982,19 @@
|
||||||
"litellm_provider": "bedrock",
|
"litellm_provider": "bedrock",
|
||||||
"mode": "chat"
|
"mode": "chat"
|
||||||
},
|
},
|
||||||
|
"amazon.rerank-v1:0": {
|
||||||
|
"max_tokens": 32000,
|
||||||
|
"max_input_tokens": 32000,
|
||||||
|
"max_output_tokens": 32000,
|
||||||
|
"max_query_tokens": 32000,
|
||||||
|
"max_document_chunks_per_query": 100,
|
||||||
|
"max_tokens_per_document_chunk": 512,
|
||||||
|
"input_cost_per_token": 0.0,
|
||||||
|
"input_cost_per_query": 0.001,
|
||||||
|
"output_cost_per_token": 0.0,
|
||||||
|
"litellm_provider": "bedrock",
|
||||||
|
"mode": "rerank"
|
||||||
|
},
|
||||||
"amazon.titan-text-lite-v1": {
|
"amazon.titan-text-lite-v1": {
|
||||||
"max_tokens": 4000,
|
"max_tokens": 4000,
|
||||||
"max_input_tokens": 42000,
|
"max_input_tokens": 42000,
|
||||||
|
@ -7022,6 +7035,19 @@
|
||||||
"mode": "chat",
|
"mode": "chat",
|
||||||
"supports_tool_choice": true
|
"supports_tool_choice": true
|
||||||
},
|
},
|
||||||
|
"cohere.rerank-v3-5:0": {
|
||||||
|
"max_tokens": 32000,
|
||||||
|
"max_input_tokens": 32000,
|
||||||
|
"max_output_tokens": 32000,
|
||||||
|
"max_query_tokens": 32000,
|
||||||
|
"max_document_chunks_per_query": 100,
|
||||||
|
"max_tokens_per_document_chunk": 512,
|
||||||
|
"input_cost_per_token": 0.0,
|
||||||
|
"input_cost_per_query": 0.002,
|
||||||
|
"output_cost_per_token": 0.0,
|
||||||
|
"litellm_provider": "bedrock",
|
||||||
|
"mode": "rerank"
|
||||||
|
},
|
||||||
"cohere.command-text-v14": {
|
"cohere.command-text-v14": {
|
||||||
"max_tokens": 4096,
|
"max_tokens": 4096,
|
||||||
"max_input_tokens": 4096,
|
"max_input_tokens": 4096,
|
||||||
|
@ -9154,5 +9180,15 @@
|
||||||
"input_cost_per_second": 0.00003333,
|
"input_cost_per_second": 0.00003333,
|
||||||
"output_cost_per_second": 0.00,
|
"output_cost_per_second": 0.00,
|
||||||
"litellm_provider": "assemblyai"
|
"litellm_provider": "assemblyai"
|
||||||
|
},
|
||||||
|
"jina-reranker-v2-base-multilingual": {
|
||||||
|
"max_tokens": 1024,
|
||||||
|
"max_input_tokens": 1024,
|
||||||
|
"max_output_tokens": 1024,
|
||||||
|
"max_document_chunks_per_query": 2048,
|
||||||
|
"input_cost_per_token": 0.000000018,
|
||||||
|
"output_cost_per_token": 0.000000018,
|
||||||
|
"litellm_provider": "jina_ai",
|
||||||
|
"mode": "rerank"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -9,7 +9,6 @@ from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLogging
|
||||||
from litellm.llms.base_llm.rerank.transformation import BaseRerankConfig
|
from litellm.llms.base_llm.rerank.transformation import BaseRerankConfig
|
||||||
from litellm.llms.bedrock.rerank.handler import BedrockRerankHandler
|
from litellm.llms.bedrock.rerank.handler import BedrockRerankHandler
|
||||||
from litellm.llms.custom_httpx.llm_http_handler import BaseLLMHTTPHandler
|
from litellm.llms.custom_httpx.llm_http_handler import BaseLLMHTTPHandler
|
||||||
from litellm.llms.jina_ai.rerank.handler import JinaAIRerank
|
|
||||||
from litellm.llms.together_ai.rerank.handler import TogetherAIRerank
|
from litellm.llms.together_ai.rerank.handler import TogetherAIRerank
|
||||||
from litellm.rerank_api.rerank_utils import get_optional_rerank_params
|
from litellm.rerank_api.rerank_utils import get_optional_rerank_params
|
||||||
from litellm.secret_managers.main import get_secret, get_secret_str
|
from litellm.secret_managers.main import get_secret, get_secret_str
|
||||||
|
@ -20,7 +19,6 @@ from litellm.utils import ProviderConfigManager, client, exception_type
|
||||||
####### ENVIRONMENT VARIABLES ###################
|
####### ENVIRONMENT VARIABLES ###################
|
||||||
# Initialize any necessary instances or variables here
|
# Initialize any necessary instances or variables here
|
||||||
together_rerank = TogetherAIRerank()
|
together_rerank = TogetherAIRerank()
|
||||||
jina_ai_rerank = JinaAIRerank()
|
|
||||||
bedrock_rerank = BedrockRerankHandler()
|
bedrock_rerank = BedrockRerankHandler()
|
||||||
base_llm_http_handler = BaseLLMHTTPHandler()
|
base_llm_http_handler = BaseLLMHTTPHandler()
|
||||||
#################################################
|
#################################################
|
||||||
|
@ -264,16 +262,26 @@ def rerank( # noqa: PLR0915
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Jina AI API key is required, please set 'JINA_AI_API_KEY' in your environment"
|
"Jina AI API key is required, please set 'JINA_AI_API_KEY' in your environment"
|
||||||
)
|
)
|
||||||
response = jina_ai_rerank.rerank(
|
|
||||||
|
api_base = (
|
||||||
|
dynamic_api_base
|
||||||
|
or optional_params.api_base
|
||||||
|
or litellm.api_base
|
||||||
|
or get_secret("BEDROCK_API_BASE") # type: ignore
|
||||||
|
)
|
||||||
|
|
||||||
|
response = base_llm_http_handler.rerank(
|
||||||
model=model,
|
model=model,
|
||||||
api_key=dynamic_api_key,
|
custom_llm_provider=_custom_llm_provider,
|
||||||
query=query,
|
optional_rerank_params=optional_rerank_params,
|
||||||
documents=documents,
|
logging_obj=litellm_logging_obj,
|
||||||
top_n=top_n,
|
timeout=optional_params.timeout,
|
||||||
rank_fields=rank_fields,
|
api_key=dynamic_api_key or optional_params.api_key,
|
||||||
return_documents=return_documents,
|
api_base=api_base,
|
||||||
max_chunks_per_doc=max_chunks_per_doc,
|
|
||||||
_is_async=_is_async,
|
_is_async=_is_async,
|
||||||
|
headers=headers or litellm.headers or {},
|
||||||
|
client=client,
|
||||||
|
model_response=model_response,
|
||||||
)
|
)
|
||||||
elif _custom_llm_provider == "bedrock":
|
elif _custom_llm_provider == "bedrock":
|
||||||
api_base = (
|
api_base = (
|
||||||
|
@ -295,6 +303,7 @@ def rerank( # noqa: PLR0915
|
||||||
optional_params=optional_params.model_dump(exclude_unset=True),
|
optional_params=optional_params.model_dump(exclude_unset=True),
|
||||||
api_base=api_base,
|
api_base=api_base,
|
||||||
logging_obj=litellm_logging_obj,
|
logging_obj=litellm_logging_obj,
|
||||||
|
client=client,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unsupported provider: {_custom_llm_provider}")
|
raise ValueError(f"Unsupported provider: {_custom_llm_provider}")
|
||||||
|
|
|
@ -17,6 +17,17 @@ def get_optional_rerank_params(
|
||||||
max_chunks_per_doc: Optional[int] = None,
|
max_chunks_per_doc: Optional[int] = None,
|
||||||
non_default_params: Optional[dict] = None,
|
non_default_params: Optional[dict] = None,
|
||||||
) -> OptionalRerankParams:
|
) -> OptionalRerankParams:
|
||||||
|
all_non_default_params = non_default_params or {}
|
||||||
|
if query is not None:
|
||||||
|
all_non_default_params["query"] = query
|
||||||
|
if top_n is not None:
|
||||||
|
all_non_default_params["top_n"] = top_n
|
||||||
|
if documents is not None:
|
||||||
|
all_non_default_params["documents"] = documents
|
||||||
|
if return_documents is not None:
|
||||||
|
all_non_default_params["return_documents"] = return_documents
|
||||||
|
if max_chunks_per_doc is not None:
|
||||||
|
all_non_default_params["max_chunks_per_doc"] = max_chunks_per_doc
|
||||||
return rerank_provider_config.map_cohere_rerank_params(
|
return rerank_provider_config.map_cohere_rerank_params(
|
||||||
model=model,
|
model=model,
|
||||||
drop_params=drop_params,
|
drop_params=drop_params,
|
||||||
|
@ -27,5 +38,5 @@ def get_optional_rerank_params(
|
||||||
rank_fields=rank_fields,
|
rank_fields=rank_fields,
|
||||||
return_documents=return_documents,
|
return_documents=return_documents,
|
||||||
max_chunks_per_doc=max_chunks_per_doc,
|
max_chunks_per_doc=max_chunks_per_doc,
|
||||||
non_default_params=non_default_params,
|
non_default_params=all_non_default_params,
|
||||||
)
|
)
|
||||||
|
|
|
@ -6198,6 +6198,8 @@ class ProviderConfigManager:
|
||||||
return litellm.AzureAIRerankConfig()
|
return litellm.AzureAIRerankConfig()
|
||||||
elif litellm.LlmProviders.INFINITY == provider:
|
elif litellm.LlmProviders.INFINITY == provider:
|
||||||
return litellm.InfinityRerankConfig()
|
return litellm.InfinityRerankConfig()
|
||||||
|
elif litellm.LlmProviders.JINA_AI == provider:
|
||||||
|
return litellm.JinaAIRerankConfig()
|
||||||
return litellm.CohereRerankConfig()
|
return litellm.CohereRerankConfig()
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
|
|
@ -5982,6 +5982,19 @@
|
||||||
"litellm_provider": "bedrock",
|
"litellm_provider": "bedrock",
|
||||||
"mode": "chat"
|
"mode": "chat"
|
||||||
},
|
},
|
||||||
|
"amazon.rerank-v1:0": {
|
||||||
|
"max_tokens": 32000,
|
||||||
|
"max_input_tokens": 32000,
|
||||||
|
"max_output_tokens": 32000,
|
||||||
|
"max_query_tokens": 32000,
|
||||||
|
"max_document_chunks_per_query": 100,
|
||||||
|
"max_tokens_per_document_chunk": 512,
|
||||||
|
"input_cost_per_token": 0.0,
|
||||||
|
"input_cost_per_query": 0.001,
|
||||||
|
"output_cost_per_token": 0.0,
|
||||||
|
"litellm_provider": "bedrock",
|
||||||
|
"mode": "rerank"
|
||||||
|
},
|
||||||
"amazon.titan-text-lite-v1": {
|
"amazon.titan-text-lite-v1": {
|
||||||
"max_tokens": 4000,
|
"max_tokens": 4000,
|
||||||
"max_input_tokens": 42000,
|
"max_input_tokens": 42000,
|
||||||
|
@ -7022,6 +7035,19 @@
|
||||||
"mode": "chat",
|
"mode": "chat",
|
||||||
"supports_tool_choice": true
|
"supports_tool_choice": true
|
||||||
},
|
},
|
||||||
|
"cohere.rerank-v3-5:0": {
|
||||||
|
"max_tokens": 32000,
|
||||||
|
"max_input_tokens": 32000,
|
||||||
|
"max_output_tokens": 32000,
|
||||||
|
"max_query_tokens": 32000,
|
||||||
|
"max_document_chunks_per_query": 100,
|
||||||
|
"max_tokens_per_document_chunk": 512,
|
||||||
|
"input_cost_per_token": 0.0,
|
||||||
|
"input_cost_per_query": 0.002,
|
||||||
|
"output_cost_per_token": 0.0,
|
||||||
|
"litellm_provider": "bedrock",
|
||||||
|
"mode": "rerank"
|
||||||
|
},
|
||||||
"cohere.command-text-v14": {
|
"cohere.command-text-v14": {
|
||||||
"max_tokens": 4096,
|
"max_tokens": 4096,
|
||||||
"max_input_tokens": 4096,
|
"max_input_tokens": 4096,
|
||||||
|
@ -9154,5 +9180,15 @@
|
||||||
"input_cost_per_second": 0.00003333,
|
"input_cost_per_second": 0.00003333,
|
||||||
"output_cost_per_second": 0.00,
|
"output_cost_per_second": 0.00,
|
||||||
"litellm_provider": "assemblyai"
|
"litellm_provider": "assemblyai"
|
||||||
|
},
|
||||||
|
"jina-reranker-v2-base-multilingual": {
|
||||||
|
"max_tokens": 1024,
|
||||||
|
"max_input_tokens": 1024,
|
||||||
|
"max_output_tokens": 1024,
|
||||||
|
"max_document_chunks_per_query": 2048,
|
||||||
|
"input_cost_per_token": 0.000000018,
|
||||||
|
"output_cost_per_token": 0.000000018,
|
||||||
|
"litellm_provider": "jina_ai",
|
||||||
|
"mode": "rerank"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
14
tests/litellm/llms/bedrock/rerank/transformation.py
Normal file
14
tests/litellm/llms/bedrock/rerank/transformation.py
Normal file
|
@ -0,0 +1,14 @@
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from fastapi.testclient import TestClient
|
||||||
|
|
||||||
|
sys.path.insert(
|
||||||
|
0, os.path.abspath("../../..")
|
||||||
|
) # Adds the parent directory to the system path
|
||||||
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
|
from litellm import rerank
|
||||||
|
from litellm.llms.custom_httpx.http_handler import HTTPHandler
|
51
tests/litellm/rerank_api/test_main.py
Normal file
51
tests/litellm/rerank_api/test_main.py
Normal file
|
@ -0,0 +1,51 @@
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from fastapi.testclient import TestClient
|
||||||
|
|
||||||
|
sys.path.insert(
|
||||||
|
0, os.path.abspath("../../..")
|
||||||
|
) # Adds the parent directory to the system path
|
||||||
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
|
from litellm import rerank
|
||||||
|
from litellm.llms.custom_httpx.http_handler import HTTPHandler
|
||||||
|
|
||||||
|
|
||||||
|
def test_rerank_infer_region_from_model_arn(monkeypatch):
|
||||||
|
mock_response = MagicMock()
|
||||||
|
|
||||||
|
monkeypatch.setenv("AWS_REGION_NAME", "us-east-1")
|
||||||
|
args = {
|
||||||
|
"model": "bedrock/arn:aws:bedrock:us-west-2::foundation-model/amazon.rerank-v1:0",
|
||||||
|
"query": "hello",
|
||||||
|
"documents": ["hello", "world"],
|
||||||
|
}
|
||||||
|
|
||||||
|
def return_val():
|
||||||
|
return {
|
||||||
|
"results": [
|
||||||
|
{"index": 0, "relevanceScore": 0.6716859340667725},
|
||||||
|
{"index": 1, "relevanceScore": 0.0004994205664843321},
|
||||||
|
]
|
||||||
|
}
|
||||||
|
|
||||||
|
mock_response.json = return_val
|
||||||
|
mock_response.headers = {"key": "value"}
|
||||||
|
mock_response.status_code = 200
|
||||||
|
|
||||||
|
client = HTTPHandler()
|
||||||
|
|
||||||
|
with patch.object(client, "post", return_value=mock_response) as mock_post:
|
||||||
|
rerank(
|
||||||
|
model=args["model"],
|
||||||
|
query=args["query"],
|
||||||
|
documents=args["documents"],
|
||||||
|
client=client,
|
||||||
|
)
|
||||||
|
mock_post.assert_called_once()
|
||||||
|
print(f"mock_post.call_args: {mock_post.call_args.kwargs}")
|
||||||
|
assert "us-west-2" in mock_post.call_args.kwargs["url"]
|
||||||
|
assert "us-east-1" not in mock_post.call_args.kwargs["url"]
|
|
@ -33,6 +33,7 @@ def assert_response_shape(response, custom_llm_provider):
|
||||||
expected_api_version_shape = {"version": str}
|
expected_api_version_shape = {"version": str}
|
||||||
|
|
||||||
expected_billed_units_shape = {"search_units": int}
|
expected_billed_units_shape = {"search_units": int}
|
||||||
|
expected_billed_units_total_tokens_shape = {"total_tokens": int}
|
||||||
|
|
||||||
assert isinstance(response.id, expected_response_shape["id"])
|
assert isinstance(response.id, expected_response_shape["id"])
|
||||||
assert isinstance(response.results, expected_response_shape["results"])
|
assert isinstance(response.results, expected_response_shape["results"])
|
||||||
|
@ -52,9 +53,15 @@ def assert_response_shape(response, custom_llm_provider):
|
||||||
response.meta["api_version"]["version"],
|
response.meta["api_version"]["version"],
|
||||||
expected_api_version_shape["version"],
|
expected_api_version_shape["version"],
|
||||||
)
|
)
|
||||||
|
assert isinstance(
|
||||||
|
response.meta["billed_units"], expected_meta_shape["billed_units"]
|
||||||
|
)
|
||||||
|
if "total_tokens" in response.meta["billed_units"]:
|
||||||
assert isinstance(
|
assert isinstance(
|
||||||
response.meta["billed_units"], expected_meta_shape["billed_units"]
|
response.meta["billed_units"]["total_tokens"],
|
||||||
|
expected_billed_units_total_tokens_shape["total_tokens"],
|
||||||
)
|
)
|
||||||
|
else:
|
||||||
assert isinstance(
|
assert isinstance(
|
||||||
response.meta["billed_units"]["search_units"],
|
response.meta["billed_units"]["search_units"],
|
||||||
expected_billed_units_shape["search_units"],
|
expected_billed_units_shape["search_units"],
|
||||||
|
@ -79,7 +86,9 @@ class BaseLLMRerankTest(ABC):
|
||||||
@pytest.mark.asyncio()
|
@pytest.mark.asyncio()
|
||||||
@pytest.mark.parametrize("sync_mode", [True, False])
|
@pytest.mark.parametrize("sync_mode", [True, False])
|
||||||
async def test_basic_rerank(self, sync_mode):
|
async def test_basic_rerank(self, sync_mode):
|
||||||
litellm.set_verbose = True
|
litellm._turn_on_debug()
|
||||||
|
os.environ["LITELLM_LOCAL_MODEL_COST_MAP"] = "True"
|
||||||
|
litellm.model_cost = litellm.get_model_cost_map(url="")
|
||||||
rerank_call_args = self.get_base_rerank_call_args()
|
rerank_call_args = self.get_base_rerank_call_args()
|
||||||
custom_llm_provider = self.get_custom_llm_provider()
|
custom_llm_provider = self.get_custom_llm_provider()
|
||||||
if sync_mode is True:
|
if sync_mode is True:
|
||||||
|
@ -95,6 +104,9 @@ class BaseLLMRerankTest(ABC):
|
||||||
assert response.id is not None
|
assert response.id is not None
|
||||||
assert response.results is not None
|
assert response.results is not None
|
||||||
|
|
||||||
|
assert response._hidden_params["response_cost"] is not None
|
||||||
|
assert response._hidden_params["response_cost"] > 0
|
||||||
|
|
||||||
assert_response_shape(
|
assert_response_shape(
|
||||||
response=response, custom_llm_provider=custom_llm_provider.value
|
response=response, custom_llm_provider=custom_llm_provider.value
|
||||||
)
|
)
|
||||||
|
|
|
@ -14,6 +14,7 @@ from litellm.llms.anthropic.chat import ModelResponseIterator
|
||||||
import httpx
|
import httpx
|
||||||
import json
|
import json
|
||||||
from litellm.llms.custom_httpx.http_handler import HTTPHandler
|
from litellm.llms.custom_httpx.http_handler import HTTPHandler
|
||||||
|
from base_rerank_unit_tests import BaseLLMRerankTest
|
||||||
|
|
||||||
load_dotenv()
|
load_dotenv()
|
||||||
import io
|
import io
|
||||||
|
@ -185,6 +186,7 @@ def test_completion_azure_ai_command_r():
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
pytest.fail(f"Error occurred: {e}")
|
pytest.fail(f"Error occurred: {e}")
|
||||||
|
|
||||||
|
|
||||||
def test_azure_deepseek_reasoning_content():
|
def test_azure_deepseek_reasoning_content():
|
||||||
import json
|
import json
|
||||||
|
|
||||||
|
@ -192,37 +194,48 @@ def test_azure_deepseek_reasoning_content():
|
||||||
|
|
||||||
with patch.object(client, "post") as mock_post:
|
with patch.object(client, "post") as mock_post:
|
||||||
mock_response = MagicMock()
|
mock_response = MagicMock()
|
||||||
|
|
||||||
mock_response.text = json.dumps(
|
mock_response.text = json.dumps(
|
||||||
{
|
{
|
||||||
"choices": [
|
"choices": [
|
||||||
{
|
{
|
||||||
"finish_reason": "stop",
|
"finish_reason": "stop",
|
||||||
"index": 0,
|
"index": 0,
|
||||||
"message": {
|
"message": {
|
||||||
"content": "<think>I am thinking here</think>\n\nThe sky is a canvas of blue",
|
"content": "<think>I am thinking here</think>\n\nThe sky is a canvas of blue",
|
||||||
"role": "assistant",
|
"role": "assistant",
|
||||||
}
|
},
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
mock_response.status_code = 200
|
mock_response.status_code = 200
|
||||||
# Add required response attributes
|
# Add required response attributes
|
||||||
mock_response.headers = {"Content-Type": "application/json"}
|
mock_response.headers = {"Content-Type": "application/json"}
|
||||||
mock_response.json = lambda: json.loads(mock_response.text)
|
mock_response.json = lambda: json.loads(mock_response.text)
|
||||||
mock_post.return_value = mock_response
|
mock_post.return_value = mock_response
|
||||||
|
|
||||||
|
|
||||||
response = litellm.completion(
|
response = litellm.completion(
|
||||||
model='azure_ai/deepseek-r1',
|
model="azure_ai/deepseek-r1",
|
||||||
messages=[{"role": "user", "content": "Hello, world!"}],
|
messages=[{"role": "user", "content": "Hello, world!"}],
|
||||||
api_base="https://litellm8397336933.services.ai.azure.com/models/chat/completions",
|
api_base="https://litellm8397336933.services.ai.azure.com/models/chat/completions",
|
||||||
api_key="my-fake-api-key",
|
api_key="my-fake-api-key",
|
||||||
client=client
|
client=client,
|
||||||
)
|
)
|
||||||
|
|
||||||
print(response)
|
print(response)
|
||||||
assert(response.choices[0].message.reasoning_content == "I am thinking here")
|
assert response.choices[0].message.reasoning_content == "I am thinking here"
|
||||||
assert(response.choices[0].message.content == "\n\nThe sky is a canvas of blue")
|
assert response.choices[0].message.content == "\n\nThe sky is a canvas of blue"
|
||||||
|
|
||||||
|
|
||||||
|
class TestAzureAIRerank(BaseLLMRerankTest):
|
||||||
|
def get_custom_llm_provider(self) -> litellm.LlmProviders:
|
||||||
|
return litellm.LlmProviders.AZURE_AI
|
||||||
|
|
||||||
|
def get_base_rerank_call_args(self) -> dict:
|
||||||
|
return {
|
||||||
|
"model": "azure_ai/cohere-rerank-v3-english",
|
||||||
|
"api_base": os.getenv("AZURE_AI_COHERE_API_BASE"),
|
||||||
|
"api_key": os.getenv("AZURE_AI_COHERE_API_KEY"),
|
||||||
|
}
|
||||||
|
|
|
@ -2186,6 +2186,16 @@ class TestBedrockRerank(BaseLLMRerankTest):
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class TestBedrockCohereRerank(BaseLLMRerankTest):
|
||||||
|
def get_custom_llm_provider(self) -> litellm.LlmProviders:
|
||||||
|
return litellm.LlmProviders.BEDROCK
|
||||||
|
|
||||||
|
def get_base_rerank_call_args(self) -> dict:
|
||||||
|
return {
|
||||||
|
"model": "bedrock/arn:aws:bedrock:us-west-2::foundation-model/cohere.rerank-v3-5:0",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"messages, continue_message_index",
|
"messages, continue_message_index",
|
||||||
[
|
[
|
||||||
|
|
|
@ -66,6 +66,7 @@ def assert_response_shape(response, custom_llm_provider):
|
||||||
|
|
||||||
@pytest.mark.asyncio()
|
@pytest.mark.asyncio()
|
||||||
@pytest.mark.parametrize("sync_mode", [True, False])
|
@pytest.mark.parametrize("sync_mode", [True, False])
|
||||||
|
@pytest.mark.flaky(retries=3, delay=1)
|
||||||
async def test_basic_rerank(sync_mode):
|
async def test_basic_rerank(sync_mode):
|
||||||
litellm.set_verbose = True
|
litellm.set_verbose = True
|
||||||
if sync_mode is True:
|
if sync_mode is True:
|
||||||
|
@ -311,6 +312,7 @@ def test_complete_base_url_cohere():
|
||||||
(3, None, False),
|
(3, None, False),
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
@pytest.mark.flaky(retries=3, delay=1)
|
||||||
async def test_basic_rerank_caching(sync_mode, top_n_1, top_n_2, expect_cache_hit):
|
async def test_basic_rerank_caching(sync_mode, top_n_1, top_n_2, expect_cache_hit):
|
||||||
from litellm.caching.caching import Cache
|
from litellm.caching.caching import Cache
|
||||||
|
|
||||||
|
|
|
@ -1574,7 +1574,11 @@ def test_completion_cost_azure_ai_rerank(model):
|
||||||
"relevance_score": 0.990732,
|
"relevance_score": 0.990732,
|
||||||
},
|
},
|
||||||
],
|
],
|
||||||
meta={},
|
meta={
|
||||||
|
"billed_units": {
|
||||||
|
"search_units": 1,
|
||||||
|
}
|
||||||
|
},
|
||||||
)
|
)
|
||||||
print("response", response)
|
print("response", response)
|
||||||
model = model
|
model = model
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue