mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-27 19:54:13 +00:00
fix(sagemaker.py): enable passing hf model name for prompt template
This commit is contained in:
parent
20dab6f636
commit
f9b74e54a3
3 changed files with 8 additions and 7 deletions
|
@ -63,6 +63,7 @@ def completion(
|
|||
encoding,
|
||||
logging_obj,
|
||||
custom_prompt_dict={},
|
||||
hf_model_name=None,
|
||||
optional_params=None,
|
||||
litellm_params=None,
|
||||
logger_fn=None,
|
||||
|
@ -119,12 +120,7 @@ def completion(
|
|||
messages=messages
|
||||
)
|
||||
else:
|
||||
hf_model_name = model
|
||||
if "meta-textgeneration-llama-2" in model or "meta-textgenerationneuron-llama-2" in model: # llama2 model
|
||||
if model.endswith("-f") or "-f-" in model or "chat" in model: # sagemaker default for a chat model
|
||||
hf_model_name = "meta-llama/Llama-2-7b-chat" # apply the prompt template for a llama2 chat model
|
||||
else:
|
||||
hf_model_name = "meta-llama/Llama-2-7b" # apply the normal prompt template
|
||||
hf_model_name = hf_model_name or model # pass in hf model name for pulling it's prompt template - (e.g. `hf_model_name="meta-llama/Llama-2-7b-chat-hf` applies the llama2 chat template to the prompt)
|
||||
prompt = prompt_factory(model=hf_model_name, messages=messages)
|
||||
|
||||
data = json.dumps({
|
||||
|
|
|
@ -341,11 +341,13 @@ def completion(
|
|||
final_prompt_value = kwargs.get("final_prompt_value", None)
|
||||
bos_token = kwargs.get("bos_token", None)
|
||||
eos_token = kwargs.get("eos_token", None)
|
||||
hf_model_name = kwargs.get("hf_model_name", None)
|
||||
### ASYNC CALLS ###
|
||||
acompletion = kwargs.get("acompletion", False)
|
||||
client = kwargs.get("client", None)
|
||||
######## end of unpacking kwargs ###########
|
||||
openai_params = ["functions", "function_call", "temperature", "temperature", "top_p", "n", "stream", "stop", "max_tokens", "presence_penalty", "frequency_penalty", "logit_bias", "user", "request_timeout", "api_base", "api_version", "api_key", "deployment_id", "organization", "base_url", "default_headers", "timeout", "response_format", "seed", "tools", "tool_choice", "max_retries"]
|
||||
litellm_params = ["metadata", "acompletion", "caching", "return_async", "mock_response", "api_key", "api_version", "api_base", "force_timeout", "logger_fn", "verbose", "custom_llm_provider", "litellm_logging_obj", "litellm_call_id", "use_client", "id", "fallbacks", "azure", "headers", "model_list", "num_retries", "context_window_fallback_dict", "roles", "final_prompt_value", "bos_token", "eos_token", "request_timeout", "complete_response", "self", "client", "rpm", "tpm", "input_cost_per_token", "output_cost_per_token"]
|
||||
litellm_params = ["metadata", "acompletion", "caching", "return_async", "mock_response", "api_key", "api_version", "api_base", "force_timeout", "logger_fn", "verbose", "custom_llm_provider", "litellm_logging_obj", "litellm_call_id", "use_client", "id", "fallbacks", "azure", "headers", "model_list", "num_retries", "context_window_fallback_dict", "roles", "final_prompt_value", "bos_token", "eos_token", "request_timeout", "complete_response", "self", "client", "rpm", "tpm", "input_cost_per_token", "output_cost_per_token", "hf_model_name"]
|
||||
default_params = openai_params + litellm_params
|
||||
non_default_params = {k: v for k,v in kwargs.items() if k not in default_params} # model-specific params - pass them straight to the model/provider
|
||||
if mock_response:
|
||||
|
@ -1167,6 +1169,7 @@ def completion(
|
|||
optional_params=optional_params,
|
||||
litellm_params=litellm_params,
|
||||
custom_prompt_dict=custom_prompt_dict,
|
||||
hf_model_name=hf_model_name,
|
||||
logger_fn=logger_fn,
|
||||
encoding=encoding,
|
||||
logging_obj=logging
|
||||
|
|
|
@ -1039,6 +1039,7 @@ def test_completion_sagemaker():
|
|||
messages=messages,
|
||||
temperature=0.2,
|
||||
max_tokens=80,
|
||||
hf_model_name="meta-llama/Llama-2-7b",
|
||||
)
|
||||
# Add any assertions here to check the response
|
||||
print(response)
|
||||
|
@ -1056,6 +1057,7 @@ def test_completion_chat_sagemaker():
|
|||
messages=messages,
|
||||
max_tokens=100,
|
||||
stream=True,
|
||||
hf_model_name="meta-llama/Llama-2-7b-chat-hf",
|
||||
)
|
||||
# Add any assertions here to check the response
|
||||
complete_response = ""
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue