forked from phoenix-oss/llama-stack-mirror
Convert SamplingParams.strategy
to a union (#767)
# What does this PR do? Cleans up how we provide sampling params. Earlier, strategy was an enum and all params (top_p, temperature, top_k) across all strategies were grouped. We now have a strategy union object with each strategy (greedy, top_p, top_k) having its corresponding params. Earlier, ``` class SamplingParams: strategy: enum () top_p, temperature, top_k and other params ``` However, the `strategy` field was not being used in any providers making it confusing to know the exact sampling behavior purely based on the params since you could pass temperature, top_p, top_k and how the provider would interpret those would not be clear. Hence we introduced -- a union where the strategy and relevant params are all clubbed together to avoid this confusion. Have updated all providers, tests, notebooks, readme and otehr places where sampling params was being used to use the new format. ## Test Plan `pytest llama_stack/providers/tests/inference/groq/test_groq_utils.py` // inference on ollama, fireworks and together `with-proxy pytest -v -s -k "ollama" --inference-model="meta-llama/Llama-3.1-8B-Instruct" llama_stack/providers/tests/inference/test_text_inference.py ` // agents on fireworks `pytest -v -s -k 'fireworks and create_agent' --inference-model="meta-llama/Llama-3.1-8B-Instruct" llama_stack/providers/tests/agents/test_agents.py --safety-shield="meta-llama/Llama-Guard-3-8B"` ## Before submitting - [ ] This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case). - [X] Ran pre-commit to handle lint / formatting issues. - [X] Read the [contributor guideline](https://github.com/meta-llama/llama-stack/blob/main/CONTRIBUTING.md), Pull Request section? - [X] Updated relevant documentation. - [X] Wrote necessary unit or integration tests. --------- Co-authored-by: Hardik Shah <hjshah@fb.com>
This commit is contained in:
parent
300e6e2702
commit
a51c8b4efc
29 changed files with 611 additions and 388 deletions
|
@ -13,7 +13,6 @@ from termcolor import colored
|
|||
|
||||
from llama_stack.cli.subcommand import Subcommand
|
||||
from llama_stack.cli.table import print_table
|
||||
from llama_stack.distribution.utils.serialize import EnumEncoder
|
||||
|
||||
|
||||
class ModelDescribe(Subcommand):
|
||||
|
@ -72,7 +71,7 @@ class ModelDescribe(Subcommand):
|
|||
rows.append(
|
||||
(
|
||||
"Recommended sampling params",
|
||||
json.dumps(sampling_params, cls=EnumEncoder, indent=4),
|
||||
json.dumps(sampling_params, indent=4),
|
||||
)
|
||||
)
|
||||
|
||||
|
|
|
@ -58,11 +58,6 @@ def define_eval_candidate_2():
|
|||
|
||||
# Sampling Parameters
|
||||
st.markdown("##### Sampling Parameters")
|
||||
strategy = st.selectbox(
|
||||
"Strategy",
|
||||
["greedy", "top_p", "top_k"],
|
||||
index=0,
|
||||
)
|
||||
temperature = st.slider(
|
||||
"Temperature",
|
||||
min_value=0.0,
|
||||
|
@ -95,13 +90,20 @@ def define_eval_candidate_2():
|
|||
help="Controls the likelihood for generating the same word or phrase multiple times in the same sentence or paragraph. 1 implies no penalty, 2 will strongly discourage model to repeat words or phrases.",
|
||||
)
|
||||
if candidate_type == "model":
|
||||
if temperature > 0.0:
|
||||
strategy = {
|
||||
"type": "top_p",
|
||||
"temperature": temperature,
|
||||
"top_p": top_p,
|
||||
}
|
||||
else:
|
||||
strategy = {"type": "greedy"}
|
||||
|
||||
eval_candidate = {
|
||||
"type": "model",
|
||||
"model": selected_model,
|
||||
"sampling_params": {
|
||||
"strategy": strategy,
|
||||
"temperature": temperature,
|
||||
"top_p": top_p,
|
||||
"max_tokens": max_tokens,
|
||||
"repetition_penalty": repetition_penalty,
|
||||
},
|
||||
|
|
|
@ -95,6 +95,15 @@ if prompt := st.chat_input("Example: What is Llama Stack?"):
|
|||
message_placeholder = st.empty()
|
||||
full_response = ""
|
||||
|
||||
if temperature > 0.0:
|
||||
strategy = {
|
||||
"type": "top_p",
|
||||
"temperature": temperature,
|
||||
"top_p": top_p,
|
||||
}
|
||||
else:
|
||||
strategy = {"type": "greedy"}
|
||||
|
||||
response = llama_stack_api.client.inference.chat_completion(
|
||||
messages=[
|
||||
{"role": "system", "content": system_prompt},
|
||||
|
@ -103,8 +112,7 @@ if prompt := st.chat_input("Example: What is Llama Stack?"):
|
|||
model_id=selected_model,
|
||||
stream=stream,
|
||||
sampling_params={
|
||||
"temperature": temperature,
|
||||
"top_p": top_p,
|
||||
"strategy": strategy,
|
||||
"max_tokens": max_tokens,
|
||||
"repetition_penalty": repetition_penalty,
|
||||
},
|
||||
|
|
|
@ -118,13 +118,20 @@ def rag_chat_page():
|
|||
with st.chat_message(message["role"]):
|
||||
st.markdown(message["content"])
|
||||
|
||||
if temperature > 0.0:
|
||||
strategy = {
|
||||
"type": "top_p",
|
||||
"temperature": temperature,
|
||||
"top_p": top_p,
|
||||
}
|
||||
else:
|
||||
strategy = {"type": "greedy"}
|
||||
|
||||
agent_config = AgentConfig(
|
||||
model=selected_model,
|
||||
instructions=system_prompt,
|
||||
sampling_params={
|
||||
"strategy": "greedy",
|
||||
"temperature": temperature,
|
||||
"top_p": top_p,
|
||||
"strategy": strategy,
|
||||
},
|
||||
tools=[
|
||||
{
|
||||
|
|
|
@ -23,6 +23,11 @@ from fairscale.nn.model_parallel.initialize import (
|
|||
initialize_model_parallel,
|
||||
model_parallel_is_initialized,
|
||||
)
|
||||
from llama_models.datatypes import (
|
||||
GreedySamplingStrategy,
|
||||
SamplingParams,
|
||||
TopPSamplingStrategy,
|
||||
)
|
||||
from llama_models.llama3.api.args import ModelArgs
|
||||
from llama_models.llama3.api.chat_format import ChatFormat, LLMInput
|
||||
from llama_models.llama3.api.datatypes import Model
|
||||
|
@ -363,11 +368,12 @@ class Llama:
|
|||
max_gen_len = self.model.params.max_seq_len - 1
|
||||
|
||||
model_input = self.formatter.encode_content(request.content)
|
||||
temperature, top_p = _infer_sampling_params(sampling_params)
|
||||
yield from self.generate(
|
||||
model_input=model_input,
|
||||
max_gen_len=max_gen_len,
|
||||
temperature=sampling_params.temperature,
|
||||
top_p=sampling_params.top_p,
|
||||
temperature=temperature,
|
||||
top_p=top_p,
|
||||
logprobs=bool(request.logprobs),
|
||||
include_stop_token=True,
|
||||
logits_processor=get_logits_processor(
|
||||
|
@ -390,14 +396,15 @@ class Llama:
|
|||
):
|
||||
max_gen_len = self.model.params.max_seq_len - 1
|
||||
|
||||
temperature, top_p = _infer_sampling_params(sampling_params)
|
||||
yield from self.generate(
|
||||
model_input=self.formatter.encode_dialog_prompt(
|
||||
request.messages,
|
||||
request.tool_prompt_format,
|
||||
),
|
||||
max_gen_len=max_gen_len,
|
||||
temperature=sampling_params.temperature,
|
||||
top_p=sampling_params.top_p,
|
||||
temperature=temperature,
|
||||
top_p=top_p,
|
||||
logprobs=bool(request.logprobs),
|
||||
include_stop_token=True,
|
||||
logits_processor=get_logits_processor(
|
||||
|
@ -492,3 +499,15 @@ def _build_regular_tokens_list(
|
|||
is_word_start_token = len(decoded_after_0) > len(decoded_regular)
|
||||
regular_tokens.append((token_idx, decoded_after_0, is_word_start_token))
|
||||
return regular_tokens
|
||||
|
||||
|
||||
def _infer_sampling_params(sampling_params: SamplingParams):
|
||||
if isinstance(sampling_params.strategy, GreedySamplingStrategy):
|
||||
temperature = 0.0
|
||||
top_p = 1.0
|
||||
elif isinstance(sampling_params.strategy, TopPSamplingStrategy):
|
||||
temperature = sampling_params.strategy.temperature
|
||||
top_p = sampling_params.strategy.top_p
|
||||
else:
|
||||
raise ValueError(f"Unsupported sampling strategy {sampling_params.strategy}")
|
||||
return temperature, top_p
|
||||
|
|
|
@ -36,6 +36,7 @@ from llama_stack.apis.inference import (
|
|||
from llama_stack.apis.models import Model
|
||||
from llama_stack.providers.datatypes import ModelsProtocolPrivate
|
||||
from llama_stack.providers.utils.inference.openai_compat import (
|
||||
get_sampling_options,
|
||||
OpenAICompatCompletionChoice,
|
||||
OpenAICompatCompletionResponse,
|
||||
process_chat_completion_response,
|
||||
|
@ -126,21 +127,12 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate):
|
|||
if sampling_params is None:
|
||||
return VLLMSamplingParams(max_tokens=self.config.max_tokens)
|
||||
|
||||
# TODO convert what I saw in my first test ... but surely there's more to do here
|
||||
kwargs = {
|
||||
"temperature": sampling_params.temperature,
|
||||
"max_tokens": self.config.max_tokens,
|
||||
}
|
||||
if sampling_params.top_k:
|
||||
kwargs["top_k"] = sampling_params.top_k
|
||||
if sampling_params.top_p:
|
||||
kwargs["top_p"] = sampling_params.top_p
|
||||
if sampling_params.max_tokens:
|
||||
kwargs["max_tokens"] = sampling_params.max_tokens
|
||||
if sampling_params.repetition_penalty > 0:
|
||||
kwargs["repetition_penalty"] = sampling_params.repetition_penalty
|
||||
options = get_sampling_options(sampling_params)
|
||||
if "repeat_penalty" in options:
|
||||
options["repetition_penalty"] = options["repeat_penalty"]
|
||||
del options["repeat_penalty"]
|
||||
|
||||
return VLLMSamplingParams(**kwargs)
|
||||
return VLLMSamplingParams(**options)
|
||||
|
||||
async def unregister_model(self, model_id: str) -> None:
|
||||
pass
|
||||
|
|
|
@ -34,6 +34,7 @@ from llama_stack.providers.utils.inference.model_registry import (
|
|||
ModelRegistryHelper,
|
||||
)
|
||||
from llama_stack.providers.utils.inference.openai_compat import (
|
||||
get_sampling_strategy_options,
|
||||
OpenAICompatCompletionChoice,
|
||||
OpenAICompatCompletionResponse,
|
||||
process_chat_completion_response,
|
||||
|
@ -166,16 +167,13 @@ class BedrockInferenceAdapter(ModelRegistryHelper, Inference):
|
|||
) -> Dict:
|
||||
bedrock_model = request.model
|
||||
|
||||
inference_config = {}
|
||||
param_mapping = {
|
||||
"max_tokens": "max_gen_len",
|
||||
"temperature": "temperature",
|
||||
"top_p": "top_p",
|
||||
}
|
||||
sampling_params = request.sampling_params
|
||||
options = get_sampling_strategy_options(sampling_params)
|
||||
|
||||
for k, v in param_mapping.items():
|
||||
if getattr(request.sampling_params, k):
|
||||
inference_config[v] = getattr(request.sampling_params, k)
|
||||
if sampling_params.max_tokens:
|
||||
options["max_gen_len"] = sampling_params.max_tokens
|
||||
if sampling_params.repetition_penalty > 0:
|
||||
options["repetition_penalty"] = sampling_params.repetition_penalty
|
||||
|
||||
prompt = await chat_completion_request_to_prompt(
|
||||
request, self.get_llama_model(request.model), self.formatter
|
||||
|
@ -185,7 +183,7 @@ class BedrockInferenceAdapter(ModelRegistryHelper, Inference):
|
|||
"body": json.dumps(
|
||||
{
|
||||
"prompt": prompt,
|
||||
**inference_config,
|
||||
**options,
|
||||
}
|
||||
),
|
||||
}
|
||||
|
|
|
@ -9,6 +9,7 @@ from typing import AsyncGenerator, List, Optional, Union
|
|||
from cerebras.cloud.sdk import AsyncCerebras
|
||||
from llama_models.datatypes import CoreModelId
|
||||
from llama_models.llama3.api.chat_format import ChatFormat
|
||||
from llama_models.llama3.api.datatypes import TopKSamplingStrategy
|
||||
from llama_models.llama3.api.tokenizer import Tokenizer
|
||||
|
||||
from llama_stack.apis.common.content_types import InterleavedContent
|
||||
|
@ -172,7 +173,9 @@ class CerebrasInferenceAdapter(ModelRegistryHelper, Inference):
|
|||
async def _get_params(
|
||||
self, request: Union[ChatCompletionRequest, CompletionRequest]
|
||||
) -> dict:
|
||||
if request.sampling_params and request.sampling_params.top_k:
|
||||
if request.sampling_params and isinstance(
|
||||
request.sampling_params.strategy, TopKSamplingStrategy
|
||||
):
|
||||
raise ValueError("`top_k` not supported by Cerebras")
|
||||
|
||||
prompt = ""
|
||||
|
|
|
@ -48,6 +48,9 @@ from llama_stack.apis.inference import (
|
|||
ToolDefinition,
|
||||
ToolPromptFormat,
|
||||
)
|
||||
from llama_stack.providers.utils.inference.openai_compat import (
|
||||
get_sampling_strategy_options,
|
||||
)
|
||||
|
||||
|
||||
def convert_chat_completion_request(
|
||||
|
@ -77,6 +80,7 @@ def convert_chat_completion_request(
|
|||
if request.tool_prompt_format != ToolPromptFormat.json:
|
||||
warnings.warn("tool_prompt_format is not used by Groq. Ignoring.")
|
||||
|
||||
sampling_options = get_sampling_strategy_options(request.sampling_params)
|
||||
return CompletionCreateParams(
|
||||
model=request.model,
|
||||
messages=[_convert_message(message) for message in request.messages],
|
||||
|
@ -84,8 +88,8 @@ def convert_chat_completion_request(
|
|||
frequency_penalty=None,
|
||||
stream=request.stream,
|
||||
max_tokens=request.sampling_params.max_tokens or None,
|
||||
temperature=request.sampling_params.temperature,
|
||||
top_p=request.sampling_params.top_p,
|
||||
temperature=sampling_options.get("temperature", 1.0),
|
||||
top_p=sampling_options.get("top_p", 1.0),
|
||||
tools=[_convert_groq_tool_definition(tool) for tool in request.tools or []],
|
||||
tool_choice=request.tool_choice.value if request.tool_choice else None,
|
||||
)
|
||||
|
|
|
@ -8,6 +8,11 @@ import json
|
|||
import warnings
|
||||
from typing import Any, AsyncGenerator, Dict, Generator, List, Optional
|
||||
|
||||
from llama_models.datatypes import (
|
||||
GreedySamplingStrategy,
|
||||
TopKSamplingStrategy,
|
||||
TopPSamplingStrategy,
|
||||
)
|
||||
from llama_models.llama3.api.datatypes import (
|
||||
BuiltinTool,
|
||||
StopReason,
|
||||
|
@ -263,19 +268,20 @@ def convert_chat_completion_request(
|
|||
if request.sampling_params.max_tokens:
|
||||
payload.update(max_tokens=request.sampling_params.max_tokens)
|
||||
|
||||
if request.sampling_params.strategy == "top_p":
|
||||
strategy = request.sampling_params.strategy
|
||||
if isinstance(strategy, TopPSamplingStrategy):
|
||||
nvext.update(top_k=-1)
|
||||
payload.update(top_p=request.sampling_params.top_p)
|
||||
elif request.sampling_params.strategy == "top_k":
|
||||
if (
|
||||
request.sampling_params.top_k != -1
|
||||
and request.sampling_params.top_k < 1
|
||||
):
|
||||
payload.update(top_p=strategy.top_p)
|
||||
payload.update(temperature=strategy.temperature)
|
||||
elif isinstance(strategy, TopKSamplingStrategy):
|
||||
if strategy.top_k != -1 and strategy.top_k < 1:
|
||||
warnings.warn("top_k must be -1 or >= 1")
|
||||
nvext.update(top_k=request.sampling_params.top_k)
|
||||
elif request.sampling_params.strategy == "greedy":
|
||||
nvext.update(top_k=strategy.top_k)
|
||||
elif isinstance(strategy, GreedySamplingStrategy):
|
||||
nvext.update(top_k=-1)
|
||||
payload.update(temperature=request.sampling_params.temperature)
|
||||
payload.update(temperature=strategy.temperature)
|
||||
else:
|
||||
raise ValueError(f"Unsupported sampling strategy: {strategy}")
|
||||
|
||||
return payload
|
||||
|
||||
|
|
|
@ -7,6 +7,7 @@
|
|||
import os
|
||||
|
||||
import pytest
|
||||
from llama_models.datatypes import SamplingParams, TopPSamplingStrategy
|
||||
from llama_models.llama3.api.datatypes import BuiltinTool
|
||||
|
||||
from llama_stack.apis.agents import (
|
||||
|
@ -22,7 +23,8 @@ from llama_stack.apis.agents import (
|
|||
ToolExecutionStep,
|
||||
Turn,
|
||||
)
|
||||
from llama_stack.apis.inference import CompletionMessage, SamplingParams, UserMessage
|
||||
|
||||
from llama_stack.apis.inference import CompletionMessage, UserMessage
|
||||
from llama_stack.apis.safety import ViolationLevel
|
||||
from llama_stack.providers.datatypes import Api
|
||||
|
||||
|
@ -42,7 +44,9 @@ def common_params(inference_model):
|
|||
model=inference_model,
|
||||
instructions="You are a helpful assistant.",
|
||||
enable_session_persistence=True,
|
||||
sampling_params=SamplingParams(temperature=0.7, top_p=0.95),
|
||||
sampling_params=SamplingParams(
|
||||
strategy=TopPSamplingStrategy(temperature=0.7, top_p=0.95)
|
||||
),
|
||||
input_shields=[],
|
||||
output_shields=[],
|
||||
toolgroups=[],
|
||||
|
|
|
@ -21,6 +21,7 @@ from groq.types.chat.chat_completion_message_tool_call import (
|
|||
Function,
|
||||
)
|
||||
from groq.types.shared.function_definition import FunctionDefinition
|
||||
from llama_models.datatypes import GreedySamplingStrategy, TopPSamplingStrategy
|
||||
from llama_models.llama3.api.datatypes import ToolParamDefinition
|
||||
from llama_stack.apis.inference import (
|
||||
ChatCompletionRequest,
|
||||
|
@ -152,21 +153,30 @@ class TestConvertChatCompletionRequest:
|
|||
|
||||
assert converted["max_tokens"] == 100
|
||||
|
||||
def test_includes_temperature(self):
|
||||
def _dummy_chat_completion_request(self):
|
||||
return ChatCompletionRequest(
|
||||
model="Llama-3.2-3B",
|
||||
messages=[UserMessage(content="Hello World")],
|
||||
)
|
||||
|
||||
def test_includes_stratgy(self):
|
||||
request = self._dummy_chat_completion_request()
|
||||
request.sampling_params.temperature = 0.5
|
||||
request.sampling_params.strategy = TopPSamplingStrategy(
|
||||
temperature=0.5, top_p=0.95
|
||||
)
|
||||
|
||||
converted = convert_chat_completion_request(request)
|
||||
|
||||
assert converted["temperature"] == 0.5
|
||||
assert converted["top_p"] == 0.95
|
||||
|
||||
def test_includes_top_p(self):
|
||||
def test_includes_greedy_strategy(self):
|
||||
request = self._dummy_chat_completion_request()
|
||||
request.sampling_params.top_p = 0.95
|
||||
request.sampling_params.strategy = GreedySamplingStrategy()
|
||||
|
||||
converted = convert_chat_completion_request(request)
|
||||
|
||||
assert converted["top_p"] == 0.95
|
||||
assert converted["temperature"] == 0.0
|
||||
|
||||
def test_includes_tool_choice(self):
|
||||
request = self._dummy_chat_completion_request()
|
||||
|
@ -268,12 +278,6 @@ class TestConvertChatCompletionRequest:
|
|||
},
|
||||
]
|
||||
|
||||
def _dummy_chat_completion_request(self):
|
||||
return ChatCompletionRequest(
|
||||
model="Llama-3.2-3B",
|
||||
messages=[UserMessage(content="Hello World")],
|
||||
)
|
||||
|
||||
|
||||
class TestConvertNonStreamChatCompletionResponse:
|
||||
def test_returns_response(self):
|
||||
|
@ -409,19 +413,19 @@ class TestConvertStreamChatCompletionResponse:
|
|||
iter = converted.__aiter__()
|
||||
chunk = await iter.__anext__()
|
||||
assert chunk.event.event_type == ChatCompletionResponseEventType.start
|
||||
assert chunk.event.delta == "Hello "
|
||||
assert chunk.event.delta.text == "Hello "
|
||||
|
||||
chunk = await iter.__anext__()
|
||||
assert chunk.event.event_type == ChatCompletionResponseEventType.progress
|
||||
assert chunk.event.delta == "World "
|
||||
assert chunk.event.delta.text == "World "
|
||||
|
||||
chunk = await iter.__anext__()
|
||||
assert chunk.event.event_type == ChatCompletionResponseEventType.progress
|
||||
assert chunk.event.delta == " !"
|
||||
assert chunk.event.delta.text == " !"
|
||||
|
||||
chunk = await iter.__anext__()
|
||||
assert chunk.event.event_type == ChatCompletionResponseEventType.complete
|
||||
assert chunk.event.delta == ""
|
||||
assert chunk.event.delta.text == ""
|
||||
assert chunk.event.stop_reason == StopReason.end_of_turn
|
||||
|
||||
with pytest.raises(StopAsyncIteration):
|
||||
|
|
|
@ -32,6 +32,7 @@ from llama_stack.apis.inference import (
|
|||
UserMessage,
|
||||
)
|
||||
from llama_stack.apis.models import Model
|
||||
|
||||
from .utils import group_chunks
|
||||
|
||||
|
||||
|
@ -476,7 +477,7 @@ class TestInference:
|
|||
last = grouped[ChatCompletionResponseEventType.progress][-1]
|
||||
# assert last.event.stop_reason == expected_stop_reason
|
||||
assert last.event.delta.parse_status == ToolCallParseStatus.succeeded
|
||||
assert last.event.delta.content.type == "tool_call"
|
||||
assert isinstance(last.event.delta.content, ToolCall)
|
||||
|
||||
call = last.event.delta.content
|
||||
assert call.tool_name == "get_weather"
|
||||
|
|
|
@ -8,7 +8,13 @@ from typing import AsyncGenerator, List, Optional
|
|||
|
||||
from llama_models.llama3.api.chat_format import ChatFormat
|
||||
|
||||
from llama_models.llama3.api.datatypes import SamplingParams, StopReason
|
||||
from llama_models.llama3.api.datatypes import (
|
||||
GreedySamplingStrategy,
|
||||
SamplingParams,
|
||||
StopReason,
|
||||
TopKSamplingStrategy,
|
||||
TopPSamplingStrategy,
|
||||
)
|
||||
from pydantic import BaseModel
|
||||
|
||||
from llama_stack.apis.common.content_types import (
|
||||
|
@ -49,12 +55,27 @@ class OpenAICompatCompletionResponse(BaseModel):
|
|||
choices: List[OpenAICompatCompletionChoice]
|
||||
|
||||
|
||||
def get_sampling_strategy_options(params: SamplingParams) -> dict:
|
||||
options = {}
|
||||
if isinstance(params.strategy, GreedySamplingStrategy):
|
||||
options["temperature"] = 0.0
|
||||
elif isinstance(params.strategy, TopPSamplingStrategy):
|
||||
options["temperature"] = params.strategy.temperature
|
||||
options["top_p"] = params.strategy.top_p
|
||||
elif isinstance(params.strategy, TopKSamplingStrategy):
|
||||
options["top_k"] = params.strategy.top_k
|
||||
else:
|
||||
raise ValueError(f"Unsupported sampling strategy: {params.strategy}")
|
||||
|
||||
return options
|
||||
|
||||
|
||||
def get_sampling_options(params: SamplingParams) -> dict:
|
||||
options = {}
|
||||
if params:
|
||||
for attr in {"temperature", "top_p", "top_k", "max_tokens"}:
|
||||
if getattr(params, attr):
|
||||
options[attr] = getattr(params, attr)
|
||||
options.update(get_sampling_strategy_options(params))
|
||||
if params.max_tokens:
|
||||
options["max_tokens"] = params.max_tokens
|
||||
|
||||
if params.repetition_penalty is not None and params.repetition_penalty != 1.0:
|
||||
options["repeat_penalty"] = params.repetition_penalty
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue