mirror of
https://github.com/meta-llama/llama-stack.git
synced 2026-01-02 15:30:01 +00:00
Update Strategy in SamplingParams to be a union
This commit is contained in:
parent
300e6e2702
commit
dea575c994
28 changed files with 600 additions and 377 deletions
|
|
@ -263,19 +263,18 @@ def convert_chat_completion_request(
|
|||
if request.sampling_params.max_tokens:
|
||||
payload.update(max_tokens=request.sampling_params.max_tokens)
|
||||
|
||||
if request.sampling_params.strategy == "top_p":
|
||||
strategy = request.sampling_params.strategy
|
||||
if isinstance(strategy, TopPSamplingStrategy):
|
||||
nvext.update(top_k=-1)
|
||||
payload.update(top_p=request.sampling_params.top_p)
|
||||
elif request.sampling_params.strategy == "top_k":
|
||||
if (
|
||||
request.sampling_params.top_k != -1
|
||||
and request.sampling_params.top_k < 1
|
||||
):
|
||||
payload.update(top_p=strategy.top_p)
|
||||
payload.update(temperature=strategy.temperature)
|
||||
elif isinstance(strategy, TopKSamplingStrategy):
|
||||
if strategy.top_k != -1 and strategy.top_k < 1:
|
||||
warnings.warn("top_k must be -1 or >= 1")
|
||||
nvext.update(top_k=request.sampling_params.top_k)
|
||||
elif request.sampling_params.strategy == "greedy":
|
||||
nvext.update(top_k=strategy.top_k)
|
||||
elif strategy.strategy == "greedy":
|
||||
nvext.update(top_k=-1)
|
||||
payload.update(temperature=request.sampling_params.temperature)
|
||||
payload.update(temperature=strategy.temperature)
|
||||
|
||||
return payload
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue