fix(replicate.py): pass version if passed in

This commit is contained in:
Krrish Dholakia 2024-04-26 17:11:21 -07:00
parent e05764bdb7
commit 93463565fb

View file

@ -112,10 +112,16 @@ def start_prediction(
}
initial_prediction_data = {
"version": version_id,
"input": input_data,
}
if ":" in version_id and len(version_id) > 64:
model_parts = version_id.split(":")
if (
len(model_parts) > 1 and len(model_parts[1]) == 64
): ## checks if model name has a 64 digit code - e.g. "meta/llama-2-70b-chat:02e509c789964a7ea8736978a43525956ef40397be9033abf9fd2badfe68c9e3"
initial_prediction_data["version"] = model_parts[1]
## LOGGING
logging_obj.pre_call(
input=input_data["prompt"],