forked from phoenix/litellm-mirror
refactor sagemaker to be async
This commit is contained in:
parent
b1aed699ea
commit
df4ea8fba6
5 changed files with 798 additions and 603 deletions
File diff suppressed because it is too large
Load diff
|
@ -95,7 +95,6 @@ from .llms import (
|
||||||
palm,
|
palm,
|
||||||
petals,
|
petals,
|
||||||
replicate,
|
replicate,
|
||||||
sagemaker,
|
|
||||||
together_ai,
|
together_ai,
|
||||||
triton,
|
triton,
|
||||||
vertex_ai,
|
vertex_ai,
|
||||||
|
@ -120,6 +119,7 @@ from .llms.prompt_templates.factory import (
|
||||||
prompt_factory,
|
prompt_factory,
|
||||||
stringify_json_tool_call_content,
|
stringify_json_tool_call_content,
|
||||||
)
|
)
|
||||||
|
from .llms.sagemaker import SagemakerLLM
|
||||||
from .llms.text_completion_codestral import CodestralTextCompletion
|
from .llms.text_completion_codestral import CodestralTextCompletion
|
||||||
from .llms.triton import TritonChatCompletion
|
from .llms.triton import TritonChatCompletion
|
||||||
from .llms.vertex_ai_partner import VertexAIPartnerModels
|
from .llms.vertex_ai_partner import VertexAIPartnerModels
|
||||||
|
@ -166,6 +166,7 @@ bedrock_converse_chat_completion = BedrockConverseLLM()
|
||||||
vertex_chat_completion = VertexLLM()
|
vertex_chat_completion = VertexLLM()
|
||||||
vertex_partner_models_chat_completion = VertexAIPartnerModels()
|
vertex_partner_models_chat_completion = VertexAIPartnerModels()
|
||||||
watsonxai = IBMWatsonXAI()
|
watsonxai = IBMWatsonXAI()
|
||||||
|
sagemaker_llm = SagemakerLLM()
|
||||||
####### COMPLETION ENDPOINTS ################
|
####### COMPLETION ENDPOINTS ################
|
||||||
|
|
||||||
|
|
||||||
|
@ -2216,7 +2217,7 @@ def completion(
|
||||||
response = model_response
|
response = model_response
|
||||||
elif custom_llm_provider == "sagemaker":
|
elif custom_llm_provider == "sagemaker":
|
||||||
# boto3 reads keys from .env
|
# boto3 reads keys from .env
|
||||||
model_response = sagemaker.completion(
|
model_response = sagemaker_llm.completion(
|
||||||
model=model,
|
model=model,
|
||||||
messages=messages,
|
messages=messages,
|
||||||
model_response=model_response,
|
model_response=model_response,
|
||||||
|
@ -2230,26 +2231,13 @@ def completion(
|
||||||
logging_obj=logging,
|
logging_obj=logging,
|
||||||
acompletion=acompletion,
|
acompletion=acompletion,
|
||||||
)
|
)
|
||||||
if (
|
if optional_params.get("stream", False):
|
||||||
"stream" in optional_params and optional_params["stream"] == True
|
|
||||||
): ## [BETA]
|
|
||||||
print_verbose(f"ENTERS SAGEMAKER CUSTOMSTREAMWRAPPER")
|
|
||||||
from .llms.sagemaker import TokenIterator
|
|
||||||
|
|
||||||
tokenIterator = TokenIterator(model_response, acompletion=acompletion)
|
|
||||||
response = CustomStreamWrapper(
|
|
||||||
completion_stream=tokenIterator,
|
|
||||||
model=model,
|
|
||||||
custom_llm_provider="sagemaker",
|
|
||||||
logging_obj=logging,
|
|
||||||
)
|
|
||||||
## LOGGING
|
## LOGGING
|
||||||
logging.post_call(
|
logging.post_call(
|
||||||
input=messages,
|
input=messages,
|
||||||
api_key=None,
|
api_key=None,
|
||||||
original_response=response,
|
original_response=model_response,
|
||||||
)
|
)
|
||||||
return response
|
|
||||||
|
|
||||||
## RESPONSE OBJECT
|
## RESPONSE OBJECT
|
||||||
response = model_response
|
response = model_response
|
||||||
|
@ -3529,7 +3517,7 @@ def embedding(
|
||||||
model_response=EmbeddingResponse(),
|
model_response=EmbeddingResponse(),
|
||||||
)
|
)
|
||||||
elif custom_llm_provider == "sagemaker":
|
elif custom_llm_provider == "sagemaker":
|
||||||
response = sagemaker.embedding(
|
response = sagemaker_llm.embedding(
|
||||||
model=model,
|
model=model,
|
||||||
input=input,
|
input=input,
|
||||||
encoding=encoding,
|
encoding=encoding,
|
||||||
|
|
|
@ -28,6 +28,9 @@ litellm.cache = None
|
||||||
litellm.success_callback = []
|
litellm.success_callback = []
|
||||||
user_message = "Write a short poem about the sky"
|
user_message = "Write a short poem about the sky"
|
||||||
messages = [{"content": user_message, "role": "user"}]
|
messages = [{"content": user_message, "role": "user"}]
|
||||||
|
import logging
|
||||||
|
|
||||||
|
from litellm._logging import verbose_logger
|
||||||
|
|
||||||
|
|
||||||
def logger_fn(user_model_dict):
|
def logger_fn(user_model_dict):
|
||||||
|
@ -80,6 +83,55 @@ async def test_completion_sagemaker(sync_mode):
|
||||||
pytest.fail(f"Error occurred: {e}")
|
pytest.fail(f"Error occurred: {e}")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio()
|
||||||
|
@pytest.mark.parametrize("sync_mode", [True])
|
||||||
|
async def test_completion_sagemaker_stream(sync_mode):
|
||||||
|
try:
|
||||||
|
litellm.set_verbose = False
|
||||||
|
print("testing sagemaker")
|
||||||
|
verbose_logger.setLevel(logging.DEBUG)
|
||||||
|
full_text = ""
|
||||||
|
if sync_mode is True:
|
||||||
|
response = litellm.completion(
|
||||||
|
model="sagemaker/jumpstart-dft-hf-textgeneration1-mp-20240815-185614",
|
||||||
|
messages=[
|
||||||
|
{"role": "user", "content": "hi - what is ur name"},
|
||||||
|
],
|
||||||
|
temperature=0.2,
|
||||||
|
stream=True,
|
||||||
|
max_tokens=80,
|
||||||
|
input_cost_per_second=0.000420,
|
||||||
|
)
|
||||||
|
|
||||||
|
for chunk in response:
|
||||||
|
print(chunk)
|
||||||
|
full_text += chunk.choices[0].delta.content or ""
|
||||||
|
|
||||||
|
print("SYNC RESPONSE full text", full_text)
|
||||||
|
else:
|
||||||
|
response = await litellm.acompletion(
|
||||||
|
model="sagemaker/jumpstart-dft-hf-textgeneration1-mp-20240815-185614",
|
||||||
|
messages=[
|
||||||
|
{"role": "user", "content": "hi - what is ur name"},
|
||||||
|
],
|
||||||
|
stream=True,
|
||||||
|
temperature=0.2,
|
||||||
|
max_tokens=80,
|
||||||
|
input_cost_per_second=0.000420,
|
||||||
|
)
|
||||||
|
|
||||||
|
print("streaming response")
|
||||||
|
|
||||||
|
async for chunk in response:
|
||||||
|
print(chunk)
|
||||||
|
full_text += chunk.choices[0].delta.content or ""
|
||||||
|
|
||||||
|
print("ASYNC RESPONSE full text", full_text)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
pytest.fail(f"Error occurred: {e}")
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_acompletion_sagemaker_non_stream():
|
async def test_acompletion_sagemaker_non_stream():
|
||||||
mock_response = AsyncMock()
|
mock_response = AsyncMock()
|
||||||
|
|
|
@ -80,7 +80,7 @@ class ModelInfo(TypedDict, total=False):
|
||||||
supports_assistant_prefill: Optional[bool]
|
supports_assistant_prefill: Optional[bool]
|
||||||
|
|
||||||
|
|
||||||
class GenericStreamingChunk(TypedDict):
|
class GenericStreamingChunk(TypedDict, total=False):
|
||||||
text: Required[str]
|
text: Required[str]
|
||||||
tool_use: Optional[ChatCompletionToolCallChunk]
|
tool_use: Optional[ChatCompletionToolCallChunk]
|
||||||
is_finished: Required[bool]
|
is_finished: Required[bool]
|
||||||
|
|
|
@ -9848,11 +9848,28 @@ class CustomStreamWrapper:
|
||||||
completion_obj["tool_calls"] = [response_obj["tool_use"]]
|
completion_obj["tool_calls"] = [response_obj["tool_use"]]
|
||||||
|
|
||||||
elif self.custom_llm_provider == "sagemaker":
|
elif self.custom_llm_provider == "sagemaker":
|
||||||
print_verbose(f"ENTERS SAGEMAKER STREAMING for chunk {chunk}")
|
from litellm.types.llms.bedrock import GenericStreamingChunk
|
||||||
response_obj = self.handle_sagemaker_stream(chunk)
|
|
||||||
|
if self.received_finish_reason is not None:
|
||||||
|
raise StopIteration
|
||||||
|
response_obj: GenericStreamingChunk = chunk
|
||||||
completion_obj["content"] = response_obj["text"]
|
completion_obj["content"] = response_obj["text"]
|
||||||
if response_obj["is_finished"]:
|
if response_obj["is_finished"]:
|
||||||
self.received_finish_reason = response_obj["finish_reason"]
|
self.received_finish_reason = response_obj["finish_reason"]
|
||||||
|
|
||||||
|
if (
|
||||||
|
self.stream_options
|
||||||
|
and self.stream_options.get("include_usage", False) is True
|
||||||
|
and response_obj["usage"] is not None
|
||||||
|
):
|
||||||
|
model_response.usage = litellm.Usage(
|
||||||
|
prompt_tokens=response_obj["usage"]["inputTokens"],
|
||||||
|
completion_tokens=response_obj["usage"]["outputTokens"],
|
||||||
|
total_tokens=response_obj["usage"]["totalTokens"],
|
||||||
|
)
|
||||||
|
|
||||||
|
if "tool_use" in response_obj and response_obj["tool_use"] is not None:
|
||||||
|
completion_obj["tool_calls"] = [response_obj["tool_use"]]
|
||||||
elif self.custom_llm_provider == "petals":
|
elif self.custom_llm_provider == "petals":
|
||||||
if len(self.completion_stream) == 0:
|
if len(self.completion_stream) == 0:
|
||||||
if self.received_finish_reason is not None:
|
if self.received_finish_reason is not None:
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue