forked from phoenix/litellm-mirror
refactor sagemaker to be async
This commit is contained in:
parent
b1aed699ea
commit
df4ea8fba6
5 changed files with 798 additions and 603 deletions
File diff suppressed because it is too large
Load diff
|
@ -95,7 +95,6 @@ from .llms import (
|
|||
palm,
|
||||
petals,
|
||||
replicate,
|
||||
sagemaker,
|
||||
together_ai,
|
||||
triton,
|
||||
vertex_ai,
|
||||
|
@ -120,6 +119,7 @@ from .llms.prompt_templates.factory import (
|
|||
prompt_factory,
|
||||
stringify_json_tool_call_content,
|
||||
)
|
||||
from .llms.sagemaker import SagemakerLLM
|
||||
from .llms.text_completion_codestral import CodestralTextCompletion
|
||||
from .llms.triton import TritonChatCompletion
|
||||
from .llms.vertex_ai_partner import VertexAIPartnerModels
|
||||
|
@ -166,6 +166,7 @@ bedrock_converse_chat_completion = BedrockConverseLLM()
|
|||
vertex_chat_completion = VertexLLM()
|
||||
vertex_partner_models_chat_completion = VertexAIPartnerModels()
|
||||
watsonxai = IBMWatsonXAI()
|
||||
sagemaker_llm = SagemakerLLM()
|
||||
####### COMPLETION ENDPOINTS ################
|
||||
|
||||
|
||||
|
@ -2216,7 +2217,7 @@ def completion(
|
|||
response = model_response
|
||||
elif custom_llm_provider == "sagemaker":
|
||||
# boto3 reads keys from .env
|
||||
model_response = sagemaker.completion(
|
||||
model_response = sagemaker_llm.completion(
|
||||
model=model,
|
||||
messages=messages,
|
||||
model_response=model_response,
|
||||
|
@ -2230,26 +2231,13 @@ def completion(
|
|||
logging_obj=logging,
|
||||
acompletion=acompletion,
|
||||
)
|
||||
if (
|
||||
"stream" in optional_params and optional_params["stream"] == True
|
||||
): ## [BETA]
|
||||
print_verbose(f"ENTERS SAGEMAKER CUSTOMSTREAMWRAPPER")
|
||||
from .llms.sagemaker import TokenIterator
|
||||
|
||||
tokenIterator = TokenIterator(model_response, acompletion=acompletion)
|
||||
response = CustomStreamWrapper(
|
||||
completion_stream=tokenIterator,
|
||||
model=model,
|
||||
custom_llm_provider="sagemaker",
|
||||
logging_obj=logging,
|
||||
)
|
||||
if optional_params.get("stream", False):
|
||||
## LOGGING
|
||||
logging.post_call(
|
||||
input=messages,
|
||||
api_key=None,
|
||||
original_response=response,
|
||||
original_response=model_response,
|
||||
)
|
||||
return response
|
||||
|
||||
## RESPONSE OBJECT
|
||||
response = model_response
|
||||
|
@ -3529,7 +3517,7 @@ def embedding(
|
|||
model_response=EmbeddingResponse(),
|
||||
)
|
||||
elif custom_llm_provider == "sagemaker":
|
||||
response = sagemaker.embedding(
|
||||
response = sagemaker_llm.embedding(
|
||||
model=model,
|
||||
input=input,
|
||||
encoding=encoding,
|
||||
|
|
|
@ -28,6 +28,9 @@ litellm.cache = None
|
|||
litellm.success_callback = []
|
||||
user_message = "Write a short poem about the sky"
|
||||
messages = [{"content": user_message, "role": "user"}]
|
||||
import logging
|
||||
|
||||
from litellm._logging import verbose_logger
|
||||
|
||||
|
||||
def logger_fn(user_model_dict):
|
||||
|
@ -80,6 +83,55 @@ async def test_completion_sagemaker(sync_mode):
|
|||
pytest.fail(f"Error occurred: {e}")
|
||||
|
||||
|
||||
@pytest.mark.asyncio()
|
||||
@pytest.mark.parametrize("sync_mode", [True])
|
||||
async def test_completion_sagemaker_stream(sync_mode):
|
||||
try:
|
||||
litellm.set_verbose = False
|
||||
print("testing sagemaker")
|
||||
verbose_logger.setLevel(logging.DEBUG)
|
||||
full_text = ""
|
||||
if sync_mode is True:
|
||||
response = litellm.completion(
|
||||
model="sagemaker/jumpstart-dft-hf-textgeneration1-mp-20240815-185614",
|
||||
messages=[
|
||||
{"role": "user", "content": "hi - what is ur name"},
|
||||
],
|
||||
temperature=0.2,
|
||||
stream=True,
|
||||
max_tokens=80,
|
||||
input_cost_per_second=0.000420,
|
||||
)
|
||||
|
||||
for chunk in response:
|
||||
print(chunk)
|
||||
full_text += chunk.choices[0].delta.content or ""
|
||||
|
||||
print("SYNC RESPONSE full text", full_text)
|
||||
else:
|
||||
response = await litellm.acompletion(
|
||||
model="sagemaker/jumpstart-dft-hf-textgeneration1-mp-20240815-185614",
|
||||
messages=[
|
||||
{"role": "user", "content": "hi - what is ur name"},
|
||||
],
|
||||
stream=True,
|
||||
temperature=0.2,
|
||||
max_tokens=80,
|
||||
input_cost_per_second=0.000420,
|
||||
)
|
||||
|
||||
print("streaming response")
|
||||
|
||||
async for chunk in response:
|
||||
print(chunk)
|
||||
full_text += chunk.choices[0].delta.content or ""
|
||||
|
||||
print("ASYNC RESPONSE full text", full_text)
|
||||
|
||||
except Exception as e:
|
||||
pytest.fail(f"Error occurred: {e}")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_acompletion_sagemaker_non_stream():
|
||||
mock_response = AsyncMock()
|
||||
|
|
|
@ -80,7 +80,7 @@ class ModelInfo(TypedDict, total=False):
|
|||
supports_assistant_prefill: Optional[bool]
|
||||
|
||||
|
||||
class GenericStreamingChunk(TypedDict):
|
||||
class GenericStreamingChunk(TypedDict, total=False):
|
||||
text: Required[str]
|
||||
tool_use: Optional[ChatCompletionToolCallChunk]
|
||||
is_finished: Required[bool]
|
||||
|
|
|
@ -9848,11 +9848,28 @@ class CustomStreamWrapper:
|
|||
completion_obj["tool_calls"] = [response_obj["tool_use"]]
|
||||
|
||||
elif self.custom_llm_provider == "sagemaker":
|
||||
print_verbose(f"ENTERS SAGEMAKER STREAMING for chunk {chunk}")
|
||||
response_obj = self.handle_sagemaker_stream(chunk)
|
||||
from litellm.types.llms.bedrock import GenericStreamingChunk
|
||||
|
||||
if self.received_finish_reason is not None:
|
||||
raise StopIteration
|
||||
response_obj: GenericStreamingChunk = chunk
|
||||
completion_obj["content"] = response_obj["text"]
|
||||
if response_obj["is_finished"]:
|
||||
self.received_finish_reason = response_obj["finish_reason"]
|
||||
|
||||
if (
|
||||
self.stream_options
|
||||
and self.stream_options.get("include_usage", False) is True
|
||||
and response_obj["usage"] is not None
|
||||
):
|
||||
model_response.usage = litellm.Usage(
|
||||
prompt_tokens=response_obj["usage"]["inputTokens"],
|
||||
completion_tokens=response_obj["usage"]["outputTokens"],
|
||||
total_tokens=response_obj["usage"]["totalTokens"],
|
||||
)
|
||||
|
||||
if "tool_use" in response_obj and response_obj["tool_use"] is not None:
|
||||
completion_obj["tool_calls"] = [response_obj["tool_use"]]
|
||||
elif self.custom_llm_provider == "petals":
|
||||
if len(self.completion_stream) == 0:
|
||||
if self.received_finish_reason is not None:
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue