Merge branch 'main' into add-watsonx-inference-adapter

This commit is contained in:
Sajikumar JS 2025-03-25 09:34:48 +05:30
commit 4b53171139
6 changed files with 79 additions and 4 deletions

View file

@ -195,11 +195,23 @@ register_schema(SamplingStrategy, name="SamplingStrategy")
@json_schema_type
class SamplingParams(BaseModel):
"""Sampling parameters.
:param strategy: The sampling strategy.
:param max_tokens: The maximum number of tokens that can be generated in the completion. The token count of
your prompt plus max_tokens cannot exceed the model's context length.
:param repetition_penalty: Number between -2.0 and 2.0. Positive values penalize new tokens
based on whether they appear in the text so far, increasing the model's likelihood to talk about new topics.
:param stop: Up to 4 sequences where the API will stop generating further tokens.
The returned text will not contain the stop sequence.
"""
strategy: SamplingStrategy = Field(default_factory=GreedySamplingStrategy)
max_tokens: Optional[int] = 0
repetition_penalty: Optional[float] = 1.0
additional_params: Optional[dict] = {}
stop: Optional[List[str]] = None
class CheckpointQuantizationFormat(Enum):

View file

@ -147,6 +147,9 @@ def get_sampling_options(params: SamplingParams) -> dict:
if params.repetition_penalty is not None and params.repetition_penalty != 1.0:
options["repeat_penalty"] = params.repetition_penalty
if params.stop is not None:
options["stop"] = params.stop
return options