mirror of
https://github.com/meta-llama/llama-stack.git
synced 2026-01-02 09:59:59 +00:00
Update Strategy in SamplingParams to be a union
This commit is contained in:
parent
300e6e2702
commit
dea575c994
28 changed files with 600 additions and 377 deletions
|
|
@ -22,7 +22,12 @@ from llama_stack.apis.agents import (
|
|||
ToolExecutionStep,
|
||||
Turn,
|
||||
)
|
||||
from llama_stack.apis.inference import CompletionMessage, SamplingParams, UserMessage
|
||||
from llama_stack.apis.inference import (
|
||||
CompletionMessage,
|
||||
SamplingParams,
|
||||
TopPSamplingStrategy,
|
||||
UserMessage,
|
||||
)
|
||||
from llama_stack.apis.safety import ViolationLevel
|
||||
from llama_stack.providers.datatypes import Api
|
||||
|
||||
|
|
@ -42,7 +47,9 @@ def common_params(inference_model):
|
|||
model=inference_model,
|
||||
instructions="You are a helpful assistant.",
|
||||
enable_session_persistence=True,
|
||||
sampling_params=SamplingParams(temperature=0.7, top_p=0.95),
|
||||
sampling_params=SamplingParams(
|
||||
strategy=TopPSamplingStrategy(temperature=0.7, top_p=0.95)
|
||||
),
|
||||
input_shields=[],
|
||||
output_shields=[],
|
||||
toolgroups=[],
|
||||
|
|
|
|||
|
|
@ -21,6 +21,7 @@ from groq.types.chat.chat_completion_message_tool_call import (
|
|||
Function,
|
||||
)
|
||||
from groq.types.shared.function_definition import FunctionDefinition
|
||||
from llama_models.datatypes import GreedySamplingStrategy, TopPSamplingStrategy
|
||||
from llama_models.llama3.api.datatypes import ToolParamDefinition
|
||||
from llama_stack.apis.inference import (
|
||||
ChatCompletionRequest,
|
||||
|
|
@ -152,21 +153,30 @@ class TestConvertChatCompletionRequest:
|
|||
|
||||
assert converted["max_tokens"] == 100
|
||||
|
||||
def test_includes_temperature(self):
|
||||
def _dummy_chat_completion_request(self):
|
||||
return ChatCompletionRequest(
|
||||
model="Llama-3.2-3B",
|
||||
messages=[UserMessage(content="Hello World")],
|
||||
)
|
||||
|
||||
def test_includes_stratgy(self):
|
||||
request = self._dummy_chat_completion_request()
|
||||
request.sampling_params.temperature = 0.5
|
||||
request.sampling_params.strategy = TopPSamplingStrategy(
|
||||
temperature=0.5, top_p=0.95
|
||||
)
|
||||
|
||||
converted = convert_chat_completion_request(request)
|
||||
|
||||
assert converted["temperature"] == 0.5
|
||||
assert converted["top_p"] == 0.95
|
||||
|
||||
def test_includes_top_p(self):
|
||||
def test_includes_greedy_strategy(self):
|
||||
request = self._dummy_chat_completion_request()
|
||||
request.sampling_params.top_p = 0.95
|
||||
request.sampling_params.strategy = GreedySamplingStrategy()
|
||||
|
||||
converted = convert_chat_completion_request(request)
|
||||
|
||||
assert converted["top_p"] == 0.95
|
||||
assert converted["temperature"] == 0.0
|
||||
|
||||
def test_includes_tool_choice(self):
|
||||
request = self._dummy_chat_completion_request()
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue