fix(replicate.py): fix custom prompt formatting

This commit is contained in:
Krrish Dholakia 2023-11-29 19:44:02 -08:00
parent c05da0797b
commit 1f5a1122fc
5 changed files with 177 additions and 80 deletions

View file

@ -6,6 +6,7 @@ from typing import Callable, Optional
from litellm.utils import ModelResponse, Usage
import litellm
import httpx
from .prompt_templates.factory import prompt_factory, custom_prompt
class ReplicateError(Exception):
def __init__(self, status_code, message):
@ -186,6 +187,7 @@ def completion(
logging_obj,
api_key,
encoding,
custom_prompt_dict={},
optional_params=None,
litellm_params=None,
logger_fn=None,
@ -213,10 +215,19 @@ def completion(
**optional_params
}
else:
# Convert messages to prompt
prompt = ""
for message in messages:
prompt += message["content"]
if model in custom_prompt_dict:
# check if the model has a registered custom prompt
model_prompt_details = custom_prompt_dict[model]
prompt = custom_prompt(
role_dict=model_prompt_details.get("roles", {}),
initial_prompt_value=model_prompt_details.get("initial_prompt_value", ""),
final_prompt_value=model_prompt_details.get("final_prompt_value", ""),
bos_token=model_prompt_details.get("bos_token", ""),
eos_token=model_prompt_details.get("eos_token", ""),
messages=messages,
)
else:
prompt = prompt_factory(model=model, messages=messages)
input_data = {
"prompt": prompt,
@ -245,7 +256,7 @@ def completion(
input=prompt,
api_key="",
original_response=result,
additional_args={"complete_input_dict": input_data,"logs": logs},
additional_args={"complete_input_dict": input_data,"logs": logs, "api_base": prediction_url, },
)
print_verbose(f"raw model_response: {result}")