From 34fec77fa6775b00df7142d8f5683335b8f04a63 Mon Sep 17 00:00:00 2001 From: Ashwin Bharambe Date: Thu, 13 Feb 2025 21:47:58 -0800 Subject: [PATCH] update llama_models.sku_list --- llama_stack/cli/download.py | 4 ++-- llama_stack/cli/model/describe.py | 2 +- llama_stack/cli/model/list.py | 3 +-- llama_stack/cli/model/safety_models.py | 2 +- .../providers/inline/inference/meta_reference/generation.py | 2 +- .../providers/inline/inference/meta_reference/inference.py | 3 +-- .../inline/inference/meta_reference/model_parallel.py | 2 +- .../inline/inference/meta_reference/quantization/loader.py | 2 +- llama_stack/providers/inline/inference/vllm/vllm.py | 2 +- .../providers/inline/post_training/torchtune/common/utils.py | 2 +- .../torchtune/recipes/lora_finetuning_single_device.py | 2 +- llama_stack/providers/remote/inference/groq/groq.py | 2 +- llama_stack/providers/remote/inference/nvidia/nvidia.py | 3 +-- llama_stack/providers/remote/inference/tgi/tgi.py | 2 +- llama_stack/providers/remote/inference/vllm/vllm.py | 4 ++-- llama_stack/providers/tests/report.py | 2 +- llama_stack/providers/utils/inference/__init__.py | 3 +-- llama_stack/providers/utils/inference/model_registry.py | 3 +-- llama_stack/providers/utils/inference/prompt_adapter.py | 2 +- llama_stack/templates/bedrock/bedrock.py | 3 +-- llama_stack/templates/cerebras/cerebras.py | 3 +-- llama_stack/templates/fireworks/fireworks.py | 3 +-- llama_stack/templates/nvidia/nvidia.py | 3 +-- llama_stack/templates/sambanova/sambanova.py | 3 +-- llama_stack/templates/together/together.py | 3 +-- 25 files changed, 27 insertions(+), 38 deletions(-) diff --git a/llama_stack/cli/download.py b/llama_stack/cli/download.py index 16ecc3ad4..6b0463c10 100644 --- a/llama_stack/cli/download.py +++ b/llama_stack/cli/download.py @@ -16,7 +16,6 @@ from pathlib import Path from typing import Dict, List, Optional import httpx -from llama_models.sku_list import LlamaDownloadInfo from pydantic import BaseModel, ConfigDict from rich.console import Console from rich.progress import ( @@ -31,6 +30,7 @@ from termcolor import cprint from llama_stack.cli.subcommand import Subcommand from llama_stack.models.llama.datatypes import Model +from llama_stack.models.llama.sku_list import LlamaDownloadInfo class Download(Subcommand): @@ -454,7 +454,7 @@ def run_download_cmd(args: argparse.Namespace, parser: argparse.ArgumentParser): # Handle comma-separated model IDs model_ids = [model_id.strip() for model_id in args.model_id.split(",")] - from llama_models.sku_list import llama_meta_net_info, resolve_model + from llama_stack.models.llama.sku_list import llama_meta_net_info, resolve_model from .model.safety_models import ( prompt_guard_download_info, diff --git a/llama_stack/cli/model/describe.py b/llama_stack/cli/model/describe.py index 3e55052c5..d8f4e035c 100644 --- a/llama_stack/cli/model/describe.py +++ b/llama_stack/cli/model/describe.py @@ -7,11 +7,11 @@ import argparse import json -from llama_models.sku_list import resolve_model from termcolor import colored from llama_stack.cli.subcommand import Subcommand from llama_stack.cli.table import print_table +from llama_stack.models.llama.sku_list import resolve_model class ModelDescribe(Subcommand): diff --git a/llama_stack/cli/model/list.py b/llama_stack/cli/model/list.py index 9b5ebb1a5..4fe28751e 100644 --- a/llama_stack/cli/model/list.py +++ b/llama_stack/cli/model/list.py @@ -6,10 +6,9 @@ import argparse -from llama_models.sku_list import all_registered_models - from llama_stack.cli.subcommand import Subcommand from llama_stack.cli.table import print_table +from llama_stack.models.llama.sku_list import all_registered_models class ModelList(Subcommand): diff --git a/llama_stack/cli/model/safety_models.py b/llama_stack/cli/model/safety_models.py index 4fad2c09f..c81783f60 100644 --- a/llama_stack/cli/model/safety_models.py +++ b/llama_stack/cli/model/safety_models.py @@ -6,10 +6,10 @@ from typing import Any, Dict, Optional -from llama_models.sku_list import LlamaDownloadInfo from pydantic import BaseModel, ConfigDict, Field from llama_stack.models.llama.datatypes import CheckpointQuantizationFormat, SamplingParams +from llama_stack.models.llama.sku_list import LlamaDownloadInfo class PromptGuardModel(BaseModel): diff --git a/llama_stack/providers/inline/inference/meta_reference/generation.py b/llama_stack/providers/inline/inference/meta_reference/generation.py index ed7bb7b79..2d2ec5c8f 100644 --- a/llama_stack/providers/inline/inference/meta_reference/generation.py +++ b/llama_stack/providers/inline/inference/meta_reference/generation.py @@ -30,7 +30,6 @@ from llama_models.llama3.reference_impl.model import Transformer from llama_models.llama3.reference_impl.multimodal.model import ( CrossAttentionTransformer, ) -from llama_models.sku_list import resolve_model from lmformatenforcer import JsonSchemaParser, TokenEnforcer, TokenEnforcerTokenizerData from pydantic import BaseModel @@ -47,6 +46,7 @@ from llama_stack.models.llama.datatypes import ( SamplingParams, TopPSamplingStrategy, ) +from llama_stack.models.llama.sku_list import resolve_model from llama_stack.providers.utils.inference.prompt_adapter import ( ChatCompletionRequestWithRawContent, CompletionRequestWithRawContent, diff --git a/llama_stack/providers/inline/inference/meta_reference/inference.py b/llama_stack/providers/inline/inference/meta_reference/inference.py index 2a66986d1..c79f97def 100644 --- a/llama_stack/providers/inline/inference/meta_reference/inference.py +++ b/llama_stack/providers/inline/inference/meta_reference/inference.py @@ -8,8 +8,6 @@ import asyncio import logging from typing import AsyncGenerator, List, Optional, Union -from llama_models.sku_list import resolve_model - from llama_stack.apis.common.content_types import ( TextDelta, ToolCallDelta, @@ -41,6 +39,7 @@ from llama_stack.models.llama.datatypes import ( ToolDefinition, ToolPromptFormat, ) +from llama_stack.models.llama.sku_list import resolve_model from llama_stack.providers.datatypes import ModelsProtocolPrivate from llama_stack.providers.utils.inference.embedding_mixin import ( SentenceTransformerEmbeddingMixin, diff --git a/llama_stack/providers/inline/inference/meta_reference/model_parallel.py b/llama_stack/providers/inline/inference/meta_reference/model_parallel.py index 4f6ad017f..64f94a69d 100644 --- a/llama_stack/providers/inline/inference/meta_reference/model_parallel.py +++ b/llama_stack/providers/inline/inference/meta_reference/model_parallel.py @@ -11,9 +11,9 @@ from typing import Any, Generator from llama_models.llama3.api.chat_format import ChatFormat from llama_models.llama3.api.tokenizer import Tokenizer -from llama_models.sku_list import resolve_model from llama_stack.models.llama.datatypes import Model +from llama_stack.models.llama.sku_list import resolve_model from llama_stack.providers.utils.inference.prompt_adapter import ( ChatCompletionRequestWithRawContent, CompletionRequestWithRawContent, diff --git a/llama_stack/providers/inline/inference/meta_reference/quantization/loader.py b/llama_stack/providers/inline/inference/meta_reference/quantization/loader.py index d63fff458..a2dc00916 100644 --- a/llama_stack/providers/inline/inference/meta_reference/quantization/loader.py +++ b/llama_stack/providers/inline/inference/meta_reference/quantization/loader.py @@ -16,12 +16,12 @@ from fairscale.nn.model_parallel.layers import ColumnParallelLinear, RowParallel from fairscale.nn.model_parallel.mappings import reduce_from_model_parallel_region from llama_models.llama3.api.args import ModelArgs from llama_models.llama3.reference_impl.model import Transformer, TransformerBlock -from llama_models.sku_list import resolve_model from torch import Tensor, nn from torchao.quantization.GPTQ import Int8DynActInt4WeightLinear from llama_stack.apis.inference import QuantizationType from llama_stack.models.llama.datatypes import CheckpointQuantizationFormat +from llama_stack.models.llama.sku_list import resolve_model from ..config import MetaReferenceQuantizedInferenceConfig diff --git a/llama_stack/providers/inline/inference/vllm/vllm.py b/llama_stack/providers/inline/inference/vllm/vllm.py index e75a9aac3..5536ea3a5 100644 --- a/llama_stack/providers/inline/inference/vllm/vllm.py +++ b/llama_stack/providers/inline/inference/vllm/vllm.py @@ -11,7 +11,6 @@ from typing import AsyncGenerator, List, Optional from llama_models.llama3.api.chat_format import ChatFormat from llama_models.llama3.api.tokenizer import Tokenizer -from llama_models.sku_list import resolve_model from vllm.engine.arg_utils import AsyncEngineArgs from vllm.engine.async_llm_engine import AsyncLLMEngine from vllm.sampling_params import SamplingParams as VLLMSamplingParams @@ -35,6 +34,7 @@ from llama_stack.apis.inference import ( ToolPromptFormat, ) from llama_stack.apis.models import Model +from llama_stack.models.llama.sku_list import resolve_model from llama_stack.providers.datatypes import ModelsProtocolPrivate from llama_stack.providers.utils.inference.openai_compat import ( OpenAICompatCompletionChoice, diff --git a/llama_stack/providers/inline/post_training/torchtune/common/utils.py b/llama_stack/providers/inline/post_training/torchtune/common/utils.py index fee82d5d2..98e16f9d7 100644 --- a/llama_stack/providers/inline/post_training/torchtune/common/utils.py +++ b/llama_stack/providers/inline/post_training/torchtune/common/utils.py @@ -13,7 +13,6 @@ from typing import Any, Callable, Dict import torch -from llama_models.sku_list import resolve_model from pydantic import BaseModel from torchtune.data._messages import InputOutputToMessages, ShareGPTToMessages from torchtune.models.llama3 import llama3_tokenizer @@ -24,6 +23,7 @@ from torchtune.modules.transforms import Transform from llama_stack.apis.post_training import DatasetFormat from llama_stack.models.llama.datatypes import Model +from llama_stack.models.llama.sku_list import resolve_model class ModelConfig(BaseModel): diff --git a/llama_stack/providers/inline/post_training/torchtune/recipes/lora_finetuning_single_device.py b/llama_stack/providers/inline/post_training/torchtune/recipes/lora_finetuning_single_device.py index ef379aff2..4ab59fec4 100644 --- a/llama_stack/providers/inline/post_training/torchtune/recipes/lora_finetuning_single_device.py +++ b/llama_stack/providers/inline/post_training/torchtune/recipes/lora_finetuning_single_device.py @@ -14,7 +14,6 @@ from pathlib import Path from typing import Any, Dict, List, Optional, Tuple import torch -from llama_models.sku_list import resolve_model from torch import nn from torch.optim import Optimizer from torch.utils.data import DataLoader, DistributedSampler @@ -46,6 +45,7 @@ from llama_stack.apis.post_training import ( ) from llama_stack.distribution.utils.config_dirs import DEFAULT_CHECKPOINT_DIR from llama_stack.distribution.utils.model_utils import model_local_dir +from llama_stack.models.llama.sku_list import resolve_model from llama_stack.providers.inline.post_training.common.validator import ( validate_input_dataset_schema, ) diff --git a/llama_stack/providers/remote/inference/groq/groq.py b/llama_stack/providers/remote/inference/groq/groq.py index 5335e6ad3..441b6af5c 100644 --- a/llama_stack/providers/remote/inference/groq/groq.py +++ b/llama_stack/providers/remote/inference/groq/groq.py @@ -9,7 +9,6 @@ from typing import AsyncIterator, List, Optional, Union import groq from groq import Groq -from llama_models.sku_list import CoreModelId from llama_stack.apis.inference import ( ChatCompletionRequest, @@ -28,6 +27,7 @@ from llama_stack.apis.inference import ( ) from llama_stack.distribution.request_headers import NeedsRequestProviderData from llama_stack.models.llama.datatypes import SamplingParams, ToolDefinition, ToolPromptFormat +from llama_stack.models.llama.sku_list import CoreModelId from llama_stack.providers.remote.inference.groq.config import GroqConfig from llama_stack.providers.utils.inference.model_registry import ( ModelRegistryHelper, diff --git a/llama_stack/providers/remote/inference/nvidia/nvidia.py b/llama_stack/providers/remote/inference/nvidia/nvidia.py index 1276e29e0..0c5b7c454 100644 --- a/llama_stack/providers/remote/inference/nvidia/nvidia.py +++ b/llama_stack/providers/remote/inference/nvidia/nvidia.py @@ -7,7 +7,6 @@ import warnings from typing import AsyncIterator, List, Optional, Union -from llama_models.sku_list import CoreModelId from openai import APIConnectionError, AsyncOpenAI from llama_stack.apis.inference import ( @@ -26,7 +25,7 @@ from llama_stack.apis.inference import ( ToolChoice, ToolConfig, ) -from llama_stack.models.llama.datatypes import SamplingParams, ToolDefinition, ToolPromptFormat +from llama_stack.models.llama.datatypes import CoreModelId, SamplingParams, ToolDefinition, ToolPromptFormat from llama_stack.providers.utils.inference.model_registry import ( ModelRegistryHelper, build_model_alias, diff --git a/llama_stack/providers/remote/inference/tgi/tgi.py b/llama_stack/providers/remote/inference/tgi/tgi.py index 72eaa6c31..1909e01f8 100644 --- a/llama_stack/providers/remote/inference/tgi/tgi.py +++ b/llama_stack/providers/remote/inference/tgi/tgi.py @@ -11,7 +11,6 @@ from typing import AsyncGenerator, List, Optional from huggingface_hub import AsyncInferenceClient, HfApi from llama_models.llama3.api.chat_format import ChatFormat from llama_models.llama3.api.tokenizer import Tokenizer -from llama_models.sku_list import all_registered_models from llama_stack.apis.common.content_types import InterleavedContent from llama_stack.apis.inference import ( @@ -31,6 +30,7 @@ from llama_stack.apis.inference import ( ToolPromptFormat, ) from llama_stack.apis.models import Model +from llama_stack.models.llama.sku_list import all_registered_models from llama_stack.providers.datatypes import ModelsProtocolPrivate from llama_stack.providers.utils.inference.model_registry import ( ModelRegistryHelper, diff --git a/llama_stack/providers/remote/inference/vllm/vllm.py b/llama_stack/providers/remote/inference/vllm/vllm.py index 8f9cf68a8..b22284302 100644 --- a/llama_stack/providers/remote/inference/vllm/vllm.py +++ b/llama_stack/providers/remote/inference/vllm/vllm.py @@ -7,10 +7,9 @@ import json import logging from typing import AsyncGenerator, List, Optional, Union -from llama_models.llama3.api import StopReason, ToolCall +from llama_models.datatypes import StopReason, ToolCall from llama_models.llama3.api.chat_format import ChatFormat from llama_models.llama3.api.tokenizer import Tokenizer -from llama_models.sku_list import all_registered_models from openai import OpenAI from llama_stack.apis.common.content_types import InterleavedContent, TextDelta, ToolCallDelta, ToolCallParseStatus @@ -37,6 +36,7 @@ from llama_stack.apis.inference import ( ToolPromptFormat, ) from llama_stack.apis.models import Model, ModelType +from llama_stack.models.llama.sku_list import all_registered_models from llama_stack.providers.datatypes import ModelsProtocolPrivate from llama_stack.providers.utils.inference.model_registry import ( ModelRegistryHelper, diff --git a/llama_stack/providers/tests/report.py b/llama_stack/providers/tests/report.py index 991696af4..febd13045 100644 --- a/llama_stack/providers/tests/report.py +++ b/llama_stack/providers/tests/report.py @@ -9,11 +9,11 @@ from collections import defaultdict from pathlib import Path import pytest -from llama_models.sku_list import all_registered_models from pytest import ExitCode from pytest_html.basereport import _process_outcome from llama_stack.models.llama.datatypes import CoreModelId +from llama_stack.models.llama.sku_list import all_registered_models INFERENCE_APIS = ["chat_completion"] FUNCTIONALITIES = ["streaming", "structured_output", "tool_calling"] diff --git a/llama_stack/providers/utils/inference/__init__.py b/llama_stack/providers/utils/inference/__init__.py index ef3f17f68..cab3725da 100644 --- a/llama_stack/providers/utils/inference/__init__.py +++ b/llama_stack/providers/utils/inference/__init__.py @@ -6,9 +6,8 @@ from typing import List -from llama_models.sku_list import all_registered_models - from llama_stack.models.llama.datatypes import * # noqa: F403 +from llama_stack.models.llama.sku_list import all_registered_models def is_supported_safety_model(model: Model) -> bool: diff --git a/llama_stack/providers/utils/inference/model_registry.py b/llama_stack/providers/utils/inference/model_registry.py index 9345da949..c5f6cd6b5 100644 --- a/llama_stack/providers/utils/inference/model_registry.py +++ b/llama_stack/providers/utils/inference/model_registry.py @@ -7,9 +7,8 @@ from collections import namedtuple from typing import List, Optional -from llama_models.sku_list import all_registered_models - from llama_stack.apis.models.models import ModelType +from llama_stack.models.llama.sku_list import all_registered_models from llama_stack.providers.datatypes import Model, ModelsProtocolPrivate from llama_stack.providers.utils.inference import ( ALL_HUGGINGFACE_REPOS_TO_MODEL_DESCRIPTOR, diff --git a/llama_stack/providers/utils/inference/prompt_adapter.py b/llama_stack/providers/utils/inference/prompt_adapter.py index 069f9c5bd..269f84d6c 100644 --- a/llama_stack/providers/utils/inference/prompt_adapter.py +++ b/llama_stack/providers/utils/inference/prompt_adapter.py @@ -21,7 +21,6 @@ from llama_models.llama3.prompt_templates import ( PythonListCustomToolGenerator, SystemDefaultGenerator, ) -from llama_models.sku_list import resolve_model from PIL import Image as PIL_Image from llama_stack.apis.common.content_types import ( @@ -52,6 +51,7 @@ from llama_stack.models.llama.datatypes import ( ToolPromptFormat, is_multimodal, ) +from llama_stack.models.llama.sku_list import resolve_model from llama_stack.providers.utils.inference import supported_inference_models log = logging.getLogger(__name__) diff --git a/llama_stack/templates/bedrock/bedrock.py b/llama_stack/templates/bedrock/bedrock.py index af1d48b7f..0b294824d 100644 --- a/llama_stack/templates/bedrock/bedrock.py +++ b/llama_stack/templates/bedrock/bedrock.py @@ -6,10 +6,9 @@ from pathlib import Path -from llama_models.sku_list import all_registered_models - from llama_stack.apis.models import ModelInput from llama_stack.distribution.datatypes import Provider, ToolGroupInput +from llama_stack.models.llama.sku_list import all_registered_models from llama_stack.providers.inline.vector_io.faiss.config import FaissVectorIOConfig from llama_stack.providers.remote.inference.bedrock.bedrock import MODEL_ALIASES from llama_stack.templates.template import DistributionTemplate, RunConfigSettings diff --git a/llama_stack/templates/cerebras/cerebras.py b/llama_stack/templates/cerebras/cerebras.py index 870240feb..4f6d0c8f3 100644 --- a/llama_stack/templates/cerebras/cerebras.py +++ b/llama_stack/templates/cerebras/cerebras.py @@ -6,10 +6,9 @@ from pathlib import Path -from llama_models.sku_list import all_registered_models - from llama_stack.apis.models.models import ModelType from llama_stack.distribution.datatypes import ModelInput, Provider, ToolGroupInput +from llama_stack.models.llama.sku_list import all_registered_models from llama_stack.providers.inline.inference.sentence_transformers import ( SentenceTransformersInferenceConfig, ) diff --git a/llama_stack/templates/fireworks/fireworks.py b/llama_stack/templates/fireworks/fireworks.py index e2e2ca99c..a6809fef6 100644 --- a/llama_stack/templates/fireworks/fireworks.py +++ b/llama_stack/templates/fireworks/fireworks.py @@ -6,8 +6,6 @@ from pathlib import Path -from llama_models.sku_list import all_registered_models - from llama_stack.apis.models.models import ModelType from llama_stack.distribution.datatypes import ( ModelInput, @@ -15,6 +13,7 @@ from llama_stack.distribution.datatypes import ( ShieldInput, ToolGroupInput, ) +from llama_stack.models.llama.sku_list import all_registered_models from llama_stack.providers.inline.inference.sentence_transformers import ( SentenceTransformersInferenceConfig, ) diff --git a/llama_stack/templates/nvidia/nvidia.py b/llama_stack/templates/nvidia/nvidia.py index d24c9ed48..ee22b5555 100644 --- a/llama_stack/templates/nvidia/nvidia.py +++ b/llama_stack/templates/nvidia/nvidia.py @@ -6,9 +6,8 @@ from pathlib import Path -from llama_models.sku_list import all_registered_models - from llama_stack.distribution.datatypes import ModelInput, Provider, ToolGroupInput +from llama_stack.models.llama.sku_list import all_registered_models from llama_stack.providers.remote.inference.nvidia import NVIDIAConfig from llama_stack.providers.remote.inference.nvidia.nvidia import _MODEL_ALIASES from llama_stack.templates.template import DistributionTemplate, RunConfigSettings diff --git a/llama_stack/templates/sambanova/sambanova.py b/llama_stack/templates/sambanova/sambanova.py index 6d7477c8e..c7a9428af 100644 --- a/llama_stack/templates/sambanova/sambanova.py +++ b/llama_stack/templates/sambanova/sambanova.py @@ -6,14 +6,13 @@ from pathlib import Path -from llama_models.sku_list import all_registered_models - from llama_stack.distribution.datatypes import ( ModelInput, Provider, ShieldInput, ToolGroupInput, ) +from llama_stack.models.llama.sku_list import all_registered_models from llama_stack.providers.remote.inference.sambanova import SambaNovaImplConfig from llama_stack.providers.remote.inference.sambanova.sambanova import MODEL_ALIASES from llama_stack.templates.template import DistributionTemplate, RunConfigSettings diff --git a/llama_stack/templates/together/together.py b/llama_stack/templates/together/together.py index 9ec5b38ba..f7b18e32a 100644 --- a/llama_stack/templates/together/together.py +++ b/llama_stack/templates/together/together.py @@ -6,8 +6,6 @@ from pathlib import Path -from llama_models.sku_list import all_registered_models - from llama_stack.apis.models.models import ModelType from llama_stack.distribution.datatypes import ( ModelInput, @@ -15,6 +13,7 @@ from llama_stack.distribution.datatypes import ( ShieldInput, ToolGroupInput, ) +from llama_stack.models.llama.sku_list import all_registered_models from llama_stack.providers.inline.inference.sentence_transformers import ( SentenceTransformersInferenceConfig, )