diff --git a/litellm/cost_calculator.py b/litellm/cost_calculator.py index 9bd5f90be3..b680fc8a54 100644 --- a/litellm/cost_calculator.py +++ b/litellm/cost_calculator.py @@ -509,7 +509,7 @@ def completion_cost( ): model = completion_response._hidden_params.get("model", model) custom_llm_provider = completion_response._hidden_params.get( - "custom_llm_provider", "" + "custom_llm_provider", custom_llm_provider or "" ) region_name = completion_response._hidden_params.get( "region_name", region_name @@ -732,7 +732,7 @@ def response_cost_calculator( ) return response_cost except litellm.NotFoundError as e: - print_verbose( + verbose_logger.debug( # debug since it can be spammy in logs, for calls f"Model={model} for LLM Provider={custom_llm_provider} not found in completion cost map." ) return None diff --git a/litellm/llms/databricks.py b/litellm/llms/databricks.py index 363b222fea..0567e6e053 100644 --- a/litellm/llms/databricks.py +++ b/litellm/llms/databricks.py @@ -259,87 +259,6 @@ class DatabricksChatCompletion(BaseLLM): api_base = "{}/embeddings".format(api_base) return api_base, headers - def process_response( - self, - model: str, - response: Union[requests.Response, httpx.Response], - model_response: ModelResponse, - stream: bool, - logging_obj: litellm.litellm_core_utils.litellm_logging.Logging, - optional_params: dict, - api_key: str, - data: Union[dict, str], - messages: List, - print_verbose, - encoding, - ) -> ModelResponse: - ## LOGGING - logging_obj.post_call( - input=messages, - api_key=api_key, - original_response=response.text, - additional_args={"complete_input_dict": data}, - ) - print_verbose(f"raw model_response: {response.text}") - ## RESPONSE OBJECT - try: - completion_response = response.json() - except: - raise DatabricksError( - message=response.text, status_code=response.status_code - ) - if "error" in completion_response: - raise DatabricksError( - message=str(completion_response["error"]), - status_code=response.status_code, - ) - else: - text_content = "" - tool_calls = [] - for content in completion_response["content"]: - if content["type"] == "text": - text_content += content["text"] - ## TOOL CALLING - elif content["type"] == "tool_use": - tool_calls.append( - { - "id": content["id"], - "type": "function", - "function": { - "name": content["name"], - "arguments": json.dumps(content["input"]), - }, - } - ) - - _message = litellm.Message( - tool_calls=tool_calls, - content=text_content or None, - ) - model_response.choices[0].message = _message # type: ignore - model_response._hidden_params["original_response"] = completion_response[ - "content" - ] # allow user to access raw anthropic tool calling response - - model_response.choices[0].finish_reason = map_finish_reason( - completion_response["stop_reason"] - ) - - ## CALCULATING USAGE - prompt_tokens = completion_response["usage"]["input_tokens"] - completion_tokens = completion_response["usage"]["output_tokens"] - total_tokens = prompt_tokens + completion_tokens - - model_response.created = int(time.time()) - model_response.model = model - usage = Usage( - prompt_tokens=prompt_tokens, - completion_tokens=completion_tokens, - total_tokens=total_tokens, - ) - setattr(model_response, "usage", usage) # type: ignore - return model_response - async def acompletion_stream_function( self, model: str, @@ -392,6 +311,7 @@ class DatabricksChatCompletion(BaseLLM): logging_obj, stream, data: dict, + base_model: Optional[str], optional_params: dict, litellm_params=None, logger_fn=None, @@ -420,7 +340,11 @@ class DatabricksChatCompletion(BaseLLM): except Exception as e: raise DatabricksError(status_code=500, message=str(e)) - return ModelResponse(**response_json) + response = ModelResponse(**response_json) + + if base_model is not None: + response._hidden_params["model"] = base_model + return response def completion( self, @@ -443,6 +367,7 @@ class DatabricksChatCompletion(BaseLLM): client: Optional[Union[HTTPHandler, AsyncHTTPHandler]] = None, ): custom_endpoint: Optional[bool] = optional_params.pop("custom_endpoint", None) + base_model: Optional[str] = optional_params.pop("base_model", None) api_base, headers = self._validate_environment( api_base=api_base, api_key=api_key, @@ -476,11 +401,11 @@ class DatabricksChatCompletion(BaseLLM): "headers": headers, }, ) - if acompletion == True: + if acompletion is True: if client is not None and isinstance(client, HTTPHandler): client = None if ( - stream is not None and stream == True + stream is not None and stream is True ): # if function call - fake the streaming (need complete blocks for output parsing in openai format) print_verbose("makes async anthropic streaming POST request") data["stream"] = stream @@ -521,6 +446,7 @@ class DatabricksChatCompletion(BaseLLM): logger_fn=logger_fn, headers=headers, timeout=timeout, + base_model=base_model, ) else: if client is None or not isinstance(client, HTTPHandler): @@ -562,7 +488,12 @@ class DatabricksChatCompletion(BaseLLM): except Exception as e: raise DatabricksError(status_code=500, message=str(e)) - return ModelResponse(**response_json) + response = ModelResponse(**response_json) + + if base_model is not None: + response._hidden_params["model"] = base_model + + return response async def aembedding( self, diff --git a/litellm/tests/test_amazing_vertex_completion.py b/litellm/tests/test_amazing_vertex_completion.py index 5419c25ff1..e2d35c9720 100644 --- a/litellm/tests/test_amazing_vertex_completion.py +++ b/litellm/tests/test_amazing_vertex_completion.py @@ -939,6 +939,8 @@ async def test_partner_models_httpx(model, sync_mode): response_format_tests(response=response) print(f"response: {response}") + + assert response._hidden_params["response_cost"] > 0 except litellm.RateLimitError as e: pass except Exception as e: diff --git a/litellm/tests/test_completion_cost.py b/litellm/tests/test_completion_cost.py index 41448bd562..53cbaa31dc 100644 --- a/litellm/tests/test_completion_cost.py +++ b/litellm/tests/test_completion_cost.py @@ -918,6 +918,42 @@ def test_vertex_ai_llama_predict_cost(): assert predictive_cost == 0 +def test_vertex_ai_mistral_predict_cost(): + from litellm.types.utils import Choices, Message, ModelResponse, Usage + + response_object = ModelResponse( + id="26c0ef045020429d9c5c9b078c01e564", + choices=[ + Choices( + finish_reason="stop", + index=0, + message=Message( + content="Hello! I'm Litellm Bot, your helpful assistant. While I can't provide real-time weather updates, I can help you find a reliable weather service or guide you on how to check the weather on your device. Would you like assistance with that?", + role="assistant", + tool_calls=None, + function_call=None, + ), + ) + ], + created=1722124652, + model="vertex_ai/mistral-large", + object="chat.completion", + system_fingerprint=None, + usage=Usage(prompt_tokens=32, completion_tokens=55, total_tokens=87), + ) + model = "mistral-large@2407" + messages = [{"role": "user", "content": "Hey, hows it going???"}] + custom_llm_provider = "vertex_ai" + predictive_cost = completion_cost( + completion_response=response_object, + model=model, + messages=messages, + custom_llm_provider=custom_llm_provider, + ) + + assert predictive_cost > 0 + + @pytest.mark.parametrize("model", ["openai/tts-1", "azure/tts-1"]) def test_completion_cost_tts(model): os.environ["LITELLM_LOCAL_MODEL_COST_MAP"] = "True"