refactor sagemaker to be async

This commit is contained in:
Ishaan Jaff 2024-08-15 18:18:02 -07:00
parent b1aed699ea
commit df4ea8fba6
5 changed files with 798 additions and 603 deletions

File diff suppressed because it is too large Load diff

View file

@ -95,7 +95,6 @@ from .llms import (
palm, palm,
petals, petals,
replicate, replicate,
sagemaker,
together_ai, together_ai,
triton, triton,
vertex_ai, vertex_ai,
@ -120,6 +119,7 @@ from .llms.prompt_templates.factory import (
prompt_factory, prompt_factory,
stringify_json_tool_call_content, stringify_json_tool_call_content,
) )
from .llms.sagemaker import SagemakerLLM
from .llms.text_completion_codestral import CodestralTextCompletion from .llms.text_completion_codestral import CodestralTextCompletion
from .llms.triton import TritonChatCompletion from .llms.triton import TritonChatCompletion
from .llms.vertex_ai_partner import VertexAIPartnerModels from .llms.vertex_ai_partner import VertexAIPartnerModels
@ -166,6 +166,7 @@ bedrock_converse_chat_completion = BedrockConverseLLM()
vertex_chat_completion = VertexLLM() vertex_chat_completion = VertexLLM()
vertex_partner_models_chat_completion = VertexAIPartnerModels() vertex_partner_models_chat_completion = VertexAIPartnerModels()
watsonxai = IBMWatsonXAI() watsonxai = IBMWatsonXAI()
sagemaker_llm = SagemakerLLM()
####### COMPLETION ENDPOINTS ################ ####### COMPLETION ENDPOINTS ################
@ -2216,7 +2217,7 @@ def completion(
response = model_response response = model_response
elif custom_llm_provider == "sagemaker": elif custom_llm_provider == "sagemaker":
# boto3 reads keys from .env # boto3 reads keys from .env
model_response = sagemaker.completion( model_response = sagemaker_llm.completion(
model=model, model=model,
messages=messages, messages=messages,
model_response=model_response, model_response=model_response,
@ -2230,26 +2231,13 @@ def completion(
logging_obj=logging, logging_obj=logging,
acompletion=acompletion, acompletion=acompletion,
) )
if ( if optional_params.get("stream", False):
"stream" in optional_params and optional_params["stream"] == True
): ## [BETA]
print_verbose(f"ENTERS SAGEMAKER CUSTOMSTREAMWRAPPER")
from .llms.sagemaker import TokenIterator
tokenIterator = TokenIterator(model_response, acompletion=acompletion)
response = CustomStreamWrapper(
completion_stream=tokenIterator,
model=model,
custom_llm_provider="sagemaker",
logging_obj=logging,
)
## LOGGING ## LOGGING
logging.post_call( logging.post_call(
input=messages, input=messages,
api_key=None, api_key=None,
original_response=response, original_response=model_response,
) )
return response
## RESPONSE OBJECT ## RESPONSE OBJECT
response = model_response response = model_response
@ -3529,7 +3517,7 @@ def embedding(
model_response=EmbeddingResponse(), model_response=EmbeddingResponse(),
) )
elif custom_llm_provider == "sagemaker": elif custom_llm_provider == "sagemaker":
response = sagemaker.embedding( response = sagemaker_llm.embedding(
model=model, model=model,
input=input, input=input,
encoding=encoding, encoding=encoding,

View file

@ -28,6 +28,9 @@ litellm.cache = None
litellm.success_callback = [] litellm.success_callback = []
user_message = "Write a short poem about the sky" user_message = "Write a short poem about the sky"
messages = [{"content": user_message, "role": "user"}] messages = [{"content": user_message, "role": "user"}]
import logging
from litellm._logging import verbose_logger
def logger_fn(user_model_dict): def logger_fn(user_model_dict):
@ -80,6 +83,55 @@ async def test_completion_sagemaker(sync_mode):
pytest.fail(f"Error occurred: {e}") pytest.fail(f"Error occurred: {e}")
@pytest.mark.asyncio()
@pytest.mark.parametrize("sync_mode", [True])
async def test_completion_sagemaker_stream(sync_mode):
try:
litellm.set_verbose = False
print("testing sagemaker")
verbose_logger.setLevel(logging.DEBUG)
full_text = ""
if sync_mode is True:
response = litellm.completion(
model="sagemaker/jumpstart-dft-hf-textgeneration1-mp-20240815-185614",
messages=[
{"role": "user", "content": "hi - what is ur name"},
],
temperature=0.2,
stream=True,
max_tokens=80,
input_cost_per_second=0.000420,
)
for chunk in response:
print(chunk)
full_text += chunk.choices[0].delta.content or ""
print("SYNC RESPONSE full text", full_text)
else:
response = await litellm.acompletion(
model="sagemaker/jumpstart-dft-hf-textgeneration1-mp-20240815-185614",
messages=[
{"role": "user", "content": "hi - what is ur name"},
],
stream=True,
temperature=0.2,
max_tokens=80,
input_cost_per_second=0.000420,
)
print("streaming response")
async for chunk in response:
print(chunk)
full_text += chunk.choices[0].delta.content or ""
print("ASYNC RESPONSE full text", full_text)
except Exception as e:
pytest.fail(f"Error occurred: {e}")
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_acompletion_sagemaker_non_stream(): async def test_acompletion_sagemaker_non_stream():
mock_response = AsyncMock() mock_response = AsyncMock()

View file

@ -80,7 +80,7 @@ class ModelInfo(TypedDict, total=False):
supports_assistant_prefill: Optional[bool] supports_assistant_prefill: Optional[bool]
class GenericStreamingChunk(TypedDict): class GenericStreamingChunk(TypedDict, total=False):
text: Required[str] text: Required[str]
tool_use: Optional[ChatCompletionToolCallChunk] tool_use: Optional[ChatCompletionToolCallChunk]
is_finished: Required[bool] is_finished: Required[bool]

View file

@ -9848,11 +9848,28 @@ class CustomStreamWrapper:
completion_obj["tool_calls"] = [response_obj["tool_use"]] completion_obj["tool_calls"] = [response_obj["tool_use"]]
elif self.custom_llm_provider == "sagemaker": elif self.custom_llm_provider == "sagemaker":
print_verbose(f"ENTERS SAGEMAKER STREAMING for chunk {chunk}") from litellm.types.llms.bedrock import GenericStreamingChunk
response_obj = self.handle_sagemaker_stream(chunk)
if self.received_finish_reason is not None:
raise StopIteration
response_obj: GenericStreamingChunk = chunk
completion_obj["content"] = response_obj["text"] completion_obj["content"] = response_obj["text"]
if response_obj["is_finished"]: if response_obj["is_finished"]:
self.received_finish_reason = response_obj["finish_reason"] self.received_finish_reason = response_obj["finish_reason"]
if (
self.stream_options
and self.stream_options.get("include_usage", False) is True
and response_obj["usage"] is not None
):
model_response.usage = litellm.Usage(
prompt_tokens=response_obj["usage"]["inputTokens"],
completion_tokens=response_obj["usage"]["outputTokens"],
total_tokens=response_obj["usage"]["totalTokens"],
)
if "tool_use" in response_obj and response_obj["tool_use"] is not None:
completion_obj["tool_calls"] = [response_obj["tool_use"]]
elif self.custom_llm_provider == "petals": elif self.custom_llm_provider == "petals":
if len(self.completion_stream) == 0: if len(self.completion_stream) == 0:
if self.received_finish_reason is not None: if self.received_finish_reason is not None: