forked from phoenix-oss/llama-stack-mirror
# What does this PR do? Cleans up how we provide sampling params. Earlier, strategy was an enum and all params (top_p, temperature, top_k) across all strategies were grouped. We now have a strategy union object with each strategy (greedy, top_p, top_k) having its corresponding params. Earlier, ``` class SamplingParams: strategy: enum () top_p, temperature, top_k and other params ``` However, the `strategy` field was not being used in any providers making it confusing to know the exact sampling behavior purely based on the params since you could pass temperature, top_p, top_k and how the provider would interpret those would not be clear. Hence we introduced -- a union where the strategy and relevant params are all clubbed together to avoid this confusion. Have updated all providers, tests, notebooks, readme and otehr places where sampling params was being used to use the new format. ## Test Plan `pytest llama_stack/providers/tests/inference/groq/test_groq_utils.py` // inference on ollama, fireworks and together `with-proxy pytest -v -s -k "ollama" --inference-model="meta-llama/Llama-3.1-8B-Instruct" llama_stack/providers/tests/inference/test_text_inference.py ` // agents on fireworks `pytest -v -s -k 'fireworks and create_agent' --inference-model="meta-llama/Llama-3.1-8B-Instruct" llama_stack/providers/tests/agents/test_agents.py --safety-shield="meta-llama/Llama-Guard-3-8B"` ## Before submitting - [ ] This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case). - [X] Ran pre-commit to handle lint / formatting issues. - [X] Read the [contributor guideline](https://github.com/meta-llama/llama-stack/blob/main/CONTRIBUTING.md), Pull Request section? - [X] Updated relevant documentation. - [X] Wrote necessary unit or integration tests. --------- Co-authored-by: Hardik Shah <hjshah@fb.com>
81 lines
2.6 KiB
Python
81 lines
2.6 KiB
Python
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
# All rights reserved.
|
|
#
|
|
# This source code is licensed under the terms described in the LICENSE file in
|
|
# the root directory of this source tree.
|
|
|
|
import argparse
|
|
import json
|
|
|
|
from llama_models.sku_list import resolve_model
|
|
|
|
from termcolor import colored
|
|
|
|
from llama_stack.cli.subcommand import Subcommand
|
|
from llama_stack.cli.table import print_table
|
|
|
|
|
|
class ModelDescribe(Subcommand):
|
|
"""Show details about a model"""
|
|
|
|
def __init__(self, subparsers: argparse._SubParsersAction):
|
|
super().__init__()
|
|
self.parser = subparsers.add_parser(
|
|
"describe",
|
|
prog="llama model describe",
|
|
description="Show details about a llama model",
|
|
formatter_class=argparse.RawTextHelpFormatter,
|
|
)
|
|
self._add_arguments()
|
|
self.parser.set_defaults(func=self._run_model_describe_cmd)
|
|
|
|
def _add_arguments(self):
|
|
self.parser.add_argument(
|
|
"-m",
|
|
"--model-id",
|
|
type=str,
|
|
required=True,
|
|
)
|
|
|
|
def _run_model_describe_cmd(self, args: argparse.Namespace) -> None:
|
|
from .safety_models import prompt_guard_model_sku
|
|
|
|
prompt_guard = prompt_guard_model_sku()
|
|
if args.model_id == prompt_guard.model_id:
|
|
model = prompt_guard
|
|
else:
|
|
model = resolve_model(args.model_id)
|
|
|
|
if model is None:
|
|
self.parser.error(
|
|
f"Model {args.model_id} not found; try 'llama model list' for a list of available models."
|
|
)
|
|
return
|
|
|
|
rows = [
|
|
(
|
|
colored("Model", "white", attrs=["bold"]),
|
|
colored(model.descriptor(), "white", attrs=["bold"]),
|
|
),
|
|
("Hugging Face ID", model.huggingface_repo or "<Not Available>"),
|
|
("Description", model.description),
|
|
("Context Length", f"{model.max_seq_length // 1024}K tokens"),
|
|
("Weights format", model.quantization_format.value),
|
|
("Model params.json", json.dumps(model.arch_args, indent=4)),
|
|
]
|
|
|
|
if model.recommended_sampling_params is not None:
|
|
sampling_params = model.recommended_sampling_params.dict()
|
|
for k in ("max_tokens", "repetition_penalty"):
|
|
del sampling_params[k]
|
|
rows.append(
|
|
(
|
|
"Recommended sampling params",
|
|
json.dumps(sampling_params, indent=4),
|
|
)
|
|
)
|
|
|
|
print_table(
|
|
rows,
|
|
separate_rows=True,
|
|
)
|