mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-02 08:44:44 +00:00
Updated SamplingParams
This commit is contained in:
parent
03a25a7753
commit
3588c5bcd7
4 changed files with 13 additions and 66 deletions
25
docs/_static/llama-stack-spec.html
vendored
25
docs/_static/llama-stack-spec.html
vendored
|
@ -4191,31 +4191,6 @@
|
|||
"type": "string"
|
||||
},
|
||||
"description": "Up to 4 sequences where the API will stop generating further tokens. The returned text will not contain the stop sequence."
|
||||
},
|
||||
"additional_params": {
|
||||
"type": "object",
|
||||
"additionalProperties": {
|
||||
"oneOf": [
|
||||
{
|
||||
"type": "null"
|
||||
},
|
||||
{
|
||||
"type": "boolean"
|
||||
},
|
||||
{
|
||||
"type": "number"
|
||||
},
|
||||
{
|
||||
"type": "string"
|
||||
},
|
||||
{
|
||||
"type": "array"
|
||||
},
|
||||
{
|
||||
"type": "object"
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
},
|
||||
"additionalProperties": false,
|
||||
|
|
10
docs/_static/llama-stack-spec.yaml
vendored
10
docs/_static/llama-stack-spec.yaml
vendored
|
@ -2907,16 +2907,6 @@ components:
|
|||
description: >-
|
||||
Up to 4 sequences where the API will stop generating further tokens. The
|
||||
returned text will not contain the stop sequence.
|
||||
additional_params:
|
||||
type: object
|
||||
additionalProperties:
|
||||
oneOf:
|
||||
- type: 'null'
|
||||
- type: boolean
|
||||
- type: number
|
||||
- type: string
|
||||
- type: array
|
||||
- type: object
|
||||
additionalProperties: false
|
||||
required:
|
||||
- strategy
|
||||
|
|
|
@ -82,7 +82,6 @@ class SamplingParams(BaseModel):
|
|||
max_tokens: Optional[int] = 0
|
||||
repetition_penalty: Optional[float] = 1.0
|
||||
stop: Optional[List[str]] = None
|
||||
additional_params: Optional[Dict[str, Any]] = {}
|
||||
|
||||
|
||||
class LogProbConfig(BaseModel):
|
||||
|
|
|
@ -34,6 +34,9 @@ from llama_stack.apis.inference.inference import (
|
|||
OpenAICompletion,
|
||||
OpenAIMessageParam,
|
||||
OpenAIResponseFormatParam,
|
||||
GreedySamplingStrategy,
|
||||
TopKSamplingStrategy,
|
||||
TopPSamplingStrategy,
|
||||
)
|
||||
from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper
|
||||
from llama_stack.providers.utils.inference.openai_compat import (
|
||||
|
@ -230,36 +233,16 @@ class WatsonXInferenceAdapter(Inference, ModelRegistryHelper):
|
|||
input_dict["params"][GenParams.MAX_NEW_TOKENS] = request.sampling_params.max_tokens
|
||||
if request.sampling_params.repetition_penalty:
|
||||
input_dict["params"][GenParams.REPETITION_PENALTY] = request.sampling_params.repetition_penalty
|
||||
if request.sampling_params.additional_params.get("top_p"):
|
||||
input_dict["params"][GenParams.TOP_P] = request.sampling_params.additional_params["top_p"]
|
||||
if request.sampling_params.additional_params.get("top_k"):
|
||||
input_dict["params"][GenParams.TOP_K] = request.sampling_params.additional_params["top_k"]
|
||||
if request.sampling_params.additional_params.get("temperature"):
|
||||
input_dict["params"][GenParams.TEMPERATURE] = request.sampling_params.additional_params["temperature"]
|
||||
if request.sampling_params.additional_params.get("length_penalty"):
|
||||
input_dict["params"][GenParams.LENGTH_PENALTY] = request.sampling_params.additional_params[
|
||||
"length_penalty"
|
||||
]
|
||||
if request.sampling_params.additional_params.get("random_seed"):
|
||||
input_dict["params"][GenParams.RANDOM_SEED] = request.sampling_params.additional_params["random_seed"]
|
||||
if request.sampling_params.additional_params.get("min_new_tokens"):
|
||||
input_dict["params"][GenParams.MIN_NEW_TOKENS] = request.sampling_params.additional_params[
|
||||
"min_new_tokens"
|
||||
]
|
||||
if request.sampling_params.additional_params.get("stop_sequences"):
|
||||
input_dict["params"][GenParams.STOP_SEQUENCES] = request.sampling_params.additional_params[
|
||||
"stop_sequences"
|
||||
]
|
||||
if request.sampling_params.additional_params.get("time_limit"):
|
||||
input_dict["params"][GenParams.TIME_LIMIT] = request.sampling_params.additional_params["time_limit"]
|
||||
if request.sampling_params.additional_params.get("truncate_input_tokens"):
|
||||
input_dict["params"][GenParams.TRUNCATE_INPUT_TOKENS] = request.sampling_params.additional_params[
|
||||
"truncate_input_tokens"
|
||||
]
|
||||
if request.sampling_params.additional_params.get("return_options"):
|
||||
input_dict["params"][GenParams.RETURN_OPTIONS] = request.sampling_params.additional_params[
|
||||
"return_options"
|
||||
]
|
||||
|
||||
if isinstance(request.sampling_params.strategy, TopPSamplingStrategy):
|
||||
input_dict["top_p"] = request.sampling_params.strategy.top_p
|
||||
input_dict["temperature"] = request.sampling_params.strategy.temperature
|
||||
if isinstance(request.sampling_params.strategy, TopKSamplingStrategy):
|
||||
input_dict["extra_body"]["top_k"] = request.sampling_params.strategy.top_k
|
||||
if isinstance(request.sampling_params.strategy, GreedySamplingStrategy):
|
||||
input_dict["temperature"] = 1.0
|
||||
|
||||
input_dict["params"][GenParams.STOP_SEQUENCES] = ["<|endoftext|>"]
|
||||
|
||||
params = {
|
||||
**input_dict,
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue