Updated SamplingParams

This commit is contained in:
Sajikumar JS 2025-04-26 18:25:58 +05:30
parent 03a25a7753
commit 3588c5bcd7
4 changed files with 13 additions and 66 deletions

View file

@ -4191,31 +4191,6 @@
"type": "string" "type": "string"
}, },
"description": "Up to 4 sequences where the API will stop generating further tokens. The returned text will not contain the stop sequence." "description": "Up to 4 sequences where the API will stop generating further tokens. The returned text will not contain the stop sequence."
},
"additional_params": {
"type": "object",
"additionalProperties": {
"oneOf": [
{
"type": "null"
},
{
"type": "boolean"
},
{
"type": "number"
},
{
"type": "string"
},
{
"type": "array"
},
{
"type": "object"
}
]
}
} }
}, },
"additionalProperties": false, "additionalProperties": false,

View file

@ -2907,16 +2907,6 @@ components:
description: >- description: >-
Up to 4 sequences where the API will stop generating further tokens. The Up to 4 sequences where the API will stop generating further tokens. The
returned text will not contain the stop sequence. returned text will not contain the stop sequence.
additional_params:
type: object
additionalProperties:
oneOf:
- type: 'null'
- type: boolean
- type: number
- type: string
- type: array
- type: object
additionalProperties: false additionalProperties: false
required: required:
- strategy - strategy

View file

@ -82,7 +82,6 @@ class SamplingParams(BaseModel):
max_tokens: Optional[int] = 0 max_tokens: Optional[int] = 0
repetition_penalty: Optional[float] = 1.0 repetition_penalty: Optional[float] = 1.0
stop: Optional[List[str]] = None stop: Optional[List[str]] = None
additional_params: Optional[Dict[str, Any]] = {}
class LogProbConfig(BaseModel): class LogProbConfig(BaseModel):

View file

@ -34,6 +34,9 @@ from llama_stack.apis.inference.inference import (
OpenAICompletion, OpenAICompletion,
OpenAIMessageParam, OpenAIMessageParam,
OpenAIResponseFormatParam, OpenAIResponseFormatParam,
GreedySamplingStrategy,
TopKSamplingStrategy,
TopPSamplingStrategy,
) )
from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper
from llama_stack.providers.utils.inference.openai_compat import ( from llama_stack.providers.utils.inference.openai_compat import (
@ -230,36 +233,16 @@ class WatsonXInferenceAdapter(Inference, ModelRegistryHelper):
input_dict["params"][GenParams.MAX_NEW_TOKENS] = request.sampling_params.max_tokens input_dict["params"][GenParams.MAX_NEW_TOKENS] = request.sampling_params.max_tokens
if request.sampling_params.repetition_penalty: if request.sampling_params.repetition_penalty:
input_dict["params"][GenParams.REPETITION_PENALTY] = request.sampling_params.repetition_penalty input_dict["params"][GenParams.REPETITION_PENALTY] = request.sampling_params.repetition_penalty
if request.sampling_params.additional_params.get("top_p"):
input_dict["params"][GenParams.TOP_P] = request.sampling_params.additional_params["top_p"] if isinstance(request.sampling_params.strategy, TopPSamplingStrategy):
if request.sampling_params.additional_params.get("top_k"): input_dict["top_p"] = request.sampling_params.strategy.top_p
input_dict["params"][GenParams.TOP_K] = request.sampling_params.additional_params["top_k"] input_dict["temperature"] = request.sampling_params.strategy.temperature
if request.sampling_params.additional_params.get("temperature"): if isinstance(request.sampling_params.strategy, TopKSamplingStrategy):
input_dict["params"][GenParams.TEMPERATURE] = request.sampling_params.additional_params["temperature"] input_dict["extra_body"]["top_k"] = request.sampling_params.strategy.top_k
if request.sampling_params.additional_params.get("length_penalty"): if isinstance(request.sampling_params.strategy, GreedySamplingStrategy):
input_dict["params"][GenParams.LENGTH_PENALTY] = request.sampling_params.additional_params[ input_dict["temperature"] = 1.0
"length_penalty"
] input_dict["params"][GenParams.STOP_SEQUENCES] = ["<|endoftext|>"]
if request.sampling_params.additional_params.get("random_seed"):
input_dict["params"][GenParams.RANDOM_SEED] = request.sampling_params.additional_params["random_seed"]
if request.sampling_params.additional_params.get("min_new_tokens"):
input_dict["params"][GenParams.MIN_NEW_TOKENS] = request.sampling_params.additional_params[
"min_new_tokens"
]
if request.sampling_params.additional_params.get("stop_sequences"):
input_dict["params"][GenParams.STOP_SEQUENCES] = request.sampling_params.additional_params[
"stop_sequences"
]
if request.sampling_params.additional_params.get("time_limit"):
input_dict["params"][GenParams.TIME_LIMIT] = request.sampling_params.additional_params["time_limit"]
if request.sampling_params.additional_params.get("truncate_input_tokens"):
input_dict["params"][GenParams.TRUNCATE_INPUT_TOKENS] = request.sampling_params.additional_params[
"truncate_input_tokens"
]
if request.sampling_params.additional_params.get("return_options"):
input_dict["params"][GenParams.RETURN_OPTIONS] = request.sampling_params.additional_params[
"return_options"
]
params = { params = {
**input_dict, **input_dict,