forked from phoenix/litellm-mirror
fix system prompts for replicate
This commit is contained in:
parent
1e9aa69268
commit
c2e2e927fb
2 changed files with 37 additions and 30 deletions
|
@ -49,8 +49,8 @@ Below are examples on how to call replicate LLMs using liteLLM
|
|||
|
||||
Model Name | Function Call | Required OS Variables |
|
||||
-----------------------------|----------------------------------------------------------------|--------------------------------------|
|
||||
replicate/llama-2-70b-chat | `completion(model='replicate/llama-2-70b-chat:2796ee9483c3fd7aa2e171d38f4ca12251a30609463dcfd4cd76703f22e96cdf', messages)` | `os.environ['REPLICATE_API_KEY']` |
|
||||
a16z-infra/llama-2-13b-chat| `completion(model='replicate/a16z-infra/llama-2-13b-chat:2a7f981751ec7fdf87b5b91ad4db53683a98082e9ff7bfd12c8cd5ea85980a52', messages)`| `os.environ['REPLICATE_API_KEY']` |
|
||||
replicate/llama-2-70b-chat | `completion(model='replicate/llama-2-70b-chat:2796ee9483c3fd7aa2e171d38f4ca12251a30609463dcfd4cd76703f22e96cdf', messages, supports_system_prompt=True)` | `os.environ['REPLICATE_API_KEY']` |
|
||||
a16z-infra/llama-2-13b-chat| `completion(model='replicate/a16z-infra/llama-2-13b-chat:2a7f981751ec7fdf87b5b91ad4db53683a98082e9ff7bfd12c8cd5ea85980a52', messages, supports_system_prompt=True)`| `os.environ['REPLICATE_API_KEY']` |
|
||||
replicate/vicuna-13b | `completion(model='replicate/vicuna-13b:6282abe6a492de4145d7bb601023762212f9ddbbe78278bd6771c8b3b2f2a13b', messages)` | `os.environ['REPLICATE_API_KEY']` |
|
||||
daanelson/flan-t5-large | `completion(model='replicate/daanelson/flan-t5-large:ce962b3f6792a57074a601d3979db5839697add2e4e02696b3ced4c022d4767f', messages)` | `os.environ['REPLICATE_API_KEY']` |
|
||||
custom-llm | `completion(model='replicate/custom-llm-version-id', messages)` | `os.environ['REPLICATE_API_KEY']` |
|
||||
|
|
|
@ -169,6 +169,7 @@ def handle_prediction_response_streaming(prediction_url, api_token, print_verbos
|
|||
else:
|
||||
# this can fail temporarily but it does not mean the replicate request failed, replicate request fails when status=="failed"
|
||||
print_verbose(f"Replicate: Failed to fetch prediction status and output.{response.status_code}{response.text}")
|
||||
|
||||
|
||||
# Function to extract version ID from model string
|
||||
def model_to_version_id(model):
|
||||
|
@ -194,41 +195,47 @@ def completion(
|
|||
):
|
||||
# Start a prediction and get the prediction URL
|
||||
version_id = model_to_version_id(model)
|
||||
|
||||
## Load Config
|
||||
config = litellm.ReplicateConfig.get_config()
|
||||
for k, v in config.items():
|
||||
if k not in optional_params: # completion(top_k=3) > replicate_config(top_k=3) <- allows for dynamic variables to be passed in
|
||||
optional_params[k] = v
|
||||
|
||||
if "meta/llama-2-13b-chat" in model:
|
||||
system_prompt = ""
|
||||
prompt = ""
|
||||
for message in messages:
|
||||
if message["role"] == "system":
|
||||
system_prompt = message["content"]
|
||||
else:
|
||||
prompt += message["content"]
|
||||
input_data = {
|
||||
"system_prompt": system_prompt,
|
||||
"prompt": prompt,
|
||||
**optional_params
|
||||
}
|
||||
|
||||
system_prompt = None
|
||||
if optional_params is not None and "supports_system_prompt" in optional_params:
|
||||
supports_sys_prompt = optional_params.pop("supports_system_prompt")
|
||||
else:
|
||||
if model in custom_prompt_dict:
|
||||
# check if the model has a registered custom prompt
|
||||
model_prompt_details = custom_prompt_dict[model]
|
||||
prompt = custom_prompt(
|
||||
role_dict=model_prompt_details.get("roles", {}),
|
||||
initial_prompt_value=model_prompt_details.get("initial_prompt_value", ""),
|
||||
final_prompt_value=model_prompt_details.get("final_prompt_value", ""),
|
||||
bos_token=model_prompt_details.get("bos_token", ""),
|
||||
eos_token=model_prompt_details.get("eos_token", ""),
|
||||
messages=messages,
|
||||
)
|
||||
else:
|
||||
prompt = prompt_factory(model=model, messages=messages)
|
||||
supports_sys_prompt = False
|
||||
|
||||
if supports_sys_prompt:
|
||||
for i in range(len(messages)):
|
||||
if messages[i]["role"] == "system":
|
||||
first_sys_message = messages.pop(i)
|
||||
system_prompt = first_sys_message["content"]
|
||||
break
|
||||
|
||||
if model in custom_prompt_dict:
|
||||
# check if the model has a registered custom prompt
|
||||
model_prompt_details = custom_prompt_dict[model]
|
||||
prompt = custom_prompt(
|
||||
role_dict=model_prompt_details.get("roles", {}),
|
||||
initial_prompt_value=model_prompt_details.get("initial_prompt_value", ""),
|
||||
final_prompt_value=model_prompt_details.get("final_prompt_value", ""),
|
||||
bos_token=model_prompt_details.get("bos_token", ""),
|
||||
eos_token=model_prompt_details.get("eos_token", ""),
|
||||
messages=messages,
|
||||
)
|
||||
else:
|
||||
prompt = prompt_factory(model=model, messages=messages)
|
||||
|
||||
# If system prompt is supported, and a system prompt is provided, use it
|
||||
if system_prompt is not None:
|
||||
input_data = {
|
||||
"prompt": prompt,
|
||||
"system_prompt": system_prompt
|
||||
}
|
||||
# Otherwise, use the prompt as is
|
||||
else:
|
||||
input_data = {
|
||||
"prompt": prompt,
|
||||
**optional_params
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue