mirror of
https://github.com/meta-llama/llama-stack.git
synced 2026-01-01 00:14:31 +00:00
Update Strategy in SamplingParams to be a union
This commit is contained in:
parent
300e6e2702
commit
dea575c994
28 changed files with 600 additions and 377 deletions
|
|
@ -13,7 +13,6 @@ from termcolor import colored
|
|||
|
||||
from llama_stack.cli.subcommand import Subcommand
|
||||
from llama_stack.cli.table import print_table
|
||||
from llama_stack.distribution.utils.serialize import EnumEncoder
|
||||
|
||||
|
||||
class ModelDescribe(Subcommand):
|
||||
|
|
@ -72,7 +71,7 @@ class ModelDescribe(Subcommand):
|
|||
rows.append(
|
||||
(
|
||||
"Recommended sampling params",
|
||||
json.dumps(sampling_params, cls=EnumEncoder, indent=4),
|
||||
json.dumps(sampling_params, indent=4),
|
||||
)
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -58,11 +58,6 @@ def define_eval_candidate_2():
|
|||
|
||||
# Sampling Parameters
|
||||
st.markdown("##### Sampling Parameters")
|
||||
strategy = st.selectbox(
|
||||
"Strategy",
|
||||
["greedy", "top_p", "top_k"],
|
||||
index=0,
|
||||
)
|
||||
temperature = st.slider(
|
||||
"Temperature",
|
||||
min_value=0.0,
|
||||
|
|
@ -95,13 +90,20 @@ def define_eval_candidate_2():
|
|||
help="Controls the likelihood for generating the same word or phrase multiple times in the same sentence or paragraph. 1 implies no penalty, 2 will strongly discourage model to repeat words or phrases.",
|
||||
)
|
||||
if candidate_type == "model":
|
||||
if temperature > 0.0:
|
||||
strategy = {
|
||||
"type": "top_p",
|
||||
"temperature": temperature,
|
||||
"top_p": top_p,
|
||||
}
|
||||
else:
|
||||
strategy = {"type": "greedy"}
|
||||
|
||||
eval_candidate = {
|
||||
"type": "model",
|
||||
"model": selected_model,
|
||||
"sampling_params": {
|
||||
"strategy": strategy,
|
||||
"temperature": temperature,
|
||||
"top_p": top_p,
|
||||
"max_tokens": max_tokens,
|
||||
"repetition_penalty": repetition_penalty,
|
||||
},
|
||||
|
|
|
|||
|
|
@ -95,6 +95,15 @@ if prompt := st.chat_input("Example: What is Llama Stack?"):
|
|||
message_placeholder = st.empty()
|
||||
full_response = ""
|
||||
|
||||
if temperature > 0.0:
|
||||
strategy = {
|
||||
"type": "top_p",
|
||||
"temperature": temperature,
|
||||
"top_p": top_p,
|
||||
}
|
||||
else:
|
||||
strategy = {"type": "greedy"}
|
||||
|
||||
response = llama_stack_api.client.inference.chat_completion(
|
||||
messages=[
|
||||
{"role": "system", "content": system_prompt},
|
||||
|
|
@ -103,8 +112,7 @@ if prompt := st.chat_input("Example: What is Llama Stack?"):
|
|||
model_id=selected_model,
|
||||
stream=stream,
|
||||
sampling_params={
|
||||
"temperature": temperature,
|
||||
"top_p": top_p,
|
||||
"strategy": strategy,
|
||||
"max_tokens": max_tokens,
|
||||
"repetition_penalty": repetition_penalty,
|
||||
},
|
||||
|
|
|
|||
|
|
@ -118,13 +118,20 @@ def rag_chat_page():
|
|||
with st.chat_message(message["role"]):
|
||||
st.markdown(message["content"])
|
||||
|
||||
if temperature > 0.0:
|
||||
strategy = {
|
||||
"type": "top_p",
|
||||
"temperature": temperature,
|
||||
"top_p": top_p,
|
||||
}
|
||||
else:
|
||||
strategy = {"type": "greedy"}
|
||||
|
||||
agent_config = AgentConfig(
|
||||
model=selected_model,
|
||||
instructions=system_prompt,
|
||||
sampling_params={
|
||||
"strategy": "greedy",
|
||||
"temperature": temperature,
|
||||
"top_p": top_p,
|
||||
"strategy": strategy,
|
||||
},
|
||||
tools=[
|
||||
{
|
||||
|
|
|
|||
|
|
@ -23,6 +23,11 @@ from fairscale.nn.model_parallel.initialize import (
|
|||
initialize_model_parallel,
|
||||
model_parallel_is_initialized,
|
||||
)
|
||||
from llama_models.datatypes import (
|
||||
GreedySamplingStrategy,
|
||||
SamplingParams,
|
||||
TopPSamplingStrategy,
|
||||
)
|
||||
from llama_models.llama3.api.args import ModelArgs
|
||||
from llama_models.llama3.api.chat_format import ChatFormat, LLMInput
|
||||
from llama_models.llama3.api.datatypes import Model
|
||||
|
|
@ -363,11 +368,12 @@ class Llama:
|
|||
max_gen_len = self.model.params.max_seq_len - 1
|
||||
|
||||
model_input = self.formatter.encode_content(request.content)
|
||||
temperature, top_p = _infer_sampling_params(sampling_params)
|
||||
yield from self.generate(
|
||||
model_input=model_input,
|
||||
max_gen_len=max_gen_len,
|
||||
temperature=sampling_params.temperature,
|
||||
top_p=sampling_params.top_p,
|
||||
temperature=temperature,
|
||||
top_p=top_p,
|
||||
logprobs=bool(request.logprobs),
|
||||
include_stop_token=True,
|
||||
logits_processor=get_logits_processor(
|
||||
|
|
@ -390,14 +396,15 @@ class Llama:
|
|||
):
|
||||
max_gen_len = self.model.params.max_seq_len - 1
|
||||
|
||||
temperature, top_p = _infer_sampling_params(sampling_params)
|
||||
yield from self.generate(
|
||||
model_input=self.formatter.encode_dialog_prompt(
|
||||
request.messages,
|
||||
request.tool_prompt_format,
|
||||
),
|
||||
max_gen_len=max_gen_len,
|
||||
temperature=sampling_params.temperature,
|
||||
top_p=sampling_params.top_p,
|
||||
temperature=temperature,
|
||||
top_p=top_p,
|
||||
logprobs=bool(request.logprobs),
|
||||
include_stop_token=True,
|
||||
logits_processor=get_logits_processor(
|
||||
|
|
@ -492,3 +499,15 @@ def _build_regular_tokens_list(
|
|||
is_word_start_token = len(decoded_after_0) > len(decoded_regular)
|
||||
regular_tokens.append((token_idx, decoded_after_0, is_word_start_token))
|
||||
return regular_tokens
|
||||
|
||||
|
||||
def _infer_sampling_params(sampling_params: SamplingParams):
|
||||
if isinstance(sampling_params.strategy, GreedySamplingStrategy):
|
||||
temperature = 0.0
|
||||
top_p = 1.0
|
||||
elif isinstance(sampling_params.strategy, TopPSamplingStrategy):
|
||||
temperature = sampling_params.strategy.temperature
|
||||
top_p = sampling_params.strategy.top_p
|
||||
else:
|
||||
raise ValueError(f"Unsupported sampling strategy {sampling_params.strategy}")
|
||||
return temperature, top_p
|
||||
|
|
|
|||
|
|
@ -36,6 +36,7 @@ from llama_stack.apis.inference import (
|
|||
from llama_stack.apis.models import Model
|
||||
from llama_stack.providers.datatypes import ModelsProtocolPrivate
|
||||
from llama_stack.providers.utils.inference.openai_compat import (
|
||||
get_sampling_options,
|
||||
OpenAICompatCompletionChoice,
|
||||
OpenAICompatCompletionResponse,
|
||||
process_chat_completion_response,
|
||||
|
|
@ -126,21 +127,12 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate):
|
|||
if sampling_params is None:
|
||||
return VLLMSamplingParams(max_tokens=self.config.max_tokens)
|
||||
|
||||
# TODO convert what I saw in my first test ... but surely there's more to do here
|
||||
kwargs = {
|
||||
"temperature": sampling_params.temperature,
|
||||
"max_tokens": self.config.max_tokens,
|
||||
}
|
||||
if sampling_params.top_k:
|
||||
kwargs["top_k"] = sampling_params.top_k
|
||||
if sampling_params.top_p:
|
||||
kwargs["top_p"] = sampling_params.top_p
|
||||
if sampling_params.max_tokens:
|
||||
kwargs["max_tokens"] = sampling_params.max_tokens
|
||||
if sampling_params.repetition_penalty > 0:
|
||||
kwargs["repetition_penalty"] = sampling_params.repetition_penalty
|
||||
options = get_sampling_options(sampling_params)
|
||||
if "repeat_penalty" in options:
|
||||
options["repetition_penalty"] = options["repeat_penalty"]
|
||||
del options["repeat_penalty"]
|
||||
|
||||
return VLLMSamplingParams(**kwargs)
|
||||
return VLLMSamplingParams(**options)
|
||||
|
||||
async def unregister_model(self, model_id: str) -> None:
|
||||
pass
|
||||
|
|
|
|||
|
|
@ -34,6 +34,7 @@ from llama_stack.providers.utils.inference.model_registry import (
|
|||
ModelRegistryHelper,
|
||||
)
|
||||
from llama_stack.providers.utils.inference.openai_compat import (
|
||||
get_sampling_strategy_options,
|
||||
OpenAICompatCompletionChoice,
|
||||
OpenAICompatCompletionResponse,
|
||||
process_chat_completion_response,
|
||||
|
|
@ -166,16 +167,13 @@ class BedrockInferenceAdapter(ModelRegistryHelper, Inference):
|
|||
) -> Dict:
|
||||
bedrock_model = request.model
|
||||
|
||||
inference_config = {}
|
||||
param_mapping = {
|
||||
"max_tokens": "max_gen_len",
|
||||
"temperature": "temperature",
|
||||
"top_p": "top_p",
|
||||
}
|
||||
sampling_params = request.sampling_params
|
||||
options = get_sampling_strategy_options(sampling_params)
|
||||
|
||||
for k, v in param_mapping.items():
|
||||
if getattr(request.sampling_params, k):
|
||||
inference_config[v] = getattr(request.sampling_params, k)
|
||||
if sampling_params.max_tokens:
|
||||
options["max_gen_len"] = sampling_params.max_tokens
|
||||
if sampling_params.repetition_penalty > 0:
|
||||
options["repetition_penalty"] = sampling_params.repetition_penalty
|
||||
|
||||
prompt = await chat_completion_request_to_prompt(
|
||||
request, self.get_llama_model(request.model), self.formatter
|
||||
|
|
@ -185,7 +183,7 @@ class BedrockInferenceAdapter(ModelRegistryHelper, Inference):
|
|||
"body": json.dumps(
|
||||
{
|
||||
"prompt": prompt,
|
||||
**inference_config,
|
||||
**options,
|
||||
}
|
||||
),
|
||||
}
|
||||
|
|
|
|||
|
|
@ -9,6 +9,7 @@ from typing import AsyncGenerator, List, Optional, Union
|
|||
from cerebras.cloud.sdk import AsyncCerebras
|
||||
from llama_models.datatypes import CoreModelId
|
||||
from llama_models.llama3.api.chat_format import ChatFormat
|
||||
from llama_models.llama3.api.datatypes import TopKSamplingStrategy
|
||||
from llama_models.llama3.api.tokenizer import Tokenizer
|
||||
|
||||
from llama_stack.apis.common.content_types import InterleavedContent
|
||||
|
|
@ -172,7 +173,9 @@ class CerebrasInferenceAdapter(ModelRegistryHelper, Inference):
|
|||
async def _get_params(
|
||||
self, request: Union[ChatCompletionRequest, CompletionRequest]
|
||||
) -> dict:
|
||||
if request.sampling_params and request.sampling_params.top_k:
|
||||
if request.sampling_params and isinstance(
|
||||
request.sampling_params.strategy, TopKSamplingStrategy
|
||||
):
|
||||
raise ValueError("`top_k` not supported by Cerebras")
|
||||
|
||||
prompt = ""
|
||||
|
|
|
|||
|
|
@ -48,6 +48,9 @@ from llama_stack.apis.inference import (
|
|||
ToolDefinition,
|
||||
ToolPromptFormat,
|
||||
)
|
||||
from llama_stack.providers.utils.inference.openai_compat import (
|
||||
get_sampling_strategy_options,
|
||||
)
|
||||
|
||||
|
||||
def convert_chat_completion_request(
|
||||
|
|
@ -77,6 +80,7 @@ def convert_chat_completion_request(
|
|||
if request.tool_prompt_format != ToolPromptFormat.json:
|
||||
warnings.warn("tool_prompt_format is not used by Groq. Ignoring.")
|
||||
|
||||
sampling_options = get_sampling_strategy_options(request.sampling_params)
|
||||
return CompletionCreateParams(
|
||||
model=request.model,
|
||||
messages=[_convert_message(message) for message in request.messages],
|
||||
|
|
@ -84,8 +88,8 @@ def convert_chat_completion_request(
|
|||
frequency_penalty=None,
|
||||
stream=request.stream,
|
||||
max_tokens=request.sampling_params.max_tokens or None,
|
||||
temperature=request.sampling_params.temperature,
|
||||
top_p=request.sampling_params.top_p,
|
||||
temperature=sampling_options.get("temperature", 1.0),
|
||||
top_p=sampling_options.get("top_p", 1.0),
|
||||
tools=[_convert_groq_tool_definition(tool) for tool in request.tools or []],
|
||||
tool_choice=request.tool_choice.value if request.tool_choice else None,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -263,19 +263,18 @@ def convert_chat_completion_request(
|
|||
if request.sampling_params.max_tokens:
|
||||
payload.update(max_tokens=request.sampling_params.max_tokens)
|
||||
|
||||
if request.sampling_params.strategy == "top_p":
|
||||
strategy = request.sampling_params.strategy
|
||||
if isinstance(strategy, TopPSamplingStrategy):
|
||||
nvext.update(top_k=-1)
|
||||
payload.update(top_p=request.sampling_params.top_p)
|
||||
elif request.sampling_params.strategy == "top_k":
|
||||
if (
|
||||
request.sampling_params.top_k != -1
|
||||
and request.sampling_params.top_k < 1
|
||||
):
|
||||
payload.update(top_p=strategy.top_p)
|
||||
payload.update(temperature=strategy.temperature)
|
||||
elif isinstance(strategy, TopKSamplingStrategy):
|
||||
if strategy.top_k != -1 and strategy.top_k < 1:
|
||||
warnings.warn("top_k must be -1 or >= 1")
|
||||
nvext.update(top_k=request.sampling_params.top_k)
|
||||
elif request.sampling_params.strategy == "greedy":
|
||||
nvext.update(top_k=strategy.top_k)
|
||||
elif strategy.strategy == "greedy":
|
||||
nvext.update(top_k=-1)
|
||||
payload.update(temperature=request.sampling_params.temperature)
|
||||
payload.update(temperature=strategy.temperature)
|
||||
|
||||
return payload
|
||||
|
||||
|
|
|
|||
|
|
@ -22,7 +22,12 @@ from llama_stack.apis.agents import (
|
|||
ToolExecutionStep,
|
||||
Turn,
|
||||
)
|
||||
from llama_stack.apis.inference import CompletionMessage, SamplingParams, UserMessage
|
||||
from llama_stack.apis.inference import (
|
||||
CompletionMessage,
|
||||
SamplingParams,
|
||||
TopPSamplingStrategy,
|
||||
UserMessage,
|
||||
)
|
||||
from llama_stack.apis.safety import ViolationLevel
|
||||
from llama_stack.providers.datatypes import Api
|
||||
|
||||
|
|
@ -42,7 +47,9 @@ def common_params(inference_model):
|
|||
model=inference_model,
|
||||
instructions="You are a helpful assistant.",
|
||||
enable_session_persistence=True,
|
||||
sampling_params=SamplingParams(temperature=0.7, top_p=0.95),
|
||||
sampling_params=SamplingParams(
|
||||
strategy=TopPSamplingStrategy(temperature=0.7, top_p=0.95)
|
||||
),
|
||||
input_shields=[],
|
||||
output_shields=[],
|
||||
toolgroups=[],
|
||||
|
|
|
|||
|
|
@ -21,6 +21,7 @@ from groq.types.chat.chat_completion_message_tool_call import (
|
|||
Function,
|
||||
)
|
||||
from groq.types.shared.function_definition import FunctionDefinition
|
||||
from llama_models.datatypes import GreedySamplingStrategy, TopPSamplingStrategy
|
||||
from llama_models.llama3.api.datatypes import ToolParamDefinition
|
||||
from llama_stack.apis.inference import (
|
||||
ChatCompletionRequest,
|
||||
|
|
@ -152,21 +153,30 @@ class TestConvertChatCompletionRequest:
|
|||
|
||||
assert converted["max_tokens"] == 100
|
||||
|
||||
def test_includes_temperature(self):
|
||||
def _dummy_chat_completion_request(self):
|
||||
return ChatCompletionRequest(
|
||||
model="Llama-3.2-3B",
|
||||
messages=[UserMessage(content="Hello World")],
|
||||
)
|
||||
|
||||
def test_includes_stratgy(self):
|
||||
request = self._dummy_chat_completion_request()
|
||||
request.sampling_params.temperature = 0.5
|
||||
request.sampling_params.strategy = TopPSamplingStrategy(
|
||||
temperature=0.5, top_p=0.95
|
||||
)
|
||||
|
||||
converted = convert_chat_completion_request(request)
|
||||
|
||||
assert converted["temperature"] == 0.5
|
||||
assert converted["top_p"] == 0.95
|
||||
|
||||
def test_includes_top_p(self):
|
||||
def test_includes_greedy_strategy(self):
|
||||
request = self._dummy_chat_completion_request()
|
||||
request.sampling_params.top_p = 0.95
|
||||
request.sampling_params.strategy = GreedySamplingStrategy()
|
||||
|
||||
converted = convert_chat_completion_request(request)
|
||||
|
||||
assert converted["top_p"] == 0.95
|
||||
assert converted["temperature"] == 0.0
|
||||
|
||||
def test_includes_tool_choice(self):
|
||||
request = self._dummy_chat_completion_request()
|
||||
|
|
|
|||
|
|
@ -8,7 +8,13 @@ from typing import AsyncGenerator, List, Optional
|
|||
|
||||
from llama_models.llama3.api.chat_format import ChatFormat
|
||||
|
||||
from llama_models.llama3.api.datatypes import SamplingParams, StopReason
|
||||
from llama_models.llama3.api.datatypes import (
|
||||
GreedySamplingStrategy,
|
||||
SamplingParams,
|
||||
StopReason,
|
||||
TopKSamplingStrategy,
|
||||
TopPSamplingStrategy,
|
||||
)
|
||||
from pydantic import BaseModel
|
||||
|
||||
from llama_stack.apis.common.content_types import (
|
||||
|
|
@ -49,12 +55,26 @@ class OpenAICompatCompletionResponse(BaseModel):
|
|||
choices: List[OpenAICompatCompletionChoice]
|
||||
|
||||
|
||||
def get_sampling_strategy_options(params: SamplingParams) -> dict:
|
||||
options = {}
|
||||
if isinstance(params.strategy, GreedySamplingStrategy):
|
||||
options["temperature"] = 0.0
|
||||
elif isinstance(params.strategy, TopPSamplingStrategy):
|
||||
options["temperature"] = params.strategy.temperature
|
||||
options["top_p"] = params.strategy.top_p
|
||||
elif isinstance(params.strategy, TopKSamplingStrategy):
|
||||
options["top_k"] = params.strategy.top_k
|
||||
else:
|
||||
raise ValueError(f"Unsupported sampling strategy: {params.strategy}")
|
||||
|
||||
return options
|
||||
|
||||
|
||||
def get_sampling_options(params: SamplingParams) -> dict:
|
||||
options = {}
|
||||
if params:
|
||||
for attr in {"temperature", "top_p", "top_k", "max_tokens"}:
|
||||
if getattr(params, attr):
|
||||
options[attr] = getattr(params, attr)
|
||||
options.update(get_sampling_strategy_options(params))
|
||||
options["max_tokens"] = params.max_tokens
|
||||
|
||||
if params.repetition_penalty is not None and params.repetition_penalty != 1.0:
|
||||
options["repeat_penalty"] = params.repetition_penalty
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue