diff --git a/docs/my-website/docs/completion/input.md b/docs/my-website/docs/completion/input.md index 371f9394a..1c2a599ca 100644 --- a/docs/my-website/docs/completion/input.md +++ b/docs/my-website/docs/completion/input.md @@ -59,6 +59,7 @@ Use `litellm.get_supported_openai_params()` for an updated list of params for ea |NLP Cloud| ✅ | ✅ | ✅ | ✅ | ✅ | | | | | | |Petals| ✅ | ✅ | | ✅ | | | | | | | |Ollama| ✅ | ✅ | ✅ | ✅ | ✅ | | | ✅ | | | | | ✅ | | | +|Databricks| ✅ | ✅ | ✅ | ✅ | ✅ | | | | | | | | | | | :::note diff --git a/docs/my-website/docs/providers/databricks.md b/docs/my-website/docs/providers/databricks.md new file mode 100644 index 000000000..08a3e4f76 --- /dev/null +++ b/docs/my-website/docs/providers/databricks.md @@ -0,0 +1,202 @@ +import Tabs from '@theme/Tabs'; +import TabItem from '@theme/TabItem'; + +# 🆕 Databricks + +LiteLLM supports all models on Databricks + + +## Usage + + + + +### ENV VAR +```python +import os +os.environ["DATABRICKS_API_KEY"] = "" +os.environ["DATABRICKS_API_BASE"] = "" +``` + +### Example Call + +```python +from litellm import completion +import os +## set ENV variables +os.environ["DATABRICKS_API_KEY"] = "databricks key" +os.environ["DATABRICKS_API_BASE"] = "databricks base url" # e.g.: https://adb-3064715882934586.6.azuredatabricks.net/serving-endpoints + +# predibase llama-3 call +response = completion( + model="databricks/databricks-dbrx-instruct", + messages = [{ "content": "Hello, how are you?","role": "user"}] +) +``` + + + + +1. Add models to your config.yaml + + ```yaml + model_list: + - model_name: dbrx-instruct + litellm_params: + model: databricks/databricks-dbrx-instruct + api_key: os.environ/DATABRICKS_API_KEY + api_base: os.environ/DATABRICKS_API_BASE + ``` + + + +2. Start the proxy + + ```bash + $ litellm --config /path/to/config.yaml --debug + ``` + +3. Send Request to LiteLLM Proxy Server + + + + + + ```python + import openai + client = openai.OpenAI( + api_key="sk-1234", # pass litellm proxy key, if you're using virtual keys + base_url="http://0.0.0.0:4000" # litellm-proxy-base url + ) + + response = client.chat.completions.create( + model="dbrx-instruct", + messages = [ + { + "role": "system", + "content": "Be a good human!" + }, + { + "role": "user", + "content": "What do you know about earth?" + } + ] + ) + + print(response) + ``` + + + + + + ```shell + curl --location 'http://0.0.0.0:4000/chat/completions' \ + --header 'Authorization: Bearer sk-1234' \ + --header 'Content-Type: application/json' \ + --data '{ + "model": "dbrx-instruct", + "messages": [ + { + "role": "system", + "content": "Be a good human!" + }, + { + "role": "user", + "content": "What do you know about earth?" + } + ], + }' + ``` + + + + + + + + + +## Passing additional params - max_tokens, temperature +See all litellm.completion supported params [here](../completion/input.md#translated-openai-params) + +```python +# !pip install litellm +from litellm import completion +import os +## set ENV variables +os.environ["PREDIBASE_API_KEY"] = "predibase key" + +# predibae llama-3 call +response = completion( + model="predibase/llama3-8b-instruct", + messages = [{ "content": "Hello, how are you?","role": "user"}], + max_tokens=20, + temperature=0.5 +) +``` + +**proxy** + +```yaml + model_list: + - model_name: llama-3 + litellm_params: + model: predibase/llama-3-8b-instruct + api_key: os.environ/PREDIBASE_API_KEY + max_tokens: 20 + temperature: 0.5 +``` + +## Passings Database specific params - 'instruction' + +For embedding models, databricks lets you pass in an additional param 'instruction'. [Full Spec](https://github.com/BerriAI/litellm/blob/43353c28b341df0d9992b45c6ce464222ebd7984/litellm/llms/databricks.py#L164) + + +```python +# !pip install litellm +from litellm import embedding +import os +## set ENV variables +os.environ["DATABRICKS_API_KEY"] = "databricks key" +os.environ["DATABRICKS_API_BASE"] = "databricks url" + +# predibase llama3 call +response = litellm.embedding( + model="databricks/databricks-bge-large-en", + input=["good morning from litellm"], + instruction="Represent this sentence for searching relevant passages:", + ) +``` + +**proxy** + +```yaml + model_list: + - model_name: bge-large + litellm_params: + model: databricks/databricks-bge-large-en + api_key: os.environ/DATABRICKS_API_KEY + api_base: os.environ/DATABRICKS_API_BASE + instruction: "Represent this sentence for searching relevant passages:" +``` + + +## Supported Databricks Chat Completion Models +Here's an example of using a Databricks models with LiteLLM + +| Model Name | Command | +|----------------------------|------------------------------------------------------------------| +| databricks-dbrx-instruct | `completion(model='databricks/databricks-dbrx-instruct', messages=messages)` | +| databricks-meta-llama-3-70b-instruct | `completion(model='databricks/databricks-meta-llama-3-70b-instruct', messages=messages)` | +| databricks-llama-2-70b-chat | `completion(model='databricks/databricks-llama-2-70b-chat', messages=messages)` | +| databricks-mixtral-8x7b-instruct | `completion(model='databricks/databricks-mixtral-8x7b-instruct', messages=messages)` | +| databricks-mpt-30b-instruct | `completion(model='databricks/databricks-mpt-30b-instruct', messages=messages)` | +| databricks-mpt-7b-instruct | `completion(model='databricks/databricks-mpt-7b-instruct', messages=messages)` | + +## Supported Databricks Embedding Models +Here's an example of using a databricks models with LiteLLM + +| Model Name | Command | +|----------------------------|------------------------------------------------------------------| +| databricks-bge-large-en | `completion(model='databricks/databricks-bge-large-en', messages=messages)` | diff --git a/docs/my-website/docs/providers/predibase.md b/docs/my-website/docs/providers/predibase.md index 3d5bbaef4..31713aef1 100644 --- a/docs/my-website/docs/providers/predibase.md +++ b/docs/my-website/docs/providers/predibase.md @@ -1,7 +1,7 @@ import Tabs from '@theme/Tabs'; import TabItem from '@theme/TabItem'; -# 🆕 Predibase +# Predibase LiteLLM supports all models on Predibase diff --git a/docs/my-website/sidebars.js b/docs/my-website/sidebars.js index f840ed789..fddbf2838 100644 --- a/docs/my-website/sidebars.js +++ b/docs/my-website/sidebars.js @@ -133,6 +133,7 @@ const sidebars = { "providers/cohere", "providers/anyscale", "providers/huggingface", + "providers/databricks", "providers/watsonx", "providers/predibase", "providers/triton-inference-server", diff --git a/litellm/__init__.py b/litellm/__init__.py index 2af65f790..7e0f22e8f 100644 --- a/litellm/__init__.py +++ b/litellm/__init__.py @@ -69,6 +69,7 @@ retry = True ### AUTH ### api_key: Optional[str] = None openai_key: Optional[str] = None +databricks_key: Optional[str] = None azure_key: Optional[str] = None anthropic_key: Optional[str] = None replicate_key: Optional[str] = None @@ -616,6 +617,7 @@ provider_list: List = [ "watsonx", "triton", "predibase", + "databricks", "custom", # custom apis ] @@ -731,6 +733,7 @@ from .utils import ( ) from .llms.huggingface_restapi import HuggingfaceConfig from .llms.anthropic import AnthropicConfig +from .llms.databricks import DatabricksConfig, DatabricksEmbeddingConfig from .llms.predibase import PredibaseConfig from .llms.anthropic_text import AnthropicTextConfig from .llms.replicate import ReplicateConfig diff --git a/litellm/llms/custom_httpx/http_handler.py b/litellm/llms/custom_httpx/http_handler.py index 0adbd95bf..4df25944b 100644 --- a/litellm/llms/custom_httpx/http_handler.py +++ b/litellm/llms/custom_httpx/http_handler.py @@ -7,8 +7,12 @@ _DEFAULT_TIMEOUT = httpx.Timeout(timeout=5.0, connect=5.0) class AsyncHTTPHandler: def __init__( - self, timeout: httpx.Timeout = _DEFAULT_TIMEOUT, concurrent_limit=1000 + self, + timeout: Optional[Union[float, httpx.Timeout]] = None, + concurrent_limit=1000, ): + if timeout is None: + timeout = _DEFAULT_TIMEOUT # Create a client with a connection pool self.client = httpx.AsyncClient( timeout=timeout, @@ -59,7 +63,7 @@ class AsyncHTTPHandler: class HTTPHandler: def __init__( self, - timeout: Optional[httpx.Timeout] = None, + timeout: Optional[Union[float, httpx.Timeout]] = None, concurrent_limit=1000, client: Optional[httpx.Client] = None, ): diff --git a/litellm/llms/databricks.py b/litellm/llms/databricks.py new file mode 100644 index 000000000..b306d425e --- /dev/null +++ b/litellm/llms/databricks.py @@ -0,0 +1,678 @@ +# What is this? +## Handler file for databricks API https://docs.databricks.com/en/machine-learning/foundation-models/api-reference.html#chat-request +import os, types +import json +from enum import Enum +import requests, copy # type: ignore +import time +from typing import Callable, Optional, List, Union, Tuple, Literal +from litellm.utils import ( + ModelResponse, + Usage, + map_finish_reason, + CustomStreamWrapper, + EmbeddingResponse, +) +import litellm +from .prompt_templates.factory import prompt_factory, custom_prompt +from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler +from .base import BaseLLM +import httpx # type: ignore +from litellm.types.llms.databricks import GenericStreamingChunk + + +class DatabricksError(Exception): + def __init__(self, status_code, message): + self.status_code = status_code + self.message = message + self.request = httpx.Request(method="POST", url="https://docs.databricks.com/") + self.response = httpx.Response(status_code=status_code, request=self.request) + super().__init__( + self.message + ) # Call the base class constructor with the parameters it needs + + +class DatabricksConfig: + """ + Reference: https://docs.databricks.com/en/machine-learning/foundation-models/api-reference.html#chat-request + """ + + max_tokens: Optional[int] = None + temperature: Optional[int] = None + top_p: Optional[int] = None + top_k: Optional[int] = None + stop: Optional[Union[List[str], str]] = None + n: Optional[int] = None + + def __init__( + self, + max_tokens: Optional[int] = None, + temperature: Optional[int] = None, + top_p: Optional[int] = None, + top_k: Optional[int] = None, + stop: Optional[Union[List[str], str]] = None, + n: Optional[int] = None, + ) -> None: + locals_ = locals() + for key, value in locals_.items(): + if key != "self" and value is not None: + setattr(self.__class__, key, value) + + @classmethod + def get_config(cls): + return { + k: v + for k, v in cls.__dict__.items() + if not k.startswith("__") + and not isinstance( + v, + ( + types.FunctionType, + types.BuiltinFunctionType, + classmethod, + staticmethod, + ), + ) + and v is not None + } + + def get_supported_openai_params(self): + return ["stream", "stop", "temperature", "top_p", "max_tokens", "n"] + + def map_openai_params(self, non_default_params: dict, optional_params: dict): + for param, value in non_default_params.items(): + if param == "max_tokens": + optional_params["max_tokens"] = value + if param == "n": + optional_params["n"] = value + if param == "stream" and value == True: + optional_params["stream"] = value + if param == "temperature": + optional_params["temperature"] = value + if param == "top_p": + optional_params["top_p"] = value + if param == "stop": + optional_params["stop"] = value + return optional_params + + def _chunk_parser(self, chunk_data: str) -> GenericStreamingChunk: + try: + text = "" + is_finished = False + finish_reason = None + logprobs = None + usage = None + original_chunk = None # this is used for function/tool calling + chunk_data = chunk_data.replace("data:", "") + chunk_data = chunk_data.strip() + if len(chunk_data) == 0: + return { + "text": "", + "is_finished": is_finished, + "finish_reason": finish_reason, + } + chunk_data_dict = json.loads(chunk_data) + str_line = litellm.ModelResponse(**chunk_data_dict, stream=True) + + if len(str_line.choices) > 0: + if ( + str_line.choices[0].delta is not None # type: ignore + and str_line.choices[0].delta.content is not None # type: ignore + ): + text = str_line.choices[0].delta.content # type: ignore + else: # function/tool calling chunk - when content is None. in this case we just return the original chunk from openai + original_chunk = str_line + if str_line.choices[0].finish_reason: + is_finished = True + finish_reason = str_line.choices[0].finish_reason + if finish_reason == "content_filter": + if hasattr(str_line.choices[0], "content_filter_result"): + error_message = json.dumps( + str_line.choices[0].content_filter_result # type: ignore + ) + else: + error_message = "Azure Response={}".format( + str(dict(str_line)) + ) + raise litellm.AzureOpenAIError( + status_code=400, message=error_message + ) + + # checking for logprobs + if ( + hasattr(str_line.choices[0], "logprobs") + and str_line.choices[0].logprobs is not None + ): + logprobs = str_line.choices[0].logprobs + else: + logprobs = None + + usage = getattr(str_line, "usage", None) + + return GenericStreamingChunk( + text=text, + is_finished=is_finished, + finish_reason=finish_reason, + logprobs=logprobs, + original_chunk=original_chunk, + usage=usage, + ) + except Exception as e: + raise e + + +class DatabricksEmbeddingConfig: + """ + Reference: https://learn.microsoft.com/en-us/azure/databricks/machine-learning/foundation-models/api-reference#--embedding-task + """ + + instruction: Optional[str] = ( + None # An optional instruction to pass to the embedding model. BGE Authors recommend 'Represent this sentence for searching relevant passages:' for retrieval queries + ) + + def __init__(self, instruction: Optional[str] = None) -> None: + locals_ = locals() + for key, value in locals_.items(): + if key != "self" and value is not None: + setattr(self.__class__, key, value) + + @classmethod + def get_config(cls): + return { + k: v + for k, v in cls.__dict__.items() + if not k.startswith("__") + and not isinstance( + v, + ( + types.FunctionType, + types.BuiltinFunctionType, + classmethod, + staticmethod, + ), + ) + and v is not None + } + + def get_supported_openai_params( + self, + ): # no optional openai embedding params supported + return [] + + def map_openai_params(self, non_default_params: dict, optional_params: dict): + return optional_params + + +class DatabricksChatCompletion(BaseLLM): + def __init__(self) -> None: + super().__init__() + + # makes headers for API call + + def _validate_environment( + self, + api_key: Optional[str], + api_base: Optional[str], + endpoint_type: Literal["chat_completions", "embeddings"], + ) -> Tuple[str, dict]: + if api_key is None: + raise DatabricksError( + status_code=400, + message="Missing Databricks API Key - A call is being made to Databricks but no key is set either in the environment variables (DATABRICKS_API_KEY) or via params", + ) + + if api_base is None: + raise DatabricksError( + status_code=400, + message="Missing Databricks API Base - A call is being made to Databricks but no api base is set either in the environment variables (DATABRICKS_API_BASE) or via params", + ) + + headers = { + "Authorization": "Bearer {}".format(api_key), + "Content-Type": "application/json", + } + + if endpoint_type == "chat_completions": + api_base = "{}/chat/completions".format(api_base) + elif endpoint_type == "embeddings": + api_base = "{}/embeddings".format(api_base) + return api_base, headers + + def process_response( + self, + model: str, + response: Union[requests.Response, httpx.Response], + model_response: ModelResponse, + stream: bool, + logging_obj: litellm.utils.Logging, + optional_params: dict, + api_key: str, + data: Union[dict, str], + messages: List, + print_verbose, + encoding, + ) -> ModelResponse: + ## LOGGING + logging_obj.post_call( + input=messages, + api_key=api_key, + original_response=response.text, + additional_args={"complete_input_dict": data}, + ) + print_verbose(f"raw model_response: {response.text}") + ## RESPONSE OBJECT + try: + completion_response = response.json() + except: + raise DatabricksError( + message=response.text, status_code=response.status_code + ) + if "error" in completion_response: + raise DatabricksError( + message=str(completion_response["error"]), + status_code=response.status_code, + ) + else: + text_content = "" + tool_calls = [] + for content in completion_response["content"]: + if content["type"] == "text": + text_content += content["text"] + ## TOOL CALLING + elif content["type"] == "tool_use": + tool_calls.append( + { + "id": content["id"], + "type": "function", + "function": { + "name": content["name"], + "arguments": json.dumps(content["input"]), + }, + } + ) + + _message = litellm.Message( + tool_calls=tool_calls, + content=text_content or None, + ) + model_response.choices[0].message = _message # type: ignore + model_response._hidden_params["original_response"] = completion_response[ + "content" + ] # allow user to access raw anthropic tool calling response + + model_response.choices[0].finish_reason = map_finish_reason( + completion_response["stop_reason"] + ) + + ## CALCULATING USAGE + prompt_tokens = completion_response["usage"]["input_tokens"] + completion_tokens = completion_response["usage"]["output_tokens"] + total_tokens = prompt_tokens + completion_tokens + + model_response["created"] = int(time.time()) + model_response["model"] = model + usage = Usage( + prompt_tokens=prompt_tokens, + completion_tokens=completion_tokens, + total_tokens=total_tokens, + ) + setattr(model_response, "usage", usage) # type: ignore + return model_response + + async def acompletion_stream_function( + self, + model: str, + messages: list, + api_base: str, + custom_prompt_dict: dict, + model_response: ModelResponse, + print_verbose: Callable, + encoding, + api_key, + logging_obj, + stream, + data: dict, + optional_params=None, + litellm_params=None, + logger_fn=None, + headers={}, + ): + self.async_handler = AsyncHTTPHandler( + timeout=httpx.Timeout(timeout=600.0, connect=5.0) + ) + data["stream"] = True + try: + response = await self.async_handler.post( + api_base, headers=headers, data=json.dumps(data), stream=True + ) + response.raise_for_status() + + completion_stream = response.aiter_lines() + except httpx.HTTPStatusError as e: + raise DatabricksError( + status_code=e.response.status_code, message=response.text + ) + except httpx.TimeoutException as e: + raise DatabricksError(status_code=408, message="Timeout error occurred.") + except Exception as e: + raise DatabricksError(status_code=500, message=str(e)) + + streamwrapper = CustomStreamWrapper( + completion_stream=completion_stream, + model=model, + custom_llm_provider="databricks", + logging_obj=logging_obj, + ) + return streamwrapper + + async def acompletion_function( + self, + model: str, + messages: list, + api_base: str, + custom_prompt_dict: dict, + model_response: ModelResponse, + print_verbose: Callable, + encoding, + api_key, + logging_obj, + stream, + data: dict, + optional_params: dict, + litellm_params=None, + logger_fn=None, + headers={}, + timeout: Optional[Union[float, httpx.Timeout]] = None, + ) -> ModelResponse: + if timeout is None: + timeout = httpx.Timeout(timeout=600.0, connect=5.0) + + self.async_handler = AsyncHTTPHandler(timeout=timeout) + + try: + response = await self.async_handler.post( + api_base, headers=headers, data=json.dumps(data) + ) + response.raise_for_status() + + response_json = response.json() + except httpx.HTTPStatusError as e: + raise DatabricksError( + status_code=e.response.status_code, + message=response.text if response else str(e), + ) + except httpx.TimeoutException as e: + raise DatabricksError(status_code=408, message="Timeout error occurred.") + except Exception as e: + raise DatabricksError(status_code=500, message=str(e)) + + return ModelResponse(**response_json) + + def completion( + self, + model: str, + messages: list, + api_base: str, + custom_prompt_dict: dict, + model_response: ModelResponse, + print_verbose: Callable, + encoding, + api_key, + logging_obj, + optional_params: dict, + acompletion=None, + litellm_params=None, + logger_fn=None, + headers={}, + timeout: Optional[Union[float, httpx.Timeout]] = None, + client: Optional[Union[HTTPHandler, AsyncHTTPHandler]] = None, + ): + api_base, headers = self._validate_environment( + api_base=api_base, api_key=api_key, endpoint_type="chat_completions" + ) + ## Load Config + config = litellm.DatabricksConfig().get_config() + for k, v in config.items(): + if ( + k not in optional_params + ): # completion(top_k=3) > anthropic_config(top_k=3) <- allows for dynamic variables to be passed in + optional_params[k] = v + + stream = optional_params.pop("stream", None) + + data = { + "model": model, + "messages": messages, + **optional_params, + } + + ## LOGGING + logging_obj.pre_call( + input=messages, + api_key=api_key, + additional_args={ + "complete_input_dict": data, + "api_base": api_base, + "headers": headers, + }, + ) + if acompletion == True: + if ( + stream is not None and stream == True + ): # if function call - fake the streaming (need complete blocks for output parsing in openai format) + print_verbose("makes async anthropic streaming POST request") + data["stream"] = stream + return self.acompletion_stream_function( + model=model, + messages=messages, + data=data, + api_base=api_base, + custom_prompt_dict=custom_prompt_dict, + model_response=model_response, + print_verbose=print_verbose, + encoding=encoding, + api_key=api_key, + logging_obj=logging_obj, + optional_params=optional_params, + stream=stream, + litellm_params=litellm_params, + logger_fn=logger_fn, + headers=headers, + ) + else: + return self.acompletion_function( + model=model, + messages=messages, + data=data, + api_base=api_base, + custom_prompt_dict=custom_prompt_dict, + model_response=model_response, + print_verbose=print_verbose, + encoding=encoding, + api_key=api_key, + logging_obj=logging_obj, + optional_params=optional_params, + stream=stream, + litellm_params=litellm_params, + logger_fn=logger_fn, + headers=headers, + timeout=timeout, + ) + else: + if client is None or isinstance(client, AsyncHTTPHandler): + self.client = HTTPHandler(timeout=timeout) # type: ignore + else: + self.client = client + ## COMPLETION CALL + if ( + stream is not None and stream == True + ): # if function call - fake the streaming (need complete blocks for output parsing in openai format) + print_verbose("makes dbrx streaming POST request") + data["stream"] = stream + try: + response = self.client.post( + api_base, headers=headers, data=json.dumps(data), stream=stream + ) + response.raise_for_status() + completion_stream = response.iter_lines() + except httpx.HTTPStatusError as e: + raise DatabricksError( + status_code=e.response.status_code, message=response.text + ) + except httpx.TimeoutException as e: + raise DatabricksError( + status_code=408, message="Timeout error occurred." + ) + except Exception as e: + raise DatabricksError(status_code=408, message=str(e)) + + streaming_response = CustomStreamWrapper( + completion_stream=completion_stream, + model=model, + custom_llm_provider="databricks", + logging_obj=logging_obj, + ) + return streaming_response + + else: + try: + response = self.client.post( + api_base, headers=headers, data=json.dumps(data) + ) + response.raise_for_status() + + response_json = response.json() + except httpx.HTTPStatusError as e: + raise DatabricksError( + status_code=e.response.status_code, message=response.text + ) + except httpx.TimeoutException as e: + raise DatabricksError( + status_code=408, message="Timeout error occurred." + ) + except Exception as e: + raise DatabricksError(status_code=500, message=str(e)) + + return ModelResponse(**response_json) + + async def aembedding( + self, + input: list, + data: dict, + model_response: ModelResponse, + timeout: float, + api_key: str, + api_base: str, + logging_obj, + headers: dict, + client=None, + ) -> EmbeddingResponse: + response = None + try: + if client is None or isinstance(client, AsyncHTTPHandler): + self.async_client = AsyncHTTPHandler(timeout=timeout) # type: ignore + else: + self.async_client = client + + try: + response = await self.async_client.post( + api_base, + headers=headers, + data=json.dumps(data), + ) # type: ignore + + response.raise_for_status() + + response_json = response.json() + except httpx.HTTPStatusError as e: + raise DatabricksError( + status_code=e.response.status_code, + message=response.text if response else str(e), + ) + except httpx.TimeoutException as e: + raise DatabricksError( + status_code=408, message="Timeout error occurred." + ) + except Exception as e: + raise DatabricksError(status_code=500, message=str(e)) + + ## LOGGING + logging_obj.post_call( + input=input, + api_key=api_key, + additional_args={"complete_input_dict": data}, + original_response=response_json, + ) + return EmbeddingResponse(**response_json) + except Exception as e: + ## LOGGING + logging_obj.post_call( + input=input, + api_key=api_key, + original_response=str(e), + ) + raise e + + def embedding( + self, + model: str, + input: list, + timeout: float, + logging_obj, + api_key: Optional[str], + api_base: Optional[str], + optional_params: dict, + model_response: Optional[litellm.utils.EmbeddingResponse] = None, + client=None, + aembedding=None, + ) -> EmbeddingResponse: + api_base, headers = self._validate_environment( + api_base=api_base, api_key=api_key, endpoint_type="embeddings" + ) + model = model + data = {"model": model, "input": input, **optional_params} + + ## LOGGING + logging_obj.pre_call( + input=input, + api_key=api_key, + additional_args={"complete_input_dict": data, "api_base": api_base}, + ) + + if aembedding == True: + return self.aembedding(data=data, input=input, logging_obj=logging_obj, model_response=model_response, api_base=api_base, api_key=api_key, timeout=timeout, client=client, headers=headers) # type: ignore + if client is None or isinstance(client, AsyncHTTPHandler): + self.client = HTTPHandler(timeout=timeout) # type: ignore + else: + self.client = client + + ## EMBEDDING CALL + try: + response = self.client.post( + api_base, + headers=headers, + data=json.dumps(data), + ) # type: ignore + + response.raise_for_status() # type: ignore + + response_json = response.json() # type: ignore + except httpx.HTTPStatusError as e: + raise DatabricksError( + status_code=e.response.status_code, + message=response.text if response else str(e), + ) + except httpx.TimeoutException as e: + raise DatabricksError(status_code=408, message="Timeout error occurred.") + except Exception as e: + raise DatabricksError(status_code=500, message=str(e)) + + ## LOGGING + logging_obj.post_call( + input=input, + api_key=api_key, + additional_args={"complete_input_dict": data}, + original_response=response_json, + ) + + return litellm.EmbeddingResponse(**response_json) diff --git a/litellm/llms/openai.py b/litellm/llms/openai.py index 9d143f5d9..2e0196faa 100644 --- a/litellm/llms/openai.py +++ b/litellm/llms/openai.py @@ -404,6 +404,7 @@ class OpenAIChatCompletion(BaseLLM): self, model_response: ModelResponse, timeout: Union[float, httpx.Timeout], + optional_params: dict, model: Optional[str] = None, messages: Optional[list] = None, print_verbose: Optional[Callable] = None, @@ -411,7 +412,6 @@ class OpenAIChatCompletion(BaseLLM): api_base: Optional[str] = None, acompletion: bool = False, logging_obj=None, - optional_params=None, litellm_params=None, logger_fn=None, headers: Optional[dict] = None, @@ -795,10 +795,10 @@ class OpenAIChatCompletion(BaseLLM): model: str, input: list, timeout: float, + logging_obj, api_key: Optional[str] = None, api_base: Optional[str] = None, model_response: Optional[litellm.utils.EmbeddingResponse] = None, - logging_obj=None, optional_params=None, client=None, aembedding=None, diff --git a/litellm/main.py b/litellm/main.py index 42c4eb8ff..dc4cf0001 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -73,6 +73,7 @@ from .llms import ( ) from .llms.openai import OpenAIChatCompletion, OpenAITextCompletion from .llms.azure import AzureChatCompletion +from .llms.databricks import DatabricksChatCompletion from .llms.azure_text import AzureTextCompletion from .llms.anthropic import AnthropicChatCompletion from .llms.anthropic_text import AnthropicTextCompletion @@ -111,6 +112,7 @@ from litellm.utils import ( ####### ENVIRONMENT VARIABLES ################### openai_chat_completions = OpenAIChatCompletion() openai_text_completions = OpenAITextCompletion() +databricks_chat_completions = DatabricksChatCompletion() anthropic_chat_completions = AnthropicChatCompletion() anthropic_text_completions = AnthropicTextCompletion() azure_chat_completions = AzureChatCompletion() @@ -329,6 +331,7 @@ async def acompletion( or custom_llm_provider == "anthropic" or custom_llm_provider == "predibase" or custom_llm_provider == "bedrock" + or custom_llm_provider == "databricks" or custom_llm_provider in litellm.openai_compatible_providers ): # currently implemented aiohttp calls for just azure, openai, hf, ollama, vertex ai soon all. init_response = await loop.run_in_executor(None, func_with_context) @@ -1615,6 +1618,61 @@ def completion( ) return response response = model_response + elif custom_llm_provider == "databricks": + api_base = ( + api_base # for databricks we check in get_llm_provider and pass in the api base from there + or litellm.api_base + or os.getenv("DATABRICKS_API_BASE") + ) + + # set API KEY + api_key = ( + api_key + or litellm.api_key # for databricks we check in get_llm_provider and pass in the api key from there + or litellm.databricks_key + or get_secret("DATABRICKS_API_KEY") + ) + + headers = headers or litellm.headers + + ## COMPLETION CALL + try: + response = databricks_chat_completions.completion( + model=model, + messages=messages, + headers=headers, + model_response=model_response, + print_verbose=print_verbose, + api_key=api_key, + api_base=api_base, + acompletion=acompletion, + logging_obj=logging, + optional_params=optional_params, + litellm_params=litellm_params, + logger_fn=logger_fn, + timeout=timeout, # type: ignore + custom_prompt_dict=custom_prompt_dict, + client=client, # pass AsyncOpenAI, OpenAI client + encoding=encoding, + ) + except Exception as e: + ## LOGGING - log the original exception returned + logging.post_call( + input=messages, + api_key=api_key, + original_response=str(e), + additional_args={"headers": headers}, + ) + raise e + + if optional_params.get("stream", False): + ## LOGGING + logging.post_call( + input=messages, + api_key=api_key, + original_response=response, + additional_args={"headers": headers}, + ) elif custom_llm_provider == "openrouter": api_base = api_base or litellm.api_base or "https://openrouter.ai/api/v1" @@ -2669,7 +2727,7 @@ def batch_completion_models_all_responses(*args, **kwargs): ### EMBEDDING ENDPOINTS #################### @client -async def aembedding(*args, **kwargs): +async def aembedding(*args, **kwargs) -> EmbeddingResponse: """ Asynchronously calls the `embedding` function with the given arguments and keyword arguments. @@ -2714,12 +2772,13 @@ async def aembedding(*args, **kwargs): or custom_llm_provider == "fireworks_ai" or custom_llm_provider == "ollama" or custom_llm_provider == "vertex_ai" + or custom_llm_provider == "databricks" ): # currently implemented aiohttp calls for just azure and openai, soon all. # Await normally init_response = await loop.run_in_executor(None, func_with_context) - if isinstance(init_response, dict) or isinstance( - init_response, ModelResponse - ): ## CACHING SCENARIO + if isinstance(init_response, dict): + response = EmbeddingResponse(**init_response) + elif isinstance(init_response, EmbeddingResponse): ## CACHING SCENARIO response = init_response elif asyncio.iscoroutine(init_response): response = await init_response @@ -2759,7 +2818,7 @@ def embedding( litellm_logging_obj=None, logger_fn=None, **kwargs, -): +) -> EmbeddingResponse: """ Embedding function that calls an API to generate embeddings for the given input. @@ -2907,7 +2966,7 @@ def embedding( ) try: response = None - logging = litellm_logging_obj + logging: Logging = litellm_logging_obj # type: ignore logging.update_environment_variables( model=model, user=user, @@ -2997,6 +3056,32 @@ def embedding( client=client, aembedding=aembedding, ) + elif custom_llm_provider == "databricks": + api_base = ( + api_base or litellm.api_base or get_secret("DATABRICKS_API_BASE") + ) # type: ignore + + # set API KEY + api_key = ( + api_key + or litellm.api_key + or litellm.databricks_key + or get_secret("DATABRICKS_API_KEY") + ) # type: ignore + + ## EMBEDDING CALL + response = databricks_chat_completions.embedding( + model=model, + input=input, + api_base=api_base, + api_key=api_key, + logging_obj=logging, + timeout=timeout, + model_response=EmbeddingResponse(), + optional_params=optional_params, + client=client, + aembedding=aembedding, + ) elif custom_llm_provider == "cohere": cohere_key = ( api_key diff --git a/litellm/tests/test_completion.py b/litellm/tests/test_completion.py index f599b1817..1b319c160 100644 --- a/litellm/tests/test_completion.py +++ b/litellm/tests/test_completion.py @@ -131,6 +131,27 @@ def test_completion_azure_command_r(): pytest.fail(f"Error occurred: {e}") +@pytest.mark.parametrize("sync_mode", [True, False]) +@pytest.mark.asyncio +async def test_completion_databricks(sync_mode): + litellm.set_verbose = True + + if sync_mode: + response: litellm.ModelResponse = completion( + model="databricks/databricks-dbrx-instruct", + messages=[{"role": "user", "content": "Hey, how's it going?"}], + ) # type: ignore + + else: + response: litellm.ModelResponse = await litellm.acompletion( + model="databricks/databricks-dbrx-instruct", + messages=[{"role": "user", "content": "Hey, how's it going?"}], + ) # type: ignore + print(f"response: {response}") + + response_format_tests(response=response) + + # @pytest.mark.skip(reason="local test") @pytest.mark.parametrize("sync_mode", [True, False]) @pytest.mark.asyncio diff --git a/litellm/tests/test_embedding.py b/litellm/tests/test_embedding.py index a441b0e70..30988dba1 100644 --- a/litellm/tests/test_embedding.py +++ b/litellm/tests/test_embedding.py @@ -535,6 +535,37 @@ async def test_triton_embeddings(): pytest.fail(f"Error occurred: {e}") +@pytest.mark.parametrize("sync_mode", [True, False]) +@pytest.mark.asyncio +async def test_databricks_embeddings(sync_mode): + try: + litellm.set_verbose = True + litellm.drop_params = True + + if sync_mode: + response = litellm.embedding( + model="databricks/databricks-bge-large-en", + input=["good morning from litellm"], + instruction="Represent this sentence for searching relevant passages:", + ) + else: + response = await litellm.aembedding( + model="databricks/databricks-bge-large-en", + input=["good morning from litellm"], + instruction="Represent this sentence for searching relevant passages:", + ) + + print(f"response: {response}") + + openai.types.CreateEmbeddingResponse.model_validate( + response.model_dump(), strict=True + ) + # stubbed endpoint is setup to return this + # assert response.data[0]["embedding"] == [0.1, 0.2, 0.3] + except Exception as e: + pytest.fail(f"Error occurred: {e}") + + # test_voyage_embeddings() # def test_xinference_embeddings(): # try: diff --git a/litellm/tests/test_optional_params.py b/litellm/tests/test_optional_params.py index 5c33cfa0e..31ee3d99b 100644 --- a/litellm/tests/test_optional_params.py +++ b/litellm/tests/test_optional_params.py @@ -83,6 +83,20 @@ def test_azure_optional_params_embeddings(): assert optional_params["user"] == "John" +def test_databricks_optional_params(): + litellm.drop_params = True + optional_params = get_optional_params( + model="", + user="John", + custom_llm_provider="databricks", + max_tokens=10, + temperature=0.2, + ) + print(f"optional_params: {optional_params}") + assert len(optional_params) == 2 + assert "user" not in optional_params + + def test_azure_gpt_optional_params_gpt_vision(): # for OpenAI, Azure all extra params need to get passed as extra_body to OpenAI python. We assert we actually set extra_body here optional_params = litellm.utils.get_optional_params( diff --git a/litellm/tests/test_streaming.py b/litellm/tests/test_streaming.py index 554a77eef..237d3895d 100644 --- a/litellm/tests/test_streaming.py +++ b/litellm/tests/test_streaming.py @@ -951,6 +951,62 @@ def test_vertex_ai_stream(): # test_completion_vertexai_stream_bad_key() +@pytest.mark.parametrize("sync_mode", [True, False]) +@pytest.mark.asyncio +async def test_completion_databricks_streaming(sync_mode): + litellm.set_verbose = True + model_name = "databricks/databricks-dbrx-instruct" + try: + if sync_mode: + final_chunk: Optional[litellm.ModelResponse] = None + response: litellm.CustomStreamWrapper = completion( # type: ignore + model=model_name, + messages=messages, + max_tokens=10, # type: ignore + stream=True, + ) + complete_response = "" + # Add any assertions here to check the response + has_finish_reason = False + for idx, chunk in enumerate(response): + final_chunk = chunk + chunk, finished = streaming_format_tests(idx, chunk) + if finished: + has_finish_reason = True + break + complete_response += chunk + if has_finish_reason == False: + raise Exception("finish reason not set") + if complete_response.strip() == "": + raise Exception("Empty response received") + else: + response: litellm.CustomStreamWrapper = await litellm.acompletion( # type: ignore + model=model_name, + messages=messages, + max_tokens=100, # type: ignore + stream=True, + ) + complete_response = "" + # Add any assertions here to check the response + has_finish_reason = False + idx = 0 + final_chunk: Optional[litellm.ModelResponse] = None + async for chunk in response: + final_chunk = chunk + chunk, finished = streaming_format_tests(idx, chunk) + if finished: + has_finish_reason = True + break + complete_response += chunk + idx += 1 + if has_finish_reason == False: + raise Exception("finish reason not set") + if complete_response.strip() == "": + raise Exception("Empty response received") + except Exception as e: + pytest.fail(f"Error occurred: {e}") + + @pytest.mark.parametrize("sync_mode", [False, True]) @pytest.mark.asyncio async def test_completion_replicate_llama3_streaming(sync_mode): diff --git a/litellm/types/llms/databricks.py b/litellm/types/llms/databricks.py new file mode 100644 index 000000000..770e05fe3 --- /dev/null +++ b/litellm/types/llms/databricks.py @@ -0,0 +1,21 @@ +from typing import TypedDict, Any, Union, Optional +import json +from typing_extensions import ( + Self, + Protocol, + TypeGuard, + override, + get_origin, + runtime_checkable, + Required, +) +from pydantic import BaseModel + + +class GenericStreamingChunk(TypedDict, total=False): + text: Required[str] + is_finished: Required[bool] + finish_reason: Required[Optional[str]] + logprobs: Optional[BaseModel] + original_chunk: Optional[BaseModel] + usage: Optional[BaseModel] diff --git a/litellm/utils.py b/litellm/utils.py index 72b734ffa..33dfb261e 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -568,7 +568,7 @@ class StreamingChoices(OpenAIObject): if delta is not None: if isinstance(delta, Delta): self.delta = delta - if isinstance(delta, dict): + elif isinstance(delta, dict): self.delta = Delta(**delta) else: self.delta = Delta() @@ -676,7 +676,10 @@ class ModelResponse(OpenAIObject): created = created model = model if usage is not None: - usage = usage + if isinstance(usage, dict): + usage = Usage(**usage) + else: + usage = usage elif stream is None or stream == False: usage = Usage() elif ( @@ -763,7 +766,13 @@ class EmbeddingResponse(OpenAIObject): _hidden_params: dict = {} def __init__( - self, model=None, usage=None, stream=False, response_ms=None, data=None + self, + model=None, + usage=None, + stream=False, + response_ms=None, + data=None, + **params, ): object = "list" if response_ms: @@ -5035,6 +5044,19 @@ def get_optional_params_embeddings( default_params = {"user": None, "encoding_format": None, "dimensions": None} + def _check_valid_arg(supported_params: Optional[list]): + if supported_params is None: + return + unsupported_params = {} + for k in non_default_params.keys(): + if k not in supported_params: + unsupported_params[k] = non_default_params[k] + if unsupported_params and not litellm.drop_params: + raise UnsupportedParamsError( + status_code=500, + message=f"{custom_llm_provider} does not support parameters: {unsupported_params}, for model={model}. To drop these, set `litellm.drop_params=True` or for proxy:\n\n`litellm_settings:\n drop_params: true`\n", + ) + non_default_params = { k: v for k, v in passed_params.items() @@ -5060,6 +5082,18 @@ def get_optional_params_embeddings( non_default_params.pop(k, None) final_params = {**non_default_params, **kwargs} return final_params + if custom_llm_provider == "databricks": + supported_params = get_supported_openai_params( + model=model or "", + custom_llm_provider="databricks", + request_type="embeddings", + ) + _check_valid_arg(supported_params=supported_params) + optional_params = litellm.DatabricksEmbeddingConfig().map_openai_params( + non_default_params=non_default_params, optional_params={} + ) + final_params = {**optional_params, **kwargs} + return final_params if custom_llm_provider == "vertex_ai": if len(non_default_params.keys()) > 0: if litellm.drop_params is True: # drop the unsupported non-default values @@ -5846,6 +5880,14 @@ def get_optional_params( optional_params = litellm.MistralConfig().map_openai_params( non_default_params=non_default_params, optional_params=optional_params ) + elif custom_llm_provider == "databricks": + supported_params = get_supported_openai_params( + model=model, custom_llm_provider=custom_llm_provider + ) + _check_valid_arg(supported_params=supported_params) + optional_params = litellm.DatabricksConfig().map_openai_params( + non_default_params=non_default_params, optional_params=optional_params + ) elif custom_llm_provider == "groq": supported_params = get_supported_openai_params( model=model, custom_llm_provider=custom_llm_provider @@ -6333,7 +6375,11 @@ def get_first_chars_messages(kwargs: dict) -> str: return "" -def get_supported_openai_params(model: str, custom_llm_provider: str) -> Optional[list]: +def get_supported_openai_params( + model: str, + custom_llm_provider: str, + request_type: Literal["chat_completion", "embeddings"] = "chat_completion", +) -> Optional[list]: """ Returns the supported openai params for a given model + provider @@ -6506,6 +6552,11 @@ def get_supported_openai_params(model: str, custom_llm_provider: str) -> Optiona "frequency_penalty", "presence_penalty", ] + elif custom_llm_provider == "databricks": + if request_type == "chat_completion": + return litellm.DatabricksConfig().get_supported_openai_params() + elif request_type == "embeddings": + return litellm.DatabricksEmbeddingConfig().get_supported_openai_params() elif custom_llm_provider == "palm" or custom_llm_provider == "gemini": return ["temperature", "top_p", "stream", "n", "stop", "max_tokens"] elif custom_llm_provider == "vertex_ai": @@ -11017,6 +11068,8 @@ class CustomStreamWrapper: elif self.custom_llm_provider and self.custom_llm_provider == "clarifai": response_obj = self.handle_clarifai_completion_chunk(chunk) completion_obj["content"] = response_obj["text"] + if response_obj["is_finished"]: + self.received_finish_reason = response_obj["finish_reason"] elif self.model == "replicate" or self.custom_llm_provider == "replicate": response_obj = self.handle_replicate_chunk(chunk) completion_obj["content"] = response_obj["text"] @@ -11268,6 +11321,17 @@ class CustomStreamWrapper: and self.stream_options.get("include_usage", False) == True ): model_response.usage = response_obj["usage"] + elif self.custom_llm_provider == "databricks": + response_obj = litellm.DatabricksConfig()._chunk_parser(chunk) + completion_obj["content"] = response_obj["text"] + print_verbose(f"completion obj content: {completion_obj['content']}") + if response_obj["is_finished"]: + self.received_finish_reason = response_obj["finish_reason"] + if ( + self.stream_options + and self.stream_options.get("include_usage", False) == True + ): + model_response.usage = response_obj["usage"] elif self.custom_llm_provider == "azure_text": response_obj = self.handle_azure_text_completion_chunk(chunk) completion_obj["content"] = response_obj["text"] @@ -11677,6 +11741,7 @@ class CustomStreamWrapper: or self.custom_llm_provider == "replicate" or self.custom_llm_provider == "cached_response" or self.custom_llm_provider == "predibase" + or self.custom_llm_provider == "databricks" or self.custom_llm_provider == "bedrock" or self.custom_llm_provider in litellm.openai_compatible_endpoints ):