refactor(sagemaker/): separate chat + completion routes + make them b… (#7151)

* refactor(sagemaker/): separate chat + completion routes + make them both use base llm config

Addresses https://github.com/andrewyng/aisuite/issues/113#issuecomment-2512369132

* fix(main.py): pass hf model name + custom prompt dict to litellm params
This commit is contained in:
Krish Dholakia 2024-12-10 19:40:05 -08:00 committed by GitHub
parent df12f87a64
commit 61afdab228
14 changed files with 799 additions and 534 deletions

View file

@ -1103,7 +1103,8 @@ from .llms.vertex_ai_and_google_ai_studio.vertex_ai_partner_models.ai21.transfor
VertexAIAi21Config, VertexAIAi21Config,
) )
from .llms.sagemaker.sagemaker import SagemakerConfig from .llms.sagemaker.completion.transformation import SagemakerConfig
from .llms.sagemaker.chat.transformation import SagemakerChatConfig
from .llms.ollama import OllamaConfig from .llms.ollama import OllamaConfig
from .llms.ollama_chat import OllamaChatConfig from .llms.ollama_chat import OllamaChatConfig
from .llms.maritalk import MaritTalkConfig from .llms.maritalk import MaritTalkConfig

View file

@ -182,7 +182,7 @@ def get_supported_openai_params( # noqa: PLR0915
elif request_type == "embeddings": elif request_type == "embeddings":
return litellm.VertexAITextEmbeddingConfig().get_supported_openai_params() return litellm.VertexAITextEmbeddingConfig().get_supported_openai_params()
elif custom_llm_provider == "sagemaker": elif custom_llm_provider == "sagemaker":
return ["stream", "temperature", "max_tokens", "top_p", "stop", "n"] return litellm.SagemakerConfig().get_supported_openai_params(model=model)
elif custom_llm_provider == "aleph_alpha": elif custom_llm_provider == "aleph_alpha":
return [ return [
"max_tokens", "max_tokens",

View file

@ -182,7 +182,11 @@ class OpenAIGPTConfig(BaseConfig):
Returns: Returns:
dict: The transformed request. Sent as the body of the API call. dict: The transformed request. Sent as the body of the API call.
""" """
raise NotImplementedError return {
"model": model,
"messages": messages,
**optional_params,
}
def transform_response( def transform_response(
self, self,

View file

@ -34,7 +34,7 @@ class BaseLLMException(Exception):
self, self,
status_code: int, status_code: int,
message: str, message: str,
headers: Optional[Union[Dict, httpx.Headers]] = None, headers: Optional[Union[httpx.Headers, Dict]] = None,
request: Optional[httpx.Request] = None, request: Optional[httpx.Request] = None,
response: Optional[httpx.Response] = None, response: Optional[httpx.Response] = None,
): ):

View file

@ -0,0 +1,179 @@
import json
from copy import deepcopy
from typing import Any, Callable, Dict, Optional, Union
import httpx
from litellm.utils import ModelResponse, get_secret
from ...base_aws_llm import BaseAWSLLM
from ...prompt_templates.factory import custom_prompt, prompt_factory
from ..common_utils import AWSEventStreamDecoder
from .transformation import SagemakerChatConfig
class SagemakerChatHandler(BaseAWSLLM):
def _load_credentials(
self,
optional_params: dict,
):
try:
from botocore.credentials import Credentials
except ImportError:
raise ImportError("Missing boto3 to call bedrock. Run 'pip install boto3'.")
## CREDENTIALS ##
# pop aws_secret_access_key, aws_access_key_id, aws_session_token, aws_region_name from kwargs, since completion calls fail with them
aws_secret_access_key = optional_params.pop("aws_secret_access_key", None)
aws_access_key_id = optional_params.pop("aws_access_key_id", None)
aws_session_token = optional_params.pop("aws_session_token", None)
aws_region_name = optional_params.pop("aws_region_name", None)
aws_role_name = optional_params.pop("aws_role_name", None)
aws_session_name = optional_params.pop("aws_session_name", None)
aws_profile_name = optional_params.pop("aws_profile_name", None)
optional_params.pop(
"aws_bedrock_runtime_endpoint", None
) # https://bedrock-runtime.{region_name}.amazonaws.com
aws_web_identity_token = optional_params.pop("aws_web_identity_token", None)
aws_sts_endpoint = optional_params.pop("aws_sts_endpoint", None)
### SET REGION NAME ###
if aws_region_name is None:
# check env #
litellm_aws_region_name = get_secret("AWS_REGION_NAME", None)
if litellm_aws_region_name is not None and isinstance(
litellm_aws_region_name, str
):
aws_region_name = litellm_aws_region_name
standard_aws_region_name = get_secret("AWS_REGION", None)
if standard_aws_region_name is not None and isinstance(
standard_aws_region_name, str
):
aws_region_name = standard_aws_region_name
if aws_region_name is None:
aws_region_name = "us-west-2"
credentials: Credentials = self.get_credentials(
aws_access_key_id=aws_access_key_id,
aws_secret_access_key=aws_secret_access_key,
aws_session_token=aws_session_token,
aws_region_name=aws_region_name,
aws_session_name=aws_session_name,
aws_profile_name=aws_profile_name,
aws_role_name=aws_role_name,
aws_web_identity_token=aws_web_identity_token,
aws_sts_endpoint=aws_sts_endpoint,
)
return credentials, aws_region_name
def _prepare_request(
self,
credentials,
model: str,
data: dict,
optional_params: dict,
aws_region_name: str,
extra_headers: Optional[dict] = None,
):
try:
import boto3
from botocore.auth import SigV4Auth
from botocore.awsrequest import AWSRequest
from botocore.credentials import Credentials
except ImportError:
raise ImportError("Missing boto3 to call bedrock. Run 'pip install boto3'.")
sigv4 = SigV4Auth(credentials, "sagemaker", aws_region_name)
if optional_params.get("stream") is True:
api_base = f"https://runtime.sagemaker.{aws_region_name}.amazonaws.com/endpoints/{model}/invocations-response-stream"
else:
api_base = f"https://runtime.sagemaker.{aws_region_name}.amazonaws.com/endpoints/{model}/invocations"
sagemaker_base_url = optional_params.get("sagemaker_base_url", None)
if sagemaker_base_url is not None:
api_base = sagemaker_base_url
encoded_data = json.dumps(data).encode("utf-8")
headers = {"Content-Type": "application/json"}
if extra_headers is not None:
headers = {"Content-Type": "application/json", **extra_headers}
request = AWSRequest(
method="POST", url=api_base, data=encoded_data, headers=headers
)
sigv4.add_auth(request)
if (
extra_headers is not None and "Authorization" in extra_headers
): # prevent sigv4 from overwriting the auth header
request.headers["Authorization"] = extra_headers["Authorization"]
prepped_request = request.prepare()
return prepped_request
def completion(
self,
model: str,
messages: list,
model_response: ModelResponse,
print_verbose: Callable,
encoding,
logging_obj,
optional_params: dict,
litellm_params: dict,
timeout: Optional[Union[float, httpx.Timeout]] = None,
custom_prompt_dict={},
logger_fn=None,
acompletion: bool = False,
headers: dict = {},
):
# pop streaming if it's in the optional params as 'stream' raises an error with sagemaker
credentials, aws_region_name = self._load_credentials(optional_params)
inference_params = deepcopy(optional_params)
stream = inference_params.pop("stream", None)
from litellm.llms.openai_like.chat.handler import OpenAILikeChatHandler
openai_like_chat_completions = OpenAILikeChatHandler()
inference_params["stream"] = True if stream is True else False
_data = SagemakerChatConfig().transform_request(
model=model,
messages=messages,
optional_params=inference_params,
litellm_params=litellm_params,
headers=headers,
)
prepared_request = self._prepare_request(
model=model,
data=_data,
optional_params=optional_params,
credentials=credentials,
aws_region_name=aws_region_name,
)
custom_stream_decoder = AWSEventStreamDecoder(model="", is_messages_api=True)
return openai_like_chat_completions.completion(
model=model,
messages=messages,
api_base=prepared_request.url,
api_key=None,
custom_prompt_dict=custom_prompt_dict,
model_response=model_response,
print_verbose=print_verbose,
logging_obj=logging_obj,
optional_params=inference_params,
acompletion=acompletion,
litellm_params=litellm_params,
logger_fn=logger_fn,
timeout=timeout,
encoding=encoding,
headers=prepared_request.headers, # type: ignore
custom_endpoint=True,
custom_llm_provider="sagemaker_chat",
streaming_decoder=custom_stream_decoder, # type: ignore
)

View file

@ -0,0 +1,26 @@
"""
Translate from OpenAI's `/v1/chat/completions` to Sagemaker's `/invocations` API
Called if Sagemaker endpoint supports HF Messages API.
LiteLLM Docs: https://docs.litellm.ai/docs/providers/aws_sagemaker#sagemaker-messages-api
Huggingface Docs: https://huggingface.co/docs/text-generation-inference/en/messages_api
"""
from typing import Union
from httpx._models import Headers
from litellm.llms.base_llm.transformation import BaseLLMException
from ...OpenAI.chat.gpt_transformation import OpenAIGPTConfig
from ..common_utils import SagemakerError
class SagemakerChatConfig(OpenAIGPTConfig):
def get_error_class(
self, error_message: str, status_code: int, headers: Union[dict, Headers]
) -> BaseLLMException:
return SagemakerError(
status_code=status_code, message=error_message, headers=headers
)

View file

@ -0,0 +1,198 @@
import json
from typing import AsyncIterator, Iterator, List, Optional, Union
import httpx
from litellm import verbose_logger
from litellm.llms.base_llm.transformation import BaseLLMException
from litellm.types.utils import GenericStreamingChunk as GChunk
from litellm.types.utils import StreamingChatCompletionChunk
_response_stream_shape_cache = None
class SagemakerError(BaseLLMException):
def __init__(
self,
status_code: int,
message: str,
headers: Optional[Union[dict, httpx.Headers]] = None,
):
super().__init__(status_code=status_code, message=message, headers=headers)
class AWSEventStreamDecoder:
def __init__(self, model: str, is_messages_api: Optional[bool] = None) -> None:
from botocore.parsers import EventStreamJSONParser
self.model = model
self.parser = EventStreamJSONParser()
self.content_blocks: List = []
self.is_messages_api = is_messages_api
def _chunk_parser_messages_api(
self, chunk_data: dict
) -> StreamingChatCompletionChunk:
openai_chunk = StreamingChatCompletionChunk(**chunk_data)
return openai_chunk
def _chunk_parser(self, chunk_data: dict) -> GChunk:
verbose_logger.debug("in sagemaker chunk parser, chunk_data %s", chunk_data)
_token = chunk_data.get("token", {}) or {}
_index = chunk_data.get("index", None) or 0
is_finished = False
finish_reason = ""
_text = _token.get("text", "")
if _text == "<|endoftext|>":
return GChunk(
text="",
index=_index,
is_finished=True,
finish_reason="stop",
usage=None,
)
return GChunk(
text=_text,
index=_index,
is_finished=is_finished,
finish_reason=finish_reason,
usage=None,
)
def iter_bytes(
self, iterator: Iterator[bytes]
) -> Iterator[Optional[Union[GChunk, StreamingChatCompletionChunk]]]:
"""Given an iterator that yields lines, iterate over it & yield every event encountered"""
from botocore.eventstream import EventStreamBuffer
event_stream_buffer = EventStreamBuffer()
accumulated_json = ""
for chunk in iterator:
event_stream_buffer.add_data(chunk)
for event in event_stream_buffer:
message = self._parse_message_from_event(event)
if message:
# remove data: prefix and "\n\n" at the end
message = message.replace("data:", "").replace("\n\n", "")
# Accumulate JSON data
accumulated_json += message
# Try to parse the accumulated JSON
try:
_data = json.loads(accumulated_json)
if self.is_messages_api:
yield self._chunk_parser_messages_api(chunk_data=_data)
else:
yield self._chunk_parser(chunk_data=_data)
# Reset accumulated_json after successful parsing
accumulated_json = ""
except json.JSONDecodeError:
# If it's not valid JSON yet, continue to the next event
continue
# Handle any remaining data after the iterator is exhausted
if accumulated_json:
try:
_data = json.loads(accumulated_json)
if self.is_messages_api:
yield self._chunk_parser_messages_api(chunk_data=_data)
else:
yield self._chunk_parser(chunk_data=_data)
except json.JSONDecodeError:
# Handle or log any unparseable data at the end
verbose_logger.error(
f"Warning: Unparseable JSON data remained: {accumulated_json}"
)
yield None
async def aiter_bytes(
self, iterator: AsyncIterator[bytes]
) -> AsyncIterator[Optional[Union[GChunk, StreamingChatCompletionChunk]]]:
"""Given an async iterator that yields lines, iterate over it & yield every event encountered"""
from botocore.eventstream import EventStreamBuffer
event_stream_buffer = EventStreamBuffer()
accumulated_json = ""
async for chunk in iterator:
event_stream_buffer.add_data(chunk)
for event in event_stream_buffer:
message = self._parse_message_from_event(event)
if message:
verbose_logger.debug("sagemaker parsed chunk bytes %s", message)
# remove data: prefix and "\n\n" at the end
message = message.replace("data:", "").replace("\n\n", "")
# Accumulate JSON data
accumulated_json += message
# Try to parse the accumulated JSON
try:
_data = json.loads(accumulated_json)
if self.is_messages_api:
yield self._chunk_parser_messages_api(chunk_data=_data)
else:
yield self._chunk_parser(chunk_data=_data)
# Reset accumulated_json after successful parsing
accumulated_json = ""
except json.JSONDecodeError:
# If it's not valid JSON yet, continue to the next event
continue
# Handle any remaining data after the iterator is exhausted
if accumulated_json:
try:
_data = json.loads(accumulated_json)
if self.is_messages_api:
yield self._chunk_parser_messages_api(chunk_data=_data)
else:
yield self._chunk_parser(chunk_data=_data)
except json.JSONDecodeError:
# Handle or log any unparseable data at the end
verbose_logger.error(
f"Warning: Unparseable JSON data remained: {accumulated_json}"
)
yield None
def _parse_message_from_event(self, event) -> Optional[str]:
response_dict = event.to_response_dict()
parsed_response = self.parser.parse(response_dict, get_response_stream_shape())
if response_dict["status_code"] != 200:
raise ValueError(f"Bad response code, expected 200: {response_dict}")
if "chunk" in parsed_response:
chunk = parsed_response.get("chunk")
if not chunk:
return None
return chunk.get("bytes").decode() # type: ignore[no-any-return]
else:
chunk = response_dict.get("body")
if not chunk:
return None
return chunk.decode() # type: ignore[no-any-return]
def get_response_stream_shape():
global _response_stream_shape_cache
if _response_stream_shape_cache is None:
from botocore.loaders import Loader
from botocore.model import ServiceModel
loader = Loader()
sagemaker_service_dict = loader.load_service_model(
"sagemaker-runtime", "service-2"
)
sagemaker_service_model = ServiceModel(sagemaker_service_dict)
_response_stream_shape_cache = sagemaker_service_model.shape_for(
"InvokeEndpointWithResponseStreamOutput"
)
return _response_stream_shape_cache

View file

@ -22,12 +22,7 @@ from litellm.llms.custom_httpx.http_handler import (
_get_httpx_client, _get_httpx_client,
get_async_httpx_client, get_async_httpx_client,
) )
from litellm.types.llms.openai import ( from litellm.types.llms.openai import AllMessageValues
ChatCompletionToolCallChunk,
ChatCompletionUsageBlock,
)
from litellm.types.utils import GenericStreamingChunk as GChunk
from litellm.types.utils import StreamingChatCompletionChunk
from litellm.utils import ( from litellm.utils import (
CustomStreamWrapper, CustomStreamWrapper,
EmbeddingResponse, EmbeddingResponse,
@ -36,65 +31,12 @@ from litellm.utils import (
get_secret, get_secret,
) )
from ..base_aws_llm import BaseAWSLLM from ...base_aws_llm import BaseAWSLLM
from ..prompt_templates.factory import custom_prompt, prompt_factory from ...prompt_templates.factory import custom_prompt, prompt_factory
from ..common_utils import AWSEventStreamDecoder, SagemakerError
_response_stream_shape_cache = None from .transformation import SagemakerConfig
class SagemakerError(Exception):
def __init__(self, status_code, message):
self.status_code = status_code
self.message = message
self.request = httpx.Request(
method="POST", url="https://us-west-2.console.aws.amazon.com/sagemaker"
)
self.response = httpx.Response(status_code=status_code, request=self.request)
super().__init__(
self.message
) # Call the base class constructor with the parameters it needs
class SagemakerConfig:
"""
Reference: https://d-uuwbxj1u4cnu.studio.us-west-2.sagemaker.aws/jupyter/default/lab/workspaces/auto-q/tree/DemoNotebooks/meta-textgeneration-llama-2-7b-SDK_1.ipynb
"""
max_new_tokens: Optional[int] = None
top_p: Optional[float] = None
temperature: Optional[float] = None
return_full_text: Optional[bool] = None
def __init__(
self,
max_new_tokens: Optional[int] = None,
top_p: Optional[float] = None,
temperature: Optional[float] = None,
return_full_text: Optional[bool] = None,
) -> None:
locals_ = locals()
for key, value in locals_.items():
if key != "self" and value is not None:
setattr(self.__class__, key, value)
@classmethod
def get_config(cls):
return {
k: v
for k, v in cls.__dict__.items()
if not k.startswith("__")
and not isinstance(
v,
(
types.FunctionType,
types.BuiltinFunctionType,
classmethod,
staticmethod,
),
)
and v is not None
}
sagemaker_config = SagemakerConfig()
""" """
SAGEMAKER AUTH Keys/Vars SAGEMAKER AUTH Keys/Vars
@ -166,6 +108,7 @@ class SagemakerLLM(BaseAWSLLM):
credentials, credentials,
model: str, model: str,
data: dict, data: dict,
messages: List[AllMessageValues],
optional_params: dict, optional_params: dict,
aws_region_name: str, aws_region_name: str,
extra_headers: Optional[dict] = None, extra_headers: Optional[dict] = None,
@ -189,9 +132,12 @@ class SagemakerLLM(BaseAWSLLM):
api_base = sagemaker_base_url api_base = sagemaker_base_url
encoded_data = json.dumps(data).encode("utf-8") encoded_data = json.dumps(data).encode("utf-8")
headers = {"Content-Type": "application/json"} headers = sagemaker_config.validate_environment(
if extra_headers is not None: headers=extra_headers,
headers = {"Content-Type": "application/json", **extra_headers} model=model,
messages=messages,
optional_params=optional_params,
)
request = AWSRequest( request = AWSRequest(
method="POST", url=api_base, data=encoded_data, headers=headers method="POST", url=api_base, data=encoded_data, headers=headers
) )
@ -205,49 +151,6 @@ class SagemakerLLM(BaseAWSLLM):
return prepped_request return prepped_request
def _transform_prompt(
self,
model: str,
messages: List,
custom_prompt_dict: dict,
hf_model_name: Optional[str],
) -> str:
if model in custom_prompt_dict:
# check if the model has a registered custom prompt
model_prompt_details = custom_prompt_dict[model]
prompt = custom_prompt(
role_dict=model_prompt_details.get("roles", None),
initial_prompt_value=model_prompt_details.get(
"initial_prompt_value", ""
),
final_prompt_value=model_prompt_details.get("final_prompt_value", ""),
messages=messages,
)
elif hf_model_name in custom_prompt_dict:
# check if the base huggingface model has a registered custom prompt
model_prompt_details = custom_prompt_dict[hf_model_name]
prompt = custom_prompt(
role_dict=model_prompt_details.get("roles", None),
initial_prompt_value=model_prompt_details.get(
"initial_prompt_value", ""
),
final_prompt_value=model_prompt_details.get("final_prompt_value", ""),
messages=messages,
)
else:
if hf_model_name is None:
if "llama-2" in model.lower(): # llama-2 model
if "chat" in model.lower(): # apply llama2 chat template
hf_model_name = "meta-llama/Llama-2-7b-chat-hf"
else: # apply regular llama2 template
hf_model_name = "meta-llama/Llama-2-7b"
hf_model_name = (
hf_model_name or model
) # pass in hf model name for pulling it's prompt template - (e.g. `hf_model_name="meta-llama/Llama-2-7b-chat-hf` applies the llama2 chat template to the prompt)
prompt: str = prompt_factory(model=hf_model_name, messages=messages) # type: ignore
return prompt
def completion( # noqa: PLR0915 def completion( # noqa: PLR0915
self, self,
model: str, model: str,
@ -257,13 +160,13 @@ class SagemakerLLM(BaseAWSLLM):
encoding, encoding,
logging_obj, logging_obj,
optional_params: dict, optional_params: dict,
litellm_params: dict,
timeout: Optional[Union[float, httpx.Timeout]] = None, timeout: Optional[Union[float, httpx.Timeout]] = None,
custom_prompt_dict={}, custom_prompt_dict={},
hf_model_name=None, hf_model_name=None,
litellm_params=None,
logger_fn=None, logger_fn=None,
acompletion: bool = False, acompletion: bool = False,
use_messages_api: Optional[bool] = None, headers: dict = {},
): ):
# pop streaming if it's in the optional params as 'stream' raises an error with sagemaker # pop streaming if it's in the optional params as 'stream' raises an error with sagemaker
@ -272,50 +175,6 @@ class SagemakerLLM(BaseAWSLLM):
stream = inference_params.pop("stream", None) stream = inference_params.pop("stream", None)
model_id = optional_params.get("model_id", None) model_id = optional_params.get("model_id", None)
if use_messages_api is True:
from litellm.llms.openai_like.chat.handler import OpenAILikeChatHandler
openai_like_chat_completions = OpenAILikeChatHandler()
inference_params["stream"] = True if stream is True else False
_data: Dict[str, Any] = {
"model": model,
"messages": messages,
**inference_params,
}
prepared_request = self._prepare_request(
model=model,
data=_data,
optional_params=optional_params,
credentials=credentials,
aws_region_name=aws_region_name,
)
custom_stream_decoder = AWSEventStreamDecoder(
model="", is_messages_api=True
)
return openai_like_chat_completions.completion(
model=model,
messages=messages,
api_base=prepared_request.url,
api_key=None,
custom_prompt_dict=custom_prompt_dict,
model_response=model_response,
print_verbose=print_verbose,
logging_obj=logging_obj,
optional_params=inference_params,
acompletion=acompletion,
litellm_params=litellm_params,
logger_fn=logger_fn,
timeout=timeout,
encoding=encoding,
headers=prepared_request.headers, # type: ignore
custom_endpoint=True,
custom_llm_provider="sagemaker_chat",
streaming_decoder=custom_stream_decoder, # type: ignore
)
## Load Config ## Load Config
config = litellm.SagemakerConfig.get_config() config = litellm.SagemakerConfig.get_config()
for k, v in config.items(): for k, v in config.items():
@ -325,21 +184,6 @@ class SagemakerLLM(BaseAWSLLM):
inference_params[k] = v inference_params[k] = v
if stream is True: if stream is True:
data = {"parameters": inference_params, "stream": True}
prepared_request = self._prepare_request(
model=model,
data=data,
optional_params=optional_params,
credentials=credentials,
aws_region_name=aws_region_name,
)
if model_id is not None:
# Add model_id as InferenceComponentName header
# boto3 doc: https://docs.aws.amazon.com/sagemaker/latest/APIReference/API_runtime_InvokeEndpoint.html
prepared_request.headers.update(
{"X-Amzn-SageMaker-Inference-Component": model_id}
)
if acompletion is True: if acompletion is True:
response = self.async_streaming( response = self.async_streaming(
messages=messages, messages=messages,
@ -350,23 +194,25 @@ class SagemakerLLM(BaseAWSLLM):
encoding=encoding, encoding=encoding,
model_response=model_response, model_response=model_response,
logging_obj=logging_obj, logging_obj=logging_obj,
data=data,
model_id=model_id, model_id=model_id,
aws_region_name=aws_region_name, aws_region_name=aws_region_name,
credentials=credentials, credentials=credentials,
headers=headers,
litellm_params=litellm_params,
) )
return response return response
else: else:
prompt = self._transform_prompt( data = sagemaker_config.transform_request(
model=model, model=model,
messages=messages, messages=messages,
custom_prompt_dict=custom_prompt_dict, optional_params=optional_params,
hf_model_name=hf_model_name, litellm_params=litellm_params,
headers=headers,
) )
data["inputs"] = prompt
prepared_request = self._prepare_request( prepared_request = self._prepare_request(
model=model, model=model,
data=data, data=data,
messages=messages,
optional_params=optional_params, optional_params=optional_params,
credentials=credentials, credentials=credentials,
aws_region_name=aws_region_name, aws_region_name=aws_region_name,
@ -388,7 +234,7 @@ class SagemakerLLM(BaseAWSLLM):
if sync_response.status_code != 200: if sync_response.status_code != 200:
raise SagemakerError( raise SagemakerError(
status_code=sync_response.status_code, status_code=sync_response.status_code,
message=sync_response.read(), message=str(sync_response.read()),
) )
decoder = AWSEventStreamDecoder(model="") decoder = AWSEventStreamDecoder(model="")
@ -413,14 +259,6 @@ class SagemakerLLM(BaseAWSLLM):
return streaming_response return streaming_response
# Non-Streaming Requests # Non-Streaming Requests
_data = {"parameters": inference_params}
prepared_request_args = {
"model": model,
"data": _data,
"optional_params": optional_params,
"credentials": credentials,
"aws_region_name": aws_region_name,
}
# Async completion # Async completion
if acompletion is True: if acompletion is True:
@ -432,21 +270,30 @@ class SagemakerLLM(BaseAWSLLM):
model_response=model_response, model_response=model_response,
encoding=encoding, encoding=encoding,
logging_obj=logging_obj, logging_obj=logging_obj,
data=_data,
model_id=model_id, model_id=model_id,
optional_params=optional_params, optional_params=optional_params,
credentials=credentials, credentials=credentials,
aws_region_name=aws_region_name, aws_region_name=aws_region_name,
headers=headers,
litellm_params=litellm_params,
) )
prompt = self._transform_prompt( ## Non-Streaming completion CALL
_data = sagemaker_config.transform_request(
model=model, model=model,
messages=messages, messages=messages,
custom_prompt_dict=custom_prompt_dict, optional_params=optional_params,
hf_model_name=hf_model_name, litellm_params=litellm_params,
headers=headers,
) )
_data["inputs"] = prompt prepared_request_args = {
## Non-Streaming completion CALL "model": model,
"data": _data,
"optional_params": optional_params,
"credentials": credentials,
"aws_region_name": aws_region_name,
"messages": messages,
}
prepared_request = self._prepare_request(**prepared_request_args) prepared_request = self._prepare_request(**prepared_request_args)
try: try:
if model_id is not None: if model_id is not None:
@ -507,53 +354,16 @@ class SagemakerLLM(BaseAWSLLM):
error_message += "\n pass in via `litellm.completion(..., model_id={InferenceComponentName})`" error_message += "\n pass in via `litellm.completion(..., model_id={InferenceComponentName})`"
raise SagemakerError(status_code=status_code, message=error_message) raise SagemakerError(status_code=status_code, message=error_message)
completion_response = sync_response.json() return sagemaker_config.transform_response(
## LOGGING model=model,
logging_obj.post_call( raw_response=sync_response,
input=prompt, model_response=model_response,
api_key="", logging_obj=logging_obj,
original_response=completion_response, request_data=_data,
additional_args={"complete_input_dict": _data}, messages=messages,
optional_params=optional_params,
encoding=encoding,
) )
print_verbose(f"raw model_response: {completion_response}")
## RESPONSE OBJECT
try:
if isinstance(completion_response, list):
completion_response_choices = completion_response[0]
else:
completion_response_choices = completion_response
completion_output = ""
if "generation" in completion_response_choices:
completion_output += completion_response_choices["generation"]
elif "generated_text" in completion_response_choices:
completion_output += completion_response_choices["generated_text"]
# check if the prompt template is part of output, if so - filter it out
if completion_output.startswith(prompt) and "<s>" in prompt:
completion_output = completion_output.replace(prompt, "", 1)
model_response.choices[0].message.content = completion_output # type: ignore
except Exception:
raise SagemakerError(
message=f"LiteLLM Error: Unable to parse sagemaker RAW RESPONSE {json.dumps(completion_response)}",
status_code=500,
)
## CALCULATING USAGE - baseten charges on time, not tokens - have some mapping of cost here.
prompt_tokens = len(encoding.encode(prompt))
completion_tokens = len(
encoding.encode(model_response["choices"][0]["message"].get("content", ""))
)
model_response.created = int(time.time())
model_response.model = model
usage = Usage(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens,
)
setattr(model_response, "usage", usage)
return model_response
async def make_async_call( async def make_async_call(
self, self,
@ -605,7 +415,7 @@ class SagemakerLLM(BaseAWSLLM):
async def async_streaming( async def async_streaming(
self, self,
messages: list, messages: List[AllMessageValues],
model: str, model: str,
custom_prompt_dict: dict, custom_prompt_dict: dict,
hf_model_name: Optional[str], hf_model_name: Optional[str],
@ -616,13 +426,15 @@ class SagemakerLLM(BaseAWSLLM):
model_response: ModelResponse, model_response: ModelResponse,
model_id: Optional[str], model_id: Optional[str],
logging_obj: Any, logging_obj: Any,
data, litellm_params: dict,
headers: dict,
): ):
data["inputs"] = self._transform_prompt( data = await sagemaker_config.async_transform_request(
model=model, model=model,
messages=messages, messages=messages,
custom_prompt_dict=custom_prompt_dict, optional_params={**optional_params, "stream": True},
hf_model_name=hf_model_name, litellm_params=litellm_params,
headers=headers,
) )
asyncified_prepare_request = asyncify(self._prepare_request) asyncified_prepare_request = asyncify(self._prepare_request)
prepared_request_args = { prepared_request_args = {
@ -631,6 +443,7 @@ class SagemakerLLM(BaseAWSLLM):
"optional_params": optional_params, "optional_params": optional_params,
"credentials": credentials, "credentials": credentials,
"aws_region_name": aws_region_name, "aws_region_name": aws_region_name,
"messages": messages,
} }
prepared_request = await asyncified_prepare_request(**prepared_request_args) prepared_request = await asyncified_prepare_request(**prepared_request_args)
completion_stream = await self.make_async_call( completion_stream = await self.make_async_call(
@ -658,7 +471,7 @@ class SagemakerLLM(BaseAWSLLM):
async def async_completion( async def async_completion(
self, self,
messages: list, messages: List[AllMessageValues],
model: str, model: str,
custom_prompt_dict: dict, custom_prompt_dict: dict,
hf_model_name: Optional[str], hf_model_name: Optional[str],
@ -668,22 +481,23 @@ class SagemakerLLM(BaseAWSLLM):
model_response: ModelResponse, model_response: ModelResponse,
optional_params: dict, optional_params: dict,
logging_obj: Any, logging_obj: Any,
data: dict,
model_id: Optional[str], model_id: Optional[str],
headers: dict,
litellm_params: dict,
): ):
timeout = 300.0 timeout = 300.0
async_handler = get_async_httpx_client( async_handler = get_async_httpx_client(
llm_provider=litellm.LlmProviders.SAGEMAKER llm_provider=litellm.LlmProviders.SAGEMAKER
) )
async_transform_prompt = asyncify(self._transform_prompt) data = await sagemaker_config.async_transform_request(
data["inputs"] = await async_transform_prompt(
model=model, model=model,
messages=messages, messages=messages,
custom_prompt_dict=custom_prompt_dict, optional_params=optional_params,
hf_model_name=hf_model_name, litellm_params=litellm_params,
headers=headers,
) )
asyncified_prepare_request = asyncify(self._prepare_request) asyncified_prepare_request = asyncify(self._prepare_request)
prepared_request_args = { prepared_request_args = {
"model": model, "model": model,
@ -691,6 +505,7 @@ class SagemakerLLM(BaseAWSLLM):
"optional_params": optional_params, "optional_params": optional_params,
"credentials": credentials, "credentials": credentials,
"aws_region_name": aws_region_name, "aws_region_name": aws_region_name,
"messages": messages,
} }
prepared_request = await asyncified_prepare_request(**prepared_request_args) prepared_request = await asyncified_prepare_request(**prepared_request_args)
@ -738,52 +553,16 @@ class SagemakerLLM(BaseAWSLLM):
if "Inference Component Name header is required" in error_message: if "Inference Component Name header is required" in error_message:
error_message += "\n pass in via `litellm.completion(..., model_id={InferenceComponentName})`" error_message += "\n pass in via `litellm.completion(..., model_id={InferenceComponentName})`"
raise SagemakerError(status_code=500, message=error_message) raise SagemakerError(status_code=500, message=error_message)
completion_response = response.json() return sagemaker_config.transform_response(
## LOGGING model=model,
logging_obj.post_call( raw_response=response,
input=data["inputs"], model_response=model_response,
api_key="", logging_obj=logging_obj,
original_response=response, request_data=data,
additional_args={"complete_input_dict": data}, messages=messages,
optional_params=optional_params,
encoding=encoding,
) )
## RESPONSE OBJECT
try:
if isinstance(completion_response, list):
completion_response_choices = completion_response[0]
else:
completion_response_choices = completion_response
completion_output = ""
if "generation" in completion_response_choices:
completion_output += completion_response_choices["generation"]
elif "generated_text" in completion_response_choices:
completion_output += completion_response_choices["generated_text"]
# check if the prompt template is part of output, if so - filter it out
if completion_output.startswith(data["inputs"]) and "<s>" in data["inputs"]:
completion_output = completion_output.replace(data["inputs"], "", 1)
model_response.choices[0].message.content = completion_output # type: ignore
except Exception:
raise SagemakerError(
message=f"LiteLLM Error: Unable to parse sagemaker RAW RESPONSE {json.dumps(completion_response)}",
status_code=500,
)
## CALCULATING USAGE - baseten charges on time, not tokens - have some mapping of cost here.
prompt_tokens = len(encoding.encode(data["inputs"]))
completion_tokens = len(
encoding.encode(model_response["choices"][0]["message"].get("content", ""))
)
model_response.created = int(time.time())
model_response.model = model
usage = Usage(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens,
)
setattr(model_response, "usage", usage)
return model_response
def embedding( def embedding(
self, self,
@ -928,180 +707,3 @@ class SagemakerLLM(BaseAWSLLM):
) )
return model_response return model_response
def get_response_stream_shape():
global _response_stream_shape_cache
if _response_stream_shape_cache is None:
from botocore.loaders import Loader
from botocore.model import ServiceModel
loader = Loader()
sagemaker_service_dict = loader.load_service_model(
"sagemaker-runtime", "service-2"
)
sagemaker_service_model = ServiceModel(sagemaker_service_dict)
_response_stream_shape_cache = sagemaker_service_model.shape_for(
"InvokeEndpointWithResponseStreamOutput"
)
return _response_stream_shape_cache
class AWSEventStreamDecoder:
def __init__(self, model: str, is_messages_api: Optional[bool] = None) -> None:
from botocore.parsers import EventStreamJSONParser
self.model = model
self.parser = EventStreamJSONParser()
self.content_blocks: List = []
self.is_messages_api = is_messages_api
def _chunk_parser_messages_api(
self, chunk_data: dict
) -> StreamingChatCompletionChunk:
openai_chunk = StreamingChatCompletionChunk(**chunk_data)
return openai_chunk
def _chunk_parser(self, chunk_data: dict) -> GChunk:
verbose_logger.debug("in sagemaker chunk parser, chunk_data %s", chunk_data)
_token = chunk_data.get("token", {}) or {}
_index = chunk_data.get("index", None) or 0
is_finished = False
finish_reason = ""
_text = _token.get("text", "")
if _text == "<|endoftext|>":
return GChunk(
text="",
index=_index,
is_finished=True,
finish_reason="stop",
usage=None,
)
return GChunk(
text=_text,
index=_index,
is_finished=is_finished,
finish_reason=finish_reason,
usage=None,
)
def iter_bytes(
self, iterator: Iterator[bytes]
) -> Iterator[Optional[Union[GChunk, StreamingChatCompletionChunk]]]:
"""Given an iterator that yields lines, iterate over it & yield every event encountered"""
from botocore.eventstream import EventStreamBuffer
event_stream_buffer = EventStreamBuffer()
accumulated_json = ""
for chunk in iterator:
event_stream_buffer.add_data(chunk)
for event in event_stream_buffer:
message = self._parse_message_from_event(event)
if message:
# remove data: prefix and "\n\n" at the end
message = message.replace("data:", "").replace("\n\n", "")
# Accumulate JSON data
accumulated_json += message
# Try to parse the accumulated JSON
try:
_data = json.loads(accumulated_json)
if self.is_messages_api:
yield self._chunk_parser_messages_api(chunk_data=_data)
else:
yield self._chunk_parser(chunk_data=_data)
# Reset accumulated_json after successful parsing
accumulated_json = ""
except json.JSONDecodeError:
# If it's not valid JSON yet, continue to the next event
continue
# Handle any remaining data after the iterator is exhausted
if accumulated_json:
try:
_data = json.loads(accumulated_json)
if self.is_messages_api:
yield self._chunk_parser_messages_api(chunk_data=_data)
else:
yield self._chunk_parser(chunk_data=_data)
except json.JSONDecodeError:
# Handle or log any unparseable data at the end
verbose_logger.error(
f"Warning: Unparseable JSON data remained: {accumulated_json}"
)
yield None
async def aiter_bytes(
self, iterator: AsyncIterator[bytes]
) -> AsyncIterator[Optional[Union[GChunk, StreamingChatCompletionChunk]]]:
"""Given an async iterator that yields lines, iterate over it & yield every event encountered"""
from botocore.eventstream import EventStreamBuffer
event_stream_buffer = EventStreamBuffer()
accumulated_json = ""
async for chunk in iterator:
event_stream_buffer.add_data(chunk)
for event in event_stream_buffer:
message = self._parse_message_from_event(event)
if message:
verbose_logger.debug("sagemaker parsed chunk bytes %s", message)
# remove data: prefix and "\n\n" at the end
message = message.replace("data:", "").replace("\n\n", "")
# Accumulate JSON data
accumulated_json += message
# Try to parse the accumulated JSON
try:
_data = json.loads(accumulated_json)
if self.is_messages_api:
yield self._chunk_parser_messages_api(chunk_data=_data)
else:
yield self._chunk_parser(chunk_data=_data)
# Reset accumulated_json after successful parsing
accumulated_json = ""
except json.JSONDecodeError:
# If it's not valid JSON yet, continue to the next event
continue
# Handle any remaining data after the iterator is exhausted
if accumulated_json:
try:
_data = json.loads(accumulated_json)
if self.is_messages_api:
yield self._chunk_parser_messages_api(chunk_data=_data)
else:
yield self._chunk_parser(chunk_data=_data)
except json.JSONDecodeError:
# Handle or log any unparseable data at the end
verbose_logger.error(
f"Warning: Unparseable JSON data remained: {accumulated_json}"
)
yield None
def _parse_message_from_event(self, event) -> Optional[str]:
response_dict = event.to_response_dict()
parsed_response = self.parser.parse(response_dict, get_response_stream_shape())
if response_dict["status_code"] != 200:
raise ValueError(f"Bad response code, expected 200: {response_dict}")
if "chunk" in parsed_response:
chunk = parsed_response.get("chunk")
if not chunk:
return None
return chunk.get("bytes").decode() # type: ignore[no-any-return]
else:
chunk = response_dict.get("body")
if not chunk:
return None
return chunk.decode() # type: ignore[no-any-return]

View file

@ -0,0 +1,272 @@
"""
Translate from OpenAI's `/v1/chat/completions` to Sagemaker's `/invoke`
In the Huggingface TGI format.
"""
import json
import time
import types
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
from httpx._models import Headers, Response
import litellm
from litellm.litellm_core_utils.asyncify import asyncify
from litellm.llms.base_llm.transformation import BaseConfig, BaseLLMException
from litellm.llms.prompt_templates.factory import custom_prompt, prompt_factory
from litellm.types.llms.openai import AllMessageValues
from litellm.types.utils import Usage
from ..common_utils import SagemakerError
if TYPE_CHECKING:
from litellm.litellm_core_utils.litellm_logging import Logging as _LiteLLMLoggingObj
LiteLLMLoggingObj = _LiteLLMLoggingObj
else:
LiteLLMLoggingObj = Any
class SagemakerConfig(BaseConfig):
"""
Reference: https://d-uuwbxj1u4cnu.studio.us-west-2.sagemaker.aws/jupyter/default/lab/workspaces/auto-q/tree/DemoNotebooks/meta-textgeneration-llama-2-7b-SDK_1.ipynb
"""
max_new_tokens: Optional[int] = None
top_p: Optional[float] = None
temperature: Optional[float] = None
return_full_text: Optional[bool] = None
def __init__(
self,
max_new_tokens: Optional[int] = None,
top_p: Optional[float] = None,
temperature: Optional[float] = None,
return_full_text: Optional[bool] = None,
) -> None:
locals_ = locals()
for key, value in locals_.items():
if key != "self" and value is not None:
setattr(self.__class__, key, value)
@classmethod
def get_config(cls):
return super().get_config()
def _transform_messages(
self,
messages: List[AllMessageValues],
) -> List[AllMessageValues]:
return messages
def get_error_class(
self, error_message: str, status_code: int, headers: Union[dict, Headers]
) -> BaseLLMException:
return SagemakerError(
message=error_message, status_code=status_code, headers=headers
)
def get_supported_openai_params(self, model: str) -> List:
return ["stream", "temperature", "max_tokens", "top_p", "stop", "n"]
def map_openai_params(
self,
non_default_params: dict,
optional_params: dict,
model: str,
drop_params: bool,
) -> dict:
for param, value in non_default_params.items():
if param == "temperature":
if value == 0.0 or value == 0:
# hugging face exception raised when temp==0
# Failed: Error occurred: HuggingfaceException - Input validation error: `temperature` must be strictly positive
if not non_default_params.get(
"aws_sagemaker_allow_zero_temp", False
):
value = 0.01
optional_params["temperature"] = value
if param == "top_p":
optional_params["top_p"] = value
if param == "n":
optional_params["best_of"] = value
optional_params["do_sample"] = (
True # Need to sample if you want best of for hf inference endpoints
)
if param == "stream":
optional_params["stream"] = value
if param == "stop":
optional_params["stop"] = value
if param == "max_tokens":
# HF TGI raises the following exception when max_new_tokens==0
# Failed: Error occurred: HuggingfaceException - Input validation error: `max_new_tokens` must be strictly positive
if value == 0:
value = 1
optional_params["max_new_tokens"] = value
non_default_params.pop("aws_sagemaker_allow_zero_temp", None)
return optional_params
def _transform_prompt(
self,
model: str,
messages: List,
custom_prompt_dict: dict,
hf_model_name: Optional[str],
) -> str:
if model in custom_prompt_dict:
# check if the model has a registered custom prompt
model_prompt_details = custom_prompt_dict[model]
prompt = custom_prompt(
role_dict=model_prompt_details.get("roles", None),
initial_prompt_value=model_prompt_details.get(
"initial_prompt_value", ""
),
final_prompt_value=model_prompt_details.get("final_prompt_value", ""),
messages=messages,
)
elif hf_model_name in custom_prompt_dict:
# check if the base huggingface model has a registered custom prompt
model_prompt_details = custom_prompt_dict[hf_model_name]
prompt = custom_prompt(
role_dict=model_prompt_details.get("roles", None),
initial_prompt_value=model_prompt_details.get(
"initial_prompt_value", ""
),
final_prompt_value=model_prompt_details.get("final_prompt_value", ""),
messages=messages,
)
else:
if hf_model_name is None:
if "llama-2" in model.lower(): # llama-2 model
if "chat" in model.lower(): # apply llama2 chat template
hf_model_name = "meta-llama/Llama-2-7b-chat-hf"
else: # apply regular llama2 template
hf_model_name = "meta-llama/Llama-2-7b"
hf_model_name = (
hf_model_name or model
) # pass in hf model name for pulling it's prompt template - (e.g. `hf_model_name="meta-llama/Llama-2-7b-chat-hf` applies the llama2 chat template to the prompt)
prompt: str = prompt_factory(model=hf_model_name, messages=messages) # type: ignore
return prompt
def transform_request(
self,
model: str,
messages: List[AllMessageValues],
optional_params: dict,
litellm_params: dict,
headers: dict,
) -> dict:
inference_params = optional_params.copy()
stream = inference_params.pop("stream", False)
data: Dict = {"parameters": inference_params}
if stream is True:
data["stream"] = True
custom_prompt_dict = (
litellm_params.get("custom_prompt_dict", None) or litellm.custom_prompt_dict
)
hf_model_name = litellm_params.get("hf_model_name", None)
prompt = self._transform_prompt(
model=model,
messages=messages,
custom_prompt_dict=custom_prompt_dict,
hf_model_name=hf_model_name,
)
data["inputs"] = prompt
return data
async def async_transform_request(
self,
model: str,
messages: List[AllMessageValues],
optional_params: dict,
litellm_params: dict,
headers: dict,
) -> dict:
return await asyncify(self.transform_request)(
model, messages, optional_params, litellm_params, headers
)
def transform_response(
self,
model: str,
raw_response: Response,
model_response: litellm.ModelResponse,
logging_obj: LiteLLMLoggingObj,
request_data: dict,
messages: List[AllMessageValues],
optional_params: dict,
encoding: str,
api_key: Optional[str] = None,
json_mode: Optional[bool] = None,
) -> litellm.ModelResponse:
completion_response = raw_response.json()
## LOGGING
logging_obj.post_call(
input=messages,
api_key="",
original_response=completion_response,
additional_args={"complete_input_dict": request_data},
)
prompt = request_data["inputs"]
## RESPONSE OBJECT
try:
if isinstance(completion_response, list):
completion_response_choices = completion_response[0]
else:
completion_response_choices = completion_response
completion_output = ""
if "generation" in completion_response_choices:
completion_output += completion_response_choices["generation"]
elif "generated_text" in completion_response_choices:
completion_output += completion_response_choices["generated_text"]
# check if the prompt template is part of output, if so - filter it out
if completion_output.startswith(prompt) and "<s>" in prompt:
completion_output = completion_output.replace(prompt, "", 1)
model_response.choices[0].message.content = completion_output # type: ignore
except Exception:
raise SagemakerError(
message=f"LiteLLM Error: Unable to parse sagemaker RAW RESPONSE {json.dumps(completion_response)}",
status_code=500,
)
## CALCULATING USAGE - baseten charges on time, not tokens - have some mapping of cost here.
prompt_tokens = len(encoding.encode(prompt))
completion_tokens = len(
encoding.encode(model_response["choices"][0]["message"].get("content", ""))
)
model_response.created = int(time.time())
model_response.model = model
usage = Usage(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens,
)
setattr(model_response, "usage", usage)
return model_response
def validate_environment(
self,
headers: Optional[dict],
model: str,
messages: List[AllMessageValues],
optional_params: dict,
api_key: Optional[str] = None,
) -> dict:
headers = {"Content-Type": "application/json"}
if headers is not None:
headers = {"Content-Type": "application/json", **headers}
return headers

View file

@ -130,7 +130,8 @@ from .llms.prompt_templates.factory import (
prompt_factory, prompt_factory,
stringify_json_tool_call_content, stringify_json_tool_call_content,
) )
from .llms.sagemaker.sagemaker import SagemakerLLM from .llms.sagemaker.chat.handler import SagemakerChatHandler
from .llms.sagemaker.completion.handler import SagemakerLLM
from .llms.text_completion_codestral import CodestralTextCompletion from .llms.text_completion_codestral import CodestralTextCompletion
from .llms.together_ai.completion.handler import TogetherAITextCompletion from .llms.together_ai.completion.handler import TogetherAITextCompletion
from .llms.triton import TritonChatCompletion from .llms.triton import TritonChatCompletion
@ -229,6 +230,7 @@ watsonx_chat_completion = WatsonXChatHandler()
openai_like_embedding = OpenAILikeEmbeddingHandler() openai_like_embedding = OpenAILikeEmbeddingHandler()
databricks_embedding = DatabricksEmbeddingHandler() databricks_embedding = DatabricksEmbeddingHandler()
base_llm_http_handler = BaseLLMHTTPHandler() base_llm_http_handler = BaseLLMHTTPHandler()
sagemaker_chat_completion = SagemakerChatHandler()
####### COMPLETION ENDPOINTS ################ ####### COMPLETION ENDPOINTS ################
@ -1073,6 +1075,8 @@ def completion( # type: ignore # noqa: PLR0915
user_continue_message=kwargs.get("user_continue_message"), user_continue_message=kwargs.get("user_continue_message"),
base_model=base_model, base_model=base_model,
litellm_trace_id=kwargs.get("litellm_trace_id"), litellm_trace_id=kwargs.get("litellm_trace_id"),
hf_model_name=hf_model_name,
custom_prompt_dict=custom_prompt_dict,
) )
logging.update_environment_variables( logging.update_environment_variables(
model=model, model=model,
@ -2513,10 +2517,23 @@ def completion( # type: ignore # noqa: PLR0915
## RESPONSE OBJECT ## RESPONSE OBJECT
response = model_response response = model_response
elif ( elif custom_llm_provider == "sagemaker_chat":
custom_llm_provider == "sagemaker" # boto3 reads keys from .env
or custom_llm_provider == "sagemaker_chat" response = sagemaker_chat_completion.completion(
): model=model,
messages=messages,
model_response=model_response,
print_verbose=print_verbose,
optional_params=optional_params,
litellm_params=litellm_params,
custom_prompt_dict=custom_prompt_dict,
logger_fn=logger_fn,
encoding=encoding,
logging_obj=logging,
acompletion=acompletion,
headers=headers or {},
)
elif custom_llm_provider == "sagemaker":
# boto3 reads keys from .env # boto3 reads keys from .env
model_response = sagemaker_llm.completion( model_response = sagemaker_llm.completion(
model=model, model=model,
@ -2531,17 +2548,7 @@ def completion( # type: ignore # noqa: PLR0915
encoding=encoding, encoding=encoding,
logging_obj=logging, logging_obj=logging,
acompletion=acompletion, acompletion=acompletion,
use_messages_api=(
True if custom_llm_provider == "sagemaker_chat" else False
),
) )
if optional_params.get("stream", False):
## LOGGING
logging.post_call(
input=messages,
api_key=None,
original_response=model_response,
)
## RESPONSE OBJECT ## RESPONSE OBJECT
response = model_response response = model_response

View file

@ -2076,6 +2076,8 @@ def get_litellm_params(
user_continue_message=None, user_continue_message=None,
base_model=None, base_model=None,
litellm_trace_id=None, litellm_trace_id=None,
hf_model_name: Optional[str] = None,
custom_prompt_dict: Optional[dict] = None,
): ):
litellm_params = { litellm_params = {
"acompletion": acompletion, "acompletion": acompletion,
@ -2105,6 +2107,8 @@ def get_litellm_params(
"base_model": base_model "base_model": base_model
or _get_base_model_from_litellm_call_metadata(metadata=metadata), or _get_base_model_from_litellm_call_metadata(metadata=metadata),
"litellm_trace_id": litellm_trace_id, "litellm_trace_id": litellm_trace_id,
"hf_model_name": hf_model_name,
"custom_prompt_dict": custom_prompt_dict,
} }
return litellm_params return litellm_params
@ -3145,31 +3149,16 @@ def get_optional_params( # noqa: PLR0915
) )
_check_valid_arg(supported_params=supported_params) _check_valid_arg(supported_params=supported_params)
# temperature, top_p, n, stream, stop, max_tokens, n, presence_penalty default to None # temperature, top_p, n, stream, stop, max_tokens, n, presence_penalty default to None
if temperature is not None: optional_params = litellm.SagemakerConfig().map_openai_params(
if temperature == 0.0 or temperature == 0: non_default_params=non_default_params,
# hugging face exception raised when temp==0 optional_params=optional_params,
# Failed: Error occurred: HuggingfaceException - Input validation error: `temperature` must be strictly positive model=model,
if not passed_params.get("aws_sagemaker_allow_zero_temp", False): drop_params=(
temperature = 0.01 drop_params
optional_params["temperature"] = temperature if drop_params is not None and isinstance(drop_params, bool)
if top_p is not None: else False
optional_params["top_p"] = top_p ),
if n is not None: )
optional_params["best_of"] = n
optional_params["do_sample"] = (
True # Need to sample if you want best of for hf inference endpoints
)
if stream is not None:
optional_params["stream"] = stream
if stop is not None:
optional_params["stop"] = stop
if max_tokens is not None:
# HF TGI raises the following exception when max_new_tokens==0
# Failed: Error occurred: HuggingfaceException - Input validation error: `max_new_tokens` must be strictly positive
if max_tokens == 0:
max_tokens = 1
optional_params["max_new_tokens"] = max_tokens
passed_params.pop("aws_sagemaker_allow_zero_temp", None)
elif custom_llm_provider == "bedrock": elif custom_llm_provider == "bedrock":
supported_params = get_supported_openai_params( supported_params = get_supported_openai_params(
model=model, custom_llm_provider=custom_llm_provider model=model, custom_llm_provider=custom_llm_provider
@ -6295,6 +6284,10 @@ class ProviderConfigManager:
return litellm.VertexAIAnthropicConfig() return litellm.VertexAIAnthropicConfig()
elif litellm.LlmProviders.CLOUDFLARE == provider: elif litellm.LlmProviders.CLOUDFLARE == provider:
return litellm.CloudflareChatConfig() return litellm.CloudflareChatConfig()
elif litellm.LlmProviders.SAGEMAKER_CHAT == provider:
return litellm.SagemakerChatConfig()
elif litellm.LlmProviders.SAGEMAKER == provider:
return litellm.SagemakerConfig()
elif litellm.LlmProviders.FIREWORKS_AI == provider: elif litellm.LlmProviders.FIREWORKS_AI == provider:
return litellm.FireworksAIConfig() return litellm.FireworksAIConfig()
elif litellm.LlmProviders.FRIENDLIAI == provider: elif litellm.LlmProviders.FRIENDLIAI == provider:

View file

@ -246,23 +246,6 @@ async def test_hf_completion_tgi():
# test_get_cloudflare_response_streaming() # test_get_cloudflare_response_streaming()
@pytest.mark.skip(reason="AWS Suspended Account")
@pytest.mark.asyncio
async def test_completion_sagemaker():
# litellm.set_verbose=True
try:
response = await acompletion(
model="sagemaker/berri-benchmarking-Llama-2-70b-chat-hf-4",
messages=[{"content": "Hello, how are you?", "role": "user"}],
)
# Add any assertions here to check the response
print(response)
except litellm.Timeout as e:
pass
except Exception as e:
pytest.fail(f"Error occurred: {e}")
def test_get_response_streaming(): def test_get_response_streaming():
import asyncio import asyncio

View file

@ -129,7 +129,7 @@ async def test_completion_sagemaker_messages_api(sync_mode):
"sagemaker/jumpstart-dft-hf-textgeneration1-mp-20240815-185614", "sagemaker/jumpstart-dft-hf-textgeneration1-mp-20240815-185614",
], ],
) )
@pytest.mark.flaky(retries=3, delay=1) # @pytest.mark.flaky(retries=3, delay=1)
async def test_completion_sagemaker_stream(sync_mode, model): async def test_completion_sagemaker_stream(sync_mode, model):
try: try:
litellm.set_verbose = False litellm.set_verbose = False

View file

@ -1750,7 +1750,7 @@ def test_sagemaker_weird_response():
try: try:
import json import json
from litellm.llms.sagemaker.sagemaker import TokenIterator from litellm.llms.sagemaker.completion.handler import TokenIterator
chunk = """<s>[INST] Hey, how's it going? [/INST], chunk = """<s>[INST] Hey, how's it going? [/INST],
I'm doing well, thanks for asking! How about you? Is there anything you'd like to chat about or ask? I'm here to help with any questions you might have.""" I'm doing well, thanks for asking! How about you? Is there anything you'd like to chat about or ask? I'm here to help with any questions you might have."""