refactor sagemaker to be async

This commit is contained in:
Ishaan Jaff 2024-08-15 18:18:02 -07:00
parent b1aed699ea
commit df4ea8fba6
5 changed files with 798 additions and 603 deletions

View file

@ -7,16 +7,38 @@ import traceback
import types
from copy import deepcopy
from enum import Enum
from typing import Any, Callable, Optional
from functools import partial
from typing import Any, AsyncIterator, Callable, Iterator, List, Optional, Union
import httpx # type: ignore
import requests # type: ignore
import litellm
from litellm.utils import EmbeddingResponse, ModelResponse, Usage, get_secret
from litellm._logging import verbose_logger
from litellm.llms.custom_httpx.http_handler import (
AsyncHTTPHandler,
HTTPHandler,
_get_async_httpx_client,
_get_httpx_client,
)
from litellm.types.llms.openai import (
ChatCompletionToolCallChunk,
ChatCompletionUsageBlock,
)
from litellm.types.utils import GenericStreamingChunk as GChunk
from litellm.utils import (
CustomStreamWrapper,
EmbeddingResponse,
ModelResponse,
Usage,
get_secret,
)
from .base_aws_llm import BaseAWSLLM
from .prompt_templates.factory import custom_prompt, prompt_factory
_response_stream_shape_cache = None
class SagemakerError(Exception):
def __init__(self, status_code, message):
@ -31,73 +53,6 @@ class SagemakerError(Exception):
) # Call the base class constructor with the parameters it needs
class TokenIterator:
def __init__(self, stream, acompletion: bool = False):
if acompletion == False:
self.byte_iterator = iter(stream)
elif acompletion == True:
self.byte_iterator = stream
self.buffer = io.BytesIO()
self.read_pos = 0
self.end_of_data = False
def __iter__(self):
return self
def __next__(self):
try:
while True:
self.buffer.seek(self.read_pos)
line = self.buffer.readline()
if line and line[-1] == ord("\n"):
response_obj = {"text": "", "is_finished": False}
self.read_pos += len(line) + 1
full_line = line[:-1].decode("utf-8")
line_data = json.loads(full_line.lstrip("data:").rstrip("/n"))
if line_data.get("generated_text", None) is not None:
self.end_of_data = True
response_obj["is_finished"] = True
response_obj["text"] = line_data["token"]["text"]
return response_obj
chunk = next(self.byte_iterator)
self.buffer.seek(0, io.SEEK_END)
self.buffer.write(chunk["PayloadPart"]["Bytes"])
except StopIteration as e:
if self.end_of_data == True:
raise e # Re-raise StopIteration
else:
self.end_of_data = True
return "data: [DONE]"
def __aiter__(self):
return self
async def __anext__(self):
try:
while True:
self.buffer.seek(self.read_pos)
line = self.buffer.readline()
if line and line[-1] == ord("\n"):
response_obj = {"text": "", "is_finished": False}
self.read_pos += len(line) + 1
full_line = line[:-1].decode("utf-8")
line_data = json.loads(full_line.lstrip("data:").rstrip("/n"))
if line_data.get("generated_text", None) is not None:
self.end_of_data = True
response_obj["is_finished"] = True
response_obj["text"] = line_data["token"]["text"]
return response_obj
chunk = await self.byte_iterator.__anext__()
self.buffer.seek(0, io.SEEK_END)
self.buffer.write(chunk["PayloadPart"]["Bytes"])
except StopAsyncIteration as e:
if self.end_of_data == True:
raise e # Re-raise StopIteration
else:
self.end_of_data = True
return "data: [DONE]"
class SagemakerConfig:
"""
Reference: https://d-uuwbxj1u4cnu.studio.us-west-2.sagemaker.aws/jupyter/default/lab/workspaces/auto-q/tree/DemoNotebooks/meta-textgeneration-llama-2-7b-SDK_1.ipynb
@ -145,10 +100,89 @@ os.environ['AWS_ACCESS_KEY_ID'] = ""
os.environ['AWS_SECRET_ACCESS_KEY'] = ""
"""
# set os.environ['AWS_REGION_NAME'] = <your-region_name>
# set os.environ['AWS_REGION_NAME'] = <your-region_name>
class SagemakerLLM(BaseAWSLLM):
def _prepare_request(
self,
model: str,
data: dict,
optional_params: dict,
extra_headers: Optional[dict] = None,
):
try:
import boto3
from botocore.auth import SigV4Auth
from botocore.awsrequest import AWSRequest
from botocore.credentials import Credentials
except ImportError as e:
raise ImportError("Missing boto3 to call bedrock. Run 'pip install boto3'.")
## CREDENTIALS ##
# pop aws_secret_access_key, aws_access_key_id, aws_session_token, aws_region_name from kwargs, since completion calls fail with them
aws_secret_access_key = optional_params.pop("aws_secret_access_key", None)
aws_access_key_id = optional_params.pop("aws_access_key_id", None)
aws_session_token = optional_params.pop("aws_session_token", None)
aws_region_name = optional_params.pop("aws_region_name", None)
aws_role_name = optional_params.pop("aws_role_name", None)
aws_session_name = optional_params.pop("aws_session_name", None)
aws_profile_name = optional_params.pop("aws_profile_name", None)
aws_bedrock_runtime_endpoint = optional_params.pop(
"aws_bedrock_runtime_endpoint", None
) # https://bedrock-runtime.{region_name}.amazonaws.com
aws_web_identity_token = optional_params.pop("aws_web_identity_token", None)
aws_sts_endpoint = optional_params.pop("aws_sts_endpoint", None)
### SET REGION NAME ###
if aws_region_name is None:
# check env #
litellm_aws_region_name = get_secret("AWS_REGION_NAME", None)
if litellm_aws_region_name is not None and isinstance(
litellm_aws_region_name, str
):
aws_region_name = litellm_aws_region_name
standard_aws_region_name = get_secret("AWS_REGION", None)
if standard_aws_region_name is not None and isinstance(
standard_aws_region_name, str
):
aws_region_name = standard_aws_region_name
if aws_region_name is None:
aws_region_name = "us-west-2"
credentials: Credentials = self.get_credentials(
aws_access_key_id=aws_access_key_id,
aws_secret_access_key=aws_secret_access_key,
aws_session_token=aws_session_token,
aws_region_name=aws_region_name,
aws_session_name=aws_session_name,
aws_profile_name=aws_profile_name,
aws_role_name=aws_role_name,
aws_web_identity_token=aws_web_identity_token,
aws_sts_endpoint=aws_sts_endpoint,
)
sigv4 = SigV4Auth(credentials, "sagemaker", aws_region_name)
if optional_params.get("stream") is True:
api_base = f"https://runtime.sagemaker.{aws_region_name}.amazonaws.com/endpoints/{model}/invocations-response-stream"
else:
api_base = f"https://runtime.sagemaker.{aws_region_name}.amazonaws.com/endpoints/{model}/invocations"
encoded_data = json.dumps(data).encode("utf-8")
headers = {"Content-Type": "application/json"}
if extra_headers is not None:
headers = {"Content-Type": "application/json", **extra_headers}
request = AWSRequest(
method="POST", url=api_base, data=encoded_data, headers=headers
)
sigv4.add_auth(request)
prepped_request = request.prepare()
return prepped_request
def completion(
self,
model: str,
messages: list,
model_response: ModelResponse,
@ -162,38 +196,6 @@ def completion(
logger_fn=None,
acompletion: bool = False,
):
import boto3
# pop aws_secret_access_key, aws_access_key_id, aws_region_name from kwargs, since completion calls fail with them
aws_secret_access_key = optional_params.pop("aws_secret_access_key", None)
aws_access_key_id = optional_params.pop("aws_access_key_id", None)
aws_region_name = optional_params.pop("aws_region_name", None)
model_id = optional_params.pop("model_id", None)
if aws_access_key_id != None:
# uses auth params passed to completion
# aws_access_key_id is not None, assume user is trying to auth using litellm.completion
client = boto3.client(
service_name="sagemaker-runtime",
aws_access_key_id=aws_access_key_id,
aws_secret_access_key=aws_secret_access_key,
region_name=aws_region_name,
)
else:
# aws_access_key_id is None, assume user is trying to auth using env variables
# boto3 automaticaly reads env variables
# we need to read region name from env
# I assume majority of users use .env for auth
region_name = (
get_secret("AWS_REGION_NAME")
or aws_region_name # get region from config file if specified
or "us-west-2" # default to us-west-2 if region not specified
)
client = boto3.client(
service_name="sagemaker-runtime",
region_name=region_name,
)
# pop streaming if it's in the optional params as 'stream' raises an error with sagemaker
inference_params = deepcopy(optional_params)
@ -206,13 +208,14 @@ def completion(
): # completion(top_k=3) > sagemaker_config(top_k=3) <- allows for dynamic variables to be passed in
inference_params[k] = v
model = model
if model in custom_prompt_dict:
# check if the model has a registered custom prompt
model_prompt_details = custom_prompt_dict[model]
prompt = custom_prompt(
role_dict=model_prompt_details.get("roles", None),
initial_prompt_value=model_prompt_details.get("initial_prompt_value", ""),
initial_prompt_value=model_prompt_details.get(
"initial_prompt_value", ""
),
final_prompt_value=model_prompt_details.get("final_prompt_value", ""),
messages=messages,
)
@ -221,7 +224,9 @@ def completion(
model_prompt_details = custom_prompt_dict[hf_model_name]
prompt = custom_prompt(
role_dict=model_prompt_details.get("roles", None),
initial_prompt_value=model_prompt_details.get("initial_prompt_value", ""),
initial_prompt_value=model_prompt_details.get(
"initial_prompt_value", ""
),
final_prompt_value=model_prompt_details.get("final_prompt_value", ""),
messages=messages,
)
@ -237,12 +242,25 @@ def completion(
) # pass in hf model name for pulling it's prompt template - (e.g. `hf_model_name="meta-llama/Llama-2-7b-chat-hf` applies the llama2 chat template to the prompt)
prompt = prompt_factory(model=hf_model_name, messages=messages)
stream = inference_params.pop("stream", None)
if stream == True:
data = json.dumps(
{"inputs": prompt, "parameters": inference_params, "stream": True}
).encode("utf-8")
if acompletion == True:
response = async_streaming(
model_id = optional_params.get("model_id", None)
if stream is True:
data = {"inputs": prompt, "parameters": inference_params, "stream": True}
prepared_request = self._prepare_request(
model=model,
data=data,
optional_params=optional_params,
)
if model_id is not None:
# Add model_id as InferenceComponentName header
# boto3 doc: https://docs.aws.amazon.com/sagemaker/latest/APIReference/API_runtime_InvokeEndpoint.html
prepared_request.headers.update(
{"X-Amzn-SageMaker-Inference-Componen": model_id}
)
if acompletion is True:
response = self.async_streaming(
prepared_request=prepared_request,
optional_params=optional_params,
encoding=encoding,
model_response=model_response,
@ -250,99 +268,104 @@ def completion(
logging_obj=logging_obj,
data=data,
model_id=model_id,
aws_secret_access_key=aws_secret_access_key,
aws_access_key_id=aws_access_key_id,
aws_region_name=aws_region_name,
)
return response
if model_id is not None:
response = client.invoke_endpoint_with_response_stream(
EndpointName=model,
InferenceComponentName=model_id,
ContentType="application/json",
Body=data,
CustomAttributes="accept_eula=true",
)
else:
response = client.invoke_endpoint_with_response_stream(
EndpointName=model,
ContentType="application/json",
Body=data,
CustomAttributes="accept_eula=true",
if stream is not None and stream == True:
sync_handler = _get_httpx_client()
sync_response = sync_handler.post(
url=prepared_request.url,
headers=prepared_request.headers, # type: ignore
json=data,
stream=stream,
)
return response["Body"]
elif acompletion == True:
if sync_response.status_code != 200:
raise SagemakerError(
status_code=sync_response.status_code,
message=sync_response.read(),
)
decoder = AWSEventStreamDecoder(model="")
completion_stream = decoder.iter_bytes(
sync_response.iter_bytes(chunk_size=1024)
)
streaming_response = CustomStreamWrapper(
completion_stream=completion_stream,
model=model,
custom_llm_provider="sagemaker",
logging_obj=logging_obj,
)
## LOGGING
logging_obj.post_call(
input=messages,
api_key="",
original_response=streaming_response,
additional_args={"complete_input_dict": data},
)
return streaming_response
# Non-Streaming Requests
_data = {"inputs": prompt, "parameters": inference_params}
return async_completion(
prepared_request = self._prepare_request(
model=model,
data=_data,
optional_params=optional_params,
encoding=encoding,
)
# Async completion
if acompletion == True:
return self.async_completion(
prepared_request=prepared_request,
model_response=model_response,
encoding=encoding,
model=model,
logging_obj=logging_obj,
data=_data,
model_id=model_id,
aws_secret_access_key=aws_secret_access_key,
aws_access_key_id=aws_access_key_id,
aws_region_name=aws_region_name,
)
data = json.dumps({"inputs": prompt, "parameters": inference_params}).encode(
"utf-8"
)
## COMPLETION CALL
## Non-Streaming completion CALL
try:
if model_id is not None:
## LOGGING
request_str = f"""
response = client.invoke_endpoint(
EndpointName={model},
InferenceComponentName={model_id},
ContentType="application/json",
Body={data}, # type: ignore
CustomAttributes="accept_eula=true",
# Add model_id as InferenceComponentName header
# boto3 doc: https://docs.aws.amazon.com/sagemaker/latest/APIReference/API_runtime_InvokeEndpoint.html
prepared_request.headers.update(
{"X-Amzn-SageMaker-Inference-Componen": model_id}
)
""" # type: ignore
## LOGGING
timeout = 300.0
sync_handler = _get_httpx_client()
## LOGGING
logging_obj.pre_call(
input=prompt,
input=[],
api_key="",
additional_args={
"complete_input_dict": data,
"request_str": request_str,
"hf_model_name": hf_model_name,
"complete_input_dict": _data,
"api_base": prepared_request.url,
"headers": prepared_request.headers,
},
)
response = client.invoke_endpoint(
EndpointName=model,
InferenceComponentName=model_id,
ContentType="application/json",
Body=data,
CustomAttributes="accept_eula=true",
# make sync httpx post request here
try:
sync_response = sync_handler.post(
url=prepared_request.url,
headers=prepared_request.headers,
json=_data,
timeout=timeout,
)
else:
except Exception as e:
## LOGGING
request_str = f"""
response = client.invoke_endpoint(
EndpointName={model},
ContentType="application/json",
Body={data}, # type: ignore
CustomAttributes="accept_eula=true",
)
""" # type: ignore
logging_obj.pre_call(
input=prompt,
logging_obj.post_call(
input=[],
api_key="",
additional_args={
"complete_input_dict": data,
"request_str": request_str,
"hf_model_name": hf_model_name,
},
)
response = client.invoke_endpoint(
EndpointName=model,
ContentType="application/json",
Body=data,
CustomAttributes="accept_eula=true",
original_response=str(e),
additional_args={"complete_input_dict": _data},
)
raise e
except Exception as e:
status_code = (
getattr(e, "response", {})
@ -356,17 +379,16 @@ def completion(
error_message += "\n pass in via `litellm.completion(..., model_id={InferenceComponentName})`"
raise SagemakerError(status_code=status_code, message=error_message)
response = response["Body"].read().decode("utf8")
completion_response = sync_response.json()
## LOGGING
logging_obj.post_call(
input=prompt,
api_key="",
original_response=response,
additional_args={"complete_input_dict": data},
original_response=completion_response,
additional_args={"complete_input_dict": _data},
)
print_verbose(f"raw model_response: {response}")
## RESPONSE OBJECT
completion_response = json.loads(response)
try:
if isinstance(completion_response, list):
completion_response_choices = completion_response[0]
@ -405,8 +427,57 @@ def completion(
setattr(model_response, "usage", usage)
return model_response
async def make_async_call(
self,
api_base: str,
headers: dict,
data: str,
logging_obj,
client=None,
):
try:
if client is None:
client = (
_get_async_httpx_client()
) # Create a new client if none provided
response = await client.post(
api_base,
headers=headers,
json=data,
stream=True,
)
if response.status_code != 200:
raise SagemakerError(
status_code=response.status_code, message=response.text
)
decoder = AWSEventStreamDecoder(model="")
completion_stream = decoder.aiter_bytes(
response.aiter_bytes(chunk_size=1024)
)
return completion_stream
# LOGGING
logging_obj.post_call(
input=[],
api_key="",
original_response="first stream response received",
additional_args={"complete_input_dict": data},
)
except httpx.HTTPStatusError as err:
error_code = err.response.status_code
raise SagemakerError(status_code=error_code, message=err.response.text)
except httpx.TimeoutException as e:
raise SagemakerError(status_code=408, message="Timeout error occurred.")
except Exception as e:
raise SagemakerError(status_code=500, message=str(e))
async def async_streaming(
self,
prepared_request,
optional_params,
encoding,
model_response: ModelResponse,
@ -414,170 +485,83 @@ async def async_streaming(
model_id: Optional[str],
logging_obj: Any,
data,
aws_secret_access_key: Optional[str],
aws_access_key_id: Optional[str],
aws_region_name: Optional[str],
):
"""
Use aioboto3
"""
import aioboto3
session = aioboto3.Session()
if aws_access_key_id != None:
# uses auth params passed to completion
# aws_access_key_id is not None, assume user is trying to auth using litellm.completion
_client = session.client(
service_name="sagemaker-runtime",
aws_access_key_id=aws_access_key_id,
aws_secret_access_key=aws_secret_access_key,
region_name=aws_region_name,
)
else:
# aws_access_key_id is None, assume user is trying to auth using env variables
# boto3 automaticaly reads env variables
# we need to read region name from env
# I assume majority of users use .env for auth
region_name = (
get_secret("AWS_REGION_NAME")
or aws_region_name # get region from config file if specified
or "us-west-2" # default to us-west-2 if region not specified
)
_client = session.client(
service_name="sagemaker-runtime",
region_name=region_name,
streaming_response = CustomStreamWrapper(
completion_stream=None,
make_call=partial(
self.make_async_call,
api_base=prepared_request.url,
headers=prepared_request.headers,
data=data,
logging_obj=logging_obj,
),
model=model,
custom_llm_provider="sagemaker",
logging_obj=logging_obj,
)
async with _client as client:
try:
if model_id is not None:
response = await client.invoke_endpoint_with_response_stream(
EndpointName=model,
InferenceComponentName=model_id,
ContentType="application/json",
Body=data,
CustomAttributes="accept_eula=true",
# LOGGING
logging_obj.post_call(
input=[],
api_key="",
original_response="first stream response received",
additional_args={"complete_input_dict": data},
)
else:
response = await client.invoke_endpoint_with_response_stream(
EndpointName=model,
ContentType="application/json",
Body=data,
CustomAttributes="accept_eula=true",
)
except Exception as e:
raise SagemakerError(status_code=500, message=f"{str(e)}")
response = response["Body"]
async for chunk in response:
yield chunk
return streaming_response
async def async_completion(
optional_params,
self,
prepared_request,
encoding,
model_response: ModelResponse,
model: str,
logging_obj: Any,
data: dict,
model_id: Optional[str],
aws_secret_access_key: Optional[str],
aws_access_key_id: Optional[str],
aws_region_name: Optional[str],
):
"""
Use aioboto3
"""
import aioboto3
session = aioboto3.Session()
if aws_access_key_id != None:
# uses auth params passed to completion
# aws_access_key_id is not None, assume user is trying to auth using litellm.completion
_client = session.client(
service_name="sagemaker-runtime",
aws_access_key_id=aws_access_key_id,
aws_secret_access_key=aws_secret_access_key,
region_name=aws_region_name,
timeout = 300.0
async_handler = _get_async_httpx_client()
## LOGGING
logging_obj.pre_call(
input=[],
api_key="",
additional_args={
"complete_input_dict": data,
"api_base": prepared_request.url,
"headers": prepared_request.headers,
},
)
else:
# aws_access_key_id is None, assume user is trying to auth using env variables
# boto3 automaticaly reads env variables
# we need to read region name from env
# I assume majority of users use .env for auth
region_name = (
get_secret("AWS_REGION_NAME")
or aws_region_name # get region from config file if specified
or "us-west-2" # default to us-west-2 if region not specified
)
_client = session.client(
service_name="sagemaker-runtime",
region_name=region_name,
)
async with _client as client:
encoded_data = json.dumps(data).encode("utf-8")
try:
if model_id is not None:
## LOGGING
request_str = f"""
response = client.invoke_endpoint(
EndpointName={model},
InferenceComponentName={model_id},
ContentType="application/json",
Body={data},
CustomAttributes="accept_eula=true",
# Add model_id as InferenceComponentName header
# boto3 doc: https://docs.aws.amazon.com/sagemaker/latest/APIReference/API_runtime_InvokeEndpoint.html
prepared_request.headers.update(
{"X-Amzn-SageMaker-Inference-Componen": model_id}
)
""" # type: ignore
logging_obj.pre_call(
# make async httpx post request here
try:
response = await async_handler.post(
url=prepared_request.url,
headers=prepared_request.headers,
json=data,
timeout=timeout,
)
except Exception as e:
## LOGGING
logging_obj.post_call(
input=data["inputs"],
api_key="",
additional_args={
"complete_input_dict": data,
"request_str": request_str,
},
)
response = await client.invoke_endpoint(
EndpointName=model,
InferenceComponentName=model_id,
ContentType="application/json",
Body=encoded_data,
CustomAttributes="accept_eula=true",
)
else:
## LOGGING
request_str = f"""
response = client.invoke_endpoint(
EndpointName={model},
ContentType="application/json",
Body={data},
CustomAttributes="accept_eula=true",
)
""" # type: ignore
logging_obj.pre_call(
input=data["inputs"],
api_key="",
additional_args={
"complete_input_dict": data,
"request_str": request_str,
},
)
response = await client.invoke_endpoint(
EndpointName=model,
ContentType="application/json",
Body=encoded_data,
CustomAttributes="accept_eula=true",
original_response=str(e),
additional_args={"complete_input_dict": data},
)
raise e
except Exception as e:
error_message = f"{str(e)}"
if "Inference Component Name header is required" in error_message:
error_message += "\n pass in via `litellm.completion(..., model_id={InferenceComponentName})`"
raise SagemakerError(status_code=500, message=error_message)
response = await response["Body"].read()
response = response.decode("utf8")
completion_response = response.json()
## LOGGING
logging_obj.post_call(
input=data["inputs"],
@ -586,7 +570,6 @@ async def async_completion(
additional_args={"complete_input_dict": data},
)
## RESPONSE OBJECT
completion_response = json.loads(response)
try:
if isinstance(completion_response, list):
completion_response_choices = completion_response[0]
@ -625,8 +608,8 @@ async def async_completion(
setattr(model_response, "usage", usage)
return model_response
def embedding(
self,
model: str,
input: list,
model_response: EmbeddingResponse,
@ -732,12 +715,15 @@ def embedding(
print_verbose(f"raw model_response: {response}")
if "embedding" not in response:
raise SagemakerError(status_code=500, message="embedding not found in response")
raise SagemakerError(
status_code=500, message="embedding not found in response"
)
embeddings = response["embedding"]
if not isinstance(embeddings, list):
raise SagemakerError(
status_code=422, message=f"Response not in expected format - {embeddings}"
status_code=422,
message=f"Response not in expected format - {embeddings}",
)
output_data = []
@ -758,8 +744,160 @@ def embedding(
model_response,
"usage",
Usage(
prompt_tokens=input_tokens, completion_tokens=0, total_tokens=input_tokens
prompt_tokens=input_tokens,
completion_tokens=0,
total_tokens=input_tokens,
),
)
return model_response
def get_response_stream_shape():
global _response_stream_shape_cache
if _response_stream_shape_cache is None:
from botocore.loaders import Loader
from botocore.model import ServiceModel
loader = Loader()
sagemaker_service_dict = loader.load_service_model(
"sagemaker-runtime", "service-2"
)
sagemaker_service_model = ServiceModel(sagemaker_service_dict)
_response_stream_shape_cache = sagemaker_service_model.shape_for(
"InvokeEndpointWithResponseStreamOutput"
)
return _response_stream_shape_cache
class AWSEventStreamDecoder:
def __init__(self, model: str) -> None:
from botocore.parsers import EventStreamJSONParser
self.model = model
self.parser = EventStreamJSONParser()
self.content_blocks: List = []
def _chunk_parser(self, chunk_data: dict) -> GChunk:
verbose_logger.debug("in sagemaker chunk parser, chunk_data %s", chunk_data)
_token = chunk_data["token"]
_index = chunk_data["index"]
is_finished = False
finish_reason = ""
if _token["text"] == "<|endoftext|>":
return GChunk(
text="",
index=_index,
is_finished=True,
finish_reason="stop",
)
return GChunk(
text=_token["text"],
index=_index,
is_finished=is_finished,
finish_reason=finish_reason,
)
def iter_bytes(self, iterator: Iterator[bytes]) -> Iterator[GChunk]:
"""Given an iterator that yields lines, iterate over it & yield every event encountered"""
from botocore.eventstream import EventStreamBuffer
event_stream_buffer = EventStreamBuffer()
accumulated_json = ""
for chunk in iterator:
event_stream_buffer.add_data(chunk)
for event in event_stream_buffer:
message = self._parse_message_from_event(event)
if message:
# remove data: prefix and "\n\n" at the end
message = message.replace("data:", "").replace("\n\n", "")
# Accumulate JSON data
accumulated_json += message
# Try to parse the accumulated JSON
try:
_data = json.loads(accumulated_json)
yield self._chunk_parser(chunk_data=_data)
# Reset accumulated_json after successful parsing
accumulated_json = ""
except json.JSONDecodeError:
# If it's not valid JSON yet, continue to the next event
continue
# Handle any remaining data after the iterator is exhausted
if accumulated_json:
try:
_data = json.loads(accumulated_json)
yield self._chunk_parser(chunk_data=_data)
except json.JSONDecodeError:
# Handle or log any unparseable data at the end
verbose_logger.error(
f"Warning: Unparseable JSON data remained: {accumulated_json}"
)
async def aiter_bytes(
self, iterator: AsyncIterator[bytes]
) -> AsyncIterator[GChunk]:
"""Given an async iterator that yields lines, iterate over it & yield every event encountered"""
from botocore.eventstream import EventStreamBuffer
event_stream_buffer = EventStreamBuffer()
accumulated_json = ""
async for chunk in iterator:
event_stream_buffer.add_data(chunk)
for event in event_stream_buffer:
message = self._parse_message_from_event(event)
if message:
verbose_logger.debug("sagemaker parsed chunk bytes %s", message)
# remove data: prefix and "\n\n" at the end
message = message.replace("data:", "").replace("\n\n", "")
# Accumulate JSON data
accumulated_json += message
# Try to parse the accumulated JSON
try:
_data = json.loads(accumulated_json)
yield self._chunk_parser(chunk_data=_data)
# Reset accumulated_json after successful parsing
accumulated_json = ""
except json.JSONDecodeError:
# If it's not valid JSON yet, continue to the next event
continue
# Handle any remaining data after the iterator is exhausted
if accumulated_json:
try:
_data = json.loads(accumulated_json)
yield self._chunk_parser(chunk_data=_data)
except json.JSONDecodeError:
# Handle or log any unparseable data at the end
verbose_logger.error(
f"Warning: Unparseable JSON data remained: {accumulated_json}"
)
def _parse_message_from_event(self, event) -> Optional[str]:
response_dict = event.to_response_dict()
parsed_response = self.parser.parse(response_dict, get_response_stream_shape())
if response_dict["status_code"] != 200:
raise ValueError(f"Bad response code, expected 200: {response_dict}")
if "chunk" in parsed_response:
chunk = parsed_response.get("chunk")
if not chunk:
return None
return chunk.get("bytes").decode() # type: ignore[no-any-return]
else:
chunk = response_dict.get("body")
if not chunk:
return None
return chunk.decode() # type: ignore[no-any-return]

View file

@ -95,7 +95,6 @@ from .llms import (
palm,
petals,
replicate,
sagemaker,
together_ai,
triton,
vertex_ai,
@ -120,6 +119,7 @@ from .llms.prompt_templates.factory import (
prompt_factory,
stringify_json_tool_call_content,
)
from .llms.sagemaker import SagemakerLLM
from .llms.text_completion_codestral import CodestralTextCompletion
from .llms.triton import TritonChatCompletion
from .llms.vertex_ai_partner import VertexAIPartnerModels
@ -166,6 +166,7 @@ bedrock_converse_chat_completion = BedrockConverseLLM()
vertex_chat_completion = VertexLLM()
vertex_partner_models_chat_completion = VertexAIPartnerModels()
watsonxai = IBMWatsonXAI()
sagemaker_llm = SagemakerLLM()
####### COMPLETION ENDPOINTS ################
@ -2216,7 +2217,7 @@ def completion(
response = model_response
elif custom_llm_provider == "sagemaker":
# boto3 reads keys from .env
model_response = sagemaker.completion(
model_response = sagemaker_llm.completion(
model=model,
messages=messages,
model_response=model_response,
@ -2230,26 +2231,13 @@ def completion(
logging_obj=logging,
acompletion=acompletion,
)
if (
"stream" in optional_params and optional_params["stream"] == True
): ## [BETA]
print_verbose(f"ENTERS SAGEMAKER CUSTOMSTREAMWRAPPER")
from .llms.sagemaker import TokenIterator
tokenIterator = TokenIterator(model_response, acompletion=acompletion)
response = CustomStreamWrapper(
completion_stream=tokenIterator,
model=model,
custom_llm_provider="sagemaker",
logging_obj=logging,
)
if optional_params.get("stream", False):
## LOGGING
logging.post_call(
input=messages,
api_key=None,
original_response=response,
original_response=model_response,
)
return response
## RESPONSE OBJECT
response = model_response
@ -3529,7 +3517,7 @@ def embedding(
model_response=EmbeddingResponse(),
)
elif custom_llm_provider == "sagemaker":
response = sagemaker.embedding(
response = sagemaker_llm.embedding(
model=model,
input=input,
encoding=encoding,

View file

@ -28,6 +28,9 @@ litellm.cache = None
litellm.success_callback = []
user_message = "Write a short poem about the sky"
messages = [{"content": user_message, "role": "user"}]
import logging
from litellm._logging import verbose_logger
def logger_fn(user_model_dict):
@ -80,6 +83,55 @@ async def test_completion_sagemaker(sync_mode):
pytest.fail(f"Error occurred: {e}")
@pytest.mark.asyncio()
@pytest.mark.parametrize("sync_mode", [True])
async def test_completion_sagemaker_stream(sync_mode):
try:
litellm.set_verbose = False
print("testing sagemaker")
verbose_logger.setLevel(logging.DEBUG)
full_text = ""
if sync_mode is True:
response = litellm.completion(
model="sagemaker/jumpstart-dft-hf-textgeneration1-mp-20240815-185614",
messages=[
{"role": "user", "content": "hi - what is ur name"},
],
temperature=0.2,
stream=True,
max_tokens=80,
input_cost_per_second=0.000420,
)
for chunk in response:
print(chunk)
full_text += chunk.choices[0].delta.content or ""
print("SYNC RESPONSE full text", full_text)
else:
response = await litellm.acompletion(
model="sagemaker/jumpstart-dft-hf-textgeneration1-mp-20240815-185614",
messages=[
{"role": "user", "content": "hi - what is ur name"},
],
stream=True,
temperature=0.2,
max_tokens=80,
input_cost_per_second=0.000420,
)
print("streaming response")
async for chunk in response:
print(chunk)
full_text += chunk.choices[0].delta.content or ""
print("ASYNC RESPONSE full text", full_text)
except Exception as e:
pytest.fail(f"Error occurred: {e}")
@pytest.mark.asyncio
async def test_acompletion_sagemaker_non_stream():
mock_response = AsyncMock()

View file

@ -80,7 +80,7 @@ class ModelInfo(TypedDict, total=False):
supports_assistant_prefill: Optional[bool]
class GenericStreamingChunk(TypedDict):
class GenericStreamingChunk(TypedDict, total=False):
text: Required[str]
tool_use: Optional[ChatCompletionToolCallChunk]
is_finished: Required[bool]

View file

@ -9848,11 +9848,28 @@ class CustomStreamWrapper:
completion_obj["tool_calls"] = [response_obj["tool_use"]]
elif self.custom_llm_provider == "sagemaker":
print_verbose(f"ENTERS SAGEMAKER STREAMING for chunk {chunk}")
response_obj = self.handle_sagemaker_stream(chunk)
from litellm.types.llms.bedrock import GenericStreamingChunk
if self.received_finish_reason is not None:
raise StopIteration
response_obj: GenericStreamingChunk = chunk
completion_obj["content"] = response_obj["text"]
if response_obj["is_finished"]:
self.received_finish_reason = response_obj["finish_reason"]
if (
self.stream_options
and self.stream_options.get("include_usage", False) is True
and response_obj["usage"] is not None
):
model_response.usage = litellm.Usage(
prompt_tokens=response_obj["usage"]["inputTokens"],
completion_tokens=response_obj["usage"]["outputTokens"],
total_tokens=response_obj["usage"]["totalTokens"],
)
if "tool_use" in response_obj and response_obj["tool_use"] is not None:
completion_obj["tool_calls"] = [response_obj["tool_use"]]
elif self.custom_llm_provider == "petals":
if len(self.completion_stream) == 0:
if self.received_finish_reason is not None: