Update Strategy in SamplingParams to be a union

This commit is contained in:
Hardik Shah 2025-01-14 15:56:02 -08:00 committed by Ashwin Bharambe
parent 300e6e2702
commit dea575c994
28 changed files with 600 additions and 377 deletions

View file

@ -263,19 +263,18 @@ def convert_chat_completion_request(
if request.sampling_params.max_tokens:
payload.update(max_tokens=request.sampling_params.max_tokens)
if request.sampling_params.strategy == "top_p":
strategy = request.sampling_params.strategy
if isinstance(strategy, TopPSamplingStrategy):
nvext.update(top_k=-1)
payload.update(top_p=request.sampling_params.top_p)
elif request.sampling_params.strategy == "top_k":
if (
request.sampling_params.top_k != -1
and request.sampling_params.top_k < 1
):
payload.update(top_p=strategy.top_p)
payload.update(temperature=strategy.temperature)
elif isinstance(strategy, TopKSamplingStrategy):
if strategy.top_k != -1 and strategy.top_k < 1:
warnings.warn("top_k must be -1 or >= 1")
nvext.update(top_k=request.sampling_params.top_k)
elif request.sampling_params.strategy == "greedy":
nvext.update(top_k=strategy.top_k)
elif strategy.strategy == "greedy":
nvext.update(top_k=-1)
payload.update(temperature=request.sampling_params.temperature)
payload.update(temperature=strategy.temperature)
return payload