mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 18:54:30 +00:00
fix(cost_calculator.py): handle custom pricing at deployment level fo… (#9855)
* fix(cost_calculator.py): handle custom pricing at deployment level for router * test: add unit tests * fix(router.py): show custom pricing on UI check correct model str * fix: fix linting error * docs(custom_pricing.md): clarify custom pricing for proxy Fixes https://github.com/BerriAI/litellm/issues/8573#issuecomment-2790420740 * test: update code qa test * fix: cleanup traceback * fix: handle litellm param custom pricing * test: update test * fix(cost_calculator.py): add router model id to list of potential model names * fix(cost_calculator.py): fix router model id check * fix: router.py - maintain older model registry approach * fix: fix ruff check * fix(router.py): router get deployment info add custom values to mapped dict * test: update test * fix(utils.py): update only if value is non-null * test: add unit test
This commit is contained in:
parent
baa9bd6338
commit
e1eb5e32c1
16 changed files with 193 additions and 37 deletions
|
@ -26,10 +26,12 @@ model_list:
|
||||||
- model_name: sagemaker-completion-model
|
- model_name: sagemaker-completion-model
|
||||||
litellm_params:
|
litellm_params:
|
||||||
model: sagemaker/berri-benchmarking-Llama-2-70b-chat-hf-4
|
model: sagemaker/berri-benchmarking-Llama-2-70b-chat-hf-4
|
||||||
|
model_info:
|
||||||
input_cost_per_second: 0.000420
|
input_cost_per_second: 0.000420
|
||||||
- model_name: sagemaker-embedding-model
|
- model_name: sagemaker-embedding-model
|
||||||
litellm_params:
|
litellm_params:
|
||||||
model: sagemaker/berri-benchmarking-gpt-j-6b-fp16
|
model: sagemaker/berri-benchmarking-gpt-j-6b-fp16
|
||||||
|
model_info:
|
||||||
input_cost_per_second: 0.000420
|
input_cost_per_second: 0.000420
|
||||||
```
|
```
|
||||||
|
|
||||||
|
@ -55,11 +57,33 @@ model_list:
|
||||||
api_key: os.environ/AZURE_API_KEY
|
api_key: os.environ/AZURE_API_KEY
|
||||||
api_base: os.environ/AZURE_API_BASE
|
api_base: os.environ/AZURE_API_BASE
|
||||||
api_version: os.envrion/AZURE_API_VERSION
|
api_version: os.envrion/AZURE_API_VERSION
|
||||||
|
model_info:
|
||||||
input_cost_per_token: 0.000421 # 👈 ONLY to track cost per token
|
input_cost_per_token: 0.000421 # 👈 ONLY to track cost per token
|
||||||
output_cost_per_token: 0.000520 # 👈 ONLY to track cost per token
|
output_cost_per_token: 0.000520 # 👈 ONLY to track cost per token
|
||||||
```
|
```
|
||||||
|
|
||||||
### Debugging
|
## Override Model Cost Map
|
||||||
|
|
||||||
|
You can override [our model cost map](https://github.com/BerriAI/litellm/blob/main/model_prices_and_context_window.json) with your own custom pricing for a mapped model.
|
||||||
|
|
||||||
|
Just add a `model_info` key to your model in the config, and override the desired keys.
|
||||||
|
|
||||||
|
Example: Override Anthropic's model cost map for the `prod/claude-3-5-sonnet-20241022` model.
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
model_list:
|
||||||
|
- model_name: "prod/claude-3-5-sonnet-20241022"
|
||||||
|
litellm_params:
|
||||||
|
model: "anthropic/claude-3-5-sonnet-20241022"
|
||||||
|
api_key: os.environ/ANTHROPIC_PROD_API_KEY
|
||||||
|
model_info:
|
||||||
|
input_cost_per_token: 0.000006
|
||||||
|
output_cost_per_token: 0.00003
|
||||||
|
cache_creation_input_token_cost: 0.0000075
|
||||||
|
cache_read_input_token_cost: 0.0000006
|
||||||
|
```
|
||||||
|
|
||||||
|
## Debugging
|
||||||
|
|
||||||
If you're custom pricing is not being used or you're seeing errors, please check the following:
|
If you're custom pricing is not being used or you're seeing errors, please check the following:
|
||||||
|
|
||||||
|
|
|
@ -403,6 +403,7 @@ def _select_model_name_for_cost_calc(
|
||||||
base_model: Optional[str] = None,
|
base_model: Optional[str] = None,
|
||||||
custom_pricing: Optional[bool] = None,
|
custom_pricing: Optional[bool] = None,
|
||||||
custom_llm_provider: Optional[str] = None,
|
custom_llm_provider: Optional[str] = None,
|
||||||
|
router_model_id: Optional[str] = None,
|
||||||
) -> Optional[str]:
|
) -> Optional[str]:
|
||||||
"""
|
"""
|
||||||
1. If custom pricing is true, return received model name
|
1. If custom pricing is true, return received model name
|
||||||
|
@ -417,12 +418,6 @@ def _select_model_name_for_cost_calc(
|
||||||
model=model, custom_llm_provider=custom_llm_provider
|
model=model, custom_llm_provider=custom_llm_provider
|
||||||
)
|
)
|
||||||
|
|
||||||
if custom_pricing is True:
|
|
||||||
return_model = model
|
|
||||||
|
|
||||||
if base_model is not None:
|
|
||||||
return_model = base_model
|
|
||||||
|
|
||||||
completion_response_model: Optional[str] = None
|
completion_response_model: Optional[str] = None
|
||||||
if completion_response is not None:
|
if completion_response is not None:
|
||||||
if isinstance(completion_response, BaseModel):
|
if isinstance(completion_response, BaseModel):
|
||||||
|
@ -430,6 +425,16 @@ def _select_model_name_for_cost_calc(
|
||||||
elif isinstance(completion_response, dict):
|
elif isinstance(completion_response, dict):
|
||||||
completion_response_model = completion_response.get("model", None)
|
completion_response_model = completion_response.get("model", None)
|
||||||
hidden_params: Optional[dict] = getattr(completion_response, "_hidden_params", None)
|
hidden_params: Optional[dict] = getattr(completion_response, "_hidden_params", None)
|
||||||
|
|
||||||
|
if custom_pricing is True:
|
||||||
|
if router_model_id is not None and router_model_id in litellm.model_cost:
|
||||||
|
return_model = router_model_id
|
||||||
|
else:
|
||||||
|
return_model = model
|
||||||
|
|
||||||
|
if base_model is not None:
|
||||||
|
return_model = base_model
|
||||||
|
|
||||||
if completion_response_model is None and hidden_params is not None:
|
if completion_response_model is None and hidden_params is not None:
|
||||||
if (
|
if (
|
||||||
hidden_params.get("model", None) is not None
|
hidden_params.get("model", None) is not None
|
||||||
|
@ -559,6 +564,7 @@ def completion_cost( # noqa: PLR0915
|
||||||
base_model: Optional[str] = None,
|
base_model: Optional[str] = None,
|
||||||
standard_built_in_tools_params: Optional[StandardBuiltInToolsParams] = None,
|
standard_built_in_tools_params: Optional[StandardBuiltInToolsParams] = None,
|
||||||
litellm_model_name: Optional[str] = None,
|
litellm_model_name: Optional[str] = None,
|
||||||
|
router_model_id: Optional[str] = None,
|
||||||
) -> float:
|
) -> float:
|
||||||
"""
|
"""
|
||||||
Calculate the cost of a given completion call fot GPT-3.5-turbo, llama2, any litellm supported llm.
|
Calculate the cost of a given completion call fot GPT-3.5-turbo, llama2, any litellm supported llm.
|
||||||
|
@ -617,12 +623,12 @@ def completion_cost( # noqa: PLR0915
|
||||||
custom_llm_provider=custom_llm_provider,
|
custom_llm_provider=custom_llm_provider,
|
||||||
custom_pricing=custom_pricing,
|
custom_pricing=custom_pricing,
|
||||||
base_model=base_model,
|
base_model=base_model,
|
||||||
|
router_model_id=router_model_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
potential_model_names = [selected_model]
|
potential_model_names = [selected_model]
|
||||||
if model is not None:
|
if model is not None:
|
||||||
potential_model_names.append(model)
|
potential_model_names.append(model)
|
||||||
|
|
||||||
for idx, model in enumerate(potential_model_names):
|
for idx, model in enumerate(potential_model_names):
|
||||||
try:
|
try:
|
||||||
verbose_logger.info(
|
verbose_logger.info(
|
||||||
|
@ -943,6 +949,7 @@ def response_cost_calculator(
|
||||||
prompt: str = "",
|
prompt: str = "",
|
||||||
standard_built_in_tools_params: Optional[StandardBuiltInToolsParams] = None,
|
standard_built_in_tools_params: Optional[StandardBuiltInToolsParams] = None,
|
||||||
litellm_model_name: Optional[str] = None,
|
litellm_model_name: Optional[str] = None,
|
||||||
|
router_model_id: Optional[str] = None,
|
||||||
) -> float:
|
) -> float:
|
||||||
"""
|
"""
|
||||||
Returns
|
Returns
|
||||||
|
@ -973,6 +980,8 @@ def response_cost_calculator(
|
||||||
base_model=base_model,
|
base_model=base_model,
|
||||||
prompt=prompt,
|
prompt=prompt,
|
||||||
standard_built_in_tools_params=standard_built_in_tools_params,
|
standard_built_in_tools_params=standard_built_in_tools_params,
|
||||||
|
litellm_model_name=litellm_model_name,
|
||||||
|
router_model_id=router_model_id,
|
||||||
)
|
)
|
||||||
return response_cost
|
return response_cost
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|
|
@ -1,7 +1,6 @@
|
||||||
from typing import Literal, Optional
|
from typing import Literal, Optional
|
||||||
|
|
||||||
import litellm
|
import litellm
|
||||||
from litellm._logging import verbose_logger
|
|
||||||
from litellm.exceptions import BadRequestError
|
from litellm.exceptions import BadRequestError
|
||||||
from litellm.types.utils import LlmProviders, LlmProvidersSet
|
from litellm.types.utils import LlmProviders, LlmProvidersSet
|
||||||
|
|
||||||
|
@ -43,9 +42,6 @@ def get_supported_openai_params( # noqa: PLR0915
|
||||||
provider_config = None
|
provider_config = None
|
||||||
|
|
||||||
if provider_config and request_type == "chat_completion":
|
if provider_config and request_type == "chat_completion":
|
||||||
verbose_logger.info(
|
|
||||||
f"using provider_config: {provider_config} for checking supported params"
|
|
||||||
)
|
|
||||||
return provider_config.get_supported_openai_params(model=model)
|
return provider_config.get_supported_openai_params(model=model)
|
||||||
|
|
||||||
if custom_llm_provider == "bedrock":
|
if custom_llm_provider == "bedrock":
|
||||||
|
|
|
@ -622,7 +622,6 @@ class Logging(LiteLLMLoggingBaseClass):
|
||||||
] = RawRequestTypedDict(
|
] = RawRequestTypedDict(
|
||||||
error=str(e),
|
error=str(e),
|
||||||
)
|
)
|
||||||
traceback.print_exc()
|
|
||||||
_metadata[
|
_metadata[
|
||||||
"raw_request"
|
"raw_request"
|
||||||
] = "Unable to Log \
|
] = "Unable to Log \
|
||||||
|
@ -906,6 +905,7 @@ class Logging(LiteLLMLoggingBaseClass):
|
||||||
],
|
],
|
||||||
cache_hit: Optional[bool] = None,
|
cache_hit: Optional[bool] = None,
|
||||||
litellm_model_name: Optional[str] = None,
|
litellm_model_name: Optional[str] = None,
|
||||||
|
router_model_id: Optional[str] = None,
|
||||||
) -> Optional[float]:
|
) -> Optional[float]:
|
||||||
"""
|
"""
|
||||||
Calculate response cost using result + logging object variables.
|
Calculate response cost using result + logging object variables.
|
||||||
|
@ -944,6 +944,7 @@ class Logging(LiteLLMLoggingBaseClass):
|
||||||
"custom_pricing": custom_pricing,
|
"custom_pricing": custom_pricing,
|
||||||
"prompt": prompt,
|
"prompt": prompt,
|
||||||
"standard_built_in_tools_params": self.standard_built_in_tools_params,
|
"standard_built_in_tools_params": self.standard_built_in_tools_params,
|
||||||
|
"router_model_id": router_model_id,
|
||||||
}
|
}
|
||||||
except Exception as e: # error creating kwargs for cost calculation
|
except Exception as e: # error creating kwargs for cost calculation
|
||||||
debug_info = StandardLoggingModelCostFailureDebugInformation(
|
debug_info = StandardLoggingModelCostFailureDebugInformation(
|
||||||
|
|
|
@ -36,11 +36,16 @@ class ResponseMetadata:
|
||||||
self, logging_obj: LiteLLMLoggingObject, model: Optional[str], kwargs: dict
|
self, logging_obj: LiteLLMLoggingObject, model: Optional[str], kwargs: dict
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Set hidden parameters on the response"""
|
"""Set hidden parameters on the response"""
|
||||||
|
|
||||||
|
## ADD OTHER HIDDEN PARAMS
|
||||||
|
model_id = kwargs.get("model_info", {}).get("id", None)
|
||||||
new_params = {
|
new_params = {
|
||||||
"litellm_call_id": getattr(logging_obj, "litellm_call_id", None),
|
"litellm_call_id": getattr(logging_obj, "litellm_call_id", None),
|
||||||
"model_id": kwargs.get("model_info", {}).get("id", None),
|
|
||||||
"api_base": get_api_base(model=model or "", optional_params=kwargs),
|
"api_base": get_api_base(model=model or "", optional_params=kwargs),
|
||||||
"response_cost": logging_obj._response_cost_calculator(result=self.result),
|
"model_id": model_id,
|
||||||
|
"response_cost": logging_obj._response_cost_calculator(
|
||||||
|
result=self.result, litellm_model_name=model, router_model_id=model_id
|
||||||
|
),
|
||||||
"additional_headers": process_response_headers(
|
"additional_headers": process_response_headers(
|
||||||
self._get_value_from_hidden_params("additional_headers") or {}
|
self._get_value_from_hidden_params("additional_headers") or {}
|
||||||
),
|
),
|
||||||
|
|
|
@ -1,7 +1,6 @@
|
||||||
import copy
|
import copy
|
||||||
import json
|
import json
|
||||||
import re
|
import re
|
||||||
import traceback
|
|
||||||
import uuid
|
import uuid
|
||||||
import xml.etree.ElementTree as ET
|
import xml.etree.ElementTree as ET
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
|
@ -748,7 +747,6 @@ def convert_to_anthropic_image_obj(
|
||||||
data=base64_data,
|
data=base64_data,
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
traceback.print_exc()
|
|
||||||
if "Error: Unable to fetch image from URL" in str(e):
|
if "Error: Unable to fetch image from URL" in str(e):
|
||||||
raise e
|
raise e
|
||||||
raise Exception(
|
raise Exception(
|
||||||
|
|
|
@ -100,7 +100,6 @@ async def cache_ping():
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
import traceback
|
import traceback
|
||||||
|
|
||||||
traceback.print_exc()
|
|
||||||
error_message = {
|
error_message = {
|
||||||
"message": f"Service Unhealthy ({str(e)})",
|
"message": f"Service Unhealthy ({str(e)})",
|
||||||
"litellm_cache_params": safe_dumps(litellm_cache_params),
|
"litellm_cache_params": safe_dumps(litellm_cache_params),
|
||||||
|
|
|
@ -816,9 +816,6 @@ async def add_member_to_organization(
|
||||||
return user_object, organization_membership
|
return user_object, organization_membership
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
import traceback
|
|
||||||
|
|
||||||
traceback.print_exc()
|
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Error adding member={member} to organization={organization_id}: {e}"
|
f"Error adding member={member} to organization={organization_id}: {e}"
|
||||||
)
|
)
|
||||||
|
|
|
@ -116,6 +116,7 @@ from litellm.types.router import (
|
||||||
AllowedFailsPolicy,
|
AllowedFailsPolicy,
|
||||||
AssistantsTypedDict,
|
AssistantsTypedDict,
|
||||||
CredentialLiteLLMParams,
|
CredentialLiteLLMParams,
|
||||||
|
CustomPricingLiteLLMParams,
|
||||||
CustomRoutingStrategyBase,
|
CustomRoutingStrategyBase,
|
||||||
Deployment,
|
Deployment,
|
||||||
DeploymentTypedDict,
|
DeploymentTypedDict,
|
||||||
|
@ -132,6 +133,7 @@ from litellm.types.router import (
|
||||||
)
|
)
|
||||||
from litellm.types.services import ServiceTypes
|
from litellm.types.services import ServiceTypes
|
||||||
from litellm.types.utils import GenericBudgetConfigType
|
from litellm.types.utils import GenericBudgetConfigType
|
||||||
|
from litellm.types.utils import ModelInfo
|
||||||
from litellm.types.utils import ModelInfo as ModelMapInfo
|
from litellm.types.utils import ModelInfo as ModelMapInfo
|
||||||
from litellm.types.utils import StandardLoggingPayload
|
from litellm.types.utils import StandardLoggingPayload
|
||||||
from litellm.utils import (
|
from litellm.utils import (
|
||||||
|
@ -3324,7 +3326,6 @@ class Router:
|
||||||
|
|
||||||
return response
|
return response
|
||||||
except Exception as new_exception:
|
except Exception as new_exception:
|
||||||
traceback.print_exc()
|
|
||||||
parent_otel_span = _get_parent_otel_span_from_kwargs(kwargs)
|
parent_otel_span = _get_parent_otel_span_from_kwargs(kwargs)
|
||||||
verbose_router_logger.error(
|
verbose_router_logger.error(
|
||||||
"litellm.router.py::async_function_with_fallbacks() - Error occurred while trying to do fallbacks - {}\n{}\n\nDebug Information:\nCooldown Deployments={}".format(
|
"litellm.router.py::async_function_with_fallbacks() - Error occurred while trying to do fallbacks - {}\n{}\n\nDebug Information:\nCooldown Deployments={}".format(
|
||||||
|
@ -4301,7 +4302,20 @@ class Router:
|
||||||
model_info=_model_info,
|
model_info=_model_info,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
for field in CustomPricingLiteLLMParams.model_fields.keys():
|
||||||
|
if deployment.litellm_params.get(field) is not None:
|
||||||
|
_model_info[field] = deployment.litellm_params[field]
|
||||||
|
|
||||||
## REGISTER MODEL INFO IN LITELLM MODEL COST MAP
|
## REGISTER MODEL INFO IN LITELLM MODEL COST MAP
|
||||||
|
model_id = deployment.model_info.id
|
||||||
|
if model_id is not None:
|
||||||
|
litellm.register_model(
|
||||||
|
model_cost={
|
||||||
|
model_id: _model_info,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
## OLD MODEL REGISTRATION ## Kept to prevent breaking changes
|
||||||
_model_name = deployment.litellm_params.model
|
_model_name = deployment.litellm_params.model
|
||||||
if deployment.litellm_params.custom_llm_provider is not None:
|
if deployment.litellm_params.custom_llm_provider is not None:
|
||||||
_model_name = (
|
_model_name = (
|
||||||
|
@ -4802,6 +4816,42 @@ class Router:
|
||||||
model_name = model_info["model_name"]
|
model_name = model_info["model_name"]
|
||||||
return self.get_model_list(model_name=model_name)
|
return self.get_model_list(model_name=model_name)
|
||||||
|
|
||||||
|
def get_deployment_model_info(
|
||||||
|
self, model_id: str, model_name: str
|
||||||
|
) -> Optional[ModelInfo]:
|
||||||
|
"""
|
||||||
|
For a given model id, return the model info
|
||||||
|
|
||||||
|
1. Check if model_id is in model info
|
||||||
|
2. If not, check if litellm model name is in model info
|
||||||
|
3. If not, return None
|
||||||
|
"""
|
||||||
|
from litellm.utils import _update_dictionary
|
||||||
|
|
||||||
|
model_info: Optional[ModelInfo] = None
|
||||||
|
litellm_model_name_model_info: Optional[ModelInfo] = None
|
||||||
|
|
||||||
|
try:
|
||||||
|
model_info = litellm.get_model_info(model=model_id)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
try:
|
||||||
|
litellm_model_name_model_info = litellm.get_model_info(model=model_name)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
if model_info is not None and litellm_model_name_model_info is not None:
|
||||||
|
model_info = cast(
|
||||||
|
ModelInfo,
|
||||||
|
_update_dictionary(
|
||||||
|
cast(dict, litellm_model_name_model_info).copy(),
|
||||||
|
cast(dict, model_info),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
return model_info
|
||||||
|
|
||||||
def _set_model_group_info( # noqa: PLR0915
|
def _set_model_group_info( # noqa: PLR0915
|
||||||
self, model_group: str, user_facing_model_group_name: str
|
self, model_group: str, user_facing_model_group_name: str
|
||||||
) -> Optional[ModelGroupInfo]:
|
) -> Optional[ModelGroupInfo]:
|
||||||
|
@ -4860,9 +4910,16 @@ class Router:
|
||||||
|
|
||||||
# get model info
|
# get model info
|
||||||
try:
|
try:
|
||||||
model_info = litellm.get_model_info(model=litellm_params.model)
|
model_id = model.get("model_info", {}).get("id", None)
|
||||||
|
if model_id is not None:
|
||||||
|
model_info = self.get_deployment_model_info(
|
||||||
|
model_id=model_id, model_name=litellm_params.model
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
model_info = None
|
||||||
except Exception:
|
except Exception:
|
||||||
model_info = None
|
model_info = None
|
||||||
|
|
||||||
# get llm provider
|
# get llm provider
|
||||||
litellm_model, llm_provider = "", ""
|
litellm_model, llm_provider = "", ""
|
||||||
try:
|
try:
|
||||||
|
|
|
@ -162,7 +162,15 @@ class CredentialLiteLLMParams(BaseModel):
|
||||||
watsonx_region_name: Optional[str] = None
|
watsonx_region_name: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
class GenericLiteLLMParams(CredentialLiteLLMParams):
|
class CustomPricingLiteLLMParams(BaseModel):
|
||||||
|
## CUSTOM PRICING ##
|
||||||
|
input_cost_per_token: Optional[float] = None
|
||||||
|
output_cost_per_token: Optional[float] = None
|
||||||
|
input_cost_per_second: Optional[float] = None
|
||||||
|
output_cost_per_second: Optional[float] = None
|
||||||
|
|
||||||
|
|
||||||
|
class GenericLiteLLMParams(CredentialLiteLLMParams, CustomPricingLiteLLMParams):
|
||||||
"""
|
"""
|
||||||
LiteLLM Params without 'model' arg (used across completion / assistants api)
|
LiteLLM Params without 'model' arg (used across completion / assistants api)
|
||||||
"""
|
"""
|
||||||
|
@ -184,12 +192,6 @@ class GenericLiteLLMParams(CredentialLiteLLMParams):
|
||||||
## LOGGING PARAMS ##
|
## LOGGING PARAMS ##
|
||||||
litellm_trace_id: Optional[str] = None
|
litellm_trace_id: Optional[str] = None
|
||||||
|
|
||||||
## CUSTOM PRICING ##
|
|
||||||
input_cost_per_token: Optional[float] = None
|
|
||||||
output_cost_per_token: Optional[float] = None
|
|
||||||
input_cost_per_second: Optional[float] = None
|
|
||||||
output_cost_per_second: Optional[float] = None
|
|
||||||
|
|
||||||
max_file_size_mb: Optional[float] = None
|
max_file_size_mb: Optional[float] = None
|
||||||
|
|
||||||
# Deployment budgets
|
# Deployment budgets
|
||||||
|
|
|
@ -2245,7 +2245,8 @@ def supports_embedding_image_input(
|
||||||
####### HELPER FUNCTIONS ################
|
####### HELPER FUNCTIONS ################
|
||||||
def _update_dictionary(existing_dict: Dict, new_dict: dict) -> dict:
|
def _update_dictionary(existing_dict: Dict, new_dict: dict) -> dict:
|
||||||
for k, v in new_dict.items():
|
for k, v in new_dict.items():
|
||||||
existing_dict[k] = v
|
if v is not None:
|
||||||
|
existing_dict[k] = v
|
||||||
|
|
||||||
return existing_dict
|
return existing_dict
|
||||||
|
|
||||||
|
|
|
@ -31,7 +31,7 @@ def get_all_functions_called_in_tests(base_dir):
|
||||||
specifically in files containing the word 'router'.
|
specifically in files containing the word 'router'.
|
||||||
"""
|
"""
|
||||||
called_functions = set()
|
called_functions = set()
|
||||||
test_dirs = ["local_testing", "router_unit_tests"]
|
test_dirs = ["local_testing", "router_unit_tests", "litellm"]
|
||||||
|
|
||||||
for test_dir in test_dirs:
|
for test_dir in test_dirs:
|
||||||
dir_path = os.path.join(base_dir, test_dir)
|
dir_path = os.path.join(base_dir, test_dir)
|
||||||
|
|
|
@ -151,3 +151,63 @@ def test_handle_realtime_stream_cost_calculation():
|
||||||
litellm_model_name="gpt-3.5-turbo",
|
litellm_model_name="gpt-3.5-turbo",
|
||||||
)
|
)
|
||||||
assert cost == 0.0 # No usage, no cost
|
assert cost == 0.0 # No usage, no cost
|
||||||
|
|
||||||
|
|
||||||
|
def test_custom_pricing_with_router_model_id():
|
||||||
|
from litellm import Router
|
||||||
|
|
||||||
|
router = Router(
|
||||||
|
model_list=[
|
||||||
|
{
|
||||||
|
"model_name": "prod/claude-3-5-sonnet-20240620",
|
||||||
|
"litellm_params": {
|
||||||
|
"model": "anthropic/claude-3-5-sonnet-20240620",
|
||||||
|
"api_key": "test_api_key",
|
||||||
|
},
|
||||||
|
"model_info": {
|
||||||
|
"id": "my-unique-model-id",
|
||||||
|
"input_cost_per_token": 0.000006,
|
||||||
|
"output_cost_per_token": 0.00003,
|
||||||
|
"cache_creation_input_token_cost": 0.0000075,
|
||||||
|
"cache_read_input_token_cost": 0.0000006,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"model_name": "claude-3-5-sonnet-20240620",
|
||||||
|
"litellm_params": {
|
||||||
|
"model": "anthropic/claude-3-5-sonnet-20240620",
|
||||||
|
"api_key": "test_api_key",
|
||||||
|
},
|
||||||
|
"model_info": {
|
||||||
|
"input_cost_per_token": 100,
|
||||||
|
"output_cost_per_token": 200,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
result = router.completion(
|
||||||
|
model="claude-3-5-sonnet-20240620",
|
||||||
|
messages=[{"role": "user", "content": "Hello, world!"}],
|
||||||
|
mock_response=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
result_2 = router.completion(
|
||||||
|
model="prod/claude-3-5-sonnet-20240620",
|
||||||
|
messages=[{"role": "user", "content": "Hello, world!"}],
|
||||||
|
mock_response=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert (
|
||||||
|
result._hidden_params["response_cost"]
|
||||||
|
> result_2._hidden_params["response_cost"]
|
||||||
|
)
|
||||||
|
|
||||||
|
model_info = router.get_deployment_model_info(
|
||||||
|
model_id="my-unique-model-id", model_name="anthropic/claude-3-5-sonnet-20240620"
|
||||||
|
)
|
||||||
|
assert model_info is not None
|
||||||
|
assert model_info["input_cost_per_token"] == 0.000006
|
||||||
|
assert model_info["output_cost_per_token"] == 0.00003
|
||||||
|
assert model_info["cache_creation_input_token_cost"] == 0.0000075
|
||||||
|
assert model_info["cache_read_input_token_cost"] == 0.0000006
|
||||||
|
|
|
@ -2954,9 +2954,6 @@ def test_cost_calculator_with_custom_pricing():
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_cost_calculator_with_custom_pricing_router(model_item, custom_pricing):
|
async def test_cost_calculator_with_custom_pricing_router(model_item, custom_pricing):
|
||||||
from litellm import Router
|
from litellm import Router
|
||||||
|
|
||||||
litellm._turn_on_debug()
|
|
||||||
|
|
||||||
if custom_pricing == "litellm_params":
|
if custom_pricing == "litellm_params":
|
||||||
model_item["litellm_params"]["input_cost_per_token"] = 0.0000008
|
model_item["litellm_params"]["input_cost_per_token"] = 0.0000008
|
||||||
model_item["litellm_params"]["output_cost_per_token"] = 0.0000032
|
model_item["litellm_params"]["output_cost_per_token"] = 0.0000032
|
||||||
|
|
|
@ -314,12 +314,14 @@ def test_get_model_info_custom_model_router():
|
||||||
"input_cost_per_token": 1,
|
"input_cost_per_token": 1,
|
||||||
"output_cost_per_token": 1,
|
"output_cost_per_token": 1,
|
||||||
"model": "openai/meta-llama/Meta-Llama-3-8B-Instruct",
|
"model": "openai/meta-llama/Meta-Llama-3-8B-Instruct",
|
||||||
"model_id": "c20d603e-1166-4e0f-aa65-ed9c476ad4ca",
|
|
||||||
},
|
},
|
||||||
|
"model_info": {
|
||||||
|
"id": "c20d603e-1166-4e0f-aa65-ed9c476ad4ca",
|
||||||
|
}
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
info = get_model_info("openai/meta-llama/Meta-Llama-3-8B-Instruct")
|
info = get_model_info("c20d603e-1166-4e0f-aa65-ed9c476ad4ca")
|
||||||
print("info", info)
|
print("info", info)
|
||||||
assert info is not None
|
assert info is not None
|
||||||
|
|
||||||
|
|
|
@ -451,3 +451,11 @@ def test_router_get_deployment_credentials():
|
||||||
credentials = router.get_deployment_credentials(model_id="1")
|
credentials = router.get_deployment_credentials(model_id="1")
|
||||||
assert credentials is not None
|
assert credentials is not None
|
||||||
assert credentials["api_key"] == "123"
|
assert credentials["api_key"] == "123"
|
||||||
|
|
||||||
|
|
||||||
|
def test_router_get_deployment_model_info():
|
||||||
|
router = Router(
|
||||||
|
model_list=[{"model_name": "gemini/*", "litellm_params": {"model": "gemini/*"}, "model_info": {"id": "1"}}]
|
||||||
|
)
|
||||||
|
model_info = router.get_deployment_model_info(model_id="1", model_name="gemini/gemini-1.5-flash")
|
||||||
|
assert model_info is not None
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue