From b9f01c9f5bee01741fd0abb5737f0f6fe6f42e6a Mon Sep 17 00:00:00 2001 From: Krish Dholakia Date: Fri, 11 Apr 2025 23:20:49 -0700 Subject: [PATCH] fix(databricks/common_utils.py): fix custom endpoint check (#9925) * fix(databricks/common_utils.py): fix custom endpoint check Fixes https://github.com/BerriAI/litellm/issues/9915 * fix(common_utils.py): add unit test to ensure custom_endpoint=False is handled correctly Fixes https://github.com/BerriAI/litellm/issues/9915 --- litellm/llms/databricks/common_utils.py | 2 +- .../test_databricks_common_utils.py | 32 +++++++++++++++++++ 2 files changed, 33 insertions(+), 1 deletion(-) create mode 100644 tests/litellm/llms/databricks/test_databricks_common_utils.py diff --git a/litellm/llms/databricks/common_utils.py b/litellm/llms/databricks/common_utils.py index eab9e2f825..1353b5b13f 100644 --- a/litellm/llms/databricks/common_utils.py +++ b/litellm/llms/databricks/common_utils.py @@ -68,7 +68,7 @@ class DatabricksBase: headers: Optional[dict], ) -> Tuple[str, dict]: if api_key is None and not headers: # handle empty headers - if custom_endpoint is not None: + if custom_endpoint is True: raise DatabricksException( status_code=400, message="Missing API Key - A call is being made to LLM Provider but no key is set either in the environment variables ({LLM_PROVIDER}_API_KEY) or via params", diff --git a/tests/litellm/llms/databricks/test_databricks_common_utils.py b/tests/litellm/llms/databricks/test_databricks_common_utils.py new file mode 100644 index 0000000000..7f7ec8e900 --- /dev/null +++ b/tests/litellm/llms/databricks/test_databricks_common_utils.py @@ -0,0 +1,32 @@ +import json +import os +import sys + +import pytest +from fastapi.testclient import TestClient + +sys.path.insert( + 0, os.path.abspath("../../../..") +) # Adds the parent directory to the system path +from unittest.mock import MagicMock, patch + +from litellm.llms.databricks.common_utils import DatabricksBase + + +def test_databricks_validate_environment(): + databricks_base = DatabricksBase() + + with patch.object( + databricks_base, "_get_databricks_credentials" + ) as mock_get_credentials: + try: + databricks_base.databricks_validate_environment( + api_key=None, + api_base="my_api_base", + endpoint_type="chat_completions", + custom_endpoint=False, + headers=None, + ) + except Exception: + pass + mock_get_credentials.assert_called_once()