update llama_models.datatypes

This commit is contained in:
Ashwin Bharambe 2025-02-13 21:46:16 -08:00
parent 4e7d652f0b
commit 15a247b728
23 changed files with 45 additions and 52 deletions

View file

@ -16,7 +16,6 @@ from pathlib import Path
from typing import Dict, List, Optional from typing import Dict, List, Optional
import httpx import httpx
from llama_models.datatypes import Model
from llama_models.sku_list import LlamaDownloadInfo from llama_models.sku_list import LlamaDownloadInfo
from pydantic import BaseModel, ConfigDict from pydantic import BaseModel, ConfigDict
from rich.console import Console from rich.console import Console
@ -31,6 +30,7 @@ from rich.progress import (
from termcolor import cprint from termcolor import cprint
from llama_stack.cli.subcommand import Subcommand from llama_stack.cli.subcommand import Subcommand
from llama_stack.models.llama.datatypes import Model
class Download(Subcommand): class Download(Subcommand):

View file

@ -8,9 +8,8 @@ import argparse
import textwrap import textwrap
from io import StringIO from io import StringIO
from llama_models.datatypes import CoreModelId, ModelFamily, is_multimodal, model_family
from llama_stack.cli.subcommand import Subcommand from llama_stack.cli.subcommand import Subcommand
from llama_stack.models.llama.datatypes import CoreModelId, ModelFamily, is_multimodal, model_family
class ModelPromptFormat(Subcommand): class ModelPromptFormat(Subcommand):

View file

@ -6,11 +6,10 @@
from typing import Any, Dict, Optional from typing import Any, Dict, Optional
from llama_models.datatypes import CheckpointQuantizationFormat
from llama_models.sku_list import LlamaDownloadInfo from llama_models.sku_list import LlamaDownloadInfo
from pydantic import BaseModel, ConfigDict, Field from pydantic import BaseModel, ConfigDict, Field
from llama_stack.models.llama.datatypes import SamplingParams from llama_stack.models.llama.datatypes import CheckpointQuantizationFormat, SamplingParams
class PromptGuardModel(BaseModel): class PromptGuardModel(BaseModel):

View file

@ -23,11 +23,6 @@ from fairscale.nn.model_parallel.initialize import (
initialize_model_parallel, initialize_model_parallel,
model_parallel_is_initialized, model_parallel_is_initialized,
) )
from llama_models.datatypes import (
GreedySamplingStrategy,
SamplingParams,
TopPSamplingStrategy,
)
from llama_models.llama3.api.args import ModelArgs from llama_models.llama3.api.args import ModelArgs
from llama_models.llama3.api.chat_format import ChatFormat, LLMInput from llama_models.llama3.api.chat_format import ChatFormat, LLMInput
from llama_models.llama3.api.tokenizer import Tokenizer from llama_models.llama3.api.tokenizer import Tokenizer
@ -46,7 +41,12 @@ from llama_stack.apis.inference import (
ResponseFormatType, ResponseFormatType,
) )
from llama_stack.distribution.utils.model_utils import model_local_dir from llama_stack.distribution.utils.model_utils import model_local_dir
from llama_stack.models.llama.datatypes import Model from llama_stack.models.llama.datatypes import (
GreedySamplingStrategy,
Model,
SamplingParams,
TopPSamplingStrategy,
)
from llama_stack.providers.utils.inference.prompt_adapter import ( from llama_stack.providers.utils.inference.prompt_adapter import (
ChatCompletionRequestWithRawContent, ChatCompletionRequestWithRawContent,
CompletionRequestWithRawContent, CompletionRequestWithRawContent,

View file

@ -14,7 +14,6 @@ from typing import Any, Dict, List, Optional
import torch import torch
from fairscale.nn.model_parallel.layers import ColumnParallelLinear, RowParallelLinear from fairscale.nn.model_parallel.layers import ColumnParallelLinear, RowParallelLinear
from fairscale.nn.model_parallel.mappings import reduce_from_model_parallel_region from fairscale.nn.model_parallel.mappings import reduce_from_model_parallel_region
from llama_models.datatypes import CheckpointQuantizationFormat
from llama_models.llama3.api.args import ModelArgs from llama_models.llama3.api.args import ModelArgs
from llama_models.llama3.reference_impl.model import Transformer, TransformerBlock from llama_models.llama3.reference_impl.model import Transformer, TransformerBlock
from llama_models.sku_list import resolve_model from llama_models.sku_list import resolve_model
@ -22,6 +21,7 @@ from torch import Tensor, nn
from torchao.quantization.GPTQ import Int8DynActInt4WeightLinear from torchao.quantization.GPTQ import Int8DynActInt4WeightLinear
from llama_stack.apis.inference import QuantizationType from llama_stack.apis.inference import QuantizationType
from llama_stack.models.llama.datatypes import CheckpointQuantizationFormat
from ..config import MetaReferenceQuantizedInferenceConfig from ..config import MetaReferenceQuantizedInferenceConfig

View file

@ -13,7 +13,6 @@
from typing import Any, Callable, Dict from typing import Any, Callable, Dict
import torch import torch
from llama_models.datatypes import Model
from llama_models.sku_list import resolve_model from llama_models.sku_list import resolve_model
from pydantic import BaseModel from pydantic import BaseModel
from torchtune.data._messages import InputOutputToMessages, ShareGPTToMessages from torchtune.data._messages import InputOutputToMessages, ShareGPTToMessages
@ -24,6 +23,7 @@ from torchtune.models.llama3_2 import lora_llama3_2_3b
from torchtune.modules.transforms import Transform from torchtune.modules.transforms import Transform
from llama_stack.apis.post_training import DatasetFormat from llama_stack.apis.post_training import DatasetFormat
from llama_stack.models.llama.datatypes import Model
class ModelConfig(BaseModel): class ModelConfig(BaseModel):

View file

@ -8,8 +8,6 @@ import re
from string import Template from string import Template
from typing import Any, Dict, List, Optional from typing import Any, Dict, List, Optional
from llama_models.datatypes import CoreModelId
from llama_stack.apis.common.content_types import ImageContentItem, TextContentItem from llama_stack.apis.common.content_types import ImageContentItem, TextContentItem
from llama_stack.apis.inference import ( from llama_stack.apis.inference import (
ChatCompletionResponseEventType, ChatCompletionResponseEventType,
@ -25,7 +23,7 @@ from llama_stack.apis.safety import (
) )
from llama_stack.apis.shields import Shield from llama_stack.apis.shields import Shield
from llama_stack.distribution.datatypes import Api from llama_stack.distribution.datatypes import Api
from llama_stack.models.llama.datatypes import Role from llama_stack.models.llama.datatypes import CoreModelId, Role
from llama_stack.providers.datatypes import ShieldsProtocolPrivate from llama_stack.providers.datatypes import ShieldsProtocolPrivate
from llama_stack.providers.utils.inference.prompt_adapter import ( from llama_stack.providers.utils.inference.prompt_adapter import (
interleaved_content_as_str, interleaved_content_as_str,

View file

@ -8,7 +8,6 @@ import json
from typing import AsyncGenerator, AsyncIterator, Dict, List, Optional, Union from typing import AsyncGenerator, AsyncIterator, Dict, List, Optional, Union
from botocore.client import BaseClient from botocore.client import BaseClient
from llama_models.datatypes import CoreModelId
from llama_models.llama3.api.chat_format import ChatFormat from llama_models.llama3.api.chat_format import ChatFormat
from llama_models.llama3.api.tokenizer import Tokenizer from llama_models.llama3.api.tokenizer import Tokenizer
@ -28,6 +27,7 @@ from llama_stack.apis.inference import (
ToolDefinition, ToolDefinition,
ToolPromptFormat, ToolPromptFormat,
) )
from llama_stack.models.llama.datatypes import CoreModelId
from llama_stack.providers.remote.inference.bedrock.config import BedrockConfig from llama_stack.providers.remote.inference.bedrock.config import BedrockConfig
from llama_stack.providers.utils.bedrock.client import create_bedrock_client from llama_stack.providers.utils.bedrock.client import create_bedrock_client
from llama_stack.providers.utils.inference.model_registry import ( from llama_stack.providers.utils.inference.model_registry import (

View file

@ -7,7 +7,6 @@
from typing import AsyncGenerator, List, Optional, Union from typing import AsyncGenerator, List, Optional, Union
from cerebras.cloud.sdk import AsyncCerebras from cerebras.cloud.sdk import AsyncCerebras
from llama_models.datatypes import CoreModelId
from llama_models.llama3.api.chat_format import ChatFormat from llama_models.llama3.api.chat_format import ChatFormat
from llama_models.llama3.api.tokenizer import Tokenizer from llama_models.llama3.api.tokenizer import Tokenizer
@ -27,7 +26,7 @@ from llama_stack.apis.inference import (
ToolDefinition, ToolDefinition,
ToolPromptFormat, ToolPromptFormat,
) )
from llama_stack.models.llama.datatypes import TopKSamplingStrategy from llama_stack.models.llama.datatypes import CoreModelId, TopKSamplingStrategy
from llama_stack.providers.utils.inference.model_registry import ( from llama_stack.providers.utils.inference.model_registry import (
ModelRegistryHelper, ModelRegistryHelper,
build_model_alias, build_model_alias,

View file

@ -6,7 +6,6 @@
from typing import AsyncGenerator, List, Optional from typing import AsyncGenerator, List, Optional
from llama_models.datatypes import CoreModelId
from llama_models.llama3.api.chat_format import ChatFormat from llama_models.llama3.api.chat_format import ChatFormat
from llama_models.llama3.api.tokenizer import Tokenizer from llama_models.llama3.api.tokenizer import Tokenizer
from openai import OpenAI from openai import OpenAI
@ -25,6 +24,7 @@ from llama_stack.apis.inference import (
ToolDefinition, ToolDefinition,
ToolPromptFormat, ToolPromptFormat,
) )
from llama_stack.models.llama.datatypes import CoreModelId
from llama_stack.providers.utils.inference.model_registry import ( from llama_stack.providers.utils.inference.model_registry import (
ModelRegistryHelper, ModelRegistryHelper,
build_model_alias, build_model_alias,

View file

@ -7,7 +7,6 @@
from typing import AsyncGenerator, List, Optional, Union from typing import AsyncGenerator, List, Optional, Union
from fireworks.client import Fireworks from fireworks.client import Fireworks
from llama_models.datatypes import CoreModelId
from llama_models.llama3.api.chat_format import ChatFormat from llama_models.llama3.api.chat_format import ChatFormat
from llama_models.llama3.api.tokenizer import Tokenizer from llama_models.llama3.api.tokenizer import Tokenizer
@ -30,6 +29,7 @@ from llama_stack.apis.inference import (
ToolPromptFormat, ToolPromptFormat,
) )
from llama_stack.distribution.request_headers import NeedsRequestProviderData from llama_stack.distribution.request_headers import NeedsRequestProviderData
from llama_stack.models.llama.datatypes import CoreModelId
from llama_stack.providers.utils.inference.model_registry import ( from llama_stack.providers.utils.inference.model_registry import (
ModelRegistryHelper, ModelRegistryHelper,
build_model_alias, build_model_alias,

View file

@ -9,7 +9,6 @@ from typing import AsyncIterator, List, Optional, Union
import groq import groq
from groq import Groq from groq import Groq
from llama_models.datatypes import SamplingParams
from llama_models.sku_list import CoreModelId from llama_models.sku_list import CoreModelId
from llama_stack.apis.inference import ( from llama_stack.apis.inference import (
@ -28,7 +27,7 @@ from llama_stack.apis.inference import (
ToolConfig, ToolConfig,
) )
from llama_stack.distribution.request_headers import NeedsRequestProviderData from llama_stack.distribution.request_headers import NeedsRequestProviderData
from llama_stack.models.llama.datatypes import ToolDefinition, ToolPromptFormat from llama_stack.models.llama.datatypes import SamplingParams, ToolDefinition, ToolPromptFormat
from llama_stack.providers.remote.inference.groq.config import GroqConfig from llama_stack.providers.remote.inference.groq.config import GroqConfig
from llama_stack.providers.utils.inference.model_registry import ( from llama_stack.providers.utils.inference.model_registry import (
ModelRegistryHelper, ModelRegistryHelper,

View file

@ -7,7 +7,6 @@
import warnings import warnings
from typing import AsyncIterator, List, Optional, Union from typing import AsyncIterator, List, Optional, Union
from llama_models.datatypes import SamplingParams
from llama_models.sku_list import CoreModelId from llama_models.sku_list import CoreModelId
from openai import APIConnectionError, AsyncOpenAI from openai import APIConnectionError, AsyncOpenAI
@ -27,7 +26,7 @@ from llama_stack.apis.inference import (
ToolChoice, ToolChoice,
ToolConfig, ToolConfig,
) )
from llama_stack.models.llama.datatypes import ToolDefinition, ToolPromptFormat from llama_stack.models.llama.datatypes import SamplingParams, ToolDefinition, ToolPromptFormat
from llama_stack.providers.utils.inference.model_registry import ( from llama_stack.providers.utils.inference.model_registry import (
ModelRegistryHelper, ModelRegistryHelper,
build_model_alias, build_model_alias,

View file

@ -8,11 +8,6 @@ import json
import warnings import warnings
from typing import Any, AsyncGenerator, Dict, Generator, Iterable, List, Optional, Union from typing import Any, AsyncGenerator, Dict, Generator, Iterable, List, Optional, Union
from llama_models.datatypes import (
GreedySamplingStrategy,
TopKSamplingStrategy,
TopPSamplingStrategy,
)
from openai import AsyncStream from openai import AsyncStream
from openai.types.chat import ( from openai.types.chat import (
ChatCompletionAssistantMessageParam as OpenAIChatCompletionAssistantMessage, ChatCompletionAssistantMessageParam as OpenAIChatCompletionAssistantMessage,
@ -83,9 +78,12 @@ from llama_stack.apis.inference import (
) )
from llama_stack.models.llama.datatypes import ( from llama_stack.models.llama.datatypes import (
BuiltinTool, BuiltinTool,
GreedySamplingStrategy,
StopReason, StopReason,
ToolCall, ToolCall,
ToolDefinition, ToolDefinition,
TopKSamplingStrategy,
TopPSamplingStrategy,
) )
from llama_stack.providers.utils.inference.prompt_adapter import ( from llama_stack.providers.utils.inference.prompt_adapter import (
convert_image_content_to_url, convert_image_content_to_url,

View file

@ -8,7 +8,6 @@ import logging
from typing import AsyncGenerator, List, Optional, Union from typing import AsyncGenerator, List, Optional, Union
import httpx import httpx
from llama_models.datatypes import CoreModelId
from llama_models.llama3.api.chat_format import ChatFormat from llama_models.llama3.api.chat_format import ChatFormat
from llama_models.llama3.api.tokenizer import Tokenizer from llama_models.llama3.api.tokenizer import Tokenizer
from ollama import AsyncClient from ollama import AsyncClient
@ -34,6 +33,7 @@ from llama_stack.apis.inference import (
ToolPromptFormat, ToolPromptFormat,
) )
from llama_stack.apis.models import Model, ModelType from llama_stack.apis.models import Model, ModelType
from llama_stack.models.llama.datatypes import CoreModelId
from llama_stack.providers.datatypes import ModelsProtocolPrivate from llama_stack.providers.datatypes import ModelsProtocolPrivate
from llama_stack.providers.utils.inference.model_registry import ( from llama_stack.providers.utils.inference.model_registry import (
ModelRegistryHelper, ModelRegistryHelper,

View file

@ -7,12 +7,6 @@
import json import json
from typing import AsyncGenerator from typing import AsyncGenerator
from llama_models.datatypes import (
CoreModelId,
GreedySamplingStrategy,
TopKSamplingStrategy,
TopPSamplingStrategy,
)
from llama_models.llama3.api.chat_format import ChatFormat from llama_models.llama3.api.chat_format import ChatFormat
from llama_models.llama3.api.tokenizer import Tokenizer from llama_models.llama3.api.tokenizer import Tokenizer
from openai import OpenAI from openai import OpenAI
@ -23,6 +17,12 @@ from llama_stack.apis.common.content_types import (
TextContentItem, TextContentItem,
) )
from llama_stack.apis.inference import * # noqa: F403 from llama_stack.apis.inference import * # noqa: F403
from llama_stack.models.llama.datatypes import (
CoreModelId,
GreedySamplingStrategy,
TopKSamplingStrategy,
TopPSamplingStrategy,
)
from llama_stack.providers.utils.inference.model_registry import ( from llama_stack.providers.utils.inference.model_registry import (
ModelRegistryHelper, ModelRegistryHelper,
build_model_alias, build_model_alias,

View file

@ -6,7 +6,6 @@
from typing import AsyncGenerator, List, Optional, Union from typing import AsyncGenerator, List, Optional, Union
from llama_models.datatypes import CoreModelId
from llama_models.llama3.api.chat_format import ChatFormat from llama_models.llama3.api.chat_format import ChatFormat
from llama_models.llama3.api.tokenizer import Tokenizer from llama_models.llama3.api.tokenizer import Tokenizer
from together import Together from together import Together
@ -29,6 +28,7 @@ from llama_stack.apis.inference import (
ToolPromptFormat, ToolPromptFormat,
) )
from llama_stack.distribution.request_headers import NeedsRequestProviderData from llama_stack.distribution.request_headers import NeedsRequestProviderData
from llama_stack.models.llama.datatypes import CoreModelId
from llama_stack.providers.utils.inference.model_registry import ( from llama_stack.providers.utils.inference.model_registry import (
ModelRegistryHelper, ModelRegistryHelper,
build_model_alias, build_model_alias,

View file

@ -7,7 +7,6 @@
import os import os
import pytest import pytest
from llama_models.datatypes import SamplingParams, TopPSamplingStrategy
from llama_stack.apis.agents import ( from llama_stack.apis.agents import (
AgentConfig, AgentConfig,
@ -24,7 +23,7 @@ from llama_stack.apis.agents import (
) )
from llama_stack.apis.inference import CompletionMessage, UserMessage from llama_stack.apis.inference import CompletionMessage, UserMessage
from llama_stack.apis.safety import ViolationLevel from llama_stack.apis.safety import ViolationLevel
from llama_stack.models.llama.datatypes import BuiltinTool from llama_stack.models.llama.datatypes import BuiltinTool, SamplingParams, TopPSamplingStrategy
from llama_stack.providers.datatypes import Api from llama_stack.providers.datatypes import Api
# How to run this test: # How to run this test:

View file

@ -23,7 +23,6 @@ from groq.types.chat.chat_completion_message_tool_call import (
Function, Function,
) )
from groq.types.shared.function_definition import FunctionDefinition from groq.types.shared.function_definition import FunctionDefinition
from llama_models.datatypes import GreedySamplingStrategy, TopPSamplingStrategy
from llama_stack.apis.common.content_types import ToolCallParseStatus from llama_stack.apis.common.content_types import ToolCallParseStatus
from llama_stack.apis.inference import ( from llama_stack.apis.inference import (
@ -37,7 +36,7 @@ from llama_stack.apis.inference import (
ToolDefinition, ToolDefinition,
UserMessage, UserMessage,
) )
from llama_stack.models.llama.datatypes import ToolParamDefinition from llama_stack.models.llama.datatypes import GreedySamplingStrategy, ToolParamDefinition, TopPSamplingStrategy
from llama_stack.providers.remote.inference.groq.groq_utils import ( from llama_stack.providers.remote.inference.groq.groq_utils import (
convert_chat_completion_request, convert_chat_completion_request,
convert_chat_completion_response, convert_chat_completion_response,

View file

@ -9,11 +9,12 @@ from collections import defaultdict
from pathlib import Path from pathlib import Path
import pytest import pytest
from llama_models.datatypes import CoreModelId
from llama_models.sku_list import all_registered_models from llama_models.sku_list import all_registered_models
from pytest import ExitCode from pytest import ExitCode
from pytest_html.basereport import _process_outcome from pytest_html.basereport import _process_outcome
from llama_stack.models.llama.datatypes import CoreModelId
INFERENCE_APIS = ["chat_completion"] INFERENCE_APIS = ["chat_completion"]
FUNCTIONALITIES = ["streaming", "structured_output", "tool_calling"] FUNCTIONALITIES = ["streaming", "structured_output", "tool_calling"]
SUPPORTED_MODELS = { SUPPORTED_MODELS = {

View file

@ -6,9 +6,10 @@
from typing import List from typing import List
from llama_models.datatypes import * # noqa: F403
from llama_models.sku_list import all_registered_models from llama_models.sku_list import all_registered_models
from llama_stack.models.llama.datatypes import * # noqa: F403
def is_supported_safety_model(model: Model) -> bool: def is_supported_safety_model(model: Model) -> bool:
if model.quantization_format != CheckpointQuantizationFormat.bf16: if model.quantization_format != CheckpointQuantizationFormat.bf16:

View file

@ -7,12 +7,6 @@ import json
import logging import logging
from typing import AsyncGenerator, Dict, List, Optional, Union from typing import AsyncGenerator, Dict, List, Optional, Union
from llama_models.datatypes import (
GreedySamplingStrategy,
SamplingParams,
TopKSamplingStrategy,
TopPSamplingStrategy,
)
from llama_models.llama3.api.chat_format import ChatFormat from llama_models.llama3.api.chat_format import ChatFormat
from openai.types.chat import ChatCompletionMessageToolCall from openai.types.chat import ChatCompletionMessageToolCall
from pydantic import BaseModel from pydantic import BaseModel
@ -36,7 +30,14 @@ from llama_stack.apis.inference import (
Message, Message,
TokenLogProbs, TokenLogProbs,
) )
from llama_stack.models.llama.datatypes import StopReason, ToolCall from llama_stack.models.llama.datatypes import (
GreedySamplingStrategy,
SamplingParams,
StopReason,
ToolCall,
TopKSamplingStrategy,
TopPSamplingStrategy,
)
from llama_stack.providers.utils.inference.prompt_adapter import ( from llama_stack.providers.utils.inference.prompt_adapter import (
convert_image_content_to_url, convert_image_content_to_url,
) )

View file

@ -13,7 +13,6 @@ import re
from typing import List, Optional, Tuple, Union from typing import List, Optional, Tuple, Union
import httpx import httpx
from llama_models.datatypes import ModelFamily, is_multimodal
from llama_models.llama3.api.chat_format import ChatFormat from llama_models.llama3.api.chat_format import ChatFormat
from llama_models.llama3.prompt_templates import ( from llama_models.llama3.prompt_templates import (
BuiltinToolGenerator, BuiltinToolGenerator,
@ -43,6 +42,7 @@ from llama_stack.apis.inference import (
UserMessage, UserMessage,
) )
from llama_stack.models.llama.datatypes import ( from llama_stack.models.llama.datatypes import (
ModelFamily,
RawContent, RawContent,
RawContentItem, RawContentItem,
RawMediaItem, RawMediaItem,
@ -50,6 +50,7 @@ from llama_stack.models.llama.datatypes import (
RawTextItem, RawTextItem,
Role, Role,
ToolPromptFormat, ToolPromptFormat,
is_multimodal,
) )
from llama_stack.providers.utils.inference import supported_inference_models from llama_stack.providers.utils.inference import supported_inference_models