diff --git a/llama_stack/apis/agents/event_logger.py b/llama_stack/apis/agents/event_logger.py index 835ce4cee..92271fb60 100644 --- a/llama_stack/apis/agents/event_logger.py +++ b/llama_stack/apis/agents/event_logger.py @@ -6,13 +6,13 @@ from typing import Optional -from llama_models.llama3.api.datatypes import ToolPromptFormat from llama_models.llama3.api.tool_utils import ToolUtils from termcolor import cprint from llama_stack.apis.agents import AgentTurnResponseEventType, StepType from llama_stack.apis.common.content_types import ToolCallParseStatus from llama_stack.apis.inference import ToolResponseMessage +from llama_stack.models.llama.datatypes import ToolPromptFormat from llama_stack.providers.utils.inference.prompt_adapter import ( interleaved_content_as_str, ) diff --git a/llama_stack/apis/common/content_types.py b/llama_stack/apis/common/content_types.py index 5017fa2f5..0d0afa894 100644 --- a/llama_stack/apis/common/content_types.py +++ b/llama_stack/apis/common/content_types.py @@ -7,9 +7,9 @@ from enum import Enum from typing import Annotated, List, Literal, Optional, Union -from llama_models.llama3.api.datatypes import ToolCall from pydantic import BaseModel, Field, model_validator +from llama_stack.models.llama.datatypes import ToolCall from llama_stack.schema_utils import json_schema_type, register_schema diff --git a/llama_stack/apis/inference/inference.py b/llama_stack/apis/inference/inference.py index a5bc1d72f..433ba3274 100644 --- a/llama_stack/apis/inference/inference.py +++ b/llama_stack/apis/inference/inference.py @@ -17,7 +17,13 @@ from typing import ( runtime_checkable, ) -from llama_models.llama3.api.datatypes import ( +from pydantic import BaseModel, Field, field_validator +from typing_extensions import Annotated + +from llama_stack.apis.common.content_types import ContentDelta, InterleavedContent +from llama_stack.apis.models import Model +from llama_stack.apis.telemetry.telemetry import MetricResponseMixin +from llama_stack.models.llama.datatypes import ( BuiltinTool, SamplingParams, StopReason, @@ -25,12 +31,6 @@ from llama_models.llama3.api.datatypes import ( ToolDefinition, ToolPromptFormat, ) -from pydantic import BaseModel, Field, field_validator -from typing_extensions import Annotated - -from llama_stack.apis.common.content_types import ContentDelta, InterleavedContent -from llama_stack.apis.models import Model -from llama_stack.apis.telemetry.telemetry import MetricResponseMixin from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol from llama_stack.schema_utils import json_schema_type, register_schema, webmethod diff --git a/llama_stack/apis/telemetry/telemetry.py b/llama_stack/apis/telemetry/telemetry.py index f8dbcb5eb..d010a7e3b 100644 --- a/llama_stack/apis/telemetry/telemetry.py +++ b/llama_stack/apis/telemetry/telemetry.py @@ -17,10 +17,10 @@ from typing import ( runtime_checkable, ) -from llama_models.llama3.api.datatypes import Primitive from pydantic import BaseModel, Field from typing_extensions import Annotated +from llama_stack.models.llama.datatypes import Primitive from llama_stack.schema_utils import json_schema_type, register_schema, webmethod # Add this constant near the top of the file, after the imports diff --git a/llama_stack/cli/model/safety_models.py b/llama_stack/cli/model/safety_models.py index 2321c4615..314f1639e 100644 --- a/llama_stack/cli/model/safety_models.py +++ b/llama_stack/cli/model/safety_models.py @@ -7,10 +7,11 @@ from typing import Any, Dict, Optional from llama_models.datatypes import CheckpointQuantizationFormat -from llama_models.llama3.api.datatypes import SamplingParams from llama_models.sku_list import LlamaDownloadInfo from pydantic import BaseModel, ConfigDict, Field +from llama_stack.models.llama.datatypes import SamplingParams + class PromptGuardModel(BaseModel): """Make a 'fake' Model-like object for Prompt Guard. Eventually this will be removed.""" diff --git a/llama_stack/models/llama/datatypes.py b/llama_stack/models/llama/datatypes.py index b99c90d75..a5dc9ac4a 100644 --- a/llama_stack/models/llama/datatypes.py +++ b/llama_stack/models/llama/datatypes.py @@ -14,7 +14,8 @@ from enum import Enum from typing import Any, Dict, Literal, Optional, Union -from llama_models.datatypes import BuiltinTool, ToolCall +# import all for backwards compatibility +from llama_models.datatypes import * # noqa: F403 from pydantic import BaseModel, ConfigDict, Field, field_validator from typing_extensions import Annotated diff --git a/llama_stack/providers/inline/agents/meta_reference/agent_instance.py b/llama_stack/providers/inline/agents/meta_reference/agent_instance.py index 8ba7885cd..fc597d0f7 100644 --- a/llama_stack/providers/inline/agents/meta_reference/agent_instance.py +++ b/llama_stack/providers/inline/agents/meta_reference/agent_instance.py @@ -17,7 +17,6 @@ from typing import Any, AsyncGenerator, Dict, List, Optional, Tuple from urllib.parse import urlparse import httpx -from llama_models.llama3.api.datatypes import BuiltinTool, ToolCall, ToolParamDefinition from pydantic import TypeAdapter from llama_stack.apis.agents import ( @@ -63,6 +62,7 @@ from llama_stack.apis.inference import ( from llama_stack.apis.safety import Safety from llama_stack.apis.tools import RAGDocument, RAGQueryConfig, ToolGroups, ToolRuntime from llama_stack.apis.vector_io import VectorIO +from llama_stack.models.llama.datatypes import BuiltinTool, ToolCall, ToolParamDefinition from llama_stack.providers.utils.kvstore import KVStore from llama_stack.providers.utils.memory.vector_store import concat_interleaved_content from llama_stack.providers.utils.telemetry import tracing diff --git a/llama_stack/providers/inline/agents/meta_reference/tests/test_chat_agent.py b/llama_stack/providers/inline/agents/meta_reference/tests/test_chat_agent.py index 4e3951ad3..b802937b6 100644 --- a/llama_stack/providers/inline/agents/meta_reference/tests/test_chat_agent.py +++ b/llama_stack/providers/inline/agents/meta_reference/tests/test_chat_agent.py @@ -8,7 +8,6 @@ import tempfile from typing import AsyncIterator, List, Optional, Union import pytest -from llama_models.llama3.api.datatypes import BuiltinTool from llama_stack.apis.agents import ( AgentConfig, @@ -41,6 +40,7 @@ from llama_stack.apis.tools import ( ToolInvocationResult, ) from llama_stack.apis.vector_io import QueryChunksResponse +from llama_stack.models.llama.datatypes import BuiltinTool from llama_stack.providers.inline.agents.meta_reference.agent_instance import ( MEMORY_QUERY_TOOL, ) diff --git a/llama_stack/providers/inline/inference/meta_reference/generation.py b/llama_stack/providers/inline/inference/meta_reference/generation.py index e60c3b1be..16f76721c 100644 --- a/llama_stack/providers/inline/inference/meta_reference/generation.py +++ b/llama_stack/providers/inline/inference/meta_reference/generation.py @@ -30,7 +30,6 @@ from llama_models.datatypes import ( ) from llama_models.llama3.api.args import ModelArgs from llama_models.llama3.api.chat_format import ChatFormat, LLMInput -from llama_models.llama3.api.datatypes import Model from llama_models.llama3.api.tokenizer import Tokenizer from llama_models.llama3.reference_impl.model import Transformer from llama_models.llama3.reference_impl.multimodal.model import ( @@ -47,6 +46,7 @@ from llama_stack.apis.inference import ( ResponseFormatType, ) from llama_stack.distribution.utils.model_utils import model_local_dir +from llama_stack.models.llama.datatypes import Model from llama_stack.providers.utils.inference.prompt_adapter import ( ChatCompletionRequestWithRawContent, CompletionRequestWithRawContent, diff --git a/llama_stack/providers/inline/inference/meta_reference/inference.py b/llama_stack/providers/inline/inference/meta_reference/inference.py index 61f0ee3f4..2a66986d1 100644 --- a/llama_stack/providers/inline/inference/meta_reference/inference.py +++ b/llama_stack/providers/inline/inference/meta_reference/inference.py @@ -8,12 +8,6 @@ import asyncio import logging from typing import AsyncGenerator, List, Optional, Union -from llama_models.llama3.api.datatypes import ( - SamplingParams, - StopReason, - ToolDefinition, - ToolPromptFormat, -) from llama_models.sku_list import resolve_model from llama_stack.apis.common.content_types import ( @@ -41,6 +35,12 @@ from llama_stack.apis.inference import ( ToolConfig, ) from llama_stack.apis.models import Model, ModelType +from llama_stack.models.llama.datatypes import ( + SamplingParams, + StopReason, + ToolDefinition, + ToolPromptFormat, +) from llama_stack.providers.datatypes import ModelsProtocolPrivate from llama_stack.providers.utils.inference.embedding_mixin import ( SentenceTransformerEmbeddingMixin, diff --git a/llama_stack/providers/inline/inference/meta_reference/model_parallel.py b/llama_stack/providers/inline/inference/meta_reference/model_parallel.py index ef133274c..4f6ad017f 100644 --- a/llama_stack/providers/inline/inference/meta_reference/model_parallel.py +++ b/llama_stack/providers/inline/inference/meta_reference/model_parallel.py @@ -10,10 +10,10 @@ from functools import partial from typing import Any, Generator from llama_models.llama3.api.chat_format import ChatFormat -from llama_models.llama3.api.datatypes import Model from llama_models.llama3.api.tokenizer import Tokenizer from llama_models.sku_list import resolve_model +from llama_stack.models.llama.datatypes import Model from llama_stack.providers.utils.inference.prompt_adapter import ( ChatCompletionRequestWithRawContent, CompletionRequestWithRawContent, diff --git a/llama_stack/providers/inline/safety/llama_guard/llama_guard.py b/llama_stack/providers/inline/safety/llama_guard/llama_guard.py index 32d6d5100..b186c8b02 100644 --- a/llama_stack/providers/inline/safety/llama_guard/llama_guard.py +++ b/llama_stack/providers/inline/safety/llama_guard/llama_guard.py @@ -9,7 +9,6 @@ from string import Template from typing import Any, Dict, List, Optional from llama_models.datatypes import CoreModelId -from llama_models.llama3.api.datatypes import Role from llama_stack.apis.common.content_types import ImageContentItem, TextContentItem from llama_stack.apis.inference import ( @@ -26,6 +25,7 @@ from llama_stack.apis.safety import ( ) from llama_stack.apis.shields import Shield from llama_stack.distribution.datatypes import Api +from llama_stack.models.llama.datatypes import Role from llama_stack.providers.datatypes import ShieldsProtocolPrivate from llama_stack.providers.utils.inference.prompt_adapter import ( interleaved_content_as_str, diff --git a/llama_stack/providers/remote/inference/cerebras/cerebras.py b/llama_stack/providers/remote/inference/cerebras/cerebras.py index 2158fc5b4..3ba2c37c5 100644 --- a/llama_stack/providers/remote/inference/cerebras/cerebras.py +++ b/llama_stack/providers/remote/inference/cerebras/cerebras.py @@ -9,7 +9,6 @@ from typing import AsyncGenerator, List, Optional, Union from cerebras.cloud.sdk import AsyncCerebras from llama_models.datatypes import CoreModelId from llama_models.llama3.api.chat_format import ChatFormat -from llama_models.llama3.api.datatypes import TopKSamplingStrategy from llama_models.llama3.api.tokenizer import Tokenizer from llama_stack.apis.common.content_types import InterleavedContent @@ -28,6 +27,7 @@ from llama_stack.apis.inference import ( ToolDefinition, ToolPromptFormat, ) +from llama_stack.models.llama.datatypes import TopKSamplingStrategy from llama_stack.providers.utils.inference.model_registry import ( ModelRegistryHelper, build_model_alias, diff --git a/llama_stack/providers/remote/inference/groq/groq.py b/llama_stack/providers/remote/inference/groq/groq.py index 59ec8b0d2..c45b8ee42 100644 --- a/llama_stack/providers/remote/inference/groq/groq.py +++ b/llama_stack/providers/remote/inference/groq/groq.py @@ -10,7 +10,6 @@ from typing import AsyncIterator, List, Optional, Union import groq from groq import Groq from llama_models.datatypes import SamplingParams -from llama_models.llama3.api.datatypes import ToolDefinition, ToolPromptFormat from llama_models.sku_list import CoreModelId from llama_stack.apis.inference import ( @@ -29,6 +28,7 @@ from llama_stack.apis.inference import ( ToolConfig, ) from llama_stack.distribution.request_headers import NeedsRequestProviderData +from llama_stack.models.llama.datatypes import ToolDefinition, ToolPromptFormat from llama_stack.providers.remote.inference.groq.config import GroqConfig from llama_stack.providers.utils.inference.model_registry import ( ModelRegistryHelper, diff --git a/llama_stack/providers/remote/inference/groq/groq_utils.py b/llama_stack/providers/remote/inference/groq/groq_utils.py index 2445c1b39..f1138e789 100644 --- a/llama_stack/providers/remote/inference/groq/groq_utils.py +++ b/llama_stack/providers/remote/inference/groq/groq_utils.py @@ -24,7 +24,6 @@ from groq.types.chat.chat_completion_user_message_param import ( ) from groq.types.chat.completion_create_params import CompletionCreateParams from groq.types.shared.function_definition import FunctionDefinition -from llama_models.llama3.api.datatypes import ToolParamDefinition from llama_stack.apis.common.content_types import ( TextDelta, @@ -44,6 +43,7 @@ from llama_stack.apis.inference import ( ToolDefinition, ToolPromptFormat, ) +from llama_stack.models.llama.datatypes import ToolParamDefinition from llama_stack.providers.utils.inference.openai_compat import ( UnparseableToolCall, convert_tool_call, diff --git a/llama_stack/providers/remote/inference/nvidia/nvidia.py b/llama_stack/providers/remote/inference/nvidia/nvidia.py index 82343513f..4d30a0a9c 100644 --- a/llama_stack/providers/remote/inference/nvidia/nvidia.py +++ b/llama_stack/providers/remote/inference/nvidia/nvidia.py @@ -8,7 +8,6 @@ import warnings from typing import AsyncIterator, List, Optional, Union from llama_models.datatypes import SamplingParams -from llama_models.llama3.api.datatypes import ToolDefinition, ToolPromptFormat from llama_models.sku_list import CoreModelId from openai import APIConnectionError, AsyncOpenAI @@ -28,6 +27,7 @@ from llama_stack.apis.inference import ( ToolChoice, ToolConfig, ) +from llama_stack.models.llama.datatypes import ToolDefinition, ToolPromptFormat from llama_stack.providers.utils.inference.model_registry import ( ModelRegistryHelper, build_model_alias, diff --git a/llama_stack/providers/remote/inference/nvidia/openai_utils.py b/llama_stack/providers/remote/inference/nvidia/openai_utils.py index c757c562c..a6c5086de 100644 --- a/llama_stack/providers/remote/inference/nvidia/openai_utils.py +++ b/llama_stack/providers/remote/inference/nvidia/openai_utils.py @@ -13,12 +13,6 @@ from llama_models.datatypes import ( TopKSamplingStrategy, TopPSamplingStrategy, ) -from llama_models.llama3.api.datatypes import ( - BuiltinTool, - StopReason, - ToolCall, - ToolDefinition, -) from openai import AsyncStream from openai.types.chat import ( ChatCompletionAssistantMessageParam as OpenAIChatCompletionAssistantMessage, @@ -87,6 +81,12 @@ from llama_stack.apis.inference import ( ToolResponseMessage, UserMessage, ) +from llama_stack.models.llama.datatypes import ( + BuiltinTool, + StopReason, + ToolCall, + ToolDefinition, +) from llama_stack.providers.utils.inference.prompt_adapter import ( convert_image_content_to_url, ) diff --git a/llama_stack/providers/remote/inference/runpod/runpod.py b/llama_stack/providers/remote/inference/runpod/runpod.py index a3c615418..1abb17336 100644 --- a/llama_stack/providers/remote/inference/runpod/runpod.py +++ b/llama_stack/providers/remote/inference/runpod/runpod.py @@ -6,11 +6,11 @@ from typing import AsyncGenerator from llama_models.llama3.api.chat_format import ChatFormat -from llama_models.llama3.api.datatypes import Message from llama_models.llama3.api.tokenizer import Tokenizer from openai import OpenAI from llama_stack.apis.inference import * # noqa: F403 +from llama_stack.models.llama.datatypes import Message # from llama_stack.providers.datatypes import ModelsProtocolPrivate from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper diff --git a/llama_stack/providers/remote/tool_runtime/brave_search/brave_search.py b/llama_stack/providers/remote/tool_runtime/brave_search/brave_search.py index 564f76088..8ef9f5705 100644 --- a/llama_stack/providers/remote/tool_runtime/brave_search/brave_search.py +++ b/llama_stack/providers/remote/tool_runtime/brave_search/brave_search.py @@ -7,7 +7,6 @@ from typing import Any, Dict, List, Optional import requests -from llama_models.llama3.api.datatypes import BuiltinTool from llama_stack.apis.common.content_types import URL from llama_stack.apis.tools import ( @@ -18,6 +17,7 @@ from llama_stack.apis.tools import ( ToolRuntime, ) from llama_stack.distribution.request_headers import NeedsRequestProviderData +from llama_stack.models.llama.datatypes import BuiltinTool from llama_stack.providers.datatypes import ToolsProtocolPrivate from .config import BraveSearchToolConfig diff --git a/llama_stack/providers/tests/agents/test_agents.py b/llama_stack/providers/tests/agents/test_agents.py index 45b276cc3..68868ee52 100644 --- a/llama_stack/providers/tests/agents/test_agents.py +++ b/llama_stack/providers/tests/agents/test_agents.py @@ -8,7 +8,6 @@ import os import pytest from llama_models.datatypes import SamplingParams, TopPSamplingStrategy -from llama_models.llama3.api.datatypes import BuiltinTool from llama_stack.apis.agents import ( AgentConfig, @@ -25,6 +24,7 @@ from llama_stack.apis.agents import ( ) from llama_stack.apis.inference import CompletionMessage, UserMessage from llama_stack.apis.safety import ViolationLevel +from llama_stack.models.llama.datatypes import BuiltinTool from llama_stack.providers.datatypes import Api # How to run this test: diff --git a/llama_stack/providers/tests/inference/groq/test_groq_utils.py b/llama_stack/providers/tests/inference/groq/test_groq_utils.py index 3eba991c1..5f0278c20 100644 --- a/llama_stack/providers/tests/inference/groq/test_groq_utils.py +++ b/llama_stack/providers/tests/inference/groq/test_groq_utils.py @@ -24,7 +24,6 @@ from groq.types.chat.chat_completion_message_tool_call import ( ) from groq.types.shared.function_definition import FunctionDefinition from llama_models.datatypes import GreedySamplingStrategy, TopPSamplingStrategy -from llama_models.llama3.api.datatypes import ToolParamDefinition from llama_stack.apis.common.content_types import ToolCallParseStatus from llama_stack.apis.inference import ( @@ -38,6 +37,7 @@ from llama_stack.apis.inference import ( ToolDefinition, UserMessage, ) +from llama_stack.models.llama.datatypes import ToolParamDefinition from llama_stack.providers.remote.inference.groq.groq_utils import ( convert_chat_completion_request, convert_chat_completion_response, diff --git a/llama_stack/providers/tests/inference/test_prompt_adapter.py b/llama_stack/providers/tests/inference/test_prompt_adapter.py index c087c5df2..323c6cb6a 100644 --- a/llama_stack/providers/tests/inference/test_prompt_adapter.py +++ b/llama_stack/providers/tests/inference/test_prompt_adapter.py @@ -6,19 +6,18 @@ import unittest -from llama_models.llama3.api.datatypes import ( - BuiltinTool, - ToolDefinition, - ToolParamDefinition, - ToolPromptFormat, -) - from llama_stack.apis.inference import ( ChatCompletionRequest, SystemMessage, ToolConfig, UserMessage, ) +from llama_stack.models.llama.datatypes import ( + BuiltinTool, + ToolDefinition, + ToolParamDefinition, + ToolPromptFormat, +) from llama_stack.providers.utils.inference.prompt_adapter import ( chat_completion_request_to_messages, ) diff --git a/llama_stack/providers/tests/inference/test_text_inference.py b/llama_stack/providers/tests/inference/test_text_inference.py index 14ed2fc4b..f25b95004 100644 --- a/llama_stack/providers/tests/inference/test_text_inference.py +++ b/llama_stack/providers/tests/inference/test_text_inference.py @@ -6,14 +6,6 @@ import pytest -from llama_models.llama3.api.datatypes import ( - SamplingParams, - StopReason, - ToolCall, - ToolDefinition, - ToolParamDefinition, - ToolPromptFormat, -) from pydantic import BaseModel, ValidationError from llama_stack.apis.common.content_types import ToolCallParseStatus @@ -30,6 +22,14 @@ from llama_stack.apis.inference import ( UserMessage, ) from llama_stack.apis.models import ListModelsResponse, Model +from llama_stack.models.llama.datatypes import ( + SamplingParams, + StopReason, + ToolCall, + ToolDefinition, + ToolParamDefinition, + ToolPromptFormat, +) from .utils import group_chunks diff --git a/llama_stack/providers/utils/inference/openai_compat.py b/llama_stack/providers/utils/inference/openai_compat.py index 33f0f4e22..128c21849 100644 --- a/llama_stack/providers/utils/inference/openai_compat.py +++ b/llama_stack/providers/utils/inference/openai_compat.py @@ -14,7 +14,6 @@ from llama_models.datatypes import ( TopPSamplingStrategy, ) from llama_models.llama3.api.chat_format import ChatFormat -from llama_models.llama3.api.datatypes import StopReason, ToolCall from openai.types.chat import ChatCompletionMessageToolCall from pydantic import BaseModel @@ -37,6 +36,7 @@ from llama_stack.apis.inference import ( Message, TokenLogProbs, ) +from llama_stack.models.llama.datatypes import StopReason, ToolCall from llama_stack.providers.utils.inference.prompt_adapter import ( convert_image_content_to_url, ) diff --git a/llama_stack/providers/utils/inference/prompt_adapter.py b/llama_stack/providers/utils/inference/prompt_adapter.py index 15149e059..b90704d66 100644 --- a/llama_stack/providers/utils/inference/prompt_adapter.py +++ b/llama_stack/providers/utils/inference/prompt_adapter.py @@ -15,15 +15,6 @@ from typing import List, Optional, Tuple, Union import httpx from llama_models.datatypes import ModelFamily, is_multimodal from llama_models.llama3.api.chat_format import ChatFormat -from llama_models.llama3.api.datatypes import ( - RawContent, - RawContentItem, - RawMediaItem, - RawMessage, - RawTextItem, - Role, - ToolPromptFormat, -) from llama_models.llama3.prompt_templates import ( BuiltinToolGenerator, FunctionTagCustomToolGenerator, @@ -51,6 +42,15 @@ from llama_stack.apis.inference import ( ToolChoice, UserMessage, ) +from llama_stack.models.llama.datatypes import ( + RawContent, + RawContentItem, + RawMediaItem, + RawMessage, + RawTextItem, + Role, + ToolPromptFormat, +) from llama_stack.providers.utils.inference import supported_inference_models log = logging.getLogger(__name__) diff --git a/llama_stack/providers/utils/telemetry/trace_protocol.py b/llama_stack/providers/utils/telemetry/trace_protocol.py index 80c58a2c7..924274c42 100644 --- a/llama_stack/providers/utils/telemetry/trace_protocol.py +++ b/llama_stack/providers/utils/telemetry/trace_protocol.py @@ -9,9 +9,10 @@ import inspect from functools import wraps from typing import Any, AsyncGenerator, Callable, Type, TypeVar -from llama_models.llama3.api.datatypes import Primitive from pydantic import BaseModel +from llama_stack.models.llama.datatypes import Primitive + T = TypeVar("T") diff --git a/llama_stack/scripts/generate_prompt_format.py b/llama_stack/scripts/generate_prompt_format.py index c529b0a5f..ecdde900f 100644 --- a/llama_stack/scripts/generate_prompt_format.py +++ b/llama_stack/scripts/generate_prompt_format.py @@ -17,7 +17,7 @@ from typing import Optional import fire -# from llama_models.llama3.api.datatypes import * # noqa: F403 +# from llama_stack.models.llama.datatypes import * # noqa: F403 from llama_models.llama3.reference_impl.generation import Llama THIS_DIR = Path(__file__).parent.resolve()