mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-27 19:54:13 +00:00
feat(sagemaker.py): add sagemaker messages api support
Closes https://github.com/BerriAI/litellm/issues/2641 Closes https://github.com/BerriAI/litellm/pull/5178
This commit is contained in:
parent
2a6aa6da7a
commit
f7aa787fe6
6 changed files with 112 additions and 18 deletions
|
@ -206,17 +206,60 @@ class SagemakerLLM(BaseAWSLLM):
|
|||
print_verbose: Callable,
|
||||
encoding,
|
||||
logging_obj,
|
||||
timeout: Optional[Union[float, httpx.Timeout]] = None,
|
||||
custom_prompt_dict={},
|
||||
hf_model_name=None,
|
||||
optional_params=None,
|
||||
litellm_params=None,
|
||||
logger_fn=None,
|
||||
acompletion: bool = False,
|
||||
use_messages_api: Optional[bool] = None,
|
||||
):
|
||||
|
||||
# pop streaming if it's in the optional params as 'stream' raises an error with sagemaker
|
||||
credentials, aws_region_name = self._load_credentials(optional_params)
|
||||
inference_params = deepcopy(optional_params)
|
||||
stream = inference_params.pop("stream", None)
|
||||
model_id = optional_params.get("model_id", None)
|
||||
|
||||
if use_messages_api is True:
|
||||
from litellm.llms.databricks import DatabricksChatCompletion
|
||||
|
||||
openai_like_chat_completions = DatabricksChatCompletion()
|
||||
inference_params["stream"] = True if stream is True else False
|
||||
_data = {
|
||||
"model": model,
|
||||
"messages": messages,
|
||||
**inference_params,
|
||||
}
|
||||
|
||||
prepared_request = self._prepare_request(
|
||||
model=model,
|
||||
data=_data,
|
||||
optional_params=optional_params,
|
||||
credentials=credentials,
|
||||
aws_region_name=aws_region_name,
|
||||
)
|
||||
|
||||
return openai_like_chat_completions.completion(
|
||||
model=model,
|
||||
messages=messages,
|
||||
api_base=prepared_request.url,
|
||||
api_key=None,
|
||||
custom_prompt_dict=custom_prompt_dict,
|
||||
model_response=model_response,
|
||||
print_verbose=print_verbose,
|
||||
logging_obj=logging_obj,
|
||||
optional_params=inference_params,
|
||||
acompletion=acompletion,
|
||||
litellm_params=litellm_params,
|
||||
logger_fn=logger_fn,
|
||||
timeout=timeout,
|
||||
encoding=encoding,
|
||||
headers=prepared_request.headers,
|
||||
custom_endpoint=True,
|
||||
custom_llm_provider="sagemaker_chat",
|
||||
)
|
||||
|
||||
## Load Config
|
||||
config = litellm.SagemakerConfig.get_config()
|
||||
|
@ -259,8 +302,6 @@ class SagemakerLLM(BaseAWSLLM):
|
|||
hf_model_name or model
|
||||
) # pass in hf model name for pulling it's prompt template - (e.g. `hf_model_name="meta-llama/Llama-2-7b-chat-hf` applies the llama2 chat template to the prompt)
|
||||
prompt = prompt_factory(model=hf_model_name, messages=messages)
|
||||
stream = inference_params.pop("stream", None)
|
||||
model_id = optional_params.get("model_id", None)
|
||||
|
||||
if stream is True:
|
||||
data = {"inputs": prompt, "parameters": inference_params, "stream": True}
|
||||
|
@ -275,7 +316,7 @@ class SagemakerLLM(BaseAWSLLM):
|
|||
# Add model_id as InferenceComponentName header
|
||||
# boto3 doc: https://docs.aws.amazon.com/sagemaker/latest/APIReference/API_runtime_InvokeEndpoint.html
|
||||
prepared_request.headers.update(
|
||||
{"X-Amzn-SageMaker-Inference-Componen": model_id}
|
||||
{"X-Amzn-SageMaker-Inference-Component": model_id}
|
||||
)
|
||||
|
||||
if acompletion is True:
|
||||
|
@ -338,7 +379,7 @@ class SagemakerLLM(BaseAWSLLM):
|
|||
)
|
||||
|
||||
# Async completion
|
||||
if acompletion == True:
|
||||
if acompletion is True:
|
||||
return self.async_completion(
|
||||
prepared_request=prepared_request,
|
||||
model_response=model_response,
|
||||
|
@ -354,7 +395,7 @@ class SagemakerLLM(BaseAWSLLM):
|
|||
# Add model_id as InferenceComponentName header
|
||||
# boto3 doc: https://docs.aws.amazon.com/sagemaker/latest/APIReference/API_runtime_InvokeEndpoint.html
|
||||
prepared_request.headers.update(
|
||||
{"X-Amzn-SageMaker-Inference-Componen": model_id}
|
||||
{"X-Amzn-SageMaker-Inference-Component": model_id}
|
||||
)
|
||||
|
||||
## LOGGING
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue