fix replicate error

This commit is contained in:
Krrish Dholakia 2023-10-02 21:35:16 -07:00
parent 345a14483e
commit 5a1a1908c1
7 changed files with 29 additions and 14 deletions

View file

@ -24,7 +24,6 @@ def start_prediction(version_id, input_data, api_token, logging_obj):
initial_prediction_data = {
"version": version_id,
"input": input_data,
"max_new_tokens": 500,
}
## LOGGING
@ -109,17 +108,32 @@ def completion(
litellm_params=None,
logger_fn=None,
):
# Convert messages to prompt
prompt = ""
for message in messages:
prompt += message["content"]
# Start a prediction and get the prediction URL
version_id = model_to_version_id(model)
input_data = {
"prompt": prompt,
**optional_params
}
if "meta/llama-2-13b-chat" in model:
system_prompt = ""
prompt = ""
for message in messages:
if message["role"] == "system":
system_prompt = message["content"]
else:
prompt += message["content"]
input_data = {
"system_prompt": system_prompt,
"prompt": prompt,
**optional_params
}
else:
# Convert messages to prompt
prompt = ""
for message in messages:
prompt += message["content"]
input_data = {
"prompt": prompt,
**optional_params
}
## COMPLETION CALL
## Replicate Compeltion calls have 2 steps