Convert SamplingParams.strategy to a union (#767)

# What does this PR do?

Cleans up how we provide sampling params. Earlier, strategy was an enum
and all params (top_p, temperature, top_k) across all strategies were
grouped. We now have a strategy union object with each strategy (greedy,
top_p, top_k) having its corresponding params.
Earlier, 
```
class SamplingParams: 
    strategy: enum ()
    top_p, temperature, top_k and other params
```
However, the `strategy` field was not being used in any providers making
it confusing to know the exact sampling behavior purely based on the
params since you could pass temperature, top_p, top_k and how the
provider would interpret those would not be clear.

Hence we introduced -- a union where the strategy and relevant params
are all clubbed together to avoid this confusion.

Have updated all providers, tests, notebooks, readme and otehr places
where sampling params was being used to use the new format.
   

## Test Plan
`pytest llama_stack/providers/tests/inference/groq/test_groq_utils.py`
// inference on ollama, fireworks and together 
`with-proxy pytest -v -s -k "ollama"
--inference-model="meta-llama/Llama-3.1-8B-Instruct"
llama_stack/providers/tests/inference/test_text_inference.py `
// agents on fireworks 
`pytest -v -s -k 'fireworks and create_agent'
--inference-model="meta-llama/Llama-3.1-8B-Instruct"
llama_stack/providers/tests/agents/test_agents.py
--safety-shield="meta-llama/Llama-Guard-3-8B"`

## Before submitting

- [ ] This PR fixes a typo or improves the docs (you can dismiss the
other checks if that's the case).
- [X] Ran pre-commit to handle lint / formatting issues.
- [X] Read the [contributor
guideline](https://github.com/meta-llama/llama-stack/blob/main/CONTRIBUTING.md),
      Pull Request section?
- [X] Updated relevant documentation.
- [X] Wrote necessary unit or integration tests.

---------

Co-authored-by: Hardik Shah <hjshah@fb.com>
This commit is contained in:
Hardik Shah 2025-01-15 05:38:51 -08:00 committed by GitHub
parent 300e6e2702
commit a51c8b4efc
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
29 changed files with 611 additions and 388 deletions

View file

@ -34,6 +34,7 @@ from llama_stack.providers.utils.inference.model_registry import (
ModelRegistryHelper,
)
from llama_stack.providers.utils.inference.openai_compat import (
get_sampling_strategy_options,
OpenAICompatCompletionChoice,
OpenAICompatCompletionResponse,
process_chat_completion_response,
@ -166,16 +167,13 @@ class BedrockInferenceAdapter(ModelRegistryHelper, Inference):
) -> Dict:
bedrock_model = request.model
inference_config = {}
param_mapping = {
"max_tokens": "max_gen_len",
"temperature": "temperature",
"top_p": "top_p",
}
sampling_params = request.sampling_params
options = get_sampling_strategy_options(sampling_params)
for k, v in param_mapping.items():
if getattr(request.sampling_params, k):
inference_config[v] = getattr(request.sampling_params, k)
if sampling_params.max_tokens:
options["max_gen_len"] = sampling_params.max_tokens
if sampling_params.repetition_penalty > 0:
options["repetition_penalty"] = sampling_params.repetition_penalty
prompt = await chat_completion_request_to_prompt(
request, self.get_llama_model(request.model), self.formatter
@ -185,7 +183,7 @@ class BedrockInferenceAdapter(ModelRegistryHelper, Inference):
"body": json.dumps(
{
"prompt": prompt,
**inference_config,
**options,
}
),
}

View file

@ -9,6 +9,7 @@ from typing import AsyncGenerator, List, Optional, Union
from cerebras.cloud.sdk import AsyncCerebras
from llama_models.datatypes import CoreModelId
from llama_models.llama3.api.chat_format import ChatFormat
from llama_models.llama3.api.datatypes import TopKSamplingStrategy
from llama_models.llama3.api.tokenizer import Tokenizer
from llama_stack.apis.common.content_types import InterleavedContent
@ -172,7 +173,9 @@ class CerebrasInferenceAdapter(ModelRegistryHelper, Inference):
async def _get_params(
self, request: Union[ChatCompletionRequest, CompletionRequest]
) -> dict:
if request.sampling_params and request.sampling_params.top_k:
if request.sampling_params and isinstance(
request.sampling_params.strategy, TopKSamplingStrategy
):
raise ValueError("`top_k` not supported by Cerebras")
prompt = ""

View file

@ -48,6 +48,9 @@ from llama_stack.apis.inference import (
ToolDefinition,
ToolPromptFormat,
)
from llama_stack.providers.utils.inference.openai_compat import (
get_sampling_strategy_options,
)
def convert_chat_completion_request(
@ -77,6 +80,7 @@ def convert_chat_completion_request(
if request.tool_prompt_format != ToolPromptFormat.json:
warnings.warn("tool_prompt_format is not used by Groq. Ignoring.")
sampling_options = get_sampling_strategy_options(request.sampling_params)
return CompletionCreateParams(
model=request.model,
messages=[_convert_message(message) for message in request.messages],
@ -84,8 +88,8 @@ def convert_chat_completion_request(
frequency_penalty=None,
stream=request.stream,
max_tokens=request.sampling_params.max_tokens or None,
temperature=request.sampling_params.temperature,
top_p=request.sampling_params.top_p,
temperature=sampling_options.get("temperature", 1.0),
top_p=sampling_options.get("top_p", 1.0),
tools=[_convert_groq_tool_definition(tool) for tool in request.tools or []],
tool_choice=request.tool_choice.value if request.tool_choice else None,
)

View file

@ -8,6 +8,11 @@ import json
import warnings
from typing import Any, AsyncGenerator, Dict, Generator, List, Optional
from llama_models.datatypes import (
GreedySamplingStrategy,
TopKSamplingStrategy,
TopPSamplingStrategy,
)
from llama_models.llama3.api.datatypes import (
BuiltinTool,
StopReason,
@ -263,19 +268,20 @@ def convert_chat_completion_request(
if request.sampling_params.max_tokens:
payload.update(max_tokens=request.sampling_params.max_tokens)
if request.sampling_params.strategy == "top_p":
strategy = request.sampling_params.strategy
if isinstance(strategy, TopPSamplingStrategy):
nvext.update(top_k=-1)
payload.update(top_p=request.sampling_params.top_p)
elif request.sampling_params.strategy == "top_k":
if (
request.sampling_params.top_k != -1
and request.sampling_params.top_k < 1
):
payload.update(top_p=strategy.top_p)
payload.update(temperature=strategy.temperature)
elif isinstance(strategy, TopKSamplingStrategy):
if strategy.top_k != -1 and strategy.top_k < 1:
warnings.warn("top_k must be -1 or >= 1")
nvext.update(top_k=request.sampling_params.top_k)
elif request.sampling_params.strategy == "greedy":
nvext.update(top_k=strategy.top_k)
elif isinstance(strategy, GreedySamplingStrategy):
nvext.update(top_k=-1)
payload.update(temperature=request.sampling_params.temperature)
payload.update(temperature=strategy.temperature)
else:
raise ValueError(f"Unsupported sampling strategy: {strategy}")
return payload