mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-27 11:43:54 +00:00
feat(sagemaker.py): aioboto3 streaming support
This commit is contained in:
parent
23c410a548
commit
5de569fcb1
4 changed files with 80 additions and 13 deletions
|
@ -30,8 +30,11 @@ import json
|
||||||
|
|
||||||
|
|
||||||
class TokenIterator:
|
class TokenIterator:
|
||||||
def __init__(self, stream):
|
def __init__(self, stream, acompletion: bool):
|
||||||
self.byte_iterator = iter(stream)
|
if acompletion == False:
|
||||||
|
self.byte_iterator = iter(stream)
|
||||||
|
elif acompletion == True:
|
||||||
|
self.byte_iterator = stream
|
||||||
self.buffer = io.BytesIO()
|
self.buffer = io.BytesIO()
|
||||||
self.read_pos = 0
|
self.read_pos = 0
|
||||||
self.end_of_data = False
|
self.end_of_data = False
|
||||||
|
@ -64,6 +67,34 @@ class TokenIterator:
|
||||||
self.end_of_data = True
|
self.end_of_data = True
|
||||||
return "data: [DONE]"
|
return "data: [DONE]"
|
||||||
|
|
||||||
|
def __aiter__(self):
|
||||||
|
return self
|
||||||
|
|
||||||
|
async def __anext__(self):
|
||||||
|
try:
|
||||||
|
while True:
|
||||||
|
self.buffer.seek(self.read_pos)
|
||||||
|
line = self.buffer.readline()
|
||||||
|
if line and line[-1] == ord("\n"):
|
||||||
|
response_obj = {"text": "", "is_finished": False}
|
||||||
|
self.read_pos += len(line) + 1
|
||||||
|
full_line = line[:-1].decode("utf-8")
|
||||||
|
line_data = json.loads(full_line.lstrip("data:").rstrip("/n"))
|
||||||
|
if line_data.get("generated_text", None) is not None:
|
||||||
|
self.end_of_data = True
|
||||||
|
response_obj["is_finished"] = True
|
||||||
|
response_obj["text"] = line_data["token"]["text"]
|
||||||
|
return response_obj
|
||||||
|
chunk = await anext(self.byte_iterator)
|
||||||
|
self.buffer.seek(0, io.SEEK_END)
|
||||||
|
self.buffer.write(chunk["PayloadPart"]["Bytes"])
|
||||||
|
except StopAsyncIteration as e:
|
||||||
|
if self.end_of_data == True:
|
||||||
|
raise e # Re-raise StopIteration
|
||||||
|
else:
|
||||||
|
self.end_of_data = True
|
||||||
|
return "data: [DONE]"
|
||||||
|
|
||||||
|
|
||||||
class SagemakerConfig:
|
class SagemakerConfig:
|
||||||
"""
|
"""
|
||||||
|
@ -197,15 +228,16 @@ def completion(
|
||||||
data = json.dumps(
|
data = json.dumps(
|
||||||
{"inputs": prompt, "parameters": inference_params, "stream": True}
|
{"inputs": prompt, "parameters": inference_params, "stream": True}
|
||||||
).encode("utf-8")
|
).encode("utf-8")
|
||||||
## LOGGING
|
if acompletion == True:
|
||||||
request_str = f"""
|
response = async_streaming(
|
||||||
response = client.invoke_endpoint_with_response_stream(
|
optional_params=optional_params,
|
||||||
EndpointName={model},
|
encoding=encoding,
|
||||||
ContentType="application/json",
|
model_response=model_response,
|
||||||
Body={data},
|
model=model,
|
||||||
CustomAttributes="accept_eula=true",
|
logging_obj=logging_obj,
|
||||||
)
|
data=data,
|
||||||
""" # type: ignore
|
)
|
||||||
|
return response
|
||||||
response = client.invoke_endpoint_with_response_stream(
|
response = client.invoke_endpoint_with_response_stream(
|
||||||
EndpointName=model,
|
EndpointName=model,
|
||||||
ContentType="application/json",
|
ContentType="application/json",
|
||||||
|
@ -311,6 +343,37 @@ def completion(
|
||||||
return model_response
|
return model_response
|
||||||
|
|
||||||
|
|
||||||
|
async def async_streaming(
|
||||||
|
optional_params,
|
||||||
|
encoding,
|
||||||
|
model_response: ModelResponse,
|
||||||
|
model: str,
|
||||||
|
logging_obj: Any,
|
||||||
|
data,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Use aioboto3
|
||||||
|
"""
|
||||||
|
import aioboto3
|
||||||
|
|
||||||
|
session = aioboto3.Session()
|
||||||
|
async with session.client("sagemaker-runtime", region_name="us-west-2") as client:
|
||||||
|
try:
|
||||||
|
response = await client.invoke_endpoint_with_response_stream(
|
||||||
|
EndpointName=model,
|
||||||
|
ContentType="application/json",
|
||||||
|
Body=data,
|
||||||
|
CustomAttributes="accept_eula=true",
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
raise SagemakerError(status_code=500, message=f"{str(e)}")
|
||||||
|
response = response["Body"]
|
||||||
|
# filtered_response = TokenIterator(stream=response, acompletion=True)
|
||||||
|
async for chunk in response:
|
||||||
|
yield chunk
|
||||||
|
# return
|
||||||
|
|
||||||
|
|
||||||
async def async_completion(
|
async def async_completion(
|
||||||
optional_params,
|
optional_params,
|
||||||
encoding,
|
encoding,
|
||||||
|
|
|
@ -1562,7 +1562,7 @@ def completion(
|
||||||
print_verbose(f"ENTERS SAGEMAKER CUSTOMSTREAMWRAPPER")
|
print_verbose(f"ENTERS SAGEMAKER CUSTOMSTREAMWRAPPER")
|
||||||
from .llms.sagemaker import TokenIterator
|
from .llms.sagemaker import TokenIterator
|
||||||
|
|
||||||
tokenIterator = TokenIterator(model_response)
|
tokenIterator = TokenIterator(model_response, acompletion=acompletion)
|
||||||
response = CustomStreamWrapper(
|
response = CustomStreamWrapper(
|
||||||
completion_stream=tokenIterator,
|
completion_stream=tokenIterator,
|
||||||
model=model,
|
model=model,
|
||||||
|
|
|
@ -876,7 +876,6 @@ async def test_sagemaker_streaming_async():
|
||||||
temperature=0.7,
|
temperature=0.7,
|
||||||
stream=True,
|
stream=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Add any assertions here to check the response
|
# Add any assertions here to check the response
|
||||||
print(response)
|
print(response)
|
||||||
complete_response = ""
|
complete_response = ""
|
||||||
|
@ -900,6 +899,9 @@ async def test_sagemaker_streaming_async():
|
||||||
pytest.fail(f"An exception occurred - {str(e)}")
|
pytest.fail(f"An exception occurred - {str(e)}")
|
||||||
|
|
||||||
|
|
||||||
|
asyncio.run(test_sagemaker_streaming_async())
|
||||||
|
|
||||||
|
|
||||||
def test_completion_sagemaker_stream():
|
def test_completion_sagemaker_stream():
|
||||||
try:
|
try:
|
||||||
response = completion(
|
response = completion(
|
||||||
|
|
|
@ -8691,6 +8691,8 @@ class CustomStreamWrapper:
|
||||||
or self.custom_llm_provider == "ollama"
|
or self.custom_llm_provider == "ollama"
|
||||||
or self.custom_llm_provider == "ollama_chat"
|
or self.custom_llm_provider == "ollama_chat"
|
||||||
or self.custom_llm_provider == "vertex_ai"
|
or self.custom_llm_provider == "vertex_ai"
|
||||||
|
or self.custom_llm_provider == "sagemaker"
|
||||||
|
or self.custom_llm_provider in litellm.openai_compatible_endpoints
|
||||||
):
|
):
|
||||||
print_verbose(
|
print_verbose(
|
||||||
f"value of async completion stream: {self.completion_stream}"
|
f"value of async completion stream: {self.completion_stream}"
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue