fix system prompts for replicate

This commit is contained in:
chabala98 2023-12-01 13:16:35 +01:00
parent 1e9aa69268
commit c2e2e927fb
2 changed files with 37 additions and 30 deletions

View file

@ -49,8 +49,8 @@ Below are examples on how to call replicate LLMs using liteLLM
Model Name | Function Call | Required OS Variables | Model Name | Function Call | Required OS Variables |
-----------------------------|----------------------------------------------------------------|--------------------------------------| -----------------------------|----------------------------------------------------------------|--------------------------------------|
replicate/llama-2-70b-chat | `completion(model='replicate/llama-2-70b-chat:2796ee9483c3fd7aa2e171d38f4ca12251a30609463dcfd4cd76703f22e96cdf', messages)` | `os.environ['REPLICATE_API_KEY']` | replicate/llama-2-70b-chat | `completion(model='replicate/llama-2-70b-chat:2796ee9483c3fd7aa2e171d38f4ca12251a30609463dcfd4cd76703f22e96cdf', messages, supports_system_prompt=True)` | `os.environ['REPLICATE_API_KEY']` |
a16z-infra/llama-2-13b-chat| `completion(model='replicate/a16z-infra/llama-2-13b-chat:2a7f981751ec7fdf87b5b91ad4db53683a98082e9ff7bfd12c8cd5ea85980a52', messages)`| `os.environ['REPLICATE_API_KEY']` | a16z-infra/llama-2-13b-chat| `completion(model='replicate/a16z-infra/llama-2-13b-chat:2a7f981751ec7fdf87b5b91ad4db53683a98082e9ff7bfd12c8cd5ea85980a52', messages, supports_system_prompt=True)`| `os.environ['REPLICATE_API_KEY']` |
replicate/vicuna-13b | `completion(model='replicate/vicuna-13b:6282abe6a492de4145d7bb601023762212f9ddbbe78278bd6771c8b3b2f2a13b', messages)` | `os.environ['REPLICATE_API_KEY']` | replicate/vicuna-13b | `completion(model='replicate/vicuna-13b:6282abe6a492de4145d7bb601023762212f9ddbbe78278bd6771c8b3b2f2a13b', messages)` | `os.environ['REPLICATE_API_KEY']` |
daanelson/flan-t5-large | `completion(model='replicate/daanelson/flan-t5-large:ce962b3f6792a57074a601d3979db5839697add2e4e02696b3ced4c022d4767f', messages)` | `os.environ['REPLICATE_API_KEY']` | daanelson/flan-t5-large | `completion(model='replicate/daanelson/flan-t5-large:ce962b3f6792a57074a601d3979db5839697add2e4e02696b3ced4c022d4767f', messages)` | `os.environ['REPLICATE_API_KEY']` |
custom-llm | `completion(model='replicate/custom-llm-version-id', messages)` | `os.environ['REPLICATE_API_KEY']` | custom-llm | `completion(model='replicate/custom-llm-version-id', messages)` | `os.environ['REPLICATE_API_KEY']` |

View file

@ -169,6 +169,7 @@ def handle_prediction_response_streaming(prediction_url, api_token, print_verbos
else: else:
# this can fail temporarily but it does not mean the replicate request failed, replicate request fails when status=="failed" # this can fail temporarily but it does not mean the replicate request failed, replicate request fails when status=="failed"
print_verbose(f"Replicate: Failed to fetch prediction status and output.{response.status_code}{response.text}") print_verbose(f"Replicate: Failed to fetch prediction status and output.{response.status_code}{response.text}")
# Function to extract version ID from model string # Function to extract version ID from model string
def model_to_version_id(model): def model_to_version_id(model):
@ -194,41 +195,47 @@ def completion(
): ):
# Start a prediction and get the prediction URL # Start a prediction and get the prediction URL
version_id = model_to_version_id(model) version_id = model_to_version_id(model)
## Load Config ## Load Config
config = litellm.ReplicateConfig.get_config() config = litellm.ReplicateConfig.get_config()
for k, v in config.items(): for k, v in config.items():
if k not in optional_params: # completion(top_k=3) > replicate_config(top_k=3) <- allows for dynamic variables to be passed in if k not in optional_params: # completion(top_k=3) > replicate_config(top_k=3) <- allows for dynamic variables to be passed in
optional_params[k] = v optional_params[k] = v
if "meta/llama-2-13b-chat" in model: system_prompt = None
system_prompt = "" if optional_params is not None and "supports_system_prompt" in optional_params:
prompt = "" supports_sys_prompt = optional_params.pop("supports_system_prompt")
for message in messages:
if message["role"] == "system":
system_prompt = message["content"]
else:
prompt += message["content"]
input_data = {
"system_prompt": system_prompt,
"prompt": prompt,
**optional_params
}
else: else:
if model in custom_prompt_dict: supports_sys_prompt = False
# check if the model has a registered custom prompt
model_prompt_details = custom_prompt_dict[model] if supports_sys_prompt:
prompt = custom_prompt( for i in range(len(messages)):
role_dict=model_prompt_details.get("roles", {}), if messages[i]["role"] == "system":
initial_prompt_value=model_prompt_details.get("initial_prompt_value", ""), first_sys_message = messages.pop(i)
final_prompt_value=model_prompt_details.get("final_prompt_value", ""), system_prompt = first_sys_message["content"]
bos_token=model_prompt_details.get("bos_token", ""), break
eos_token=model_prompt_details.get("eos_token", ""),
messages=messages, if model in custom_prompt_dict:
) # check if the model has a registered custom prompt
else: model_prompt_details = custom_prompt_dict[model]
prompt = prompt_factory(model=model, messages=messages) prompt = custom_prompt(
role_dict=model_prompt_details.get("roles", {}),
initial_prompt_value=model_prompt_details.get("initial_prompt_value", ""),
final_prompt_value=model_prompt_details.get("final_prompt_value", ""),
bos_token=model_prompt_details.get("bos_token", ""),
eos_token=model_prompt_details.get("eos_token", ""),
messages=messages,
)
else:
prompt = prompt_factory(model=model, messages=messages)
# If system prompt is supported, and a system prompt is provided, use it
if system_prompt is not None:
input_data = {
"prompt": prompt,
"system_prompt": system_prompt
}
# Otherwise, use the prompt as is
else:
input_data = { input_data = {
"prompt": prompt, "prompt": prompt,
**optional_params **optional_params