""" Translates from OpenAI's `/v1/chat/completions` to Databricks' `/chat/completions` """ from typing import ( TYPE_CHECKING, Any, AsyncIterator, Iterator, List, Optional, Tuple, Union, cast, ) import httpx from pydantic import BaseModel from litellm.constants import RESPONSE_FORMAT_TOOL_NAME from litellm.litellm_core_utils.llm_response_utils.convert_dict_to_response import ( _handle_invalid_parallel_tool_calls, _should_convert_tool_call_to_json_mode, ) from litellm.litellm_core_utils.prompt_templates.common_utils import ( handle_messages_with_content_list_to_str_conversion, strip_name_from_messages, ) from litellm.llms.base_llm.base_model_iterator import BaseModelResponseIterator from litellm.types.llms.anthropic import AllAnthropicToolsValues from litellm.types.llms.databricks import ( AllDatabricksContentValues, DatabricksChoice, DatabricksFunction, DatabricksResponse, DatabricksTool, ) from litellm.types.llms.openai import ( AllMessageValues, ChatCompletionRedactedThinkingBlock, ChatCompletionThinkingBlock, ChatCompletionToolChoiceFunctionParam, ChatCompletionToolChoiceObjectParam, ) from litellm.types.utils import ( ChatCompletionMessageToolCall, Choices, Message, ModelResponse, ModelResponseStream, ProviderField, Usage, ) from ...anthropic.chat.transformation import AnthropicConfig from ...openai_like.chat.transformation import OpenAILikeChatConfig from ..common_utils import DatabricksBase, DatabricksException if TYPE_CHECKING: from litellm.litellm_core_utils.litellm_logging import Logging as _LiteLLMLoggingObj LiteLLMLoggingObj = _LiteLLMLoggingObj else: LiteLLMLoggingObj = Any class DatabricksConfig(DatabricksBase, OpenAILikeChatConfig, AnthropicConfig): """ Reference: https://docs.databricks.com/en/machine-learning/foundation-models/api-reference.html#chat-request """ max_tokens: Optional[int] = None temperature: Optional[int] = None top_p: Optional[int] = None top_k: Optional[int] = None stop: Optional[Union[List[str], str]] = None n: Optional[int] = None def __init__( self, max_tokens: Optional[int] = None, temperature: Optional[int] = None, top_p: Optional[int] = None, top_k: Optional[int] = None, stop: Optional[Union[List[str], str]] = None, n: Optional[int] = None, ) -> None: locals_ = locals().copy() for key, value in locals_.items(): if key != "self" and value is not None: setattr(self.__class__, key, value) @classmethod def get_config(cls): return super().get_config() def get_required_params(self) -> List[ProviderField]: """For a given provider, return it's required fields with a description""" return [ ProviderField( field_name="api_key", field_type="string", field_description="Your Databricks API Key.", field_value="dapi...", ), ProviderField( field_name="api_base", field_type="string", field_description="Your Databricks API Base.", field_value="https://adb-..", ), ] def validate_environment( self, headers: dict, model: str, messages: List[AllMessageValues], optional_params: dict, litellm_params: dict, api_key: Optional[str] = None, api_base: Optional[str] = None, ) -> dict: api_base, headers = self.databricks_validate_environment( api_base=api_base, api_key=api_key, endpoint_type="chat_completions", custom_endpoint=False, headers=headers, ) # Ensure Content-Type header is set headers["Content-Type"] = "application/json" return headers def get_complete_url( self, api_base: Optional[str], api_key: Optional[str], model: str, optional_params: dict, litellm_params: dict, stream: Optional[bool] = None, ) -> str: api_base = self._get_api_base(api_base) complete_url = f"{api_base}/chat/completions" return complete_url def get_supported_openai_params(self, model: Optional[str] = None) -> list: return [ "stream", "stop", "temperature", "top_p", "max_tokens", "max_completion_tokens", "n", "response_format", "tools", "tool_choice", "reasoning_effort", "thinking", ] def convert_anthropic_tool_to_databricks_tool( self, tool: Optional[AllAnthropicToolsValues] ) -> Optional[DatabricksTool]: if tool is None: return None return DatabricksTool( type="function", function=DatabricksFunction( name=tool["name"], parameters=cast(dict, tool.get("input_schema") or {}), ), ) def _map_openai_to_dbrx_tool(self, model: str, tools: List) -> List[DatabricksTool]: # if not claude, send as is if "claude" not in model: return tools # if claude, convert to anthropic tool and then to databricks tool anthropic_tools = self._map_tools(tools=tools) databricks_tools = [ cast(DatabricksTool, self.convert_anthropic_tool_to_databricks_tool(tool)) for tool in anthropic_tools ] return databricks_tools def map_response_format_to_databricks_tool( self, model: str, value: Optional[dict], optional_params: dict, is_thinking_enabled: bool, ) -> Optional[DatabricksTool]: if value is None: return None tool = self.map_response_format_to_anthropic_tool( value, optional_params, is_thinking_enabled ) databricks_tool = self.convert_anthropic_tool_to_databricks_tool(tool) return databricks_tool def map_openai_params( self, non_default_params: dict, optional_params: dict, model: str, drop_params: bool, replace_max_completion_tokens_with_max_tokens: bool = True, ) -> dict: is_thinking_enabled = self.is_thinking_enabled(non_default_params) mapped_params = super().map_openai_params( non_default_params, optional_params, model, drop_params ) if "tools" in mapped_params: mapped_params["tools"] = self._map_openai_to_dbrx_tool( model=model, tools=mapped_params["tools"] ) if ( "max_completion_tokens" in non_default_params and replace_max_completion_tokens_with_max_tokens ): mapped_params["max_tokens"] = non_default_params[ "max_completion_tokens" ] # most openai-compatible providers support 'max_tokens' not 'max_completion_tokens' mapped_params.pop("max_completion_tokens", None) if "response_format" in non_default_params and "claude" in model: _tool = self.map_response_format_to_databricks_tool( model, non_default_params["response_format"], mapped_params, is_thinking_enabled, ) if _tool is not None: self._add_tools_to_optional_params( optional_params=optional_params, tools=[_tool] ) optional_params["json_mode"] = True if not is_thinking_enabled: _tool_choice = ChatCompletionToolChoiceObjectParam( type="function", function=ChatCompletionToolChoiceFunctionParam( name=RESPONSE_FORMAT_TOOL_NAME ), ) optional_params["tool_choice"] = _tool_choice optional_params.pop( "response_format", None ) # unsupported for claude models - if json_schema -> convert to tool call if "reasoning_effort" in non_default_params and "claude" in model: optional_params["thinking"] = AnthropicConfig._map_reasoning_effort( non_default_params.get("reasoning_effort") ) optional_params.pop("reasoning_effort", None) ## handle thinking tokens self.update_optional_params_with_thinking_tokens( non_default_params=non_default_params, optional_params=mapped_params ) return mapped_params def _should_fake_stream(self, optional_params: dict) -> bool: """ Databricks doesn't support 'response_format' while streaming """ if optional_params.get("response_format") is not None: return True return False def _transform_messages( self, messages: List[AllMessageValues], model: str ) -> List[AllMessageValues]: """ Databricks does not support: - content in list format. - 'name' in user message. """ new_messages = [] for idx, message in enumerate(messages): if isinstance(message, BaseModel): _message = message.model_dump(exclude_none=True) else: _message = message new_messages.append(_message) new_messages = handle_messages_with_content_list_to_str_conversion(new_messages) new_messages = strip_name_from_messages(new_messages) return super()._transform_messages(messages=new_messages, model=model) @staticmethod def extract_content_str( content: Optional[AllDatabricksContentValues], ) -> Optional[str]: if content is None: return None if isinstance(content, str): return content elif isinstance(content, list): content_str = "" for item in content: if item["type"] == "text": content_str += item["text"] return content_str else: raise Exception(f"Unsupported content type: {type(content)}") @staticmethod def extract_reasoning_content( content: Optional[AllDatabricksContentValues], ) -> Tuple[ Optional[str], Optional[ List[ Union[ChatCompletionThinkingBlock, ChatCompletionRedactedThinkingBlock] ] ], ]: """ Extract and return the reasoning content and thinking blocks """ if content is None: return None, None thinking_blocks: Optional[ List[ Union[ChatCompletionThinkingBlock, ChatCompletionRedactedThinkingBlock] ] ] = None reasoning_content: Optional[str] = None if isinstance(content, list): for item in content: if item["type"] == "reasoning": for sum in item["summary"]: if reasoning_content is None: reasoning_content = "" reasoning_content += sum["text"] thinking_block = ChatCompletionThinkingBlock( type="thinking", thinking=sum["text"], signature=sum["signature"], ) if thinking_blocks is None: thinking_blocks = [] thinking_blocks.append(thinking_block) return reasoning_content, thinking_blocks def _transform_choices( self, choices: List[DatabricksChoice], json_mode: Optional[bool] = None ) -> List[Choices]: transformed_choices = [] for choice in choices: ## HANDLE JSON MODE - anthropic returns single function call] tool_calls = choice["message"].get("tool_calls", None) if tool_calls is not None: _openai_tool_calls = [] for _tc in tool_calls: _openai_tc = ChatCompletionMessageToolCall(**_tc) # type: ignore _openai_tool_calls.append(_openai_tc) fixed_tool_calls = _handle_invalid_parallel_tool_calls( _openai_tool_calls ) if fixed_tool_calls is not None: tool_calls = fixed_tool_calls translated_message: Optional[Message] = None finish_reason: Optional[str] = None if tool_calls and _should_convert_tool_call_to_json_mode( tool_calls=tool_calls, convert_tool_call_to_json_mode=json_mode, ): # to support response_format on claude models json_mode_content_str: Optional[str] = ( str(tool_calls[0]["function"].get("arguments", "")) or None ) if json_mode_content_str is not None: translated_message = Message(content=json_mode_content_str) finish_reason = "stop" if translated_message is None: ## get the content str content_str = DatabricksConfig.extract_content_str( choice["message"]["content"] ) ## get the reasoning content ( reasoning_content, thinking_blocks, ) = DatabricksConfig.extract_reasoning_content( choice["message"].get("content") ) translated_message = Message( role="assistant", content=content_str, reasoning_content=reasoning_content, thinking_blocks=thinking_blocks, tool_calls=choice["message"].get("tool_calls"), ) if finish_reason is None: finish_reason = choice["finish_reason"] translated_choice = Choices( finish_reason=finish_reason, index=choice["index"], message=translated_message, logprobs=None, enhancements=None, ) transformed_choices.append(translated_choice) return transformed_choices def transform_response( self, model: str, raw_response: httpx.Response, model_response: ModelResponse, logging_obj: LiteLLMLoggingObj, request_data: dict, messages: List[AllMessageValues], optional_params: dict, litellm_params: dict, encoding: Any, api_key: Optional[str] = None, json_mode: Optional[bool] = None, ) -> ModelResponse: ## LOGGING logging_obj.post_call( input=messages, api_key=api_key, original_response=raw_response.text, additional_args={"complete_input_dict": request_data}, ) ## RESPONSE OBJECT try: completion_response = DatabricksResponse(**raw_response.json()) # type: ignore except Exception as e: response_headers = getattr(raw_response, "headers", None) raise DatabricksException( message="Unable to get json response - {}, Original Response: {}".format( str(e), raw_response.text ), status_code=raw_response.status_code, headers=response_headers, ) model_response.model = completion_response["model"] model_response.id = completion_response["id"] model_response.created = completion_response["created"] setattr(model_response, "usage", Usage(**completion_response["usage"])) model_response.choices = self._transform_choices( # type: ignore choices=completion_response["choices"], json_mode=json_mode, ) return model_response def get_model_response_iterator( self, streaming_response: Union[Iterator[str], AsyncIterator[str], ModelResponse], sync_stream: bool, json_mode: Optional[bool] = False, ): return DatabricksChatResponseIterator( streaming_response=streaming_response, sync_stream=sync_stream, json_mode=json_mode, ) class DatabricksChatResponseIterator(BaseModelResponseIterator): def __init__( self, streaming_response: Union[Iterator[str], AsyncIterator[str], ModelResponse], sync_stream: bool, json_mode: Optional[bool] = False, ): super().__init__(streaming_response, sync_stream) self.json_mode = json_mode self._last_function_name = None # Track the last seen function name def chunk_parser(self, chunk: dict) -> ModelResponseStream: try: translated_choices = [] for choice in chunk["choices"]: tool_calls = choice["delta"].get("tool_calls") if tool_calls and self.json_mode: # 1. Check if the function name is set and == RESPONSE_FORMAT_TOOL_NAME # 2. If no function name, just args -> check last function name (saved via state variable) # 3. Convert args to json # 4. Convert json to message # 5. Set content to message.content # 6. Set tool_calls to None from litellm.constants import RESPONSE_FORMAT_TOOL_NAME from litellm.llms.base_llm.base_utils import ( _convert_tool_response_to_message, ) # Check if this chunk has a function name function_name = tool_calls[0].get("function", {}).get("name") if function_name is not None: self._last_function_name = function_name # If we have a saved function name that matches RESPONSE_FORMAT_TOOL_NAME # or this chunk has the matching function name if ( self._last_function_name == RESPONSE_FORMAT_TOOL_NAME or function_name == RESPONSE_FORMAT_TOOL_NAME ): # Convert tool calls to message format message = _convert_tool_response_to_message(tool_calls) if message is not None: if message.content == "{}": # empty json message.content = "" choice["delta"]["content"] = message.content choice["delta"]["tool_calls"] = None elif tool_calls: for _tc in tool_calls: if _tc.get("function", {}).get("arguments") == "{}": _tc["function"]["arguments"] = "" # avoid invalid json # extract the content str content_str = DatabricksConfig.extract_content_str( choice["delta"].get("content") ) # extract the reasoning content ( reasoning_content, thinking_blocks, ) = DatabricksConfig.extract_reasoning_content( choice["delta"]["content"] ) choice["delta"]["content"] = content_str choice["delta"]["reasoning_content"] = reasoning_content choice["delta"]["thinking_blocks"] = thinking_blocks translated_choices.append(choice) return ModelResponseStream( id=chunk["id"], object="chat.completion.chunk", created=chunk["created"], model=chunk["model"], choices=translated_choices, ) except KeyError as e: raise DatabricksException( message=f"KeyError: {e}, Got unexpected response from Databricks: {chunk}", status_code=400, ) except Exception as e: raise e