diff --git a/litellm/__init__.py b/litellm/__init__.py index 92610afd9..ebdac8c6e 100644 --- a/litellm/__init__.py +++ b/litellm/__init__.py @@ -69,6 +69,7 @@ retry = True ### AUTH ### api_key: Optional[str] = None openai_key: Optional[str] = None +databricks_key: Optional[str] = None azure_key: Optional[str] = None anthropic_key: Optional[str] = None replicate_key: Optional[str] = None @@ -615,6 +616,7 @@ provider_list: List = [ "watsonx", "triton", "predibase", + "databricks", "custom", # custom apis ] @@ -730,6 +732,7 @@ from .utils import ( ) from .llms.huggingface_restapi import HuggingfaceConfig from .llms.anthropic import AnthropicConfig +from .llms.databricks import DatabricksConfig from .llms.predibase import PredibaseConfig from .llms.anthropic_text import AnthropicTextConfig from .llms.replicate import ReplicateConfig diff --git a/litellm/llms/custom_httpx/http_handler.py b/litellm/llms/custom_httpx/http_handler.py index 0adbd95bf..4df25944b 100644 --- a/litellm/llms/custom_httpx/http_handler.py +++ b/litellm/llms/custom_httpx/http_handler.py @@ -7,8 +7,12 @@ _DEFAULT_TIMEOUT = httpx.Timeout(timeout=5.0, connect=5.0) class AsyncHTTPHandler: def __init__( - self, timeout: httpx.Timeout = _DEFAULT_TIMEOUT, concurrent_limit=1000 + self, + timeout: Optional[Union[float, httpx.Timeout]] = None, + concurrent_limit=1000, ): + if timeout is None: + timeout = _DEFAULT_TIMEOUT # Create a client with a connection pool self.client = httpx.AsyncClient( timeout=timeout, @@ -59,7 +63,7 @@ class AsyncHTTPHandler: class HTTPHandler: def __init__( self, - timeout: Optional[httpx.Timeout] = None, + timeout: Optional[Union[float, httpx.Timeout]] = None, concurrent_limit=1000, client: Optional[httpx.Client] = None, ): diff --git a/litellm/llms/databricks.py b/litellm/llms/databricks.py new file mode 100644 index 000000000..3212c7ad1 --- /dev/null +++ b/litellm/llms/databricks.py @@ -0,0 +1,506 @@ +# What is this? +## Handler file for databricks API https://docs.databricks.com/en/machine-learning/foundation-models/api-reference.html#chat-request +import os, types +import json +from enum import Enum +import requests, copy # type: ignore +import time +from typing import Callable, Optional, List, Union, Tuple +from litellm.utils import ModelResponse, Usage, map_finish_reason, CustomStreamWrapper +import litellm +from .prompt_templates.factory import prompt_factory, custom_prompt +from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler +from .base import BaseLLM +import httpx # type: ignore +from litellm.types.llms.databricks import GenericStreamingChunk + + +class DatabricksError(Exception): + def __init__(self, status_code, message): + self.status_code = status_code + self.message = message + self.request = httpx.Request(method="POST", url="https://docs.databricks.com/") + self.response = httpx.Response(status_code=status_code, request=self.request) + super().__init__( + self.message + ) # Call the base class constructor with the parameters it needs + + +class DatabricksConfig: + """ + Reference: https://docs.databricks.com/en/machine-learning/foundation-models/api-reference.html#chat-request + """ + + max_tokens: Optional[int] = None + temperature: Optional[int] = None + top_p: Optional[int] = None + top_k: Optional[int] = None + stop: Optional[Union[List[str], str]] = None + n: Optional[int] = None + + def __init__( + self, + max_tokens: Optional[int] = None, + temperature: Optional[int] = None, + top_p: Optional[int] = None, + top_k: Optional[int] = None, + stop: Optional[Union[List[str], str]] = None, + n: Optional[int] = None, + ) -> None: + locals_ = locals() + for key, value in locals_.items(): + if key != "self" and value is not None: + setattr(self.__class__, key, value) + + @classmethod + def get_config(cls): + return { + k: v + for k, v in cls.__dict__.items() + if not k.startswith("__") + and not isinstance( + v, + ( + types.FunctionType, + types.BuiltinFunctionType, + classmethod, + staticmethod, + ), + ) + and v is not None + } + + def get_supported_openai_params(self): + return ["stream", "stop", "temperature", "top_p", "max_tokens", "n"] + + def map_openai_params(self, non_default_params: dict, optional_params: dict): + for param, value in non_default_params.items(): + if param == "max_tokens": + optional_params["max_tokens"] = value + if param == "n": + optional_params["n"] = value + if param == "stream" and value == True: + optional_params["stream"] = value + if param == "temperature": + optional_params["temperature"] = value + if param == "top_p": + optional_params["top_p"] = value + if param == "stop": + optional_params["stop"] = value + return optional_params + + def _chunk_parser(self, chunk_data: str) -> GenericStreamingChunk: + try: + text = "" + is_finished = False + finish_reason = None + logprobs = None + usage = None + original_chunk = None # this is used for function/tool calling + chunk_data = chunk_data.replace("data:", "") + chunk_data = chunk_data.strip() + if len(chunk_data) == 0: + return { + "text": "", + "is_finished": is_finished, + "finish_reason": finish_reason, + } + chunk_data_dict = json.loads(chunk_data) + str_line = litellm.ModelResponse(**chunk_data_dict, stream=True) + + if len(str_line.choices) > 0: + if ( + str_line.choices[0].delta is not None # type: ignore + and str_line.choices[0].delta.content is not None # type: ignore + ): + text = str_line.choices[0].delta.content # type: ignore + else: # function/tool calling chunk - when content is None. in this case we just return the original chunk from openai + original_chunk = str_line + if str_line.choices[0].finish_reason: + is_finished = True + finish_reason = str_line.choices[0].finish_reason + if finish_reason == "content_filter": + if hasattr(str_line.choices[0], "content_filter_result"): + error_message = json.dumps( + str_line.choices[0].content_filter_result # type: ignore + ) + else: + error_message = "Azure Response={}".format( + str(dict(str_line)) + ) + raise litellm.AzureOpenAIError( + status_code=400, message=error_message + ) + + # checking for logprobs + if ( + hasattr(str_line.choices[0], "logprobs") + and str_line.choices[0].logprobs is not None + ): + logprobs = str_line.choices[0].logprobs + else: + logprobs = None + + usage = getattr(str_line, "usage", None) + + return GenericStreamingChunk( + text=text, + is_finished=is_finished, + finish_reason=finish_reason, + logprobs=logprobs, + original_chunk=original_chunk, + usage=usage, + ) + except Exception as e: + raise e + + +class DatabricksChatCompletion(BaseLLM): + def __init__(self) -> None: + super().__init__() + + # makes headers for API call + + def _validate_environment( + self, api_key: Optional[str], api_base: Optional[str] + ) -> Tuple[str, dict]: + if api_key is None: + raise DatabricksError( + status_code=400, + message="Missing Databricks API Key - A call is being made to Databricks but no key is set either in the environment variables (DATABRICKS_API_KEY) or via params", + ) + + if api_base is None: + raise DatabricksError( + status_code=400, + message="Missing Databricks API Base - A call is being made to Databricks but no api base is set either in the environment variables (DATABRICKS_API_BASE) or via params", + ) + + headers = { + "Authorization": "Bearer {}".format(api_key), + "Content-Type": "application/json", + } + + api_base = "{}/chat/completions".format(api_base) + return api_base, headers + + def process_response( + self, + model: str, + response: Union[requests.Response, httpx.Response], + model_response: ModelResponse, + stream: bool, + logging_obj: litellm.utils.Logging, + optional_params: dict, + api_key: str, + data: Union[dict, str], + messages: List, + print_verbose, + encoding, + ) -> ModelResponse: + ## LOGGING + logging_obj.post_call( + input=messages, + api_key=api_key, + original_response=response.text, + additional_args={"complete_input_dict": data}, + ) + print_verbose(f"raw model_response: {response.text}") + ## RESPONSE OBJECT + try: + completion_response = response.json() + except: + raise DatabricksError( + message=response.text, status_code=response.status_code + ) + if "error" in completion_response: + raise DatabricksError( + message=str(completion_response["error"]), + status_code=response.status_code, + ) + else: + text_content = "" + tool_calls = [] + for content in completion_response["content"]: + if content["type"] == "text": + text_content += content["text"] + ## TOOL CALLING + elif content["type"] == "tool_use": + tool_calls.append( + { + "id": content["id"], + "type": "function", + "function": { + "name": content["name"], + "arguments": json.dumps(content["input"]), + }, + } + ) + + _message = litellm.Message( + tool_calls=tool_calls, + content=text_content or None, + ) + model_response.choices[0].message = _message # type: ignore + model_response._hidden_params["original_response"] = completion_response[ + "content" + ] # allow user to access raw anthropic tool calling response + + model_response.choices[0].finish_reason = map_finish_reason( + completion_response["stop_reason"] + ) + + ## CALCULATING USAGE + prompt_tokens = completion_response["usage"]["input_tokens"] + completion_tokens = completion_response["usage"]["output_tokens"] + total_tokens = prompt_tokens + completion_tokens + + model_response["created"] = int(time.time()) + model_response["model"] = model + usage = Usage( + prompt_tokens=prompt_tokens, + completion_tokens=completion_tokens, + total_tokens=total_tokens, + ) + setattr(model_response, "usage", usage) # type: ignore + return model_response + + async def acompletion_stream_function( + self, + model: str, + messages: list, + api_base: str, + custom_prompt_dict: dict, + model_response: ModelResponse, + print_verbose: Callable, + encoding, + api_key, + logging_obj, + stream, + data: dict, + optional_params=None, + litellm_params=None, + logger_fn=None, + headers={}, + ): + self.async_handler = AsyncHTTPHandler( + timeout=httpx.Timeout(timeout=600.0, connect=5.0) + ) + data["stream"] = True + try: + response = await self.async_handler.post( + api_base, headers=headers, data=json.dumps(data), stream=True + ) + response.raise_for_status() + + completion_stream = response.aiter_lines() + except httpx.HTTPStatusError as e: + raise DatabricksError( + status_code=e.response.status_code, message=response.text + ) + except httpx.TimeoutException as e: + raise DatabricksError(status_code=408, message="Timeout error occurred.") + except Exception as e: + raise DatabricksError(status_code=500, message=str(e)) + + streamwrapper = CustomStreamWrapper( + completion_stream=completion_stream, + model=model, + custom_llm_provider="databricks", + logging_obj=logging_obj, + ) + return streamwrapper + + async def acompletion_function( + self, + model: str, + messages: list, + api_base: str, + custom_prompt_dict: dict, + model_response: ModelResponse, + print_verbose: Callable, + encoding, + api_key, + logging_obj, + stream, + data: dict, + optional_params: dict, + litellm_params=None, + logger_fn=None, + headers={}, + timeout: Optional[Union[float, httpx.Timeout]] = None, + ) -> ModelResponse: + if timeout is None: + timeout = httpx.Timeout(timeout=600.0, connect=5.0) + + self.async_handler = AsyncHTTPHandler(timeout=timeout) + + try: + response = await self.async_handler.post( + api_base, headers=headers, data=json.dumps(data) + ) + response.raise_for_status() + + response_json = response.json() + except httpx.HTTPStatusError as e: + raise DatabricksError( + status_code=e.response.status_code, + message=response.text if response else str(e), + ) + except httpx.TimeoutException as e: + raise DatabricksError(status_code=408, message="Timeout error occurred.") + except Exception as e: + raise DatabricksError(status_code=500, message=str(e)) + + return ModelResponse(**response_json) + + def completion( + self, + model: str, + messages: list, + api_base: str, + custom_prompt_dict: dict, + model_response: ModelResponse, + print_verbose: Callable, + encoding, + api_key, + logging_obj, + optional_params: dict, + acompletion=None, + litellm_params=None, + logger_fn=None, + headers={}, + timeout: Optional[Union[float, httpx.Timeout]] = None, + client: Optional[Union[HTTPHandler, AsyncHTTPHandler]] = None, + ): + api_base, headers = self._validate_environment( + api_base=api_base, api_key=api_key + ) + ## Load Config + config = litellm.DatabricksConfig().get_config() + for k, v in config.items(): + if ( + k not in optional_params + ): # completion(top_k=3) > anthropic_config(top_k=3) <- allows for dynamic variables to be passed in + optional_params[k] = v + + stream = optional_params.pop("stream", None) + + data = { + "model": model, + "messages": messages, + **optional_params, + } + + ## LOGGING + logging_obj.pre_call( + input=messages, + api_key=api_key, + additional_args={ + "complete_input_dict": data, + "api_base": api_base, + "headers": headers, + }, + ) + if acompletion == True: + if ( + stream is not None and stream == True + ): # if function call - fake the streaming (need complete blocks for output parsing in openai format) + print_verbose("makes async anthropic streaming POST request") + data["stream"] = stream + return self.acompletion_stream_function( + model=model, + messages=messages, + data=data, + api_base=api_base, + custom_prompt_dict=custom_prompt_dict, + model_response=model_response, + print_verbose=print_verbose, + encoding=encoding, + api_key=api_key, + logging_obj=logging_obj, + optional_params=optional_params, + stream=stream, + litellm_params=litellm_params, + logger_fn=logger_fn, + headers=headers, + ) + else: + return self.acompletion_function( + model=model, + messages=messages, + data=data, + api_base=api_base, + custom_prompt_dict=custom_prompt_dict, + model_response=model_response, + print_verbose=print_verbose, + encoding=encoding, + api_key=api_key, + logging_obj=logging_obj, + optional_params=optional_params, + stream=stream, + litellm_params=litellm_params, + logger_fn=logger_fn, + headers=headers, + timeout=timeout, + ) + else: + if client is None or isinstance(client, AsyncHTTPHandler): + self.client = HTTPHandler(timeout=timeout) # type: ignore + else: + self.client = client + ## COMPLETION CALL + if ( + stream is not None and stream == True + ): # if function call - fake the streaming (need complete blocks for output parsing in openai format) + print_verbose("makes dbrx streaming POST request") + data["stream"] = stream + try: + response = self.client.post( + api_base, headers=headers, data=json.dumps(data), stream=stream + ) + response.raise_for_status() + completion_stream = response.iter_lines() + except httpx.HTTPStatusError as e: + raise DatabricksError( + status_code=e.response.status_code, message=response.text + ) + except httpx.TimeoutException as e: + raise DatabricksError( + status_code=408, message="Timeout error occurred." + ) + except Exception as e: + raise DatabricksError(status_code=408, message=str(e)) + + streaming_response = CustomStreamWrapper( + completion_stream=completion_stream, + model=model, + custom_llm_provider="databricks", + logging_obj=logging_obj, + ) + return streaming_response + + else: + try: + response = self.client.post( + api_base, headers=headers, data=json.dumps(data) + ) + response.raise_for_status() + + response_json = response.json() + except httpx.HTTPStatusError as e: + raise DatabricksError( + status_code=e.response.status_code, message=response.text + ) + except httpx.TimeoutException as e: + raise DatabricksError( + status_code=408, message="Timeout error occurred." + ) + except Exception as e: + raise DatabricksError(status_code=500, message=str(e)) + + return ModelResponse(**response_json) + + def embedding(self): + # logic for parsing in - calling - parsing out model embedding calls + pass diff --git a/litellm/llms/openai.py b/litellm/llms/openai.py index 9d143f5d9..0bac50639 100644 --- a/litellm/llms/openai.py +++ b/litellm/llms/openai.py @@ -404,6 +404,7 @@ class OpenAIChatCompletion(BaseLLM): self, model_response: ModelResponse, timeout: Union[float, httpx.Timeout], + optional_params: dict, model: Optional[str] = None, messages: Optional[list] = None, print_verbose: Optional[Callable] = None, @@ -411,7 +412,6 @@ class OpenAIChatCompletion(BaseLLM): api_base: Optional[str] = None, acompletion: bool = False, logging_obj=None, - optional_params=None, litellm_params=None, logger_fn=None, headers: Optional[dict] = None, diff --git a/litellm/main.py b/litellm/main.py index 42c4eb8ff..7757fead1 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -73,6 +73,7 @@ from .llms import ( ) from .llms.openai import OpenAIChatCompletion, OpenAITextCompletion from .llms.azure import AzureChatCompletion +from .llms.databricks import DatabricksChatCompletion from .llms.azure_text import AzureTextCompletion from .llms.anthropic import AnthropicChatCompletion from .llms.anthropic_text import AnthropicTextCompletion @@ -111,6 +112,7 @@ from litellm.utils import ( ####### ENVIRONMENT VARIABLES ################### openai_chat_completions = OpenAIChatCompletion() openai_text_completions = OpenAITextCompletion() +databricks_chat_completions = DatabricksChatCompletion() anthropic_chat_completions = AnthropicChatCompletion() anthropic_text_completions = AnthropicTextCompletion() azure_chat_completions = AzureChatCompletion() @@ -329,6 +331,7 @@ async def acompletion( or custom_llm_provider == "anthropic" or custom_llm_provider == "predibase" or custom_llm_provider == "bedrock" + or custom_llm_provider == "databricks" or custom_llm_provider in litellm.openai_compatible_providers ): # currently implemented aiohttp calls for just azure, openai, hf, ollama, vertex ai soon all. init_response = await loop.run_in_executor(None, func_with_context) @@ -1615,6 +1618,61 @@ def completion( ) return response response = model_response + elif custom_llm_provider == "databricks": + api_base = ( + api_base # for databricks we check in get_llm_provider and pass in the api base from there + or litellm.api_base + or os.getenv("DATABRICKS_API_BASE") + ) + + # set API KEY + api_key = ( + api_key + or litellm.api_key # for databricks we check in get_llm_provider and pass in the api key from there + or litellm.databricks_key + or get_secret("DATABRICKS_API_KEY") + ) + + headers = headers or litellm.headers + + ## COMPLETION CALL + try: + response = databricks_chat_completions.completion( + model=model, + messages=messages, + headers=headers, + model_response=model_response, + print_verbose=print_verbose, + api_key=api_key, + api_base=api_base, + acompletion=acompletion, + logging_obj=logging, + optional_params=optional_params, + litellm_params=litellm_params, + logger_fn=logger_fn, + timeout=timeout, # type: ignore + custom_prompt_dict=custom_prompt_dict, + client=client, # pass AsyncOpenAI, OpenAI client + encoding=encoding, + ) + except Exception as e: + ## LOGGING - log the original exception returned + logging.post_call( + input=messages, + api_key=api_key, + original_response=str(e), + additional_args={"headers": headers}, + ) + raise e + + if optional_params.get("stream", False): + ## LOGGING + logging.post_call( + input=messages, + api_key=api_key, + original_response=response, + additional_args={"headers": headers}, + ) elif custom_llm_provider == "openrouter": api_base = api_base or litellm.api_base or "https://openrouter.ai/api/v1" diff --git a/litellm/tests/test_completion.py b/litellm/tests/test_completion.py index f599b1817..1b319c160 100644 --- a/litellm/tests/test_completion.py +++ b/litellm/tests/test_completion.py @@ -131,6 +131,27 @@ def test_completion_azure_command_r(): pytest.fail(f"Error occurred: {e}") +@pytest.mark.parametrize("sync_mode", [True, False]) +@pytest.mark.asyncio +async def test_completion_databricks(sync_mode): + litellm.set_verbose = True + + if sync_mode: + response: litellm.ModelResponse = completion( + model="databricks/databricks-dbrx-instruct", + messages=[{"role": "user", "content": "Hey, how's it going?"}], + ) # type: ignore + + else: + response: litellm.ModelResponse = await litellm.acompletion( + model="databricks/databricks-dbrx-instruct", + messages=[{"role": "user", "content": "Hey, how's it going?"}], + ) # type: ignore + print(f"response: {response}") + + response_format_tests(response=response) + + # @pytest.mark.skip(reason="local test") @pytest.mark.parametrize("sync_mode", [True, False]) @pytest.mark.asyncio diff --git a/litellm/tests/test_streaming.py b/litellm/tests/test_streaming.py index 554a77eef..237d3895d 100644 --- a/litellm/tests/test_streaming.py +++ b/litellm/tests/test_streaming.py @@ -951,6 +951,62 @@ def test_vertex_ai_stream(): # test_completion_vertexai_stream_bad_key() +@pytest.mark.parametrize("sync_mode", [True, False]) +@pytest.mark.asyncio +async def test_completion_databricks_streaming(sync_mode): + litellm.set_verbose = True + model_name = "databricks/databricks-dbrx-instruct" + try: + if sync_mode: + final_chunk: Optional[litellm.ModelResponse] = None + response: litellm.CustomStreamWrapper = completion( # type: ignore + model=model_name, + messages=messages, + max_tokens=10, # type: ignore + stream=True, + ) + complete_response = "" + # Add any assertions here to check the response + has_finish_reason = False + for idx, chunk in enumerate(response): + final_chunk = chunk + chunk, finished = streaming_format_tests(idx, chunk) + if finished: + has_finish_reason = True + break + complete_response += chunk + if has_finish_reason == False: + raise Exception("finish reason not set") + if complete_response.strip() == "": + raise Exception("Empty response received") + else: + response: litellm.CustomStreamWrapper = await litellm.acompletion( # type: ignore + model=model_name, + messages=messages, + max_tokens=100, # type: ignore + stream=True, + ) + complete_response = "" + # Add any assertions here to check the response + has_finish_reason = False + idx = 0 + final_chunk: Optional[litellm.ModelResponse] = None + async for chunk in response: + final_chunk = chunk + chunk, finished = streaming_format_tests(idx, chunk) + if finished: + has_finish_reason = True + break + complete_response += chunk + idx += 1 + if has_finish_reason == False: + raise Exception("finish reason not set") + if complete_response.strip() == "": + raise Exception("Empty response received") + except Exception as e: + pytest.fail(f"Error occurred: {e}") + + @pytest.mark.parametrize("sync_mode", [False, True]) @pytest.mark.asyncio async def test_completion_replicate_llama3_streaming(sync_mode): diff --git a/litellm/types/llms/databricks.py b/litellm/types/llms/databricks.py new file mode 100644 index 000000000..770e05fe3 --- /dev/null +++ b/litellm/types/llms/databricks.py @@ -0,0 +1,21 @@ +from typing import TypedDict, Any, Union, Optional +import json +from typing_extensions import ( + Self, + Protocol, + TypeGuard, + override, + get_origin, + runtime_checkable, + Required, +) +from pydantic import BaseModel + + +class GenericStreamingChunk(TypedDict, total=False): + text: Required[str] + is_finished: Required[bool] + finish_reason: Required[Optional[str]] + logprobs: Optional[BaseModel] + original_chunk: Optional[BaseModel] + usage: Optional[BaseModel] diff --git a/litellm/utils.py b/litellm/utils.py index c8c491754..8189ee058 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -568,7 +568,7 @@ class StreamingChoices(OpenAIObject): if delta is not None: if isinstance(delta, Delta): self.delta = delta - if isinstance(delta, dict): + elif isinstance(delta, dict): self.delta = Delta(**delta) else: self.delta = Delta() @@ -676,7 +676,10 @@ class ModelResponse(OpenAIObject): created = created model = model if usage is not None: - usage = usage + if isinstance(usage, dict): + usage = Usage(**usage) + else: + usage = usage elif stream is None or stream == False: usage = Usage() elif ( @@ -11012,6 +11015,8 @@ class CustomStreamWrapper: elif self.custom_llm_provider and self.custom_llm_provider == "clarifai": response_obj = self.handle_clarifai_completion_chunk(chunk) completion_obj["content"] = response_obj["text"] + if response_obj["is_finished"]: + self.received_finish_reason = response_obj["finish_reason"] elif self.model == "replicate" or self.custom_llm_provider == "replicate": response_obj = self.handle_replicate_chunk(chunk) completion_obj["content"] = response_obj["text"] @@ -11263,6 +11268,17 @@ class CustomStreamWrapper: and self.stream_options.get("include_usage", False) == True ): model_response.usage = response_obj["usage"] + elif self.custom_llm_provider == "databricks": + response_obj = litellm.DatabricksConfig()._chunk_parser(chunk) + completion_obj["content"] = response_obj["text"] + print_verbose(f"completion obj content: {completion_obj['content']}") + if response_obj["is_finished"]: + self.received_finish_reason = response_obj["finish_reason"] + if ( + self.stream_options + and self.stream_options.get("include_usage", False) == True + ): + model_response.usage = response_obj["usage"] elif self.custom_llm_provider == "azure_text": response_obj = self.handle_azure_text_completion_chunk(chunk) completion_obj["content"] = response_obj["text"] @@ -11672,6 +11688,7 @@ class CustomStreamWrapper: or self.custom_llm_provider == "replicate" or self.custom_llm_provider == "cached_response" or self.custom_llm_provider == "predibase" + or self.custom_llm_provider == "databricks" or self.custom_llm_provider == "bedrock" or self.custom_llm_provider in litellm.openai_compatible_endpoints ):