feat(sagemaker.py): add sagemaker messages api support

Closes https://github.com/BerriAI/litellm/issues/2641

 Closes https://github.com/BerriAI/litellm/pull/5178
This commit is contained in:
Krrish Dholakia 2024-08-23 10:31:35 -07:00
parent 2a6aa6da7a
commit f7aa787fe6
6 changed files with 112 additions and 18 deletions

View file

@ -206,17 +206,60 @@ class SagemakerLLM(BaseAWSLLM):
print_verbose: Callable,
encoding,
logging_obj,
timeout: Optional[Union[float, httpx.Timeout]] = None,
custom_prompt_dict={},
hf_model_name=None,
optional_params=None,
litellm_params=None,
logger_fn=None,
acompletion: bool = False,
use_messages_api: Optional[bool] = None,
):
# pop streaming if it's in the optional params as 'stream' raises an error with sagemaker
credentials, aws_region_name = self._load_credentials(optional_params)
inference_params = deepcopy(optional_params)
stream = inference_params.pop("stream", None)
model_id = optional_params.get("model_id", None)
if use_messages_api is True:
from litellm.llms.databricks import DatabricksChatCompletion
openai_like_chat_completions = DatabricksChatCompletion()
inference_params["stream"] = True if stream is True else False
_data = {
"model": model,
"messages": messages,
**inference_params,
}
prepared_request = self._prepare_request(
model=model,
data=_data,
optional_params=optional_params,
credentials=credentials,
aws_region_name=aws_region_name,
)
return openai_like_chat_completions.completion(
model=model,
messages=messages,
api_base=prepared_request.url,
api_key=None,
custom_prompt_dict=custom_prompt_dict,
model_response=model_response,
print_verbose=print_verbose,
logging_obj=logging_obj,
optional_params=inference_params,
acompletion=acompletion,
litellm_params=litellm_params,
logger_fn=logger_fn,
timeout=timeout,
encoding=encoding,
headers=prepared_request.headers,
custom_endpoint=True,
custom_llm_provider="sagemaker_chat",
)
## Load Config
config = litellm.SagemakerConfig.get_config()
@ -259,8 +302,6 @@ class SagemakerLLM(BaseAWSLLM):
hf_model_name or model
) # pass in hf model name for pulling it's prompt template - (e.g. `hf_model_name="meta-llama/Llama-2-7b-chat-hf` applies the llama2 chat template to the prompt)
prompt = prompt_factory(model=hf_model_name, messages=messages)
stream = inference_params.pop("stream", None)
model_id = optional_params.get("model_id", None)
if stream is True:
data = {"inputs": prompt, "parameters": inference_params, "stream": True}
@ -275,7 +316,7 @@ class SagemakerLLM(BaseAWSLLM):
# Add model_id as InferenceComponentName header
# boto3 doc: https://docs.aws.amazon.com/sagemaker/latest/APIReference/API_runtime_InvokeEndpoint.html
prepared_request.headers.update(
{"X-Amzn-SageMaker-Inference-Componen": model_id}
{"X-Amzn-SageMaker-Inference-Component": model_id}
)
if acompletion is True:
@ -338,7 +379,7 @@ class SagemakerLLM(BaseAWSLLM):
)
# Async completion
if acompletion == True:
if acompletion is True:
return self.async_completion(
prepared_request=prepared_request,
model_response=model_response,
@ -354,7 +395,7 @@ class SagemakerLLM(BaseAWSLLM):
# Add model_id as InferenceComponentName header
# boto3 doc: https://docs.aws.amazon.com/sagemaker/latest/APIReference/API_runtime_InvokeEndpoint.html
prepared_request.headers.update(
{"X-Amzn-SageMaker-Inference-Componen": model_id}
{"X-Amzn-SageMaker-Inference-Component": model_id}
)
## LOGGING