refactor(sagemaker/): separate chat + completion routes + make them b… (#7151)

* refactor(sagemaker/): separate chat + completion routes + make them both use base llm config

Addresses https://github.com/andrewyng/aisuite/issues/113#issuecomment-2512369132

* fix(main.py): pass hf model name + custom prompt dict to litellm params
This commit is contained in:
Krish Dholakia 2024-12-10 19:40:05 -08:00 committed by GitHub
parent df12f87a64
commit 61afdab228
14 changed files with 799 additions and 534 deletions

View file

@ -2076,6 +2076,8 @@ def get_litellm_params(
user_continue_message=None,
base_model=None,
litellm_trace_id=None,
hf_model_name: Optional[str] = None,
custom_prompt_dict: Optional[dict] = None,
):
litellm_params = {
"acompletion": acompletion,
@ -2105,6 +2107,8 @@ def get_litellm_params(
"base_model": base_model
or _get_base_model_from_litellm_call_metadata(metadata=metadata),
"litellm_trace_id": litellm_trace_id,
"hf_model_name": hf_model_name,
"custom_prompt_dict": custom_prompt_dict,
}
return litellm_params
@ -3145,31 +3149,16 @@ def get_optional_params( # noqa: PLR0915
)
_check_valid_arg(supported_params=supported_params)
# temperature, top_p, n, stream, stop, max_tokens, n, presence_penalty default to None
if temperature is not None:
if temperature == 0.0 or temperature == 0:
# hugging face exception raised when temp==0
# Failed: Error occurred: HuggingfaceException - Input validation error: `temperature` must be strictly positive
if not passed_params.get("aws_sagemaker_allow_zero_temp", False):
temperature = 0.01
optional_params["temperature"] = temperature
if top_p is not None:
optional_params["top_p"] = top_p
if n is not None:
optional_params["best_of"] = n
optional_params["do_sample"] = (
True # Need to sample if you want best of for hf inference endpoints
)
if stream is not None:
optional_params["stream"] = stream
if stop is not None:
optional_params["stop"] = stop
if max_tokens is not None:
# HF TGI raises the following exception when max_new_tokens==0
# Failed: Error occurred: HuggingfaceException - Input validation error: `max_new_tokens` must be strictly positive
if max_tokens == 0:
max_tokens = 1
optional_params["max_new_tokens"] = max_tokens
passed_params.pop("aws_sagemaker_allow_zero_temp", None)
optional_params = litellm.SagemakerConfig().map_openai_params(
non_default_params=non_default_params,
optional_params=optional_params,
model=model,
drop_params=(
drop_params
if drop_params is not None and isinstance(drop_params, bool)
else False
),
)
elif custom_llm_provider == "bedrock":
supported_params = get_supported_openai_params(
model=model, custom_llm_provider=custom_llm_provider
@ -6295,6 +6284,10 @@ class ProviderConfigManager:
return litellm.VertexAIAnthropicConfig()
elif litellm.LlmProviders.CLOUDFLARE == provider:
return litellm.CloudflareChatConfig()
elif litellm.LlmProviders.SAGEMAKER_CHAT == provider:
return litellm.SagemakerChatConfig()
elif litellm.LlmProviders.SAGEMAKER == provider:
return litellm.SagemakerConfig()
elif litellm.LlmProviders.FIREWORKS_AI == provider:
return litellm.FireworksAIConfig()
elif litellm.LlmProviders.FRIENDLIAI == provider: