diff --git a/docs/_static/llama-stack-spec.html b/docs/_static/llama-stack-spec.html
index c35da8d55..4c5393947 100644
--- a/docs/_static/llama-stack-spec.html
+++ b/docs/_static/llama-stack-spec.html
@@ -4191,31 +4191,6 @@
"type": "string"
},
"description": "Up to 4 sequences where the API will stop generating further tokens. The returned text will not contain the stop sequence."
- },
- "additional_params": {
- "type": "object",
- "additionalProperties": {
- "oneOf": [
- {
- "type": "null"
- },
- {
- "type": "boolean"
- },
- {
- "type": "number"
- },
- {
- "type": "string"
- },
- {
- "type": "array"
- },
- {
- "type": "object"
- }
- ]
- }
}
},
"additionalProperties": false,
diff --git a/docs/_static/llama-stack-spec.yaml b/docs/_static/llama-stack-spec.yaml
index 591a296d0..a24f1a9db 100644
--- a/docs/_static/llama-stack-spec.yaml
+++ b/docs/_static/llama-stack-spec.yaml
@@ -2907,16 +2907,6 @@ components:
description: >-
Up to 4 sequences where the API will stop generating further tokens. The
returned text will not contain the stop sequence.
- additional_params:
- type: object
- additionalProperties:
- oneOf:
- - type: 'null'
- - type: boolean
- - type: number
- - type: string
- - type: array
- - type: object
additionalProperties: false
required:
- strategy
diff --git a/llama_stack/apis/inference/inference.py b/llama_stack/apis/inference/inference.py
index b91169fbb..309171f20 100644
--- a/llama_stack/apis/inference/inference.py
+++ b/llama_stack/apis/inference/inference.py
@@ -82,7 +82,6 @@ class SamplingParams(BaseModel):
max_tokens: Optional[int] = 0
repetition_penalty: Optional[float] = 1.0
stop: Optional[List[str]] = None
- additional_params: Optional[Dict[str, Any]] = {}
class LogProbConfig(BaseModel):
diff --git a/llama_stack/providers/remote/inference/watsonx/watsonx.py b/llama_stack/providers/remote/inference/watsonx/watsonx.py
index 63484c888..9feeba932 100644
--- a/llama_stack/providers/remote/inference/watsonx/watsonx.py
+++ b/llama_stack/providers/remote/inference/watsonx/watsonx.py
@@ -34,6 +34,9 @@ from llama_stack.apis.inference.inference import (
OpenAICompletion,
OpenAIMessageParam,
OpenAIResponseFormatParam,
+ GreedySamplingStrategy,
+ TopKSamplingStrategy,
+ TopPSamplingStrategy,
)
from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper
from llama_stack.providers.utils.inference.openai_compat import (
@@ -230,36 +233,16 @@ class WatsonXInferenceAdapter(Inference, ModelRegistryHelper):
input_dict["params"][GenParams.MAX_NEW_TOKENS] = request.sampling_params.max_tokens
if request.sampling_params.repetition_penalty:
input_dict["params"][GenParams.REPETITION_PENALTY] = request.sampling_params.repetition_penalty
- if request.sampling_params.additional_params.get("top_p"):
- input_dict["params"][GenParams.TOP_P] = request.sampling_params.additional_params["top_p"]
- if request.sampling_params.additional_params.get("top_k"):
- input_dict["params"][GenParams.TOP_K] = request.sampling_params.additional_params["top_k"]
- if request.sampling_params.additional_params.get("temperature"):
- input_dict["params"][GenParams.TEMPERATURE] = request.sampling_params.additional_params["temperature"]
- if request.sampling_params.additional_params.get("length_penalty"):
- input_dict["params"][GenParams.LENGTH_PENALTY] = request.sampling_params.additional_params[
- "length_penalty"
- ]
- if request.sampling_params.additional_params.get("random_seed"):
- input_dict["params"][GenParams.RANDOM_SEED] = request.sampling_params.additional_params["random_seed"]
- if request.sampling_params.additional_params.get("min_new_tokens"):
- input_dict["params"][GenParams.MIN_NEW_TOKENS] = request.sampling_params.additional_params[
- "min_new_tokens"
- ]
- if request.sampling_params.additional_params.get("stop_sequences"):
- input_dict["params"][GenParams.STOP_SEQUENCES] = request.sampling_params.additional_params[
- "stop_sequences"
- ]
- if request.sampling_params.additional_params.get("time_limit"):
- input_dict["params"][GenParams.TIME_LIMIT] = request.sampling_params.additional_params["time_limit"]
- if request.sampling_params.additional_params.get("truncate_input_tokens"):
- input_dict["params"][GenParams.TRUNCATE_INPUT_TOKENS] = request.sampling_params.additional_params[
- "truncate_input_tokens"
- ]
- if request.sampling_params.additional_params.get("return_options"):
- input_dict["params"][GenParams.RETURN_OPTIONS] = request.sampling_params.additional_params[
- "return_options"
- ]
+
+ if isinstance(request.sampling_params.strategy, TopPSamplingStrategy):
+ input_dict["top_p"] = request.sampling_params.strategy.top_p
+ input_dict["temperature"] = request.sampling_params.strategy.temperature
+ if isinstance(request.sampling_params.strategy, TopKSamplingStrategy):
+ input_dict["extra_body"]["top_k"] = request.sampling_params.strategy.top_k
+ if isinstance(request.sampling_params.strategy, GreedySamplingStrategy):
+ input_dict["temperature"] = 1.0
+
+ input_dict["params"][GenParams.STOP_SEQUENCES] = ["<|endoftext|>"]
params = {
**input_dict,