fix(sagemaker.py): enable passing hf model name for prompt template

This commit is contained in:
Krrish Dholakia 2023-12-05 16:31:59 -08:00
parent 20dab6f636
commit f9b74e54a3
3 changed files with 8 additions and 7 deletions

View file

@ -63,6 +63,7 @@ def completion(
encoding,
logging_obj,
custom_prompt_dict={},
hf_model_name=None,
optional_params=None,
litellm_params=None,
logger_fn=None,
@ -119,12 +120,7 @@ def completion(
messages=messages
)
else:
hf_model_name = model
if "meta-textgeneration-llama-2" in model or "meta-textgenerationneuron-llama-2" in model: # llama2 model
if model.endswith("-f") or "-f-" in model or "chat" in model: # sagemaker default for a chat model
hf_model_name = "meta-llama/Llama-2-7b-chat" # apply the prompt template for a llama2 chat model
else:
hf_model_name = "meta-llama/Llama-2-7b" # apply the normal prompt template
hf_model_name = hf_model_name or model # pass in hf model name for pulling it's prompt template - (e.g. `hf_model_name="meta-llama/Llama-2-7b-chat-hf` applies the llama2 chat template to the prompt)
prompt = prompt_factory(model=hf_model_name, messages=messages)
data = json.dumps({

View file

@ -341,11 +341,13 @@ def completion(
final_prompt_value = kwargs.get("final_prompt_value", None)
bos_token = kwargs.get("bos_token", None)
eos_token = kwargs.get("eos_token", None)
hf_model_name = kwargs.get("hf_model_name", None)
### ASYNC CALLS ###
acompletion = kwargs.get("acompletion", False)
client = kwargs.get("client", None)
######## end of unpacking kwargs ###########
openai_params = ["functions", "function_call", "temperature", "temperature", "top_p", "n", "stream", "stop", "max_tokens", "presence_penalty", "frequency_penalty", "logit_bias", "user", "request_timeout", "api_base", "api_version", "api_key", "deployment_id", "organization", "base_url", "default_headers", "timeout", "response_format", "seed", "tools", "tool_choice", "max_retries"]
litellm_params = ["metadata", "acompletion", "caching", "return_async", "mock_response", "api_key", "api_version", "api_base", "force_timeout", "logger_fn", "verbose", "custom_llm_provider", "litellm_logging_obj", "litellm_call_id", "use_client", "id", "fallbacks", "azure", "headers", "model_list", "num_retries", "context_window_fallback_dict", "roles", "final_prompt_value", "bos_token", "eos_token", "request_timeout", "complete_response", "self", "client", "rpm", "tpm", "input_cost_per_token", "output_cost_per_token"]
litellm_params = ["metadata", "acompletion", "caching", "return_async", "mock_response", "api_key", "api_version", "api_base", "force_timeout", "logger_fn", "verbose", "custom_llm_provider", "litellm_logging_obj", "litellm_call_id", "use_client", "id", "fallbacks", "azure", "headers", "model_list", "num_retries", "context_window_fallback_dict", "roles", "final_prompt_value", "bos_token", "eos_token", "request_timeout", "complete_response", "self", "client", "rpm", "tpm", "input_cost_per_token", "output_cost_per_token", "hf_model_name"]
default_params = openai_params + litellm_params
non_default_params = {k: v for k,v in kwargs.items() if k not in default_params} # model-specific params - pass them straight to the model/provider
if mock_response:
@ -1167,6 +1169,7 @@ def completion(
optional_params=optional_params,
litellm_params=litellm_params,
custom_prompt_dict=custom_prompt_dict,
hf_model_name=hf_model_name,
logger_fn=logger_fn,
encoding=encoding,
logging_obj=logging

View file

@ -1039,6 +1039,7 @@ def test_completion_sagemaker():
messages=messages,
temperature=0.2,
max_tokens=80,
hf_model_name="meta-llama/Llama-2-7b",
)
# Add any assertions here to check the response
print(response)
@ -1056,6 +1057,7 @@ def test_completion_chat_sagemaker():
messages=messages,
max_tokens=100,
stream=True,
hf_model_name="meta-llama/Llama-2-7b-chat-hf",
)
# Add any assertions here to check the response
complete_response = ""