mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 03:04:13 +00:00
refactor(sagemaker/): separate chat + completion routes + make them b… (#7151)
* refactor(sagemaker/): separate chat + completion routes + make them both use base llm config Addresses https://github.com/andrewyng/aisuite/issues/113#issuecomment-2512369132 * fix(main.py): pass hf model name + custom prompt dict to litellm params
This commit is contained in:
parent
df12f87a64
commit
61afdab228
14 changed files with 799 additions and 534 deletions
|
@ -130,7 +130,8 @@ from .llms.prompt_templates.factory import (
|
|||
prompt_factory,
|
||||
stringify_json_tool_call_content,
|
||||
)
|
||||
from .llms.sagemaker.sagemaker import SagemakerLLM
|
||||
from .llms.sagemaker.chat.handler import SagemakerChatHandler
|
||||
from .llms.sagemaker.completion.handler import SagemakerLLM
|
||||
from .llms.text_completion_codestral import CodestralTextCompletion
|
||||
from .llms.together_ai.completion.handler import TogetherAITextCompletion
|
||||
from .llms.triton import TritonChatCompletion
|
||||
|
@ -229,6 +230,7 @@ watsonx_chat_completion = WatsonXChatHandler()
|
|||
openai_like_embedding = OpenAILikeEmbeddingHandler()
|
||||
databricks_embedding = DatabricksEmbeddingHandler()
|
||||
base_llm_http_handler = BaseLLMHTTPHandler()
|
||||
sagemaker_chat_completion = SagemakerChatHandler()
|
||||
####### COMPLETION ENDPOINTS ################
|
||||
|
||||
|
||||
|
@ -1073,6 +1075,8 @@ def completion( # type: ignore # noqa: PLR0915
|
|||
user_continue_message=kwargs.get("user_continue_message"),
|
||||
base_model=base_model,
|
||||
litellm_trace_id=kwargs.get("litellm_trace_id"),
|
||||
hf_model_name=hf_model_name,
|
||||
custom_prompt_dict=custom_prompt_dict,
|
||||
)
|
||||
logging.update_environment_variables(
|
||||
model=model,
|
||||
|
@ -2513,10 +2517,23 @@ def completion( # type: ignore # noqa: PLR0915
|
|||
|
||||
## RESPONSE OBJECT
|
||||
response = model_response
|
||||
elif (
|
||||
custom_llm_provider == "sagemaker"
|
||||
or custom_llm_provider == "sagemaker_chat"
|
||||
):
|
||||
elif custom_llm_provider == "sagemaker_chat":
|
||||
# boto3 reads keys from .env
|
||||
response = sagemaker_chat_completion.completion(
|
||||
model=model,
|
||||
messages=messages,
|
||||
model_response=model_response,
|
||||
print_verbose=print_verbose,
|
||||
optional_params=optional_params,
|
||||
litellm_params=litellm_params,
|
||||
custom_prompt_dict=custom_prompt_dict,
|
||||
logger_fn=logger_fn,
|
||||
encoding=encoding,
|
||||
logging_obj=logging,
|
||||
acompletion=acompletion,
|
||||
headers=headers or {},
|
||||
)
|
||||
elif custom_llm_provider == "sagemaker":
|
||||
# boto3 reads keys from .env
|
||||
model_response = sagemaker_llm.completion(
|
||||
model=model,
|
||||
|
@ -2531,17 +2548,7 @@ def completion( # type: ignore # noqa: PLR0915
|
|||
encoding=encoding,
|
||||
logging_obj=logging,
|
||||
acompletion=acompletion,
|
||||
use_messages_api=(
|
||||
True if custom_llm_provider == "sagemaker_chat" else False
|
||||
),
|
||||
)
|
||||
if optional_params.get("stream", False):
|
||||
## LOGGING
|
||||
logging.post_call(
|
||||
input=messages,
|
||||
api_key=None,
|
||||
original_response=model_response,
|
||||
)
|
||||
|
||||
## RESPONSE OBJECT
|
||||
response = model_response
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue