forked from phoenix/litellm-mirror
Merge pull request #970 from nbaldwin98/fixing-replicate-sys-prompt
fix system prompts for replicate
This commit is contained in:
commit
b90fcbdac4
2 changed files with 37 additions and 30 deletions
|
@ -49,8 +49,8 @@ Below are examples on how to call replicate LLMs using liteLLM
|
||||||
|
|
||||||
Model Name | Function Call | Required OS Variables |
|
Model Name | Function Call | Required OS Variables |
|
||||||
-----------------------------|----------------------------------------------------------------|--------------------------------------|
|
-----------------------------|----------------------------------------------------------------|--------------------------------------|
|
||||||
replicate/llama-2-70b-chat | `completion(model='replicate/llama-2-70b-chat:2796ee9483c3fd7aa2e171d38f4ca12251a30609463dcfd4cd76703f22e96cdf', messages)` | `os.environ['REPLICATE_API_KEY']` |
|
replicate/llama-2-70b-chat | `completion(model='replicate/llama-2-70b-chat:2796ee9483c3fd7aa2e171d38f4ca12251a30609463dcfd4cd76703f22e96cdf', messages, supports_system_prompt=True)` | `os.environ['REPLICATE_API_KEY']` |
|
||||||
a16z-infra/llama-2-13b-chat| `completion(model='replicate/a16z-infra/llama-2-13b-chat:2a7f981751ec7fdf87b5b91ad4db53683a98082e9ff7bfd12c8cd5ea85980a52', messages)`| `os.environ['REPLICATE_API_KEY']` |
|
a16z-infra/llama-2-13b-chat| `completion(model='replicate/a16z-infra/llama-2-13b-chat:2a7f981751ec7fdf87b5b91ad4db53683a98082e9ff7bfd12c8cd5ea85980a52', messages, supports_system_prompt=True)`| `os.environ['REPLICATE_API_KEY']` |
|
||||||
replicate/vicuna-13b | `completion(model='replicate/vicuna-13b:6282abe6a492de4145d7bb601023762212f9ddbbe78278bd6771c8b3b2f2a13b', messages)` | `os.environ['REPLICATE_API_KEY']` |
|
replicate/vicuna-13b | `completion(model='replicate/vicuna-13b:6282abe6a492de4145d7bb601023762212f9ddbbe78278bd6771c8b3b2f2a13b', messages)` | `os.environ['REPLICATE_API_KEY']` |
|
||||||
daanelson/flan-t5-large | `completion(model='replicate/daanelson/flan-t5-large:ce962b3f6792a57074a601d3979db5839697add2e4e02696b3ced4c022d4767f', messages)` | `os.environ['REPLICATE_API_KEY']` |
|
daanelson/flan-t5-large | `completion(model='replicate/daanelson/flan-t5-large:ce962b3f6792a57074a601d3979db5839697add2e4e02696b3ced4c022d4767f', messages)` | `os.environ['REPLICATE_API_KEY']` |
|
||||||
custom-llm | `completion(model='replicate/custom-llm-version-id', messages)` | `os.environ['REPLICATE_API_KEY']` |
|
custom-llm | `completion(model='replicate/custom-llm-version-id', messages)` | `os.environ['REPLICATE_API_KEY']` |
|
||||||
|
|
|
@ -169,6 +169,7 @@ def handle_prediction_response_streaming(prediction_url, api_token, print_verbos
|
||||||
else:
|
else:
|
||||||
# this can fail temporarily but it does not mean the replicate request failed, replicate request fails when status=="failed"
|
# this can fail temporarily but it does not mean the replicate request failed, replicate request fails when status=="failed"
|
||||||
print_verbose(f"Replicate: Failed to fetch prediction status and output.{response.status_code}{response.text}")
|
print_verbose(f"Replicate: Failed to fetch prediction status and output.{response.status_code}{response.text}")
|
||||||
|
|
||||||
|
|
||||||
# Function to extract version ID from model string
|
# Function to extract version ID from model string
|
||||||
def model_to_version_id(model):
|
def model_to_version_id(model):
|
||||||
|
@ -194,41 +195,47 @@ def completion(
|
||||||
):
|
):
|
||||||
# Start a prediction and get the prediction URL
|
# Start a prediction and get the prediction URL
|
||||||
version_id = model_to_version_id(model)
|
version_id = model_to_version_id(model)
|
||||||
|
|
||||||
## Load Config
|
## Load Config
|
||||||
config = litellm.ReplicateConfig.get_config()
|
config = litellm.ReplicateConfig.get_config()
|
||||||
for k, v in config.items():
|
for k, v in config.items():
|
||||||
if k not in optional_params: # completion(top_k=3) > replicate_config(top_k=3) <- allows for dynamic variables to be passed in
|
if k not in optional_params: # completion(top_k=3) > replicate_config(top_k=3) <- allows for dynamic variables to be passed in
|
||||||
optional_params[k] = v
|
optional_params[k] = v
|
||||||
|
|
||||||
if "meta/llama-2-13b-chat" in model:
|
system_prompt = None
|
||||||
system_prompt = ""
|
if optional_params is not None and "supports_system_prompt" in optional_params:
|
||||||
prompt = ""
|
supports_sys_prompt = optional_params.pop("supports_system_prompt")
|
||||||
for message in messages:
|
|
||||||
if message["role"] == "system":
|
|
||||||
system_prompt = message["content"]
|
|
||||||
else:
|
|
||||||
prompt += message["content"]
|
|
||||||
input_data = {
|
|
||||||
"system_prompt": system_prompt,
|
|
||||||
"prompt": prompt,
|
|
||||||
**optional_params
|
|
||||||
}
|
|
||||||
else:
|
else:
|
||||||
if model in custom_prompt_dict:
|
supports_sys_prompt = False
|
||||||
# check if the model has a registered custom prompt
|
|
||||||
model_prompt_details = custom_prompt_dict[model]
|
if supports_sys_prompt:
|
||||||
prompt = custom_prompt(
|
for i in range(len(messages)):
|
||||||
role_dict=model_prompt_details.get("roles", {}),
|
if messages[i]["role"] == "system":
|
||||||
initial_prompt_value=model_prompt_details.get("initial_prompt_value", ""),
|
first_sys_message = messages.pop(i)
|
||||||
final_prompt_value=model_prompt_details.get("final_prompt_value", ""),
|
system_prompt = first_sys_message["content"]
|
||||||
bos_token=model_prompt_details.get("bos_token", ""),
|
break
|
||||||
eos_token=model_prompt_details.get("eos_token", ""),
|
|
||||||
messages=messages,
|
if model in custom_prompt_dict:
|
||||||
)
|
# check if the model has a registered custom prompt
|
||||||
else:
|
model_prompt_details = custom_prompt_dict[model]
|
||||||
prompt = prompt_factory(model=model, messages=messages)
|
prompt = custom_prompt(
|
||||||
|
role_dict=model_prompt_details.get("roles", {}),
|
||||||
|
initial_prompt_value=model_prompt_details.get("initial_prompt_value", ""),
|
||||||
|
final_prompt_value=model_prompt_details.get("final_prompt_value", ""),
|
||||||
|
bos_token=model_prompt_details.get("bos_token", ""),
|
||||||
|
eos_token=model_prompt_details.get("eos_token", ""),
|
||||||
|
messages=messages,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
prompt = prompt_factory(model=model, messages=messages)
|
||||||
|
|
||||||
|
# If system prompt is supported, and a system prompt is provided, use it
|
||||||
|
if system_prompt is not None:
|
||||||
|
input_data = {
|
||||||
|
"prompt": prompt,
|
||||||
|
"system_prompt": system_prompt
|
||||||
|
}
|
||||||
|
# Otherwise, use the prompt as is
|
||||||
|
else:
|
||||||
input_data = {
|
input_data = {
|
||||||
"prompt": prompt,
|
"prompt": prompt,
|
||||||
**optional_params
|
**optional_params
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue