feat(sagemaker.py): add sagemaker messages api support

Closes https://github.com/BerriAI/litellm/issues/2641

 Closes https://github.com/BerriAI/litellm/pull/5178
This commit is contained in:
Krrish Dholakia 2024-08-23 10:31:35 -07:00
parent 2a6aa6da7a
commit f7aa787fe6
6 changed files with 112 additions and 18 deletions

View file

@ -235,23 +235,28 @@ class DatabricksChatCompletion(BaseLLM):
api_base: Optional[str],
endpoint_type: Literal["chat_completions", "embeddings"],
custom_endpoint: Optional[bool],
headers: Optional[dict],
) -> Tuple[str, dict]:
if api_key is None:
if api_key is None and headers is None:
raise DatabricksError(
status_code=400,
message="Missing Databricks API Key - A call is being made to Databricks but no key is set either in the environment variables (DATABRICKS_API_KEY) or via params",
message="Missing API Key - A call is being made to LLM Provider but no key is set either in the environment variables ({LLM_PROVIDER}_API_KEY) or via params",
)
if api_base is None:
raise DatabricksError(
status_code=400,
message="Missing Databricks API Base - A call is being made to Databricks but no api base is set either in the environment variables (DATABRICKS_API_BASE) or via params",
message="Missing API Base - A call is being made to LLM Provider but no api base is set either in the environment variables ({LLM_PROVIDER}_API_KEY) or via params",
)
headers = {
"Authorization": "Bearer {}".format(api_key),
"Content-Type": "application/json",
}
if headers is None:
headers = {
"Authorization": "Bearer {}".format(api_key),
"Content-Type": "application/json",
}
else:
if api_key is not None:
headers.update({"Authorization": "Bearer {}".format(api_key)})
if endpoint_type == "chat_completions" and custom_endpoint is not True:
api_base = "{}/chat/completions".format(api_base)
@ -356,23 +361,27 @@ class DatabricksChatCompletion(BaseLLM):
model_response: ModelResponse,
print_verbose: Callable,
encoding,
api_key,
api_key: Optional[str],
logging_obj,
optional_params: dict,
acompletion=None,
litellm_params=None,
logger_fn=None,
headers={},
headers: Optional[dict] = None,
timeout: Optional[Union[float, httpx.Timeout]] = None,
client: Optional[Union[HTTPHandler, AsyncHTTPHandler]] = None,
custom_endpoint: Optional[bool] = None,
):
custom_endpoint: Optional[bool] = optional_params.pop("custom_endpoint", None)
custom_endpoint = custom_endpoint or optional_params.pop(
"custom_endpoint", None
)
base_model: Optional[str] = optional_params.pop("base_model", None)
api_base, headers = self._validate_environment(
api_base=api_base,
api_key=api_key,
endpoint_type="chat_completions",
custom_endpoint=custom_endpoint,
headers=headers,
)
## Load Config
config = litellm.DatabricksConfig().get_config()
@ -382,7 +391,7 @@ class DatabricksChatCompletion(BaseLLM):
): # completion(top_k=3) > anthropic_config(top_k=3) <- allows for dynamic variables to be passed in
optional_params[k] = v
stream: bool = optional_params.pop("stream", None) or False
stream: bool = optional_params.get("stream", None) or False
optional_params["stream"] = stream
data = {
@ -565,12 +574,14 @@ class DatabricksChatCompletion(BaseLLM):
model_response: Optional[litellm.utils.EmbeddingResponse] = None,
client=None,
aembedding=None,
headers: Optional[dict] = None,
) -> EmbeddingResponse:
api_base, headers = self._validate_environment(
api_base=api_base,
api_key=api_key,
endpoint_type="embeddings",
custom_endpoint=False,
headers=headers,
)
model = model
data = {"model": model, "input": input, **optional_params}