feat(sagemaker.py): add sagemaker messages api support

Closes https://github.com/BerriAI/litellm/issues/2641

 Closes https://github.com/BerriAI/litellm/pull/5178
This commit is contained in:
Krrish Dholakia 2024-08-23 10:31:35 -07:00
parent 2a6aa6da7a
commit f7aa787fe6
6 changed files with 112 additions and 18 deletions

View file

@ -381,6 +381,7 @@ async def acompletion(
or custom_llm_provider == "vertex_ai_beta"
or custom_llm_provider == "gemini"
or custom_llm_provider == "sagemaker"
or custom_llm_provider == "sagemaker_chat"
or custom_llm_provider == "anthropic"
or custom_llm_provider == "predibase"
or custom_llm_provider == "bedrock"
@ -945,7 +946,6 @@ def completion(
text_completion=kwargs.get("text_completion"),
azure_ad_token_provider=kwargs.get("azure_ad_token_provider"),
user_continue_message=kwargs.get("user_continue_message"),
)
logging.update_environment_variables(
model=model,
@ -2247,7 +2247,10 @@ def completion(
## RESPONSE OBJECT
response = model_response
elif custom_llm_provider == "sagemaker":
elif (
custom_llm_provider == "sagemaker"
or custom_llm_provider == "sagemaker_chat"
):
# boto3 reads keys from .env
model_response = sagemaker_llm.completion(
model=model,
@ -2262,6 +2265,9 @@ def completion(
encoding=encoding,
logging_obj=logging,
acompletion=acompletion,
use_messages_api=(
True if custom_llm_provider == "sagemaker_chat" else False
),
)
if optional_params.get("stream", False):
## LOGGING