mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 02:34:29 +00:00
Add cost tracking for rerank via bedrock (#8691)
* feat(bedrock/rerank): infer model region if model given as arn * test: add unit testing to ensure bedrock region name inferred from arn on rerank * feat(bedrock/rerank/transformation.py): include search units for bedrock rerank result Resolves https://github.com/BerriAI/litellm/issues/7258#issuecomment-2671557137 * test(test_bedrock_completion.py): add testing for bedrock cohere rerank * feat(cost_calculator.py): refactor rerank cost tracking to support bedrock cost tracking * build(model_prices_and_context_window.json): add amazon.rerank model to model cost map * fix(cost_calculator.py): bedrock/common_utils.py get base model from model w/ arn -> handles rerank model * build(model_prices_and_context_window.json): add bedrock cohere rerank pricing * feat(bedrock/rerank): migrate bedrock config to basererank config * Revert "feat(bedrock/rerank): migrate bedrock config to basererank config" This reverts commit84fae1f167
. * test: add testing to ensure large doc / queries are correctly counted * Revert "test: add testing to ensure large doc / queries are correctly counted" This reverts commit4337f1657e
. * fix(migrate-jina-ai-to-rerank-config): enables cost tracking * refactor(jina_ai/): finish migrating jina ai to base rerank config enables cost tracking * fix(jina_ai/rerank): e2e jina ai rerank cost tracking * fix: cleanup dead code * fix: fix python3.8 compatibility error * test: fix test * test: add e2e testing for azure ai rerank * fix: fix linting error * test: mark cohere as flaky
This commit is contained in:
parent
4c9517fd78
commit
b682dc4ec8
26 changed files with 524 additions and 296 deletions
|
@ -10,6 +10,7 @@ ALL Bedrock models (Anthropic, Meta, Deepseek, Mistral, Amazon, etc.) are Suppor
|
|||
| Provider Route on LiteLLM | `bedrock/`, [`bedrock/converse/`](#set-converse--invoke-route), [`bedrock/invoke/`](#set-invoke-route), [`bedrock/converse_like/`](#calling-via-internal-proxy), [`bedrock/llama/`](#deepseek-not-r1), [`bedrock/deepseek_r1/`](#deepseek-r1) |
|
||||
| Provider Doc | [Amazon Bedrock ↗](https://docs.aws.amazon.com/bedrock/latest/userguide/what-is-bedrock.html) |
|
||||
| Supported OpenAI Endpoints | `/chat/completions`, `/completions`, `/embeddings`, `/images/generations` |
|
||||
| Rerank Endpoint | `/rerank` |
|
||||
| Pass-through Endpoint | [Supported](../pass_through/bedrock.md) |
|
||||
|
||||
|
||||
|
|
|
@ -820,6 +820,7 @@ from .llms.cohere.completion.transformation import CohereTextConfig as CohereCon
|
|||
from .llms.cohere.rerank.transformation import CohereRerankConfig
|
||||
from .llms.azure_ai.rerank.transformation import AzureAIRerankConfig
|
||||
from .llms.infinity.rerank.transformation import InfinityRerankConfig
|
||||
from .llms.jina_ai.rerank.transformation import JinaAIRerankConfig
|
||||
from .llms.clarifai.chat.transformation import ClarifaiConfig
|
||||
from .llms.ai21.chat.transformation import AI21ChatConfig, AI21ChatConfig as AI21Config
|
||||
from .llms.together_ai.chat import TogetherAIConfig
|
||||
|
|
|
@ -16,15 +16,9 @@ from litellm.llms.anthropic.cost_calculation import (
|
|||
from litellm.llms.azure.cost_calculation import (
|
||||
cost_per_token as azure_openai_cost_per_token,
|
||||
)
|
||||
from litellm.llms.azure_ai.cost_calculator import (
|
||||
cost_per_query as azure_ai_rerank_cost_per_query,
|
||||
)
|
||||
from litellm.llms.bedrock.image.cost_calculator import (
|
||||
cost_calculator as bedrock_image_cost_calculator,
|
||||
)
|
||||
from litellm.llms.cohere.cost_calculator import (
|
||||
cost_per_query as cohere_rerank_cost_per_query,
|
||||
)
|
||||
from litellm.llms.databricks.cost_calculator import (
|
||||
cost_per_token as databricks_cost_per_token,
|
||||
)
|
||||
|
@ -51,10 +45,12 @@ from litellm.llms.vertex_ai.image_generation.cost_calculator import (
|
|||
cost_calculator as vertex_ai_image_cost_calculator,
|
||||
)
|
||||
from litellm.types.llms.openai import HttpxBinaryResponseContent
|
||||
from litellm.types.rerank import RerankResponse
|
||||
from litellm.types.rerank import RerankBilledUnits, RerankResponse
|
||||
from litellm.types.utils import (
|
||||
CallTypesLiteral,
|
||||
LlmProviders,
|
||||
LlmProvidersSet,
|
||||
ModelInfo,
|
||||
PassthroughCallTypes,
|
||||
Usage,
|
||||
)
|
||||
|
@ -64,6 +60,7 @@ from litellm.utils import (
|
|||
EmbeddingResponse,
|
||||
ImageResponse,
|
||||
ModelResponse,
|
||||
ProviderConfigManager,
|
||||
TextCompletionResponse,
|
||||
TranscriptionResponse,
|
||||
_cached_get_model_info_helper,
|
||||
|
@ -114,6 +111,8 @@ def cost_per_token( # noqa: PLR0915
|
|||
number_of_queries: Optional[int] = None,
|
||||
### USAGE OBJECT ###
|
||||
usage_object: Optional[Usage] = None, # just read the usage object if provided
|
||||
### BILLED UNITS ###
|
||||
rerank_billed_units: Optional[RerankBilledUnits] = None,
|
||||
### CALL TYPE ###
|
||||
call_type: CallTypesLiteral = "completion",
|
||||
audio_transcription_file_duration: float = 0.0, # for audio transcription calls - the file time in seconds
|
||||
|
@ -238,6 +237,7 @@ def cost_per_token( # noqa: PLR0915
|
|||
return rerank_cost(
|
||||
model=model,
|
||||
custom_llm_provider=custom_llm_provider,
|
||||
billed_units=rerank_billed_units,
|
||||
)
|
||||
elif call_type == "atranscription" or call_type == "transcription":
|
||||
return openai_cost_per_second(
|
||||
|
@ -552,6 +552,7 @@ def completion_cost( # noqa: PLR0915
|
|||
cost_per_token_usage_object: Optional[Usage] = _get_usage_object(
|
||||
completion_response=completion_response
|
||||
)
|
||||
rerank_billed_units: Optional[RerankBilledUnits] = None
|
||||
model = _select_model_name_for_cost_calc(
|
||||
model=model,
|
||||
completion_response=completion_response,
|
||||
|
@ -698,6 +699,11 @@ def completion_cost( # noqa: PLR0915
|
|||
else:
|
||||
billed_units = {}
|
||||
|
||||
rerank_billed_units = RerankBilledUnits(
|
||||
search_units=billed_units.get("search_units"),
|
||||
total_tokens=billed_units.get("total_tokens"),
|
||||
)
|
||||
|
||||
search_units = (
|
||||
billed_units.get("search_units") or 1
|
||||
) # cohere charges per request by default.
|
||||
|
@ -763,6 +769,7 @@ def completion_cost( # noqa: PLR0915
|
|||
usage_object=cost_per_token_usage_object,
|
||||
call_type=call_type,
|
||||
audio_transcription_file_duration=audio_transcription_file_duration,
|
||||
rerank_billed_units=rerank_billed_units,
|
||||
)
|
||||
_final_cost = prompt_tokens_cost_usd_dollar + completion_tokens_cost_usd_dollar
|
||||
|
||||
|
@ -836,27 +843,33 @@ def response_cost_calculator(
|
|||
def rerank_cost(
|
||||
model: str,
|
||||
custom_llm_provider: Optional[str],
|
||||
billed_units: Optional[RerankBilledUnits] = None,
|
||||
) -> Tuple[float, float]:
|
||||
"""
|
||||
Returns
|
||||
- float or None: cost of response OR none if error.
|
||||
"""
|
||||
default_num_queries = 1
|
||||
_, custom_llm_provider, _, _ = litellm.get_llm_provider(
|
||||
model=model, custom_llm_provider=custom_llm_provider
|
||||
)
|
||||
|
||||
try:
|
||||
if custom_llm_provider == "cohere":
|
||||
return cohere_rerank_cost_per_query(
|
||||
model=model, num_queries=default_num_queries
|
||||
config = ProviderConfigManager.get_provider_rerank_config(
|
||||
model=model, provider=LlmProviders(custom_llm_provider)
|
||||
)
|
||||
|
||||
try:
|
||||
model_info: Optional[ModelInfo] = litellm.get_model_info(
|
||||
model=model, custom_llm_provider=custom_llm_provider
|
||||
)
|
||||
elif custom_llm_provider == "azure_ai":
|
||||
return azure_ai_rerank_cost_per_query(
|
||||
model=model, num_queries=default_num_queries
|
||||
)
|
||||
raise ValueError(
|
||||
f"invalid custom_llm_provider for rerank model: {model}, custom_llm_provider: {custom_llm_provider}"
|
||||
except Exception:
|
||||
model_info = None
|
||||
|
||||
return config.calculate_rerank_cost(
|
||||
model=model,
|
||||
custom_llm_provider=custom_llm_provider,
|
||||
billed_units=billed_units,
|
||||
model_info=model_info,
|
||||
)
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
|
|
@ -1,32 +0,0 @@
|
|||
"""
|
||||
Handles custom cost calculation for Azure AI models.
|
||||
|
||||
Custom cost calculation for Azure AI models only requied for rerank.
|
||||
"""
|
||||
|
||||
from typing import Tuple
|
||||
|
||||
from litellm.utils import get_model_info
|
||||
|
||||
|
||||
def cost_per_query(model: str, num_queries: int = 1) -> Tuple[float, float]:
|
||||
"""
|
||||
Calculates the cost per query for a given rerank model.
|
||||
|
||||
Input:
|
||||
- model: str, the model name without provider prefix
|
||||
|
||||
Returns:
|
||||
Tuple[float, float] - prompt_cost_in_usd, completion_cost_in_usd
|
||||
"""
|
||||
model_info = get_model_info(model=model, custom_llm_provider="azure_ai")
|
||||
|
||||
if (
|
||||
"input_cost_per_query" not in model_info
|
||||
or model_info["input_cost_per_query"] is None
|
||||
):
|
||||
return 0.0, 0.0
|
||||
|
||||
prompt_cost = model_info["input_cost_per_query"] * num_queries
|
||||
|
||||
return prompt_cost, 0.0
|
|
@ -1,9 +1,10 @@
|
|||
from abc import ABC, abstractmethod
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import httpx
|
||||
|
||||
from litellm.types.rerank import OptionalRerankParams, RerankResponse
|
||||
from litellm.types.rerank import OptionalRerankParams, RerankBilledUnits, RerankResponse
|
||||
from litellm.types.utils import ModelInfo
|
||||
|
||||
from ..chat.transformation import BaseLLMException
|
||||
|
||||
|
@ -66,7 +67,7 @@ class BaseRerankConfig(ABC):
|
|||
@abstractmethod
|
||||
def map_cohere_rerank_params(
|
||||
self,
|
||||
non_default_params: Optional[dict],
|
||||
non_default_params: dict,
|
||||
model: str,
|
||||
drop_params: bool,
|
||||
query: str,
|
||||
|
@ -79,8 +80,48 @@ class BaseRerankConfig(ABC):
|
|||
) -> OptionalRerankParams:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_error_class(
|
||||
self, error_message: str, status_code: int, headers: Union[dict, httpx.Headers]
|
||||
) -> BaseLLMException:
|
||||
pass
|
||||
raise BaseLLMException(
|
||||
status_code=status_code,
|
||||
message=error_message,
|
||||
headers=headers,
|
||||
)
|
||||
|
||||
def calculate_rerank_cost(
|
||||
self,
|
||||
model: str,
|
||||
custom_llm_provider: Optional[str] = None,
|
||||
billed_units: Optional[RerankBilledUnits] = None,
|
||||
model_info: Optional[ModelInfo] = None,
|
||||
) -> Tuple[float, float]:
|
||||
"""
|
||||
Calculates the cost per query for a given rerank model.
|
||||
|
||||
Input:
|
||||
- model: str, the model name without provider prefix
|
||||
- custom_llm_provider: str, the provider used for the model. If provided, used to check if the litellm model info is for that provider.
|
||||
- num_queries: int, the number of queries to calculate the cost for
|
||||
- model_info: ModelInfo, the model info for the given model
|
||||
|
||||
Returns:
|
||||
Tuple[float, float] - prompt_cost_in_usd, completion_cost_in_usd
|
||||
"""
|
||||
|
||||
if (
|
||||
model_info is None
|
||||
or "input_cost_per_query" not in model_info
|
||||
or model_info["input_cost_per_query"] is None
|
||||
or billed_units is None
|
||||
):
|
||||
return 0.0, 0.0
|
||||
|
||||
search_units = billed_units.get("search_units")
|
||||
|
||||
if search_units is None:
|
||||
return 0.0, 0.0
|
||||
|
||||
prompt_cost = model_info["input_cost_per_query"] * search_units
|
||||
|
||||
return prompt_cost, 0.0
|
||||
|
|
|
@ -202,6 +202,61 @@ class BaseAWSLLM:
|
|||
self.iam_cache.set_cache(cache_key, credentials, ttl=_cache_ttl)
|
||||
return credentials
|
||||
|
||||
def _get_aws_region_from_model_arn(self, model: Optional[str]) -> Optional[str]:
|
||||
try:
|
||||
# First check if the string contains the expected prefix
|
||||
if not isinstance(model, str) or "arn:aws:bedrock" not in model:
|
||||
return None
|
||||
|
||||
# Split the ARN and check if we have enough parts
|
||||
parts = model.split(":")
|
||||
if len(parts) < 4:
|
||||
return None
|
||||
|
||||
# Get the region from the correct position
|
||||
region = parts[3]
|
||||
if not region: # Check if region is empty
|
||||
return None
|
||||
|
||||
return region
|
||||
except Exception:
|
||||
# Catch any unexpected errors and return None
|
||||
return None
|
||||
|
||||
def _get_aws_region_name(
|
||||
self, optional_params: dict, model: Optional[str] = None
|
||||
) -> str:
|
||||
"""
|
||||
Get the AWS region name from the environment variables
|
||||
"""
|
||||
aws_region_name = optional_params.get("aws_region_name", None)
|
||||
### SET REGION NAME ###
|
||||
if aws_region_name is None:
|
||||
# check model arn #
|
||||
aws_region_name = self._get_aws_region_from_model_arn(model)
|
||||
# check env #
|
||||
litellm_aws_region_name = get_secret("AWS_REGION_NAME", None)
|
||||
|
||||
if (
|
||||
aws_region_name is None
|
||||
and litellm_aws_region_name is not None
|
||||
and isinstance(litellm_aws_region_name, str)
|
||||
):
|
||||
aws_region_name = litellm_aws_region_name
|
||||
|
||||
standard_aws_region_name = get_secret("AWS_REGION", None)
|
||||
if (
|
||||
aws_region_name is None
|
||||
and standard_aws_region_name is not None
|
||||
and isinstance(standard_aws_region_name, str)
|
||||
):
|
||||
aws_region_name = standard_aws_region_name
|
||||
|
||||
if aws_region_name is None:
|
||||
aws_region_name = "us-west-2"
|
||||
|
||||
return aws_region_name
|
||||
|
||||
@tracer.wrap()
|
||||
def _auth_with_web_identity_token(
|
||||
self,
|
||||
|
@ -423,7 +478,7 @@ class BaseAWSLLM:
|
|||
return endpoint_url, proxy_endpoint_url
|
||||
|
||||
def _get_boto_credentials_from_optional_params(
|
||||
self, optional_params: dict
|
||||
self, optional_params: dict, model: Optional[str] = None
|
||||
) -> Boto3CredentialsInfo:
|
||||
"""
|
||||
Get boto3 credentials from optional params
|
||||
|
@ -443,7 +498,7 @@ class BaseAWSLLM:
|
|||
aws_secret_access_key = optional_params.pop("aws_secret_access_key", None)
|
||||
aws_access_key_id = optional_params.pop("aws_access_key_id", None)
|
||||
aws_session_token = optional_params.pop("aws_session_token", None)
|
||||
aws_region_name = optional_params.pop("aws_region_name", None)
|
||||
aws_region_name = self._get_aws_region_name(optional_params, model)
|
||||
aws_role_name = optional_params.pop("aws_role_name", None)
|
||||
aws_session_name = optional_params.pop("aws_session_name", None)
|
||||
aws_profile_name = optional_params.pop("aws_profile_name", None)
|
||||
|
@ -453,25 +508,6 @@ class BaseAWSLLM:
|
|||
"aws_bedrock_runtime_endpoint", None
|
||||
) # https://bedrock-runtime.{region_name}.amazonaws.com
|
||||
|
||||
### SET REGION NAME ###
|
||||
if aws_region_name is None:
|
||||
# check env #
|
||||
litellm_aws_region_name = get_secret_str("AWS_REGION_NAME", None)
|
||||
|
||||
if litellm_aws_region_name is not None and isinstance(
|
||||
litellm_aws_region_name, str
|
||||
):
|
||||
aws_region_name = litellm_aws_region_name
|
||||
|
||||
standard_aws_region_name = get_secret_str("AWS_REGION", None)
|
||||
if standard_aws_region_name is not None and isinstance(
|
||||
standard_aws_region_name, str
|
||||
):
|
||||
aws_region_name = standard_aws_region_name
|
||||
|
||||
if aws_region_name is None:
|
||||
aws_region_name = "us-west-2"
|
||||
|
||||
credentials: Credentials = self.get_credentials(
|
||||
aws_access_key_id=aws_access_key_id,
|
||||
aws_secret_access_key=aws_secret_access_key,
|
||||
|
|
|
@ -27,7 +27,7 @@ from litellm.llms.custom_httpx.http_handler import (
|
|||
)
|
||||
from litellm.types.llms.openai import AllMessageValues
|
||||
from litellm.types.utils import ModelResponse, Usage
|
||||
from litellm.utils import CustomStreamWrapper, get_secret
|
||||
from litellm.utils import CustomStreamWrapper
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from litellm.litellm_core_utils.litellm_logging import Logging as _LiteLLMLoggingObj
|
||||
|
@ -598,61 +598,6 @@ class AmazonInvokeConfig(BaseConfig, BaseAWSLLM):
|
|||
)
|
||||
return modelId
|
||||
|
||||
def get_aws_region_from_model_arn(self, model: Optional[str]) -> Optional[str]:
|
||||
try:
|
||||
# First check if the string contains the expected prefix
|
||||
if not isinstance(model, str) or "arn:aws:bedrock" not in model:
|
||||
return None
|
||||
|
||||
# Split the ARN and check if we have enough parts
|
||||
parts = model.split(":")
|
||||
if len(parts) < 4:
|
||||
return None
|
||||
|
||||
# Get the region from the correct position
|
||||
region = parts[3]
|
||||
if not region: # Check if region is empty
|
||||
return None
|
||||
|
||||
return region
|
||||
except Exception:
|
||||
# Catch any unexpected errors and return None
|
||||
return None
|
||||
|
||||
def _get_aws_region_name(
|
||||
self, optional_params: dict, model: Optional[str] = None
|
||||
) -> str:
|
||||
"""
|
||||
Get the AWS region name from the environment variables
|
||||
"""
|
||||
aws_region_name = optional_params.get("aws_region_name", None)
|
||||
### SET REGION NAME ###
|
||||
if aws_region_name is None:
|
||||
# check model arn #
|
||||
aws_region_name = self.get_aws_region_from_model_arn(model)
|
||||
# check env #
|
||||
litellm_aws_region_name = get_secret("AWS_REGION_NAME", None)
|
||||
|
||||
if (
|
||||
aws_region_name is None
|
||||
and litellm_aws_region_name is not None
|
||||
and isinstance(litellm_aws_region_name, str)
|
||||
):
|
||||
aws_region_name = litellm_aws_region_name
|
||||
|
||||
standard_aws_region_name = get_secret("AWS_REGION", None)
|
||||
if (
|
||||
aws_region_name is None
|
||||
and standard_aws_region_name is not None
|
||||
and isinstance(standard_aws_region_name, str)
|
||||
):
|
||||
aws_region_name = standard_aws_region_name
|
||||
|
||||
if aws_region_name is None:
|
||||
aws_region_name = "us-west-2"
|
||||
|
||||
return aws_region_name
|
||||
|
||||
def _get_model_id_from_model_with_spec(
|
||||
self,
|
||||
model: str,
|
||||
|
|
|
@ -318,6 +318,23 @@ class BedrockModelInfo(BaseLLMModelInfo):
|
|||
global_config = AmazonBedrockGlobalConfig()
|
||||
all_global_regions = global_config.get_all_regions()
|
||||
|
||||
@staticmethod
|
||||
def extract_model_name_from_arn(model: str) -> str:
|
||||
"""
|
||||
Extract the model name from an AWS Bedrock ARN.
|
||||
Returns the string after the last '/' if 'arn' is in the input string.
|
||||
|
||||
Args:
|
||||
arn (str): The ARN string to parse
|
||||
|
||||
Returns:
|
||||
str: The extracted model name if 'arn' is in the string,
|
||||
otherwise returns the original string
|
||||
"""
|
||||
if "arn" in model.lower():
|
||||
return model.split("/")[-1]
|
||||
return model
|
||||
|
||||
@staticmethod
|
||||
def get_base_model(model: str) -> str:
|
||||
"""
|
||||
|
@ -335,6 +352,8 @@ class BedrockModelInfo(BaseLLMModelInfo):
|
|||
if model.startswith("invoke/"):
|
||||
model = model.split("/", 1)[1]
|
||||
|
||||
model = BedrockModelInfo.extract_model_name_from_arn(model)
|
||||
|
||||
potential_region = model.split(".", 1)[0]
|
||||
|
||||
alt_potential_region = model.split("/", 1)[
|
||||
|
|
|
@ -163,7 +163,7 @@ class BedrockImageGeneration(BaseAWSLLM):
|
|||
except ImportError:
|
||||
raise ImportError("Missing boto3 to call bedrock. Run 'pip install boto3'.")
|
||||
boto3_credentials_info = self._get_boto_credentials_from_optional_params(
|
||||
optional_params
|
||||
optional_params, model
|
||||
)
|
||||
|
||||
### SET RUNTIME ENDPOINT ###
|
||||
|
|
|
@ -6,6 +6,8 @@ import httpx
|
|||
import litellm
|
||||
from litellm.litellm_core_utils.litellm_logging import Logging as LitellmLogging
|
||||
from litellm.llms.custom_httpx.http_handler import (
|
||||
AsyncHTTPHandler,
|
||||
HTTPHandler,
|
||||
_get_httpx_client,
|
||||
get_async_httpx_client,
|
||||
)
|
||||
|
@ -27,8 +29,10 @@ class BedrockRerankHandler(BaseAWSLLM):
|
|||
async def arerank(
|
||||
self,
|
||||
prepared_request: BedrockPreparedRequest,
|
||||
client: Optional[AsyncHTTPHandler] = None,
|
||||
):
|
||||
client = get_async_httpx_client(llm_provider=litellm.LlmProviders.BEDROCK)
|
||||
if client is None:
|
||||
client = get_async_httpx_client(llm_provider=litellm.LlmProviders.BEDROCK)
|
||||
try:
|
||||
response = await client.post(url=prepared_request["endpoint_url"], headers=prepared_request["prepped"].headers, data=prepared_request["body"]) # type: ignore
|
||||
response.raise_for_status()
|
||||
|
@ -54,7 +58,9 @@ class BedrockRerankHandler(BaseAWSLLM):
|
|||
_is_async: Optional[bool] = False,
|
||||
api_base: Optional[str] = None,
|
||||
extra_headers: Optional[dict] = None,
|
||||
client: Optional[Union[HTTPHandler, AsyncHTTPHandler]] = None,
|
||||
) -> RerankResponse:
|
||||
|
||||
request_data = RerankRequest(
|
||||
model=model,
|
||||
query=query,
|
||||
|
@ -66,6 +72,7 @@ class BedrockRerankHandler(BaseAWSLLM):
|
|||
data = BedrockRerankConfig()._transform_request(request_data)
|
||||
|
||||
prepared_request = self._prepare_request(
|
||||
model=model,
|
||||
optional_params=optional_params,
|
||||
api_base=api_base,
|
||||
extra_headers=extra_headers,
|
||||
|
@ -83,9 +90,10 @@ class BedrockRerankHandler(BaseAWSLLM):
|
|||
)
|
||||
|
||||
if _is_async:
|
||||
return self.arerank(prepared_request) # type: ignore
|
||||
return self.arerank(prepared_request, client=client if client is not None and isinstance(client, AsyncHTTPHandler) else None) # type: ignore
|
||||
|
||||
client = _get_httpx_client()
|
||||
if client is None or not isinstance(client, HTTPHandler):
|
||||
client = _get_httpx_client()
|
||||
try:
|
||||
response = client.post(url=prepared_request["endpoint_url"], headers=prepared_request["prepped"].headers, data=prepared_request["body"]) # type: ignore
|
||||
response.raise_for_status()
|
||||
|
@ -95,10 +103,18 @@ class BedrockRerankHandler(BaseAWSLLM):
|
|||
except httpx.TimeoutException:
|
||||
raise BedrockError(status_code=408, message="Timeout error occurred.")
|
||||
|
||||
return BedrockRerankConfig()._transform_response(response.json())
|
||||
logging_obj.post_call(
|
||||
original_response=response.text,
|
||||
api_key="",
|
||||
)
|
||||
|
||||
response_json = response.json()
|
||||
|
||||
return BedrockRerankConfig()._transform_response(response_json)
|
||||
|
||||
def _prepare_request(
|
||||
self,
|
||||
model: str,
|
||||
api_base: Optional[str],
|
||||
extra_headers: Optional[dict],
|
||||
data: dict,
|
||||
|
@ -110,7 +126,7 @@ class BedrockRerankHandler(BaseAWSLLM):
|
|||
except ImportError:
|
||||
raise ImportError("Missing boto3 to call bedrock. Run 'pip install boto3'.")
|
||||
boto3_credentials_info = self._get_boto_credentials_from_optional_params(
|
||||
optional_params
|
||||
optional_params, model
|
||||
)
|
||||
|
||||
### SET RUNTIME ENDPOINT ###
|
||||
|
|
|
@ -91,7 +91,9 @@ class BedrockRerankConfig:
|
|||
example input:
|
||||
{"results":[{"index":0,"relevanceScore":0.6847912669181824},{"index":1,"relevanceScore":0.5980774760246277}]}
|
||||
"""
|
||||
_billed_units = RerankBilledUnits(**response.get("usage", {}))
|
||||
_billed_units = RerankBilledUnits(
|
||||
**response.get("usage", {"search_units": 1})
|
||||
) # by default 1 search unit
|
||||
_tokens = RerankTokens(**response.get("usage", {}))
|
||||
rerank_meta = RerankResponseMeta(billed_units=_billed_units, tokens=_tokens)
|
||||
|
||||
|
|
|
@ -1,31 +0,0 @@
|
|||
"""
|
||||
Custom cost calculator for Cohere rerank models
|
||||
"""
|
||||
|
||||
from typing import Tuple
|
||||
|
||||
from litellm.utils import get_model_info
|
||||
|
||||
|
||||
def cost_per_query(model: str, num_queries: int = 1) -> Tuple[float, float]:
|
||||
"""
|
||||
Calculates the cost per query for a given rerank model.
|
||||
|
||||
Input:
|
||||
- model: str, the model name without provider prefix
|
||||
|
||||
Returns:
|
||||
Tuple[float, float] - prompt_cost_in_usd, completion_cost_in_usd
|
||||
"""
|
||||
|
||||
model_info = get_model_info(model=model, custom_llm_provider="cohere")
|
||||
|
||||
if (
|
||||
"input_cost_per_query" not in model_info
|
||||
or model_info["input_cost_per_query"] is None
|
||||
):
|
||||
return 0.0, 0.0
|
||||
|
||||
prompt_cost = model_info["input_cost_per_query"] * num_queries
|
||||
|
||||
return prompt_cost, 0.0
|
|
@ -1,92 +1,3 @@
|
|||
"""
|
||||
Re rank api
|
||||
|
||||
LiteLLM supports the re rank API format, no paramter transformation occurs
|
||||
HTTP calling migrated to `llm_http_handler.py`
|
||||
"""
|
||||
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
import litellm
|
||||
from litellm.llms.base import BaseLLM
|
||||
from litellm.llms.custom_httpx.http_handler import (
|
||||
_get_httpx_client,
|
||||
get_async_httpx_client,
|
||||
)
|
||||
from litellm.llms.jina_ai.rerank.transformation import JinaAIRerankConfig
|
||||
from litellm.types.rerank import RerankRequest, RerankResponse
|
||||
|
||||
|
||||
class JinaAIRerank(BaseLLM):
|
||||
def rerank(
|
||||
self,
|
||||
model: str,
|
||||
api_key: str,
|
||||
query: str,
|
||||
documents: List[Union[str, Dict[str, Any]]],
|
||||
top_n: Optional[int] = None,
|
||||
rank_fields: Optional[List[str]] = None,
|
||||
return_documents: Optional[bool] = True,
|
||||
max_chunks_per_doc: Optional[int] = None,
|
||||
_is_async: Optional[bool] = False,
|
||||
) -> RerankResponse:
|
||||
client = _get_httpx_client()
|
||||
|
||||
request_data = RerankRequest(
|
||||
model=model,
|
||||
query=query,
|
||||
top_n=top_n,
|
||||
documents=documents,
|
||||
rank_fields=rank_fields,
|
||||
return_documents=return_documents,
|
||||
)
|
||||
|
||||
# exclude None values from request_data
|
||||
request_data_dict = request_data.dict(exclude_none=True)
|
||||
|
||||
if _is_async:
|
||||
return self.async_rerank(request_data_dict, api_key) # type: ignore # Call async method
|
||||
|
||||
response = client.post(
|
||||
"https://api.jina.ai/v1/rerank",
|
||||
headers={
|
||||
"accept": "application/json",
|
||||
"content-type": "application/json",
|
||||
"authorization": f"Bearer {api_key}",
|
||||
},
|
||||
json=request_data_dict,
|
||||
)
|
||||
|
||||
if response.status_code != 200:
|
||||
raise Exception(response.text)
|
||||
|
||||
_json_response = response.json()
|
||||
|
||||
return JinaAIRerankConfig()._transform_response(_json_response)
|
||||
|
||||
async def async_rerank( # New async method
|
||||
self,
|
||||
request_data_dict: Dict[str, Any],
|
||||
api_key: str,
|
||||
) -> RerankResponse:
|
||||
client = get_async_httpx_client(
|
||||
llm_provider=litellm.LlmProviders.JINA_AI
|
||||
) # Use async client
|
||||
|
||||
response = await client.post(
|
||||
"https://api.jina.ai/v1/rerank",
|
||||
headers={
|
||||
"accept": "application/json",
|
||||
"content-type": "application/json",
|
||||
"authorization": f"Bearer {api_key}",
|
||||
},
|
||||
json=request_data_dict,
|
||||
)
|
||||
|
||||
if response.status_code != 200:
|
||||
raise Exception(response.text)
|
||||
|
||||
_json_response = response.json()
|
||||
|
||||
return JinaAIRerankConfig()._transform_response(_json_response)
|
||||
|
||||
pass
|
||||
|
|
|
@ -7,30 +7,136 @@ Docs - https://jina.ai/reranker
|
|||
"""
|
||||
|
||||
import uuid
|
||||
from typing import List, Optional
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
from httpx import URL, Response
|
||||
|
||||
from litellm.llms.base_llm.chat.transformation import LiteLLMLoggingObj
|
||||
from litellm.llms.base_llm.rerank.transformation import BaseRerankConfig
|
||||
from litellm.types.rerank import (
|
||||
OptionalRerankParams,
|
||||
RerankBilledUnits,
|
||||
RerankResponse,
|
||||
RerankResponseMeta,
|
||||
RerankTokens,
|
||||
)
|
||||
from litellm.types.utils import ModelInfo
|
||||
|
||||
|
||||
class JinaAIRerankConfig:
|
||||
def _transform_response(self, response: dict) -> RerankResponse:
|
||||
class JinaAIRerankConfig(BaseRerankConfig):
|
||||
def get_supported_cohere_rerank_params(self, model: str) -> list:
|
||||
return [
|
||||
"query",
|
||||
"top_n",
|
||||
"documents",
|
||||
"return_documents",
|
||||
]
|
||||
|
||||
_billed_units = RerankBilledUnits(**response.get("usage", {}))
|
||||
_tokens = RerankTokens(**response.get("usage", {}))
|
||||
def map_cohere_rerank_params(
|
||||
self,
|
||||
non_default_params: dict,
|
||||
model: str,
|
||||
drop_params: bool,
|
||||
query: str,
|
||||
documents: List[Union[str, Dict[str, Any]]],
|
||||
custom_llm_provider: Optional[str] = None,
|
||||
top_n: Optional[int] = None,
|
||||
rank_fields: Optional[List[str]] = None,
|
||||
return_documents: Optional[bool] = True,
|
||||
max_chunks_per_doc: Optional[int] = None,
|
||||
) -> OptionalRerankParams:
|
||||
optional_params = {}
|
||||
supported_params = self.get_supported_cohere_rerank_params(model)
|
||||
for k, v in non_default_params.items():
|
||||
if k in supported_params:
|
||||
optional_params[k] = v
|
||||
return OptionalRerankParams(
|
||||
**optional_params,
|
||||
)
|
||||
|
||||
def get_complete_url(self, api_base: Optional[str], model: str) -> str:
|
||||
base_path = "/v1/rerank"
|
||||
|
||||
if api_base is None:
|
||||
return "https://api.jina.ai/v1/rerank"
|
||||
base = URL(api_base)
|
||||
# Reconstruct URL with cleaned path
|
||||
cleaned_base = str(base.copy_with(path=base_path))
|
||||
|
||||
return cleaned_base
|
||||
|
||||
def transform_rerank_request(
|
||||
self, model: str, optional_rerank_params: OptionalRerankParams, headers: Dict
|
||||
) -> Dict:
|
||||
return {"model": model, **optional_rerank_params}
|
||||
|
||||
def transform_rerank_response(
|
||||
self,
|
||||
model: str,
|
||||
raw_response: Response,
|
||||
model_response: RerankResponse,
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
api_key: Optional[str] = None,
|
||||
request_data: Dict = {},
|
||||
optional_params: Dict = {},
|
||||
litellm_params: Dict = {},
|
||||
) -> RerankResponse:
|
||||
if raw_response.status_code != 200:
|
||||
raise Exception(raw_response.text)
|
||||
|
||||
logging_obj.post_call(original_response=raw_response.text)
|
||||
|
||||
_json_response = raw_response.json()
|
||||
|
||||
_billed_units = RerankBilledUnits(**_json_response.get("usage", {}))
|
||||
_tokens = RerankTokens(**_json_response.get("usage", {}))
|
||||
rerank_meta = RerankResponseMeta(billed_units=_billed_units, tokens=_tokens)
|
||||
|
||||
_results: Optional[List[dict]] = response.get("results")
|
||||
_results: Optional[List[dict]] = _json_response.get("results")
|
||||
|
||||
if _results is None:
|
||||
raise ValueError(f"No results found in the response={response}")
|
||||
raise ValueError(f"No results found in the response={_json_response}")
|
||||
|
||||
return RerankResponse(
|
||||
id=response.get("id") or str(uuid.uuid4()),
|
||||
id=_json_response.get("id") or str(uuid.uuid4()),
|
||||
results=_results, # type: ignore
|
||||
meta=rerank_meta,
|
||||
) # Return response
|
||||
|
||||
def validate_environment(
|
||||
self, headers: Dict, model: str, api_key: Optional[str] = None
|
||||
) -> Dict:
|
||||
if api_key is None:
|
||||
raise ValueError(
|
||||
"api_key is required. Set via `api_key` parameter or `JINA_API_KEY` environment variable."
|
||||
)
|
||||
return {
|
||||
"accept": "application/json",
|
||||
"content-type": "application/json",
|
||||
"authorization": f"Bearer {api_key}",
|
||||
}
|
||||
|
||||
def calculate_rerank_cost(
|
||||
self,
|
||||
model: str,
|
||||
custom_llm_provider: Optional[str] = None,
|
||||
billed_units: Optional[RerankBilledUnits] = None,
|
||||
model_info: Optional[ModelInfo] = None,
|
||||
) -> Tuple[float, float]:
|
||||
"""
|
||||
Jina AI reranker is priced at $0.000000018 per token.
|
||||
"""
|
||||
if (
|
||||
model_info is None
|
||||
or "input_cost_per_token" not in model_info
|
||||
or model_info["input_cost_per_token"] is None
|
||||
or billed_units is None
|
||||
):
|
||||
return 0.0, 0.0
|
||||
|
||||
total_tokens = billed_units.get("total_tokens")
|
||||
if total_tokens is None:
|
||||
return 0.0, 0.0
|
||||
|
||||
input_cost = model_info["input_cost_per_token"] * total_tokens
|
||||
return input_cost, 0.0
|
||||
|
|
|
@ -5982,6 +5982,19 @@
|
|||
"litellm_provider": "bedrock",
|
||||
"mode": "chat"
|
||||
},
|
||||
"amazon.rerank-v1:0": {
|
||||
"max_tokens": 32000,
|
||||
"max_input_tokens": 32000,
|
||||
"max_output_tokens": 32000,
|
||||
"max_query_tokens": 32000,
|
||||
"max_document_chunks_per_query": 100,
|
||||
"max_tokens_per_document_chunk": 512,
|
||||
"input_cost_per_token": 0.0,
|
||||
"input_cost_per_query": 0.001,
|
||||
"output_cost_per_token": 0.0,
|
||||
"litellm_provider": "bedrock",
|
||||
"mode": "rerank"
|
||||
},
|
||||
"amazon.titan-text-lite-v1": {
|
||||
"max_tokens": 4000,
|
||||
"max_input_tokens": 42000,
|
||||
|
@ -7022,6 +7035,19 @@
|
|||
"mode": "chat",
|
||||
"supports_tool_choice": true
|
||||
},
|
||||
"cohere.rerank-v3-5:0": {
|
||||
"max_tokens": 32000,
|
||||
"max_input_tokens": 32000,
|
||||
"max_output_tokens": 32000,
|
||||
"max_query_tokens": 32000,
|
||||
"max_document_chunks_per_query": 100,
|
||||
"max_tokens_per_document_chunk": 512,
|
||||
"input_cost_per_token": 0.0,
|
||||
"input_cost_per_query": 0.002,
|
||||
"output_cost_per_token": 0.0,
|
||||
"litellm_provider": "bedrock",
|
||||
"mode": "rerank"
|
||||
},
|
||||
"cohere.command-text-v14": {
|
||||
"max_tokens": 4096,
|
||||
"max_input_tokens": 4096,
|
||||
|
@ -9154,5 +9180,15 @@
|
|||
"input_cost_per_second": 0.00003333,
|
||||
"output_cost_per_second": 0.00,
|
||||
"litellm_provider": "assemblyai"
|
||||
},
|
||||
"jina-reranker-v2-base-multilingual": {
|
||||
"max_tokens": 1024,
|
||||
"max_input_tokens": 1024,
|
||||
"max_output_tokens": 1024,
|
||||
"max_document_chunks_per_query": 2048,
|
||||
"input_cost_per_token": 0.000000018,
|
||||
"output_cost_per_token": 0.000000018,
|
||||
"litellm_provider": "jina_ai",
|
||||
"mode": "rerank"
|
||||
}
|
||||
}
|
||||
|
|
|
@ -9,7 +9,6 @@ from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLogging
|
|||
from litellm.llms.base_llm.rerank.transformation import BaseRerankConfig
|
||||
from litellm.llms.bedrock.rerank.handler import BedrockRerankHandler
|
||||
from litellm.llms.custom_httpx.llm_http_handler import BaseLLMHTTPHandler
|
||||
from litellm.llms.jina_ai.rerank.handler import JinaAIRerank
|
||||
from litellm.llms.together_ai.rerank.handler import TogetherAIRerank
|
||||
from litellm.rerank_api.rerank_utils import get_optional_rerank_params
|
||||
from litellm.secret_managers.main import get_secret, get_secret_str
|
||||
|
@ -20,7 +19,6 @@ from litellm.utils import ProviderConfigManager, client, exception_type
|
|||
####### ENVIRONMENT VARIABLES ###################
|
||||
# Initialize any necessary instances or variables here
|
||||
together_rerank = TogetherAIRerank()
|
||||
jina_ai_rerank = JinaAIRerank()
|
||||
bedrock_rerank = BedrockRerankHandler()
|
||||
base_llm_http_handler = BaseLLMHTTPHandler()
|
||||
#################################################
|
||||
|
@ -264,16 +262,26 @@ def rerank( # noqa: PLR0915
|
|||
raise ValueError(
|
||||
"Jina AI API key is required, please set 'JINA_AI_API_KEY' in your environment"
|
||||
)
|
||||
response = jina_ai_rerank.rerank(
|
||||
|
||||
api_base = (
|
||||
dynamic_api_base
|
||||
or optional_params.api_base
|
||||
or litellm.api_base
|
||||
or get_secret("BEDROCK_API_BASE") # type: ignore
|
||||
)
|
||||
|
||||
response = base_llm_http_handler.rerank(
|
||||
model=model,
|
||||
api_key=dynamic_api_key,
|
||||
query=query,
|
||||
documents=documents,
|
||||
top_n=top_n,
|
||||
rank_fields=rank_fields,
|
||||
return_documents=return_documents,
|
||||
max_chunks_per_doc=max_chunks_per_doc,
|
||||
custom_llm_provider=_custom_llm_provider,
|
||||
optional_rerank_params=optional_rerank_params,
|
||||
logging_obj=litellm_logging_obj,
|
||||
timeout=optional_params.timeout,
|
||||
api_key=dynamic_api_key or optional_params.api_key,
|
||||
api_base=api_base,
|
||||
_is_async=_is_async,
|
||||
headers=headers or litellm.headers or {},
|
||||
client=client,
|
||||
model_response=model_response,
|
||||
)
|
||||
elif _custom_llm_provider == "bedrock":
|
||||
api_base = (
|
||||
|
@ -295,6 +303,7 @@ def rerank( # noqa: PLR0915
|
|||
optional_params=optional_params.model_dump(exclude_unset=True),
|
||||
api_base=api_base,
|
||||
logging_obj=litellm_logging_obj,
|
||||
client=client,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unsupported provider: {_custom_llm_provider}")
|
||||
|
|
|
@ -17,6 +17,17 @@ def get_optional_rerank_params(
|
|||
max_chunks_per_doc: Optional[int] = None,
|
||||
non_default_params: Optional[dict] = None,
|
||||
) -> OptionalRerankParams:
|
||||
all_non_default_params = non_default_params or {}
|
||||
if query is not None:
|
||||
all_non_default_params["query"] = query
|
||||
if top_n is not None:
|
||||
all_non_default_params["top_n"] = top_n
|
||||
if documents is not None:
|
||||
all_non_default_params["documents"] = documents
|
||||
if return_documents is not None:
|
||||
all_non_default_params["return_documents"] = return_documents
|
||||
if max_chunks_per_doc is not None:
|
||||
all_non_default_params["max_chunks_per_doc"] = max_chunks_per_doc
|
||||
return rerank_provider_config.map_cohere_rerank_params(
|
||||
model=model,
|
||||
drop_params=drop_params,
|
||||
|
@ -27,5 +38,5 @@ def get_optional_rerank_params(
|
|||
rank_fields=rank_fields,
|
||||
return_documents=return_documents,
|
||||
max_chunks_per_doc=max_chunks_per_doc,
|
||||
non_default_params=non_default_params,
|
||||
non_default_params=all_non_default_params,
|
||||
)
|
||||
|
|
|
@ -6198,6 +6198,8 @@ class ProviderConfigManager:
|
|||
return litellm.AzureAIRerankConfig()
|
||||
elif litellm.LlmProviders.INFINITY == provider:
|
||||
return litellm.InfinityRerankConfig()
|
||||
elif litellm.LlmProviders.JINA_AI == provider:
|
||||
return litellm.JinaAIRerankConfig()
|
||||
return litellm.CohereRerankConfig()
|
||||
|
||||
@staticmethod
|
||||
|
|
|
@ -5982,6 +5982,19 @@
|
|||
"litellm_provider": "bedrock",
|
||||
"mode": "chat"
|
||||
},
|
||||
"amazon.rerank-v1:0": {
|
||||
"max_tokens": 32000,
|
||||
"max_input_tokens": 32000,
|
||||
"max_output_tokens": 32000,
|
||||
"max_query_tokens": 32000,
|
||||
"max_document_chunks_per_query": 100,
|
||||
"max_tokens_per_document_chunk": 512,
|
||||
"input_cost_per_token": 0.0,
|
||||
"input_cost_per_query": 0.001,
|
||||
"output_cost_per_token": 0.0,
|
||||
"litellm_provider": "bedrock",
|
||||
"mode": "rerank"
|
||||
},
|
||||
"amazon.titan-text-lite-v1": {
|
||||
"max_tokens": 4000,
|
||||
"max_input_tokens": 42000,
|
||||
|
@ -7022,6 +7035,19 @@
|
|||
"mode": "chat",
|
||||
"supports_tool_choice": true
|
||||
},
|
||||
"cohere.rerank-v3-5:0": {
|
||||
"max_tokens": 32000,
|
||||
"max_input_tokens": 32000,
|
||||
"max_output_tokens": 32000,
|
||||
"max_query_tokens": 32000,
|
||||
"max_document_chunks_per_query": 100,
|
||||
"max_tokens_per_document_chunk": 512,
|
||||
"input_cost_per_token": 0.0,
|
||||
"input_cost_per_query": 0.002,
|
||||
"output_cost_per_token": 0.0,
|
||||
"litellm_provider": "bedrock",
|
||||
"mode": "rerank"
|
||||
},
|
||||
"cohere.command-text-v14": {
|
||||
"max_tokens": 4096,
|
||||
"max_input_tokens": 4096,
|
||||
|
@ -9154,5 +9180,15 @@
|
|||
"input_cost_per_second": 0.00003333,
|
||||
"output_cost_per_second": 0.00,
|
||||
"litellm_provider": "assemblyai"
|
||||
},
|
||||
"jina-reranker-v2-base-multilingual": {
|
||||
"max_tokens": 1024,
|
||||
"max_input_tokens": 1024,
|
||||
"max_output_tokens": 1024,
|
||||
"max_document_chunks_per_query": 2048,
|
||||
"input_cost_per_token": 0.000000018,
|
||||
"output_cost_per_token": 0.000000018,
|
||||
"litellm_provider": "jina_ai",
|
||||
"mode": "rerank"
|
||||
}
|
||||
}
|
||||
|
|
14
tests/litellm/llms/bedrock/rerank/transformation.py
Normal file
14
tests/litellm/llms/bedrock/rerank/transformation.py
Normal file
|
@ -0,0 +1,14 @@
|
|||
import json
|
||||
import os
|
||||
import sys
|
||||
|
||||
import pytest
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
sys.path.insert(
|
||||
0, os.path.abspath("../../..")
|
||||
) # Adds the parent directory to the system path
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from litellm import rerank
|
||||
from litellm.llms.custom_httpx.http_handler import HTTPHandler
|
51
tests/litellm/rerank_api/test_main.py
Normal file
51
tests/litellm/rerank_api/test_main.py
Normal file
|
@ -0,0 +1,51 @@
|
|||
import json
|
||||
import os
|
||||
import sys
|
||||
|
||||
import pytest
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
sys.path.insert(
|
||||
0, os.path.abspath("../../..")
|
||||
) # Adds the parent directory to the system path
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from litellm import rerank
|
||||
from litellm.llms.custom_httpx.http_handler import HTTPHandler
|
||||
|
||||
|
||||
def test_rerank_infer_region_from_model_arn(monkeypatch):
|
||||
mock_response = MagicMock()
|
||||
|
||||
monkeypatch.setenv("AWS_REGION_NAME", "us-east-1")
|
||||
args = {
|
||||
"model": "bedrock/arn:aws:bedrock:us-west-2::foundation-model/amazon.rerank-v1:0",
|
||||
"query": "hello",
|
||||
"documents": ["hello", "world"],
|
||||
}
|
||||
|
||||
def return_val():
|
||||
return {
|
||||
"results": [
|
||||
{"index": 0, "relevanceScore": 0.6716859340667725},
|
||||
{"index": 1, "relevanceScore": 0.0004994205664843321},
|
||||
]
|
||||
}
|
||||
|
||||
mock_response.json = return_val
|
||||
mock_response.headers = {"key": "value"}
|
||||
mock_response.status_code = 200
|
||||
|
||||
client = HTTPHandler()
|
||||
|
||||
with patch.object(client, "post", return_value=mock_response) as mock_post:
|
||||
rerank(
|
||||
model=args["model"],
|
||||
query=args["query"],
|
||||
documents=args["documents"],
|
||||
client=client,
|
||||
)
|
||||
mock_post.assert_called_once()
|
||||
print(f"mock_post.call_args: {mock_post.call_args.kwargs}")
|
||||
assert "us-west-2" in mock_post.call_args.kwargs["url"]
|
||||
assert "us-east-1" not in mock_post.call_args.kwargs["url"]
|
|
@ -33,6 +33,7 @@ def assert_response_shape(response, custom_llm_provider):
|
|||
expected_api_version_shape = {"version": str}
|
||||
|
||||
expected_billed_units_shape = {"search_units": int}
|
||||
expected_billed_units_total_tokens_shape = {"total_tokens": int}
|
||||
|
||||
assert isinstance(response.id, expected_response_shape["id"])
|
||||
assert isinstance(response.results, expected_response_shape["results"])
|
||||
|
@ -52,9 +53,15 @@ def assert_response_shape(response, custom_llm_provider):
|
|||
response.meta["api_version"]["version"],
|
||||
expected_api_version_shape["version"],
|
||||
)
|
||||
assert isinstance(
|
||||
response.meta["billed_units"], expected_meta_shape["billed_units"]
|
||||
)
|
||||
if "total_tokens" in response.meta["billed_units"]:
|
||||
assert isinstance(
|
||||
response.meta["billed_units"], expected_meta_shape["billed_units"]
|
||||
response.meta["billed_units"]["total_tokens"],
|
||||
expected_billed_units_total_tokens_shape["total_tokens"],
|
||||
)
|
||||
else:
|
||||
assert isinstance(
|
||||
response.meta["billed_units"]["search_units"],
|
||||
expected_billed_units_shape["search_units"],
|
||||
|
@ -79,7 +86,9 @@ class BaseLLMRerankTest(ABC):
|
|||
@pytest.mark.asyncio()
|
||||
@pytest.mark.parametrize("sync_mode", [True, False])
|
||||
async def test_basic_rerank(self, sync_mode):
|
||||
litellm.set_verbose = True
|
||||
litellm._turn_on_debug()
|
||||
os.environ["LITELLM_LOCAL_MODEL_COST_MAP"] = "True"
|
||||
litellm.model_cost = litellm.get_model_cost_map(url="")
|
||||
rerank_call_args = self.get_base_rerank_call_args()
|
||||
custom_llm_provider = self.get_custom_llm_provider()
|
||||
if sync_mode is True:
|
||||
|
@ -95,6 +104,9 @@ class BaseLLMRerankTest(ABC):
|
|||
assert response.id is not None
|
||||
assert response.results is not None
|
||||
|
||||
assert response._hidden_params["response_cost"] is not None
|
||||
assert response._hidden_params["response_cost"] > 0
|
||||
|
||||
assert_response_shape(
|
||||
response=response, custom_llm_provider=custom_llm_provider.value
|
||||
)
|
||||
|
|
|
@ -14,6 +14,7 @@ from litellm.llms.anthropic.chat import ModelResponseIterator
|
|||
import httpx
|
||||
import json
|
||||
from litellm.llms.custom_httpx.http_handler import HTTPHandler
|
||||
from base_rerank_unit_tests import BaseLLMRerankTest
|
||||
|
||||
load_dotenv()
|
||||
import io
|
||||
|
@ -185,6 +186,7 @@ def test_completion_azure_ai_command_r():
|
|||
except Exception as e:
|
||||
pytest.fail(f"Error occurred: {e}")
|
||||
|
||||
|
||||
def test_azure_deepseek_reasoning_content():
|
||||
import json
|
||||
|
||||
|
@ -192,37 +194,48 @@ def test_azure_deepseek_reasoning_content():
|
|||
|
||||
with patch.object(client, "post") as mock_post:
|
||||
mock_response = MagicMock()
|
||||
|
||||
|
||||
mock_response.text = json.dumps(
|
||||
{
|
||||
"choices": [
|
||||
{
|
||||
"finish_reason": "stop",
|
||||
"index": 0,
|
||||
"message": {
|
||||
"content": "<think>I am thinking here</think>\n\nThe sky is a canvas of blue",
|
||||
"role": "assistant",
|
||||
}
|
||||
"finish_reason": "stop",
|
||||
"index": 0,
|
||||
"message": {
|
||||
"content": "<think>I am thinking here</think>\n\nThe sky is a canvas of blue",
|
||||
"role": "assistant",
|
||||
},
|
||||
}
|
||||
],
|
||||
}
|
||||
)
|
||||
|
||||
)
|
||||
|
||||
mock_response.status_code = 200
|
||||
# Add required response attributes
|
||||
mock_response.headers = {"Content-Type": "application/json"}
|
||||
mock_response.json = lambda: json.loads(mock_response.text)
|
||||
mock_post.return_value = mock_response
|
||||
|
||||
|
||||
response = litellm.completion(
|
||||
model='azure_ai/deepseek-r1',
|
||||
model="azure_ai/deepseek-r1",
|
||||
messages=[{"role": "user", "content": "Hello, world!"}],
|
||||
api_base="https://litellm8397336933.services.ai.azure.com/models/chat/completions",
|
||||
api_key="my-fake-api-key",
|
||||
client=client
|
||||
)
|
||||
client=client,
|
||||
)
|
||||
|
||||
print(response)
|
||||
assert(response.choices[0].message.reasoning_content == "I am thinking here")
|
||||
assert(response.choices[0].message.content == "\n\nThe sky is a canvas of blue")
|
||||
assert response.choices[0].message.reasoning_content == "I am thinking here"
|
||||
assert response.choices[0].message.content == "\n\nThe sky is a canvas of blue"
|
||||
|
||||
|
||||
class TestAzureAIRerank(BaseLLMRerankTest):
|
||||
def get_custom_llm_provider(self) -> litellm.LlmProviders:
|
||||
return litellm.LlmProviders.AZURE_AI
|
||||
|
||||
def get_base_rerank_call_args(self) -> dict:
|
||||
return {
|
||||
"model": "azure_ai/cohere-rerank-v3-english",
|
||||
"api_base": os.getenv("AZURE_AI_COHERE_API_BASE"),
|
||||
"api_key": os.getenv("AZURE_AI_COHERE_API_KEY"),
|
||||
}
|
||||
|
|
|
@ -2186,6 +2186,16 @@ class TestBedrockRerank(BaseLLMRerankTest):
|
|||
}
|
||||
|
||||
|
||||
class TestBedrockCohereRerank(BaseLLMRerankTest):
|
||||
def get_custom_llm_provider(self) -> litellm.LlmProviders:
|
||||
return litellm.LlmProviders.BEDROCK
|
||||
|
||||
def get_base_rerank_call_args(self) -> dict:
|
||||
return {
|
||||
"model": "bedrock/arn:aws:bedrock:us-west-2::foundation-model/cohere.rerank-v3-5:0",
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"messages, continue_message_index",
|
||||
[
|
||||
|
|
|
@ -66,6 +66,7 @@ def assert_response_shape(response, custom_llm_provider):
|
|||
|
||||
@pytest.mark.asyncio()
|
||||
@pytest.mark.parametrize("sync_mode", [True, False])
|
||||
@pytest.mark.flaky(retries=3, delay=1)
|
||||
async def test_basic_rerank(sync_mode):
|
||||
litellm.set_verbose = True
|
||||
if sync_mode is True:
|
||||
|
@ -311,6 +312,7 @@ def test_complete_base_url_cohere():
|
|||
(3, None, False),
|
||||
],
|
||||
)
|
||||
@pytest.mark.flaky(retries=3, delay=1)
|
||||
async def test_basic_rerank_caching(sync_mode, top_n_1, top_n_2, expect_cache_hit):
|
||||
from litellm.caching.caching import Cache
|
||||
|
||||
|
|
|
@ -1574,7 +1574,11 @@ def test_completion_cost_azure_ai_rerank(model):
|
|||
"relevance_score": 0.990732,
|
||||
},
|
||||
],
|
||||
meta={},
|
||||
meta={
|
||||
"billed_units": {
|
||||
"search_units": 1,
|
||||
}
|
||||
},
|
||||
)
|
||||
print("response", response)
|
||||
model = model
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue