mirror of
https://github.com/meta-llama/llama-stack.git
synced 2026-01-01 22:24:31 +00:00
Update Strategy in SamplingParams to be a union
This commit is contained in:
parent
300e6e2702
commit
dea575c994
28 changed files with 600 additions and 377 deletions
|
|
@ -8,7 +8,13 @@ from typing import AsyncGenerator, List, Optional
|
|||
|
||||
from llama_models.llama3.api.chat_format import ChatFormat
|
||||
|
||||
from llama_models.llama3.api.datatypes import SamplingParams, StopReason
|
||||
from llama_models.llama3.api.datatypes import (
|
||||
GreedySamplingStrategy,
|
||||
SamplingParams,
|
||||
StopReason,
|
||||
TopKSamplingStrategy,
|
||||
TopPSamplingStrategy,
|
||||
)
|
||||
from pydantic import BaseModel
|
||||
|
||||
from llama_stack.apis.common.content_types import (
|
||||
|
|
@ -49,12 +55,26 @@ class OpenAICompatCompletionResponse(BaseModel):
|
|||
choices: List[OpenAICompatCompletionChoice]
|
||||
|
||||
|
||||
def get_sampling_strategy_options(params: SamplingParams) -> dict:
|
||||
options = {}
|
||||
if isinstance(params.strategy, GreedySamplingStrategy):
|
||||
options["temperature"] = 0.0
|
||||
elif isinstance(params.strategy, TopPSamplingStrategy):
|
||||
options["temperature"] = params.strategy.temperature
|
||||
options["top_p"] = params.strategy.top_p
|
||||
elif isinstance(params.strategy, TopKSamplingStrategy):
|
||||
options["top_k"] = params.strategy.top_k
|
||||
else:
|
||||
raise ValueError(f"Unsupported sampling strategy: {params.strategy}")
|
||||
|
||||
return options
|
||||
|
||||
|
||||
def get_sampling_options(params: SamplingParams) -> dict:
|
||||
options = {}
|
||||
if params:
|
||||
for attr in {"temperature", "top_p", "top_k", "max_tokens"}:
|
||||
if getattr(params, attr):
|
||||
options[attr] = getattr(params, attr)
|
||||
options.update(get_sampling_strategy_options(params))
|
||||
options["max_tokens"] = params.max_tokens
|
||||
|
||||
if params.repetition_penalty is not None and params.repetition_penalty != 1.0:
|
||||
options["repeat_penalty"] = params.repetition_penalty
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue