Updated parameters

This commit is contained in:
Sajikumar JS 2025-04-26 18:29:06 +05:30
parent 3588c5bcd7
commit 7f1e4bf075

View file

@ -235,12 +235,13 @@ class WatsonXInferenceAdapter(Inference, ModelRegistryHelper):
input_dict["params"][GenParams.REPETITION_PENALTY] = request.sampling_params.repetition_penalty input_dict["params"][GenParams.REPETITION_PENALTY] = request.sampling_params.repetition_penalty
if isinstance(request.sampling_params.strategy, TopPSamplingStrategy): if isinstance(request.sampling_params.strategy, TopPSamplingStrategy):
input_dict["top_p"] = request.sampling_params.strategy.top_p input_dict["params"][GenParams.TOP_P] = request.sampling_params.strategy.top_p
input_dict["temperature"] = request.sampling_params.strategy.temperature input_dict["params"][GenParams.TEMPERATURE] = request.sampling_params.strategy.temperature
if isinstance(request.sampling_params.strategy, TopKSamplingStrategy): if isinstance(request.sampling_params.strategy, TopKSamplingStrategy):
input_dict["extra_body"]["top_k"] = request.sampling_params.strategy.top_k input_dict["params"][GenParams.TOP_K] = request.sampling_params.strategy.top_k
if isinstance(request.sampling_params.strategy, GreedySamplingStrategy): if isinstance(request.sampling_params.strategy, GreedySamplingStrategy):
input_dict["temperature"] = 1.0 input_dict["params"][GenParams.TEMPERATURE] = 0.0
input_dict["params"][GenParams.STOP_SEQUENCES] = ["<|endoftext|>"] input_dict["params"][GenParams.STOP_SEQUENCES] = ["<|endoftext|>"]