mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 03:04:13 +00:00
refactor sagemaker to be async
This commit is contained in:
parent
70b7370774
commit
953a67ba4c
5 changed files with 798 additions and 603 deletions
|
@ -95,7 +95,6 @@ from .llms import (
|
|||
palm,
|
||||
petals,
|
||||
replicate,
|
||||
sagemaker,
|
||||
together_ai,
|
||||
triton,
|
||||
vertex_ai,
|
||||
|
@ -120,6 +119,7 @@ from .llms.prompt_templates.factory import (
|
|||
prompt_factory,
|
||||
stringify_json_tool_call_content,
|
||||
)
|
||||
from .llms.sagemaker import SagemakerLLM
|
||||
from .llms.text_completion_codestral import CodestralTextCompletion
|
||||
from .llms.triton import TritonChatCompletion
|
||||
from .llms.vertex_ai_partner import VertexAIPartnerModels
|
||||
|
@ -166,6 +166,7 @@ bedrock_converse_chat_completion = BedrockConverseLLM()
|
|||
vertex_chat_completion = VertexLLM()
|
||||
vertex_partner_models_chat_completion = VertexAIPartnerModels()
|
||||
watsonxai = IBMWatsonXAI()
|
||||
sagemaker_llm = SagemakerLLM()
|
||||
####### COMPLETION ENDPOINTS ################
|
||||
|
||||
|
||||
|
@ -2216,7 +2217,7 @@ def completion(
|
|||
response = model_response
|
||||
elif custom_llm_provider == "sagemaker":
|
||||
# boto3 reads keys from .env
|
||||
model_response = sagemaker.completion(
|
||||
model_response = sagemaker_llm.completion(
|
||||
model=model,
|
||||
messages=messages,
|
||||
model_response=model_response,
|
||||
|
@ -2230,26 +2231,13 @@ def completion(
|
|||
logging_obj=logging,
|
||||
acompletion=acompletion,
|
||||
)
|
||||
if (
|
||||
"stream" in optional_params and optional_params["stream"] == True
|
||||
): ## [BETA]
|
||||
print_verbose(f"ENTERS SAGEMAKER CUSTOMSTREAMWRAPPER")
|
||||
from .llms.sagemaker import TokenIterator
|
||||
|
||||
tokenIterator = TokenIterator(model_response, acompletion=acompletion)
|
||||
response = CustomStreamWrapper(
|
||||
completion_stream=tokenIterator,
|
||||
model=model,
|
||||
custom_llm_provider="sagemaker",
|
||||
logging_obj=logging,
|
||||
)
|
||||
if optional_params.get("stream", False):
|
||||
## LOGGING
|
||||
logging.post_call(
|
||||
input=messages,
|
||||
api_key=None,
|
||||
original_response=response,
|
||||
original_response=model_response,
|
||||
)
|
||||
return response
|
||||
|
||||
## RESPONSE OBJECT
|
||||
response = model_response
|
||||
|
@ -3529,7 +3517,7 @@ def embedding(
|
|||
model_response=EmbeddingResponse(),
|
||||
)
|
||||
elif custom_llm_provider == "sagemaker":
|
||||
response = sagemaker.embedding(
|
||||
response = sagemaker_llm.embedding(
|
||||
model=model,
|
||||
input=input,
|
||||
encoding=encoding,
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue