Add DBRX Anthropic w/ thinking + response_format support (#9744)

* feat(databricks/chat/): add anthropic w/ reasoning content support via databricks

Allows user to call claude-3-7-sonnet with thinking via databricks

* refactor: refactor choices transformation + add unit testing

* fix(databricks/chat/transformation.py): support thinking blocks on databricks response streaming

* feat(databricks/chat/transformation.py): support response_format for claude models

* fix(databricks/chat/transformation.py): correctly handle response_format={"type": "text"}

* feat(databricks/chat/transformation.py): support 'reasoning_effort' param mapping for anthropic

* fix: fix ruff errors

* fix: fix linting error

* test: update test

* fix(databricks/chat/transformation.py): handle json mode output parsing

* fix(databricks/chat/transformation.py): handle json mode on streaming

* test: update test

* test: update dbrx testing

* test: update testing

* fix(base_model_iterator.py): handle non-json chunk

* test: update tests

* fix: fix ruff check

* fix: fix databricks config import

* fix: handle _tool = none

* test: skip invalid test
This commit is contained in:
Krish Dholakia 2025-04-04 22:13:32 -07:00 committed by GitHub
parent e3b231bc11
commit 5099aac1a5
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
19 changed files with 872 additions and 340 deletions

View file

@ -1,9 +1,35 @@
from typing import Literal, Optional, Tuple
from .exceptions import DatabricksError
from litellm.llms.base_llm.chat.transformation import BaseLLMException
class DatabricksException(BaseLLMException):
pass
class DatabricksBase:
def _get_api_base(self, api_base: Optional[str]) -> str:
if api_base is None:
try:
from databricks.sdk import WorkspaceClient
databricks_client = WorkspaceClient()
api_base = (
api_base or f"{databricks_client.config.host}/serving-endpoints"
)
return api_base
except ImportError:
raise DatabricksException(
status_code=400,
message=(
"Either set the DATABRICKS_API_BASE and DATABRICKS_API_KEY environment variables, "
"or install the databricks-sdk Python library."
),
)
return api_base
def _get_databricks_credentials(
self, api_key: Optional[str], api_base: Optional[str], headers: Optional[dict]
) -> Tuple[str, dict]:
@ -23,7 +49,7 @@ class DatabricksBase:
return api_base, headers
except ImportError:
raise DatabricksError(
raise DatabricksException(
status_code=400,
message=(
"If the Databricks base URL and API key are not set, the databricks-sdk "
@ -43,7 +69,7 @@ class DatabricksBase:
) -> Tuple[str, dict]:
if api_key is None and headers is None:
if custom_endpoint is not None:
raise DatabricksError(
raise DatabricksException(
status_code=400,
message="Missing API Key - A call is being made to LLM Provider but no key is set either in the environment variables ({LLM_PROVIDER}_API_KEY) or via params",
)
@ -54,7 +80,7 @@ class DatabricksBase:
if api_base is None:
if custom_endpoint:
raise DatabricksError(
raise DatabricksException(
status_code=400,
message="Missing API Base - A call is being made to LLM Provider but no api base is set either in the environment variables ({LLM_PROVIDER}_API_KEY) or via params",
)