Updated SamplingParams

This commit is contained in:
Sajikumar JS 2025-04-26 18:25:58 +05:30
parent 03a25a7753
commit 3588c5bcd7
4 changed files with 13 additions and 66 deletions

View file

@ -4191,31 +4191,6 @@
"type": "string"
},
"description": "Up to 4 sequences where the API will stop generating further tokens. The returned text will not contain the stop sequence."
},
"additional_params": {
"type": "object",
"additionalProperties": {
"oneOf": [
{
"type": "null"
},
{
"type": "boolean"
},
{
"type": "number"
},
{
"type": "string"
},
{
"type": "array"
},
{
"type": "object"
}
]
}
}
},
"additionalProperties": false,

View file

@ -2907,16 +2907,6 @@ components:
description: >-
Up to 4 sequences where the API will stop generating further tokens. The
returned text will not contain the stop sequence.
additional_params:
type: object
additionalProperties:
oneOf:
- type: 'null'
- type: boolean
- type: number
- type: string
- type: array
- type: object
additionalProperties: false
required:
- strategy

View file

@ -82,7 +82,6 @@ class SamplingParams(BaseModel):
max_tokens: Optional[int] = 0
repetition_penalty: Optional[float] = 1.0
stop: Optional[List[str]] = None
additional_params: Optional[Dict[str, Any]] = {}
class LogProbConfig(BaseModel):

View file

@ -34,6 +34,9 @@ from llama_stack.apis.inference.inference import (
OpenAICompletion,
OpenAIMessageParam,
OpenAIResponseFormatParam,
GreedySamplingStrategy,
TopKSamplingStrategy,
TopPSamplingStrategy,
)
from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper
from llama_stack.providers.utils.inference.openai_compat import (
@ -230,36 +233,16 @@ class WatsonXInferenceAdapter(Inference, ModelRegistryHelper):
input_dict["params"][GenParams.MAX_NEW_TOKENS] = request.sampling_params.max_tokens
if request.sampling_params.repetition_penalty:
input_dict["params"][GenParams.REPETITION_PENALTY] = request.sampling_params.repetition_penalty
if request.sampling_params.additional_params.get("top_p"):
input_dict["params"][GenParams.TOP_P] = request.sampling_params.additional_params["top_p"]
if request.sampling_params.additional_params.get("top_k"):
input_dict["params"][GenParams.TOP_K] = request.sampling_params.additional_params["top_k"]
if request.sampling_params.additional_params.get("temperature"):
input_dict["params"][GenParams.TEMPERATURE] = request.sampling_params.additional_params["temperature"]
if request.sampling_params.additional_params.get("length_penalty"):
input_dict["params"][GenParams.LENGTH_PENALTY] = request.sampling_params.additional_params[
"length_penalty"
]
if request.sampling_params.additional_params.get("random_seed"):
input_dict["params"][GenParams.RANDOM_SEED] = request.sampling_params.additional_params["random_seed"]
if request.sampling_params.additional_params.get("min_new_tokens"):
input_dict["params"][GenParams.MIN_NEW_TOKENS] = request.sampling_params.additional_params[
"min_new_tokens"
]
if request.sampling_params.additional_params.get("stop_sequences"):
input_dict["params"][GenParams.STOP_SEQUENCES] = request.sampling_params.additional_params[
"stop_sequences"
]
if request.sampling_params.additional_params.get("time_limit"):
input_dict["params"][GenParams.TIME_LIMIT] = request.sampling_params.additional_params["time_limit"]
if request.sampling_params.additional_params.get("truncate_input_tokens"):
input_dict["params"][GenParams.TRUNCATE_INPUT_TOKENS] = request.sampling_params.additional_params[
"truncate_input_tokens"
]
if request.sampling_params.additional_params.get("return_options"):
input_dict["params"][GenParams.RETURN_OPTIONS] = request.sampling_params.additional_params[
"return_options"
]
if isinstance(request.sampling_params.strategy, TopPSamplingStrategy):
input_dict["top_p"] = request.sampling_params.strategy.top_p
input_dict["temperature"] = request.sampling_params.strategy.temperature
if isinstance(request.sampling_params.strategy, TopKSamplingStrategy):
input_dict["extra_body"]["top_k"] = request.sampling_params.strategy.top_k
if isinstance(request.sampling_params.strategy, GreedySamplingStrategy):
input_dict["temperature"] = 1.0
input_dict["params"][GenParams.STOP_SEQUENCES] = ["<|endoftext|>"]
params = {
**input_dict,