fix(replicate.py): pass version if passed in

This commit is contained in:
Krrish Dholakia 2024-04-26 17:11:21 -07:00
parent 069d1f863d
commit 92bf686b10

View file

@ -112,10 +112,16 @@ def start_prediction(
} }
initial_prediction_data = { initial_prediction_data = {
"version": version_id,
"input": input_data, "input": input_data,
} }
if ":" in version_id and len(version_id) > 64:
model_parts = version_id.split(":")
if (
len(model_parts) > 1 and len(model_parts[1]) == 64
): ## checks if model name has a 64 digit code - e.g. "meta/llama-2-70b-chat:02e509c789964a7ea8736978a43525956ef40397be9033abf9fd2badfe68c9e3"
initial_prediction_data["version"] = model_parts[1]
## LOGGING ## LOGGING
logging_obj.pre_call( logging_obj.pre_call(
input=input_data["prompt"], input=input_data["prompt"],