Add DBRX Anthropic w/ thinking + response_format support (#9744)

* feat(databricks/chat/): add anthropic w/ reasoning content support via databricks

Allows user to call claude-3-7-sonnet with thinking via databricks

* refactor: refactor choices transformation + add unit testing

* fix(databricks/chat/transformation.py): support thinking blocks on databricks response streaming

* feat(databricks/chat/transformation.py): support response_format for claude models

* fix(databricks/chat/transformation.py): correctly handle response_format={"type": "text"}

* feat(databricks/chat/transformation.py): support 'reasoning_effort' param mapping for anthropic

* fix: fix ruff errors

* fix: fix linting error

* test: update test

* fix(databricks/chat/transformation.py): handle json mode output parsing

* fix(databricks/chat/transformation.py): handle json mode on streaming

* test: update test

* test: update dbrx testing

* test: update testing

* fix(base_model_iterator.py): handle non-json chunk

* test: update tests

* fix: fix ruff check

* fix: fix databricks config import

* fix: handle _tool = none

* test: skip invalid test
This commit is contained in:
Krish Dholakia 2025-04-04 22:13:32 -07:00 committed by GitHub
parent e3b231bc11
commit 5099aac1a5
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
19 changed files with 872 additions and 340 deletions

View file

@ -2,6 +2,7 @@ import json
from abc import abstractmethod
from typing import Optional, Union
import litellm
from litellm.types.utils import GenericStreamingChunk, ModelResponseStream
@ -33,6 +34,18 @@ class BaseModelResponseIterator:
self, str_line: str
) -> Union[GenericStreamingChunk, ModelResponseStream]:
# chunk is a str at this point
stripped_chunk = litellm.CustomStreamWrapper._strip_sse_data_from_chunk(
str_line
)
try:
if stripped_chunk is not None:
stripped_json_chunk: Optional[dict] = json.loads(stripped_chunk)
else:
stripped_json_chunk = None
except json.JSONDecodeError:
stripped_json_chunk = None
if "[DONE]" in str_line:
return GenericStreamingChunk(
text="",
@ -42,9 +55,8 @@ class BaseModelResponseIterator:
index=0,
tool_use=None,
)
elif str_line.startswith("data:"):
data_json = json.loads(str_line[5:])
return self.chunk_parser(chunk=data_json)
elif stripped_json_chunk:
return self.chunk_parser(chunk=stripped_json_chunk)
else:
return GenericStreamingChunk(
text="",
@ -85,6 +97,7 @@ class BaseModelResponseIterator:
async def __anext__(self):
try:
chunk = await self.async_response_iterator.__anext__()
except StopAsyncIteration:
raise StopAsyncIteration
except ValueError as e:
@ -99,7 +112,9 @@ class BaseModelResponseIterator:
str_line = str_line[index:]
# chunk is a str at this point
return self._handle_string_chunk(str_line=str_line)
chunk = self._handle_string_chunk(str_line=str_line)
return chunk
except StopAsyncIteration:
raise StopAsyncIteration
except ValueError as e: