fix nvidia sampling logic

This commit is contained in:
Hardik Shah 2025-01-14 18:00:30 -08:00 committed by Ashwin Bharambe
parent 0edd3ce78b
commit cb6c734460

View file

@ -8,6 +8,11 @@ import json
import warnings
from typing import Any, AsyncGenerator, Dict, Generator, List, Optional
from llama_models.datatypes import (
GreedySamplingStrategy,
TopKSamplingStrategy,
TopPSamplingStrategy,
)
from llama_models.llama3.api.datatypes import (
BuiltinTool,
StopReason,
@ -272,9 +277,11 @@ def convert_chat_completion_request(
if strategy.top_k != -1 and strategy.top_k < 1:
warnings.warn("top_k must be -1 or >= 1")
nvext.update(top_k=strategy.top_k)
elif strategy.strategy == "greedy":
elif isinstance(strategy, GreedySamplingStrategy):
nvext.update(top_k=-1)
payload.update(temperature=strategy.temperature)
else:
raise ValueError(f"Unsupported sampling strategy: {strategy}")
return payload