Updated SamplingParams

This commit is contained in:
Sajikumar JS 2025-04-26 18:25:58 +05:30
parent 03a25a7753
commit 3588c5bcd7
4 changed files with 13 additions and 66 deletions

View file

@ -82,7 +82,6 @@ class SamplingParams(BaseModel):
max_tokens: Optional[int] = 0
repetition_penalty: Optional[float] = 1.0
stop: Optional[List[str]] = None
additional_params: Optional[Dict[str, Any]] = {}
class LogProbConfig(BaseModel):

View file

@ -34,6 +34,9 @@ from llama_stack.apis.inference.inference import (
OpenAICompletion,
OpenAIMessageParam,
OpenAIResponseFormatParam,
GreedySamplingStrategy,
TopKSamplingStrategy,
TopPSamplingStrategy,
)
from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper
from llama_stack.providers.utils.inference.openai_compat import (
@ -230,36 +233,16 @@ class WatsonXInferenceAdapter(Inference, ModelRegistryHelper):
input_dict["params"][GenParams.MAX_NEW_TOKENS] = request.sampling_params.max_tokens
if request.sampling_params.repetition_penalty:
input_dict["params"][GenParams.REPETITION_PENALTY] = request.sampling_params.repetition_penalty
if request.sampling_params.additional_params.get("top_p"):
input_dict["params"][GenParams.TOP_P] = request.sampling_params.additional_params["top_p"]
if request.sampling_params.additional_params.get("top_k"):
input_dict["params"][GenParams.TOP_K] = request.sampling_params.additional_params["top_k"]
if request.sampling_params.additional_params.get("temperature"):
input_dict["params"][GenParams.TEMPERATURE] = request.sampling_params.additional_params["temperature"]
if request.sampling_params.additional_params.get("length_penalty"):
input_dict["params"][GenParams.LENGTH_PENALTY] = request.sampling_params.additional_params[
"length_penalty"
]
if request.sampling_params.additional_params.get("random_seed"):
input_dict["params"][GenParams.RANDOM_SEED] = request.sampling_params.additional_params["random_seed"]
if request.sampling_params.additional_params.get("min_new_tokens"):
input_dict["params"][GenParams.MIN_NEW_TOKENS] = request.sampling_params.additional_params[
"min_new_tokens"
]
if request.sampling_params.additional_params.get("stop_sequences"):
input_dict["params"][GenParams.STOP_SEQUENCES] = request.sampling_params.additional_params[
"stop_sequences"
]
if request.sampling_params.additional_params.get("time_limit"):
input_dict["params"][GenParams.TIME_LIMIT] = request.sampling_params.additional_params["time_limit"]
if request.sampling_params.additional_params.get("truncate_input_tokens"):
input_dict["params"][GenParams.TRUNCATE_INPUT_TOKENS] = request.sampling_params.additional_params[
"truncate_input_tokens"
]
if request.sampling_params.additional_params.get("return_options"):
input_dict["params"][GenParams.RETURN_OPTIONS] = request.sampling_params.additional_params[
"return_options"
]
if isinstance(request.sampling_params.strategy, TopPSamplingStrategy):
input_dict["top_p"] = request.sampling_params.strategy.top_p
input_dict["temperature"] = request.sampling_params.strategy.temperature
if isinstance(request.sampling_params.strategy, TopKSamplingStrategy):
input_dict["extra_body"]["top_k"] = request.sampling_params.strategy.top_k
if isinstance(request.sampling_params.strategy, GreedySamplingStrategy):
input_dict["temperature"] = 1.0
input_dict["params"][GenParams.STOP_SEQUENCES] = ["<|endoftext|>"]
params = {
**input_dict,