refactor sagemaker to be async

This commit is contained in:
Ishaan Jaff 2024-08-15 18:18:02 -07:00
parent 70b7370774
commit 953a67ba4c
5 changed files with 798 additions and 603 deletions

View file

@ -95,7 +95,6 @@ from .llms import (
palm,
petals,
replicate,
sagemaker,
together_ai,
triton,
vertex_ai,
@ -120,6 +119,7 @@ from .llms.prompt_templates.factory import (
prompt_factory,
stringify_json_tool_call_content,
)
from .llms.sagemaker import SagemakerLLM
from .llms.text_completion_codestral import CodestralTextCompletion
from .llms.triton import TritonChatCompletion
from .llms.vertex_ai_partner import VertexAIPartnerModels
@ -166,6 +166,7 @@ bedrock_converse_chat_completion = BedrockConverseLLM()
vertex_chat_completion = VertexLLM()
vertex_partner_models_chat_completion = VertexAIPartnerModels()
watsonxai = IBMWatsonXAI()
sagemaker_llm = SagemakerLLM()
####### COMPLETION ENDPOINTS ################
@ -2216,7 +2217,7 @@ def completion(
response = model_response
elif custom_llm_provider == "sagemaker":
# boto3 reads keys from .env
model_response = sagemaker.completion(
model_response = sagemaker_llm.completion(
model=model,
messages=messages,
model_response=model_response,
@ -2230,26 +2231,13 @@ def completion(
logging_obj=logging,
acompletion=acompletion,
)
if (
"stream" in optional_params and optional_params["stream"] == True
): ## [BETA]
print_verbose(f"ENTERS SAGEMAKER CUSTOMSTREAMWRAPPER")
from .llms.sagemaker import TokenIterator
tokenIterator = TokenIterator(model_response, acompletion=acompletion)
response = CustomStreamWrapper(
completion_stream=tokenIterator,
model=model,
custom_llm_provider="sagemaker",
logging_obj=logging,
)
if optional_params.get("stream", False):
## LOGGING
logging.post_call(
input=messages,
api_key=None,
original_response=response,
original_response=model_response,
)
return response
## RESPONSE OBJECT
response = model_response
@ -3529,7 +3517,7 @@ def embedding(
model_response=EmbeddingResponse(),
)
elif custom_llm_provider == "sagemaker":
response = sagemaker.embedding(
response = sagemaker_llm.embedding(
model=model,
input=input,
encoding=encoding,