diff --git a/litellm/cost_calculator.py b/litellm/cost_calculator.py index 6a885858bc..df126fd648 100644 --- a/litellm/cost_calculator.py +++ b/litellm/cost_calculator.py @@ -530,6 +530,7 @@ def completion_cost( # noqa: PLR0915 - For un-mapped Replicate models, the cost is calculated based on the total time used for the request. """ try: + call_type = _infer_call_type(call_type, completion_response) or "completion" if ( diff --git a/litellm/litellm_core_utils/get_litellm_params.py b/litellm/litellm_core_utils/get_litellm_params.py new file mode 100644 index 0000000000..3d8394f7af --- /dev/null +++ b/litellm/litellm_core_utils/get_litellm_params.py @@ -0,0 +1,101 @@ +from typing import Optional + + +def _get_base_model_from_litellm_call_metadata( + metadata: Optional[dict], +) -> Optional[str]: + if metadata is None: + return None + + if metadata is not None: + model_info = metadata.get("model_info", {}) + + if model_info is not None: + base_model = model_info.get("base_model", None) + if base_model is not None: + return base_model + return None + + +def get_litellm_params( + api_key: Optional[str] = None, + force_timeout=600, + azure=False, + logger_fn=None, + verbose=False, + hugging_face=False, + replicate=False, + together_ai=False, + custom_llm_provider: Optional[str] = None, + api_base: Optional[str] = None, + litellm_call_id=None, + model_alias_map=None, + completion_call_id=None, + metadata: Optional[dict] = None, + model_info=None, + proxy_server_request=None, + acompletion=None, + aembedding=None, + preset_cache_key=None, + no_log=None, + input_cost_per_second=None, + input_cost_per_token=None, + output_cost_per_token=None, + output_cost_per_second=None, + cooldown_time=None, + text_completion=None, + azure_ad_token_provider=None, + user_continue_message=None, + base_model: Optional[str] = None, + litellm_trace_id: Optional[str] = None, + hf_model_name: Optional[str] = None, + custom_prompt_dict: Optional[dict] = None, + litellm_metadata: Optional[dict] = None, + disable_add_transform_inline_image_block: Optional[bool] = None, + drop_params: Optional[bool] = None, + prompt_id: Optional[str] = None, + prompt_variables: Optional[dict] = None, + async_call: Optional[bool] = None, + ssl_verify: Optional[bool] = None, + **kwargs, +) -> dict: + litellm_params = { + "acompletion": acompletion, + "api_key": api_key, + "force_timeout": force_timeout, + "logger_fn": logger_fn, + "verbose": verbose, + "custom_llm_provider": custom_llm_provider, + "api_base": api_base, + "litellm_call_id": litellm_call_id, + "model_alias_map": model_alias_map, + "completion_call_id": completion_call_id, + "aembedding": aembedding, + "metadata": metadata, + "model_info": model_info, + "proxy_server_request": proxy_server_request, + "preset_cache_key": preset_cache_key, + "no-log": no_log, + "stream_response": {}, # litellm_call_id: ModelResponse Dict + "input_cost_per_token": input_cost_per_token, + "input_cost_per_second": input_cost_per_second, + "output_cost_per_token": output_cost_per_token, + "output_cost_per_second": output_cost_per_second, + "cooldown_time": cooldown_time, + "text_completion": text_completion, + "azure_ad_token_provider": azure_ad_token_provider, + "user_continue_message": user_continue_message, + "base_model": base_model + or _get_base_model_from_litellm_call_metadata(metadata=metadata), + "litellm_trace_id": litellm_trace_id, + "hf_model_name": hf_model_name, + "custom_prompt_dict": custom_prompt_dict, + "litellm_metadata": litellm_metadata, + "disable_add_transform_inline_image_block": disable_add_transform_inline_image_block, + "drop_params": drop_params, + "prompt_id": prompt_id, + "prompt_variables": prompt_variables, + "async_call": async_call, + "ssl_verify": ssl_verify, + } + return litellm_params diff --git a/litellm/litellm_core_utils/litellm_logging.py b/litellm/litellm_core_utils/litellm_logging.py index 957b73c923..b8f1bff293 100644 --- a/litellm/litellm_core_utils/litellm_logging.py +++ b/litellm/litellm_core_utils/litellm_logging.py @@ -32,6 +32,7 @@ from litellm.integrations.custom_guardrail import CustomGuardrail from litellm.integrations.custom_logger import CustomLogger from litellm.integrations.mlflow import MlflowLogger from litellm.integrations.pagerduty.pagerduty import PagerDutyAlerting +from litellm.litellm_core_utils.get_litellm_params import get_litellm_params from litellm.litellm_core_utils.redact_messages import ( redact_message_input_output_from_custom_logger, redact_message_input_output_from_logging, @@ -257,10 +258,19 @@ class Logging(LiteLLMLoggingBaseClass): self.completion_start_time: Optional[datetime.datetime] = None self._llm_caching_handler: Optional[LLMCachingHandler] = None + # INITIAL LITELLM_PARAMS + litellm_params = {} + if kwargs is not None: + litellm_params = get_litellm_params(**kwargs) + litellm_params = scrub_sensitive_keys_in_metadata(litellm_params) + + self.litellm_params = litellm_params + self.model_call_details: Dict[str, Any] = { "litellm_trace_id": litellm_trace_id, "litellm_call_id": litellm_call_id, "input": _input, + "litellm_params": litellm_params, } def process_dynamic_callbacks(self): @@ -359,7 +369,10 @@ class Logging(LiteLLMLoggingBaseClass): if model is not None: self.model = model self.user = user - self.litellm_params = scrub_sensitive_keys_in_metadata(litellm_params) + self.litellm_params = { + **self.litellm_params, + **scrub_sensitive_keys_in_metadata(litellm_params), + } self.logger_fn = litellm_params.get("logger_fn", None) verbose_logger.debug(f"self.optional_params: {self.optional_params}") @@ -785,6 +798,7 @@ class Logging(LiteLLMLoggingBaseClass): used for consistent cost calculation across response headers + logging integrations. """ + ## RESPONSE COST ## custom_pricing = use_custom_pricing_for_model( litellm_params=( diff --git a/litellm/litellm_core_utils/mock_functions.py b/litellm/litellm_core_utils/mock_functions.py index a6e560c751..9f62e0479b 100644 --- a/litellm/litellm_core_utils/mock_functions.py +++ b/litellm/litellm_core_utils/mock_functions.py @@ -1,6 +1,12 @@ from typing import List, Optional -from ..types.utils import Embedding, EmbeddingResponse, ImageObject, ImageResponse +from ..types.utils import ( + Embedding, + EmbeddingResponse, + ImageObject, + ImageResponse, + Usage, +) def mock_embedding(model: str, mock_response: Optional[List[float]]): @@ -9,6 +15,7 @@ def mock_embedding(model: str, mock_response: Optional[List[float]]): return EmbeddingResponse( model=model, data=[Embedding(embedding=mock_response, index=0, object="embedding")], + usage=Usage(prompt_tokens=10, completion_tokens=0), ) diff --git a/litellm/main.py b/litellm/main.py index 93cf16c601..c6774c9f50 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -3224,8 +3224,6 @@ def embedding( # noqa: PLR0915 **non_default_params, ) - if mock_response is not None: - return mock_embedding(model=model, mock_response=mock_response) ### REGISTER CUSTOM MODEL PRICING -- IF GIVEN ### if input_cost_per_token is not None and output_cost_per_token is not None: litellm.register_model( @@ -3248,28 +3246,22 @@ def embedding( # noqa: PLR0915 } } ) + litellm_params_dict = get_litellm_params(**kwargs) + + logging: Logging = litellm_logging_obj # type: ignore + logging.update_environment_variables( + model=model, + user=user, + optional_params=optional_params, + litellm_params=litellm_params_dict, + custom_llm_provider=custom_llm_provider, + ) + + if mock_response is not None: + return mock_embedding(model=model, mock_response=mock_response) try: response: Optional[EmbeddingResponse] = None - logging: Logging = litellm_logging_obj # type: ignore - logging.update_environment_variables( - model=model, - user=user, - optional_params=optional_params, - litellm_params={ - "timeout": timeout, - "azure": azure, - "litellm_call_id": litellm_call_id, - "logger_fn": logger_fn, - "proxy_server_request": proxy_server_request, - "model_info": model_info, - "metadata": metadata, - "aembedding": aembedding, - "preset_cache_key": None, - "stream_response": {}, - "cooldown_time": cooldown_time, - }, - custom_llm_provider=custom_llm_provider, - ) + if azure is True or custom_llm_provider == "azure": # azure configs api_type = get_secret_str("AZURE_API_TYPE") or "azure" diff --git a/litellm/utils.py b/litellm/utils.py index dd43355f01..a878802ed3 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -71,6 +71,10 @@ from litellm.litellm_core_utils.exception_mapping_utils import ( exception_type, get_error_message, ) +from litellm.litellm_core_utils.get_litellm_params import ( + _get_base_model_from_litellm_call_metadata, + get_litellm_params, +) from litellm.litellm_core_utils.get_llm_provider_logic import ( _is_non_openai_azure_model, get_llm_provider, @@ -2094,88 +2098,6 @@ def register_model(model_cost: Union[str, dict]): # noqa: PLR0915 return model_cost -def get_litellm_params( - api_key: Optional[str] = None, - force_timeout=600, - azure=False, - logger_fn=None, - verbose=False, - hugging_face=False, - replicate=False, - together_ai=False, - custom_llm_provider: Optional[str] = None, - api_base: Optional[str] = None, - litellm_call_id=None, - model_alias_map=None, - completion_call_id=None, - metadata: Optional[dict] = None, - model_info=None, - proxy_server_request=None, - acompletion=None, - preset_cache_key=None, - no_log=None, - input_cost_per_second=None, - input_cost_per_token=None, - output_cost_per_token=None, - output_cost_per_second=None, - cooldown_time=None, - text_completion=None, - azure_ad_token_provider=None, - user_continue_message=None, - base_model: Optional[str] = None, - litellm_trace_id: Optional[str] = None, - hf_model_name: Optional[str] = None, - custom_prompt_dict: Optional[dict] = None, - litellm_metadata: Optional[dict] = None, - disable_add_transform_inline_image_block: Optional[bool] = None, - drop_params: Optional[bool] = None, - prompt_id: Optional[str] = None, - prompt_variables: Optional[dict] = None, - async_call: Optional[bool] = None, - ssl_verify: Optional[bool] = None, - **kwargs, -) -> dict: - litellm_params = { - "acompletion": acompletion, - "api_key": api_key, - "force_timeout": force_timeout, - "logger_fn": logger_fn, - "verbose": verbose, - "custom_llm_provider": custom_llm_provider, - "api_base": api_base, - "litellm_call_id": litellm_call_id, - "model_alias_map": model_alias_map, - "completion_call_id": completion_call_id, - "metadata": metadata, - "model_info": model_info, - "proxy_server_request": proxy_server_request, - "preset_cache_key": preset_cache_key, - "no-log": no_log, - "stream_response": {}, # litellm_call_id: ModelResponse Dict - "input_cost_per_token": input_cost_per_token, - "input_cost_per_second": input_cost_per_second, - "output_cost_per_token": output_cost_per_token, - "output_cost_per_second": output_cost_per_second, - "cooldown_time": cooldown_time, - "text_completion": text_completion, - "azure_ad_token_provider": azure_ad_token_provider, - "user_continue_message": user_continue_message, - "base_model": base_model - or _get_base_model_from_litellm_call_metadata(metadata=metadata), - "litellm_trace_id": litellm_trace_id, - "hf_model_name": hf_model_name, - "custom_prompt_dict": custom_prompt_dict, - "litellm_metadata": litellm_metadata, - "disable_add_transform_inline_image_block": disable_add_transform_inline_image_block, - "drop_params": drop_params, - "prompt_id": prompt_id, - "prompt_variables": prompt_variables, - "async_call": async_call, - "ssl_verify": ssl_verify, - } - return litellm_params - - def _should_drop_param(k, additional_drop_params) -> bool: if ( additional_drop_params is not None @@ -5666,22 +5588,6 @@ def get_logging_id(start_time, response_obj): return None -def _get_base_model_from_litellm_call_metadata( - metadata: Optional[dict], -) -> Optional[str]: - if metadata is None: - return None - - if metadata is not None: - model_info = metadata.get("model_info", {}) - - if model_info is not None: - base_model = model_info.get("base_model", None) - if base_model is not None: - return base_model - return None - - def _get_base_model_from_metadata(model_call_details=None): if model_call_details is None: return None diff --git a/tests/local_testing/test_completion_cost.py b/tests/local_testing/test_completion_cost.py index 23ff873b56..f766692e7d 100644 --- a/tests/local_testing/test_completion_cost.py +++ b/tests/local_testing/test_completion_cost.py @@ -2772,3 +2772,92 @@ def test_bedrock_cost_calc_with_region(): aws_region_name="us-east-1", ) assert response._hidden_params["response_cost"] > 0 + + +# @pytest.mark.parametrize( +# "base_model_arg", [ +# {"base_model": "bedrock/anthropic.claude-3-sonnet-20240229-v1:0"}, +# {"model_info": "anthropic.claude-3-sonnet-20240229-v1:0"}, +# ] +# ) +def test_cost_calculator_with_base_model(): + resp = litellm.completion( + model="bedrock/random-model", + messages=[{"role": "user", "content": "Hello, how are you?"}], + base_model="bedrock/anthropic.claude-3-sonnet-20240229-v1:0", + mock_response="Hello, how are you?", + ) + assert resp.model == "random-model" + assert resp._hidden_params["response_cost"] > 0 + + +@pytest.mark.parametrize("base_model_arg", ["litellm_param", "model_info"]) +def test_cost_calculator_with_base_model_with_router(base_model_arg): + from litellm import Router + + model_item = { + "model_name": "random-model", + "litellm_params": { + "model": "bedrock/random-model", + }, + } + + if base_model_arg == "litellm_param": + model_item["litellm_params"][ + "base_model" + ] = "bedrock/anthropic.claude-3-sonnet-20240229-v1:0" + elif base_model_arg == "model_info": + model_item["model_info"] = { + "base_model": "bedrock/anthropic.claude-3-sonnet-20240229-v1:0", + } + + router = Router(model_list=[model_item]) + resp = router.completion( + model="random-model", + messages=[{"role": "user", "content": "Hello, how are you?"}], + mock_response="Hello, how are you?", + ) + assert resp.model == "random-model" + assert resp._hidden_params["response_cost"] > 0 + + +@pytest.mark.parametrize("base_model_arg", ["litellm_param", "model_info"]) +def test_cost_calculator_with_base_model_with_router_embedding(base_model_arg): + from litellm import Router + + litellm._turn_on_debug() + + model_item = { + "model_name": "random-model", + "litellm_params": { + "model": "bedrock/random-model", + }, + } + + if base_model_arg == "litellm_param": + model_item["litellm_params"]["base_model"] = "cohere.embed-english-v3" + elif base_model_arg == "model_info": + model_item["model_info"] = { + "base_model": "cohere.embed-english-v3", + } + + router = Router(model_list=[model_item]) + resp = router.embedding( + model="random-model", + input="Hello, how are you?", + mock_response=[1, 2, 3], + ) + assert resp.model == "random-model" + assert resp._hidden_params["response_cost"] > 0 + + +def test_cost_calculator_with_custom_pricing(): + resp = litellm.completion( + model="bedrock/random-model", + messages=[{"role": "user", "content": "Hello, how are you?"}], + mock_response="Hello, how are you?", + input_cost_per_token=0.0000008, + output_cost_per_token=0.0000032, + ) + assert resp.model == "random-model" + assert resp._hidden_params["response_cost"] > 0 diff --git a/tests/local_testing/test_utils.py b/tests/local_testing/test_utils.py index ea3649ca83..03ec25397a 100644 --- a/tests/local_testing/test_utils.py +++ b/tests/local_testing/test_utils.py @@ -1527,3 +1527,43 @@ def test_add_custom_logger_callback_to_specific_event_e2e(monkeypatch): assert len(litellm.success_callback) == curr_len_success_callback assert len(litellm.failure_callback) == curr_len_failure_callback + + +@pytest.mark.asyncio +async def test_wrapper_kwargs_passthrough(): + from litellm.utils import client + from litellm.litellm_core_utils.litellm_logging import ( + Logging as LiteLLMLoggingObject, + ) + + # Create mock original function + mock_original = AsyncMock() + + # Apply decorator + @client + async def test_function(**kwargs): + return await mock_original(**kwargs) + + # Test kwargs + test_kwargs = {"base_model": "gpt-4o-mini"} + + # Call decorated function + await test_function(**test_kwargs) + + mock_original.assert_called_once() + + # get litellm logging object + litellm_logging_obj: LiteLLMLoggingObject = mock_original.call_args.kwargs.get( + "litellm_logging_obj" + ) + assert litellm_logging_obj is not None + + print( + f"litellm_logging_obj.model_call_details: {litellm_logging_obj.model_call_details}" + ) + + # get base model + assert ( + litellm_logging_obj.model_call_details["litellm_params"]["base_model"] + == "gpt-4o-mini" + )