From cb6c734460ce9ebc35703d7d8811d52dda291c6b Mon Sep 17 00:00:00 2001 From: Hardik Shah Date: Tue, 14 Jan 2025 18:00:30 -0800 Subject: [PATCH] fix nvidia sampling logic --- .../providers/remote/inference/nvidia/openai_utils.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/llama_stack/providers/remote/inference/nvidia/openai_utils.py b/llama_stack/providers/remote/inference/nvidia/openai_utils.py index 977c0704f..8db7f9197 100644 --- a/llama_stack/providers/remote/inference/nvidia/openai_utils.py +++ b/llama_stack/providers/remote/inference/nvidia/openai_utils.py @@ -8,6 +8,11 @@ import json import warnings from typing import Any, AsyncGenerator, Dict, Generator, List, Optional +from llama_models.datatypes import ( + GreedySamplingStrategy, + TopKSamplingStrategy, + TopPSamplingStrategy, +) from llama_models.llama3.api.datatypes import ( BuiltinTool, StopReason, @@ -272,9 +277,11 @@ def convert_chat_completion_request( if strategy.top_k != -1 and strategy.top_k < 1: warnings.warn("top_k must be -1 or >= 1") nvext.update(top_k=strategy.top_k) - elif strategy.strategy == "greedy": + elif isinstance(strategy, GreedySamplingStrategy): nvext.update(top_k=-1) payload.update(temperature=strategy.temperature) + else: + raise ValueError(f"Unsupported sampling strategy: {strategy}") return payload