refactor sagemaker to be async

This commit is contained in:
Ishaan Jaff 2024-08-15 18:18:02 -07:00
parent b1aed699ea
commit df4ea8fba6
5 changed files with 798 additions and 603 deletions

File diff suppressed because it is too large Load diff

View file

@ -95,7 +95,6 @@ from .llms import (
palm,
petals,
replicate,
sagemaker,
together_ai,
triton,
vertex_ai,
@ -120,6 +119,7 @@ from .llms.prompt_templates.factory import (
prompt_factory,
stringify_json_tool_call_content,
)
from .llms.sagemaker import SagemakerLLM
from .llms.text_completion_codestral import CodestralTextCompletion
from .llms.triton import TritonChatCompletion
from .llms.vertex_ai_partner import VertexAIPartnerModels
@ -166,6 +166,7 @@ bedrock_converse_chat_completion = BedrockConverseLLM()
vertex_chat_completion = VertexLLM()
vertex_partner_models_chat_completion = VertexAIPartnerModels()
watsonxai = IBMWatsonXAI()
sagemaker_llm = SagemakerLLM()
####### COMPLETION ENDPOINTS ################
@ -2216,7 +2217,7 @@ def completion(
response = model_response
elif custom_llm_provider == "sagemaker":
# boto3 reads keys from .env
model_response = sagemaker.completion(
model_response = sagemaker_llm.completion(
model=model,
messages=messages,
model_response=model_response,
@ -2230,26 +2231,13 @@ def completion(
logging_obj=logging,
acompletion=acompletion,
)
if (
"stream" in optional_params and optional_params["stream"] == True
): ## [BETA]
print_verbose(f"ENTERS SAGEMAKER CUSTOMSTREAMWRAPPER")
from .llms.sagemaker import TokenIterator
tokenIterator = TokenIterator(model_response, acompletion=acompletion)
response = CustomStreamWrapper(
completion_stream=tokenIterator,
model=model,
custom_llm_provider="sagemaker",
logging_obj=logging,
)
if optional_params.get("stream", False):
## LOGGING
logging.post_call(
input=messages,
api_key=None,
original_response=response,
original_response=model_response,
)
return response
## RESPONSE OBJECT
response = model_response
@ -3529,7 +3517,7 @@ def embedding(
model_response=EmbeddingResponse(),
)
elif custom_llm_provider == "sagemaker":
response = sagemaker.embedding(
response = sagemaker_llm.embedding(
model=model,
input=input,
encoding=encoding,

View file

@ -28,6 +28,9 @@ litellm.cache = None
litellm.success_callback = []
user_message = "Write a short poem about the sky"
messages = [{"content": user_message, "role": "user"}]
import logging
from litellm._logging import verbose_logger
def logger_fn(user_model_dict):
@ -80,6 +83,55 @@ async def test_completion_sagemaker(sync_mode):
pytest.fail(f"Error occurred: {e}")
@pytest.mark.asyncio()
@pytest.mark.parametrize("sync_mode", [True])
async def test_completion_sagemaker_stream(sync_mode):
try:
litellm.set_verbose = False
print("testing sagemaker")
verbose_logger.setLevel(logging.DEBUG)
full_text = ""
if sync_mode is True:
response = litellm.completion(
model="sagemaker/jumpstart-dft-hf-textgeneration1-mp-20240815-185614",
messages=[
{"role": "user", "content": "hi - what is ur name"},
],
temperature=0.2,
stream=True,
max_tokens=80,
input_cost_per_second=0.000420,
)
for chunk in response:
print(chunk)
full_text += chunk.choices[0].delta.content or ""
print("SYNC RESPONSE full text", full_text)
else:
response = await litellm.acompletion(
model="sagemaker/jumpstart-dft-hf-textgeneration1-mp-20240815-185614",
messages=[
{"role": "user", "content": "hi - what is ur name"},
],
stream=True,
temperature=0.2,
max_tokens=80,
input_cost_per_second=0.000420,
)
print("streaming response")
async for chunk in response:
print(chunk)
full_text += chunk.choices[0].delta.content or ""
print("ASYNC RESPONSE full text", full_text)
except Exception as e:
pytest.fail(f"Error occurred: {e}")
@pytest.mark.asyncio
async def test_acompletion_sagemaker_non_stream():
mock_response = AsyncMock()

View file

@ -80,7 +80,7 @@ class ModelInfo(TypedDict, total=False):
supports_assistant_prefill: Optional[bool]
class GenericStreamingChunk(TypedDict):
class GenericStreamingChunk(TypedDict, total=False):
text: Required[str]
tool_use: Optional[ChatCompletionToolCallChunk]
is_finished: Required[bool]

View file

@ -9848,11 +9848,28 @@ class CustomStreamWrapper:
completion_obj["tool_calls"] = [response_obj["tool_use"]]
elif self.custom_llm_provider == "sagemaker":
print_verbose(f"ENTERS SAGEMAKER STREAMING for chunk {chunk}")
response_obj = self.handle_sagemaker_stream(chunk)
from litellm.types.llms.bedrock import GenericStreamingChunk
if self.received_finish_reason is not None:
raise StopIteration
response_obj: GenericStreamingChunk = chunk
completion_obj["content"] = response_obj["text"]
if response_obj["is_finished"]:
self.received_finish_reason = response_obj["finish_reason"]
if (
self.stream_options
and self.stream_options.get("include_usage", False) is True
and response_obj["usage"] is not None
):
model_response.usage = litellm.Usage(
prompt_tokens=response_obj["usage"]["inputTokens"],
completion_tokens=response_obj["usage"]["outputTokens"],
total_tokens=response_obj["usage"]["totalTokens"],
)
if "tool_use" in response_obj and response_obj["tool_use"] is not None:
completion_obj["tool_calls"] = [response_obj["tool_use"]]
elif self.custom_llm_provider == "petals":
if len(self.completion_stream) == 0:
if self.received_finish_reason is not None: