mirror of
https://github.com/meta-llama/llama-stack.git
synced 2026-01-02 11:50:01 +00:00
Update Strategy in SamplingParams to be a union
This commit is contained in:
parent
300e6e2702
commit
dea575c994
28 changed files with 600 additions and 377 deletions
|
|
@ -23,6 +23,11 @@ from fairscale.nn.model_parallel.initialize import (
|
|||
initialize_model_parallel,
|
||||
model_parallel_is_initialized,
|
||||
)
|
||||
from llama_models.datatypes import (
|
||||
GreedySamplingStrategy,
|
||||
SamplingParams,
|
||||
TopPSamplingStrategy,
|
||||
)
|
||||
from llama_models.llama3.api.args import ModelArgs
|
||||
from llama_models.llama3.api.chat_format import ChatFormat, LLMInput
|
||||
from llama_models.llama3.api.datatypes import Model
|
||||
|
|
@ -363,11 +368,12 @@ class Llama:
|
|||
max_gen_len = self.model.params.max_seq_len - 1
|
||||
|
||||
model_input = self.formatter.encode_content(request.content)
|
||||
temperature, top_p = _infer_sampling_params(sampling_params)
|
||||
yield from self.generate(
|
||||
model_input=model_input,
|
||||
max_gen_len=max_gen_len,
|
||||
temperature=sampling_params.temperature,
|
||||
top_p=sampling_params.top_p,
|
||||
temperature=temperature,
|
||||
top_p=top_p,
|
||||
logprobs=bool(request.logprobs),
|
||||
include_stop_token=True,
|
||||
logits_processor=get_logits_processor(
|
||||
|
|
@ -390,14 +396,15 @@ class Llama:
|
|||
):
|
||||
max_gen_len = self.model.params.max_seq_len - 1
|
||||
|
||||
temperature, top_p = _infer_sampling_params(sampling_params)
|
||||
yield from self.generate(
|
||||
model_input=self.formatter.encode_dialog_prompt(
|
||||
request.messages,
|
||||
request.tool_prompt_format,
|
||||
),
|
||||
max_gen_len=max_gen_len,
|
||||
temperature=sampling_params.temperature,
|
||||
top_p=sampling_params.top_p,
|
||||
temperature=temperature,
|
||||
top_p=top_p,
|
||||
logprobs=bool(request.logprobs),
|
||||
include_stop_token=True,
|
||||
logits_processor=get_logits_processor(
|
||||
|
|
@ -492,3 +499,15 @@ def _build_regular_tokens_list(
|
|||
is_word_start_token = len(decoded_after_0) > len(decoded_regular)
|
||||
regular_tokens.append((token_idx, decoded_after_0, is_word_start_token))
|
||||
return regular_tokens
|
||||
|
||||
|
||||
def _infer_sampling_params(sampling_params: SamplingParams):
|
||||
if isinstance(sampling_params.strategy, GreedySamplingStrategy):
|
||||
temperature = 0.0
|
||||
top_p = 1.0
|
||||
elif isinstance(sampling_params.strategy, TopPSamplingStrategy):
|
||||
temperature = sampling_params.strategy.temperature
|
||||
top_p = sampling_params.strategy.top_p
|
||||
else:
|
||||
raise ValueError(f"Unsupported sampling strategy {sampling_params.strategy}")
|
||||
return temperature, top_p
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue