diff --git a/litellm/llms/databricks/common_utils.py b/litellm/llms/databricks/common_utils.py index eab9e2f825..1353b5b13f 100644 --- a/litellm/llms/databricks/common_utils.py +++ b/litellm/llms/databricks/common_utils.py @@ -68,7 +68,7 @@ class DatabricksBase: headers: Optional[dict], ) -> Tuple[str, dict]: if api_key is None and not headers: # handle empty headers - if custom_endpoint is not None: + if custom_endpoint is True: raise DatabricksException( status_code=400, message="Missing API Key - A call is being made to LLM Provider but no key is set either in the environment variables ({LLM_PROVIDER}_API_KEY) or via params", diff --git a/tests/litellm/llms/databricks/test_databricks_common_utils.py b/tests/litellm/llms/databricks/test_databricks_common_utils.py new file mode 100644 index 0000000000..7f7ec8e900 --- /dev/null +++ b/tests/litellm/llms/databricks/test_databricks_common_utils.py @@ -0,0 +1,32 @@ +import json +import os +import sys + +import pytest +from fastapi.testclient import TestClient + +sys.path.insert( + 0, os.path.abspath("../../../..") +) # Adds the parent directory to the system path +from unittest.mock import MagicMock, patch + +from litellm.llms.databricks.common_utils import DatabricksBase + + +def test_databricks_validate_environment(): + databricks_base = DatabricksBase() + + with patch.object( + databricks_base, "_get_databricks_credentials" + ) as mock_get_credentials: + try: + databricks_base.databricks_validate_environment( + api_key=None, + api_base="my_api_base", + endpoint_type="chat_completions", + custom_endpoint=False, + headers=None, + ) + except Exception: + pass + mock_get_credentials.assert_called_once()