diff --git a/llama_stack/providers/remote/inference/watsonx/watsonx.py b/llama_stack/providers/remote/inference/watsonx/watsonx.py index 9feeba932..ef0ca41e1 100644 --- a/llama_stack/providers/remote/inference/watsonx/watsonx.py +++ b/llama_stack/providers/remote/inference/watsonx/watsonx.py @@ -235,12 +235,13 @@ class WatsonXInferenceAdapter(Inference, ModelRegistryHelper): input_dict["params"][GenParams.REPETITION_PENALTY] = request.sampling_params.repetition_penalty if isinstance(request.sampling_params.strategy, TopPSamplingStrategy): - input_dict["top_p"] = request.sampling_params.strategy.top_p - input_dict["temperature"] = request.sampling_params.strategy.temperature + input_dict["params"][GenParams.TOP_P] = request.sampling_params.strategy.top_p + input_dict["params"][GenParams.TEMPERATURE] = request.sampling_params.strategy.temperature if isinstance(request.sampling_params.strategy, TopKSamplingStrategy): - input_dict["extra_body"]["top_k"] = request.sampling_params.strategy.top_k + input_dict["params"][GenParams.TOP_K] = request.sampling_params.strategy.top_k if isinstance(request.sampling_params.strategy, GreedySamplingStrategy): - input_dict["temperature"] = 1.0 + input_dict["params"][GenParams.TEMPERATURE] = 0.0 + input_dict["params"][GenParams.STOP_SEQUENCES] = ["<|endoftext|>"]