diff --git a/llama_stack/cli/download.py b/llama_stack/cli/download.py index 3ea534277..16ecc3ad4 100644 --- a/llama_stack/cli/download.py +++ b/llama_stack/cli/download.py @@ -16,7 +16,6 @@ from pathlib import Path from typing import Dict, List, Optional import httpx -from llama_models.datatypes import Model from llama_models.sku_list import LlamaDownloadInfo from pydantic import BaseModel, ConfigDict from rich.console import Console @@ -31,6 +30,7 @@ from rich.progress import ( from termcolor import cprint from llama_stack.cli.subcommand import Subcommand +from llama_stack.models.llama.datatypes import Model class Download(Subcommand): diff --git a/llama_stack/cli/model/prompt_format.py b/llama_stack/cli/model/prompt_format.py index 2e1e1601e..ea9596ba5 100644 --- a/llama_stack/cli/model/prompt_format.py +++ b/llama_stack/cli/model/prompt_format.py @@ -8,9 +8,8 @@ import argparse import textwrap from io import StringIO -from llama_models.datatypes import CoreModelId, ModelFamily, is_multimodal, model_family - from llama_stack.cli.subcommand import Subcommand +from llama_stack.models.llama.datatypes import CoreModelId, ModelFamily, is_multimodal, model_family class ModelPromptFormat(Subcommand): diff --git a/llama_stack/cli/model/safety_models.py b/llama_stack/cli/model/safety_models.py index 314f1639e..4fad2c09f 100644 --- a/llama_stack/cli/model/safety_models.py +++ b/llama_stack/cli/model/safety_models.py @@ -6,11 +6,10 @@ from typing import Any, Dict, Optional -from llama_models.datatypes import CheckpointQuantizationFormat from llama_models.sku_list import LlamaDownloadInfo from pydantic import BaseModel, ConfigDict, Field -from llama_stack.models.llama.datatypes import SamplingParams +from llama_stack.models.llama.datatypes import CheckpointQuantizationFormat, SamplingParams class PromptGuardModel(BaseModel): diff --git a/llama_stack/providers/inline/inference/meta_reference/generation.py b/llama_stack/providers/inline/inference/meta_reference/generation.py index 16f76721c..ed7bb7b79 100644 --- a/llama_stack/providers/inline/inference/meta_reference/generation.py +++ b/llama_stack/providers/inline/inference/meta_reference/generation.py @@ -23,11 +23,6 @@ from fairscale.nn.model_parallel.initialize import ( initialize_model_parallel, model_parallel_is_initialized, ) -from llama_models.datatypes import ( - GreedySamplingStrategy, - SamplingParams, - TopPSamplingStrategy, -) from llama_models.llama3.api.args import ModelArgs from llama_models.llama3.api.chat_format import ChatFormat, LLMInput from llama_models.llama3.api.tokenizer import Tokenizer @@ -46,7 +41,12 @@ from llama_stack.apis.inference import ( ResponseFormatType, ) from llama_stack.distribution.utils.model_utils import model_local_dir -from llama_stack.models.llama.datatypes import Model +from llama_stack.models.llama.datatypes import ( + GreedySamplingStrategy, + Model, + SamplingParams, + TopPSamplingStrategy, +) from llama_stack.providers.utils.inference.prompt_adapter import ( ChatCompletionRequestWithRawContent, CompletionRequestWithRawContent, diff --git a/llama_stack/providers/inline/inference/meta_reference/quantization/loader.py b/llama_stack/providers/inline/inference/meta_reference/quantization/loader.py index 9be35ae70..d63fff458 100644 --- a/llama_stack/providers/inline/inference/meta_reference/quantization/loader.py +++ b/llama_stack/providers/inline/inference/meta_reference/quantization/loader.py @@ -14,7 +14,6 @@ from typing import Any, Dict, List, Optional import torch from fairscale.nn.model_parallel.layers import ColumnParallelLinear, RowParallelLinear from fairscale.nn.model_parallel.mappings import reduce_from_model_parallel_region -from llama_models.datatypes import CheckpointQuantizationFormat from llama_models.llama3.api.args import ModelArgs from llama_models.llama3.reference_impl.model import Transformer, TransformerBlock from llama_models.sku_list import resolve_model @@ -22,6 +21,7 @@ from torch import Tensor, nn from torchao.quantization.GPTQ import Int8DynActInt4WeightLinear from llama_stack.apis.inference import QuantizationType +from llama_stack.models.llama.datatypes import CheckpointQuantizationFormat from ..config import MetaReferenceQuantizedInferenceConfig diff --git a/llama_stack/providers/inline/post_training/torchtune/common/utils.py b/llama_stack/providers/inline/post_training/torchtune/common/utils.py index 735af8c79..fee82d5d2 100644 --- a/llama_stack/providers/inline/post_training/torchtune/common/utils.py +++ b/llama_stack/providers/inline/post_training/torchtune/common/utils.py @@ -13,7 +13,6 @@ from typing import Any, Callable, Dict import torch -from llama_models.datatypes import Model from llama_models.sku_list import resolve_model from pydantic import BaseModel from torchtune.data._messages import InputOutputToMessages, ShareGPTToMessages @@ -24,6 +23,7 @@ from torchtune.models.llama3_2 import lora_llama3_2_3b from torchtune.modules.transforms import Transform from llama_stack.apis.post_training import DatasetFormat +from llama_stack.models.llama.datatypes import Model class ModelConfig(BaseModel): diff --git a/llama_stack/providers/inline/safety/llama_guard/llama_guard.py b/llama_stack/providers/inline/safety/llama_guard/llama_guard.py index b186c8b02..af0987fa8 100644 --- a/llama_stack/providers/inline/safety/llama_guard/llama_guard.py +++ b/llama_stack/providers/inline/safety/llama_guard/llama_guard.py @@ -8,8 +8,6 @@ import re from string import Template from typing import Any, Dict, List, Optional -from llama_models.datatypes import CoreModelId - from llama_stack.apis.common.content_types import ImageContentItem, TextContentItem from llama_stack.apis.inference import ( ChatCompletionResponseEventType, @@ -25,7 +23,7 @@ from llama_stack.apis.safety import ( ) from llama_stack.apis.shields import Shield from llama_stack.distribution.datatypes import Api -from llama_stack.models.llama.datatypes import Role +from llama_stack.models.llama.datatypes import CoreModelId, Role from llama_stack.providers.datatypes import ShieldsProtocolPrivate from llama_stack.providers.utils.inference.prompt_adapter import ( interleaved_content_as_str, diff --git a/llama_stack/providers/remote/inference/bedrock/bedrock.py b/llama_stack/providers/remote/inference/bedrock/bedrock.py index 917ac7a25..e896f0597 100644 --- a/llama_stack/providers/remote/inference/bedrock/bedrock.py +++ b/llama_stack/providers/remote/inference/bedrock/bedrock.py @@ -8,7 +8,6 @@ import json from typing import AsyncGenerator, AsyncIterator, Dict, List, Optional, Union from botocore.client import BaseClient -from llama_models.datatypes import CoreModelId from llama_models.llama3.api.chat_format import ChatFormat from llama_models.llama3.api.tokenizer import Tokenizer @@ -28,6 +27,7 @@ from llama_stack.apis.inference import ( ToolDefinition, ToolPromptFormat, ) +from llama_stack.models.llama.datatypes import CoreModelId from llama_stack.providers.remote.inference.bedrock.config import BedrockConfig from llama_stack.providers.utils.bedrock.client import create_bedrock_client from llama_stack.providers.utils.inference.model_registry import ( diff --git a/llama_stack/providers/remote/inference/cerebras/cerebras.py b/llama_stack/providers/remote/inference/cerebras/cerebras.py index 3ba2c37c5..1ce267e8d 100644 --- a/llama_stack/providers/remote/inference/cerebras/cerebras.py +++ b/llama_stack/providers/remote/inference/cerebras/cerebras.py @@ -7,7 +7,6 @@ from typing import AsyncGenerator, List, Optional, Union from cerebras.cloud.sdk import AsyncCerebras -from llama_models.datatypes import CoreModelId from llama_models.llama3.api.chat_format import ChatFormat from llama_models.llama3.api.tokenizer import Tokenizer @@ -27,7 +26,7 @@ from llama_stack.apis.inference import ( ToolDefinition, ToolPromptFormat, ) -from llama_stack.models.llama.datatypes import TopKSamplingStrategy +from llama_stack.models.llama.datatypes import CoreModelId, TopKSamplingStrategy from llama_stack.providers.utils.inference.model_registry import ( ModelRegistryHelper, build_model_alias, diff --git a/llama_stack/providers/remote/inference/databricks/databricks.py b/llama_stack/providers/remote/inference/databricks/databricks.py index d56be1465..3d306e61f 100644 --- a/llama_stack/providers/remote/inference/databricks/databricks.py +++ b/llama_stack/providers/remote/inference/databricks/databricks.py @@ -6,7 +6,6 @@ from typing import AsyncGenerator, List, Optional -from llama_models.datatypes import CoreModelId from llama_models.llama3.api.chat_format import ChatFormat from llama_models.llama3.api.tokenizer import Tokenizer from openai import OpenAI @@ -25,6 +24,7 @@ from llama_stack.apis.inference import ( ToolDefinition, ToolPromptFormat, ) +from llama_stack.models.llama.datatypes import CoreModelId from llama_stack.providers.utils.inference.model_registry import ( ModelRegistryHelper, build_model_alias, diff --git a/llama_stack/providers/remote/inference/fireworks/fireworks.py b/llama_stack/providers/remote/inference/fireworks/fireworks.py index 7e8f85313..acf37b248 100644 --- a/llama_stack/providers/remote/inference/fireworks/fireworks.py +++ b/llama_stack/providers/remote/inference/fireworks/fireworks.py @@ -7,7 +7,6 @@ from typing import AsyncGenerator, List, Optional, Union from fireworks.client import Fireworks -from llama_models.datatypes import CoreModelId from llama_models.llama3.api.chat_format import ChatFormat from llama_models.llama3.api.tokenizer import Tokenizer @@ -30,6 +29,7 @@ from llama_stack.apis.inference import ( ToolPromptFormat, ) from llama_stack.distribution.request_headers import NeedsRequestProviderData +from llama_stack.models.llama.datatypes import CoreModelId from llama_stack.providers.utils.inference.model_registry import ( ModelRegistryHelper, build_model_alias, diff --git a/llama_stack/providers/remote/inference/groq/groq.py b/llama_stack/providers/remote/inference/groq/groq.py index c45b8ee42..5335e6ad3 100644 --- a/llama_stack/providers/remote/inference/groq/groq.py +++ b/llama_stack/providers/remote/inference/groq/groq.py @@ -9,7 +9,6 @@ from typing import AsyncIterator, List, Optional, Union import groq from groq import Groq -from llama_models.datatypes import SamplingParams from llama_models.sku_list import CoreModelId from llama_stack.apis.inference import ( @@ -28,7 +27,7 @@ from llama_stack.apis.inference import ( ToolConfig, ) from llama_stack.distribution.request_headers import NeedsRequestProviderData -from llama_stack.models.llama.datatypes import ToolDefinition, ToolPromptFormat +from llama_stack.models.llama.datatypes import SamplingParams, ToolDefinition, ToolPromptFormat from llama_stack.providers.remote.inference.groq.config import GroqConfig from llama_stack.providers.utils.inference.model_registry import ( ModelRegistryHelper, diff --git a/llama_stack/providers/remote/inference/nvidia/nvidia.py b/llama_stack/providers/remote/inference/nvidia/nvidia.py index 4d30a0a9c..1276e29e0 100644 --- a/llama_stack/providers/remote/inference/nvidia/nvidia.py +++ b/llama_stack/providers/remote/inference/nvidia/nvidia.py @@ -7,7 +7,6 @@ import warnings from typing import AsyncIterator, List, Optional, Union -from llama_models.datatypes import SamplingParams from llama_models.sku_list import CoreModelId from openai import APIConnectionError, AsyncOpenAI @@ -27,7 +26,7 @@ from llama_stack.apis.inference import ( ToolChoice, ToolConfig, ) -from llama_stack.models.llama.datatypes import ToolDefinition, ToolPromptFormat +from llama_stack.models.llama.datatypes import SamplingParams, ToolDefinition, ToolPromptFormat from llama_stack.providers.utils.inference.model_registry import ( ModelRegistryHelper, build_model_alias, diff --git a/llama_stack/providers/remote/inference/nvidia/openai_utils.py b/llama_stack/providers/remote/inference/nvidia/openai_utils.py index a6c5086de..9799eedcc 100644 --- a/llama_stack/providers/remote/inference/nvidia/openai_utils.py +++ b/llama_stack/providers/remote/inference/nvidia/openai_utils.py @@ -8,11 +8,6 @@ import json import warnings from typing import Any, AsyncGenerator, Dict, Generator, Iterable, List, Optional, Union -from llama_models.datatypes import ( - GreedySamplingStrategy, - TopKSamplingStrategy, - TopPSamplingStrategy, -) from openai import AsyncStream from openai.types.chat import ( ChatCompletionAssistantMessageParam as OpenAIChatCompletionAssistantMessage, @@ -83,9 +78,12 @@ from llama_stack.apis.inference import ( ) from llama_stack.models.llama.datatypes import ( BuiltinTool, + GreedySamplingStrategy, StopReason, ToolCall, ToolDefinition, + TopKSamplingStrategy, + TopPSamplingStrategy, ) from llama_stack.providers.utils.inference.prompt_adapter import ( convert_image_content_to_url, diff --git a/llama_stack/providers/remote/inference/ollama/ollama.py b/llama_stack/providers/remote/inference/ollama/ollama.py index 1c12d0d91..f524c0734 100644 --- a/llama_stack/providers/remote/inference/ollama/ollama.py +++ b/llama_stack/providers/remote/inference/ollama/ollama.py @@ -8,7 +8,6 @@ import logging from typing import AsyncGenerator, List, Optional, Union import httpx -from llama_models.datatypes import CoreModelId from llama_models.llama3.api.chat_format import ChatFormat from llama_models.llama3.api.tokenizer import Tokenizer from ollama import AsyncClient @@ -34,6 +33,7 @@ from llama_stack.apis.inference import ( ToolPromptFormat, ) from llama_stack.apis.models import Model, ModelType +from llama_stack.models.llama.datatypes import CoreModelId from llama_stack.providers.datatypes import ModelsProtocolPrivate from llama_stack.providers.utils.inference.model_registry import ( ModelRegistryHelper, diff --git a/llama_stack/providers/remote/inference/sambanova/sambanova.py b/llama_stack/providers/remote/inference/sambanova/sambanova.py index 3546ee977..b906e0dcb 100644 --- a/llama_stack/providers/remote/inference/sambanova/sambanova.py +++ b/llama_stack/providers/remote/inference/sambanova/sambanova.py @@ -7,12 +7,6 @@ import json from typing import AsyncGenerator -from llama_models.datatypes import ( - CoreModelId, - GreedySamplingStrategy, - TopKSamplingStrategy, - TopPSamplingStrategy, -) from llama_models.llama3.api.chat_format import ChatFormat from llama_models.llama3.api.tokenizer import Tokenizer from openai import OpenAI @@ -23,6 +17,12 @@ from llama_stack.apis.common.content_types import ( TextContentItem, ) from llama_stack.apis.inference import * # noqa: F403 +from llama_stack.models.llama.datatypes import ( + CoreModelId, + GreedySamplingStrategy, + TopKSamplingStrategy, + TopPSamplingStrategy, +) from llama_stack.providers.utils.inference.model_registry import ( ModelRegistryHelper, build_model_alias, diff --git a/llama_stack/providers/remote/inference/together/together.py b/llama_stack/providers/remote/inference/together/together.py index 916e64ad4..054501da8 100644 --- a/llama_stack/providers/remote/inference/together/together.py +++ b/llama_stack/providers/remote/inference/together/together.py @@ -6,7 +6,6 @@ from typing import AsyncGenerator, List, Optional, Union -from llama_models.datatypes import CoreModelId from llama_models.llama3.api.chat_format import ChatFormat from llama_models.llama3.api.tokenizer import Tokenizer from together import Together @@ -29,6 +28,7 @@ from llama_stack.apis.inference import ( ToolPromptFormat, ) from llama_stack.distribution.request_headers import NeedsRequestProviderData +from llama_stack.models.llama.datatypes import CoreModelId from llama_stack.providers.utils.inference.model_registry import ( ModelRegistryHelper, build_model_alias, diff --git a/llama_stack/providers/tests/agents/test_agents.py b/llama_stack/providers/tests/agents/test_agents.py index 68868ee52..2e7bd537f 100644 --- a/llama_stack/providers/tests/agents/test_agents.py +++ b/llama_stack/providers/tests/agents/test_agents.py @@ -7,7 +7,6 @@ import os import pytest -from llama_models.datatypes import SamplingParams, TopPSamplingStrategy from llama_stack.apis.agents import ( AgentConfig, @@ -24,7 +23,7 @@ from llama_stack.apis.agents import ( ) from llama_stack.apis.inference import CompletionMessage, UserMessage from llama_stack.apis.safety import ViolationLevel -from llama_stack.models.llama.datatypes import BuiltinTool +from llama_stack.models.llama.datatypes import BuiltinTool, SamplingParams, TopPSamplingStrategy from llama_stack.providers.datatypes import Api # How to run this test: diff --git a/llama_stack/providers/tests/inference/groq/test_groq_utils.py b/llama_stack/providers/tests/inference/groq/test_groq_utils.py index 5f0278c20..34725e957 100644 --- a/llama_stack/providers/tests/inference/groq/test_groq_utils.py +++ b/llama_stack/providers/tests/inference/groq/test_groq_utils.py @@ -23,7 +23,6 @@ from groq.types.chat.chat_completion_message_tool_call import ( Function, ) from groq.types.shared.function_definition import FunctionDefinition -from llama_models.datatypes import GreedySamplingStrategy, TopPSamplingStrategy from llama_stack.apis.common.content_types import ToolCallParseStatus from llama_stack.apis.inference import ( @@ -37,7 +36,7 @@ from llama_stack.apis.inference import ( ToolDefinition, UserMessage, ) -from llama_stack.models.llama.datatypes import ToolParamDefinition +from llama_stack.models.llama.datatypes import GreedySamplingStrategy, ToolParamDefinition, TopPSamplingStrategy from llama_stack.providers.remote.inference.groq.groq_utils import ( convert_chat_completion_request, convert_chat_completion_response, diff --git a/llama_stack/providers/tests/report.py b/llama_stack/providers/tests/report.py index 3901dc2e3..991696af4 100644 --- a/llama_stack/providers/tests/report.py +++ b/llama_stack/providers/tests/report.py @@ -9,11 +9,12 @@ from collections import defaultdict from pathlib import Path import pytest -from llama_models.datatypes import CoreModelId from llama_models.sku_list import all_registered_models from pytest import ExitCode from pytest_html.basereport import _process_outcome +from llama_stack.models.llama.datatypes import CoreModelId + INFERENCE_APIS = ["chat_completion"] FUNCTIONALITIES = ["streaming", "structured_output", "tool_calling"] SUPPORTED_MODELS = { diff --git a/llama_stack/providers/utils/inference/__init__.py b/llama_stack/providers/utils/inference/__init__.py index 64fe30f55..ef3f17f68 100644 --- a/llama_stack/providers/utils/inference/__init__.py +++ b/llama_stack/providers/utils/inference/__init__.py @@ -6,9 +6,10 @@ from typing import List -from llama_models.datatypes import * # noqa: F403 from llama_models.sku_list import all_registered_models +from llama_stack.models.llama.datatypes import * # noqa: F403 + def is_supported_safety_model(model: Model) -> bool: if model.quantization_format != CheckpointQuantizationFormat.bf16: diff --git a/llama_stack/providers/utils/inference/openai_compat.py b/llama_stack/providers/utils/inference/openai_compat.py index 128c21849..da8e3ce2d 100644 --- a/llama_stack/providers/utils/inference/openai_compat.py +++ b/llama_stack/providers/utils/inference/openai_compat.py @@ -7,12 +7,6 @@ import json import logging from typing import AsyncGenerator, Dict, List, Optional, Union -from llama_models.datatypes import ( - GreedySamplingStrategy, - SamplingParams, - TopKSamplingStrategy, - TopPSamplingStrategy, -) from llama_models.llama3.api.chat_format import ChatFormat from openai.types.chat import ChatCompletionMessageToolCall from pydantic import BaseModel @@ -36,7 +30,14 @@ from llama_stack.apis.inference import ( Message, TokenLogProbs, ) -from llama_stack.models.llama.datatypes import StopReason, ToolCall +from llama_stack.models.llama.datatypes import ( + GreedySamplingStrategy, + SamplingParams, + StopReason, + ToolCall, + TopKSamplingStrategy, + TopPSamplingStrategy, +) from llama_stack.providers.utils.inference.prompt_adapter import ( convert_image_content_to_url, ) diff --git a/llama_stack/providers/utils/inference/prompt_adapter.py b/llama_stack/providers/utils/inference/prompt_adapter.py index b90704d66..069f9c5bd 100644 --- a/llama_stack/providers/utils/inference/prompt_adapter.py +++ b/llama_stack/providers/utils/inference/prompt_adapter.py @@ -13,7 +13,6 @@ import re from typing import List, Optional, Tuple, Union import httpx -from llama_models.datatypes import ModelFamily, is_multimodal from llama_models.llama3.api.chat_format import ChatFormat from llama_models.llama3.prompt_templates import ( BuiltinToolGenerator, @@ -43,6 +42,7 @@ from llama_stack.apis.inference import ( UserMessage, ) from llama_stack.models.llama.datatypes import ( + ModelFamily, RawContent, RawContentItem, RawMediaItem, @@ -50,6 +50,7 @@ from llama_stack.models.llama.datatypes import ( RawTextItem, Role, ToolPromptFormat, + is_multimodal, ) from llama_stack.providers.utils.inference import supported_inference_models