litellm/litellm/llms/databricks.py
Krish Dholakia f9e6507cd1
LiteLLM Minor Fixes + Improvements (#5474)
* feat(proxy/_types.py): add lago billing to callbacks ui

Closes https://github.com/BerriAI/litellm/issues/5472

* fix(anthropic.py): return anthropic prompt caching information

Fixes https://github.com/BerriAI/litellm/issues/5364

* feat(bedrock/chat.py): support 'json_schema' for bedrock models

Closes https://github.com/BerriAI/litellm/issues/5434

* fix(bedrock/embed/embeddings.py): support async embeddings for amazon titan models

* fix: linting fixes

* fix: handle key errors

* fix(bedrock/chat.py): fix bedrock ai21 streaming object

* feat(bedrock/embed): support bedrock embedding optional params

* fix(databricks.py): fix usage chunk

* fix(internal_user_endpoints.py): apply internal user defaults, if user role updated

Fixes issue where user update wouldn't apply defaults

* feat(slack_alerting.py): provide multiple slack channels for a given alert type

multiple channels might be interested in receiving an alert for a given type

* docs(alerting.md): add multiple channel alerting to docs
2024-09-02 14:29:57 -07:00

794 lines
27 KiB
Python

# What is this?
## Handler file for databricks API https://docs.databricks.com/en/machine-learning/foundation-models/api-reference.html#chat-request
import copy
import json
import os
import time
import types
from enum import Enum
from functools import partial
from typing import Any, Callable, List, Literal, Optional, Tuple, Union
import httpx # type: ignore
import requests # type: ignore
import litellm
from litellm.litellm_core_utils.core_helpers import map_finish_reason
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
from litellm.types.llms.openai import (
ChatCompletionDeltaChunk,
ChatCompletionResponseMessage,
ChatCompletionToolCallChunk,
ChatCompletionToolCallFunctionChunk,
ChatCompletionUsageBlock,
)
from litellm.types.utils import (
CustomStreamingDecoder,
GenericStreamingChunk,
ProviderField,
)
from litellm.utils import CustomStreamWrapper, EmbeddingResponse, ModelResponse, Usage
from .base import BaseLLM
from .prompt_templates.factory import custom_prompt, prompt_factory
class DatabricksError(Exception):
def __init__(self, status_code, message):
self.status_code = status_code
self.message = message
self.request = httpx.Request(method="POST", url="https://docs.databricks.com/")
self.response = httpx.Response(status_code=status_code, request=self.request)
super().__init__(
self.message
) # Call the base class constructor with the parameters it needs
class DatabricksConfig:
"""
Reference: https://docs.databricks.com/en/machine-learning/foundation-models/api-reference.html#chat-request
"""
max_tokens: Optional[int] = None
temperature: Optional[int] = None
top_p: Optional[int] = None
top_k: Optional[int] = None
stop: Optional[Union[List[str], str]] = None
n: Optional[int] = None
def __init__(
self,
max_tokens: Optional[int] = None,
temperature: Optional[int] = None,
top_p: Optional[int] = None,
top_k: Optional[int] = None,
stop: Optional[Union[List[str], str]] = None,
n: Optional[int] = None,
) -> None:
locals_ = locals()
for key, value in locals_.items():
if key != "self" and value is not None:
setattr(self.__class__, key, value)
@classmethod
def get_config(cls):
return {
k: v
for k, v in cls.__dict__.items()
if not k.startswith("__")
and not isinstance(
v,
(
types.FunctionType,
types.BuiltinFunctionType,
classmethod,
staticmethod,
),
)
and v is not None
}
def get_required_params(self) -> List[ProviderField]:
"""For a given provider, return it's required fields with a description"""
return [
ProviderField(
field_name="api_key",
field_type="string",
field_description="Your Databricks API Key.",
field_value="dapi...",
),
ProviderField(
field_name="api_base",
field_type="string",
field_description="Your Databricks API Base.",
field_value="https://adb-..",
),
]
def get_supported_openai_params(self):
return ["stream", "stop", "temperature", "top_p", "max_tokens", "n"]
def map_openai_params(self, non_default_params: dict, optional_params: dict):
for param, value in non_default_params.items():
if param == "max_tokens":
optional_params["max_tokens"] = value
if param == "n":
optional_params["n"] = value
if param == "stream" and value == True:
optional_params["stream"] = value
if param == "temperature":
optional_params["temperature"] = value
if param == "top_p":
optional_params["top_p"] = value
if param == "stop":
optional_params["stop"] = value
return optional_params
class DatabricksEmbeddingConfig:
"""
Reference: https://learn.microsoft.com/en-us/azure/databricks/machine-learning/foundation-models/api-reference#--embedding-task
"""
instruction: Optional[str] = (
None # An optional instruction to pass to the embedding model. BGE Authors recommend 'Represent this sentence for searching relevant passages:' for retrieval queries
)
def __init__(self, instruction: Optional[str] = None) -> None:
locals_ = locals()
for key, value in locals_.items():
if key != "self" and value is not None:
setattr(self.__class__, key, value)
@classmethod
def get_config(cls):
return {
k: v
for k, v in cls.__dict__.items()
if not k.startswith("__")
and not isinstance(
v,
(
types.FunctionType,
types.BuiltinFunctionType,
classmethod,
staticmethod,
),
)
and v is not None
}
def get_supported_openai_params(
self,
): # no optional openai embedding params supported
return []
def map_openai_params(self, non_default_params: dict, optional_params: dict):
return optional_params
async def make_call(
client: AsyncHTTPHandler,
api_base: str,
headers: dict,
data: str,
model: str,
messages: list,
logging_obj,
streaming_decoder: Optional[CustomStreamingDecoder] = None,
):
response = await client.post(api_base, headers=headers, data=data, stream=True)
if response.status_code != 200:
raise DatabricksError(status_code=response.status_code, message=response.text)
if streaming_decoder is not None:
completion_stream: Any = streaming_decoder.aiter_bytes(
response.aiter_bytes(chunk_size=1024)
)
else:
completion_stream = ModelResponseIterator(
streaming_response=response.aiter_lines(), sync_stream=False
)
# LOGGING
logging_obj.post_call(
input=messages,
api_key="",
original_response=completion_stream, # Pass the completion stream for logging
additional_args={"complete_input_dict": data},
)
return completion_stream
def make_sync_call(
client: Optional[HTTPHandler],
api_base: str,
headers: dict,
data: str,
model: str,
messages: list,
logging_obj,
streaming_decoder: Optional[CustomStreamingDecoder] = None,
):
if client is None:
client = HTTPHandler() # Create a new client if none provided
response = client.post(api_base, headers=headers, data=data, stream=True)
if response.status_code != 200:
raise DatabricksError(status_code=response.status_code, message=response.read())
if streaming_decoder is not None:
completion_stream = streaming_decoder.iter_bytes(
response.iter_bytes(chunk_size=1024)
)
else:
completion_stream = ModelResponseIterator(
streaming_response=response.iter_lines(), sync_stream=True
)
# LOGGING
logging_obj.post_call(
input=messages,
api_key="",
original_response="first stream response received",
additional_args={"complete_input_dict": data},
)
return completion_stream
class DatabricksChatCompletion(BaseLLM):
def __init__(self) -> None:
super().__init__()
# makes headers for API call
def _validate_environment(
self,
api_key: Optional[str],
api_base: Optional[str],
endpoint_type: Literal["chat_completions", "embeddings"],
custom_endpoint: Optional[bool],
headers: Optional[dict],
) -> Tuple[str, dict]:
if api_key is None and headers is None:
raise DatabricksError(
status_code=400,
message="Missing API Key - A call is being made to LLM Provider but no key is set either in the environment variables ({LLM_PROVIDER}_API_KEY) or via params",
)
if api_base is None:
raise DatabricksError(
status_code=400,
message="Missing API Base - A call is being made to LLM Provider but no api base is set either in the environment variables ({LLM_PROVIDER}_API_KEY) or via params",
)
if headers is None:
headers = {
"Authorization": "Bearer {}".format(api_key),
"Content-Type": "application/json",
}
else:
if api_key is not None:
headers.update({"Authorization": "Bearer {}".format(api_key)})
if endpoint_type == "chat_completions" and custom_endpoint is not True:
api_base = "{}/chat/completions".format(api_base)
elif endpoint_type == "embeddings" and custom_endpoint is not True:
api_base = "{}/embeddings".format(api_base)
return api_base, headers
async def acompletion_stream_function(
self,
model: str,
messages: list,
custom_llm_provider: str,
api_base: str,
custom_prompt_dict: dict,
model_response: ModelResponse,
print_verbose: Callable,
encoding,
api_key,
logging_obj,
stream,
data: dict,
optional_params=None,
litellm_params=None,
logger_fn=None,
headers={},
client: Optional[AsyncHTTPHandler] = None,
streaming_decoder: Optional[CustomStreamingDecoder] = None,
) -> CustomStreamWrapper:
data["stream"] = True
streamwrapper = CustomStreamWrapper(
completion_stream=None,
make_call=partial(
make_call,
api_base=api_base,
headers=headers,
data=json.dumps(data),
model=model,
messages=messages,
logging_obj=logging_obj,
streaming_decoder=streaming_decoder,
),
model=model,
custom_llm_provider=custom_llm_provider,
logging_obj=logging_obj,
)
return streamwrapper
async def acompletion_function(
self,
model: str,
messages: list,
api_base: str,
custom_prompt_dict: dict,
model_response: ModelResponse,
print_verbose: Callable,
encoding,
api_key,
logging_obj,
stream,
data: dict,
base_model: Optional[str],
optional_params: dict,
litellm_params=None,
logger_fn=None,
headers={},
timeout: Optional[Union[float, httpx.Timeout]] = None,
) -> ModelResponse:
if timeout is None:
timeout = httpx.Timeout(timeout=600.0, connect=5.0)
self.async_handler = AsyncHTTPHandler(timeout=timeout)
try:
response = await self.async_handler.post(
api_base, headers=headers, data=json.dumps(data)
)
response.raise_for_status()
response_json = response.json()
except httpx.HTTPStatusError as e:
raise DatabricksError(
status_code=e.response.status_code,
message=e.response.text,
)
except httpx.TimeoutException as e:
raise DatabricksError(status_code=408, message="Timeout error occurred.")
except Exception as e:
raise DatabricksError(status_code=500, message=str(e))
logging_obj.post_call(
input=messages,
api_key="",
original_response=response_json,
additional_args={"complete_input_dict": data},
)
response = ModelResponse(**response_json)
if base_model is not None:
response._hidden_params["model"] = base_model
return response
def completion(
self,
model: str,
messages: list,
api_base: str,
custom_llm_provider: str,
custom_prompt_dict: dict,
model_response: ModelResponse,
print_verbose: Callable,
encoding,
api_key: Optional[str],
logging_obj,
optional_params: dict,
acompletion=None,
litellm_params=None,
logger_fn=None,
headers: Optional[dict] = None,
timeout: Optional[Union[float, httpx.Timeout]] = None,
client: Optional[Union[HTTPHandler, AsyncHTTPHandler]] = None,
custom_endpoint: Optional[bool] = None,
streaming_decoder: Optional[
CustomStreamingDecoder
] = None, # if openai-compatible api needs custom stream decoder - e.g. sagemaker
):
custom_endpoint = custom_endpoint or optional_params.pop(
"custom_endpoint", None
)
base_model: Optional[str] = optional_params.pop("base_model", None)
api_base, headers = self._validate_environment(
api_base=api_base,
api_key=api_key,
endpoint_type="chat_completions",
custom_endpoint=custom_endpoint,
headers=headers,
)
## Load Config
config = litellm.DatabricksConfig().get_config()
for k, v in config.items():
if (
k not in optional_params
): # completion(top_k=3) > anthropic_config(top_k=3) <- allows for dynamic variables to be passed in
optional_params[k] = v
stream: bool = optional_params.get("stream", None) or False
optional_params["stream"] = stream
data = {
"model": model,
"messages": messages,
**optional_params,
}
## LOGGING
logging_obj.pre_call(
input=messages,
api_key=api_key,
additional_args={
"complete_input_dict": data,
"api_base": api_base,
"headers": headers,
},
)
if acompletion is True:
if client is not None and isinstance(client, HTTPHandler):
client = None
if (
stream is not None and stream is True
): # if function call - fake the streaming (need complete blocks for output parsing in openai format)
print_verbose("makes async anthropic streaming POST request")
data["stream"] = stream
return self.acompletion_stream_function(
model=model,
messages=messages,
data=data,
api_base=api_base,
custom_prompt_dict=custom_prompt_dict,
model_response=model_response,
print_verbose=print_verbose,
encoding=encoding,
api_key=api_key,
logging_obj=logging_obj,
optional_params=optional_params,
stream=stream,
litellm_params=litellm_params,
logger_fn=logger_fn,
headers=headers,
client=client,
custom_llm_provider=custom_llm_provider,
streaming_decoder=streaming_decoder,
)
else:
return self.acompletion_function(
model=model,
messages=messages,
data=data,
api_base=api_base,
custom_prompt_dict=custom_prompt_dict,
model_response=model_response,
print_verbose=print_verbose,
encoding=encoding,
api_key=api_key,
logging_obj=logging_obj,
optional_params=optional_params,
stream=stream,
litellm_params=litellm_params,
logger_fn=logger_fn,
headers=headers,
timeout=timeout,
base_model=base_model,
)
else:
if client is None or not isinstance(client, HTTPHandler):
client = HTTPHandler(timeout=timeout) # type: ignore
## COMPLETION CALL
if stream is True:
return CustomStreamWrapper(
completion_stream=None,
make_call=partial(
make_sync_call,
client=None,
api_base=api_base,
headers=headers, # type: ignore
data=json.dumps(data),
model=model,
messages=messages,
logging_obj=logging_obj,
streaming_decoder=streaming_decoder,
),
model=model,
custom_llm_provider=custom_llm_provider,
logging_obj=logging_obj,
)
else:
try:
response = client.post(
api_base, headers=headers, data=json.dumps(data)
)
response.raise_for_status()
response_json = response.json()
except httpx.HTTPStatusError as e:
raise DatabricksError(
status_code=e.response.status_code, message=response.text
)
except httpx.TimeoutException as e:
raise DatabricksError(
status_code=408, message="Timeout error occurred."
)
except Exception as e:
raise DatabricksError(status_code=500, message=str(e))
response = ModelResponse(**response_json)
if base_model is not None:
response._hidden_params["model"] = base_model
return response
async def aembedding(
self,
input: list,
data: dict,
model_response: ModelResponse,
timeout: float,
api_key: str,
api_base: str,
logging_obj,
headers: dict,
client=None,
) -> EmbeddingResponse:
response = None
try:
if client is None or isinstance(client, AsyncHTTPHandler):
self.async_client = AsyncHTTPHandler(timeout=timeout) # type: ignore
else:
self.async_client = client
try:
response = await self.async_client.post(
api_base,
headers=headers,
data=json.dumps(data),
) # type: ignore
response.raise_for_status()
response_json = response.json()
except httpx.HTTPStatusError as e:
raise DatabricksError(
status_code=e.response.status_code,
message=response.text if response else str(e),
)
except httpx.TimeoutException as e:
raise DatabricksError(
status_code=408, message="Timeout error occurred."
)
except Exception as e:
raise DatabricksError(status_code=500, message=str(e))
## LOGGING
logging_obj.post_call(
input=input,
api_key=api_key,
additional_args={"complete_input_dict": data},
original_response=response_json,
)
return EmbeddingResponse(**response_json)
except Exception as e:
## LOGGING
logging_obj.post_call(
input=input,
api_key=api_key,
original_response=str(e),
)
raise e
def embedding(
self,
model: str,
input: list,
timeout: float,
logging_obj,
api_key: Optional[str],
api_base: Optional[str],
optional_params: dict,
model_response: Optional[litellm.utils.EmbeddingResponse] = None,
client=None,
aembedding=None,
headers: Optional[dict] = None,
) -> EmbeddingResponse:
api_base, headers = self._validate_environment(
api_base=api_base,
api_key=api_key,
endpoint_type="embeddings",
custom_endpoint=False,
headers=headers,
)
model = model
data = {"model": model, "input": input, **optional_params}
## LOGGING
logging_obj.pre_call(
input=input,
api_key=api_key,
additional_args={"complete_input_dict": data, "api_base": api_base},
)
if aembedding == True:
return self.aembedding(data=data, input=input, logging_obj=logging_obj, model_response=model_response, api_base=api_base, api_key=api_key, timeout=timeout, client=client, headers=headers) # type: ignore
if client is None or isinstance(client, AsyncHTTPHandler):
self.client = HTTPHandler(timeout=timeout) # type: ignore
else:
self.client = client
## EMBEDDING CALL
try:
response = self.client.post(
api_base,
headers=headers,
data=json.dumps(data),
) # type: ignore
response.raise_for_status() # type: ignore
response_json = response.json() # type: ignore
except httpx.HTTPStatusError as e:
raise DatabricksError(
status_code=e.response.status_code,
message=response.text if response else str(e),
)
except httpx.TimeoutException as e:
raise DatabricksError(status_code=408, message="Timeout error occurred.")
except Exception as e:
raise DatabricksError(status_code=500, message=str(e))
## LOGGING
logging_obj.post_call(
input=input,
api_key=api_key,
additional_args={"complete_input_dict": data},
original_response=response_json,
)
return litellm.EmbeddingResponse(**response_json)
class ModelResponseIterator:
def __init__(self, streaming_response, sync_stream: bool):
self.streaming_response = streaming_response
def chunk_parser(self, chunk: dict) -> GenericStreamingChunk:
try:
processed_chunk = litellm.ModelResponse(**chunk, stream=True) # type: ignore
text = ""
tool_use: Optional[ChatCompletionToolCallChunk] = None
is_finished = False
finish_reason = ""
usage: Optional[ChatCompletionUsageBlock] = None
if processed_chunk.choices[0].delta.content is not None: # type: ignore
text = processed_chunk.choices[0].delta.content # type: ignore
if (
processed_chunk.choices[0].delta.tool_calls is not None # type: ignore
and len(processed_chunk.choices[0].delta.tool_calls) > 0 # type: ignore
and processed_chunk.choices[0].delta.tool_calls[0].function is not None # type: ignore
and processed_chunk.choices[0].delta.tool_calls[0].function.arguments # type: ignore
is not None
):
tool_use = ChatCompletionToolCallChunk(
id=processed_chunk.choices[0].delta.tool_calls[0].id, # type: ignore
type="function",
function=ChatCompletionToolCallFunctionChunk(
name=processed_chunk.choices[0]
.delta.tool_calls[0] # type: ignore
.function.name,
arguments=processed_chunk.choices[0]
.delta.tool_calls[0] # type: ignore
.function.arguments,
),
index=processed_chunk.choices[0].index,
)
if processed_chunk.choices[0].finish_reason is not None:
is_finished = True
finish_reason = processed_chunk.choices[0].finish_reason
if hasattr(processed_chunk, "usage") and isinstance(
processed_chunk.usage, litellm.Usage
):
usage_chunk: litellm.Usage = processed_chunk.usage
usage = ChatCompletionUsageBlock(
prompt_tokens=usage_chunk.prompt_tokens,
completion_tokens=usage_chunk.completion_tokens,
total_tokens=usage_chunk.total_tokens,
)
return GenericStreamingChunk(
text=text,
tool_use=tool_use,
is_finished=is_finished,
finish_reason=finish_reason,
usage=usage,
index=0,
)
except json.JSONDecodeError:
raise ValueError(f"Failed to decode JSON from chunk: {chunk}")
# Sync iterator
def __iter__(self):
self.response_iterator = self.streaming_response
return self
def __next__(self):
try:
chunk = self.response_iterator.__next__()
except StopIteration:
raise StopIteration
except ValueError as e:
raise RuntimeError(f"Error receiving chunk from stream: {e}")
try:
chunk = chunk.replace("data:", "")
chunk = chunk.strip()
if len(chunk) > 0:
json_chunk = json.loads(chunk)
return self.chunk_parser(chunk=json_chunk)
else:
return GenericStreamingChunk(
text="",
is_finished=False,
finish_reason="",
usage=None,
index=0,
tool_use=None,
)
except StopIteration:
raise StopIteration
except ValueError as e:
raise RuntimeError(f"Error parsing chunk: {e},\nReceived chunk: {chunk}")
# Async iterator
def __aiter__(self):
self.async_response_iterator = self.streaming_response.__aiter__()
return self
async def __anext__(self):
try:
chunk = await self.async_response_iterator.__anext__()
except StopAsyncIteration:
raise StopAsyncIteration
except ValueError as e:
raise RuntimeError(f"Error receiving chunk from stream: {e}")
try:
chunk = chunk.replace("data:", "")
chunk = chunk.strip()
if chunk == "[DONE]":
raise StopAsyncIteration
if len(chunk) > 0:
json_chunk = json.loads(chunk)
return self.chunk_parser(chunk=json_chunk)
else:
return GenericStreamingChunk(
text="",
is_finished=False,
finish_reason="",
usage=None,
index=0,
tool_use=None,
)
except StopAsyncIteration:
raise StopAsyncIteration
except ValueError as e:
raise RuntimeError(f"Error parsing chunk: {e},\nReceived chunk: {chunk}")