mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 18:54:30 +00:00
refactor(sagemaker/): separate chat + completion routes + make them b… (#7151)
* refactor(sagemaker/): separate chat + completion routes + make them both use base llm config Addresses https://github.com/andrewyng/aisuite/issues/113#issuecomment-2512369132 * fix(main.py): pass hf model name + custom prompt dict to litellm params
This commit is contained in:
parent
1e87782215
commit
e903fe6038
14 changed files with 799 additions and 534 deletions
|
@ -2076,6 +2076,8 @@ def get_litellm_params(
|
|||
user_continue_message=None,
|
||||
base_model=None,
|
||||
litellm_trace_id=None,
|
||||
hf_model_name: Optional[str] = None,
|
||||
custom_prompt_dict: Optional[dict] = None,
|
||||
):
|
||||
litellm_params = {
|
||||
"acompletion": acompletion,
|
||||
|
@ -2105,6 +2107,8 @@ def get_litellm_params(
|
|||
"base_model": base_model
|
||||
or _get_base_model_from_litellm_call_metadata(metadata=metadata),
|
||||
"litellm_trace_id": litellm_trace_id,
|
||||
"hf_model_name": hf_model_name,
|
||||
"custom_prompt_dict": custom_prompt_dict,
|
||||
}
|
||||
|
||||
return litellm_params
|
||||
|
@ -3145,31 +3149,16 @@ def get_optional_params( # noqa: PLR0915
|
|||
)
|
||||
_check_valid_arg(supported_params=supported_params)
|
||||
# temperature, top_p, n, stream, stop, max_tokens, n, presence_penalty default to None
|
||||
if temperature is not None:
|
||||
if temperature == 0.0 or temperature == 0:
|
||||
# hugging face exception raised when temp==0
|
||||
# Failed: Error occurred: HuggingfaceException - Input validation error: `temperature` must be strictly positive
|
||||
if not passed_params.get("aws_sagemaker_allow_zero_temp", False):
|
||||
temperature = 0.01
|
||||
optional_params["temperature"] = temperature
|
||||
if top_p is not None:
|
||||
optional_params["top_p"] = top_p
|
||||
if n is not None:
|
||||
optional_params["best_of"] = n
|
||||
optional_params["do_sample"] = (
|
||||
True # Need to sample if you want best of for hf inference endpoints
|
||||
)
|
||||
if stream is not None:
|
||||
optional_params["stream"] = stream
|
||||
if stop is not None:
|
||||
optional_params["stop"] = stop
|
||||
if max_tokens is not None:
|
||||
# HF TGI raises the following exception when max_new_tokens==0
|
||||
# Failed: Error occurred: HuggingfaceException - Input validation error: `max_new_tokens` must be strictly positive
|
||||
if max_tokens == 0:
|
||||
max_tokens = 1
|
||||
optional_params["max_new_tokens"] = max_tokens
|
||||
passed_params.pop("aws_sagemaker_allow_zero_temp", None)
|
||||
optional_params = litellm.SagemakerConfig().map_openai_params(
|
||||
non_default_params=non_default_params,
|
||||
optional_params=optional_params,
|
||||
model=model,
|
||||
drop_params=(
|
||||
drop_params
|
||||
if drop_params is not None and isinstance(drop_params, bool)
|
||||
else False
|
||||
),
|
||||
)
|
||||
elif custom_llm_provider == "bedrock":
|
||||
supported_params = get_supported_openai_params(
|
||||
model=model, custom_llm_provider=custom_llm_provider
|
||||
|
@ -6295,6 +6284,10 @@ class ProviderConfigManager:
|
|||
return litellm.VertexAIAnthropicConfig()
|
||||
elif litellm.LlmProviders.CLOUDFLARE == provider:
|
||||
return litellm.CloudflareChatConfig()
|
||||
elif litellm.LlmProviders.SAGEMAKER_CHAT == provider:
|
||||
return litellm.SagemakerChatConfig()
|
||||
elif litellm.LlmProviders.SAGEMAKER == provider:
|
||||
return litellm.SagemakerConfig()
|
||||
elif litellm.LlmProviders.FIREWORKS_AI == provider:
|
||||
return litellm.FireworksAIConfig()
|
||||
elif litellm.LlmProviders.FRIENDLIAI == provider:
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue