mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 19:24:27 +00:00
feat(sagemaker.py): add sagemaker messages api support
Closes https://github.com/BerriAI/litellm/issues/2641 Closes https://github.com/BerriAI/litellm/pull/5178
This commit is contained in:
parent
2a6aa6da7a
commit
f7aa787fe6
6 changed files with 112 additions and 18 deletions
|
@ -235,23 +235,28 @@ class DatabricksChatCompletion(BaseLLM):
|
|||
api_base: Optional[str],
|
||||
endpoint_type: Literal["chat_completions", "embeddings"],
|
||||
custom_endpoint: Optional[bool],
|
||||
headers: Optional[dict],
|
||||
) -> Tuple[str, dict]:
|
||||
if api_key is None:
|
||||
if api_key is None and headers is None:
|
||||
raise DatabricksError(
|
||||
status_code=400,
|
||||
message="Missing Databricks API Key - A call is being made to Databricks but no key is set either in the environment variables (DATABRICKS_API_KEY) or via params",
|
||||
message="Missing API Key - A call is being made to LLM Provider but no key is set either in the environment variables ({LLM_PROVIDER}_API_KEY) or via params",
|
||||
)
|
||||
|
||||
if api_base is None:
|
||||
raise DatabricksError(
|
||||
status_code=400,
|
||||
message="Missing Databricks API Base - A call is being made to Databricks but no api base is set either in the environment variables (DATABRICKS_API_BASE) or via params",
|
||||
message="Missing API Base - A call is being made to LLM Provider but no api base is set either in the environment variables ({LLM_PROVIDER}_API_KEY) or via params",
|
||||
)
|
||||
|
||||
headers = {
|
||||
"Authorization": "Bearer {}".format(api_key),
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
if headers is None:
|
||||
headers = {
|
||||
"Authorization": "Bearer {}".format(api_key),
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
else:
|
||||
if api_key is not None:
|
||||
headers.update({"Authorization": "Bearer {}".format(api_key)})
|
||||
|
||||
if endpoint_type == "chat_completions" and custom_endpoint is not True:
|
||||
api_base = "{}/chat/completions".format(api_base)
|
||||
|
@ -356,23 +361,27 @@ class DatabricksChatCompletion(BaseLLM):
|
|||
model_response: ModelResponse,
|
||||
print_verbose: Callable,
|
||||
encoding,
|
||||
api_key,
|
||||
api_key: Optional[str],
|
||||
logging_obj,
|
||||
optional_params: dict,
|
||||
acompletion=None,
|
||||
litellm_params=None,
|
||||
logger_fn=None,
|
||||
headers={},
|
||||
headers: Optional[dict] = None,
|
||||
timeout: Optional[Union[float, httpx.Timeout]] = None,
|
||||
client: Optional[Union[HTTPHandler, AsyncHTTPHandler]] = None,
|
||||
custom_endpoint: Optional[bool] = None,
|
||||
):
|
||||
custom_endpoint: Optional[bool] = optional_params.pop("custom_endpoint", None)
|
||||
custom_endpoint = custom_endpoint or optional_params.pop(
|
||||
"custom_endpoint", None
|
||||
)
|
||||
base_model: Optional[str] = optional_params.pop("base_model", None)
|
||||
api_base, headers = self._validate_environment(
|
||||
api_base=api_base,
|
||||
api_key=api_key,
|
||||
endpoint_type="chat_completions",
|
||||
custom_endpoint=custom_endpoint,
|
||||
headers=headers,
|
||||
)
|
||||
## Load Config
|
||||
config = litellm.DatabricksConfig().get_config()
|
||||
|
@ -382,7 +391,7 @@ class DatabricksChatCompletion(BaseLLM):
|
|||
): # completion(top_k=3) > anthropic_config(top_k=3) <- allows for dynamic variables to be passed in
|
||||
optional_params[k] = v
|
||||
|
||||
stream: bool = optional_params.pop("stream", None) or False
|
||||
stream: bool = optional_params.get("stream", None) or False
|
||||
optional_params["stream"] = stream
|
||||
|
||||
data = {
|
||||
|
@ -565,12 +574,14 @@ class DatabricksChatCompletion(BaseLLM):
|
|||
model_response: Optional[litellm.utils.EmbeddingResponse] = None,
|
||||
client=None,
|
||||
aembedding=None,
|
||||
headers: Optional[dict] = None,
|
||||
) -> EmbeddingResponse:
|
||||
api_base, headers = self._validate_environment(
|
||||
api_base=api_base,
|
||||
api_key=api_key,
|
||||
endpoint_type="embeddings",
|
||||
custom_endpoint=False,
|
||||
headers=headers,
|
||||
)
|
||||
model = model
|
||||
data = {"model": model, "input": input, **optional_params}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue