Update Strategy in SamplingParams to be a union

This commit is contained in:
Hardik Shah 2025-01-14 15:56:02 -08:00 committed by Ashwin Bharambe
parent 300e6e2702
commit dea575c994
28 changed files with 600 additions and 377 deletions

View file

@ -22,7 +22,12 @@ from llama_stack.apis.agents import (
ToolExecutionStep,
Turn,
)
from llama_stack.apis.inference import CompletionMessage, SamplingParams, UserMessage
from llama_stack.apis.inference import (
CompletionMessage,
SamplingParams,
TopPSamplingStrategy,
UserMessage,
)
from llama_stack.apis.safety import ViolationLevel
from llama_stack.providers.datatypes import Api
@ -42,7 +47,9 @@ def common_params(inference_model):
model=inference_model,
instructions="You are a helpful assistant.",
enable_session_persistence=True,
sampling_params=SamplingParams(temperature=0.7, top_p=0.95),
sampling_params=SamplingParams(
strategy=TopPSamplingStrategy(temperature=0.7, top_p=0.95)
),
input_shields=[],
output_shields=[],
toolgroups=[],

View file

@ -21,6 +21,7 @@ from groq.types.chat.chat_completion_message_tool_call import (
Function,
)
from groq.types.shared.function_definition import FunctionDefinition
from llama_models.datatypes import GreedySamplingStrategy, TopPSamplingStrategy
from llama_models.llama3.api.datatypes import ToolParamDefinition
from llama_stack.apis.inference import (
ChatCompletionRequest,
@ -152,21 +153,30 @@ class TestConvertChatCompletionRequest:
assert converted["max_tokens"] == 100
def test_includes_temperature(self):
def _dummy_chat_completion_request(self):
return ChatCompletionRequest(
model="Llama-3.2-3B",
messages=[UserMessage(content="Hello World")],
)
def test_includes_stratgy(self):
request = self._dummy_chat_completion_request()
request.sampling_params.temperature = 0.5
request.sampling_params.strategy = TopPSamplingStrategy(
temperature=0.5, top_p=0.95
)
converted = convert_chat_completion_request(request)
assert converted["temperature"] == 0.5
assert converted["top_p"] == 0.95
def test_includes_top_p(self):
def test_includes_greedy_strategy(self):
request = self._dummy_chat_completion_request()
request.sampling_params.top_p = 0.95
request.sampling_params.strategy = GreedySamplingStrategy()
converted = convert_chat_completion_request(request)
assert converted["top_p"] == 0.95
assert converted["temperature"] == 0.0
def test_includes_tool_choice(self):
request = self._dummy_chat_completion_request()