refactor(sagemaker/): separate chat + completion routes + make them b… (#7151)

* refactor(sagemaker/): separate chat + completion routes + make them both use base llm config

Addresses https://github.com/andrewyng/aisuite/issues/113#issuecomment-2512369132

* fix(main.py): pass hf model name + custom prompt dict to litellm params
This commit is contained in:
Krish Dholakia 2024-12-10 19:40:05 -08:00 committed by GitHub
parent df12f87a64
commit 61afdab228
14 changed files with 799 additions and 534 deletions

View file

@ -130,7 +130,8 @@ from .llms.prompt_templates.factory import (
prompt_factory,
stringify_json_tool_call_content,
)
from .llms.sagemaker.sagemaker import SagemakerLLM
from .llms.sagemaker.chat.handler import SagemakerChatHandler
from .llms.sagemaker.completion.handler import SagemakerLLM
from .llms.text_completion_codestral import CodestralTextCompletion
from .llms.together_ai.completion.handler import TogetherAITextCompletion
from .llms.triton import TritonChatCompletion
@ -229,6 +230,7 @@ watsonx_chat_completion = WatsonXChatHandler()
openai_like_embedding = OpenAILikeEmbeddingHandler()
databricks_embedding = DatabricksEmbeddingHandler()
base_llm_http_handler = BaseLLMHTTPHandler()
sagemaker_chat_completion = SagemakerChatHandler()
####### COMPLETION ENDPOINTS ################
@ -1073,6 +1075,8 @@ def completion( # type: ignore # noqa: PLR0915
user_continue_message=kwargs.get("user_continue_message"),
base_model=base_model,
litellm_trace_id=kwargs.get("litellm_trace_id"),
hf_model_name=hf_model_name,
custom_prompt_dict=custom_prompt_dict,
)
logging.update_environment_variables(
model=model,
@ -2513,10 +2517,23 @@ def completion( # type: ignore # noqa: PLR0915
## RESPONSE OBJECT
response = model_response
elif (
custom_llm_provider == "sagemaker"
or custom_llm_provider == "sagemaker_chat"
):
elif custom_llm_provider == "sagemaker_chat":
# boto3 reads keys from .env
response = sagemaker_chat_completion.completion(
model=model,
messages=messages,
model_response=model_response,
print_verbose=print_verbose,
optional_params=optional_params,
litellm_params=litellm_params,
custom_prompt_dict=custom_prompt_dict,
logger_fn=logger_fn,
encoding=encoding,
logging_obj=logging,
acompletion=acompletion,
headers=headers or {},
)
elif custom_llm_provider == "sagemaker":
# boto3 reads keys from .env
model_response = sagemaker_llm.completion(
model=model,
@ -2531,17 +2548,7 @@ def completion( # type: ignore # noqa: PLR0915
encoding=encoding,
logging_obj=logging,
acompletion=acompletion,
use_messages_api=(
True if custom_llm_provider == "sagemaker_chat" else False
),
)
if optional_params.get("stream", False):
## LOGGING
logging.post_call(
input=messages,
api_key=None,
original_response=model_response,
)
## RESPONSE OBJECT
response = model_response