From 3588c5bcd769f310c4b02bfdbe2c73955c5d1707 Mon Sep 17 00:00:00 2001 From: Sajikumar JS Date: Sat, 26 Apr 2025 18:25:58 +0530 Subject: [PATCH] Updated SamplingParams --- docs/_static/llama-stack-spec.html | 25 ----------- docs/_static/llama-stack-spec.yaml | 10 ----- llama_stack/apis/inference/inference.py | 1 - .../remote/inference/watsonx/watsonx.py | 43 ++++++------------- 4 files changed, 13 insertions(+), 66 deletions(-) diff --git a/docs/_static/llama-stack-spec.html b/docs/_static/llama-stack-spec.html index c35da8d55..4c5393947 100644 --- a/docs/_static/llama-stack-spec.html +++ b/docs/_static/llama-stack-spec.html @@ -4191,31 +4191,6 @@ "type": "string" }, "description": "Up to 4 sequences where the API will stop generating further tokens. The returned text will not contain the stop sequence." - }, - "additional_params": { - "type": "object", - "additionalProperties": { - "oneOf": [ - { - "type": "null" - }, - { - "type": "boolean" - }, - { - "type": "number" - }, - { - "type": "string" - }, - { - "type": "array" - }, - { - "type": "object" - } - ] - } } }, "additionalProperties": false, diff --git a/docs/_static/llama-stack-spec.yaml b/docs/_static/llama-stack-spec.yaml index 591a296d0..a24f1a9db 100644 --- a/docs/_static/llama-stack-spec.yaml +++ b/docs/_static/llama-stack-spec.yaml @@ -2907,16 +2907,6 @@ components: description: >- Up to 4 sequences where the API will stop generating further tokens. The returned text will not contain the stop sequence. - additional_params: - type: object - additionalProperties: - oneOf: - - type: 'null' - - type: boolean - - type: number - - type: string - - type: array - - type: object additionalProperties: false required: - strategy diff --git a/llama_stack/apis/inference/inference.py b/llama_stack/apis/inference/inference.py index b91169fbb..309171f20 100644 --- a/llama_stack/apis/inference/inference.py +++ b/llama_stack/apis/inference/inference.py @@ -82,7 +82,6 @@ class SamplingParams(BaseModel): max_tokens: Optional[int] = 0 repetition_penalty: Optional[float] = 1.0 stop: Optional[List[str]] = None - additional_params: Optional[Dict[str, Any]] = {} class LogProbConfig(BaseModel): diff --git a/llama_stack/providers/remote/inference/watsonx/watsonx.py b/llama_stack/providers/remote/inference/watsonx/watsonx.py index 63484c888..9feeba932 100644 --- a/llama_stack/providers/remote/inference/watsonx/watsonx.py +++ b/llama_stack/providers/remote/inference/watsonx/watsonx.py @@ -34,6 +34,9 @@ from llama_stack.apis.inference.inference import ( OpenAICompletion, OpenAIMessageParam, OpenAIResponseFormatParam, + GreedySamplingStrategy, + TopKSamplingStrategy, + TopPSamplingStrategy, ) from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper from llama_stack.providers.utils.inference.openai_compat import ( @@ -230,36 +233,16 @@ class WatsonXInferenceAdapter(Inference, ModelRegistryHelper): input_dict["params"][GenParams.MAX_NEW_TOKENS] = request.sampling_params.max_tokens if request.sampling_params.repetition_penalty: input_dict["params"][GenParams.REPETITION_PENALTY] = request.sampling_params.repetition_penalty - if request.sampling_params.additional_params.get("top_p"): - input_dict["params"][GenParams.TOP_P] = request.sampling_params.additional_params["top_p"] - if request.sampling_params.additional_params.get("top_k"): - input_dict["params"][GenParams.TOP_K] = request.sampling_params.additional_params["top_k"] - if request.sampling_params.additional_params.get("temperature"): - input_dict["params"][GenParams.TEMPERATURE] = request.sampling_params.additional_params["temperature"] - if request.sampling_params.additional_params.get("length_penalty"): - input_dict["params"][GenParams.LENGTH_PENALTY] = request.sampling_params.additional_params[ - "length_penalty" - ] - if request.sampling_params.additional_params.get("random_seed"): - input_dict["params"][GenParams.RANDOM_SEED] = request.sampling_params.additional_params["random_seed"] - if request.sampling_params.additional_params.get("min_new_tokens"): - input_dict["params"][GenParams.MIN_NEW_TOKENS] = request.sampling_params.additional_params[ - "min_new_tokens" - ] - if request.sampling_params.additional_params.get("stop_sequences"): - input_dict["params"][GenParams.STOP_SEQUENCES] = request.sampling_params.additional_params[ - "stop_sequences" - ] - if request.sampling_params.additional_params.get("time_limit"): - input_dict["params"][GenParams.TIME_LIMIT] = request.sampling_params.additional_params["time_limit"] - if request.sampling_params.additional_params.get("truncate_input_tokens"): - input_dict["params"][GenParams.TRUNCATE_INPUT_TOKENS] = request.sampling_params.additional_params[ - "truncate_input_tokens" - ] - if request.sampling_params.additional_params.get("return_options"): - input_dict["params"][GenParams.RETURN_OPTIONS] = request.sampling_params.additional_params[ - "return_options" - ] + + if isinstance(request.sampling_params.strategy, TopPSamplingStrategy): + input_dict["top_p"] = request.sampling_params.strategy.top_p + input_dict["temperature"] = request.sampling_params.strategy.temperature + if isinstance(request.sampling_params.strategy, TopKSamplingStrategy): + input_dict["extra_body"]["top_k"] = request.sampling_params.strategy.top_k + if isinstance(request.sampling_params.strategy, GreedySamplingStrategy): + input_dict["temperature"] = 1.0 + + input_dict["params"][GenParams.STOP_SEQUENCES] = ["<|endoftext|>"] params = { **input_dict,