diff --git a/llama_toolchain/agentic_system/client.py b/llama_toolchain/agentic_system/client.py index 5b8053af9..154bca614 100644 --- a/llama_toolchain/agentic_system/client.py +++ b/llama_toolchain/agentic_system/client.py @@ -13,7 +13,7 @@ import fire import httpx -from llama_models.llama3_1.api.datatypes import ( +from llama_models.llama3.api.datatypes import ( BuiltinTool, SamplingParams, ToolParamDefinition, diff --git a/llama_toolchain/agentic_system/event_logger.py b/llama_toolchain/agentic_system/event_logger.py index 1bf669a0a..22d961a10 100644 --- a/llama_toolchain/agentic_system/event_logger.py +++ b/llama_toolchain/agentic_system/event_logger.py @@ -6,16 +6,16 @@ from typing import Optional -from llama_models.llama3_1.api.datatypes import ToolResponseMessage -from llama_models.llama3_1.api.tool_utils import ToolUtils +from llama_models.llama3.api.datatypes import ToolResponseMessage +from llama_models.llama3.api.tool_utils import ToolUtils + +from termcolor import cprint from llama_toolchain.agentic_system.api import ( AgenticSystemTurnResponseEventType, StepType, ) -from termcolor import cprint - class LogEvent: def __init__( diff --git a/llama_toolchain/agentic_system/meta_reference/safety.py b/llama_toolchain/agentic_system/meta_reference/safety.py index ff3633f18..683ae622d 100644 --- a/llama_toolchain/agentic_system/meta_reference/safety.py +++ b/llama_toolchain/agentic_system/meta_reference/safety.py @@ -6,7 +6,7 @@ from typing import List -from llama_models.llama3_1.api.datatypes import Message, Role, UserMessage +from llama_models.llama3.api.datatypes import Message, Role, UserMessage from termcolor import cprint from llama_toolchain.safety.api.datatypes import ( diff --git a/llama_toolchain/agentic_system/tools/custom/datatypes.py b/llama_toolchain/agentic_system/tools/custom/datatypes.py index ee46114e8..174b55241 100644 --- a/llama_toolchain/agentic_system/tools/custom/datatypes.py +++ b/llama_toolchain/agentic_system/tools/custom/datatypes.py @@ -9,7 +9,7 @@ import json from abc import abstractmethod from typing import Dict, List -from llama_models.llama3_1.api.datatypes import * # noqa: F403 +from llama_models.llama3.api.datatypes import * # noqa: F403 from llama_toolchain.agentic_system.api import * # noqa: F403 # TODO: this is symptomatic of us needing to pull more tooling related utilities diff --git a/llama_toolchain/agentic_system/tools/custom/execute.py b/llama_toolchain/agentic_system/tools/custom/execute.py index 987aee4e2..4729d35a7 100644 --- a/llama_toolchain/agentic_system/tools/custom/execute.py +++ b/llama_toolchain/agentic_system/tools/custom/execute.py @@ -6,7 +6,7 @@ from typing import Any, AsyncGenerator, List -from llama_models.llama3_1.api.datatypes import StopReason, ToolResponseMessage +from llama_models.llama3.api.datatypes import StopReason, ToolResponseMessage from llama_toolchain.agentic_system.api import ( AgenticSystem, diff --git a/llama_toolchain/agentic_system/utils.py b/llama_toolchain/agentic_system/utils.py index 3ae5c67b6..9613b45df 100644 --- a/llama_toolchain/agentic_system/utils.py +++ b/llama_toolchain/agentic_system/utils.py @@ -7,7 +7,7 @@ import uuid from typing import Any, List, Optional -from llama_models.llama3_1.api.datatypes import BuiltinTool, Message, SamplingParams +from llama_models.llama3.api.datatypes import BuiltinTool, Message, SamplingParams from llama_toolchain.agentic_system.api import ( AgenticSystemCreateRequest, diff --git a/llama_toolchain/cli/model/template.py b/llama_toolchain/cli/model/template.py index 58b245035..1915e87d3 100644 --- a/llama_toolchain/cli/model/template.py +++ b/llama_toolchain/cli/model/template.py @@ -7,10 +7,10 @@ import argparse import textwrap -from llama_toolchain.cli.subcommand import Subcommand - from termcolor import colored +from llama_toolchain.cli.subcommand import Subcommand + class ModelTemplate(Subcommand): """Llama model cli for describe a model template (message formats)""" @@ -48,10 +48,11 @@ class ModelTemplate(Subcommand): ) def _run_model_template_cmd(self, args: argparse.Namespace) -> None: - from llama_models.llama3_1.api.interface import ( + from llama_models.llama3.api.interface import ( list_jinja_templates, render_jinja_template, ) + from llama_toolchain.cli.table import print_table if args.name: diff --git a/llama_toolchain/common/deployment_types.py b/llama_toolchain/common/deployment_types.py index e5117cf2c..8b67eff0d 100644 --- a/llama_toolchain/common/deployment_types.py +++ b/llama_toolchain/common/deployment_types.py @@ -7,7 +7,7 @@ from enum import Enum from typing import Dict, Optional -from llama_models.llama3_1.api.datatypes import URL +from llama_models.llama3.api.datatypes import URL from llama_models.schema_utils import json_schema_type diff --git a/llama_toolchain/common/training_types.py b/llama_toolchain/common/training_types.py index 9c8d786fd..fd74293eb 100644 --- a/llama_toolchain/common/training_types.py +++ b/llama_toolchain/common/training_types.py @@ -4,7 +4,7 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from llama_models.llama3_1.api.datatypes import URL +from llama_models.llama3.api.datatypes import URL from llama_models.schema_utils import json_schema_type from pydantic import BaseModel diff --git a/llama_toolchain/dataset/api/datatypes.py b/llama_toolchain/dataset/api/datatypes.py index 5724023e9..32109b37c 100644 --- a/llama_toolchain/dataset/api/datatypes.py +++ b/llama_toolchain/dataset/api/datatypes.py @@ -7,7 +7,7 @@ from enum import Enum from typing import Any, Dict, Optional -from llama_models.llama3_1.api.datatypes import URL +from llama_models.llama3.api.datatypes import URL from llama_models.schema_utils import json_schema_type diff --git a/llama_toolchain/evaluations/api/endpoints.py b/llama_toolchain/evaluations/api/endpoints.py index 39b9a28e0..fd5b68bbe 100644 --- a/llama_toolchain/evaluations/api/endpoints.py +++ b/llama_toolchain/evaluations/api/endpoints.py @@ -10,7 +10,7 @@ from llama_models.schema_utils import webmethod from pydantic import BaseModel -from llama_models.llama3_1.api.datatypes import * # noqa: F403 +from llama_models.llama3.api.datatypes import * # noqa: F403 from .datatypes import * # noqa: F403 from llama_toolchain.dataset.api.datatypes import * # noqa: F403 from llama_toolchain.common.training_types import * # noqa: F403 diff --git a/llama_toolchain/inference/api/datatypes.py b/llama_toolchain/inference/api/datatypes.py index 5b0bc7170..571ecc3ea 100644 --- a/llama_toolchain/inference/api/datatypes.py +++ b/llama_toolchain/inference/api/datatypes.py @@ -12,7 +12,7 @@ from llama_models.schema_utils import json_schema_type from pydantic import BaseModel, Field from typing_extensions import Annotated -from llama_models.llama3_1.api.datatypes import * # noqa: F403 +from llama_models.llama3.api.datatypes import * # noqa: F403 class LogProbConfig(BaseModel): diff --git a/llama_toolchain/inference/meta_reference/generation.py b/llama_toolchain/inference/meta_reference/generation.py index f4d3c210b..058874702 100644 --- a/llama_toolchain/inference/meta_reference/generation.py +++ b/llama_toolchain/inference/meta_reference/generation.py @@ -22,11 +22,11 @@ from fairscale.nn.model_parallel.initialize import ( initialize_model_parallel, model_parallel_is_initialized, ) -from llama_models.llama3_1.api.args import ModelArgs -from llama_models.llama3_1.api.chat_format import ChatFormat, ModelInput -from llama_models.llama3_1.api.datatypes import Message -from llama_models.llama3_1.api.tokenizer import Tokenizer -from llama_models.llama3_1.reference_impl.model import Transformer +from llama_models.llama3.api.args import ModelArgs +from llama_models.llama3.api.chat_format import ChatFormat, ModelInput +from llama_models.llama3.api.datatypes import Message +from llama_models.llama3.api.tokenizer import Tokenizer +from llama_models.llama3.reference_impl.model import Transformer from llama_models.sku_list import resolve_model from termcolor import cprint diff --git a/llama_toolchain/inference/meta_reference/inference.py b/llama_toolchain/inference/meta_reference/inference.py index 4bd7a80bc..84caf1ecf 100644 --- a/llama_toolchain/inference/meta_reference/inference.py +++ b/llama_toolchain/inference/meta_reference/inference.py @@ -8,7 +8,7 @@ import asyncio from typing import AsyncIterator, Dict, Union -from llama_models.llama3_1.api.datatypes import StopReason +from llama_models.llama3.api.datatypes import StopReason from llama_models.sku_list import resolve_model from llama_toolchain.distribution.datatypes import Api, ProviderSpec diff --git a/llama_toolchain/inference/meta_reference/model_parallel.py b/llama_toolchain/inference/meta_reference/model_parallel.py index dee05d8d5..3de4a6381 100644 --- a/llama_toolchain/inference/meta_reference/model_parallel.py +++ b/llama_toolchain/inference/meta_reference/model_parallel.py @@ -10,9 +10,9 @@ from dataclasses import dataclass from functools import partial from typing import Generator, List, Optional -from llama_models.llama3_1.api.chat_format import ChatFormat -from llama_models.llama3_1.api.datatypes import Message -from llama_models.llama3_1.api.tokenizer import Tokenizer +from llama_models.llama3.api.chat_format import ChatFormat +from llama_models.llama3.api.datatypes import Message +from llama_models.llama3.api.tokenizer import Tokenizer from llama_models.sku_list import resolve_model from .config import MetaReferenceImplConfig diff --git a/llama_toolchain/inference/ollama/ollama.py b/llama_toolchain/inference/ollama/ollama.py index 64f24bee4..8901d5c02 100644 --- a/llama_toolchain/inference/ollama/ollama.py +++ b/llama_toolchain/inference/ollama/ollama.py @@ -9,15 +9,17 @@ from typing import AsyncGenerator, Dict import httpx -from llama_models.llama3_1.api.datatypes import ( +from llama_models.llama3.api.datatypes import ( BuiltinTool, CompletionMessage, Message, StopReason, ToolCall, ) -from llama_models.llama3_1.api.tool_utils import ToolUtils +from llama_models.llama3.api.tool_utils import ToolUtils from llama_models.sku_list import resolve_model +from ollama import AsyncClient + from llama_toolchain.distribution.datatypes import Api, ProviderSpec from llama_toolchain.inference.api import ( ChatCompletionRequest, @@ -30,7 +32,6 @@ from llama_toolchain.inference.api import ( ToolCallDelta, ToolCallParseStatus, ) -from ollama import AsyncClient from .config import OllamaImplConfig @@ -64,10 +65,10 @@ class OllamaInference(Inference): async def initialize(self) -> None: try: await self.client.ps() - except httpx.ConnectError: + except httpx.ConnectError as e: raise RuntimeError( "Ollama Server is not running, start it using `ollama serve` in a separate terminal" - ) + ) from e async def shutdown(self) -> None: pass diff --git a/llama_toolchain/inference/quantization/loader.py b/llama_toolchain/inference/quantization/loader.py index 583123df6..3645344aa 100644 --- a/llama_toolchain/inference/quantization/loader.py +++ b/llama_toolchain/inference/quantization/loader.py @@ -13,7 +13,7 @@ from typing import Optional import torch from fairscale.nn.model_parallel.mappings import reduce_from_model_parallel_region -from llama_models.llama3_1.api.model import Transformer, TransformerBlock +from llama_models.llama3.api.model import Transformer, TransformerBlock from llama_toolchain.inference.api.config import ( CheckpointQuantizationFormat, diff --git a/llama_toolchain/post_training/api/endpoints.py b/llama_toolchain/post_training/api/endpoints.py index 0512003d3..e451def17 100644 --- a/llama_toolchain/post_training/api/endpoints.py +++ b/llama_toolchain/post_training/api/endpoints.py @@ -12,7 +12,7 @@ from llama_models.schema_utils import json_schema_type, webmethod from pydantic import BaseModel, Field -from llama_models.llama3_1.api.datatypes import * # noqa: F403 +from llama_models.llama3.api.datatypes import * # noqa: F403 from llama_toolchain.dataset.api.datatypes import * # noqa: F403 from llama_toolchain.common.training_types import * # noqa: F403 from .datatypes import * # noqa: F403 diff --git a/llama_toolchain/reward_scoring/api/datatypes.py b/llama_toolchain/reward_scoring/api/datatypes.py index 3359d4fc9..2ce698d47 100644 --- a/llama_toolchain/reward_scoring/api/datatypes.py +++ b/llama_toolchain/reward_scoring/api/datatypes.py @@ -10,7 +10,7 @@ from llama_models.schema_utils import json_schema_type from pydantic import BaseModel -from llama_models.llama3_1.api.datatypes import * # noqa: F403 +from llama_models.llama3.api.datatypes import * # noqa: F403 @json_schema_type diff --git a/llama_toolchain/safety/api/datatypes.py b/llama_toolchain/safety/api/datatypes.py index c0d23f589..5deecc2b3 100644 --- a/llama_toolchain/safety/api/datatypes.py +++ b/llama_toolchain/safety/api/datatypes.py @@ -7,7 +7,7 @@ from enum import Enum from typing import Dict, Optional, Union -from llama_models.llama3_1.api.datatypes import ToolParamDefinition +from llama_models.llama3.api.datatypes import ToolParamDefinition from llama_models.schema_utils import json_schema_type from pydantic import BaseModel, validator diff --git a/llama_toolchain/safety/api/endpoints.py b/llama_toolchain/safety/api/endpoints.py index 11c1282a1..a282a7968 100644 --- a/llama_toolchain/safety/api/endpoints.py +++ b/llama_toolchain/safety/api/endpoints.py @@ -7,7 +7,7 @@ from .datatypes import * # noqa: F403 from typing import List, Protocol -from llama_models.llama3_1.api.datatypes import Message +from llama_models.llama3.api.datatypes import Message # this dependency is annoying and we need a forked up version anyway from llama_models.schema_utils import webmethod diff --git a/llama_toolchain/safety/client.py b/llama_toolchain/safety/client.py index 2bceebc68..5d86f9291 100644 --- a/llama_toolchain/safety/client.py +++ b/llama_toolchain/safety/client.py @@ -9,7 +9,7 @@ import asyncio import fire import httpx -from llama_models.llama3_1.api.datatypes import UserMessage +from llama_models.llama3.api.datatypes import UserMessage from termcolor import cprint from .api import ( diff --git a/llama_toolchain/safety/meta_reference/shields/base.py b/llama_toolchain/safety/meta_reference/shields/base.py index ce19a3676..245373b13 100644 --- a/llama_toolchain/safety/meta_reference/shields/base.py +++ b/llama_toolchain/safety/meta_reference/shields/base.py @@ -7,7 +7,7 @@ from abc import ABC, abstractmethod from typing import List, Union -from llama_models.llama3_1.api.datatypes import Attachment, Message +from llama_models.llama3.api.datatypes import Attachment, Message from llama_toolchain.safety.api.datatypes import * # noqa: F403 CANNED_RESPONSE_TEXT = "I can't answer that. Can I help with something else?" diff --git a/llama_toolchain/safety/meta_reference/shields/contrib/third_party_shield.py b/llama_toolchain/safety/meta_reference/shields/contrib/third_party_shield.py index 789fa5f07..61a5977ed 100644 --- a/llama_toolchain/safety/meta_reference/shields/contrib/third_party_shield.py +++ b/llama_toolchain/safety/meta_reference/shields/contrib/third_party_shield.py @@ -6,7 +6,7 @@ from typing import List -from llama_models.llama3_1.api.datatypes import Message +from llama_models.llama3.api.datatypes import Message from llama_toolchain.safety.meta_reference.shields.base import ( OnViolationAction, diff --git a/llama_toolchain/safety/meta_reference/shields/llama_guard.py b/llama_toolchain/safety/meta_reference/shields/llama_guard.py index 56126abde..a78b8127d 100644 --- a/llama_toolchain/safety/meta_reference/shields/llama_guard.py +++ b/llama_toolchain/safety/meta_reference/shields/llama_guard.py @@ -10,7 +10,7 @@ from string import Template from typing import List, Optional import torch -from llama_models.llama3_1.api.datatypes import Message, Role +from llama_models.llama3.api.datatypes import Message, Role from transformers import AutoModelForCausalLM, AutoTokenizer from .base import CANNED_RESPONSE_TEXT, OnViolationAction, ShieldBase, ShieldResponse diff --git a/llama_toolchain/safety/meta_reference/shields/prompt_guard.py b/llama_toolchain/safety/meta_reference/shields/prompt_guard.py index 0acc1e488..b9f5dd5a5 100644 --- a/llama_toolchain/safety/meta_reference/shields/prompt_guard.py +++ b/llama_toolchain/safety/meta_reference/shields/prompt_guard.py @@ -9,7 +9,7 @@ from typing import List import torch -from llama_models.llama3_1.api.datatypes import Message +from llama_models.llama3.api.datatypes import Message from termcolor import cprint from transformers import AutoModelForSequenceClassification, AutoTokenizer diff --git a/llama_toolchain/synthetic_data_generation/api/endpoints.py b/llama_toolchain/synthetic_data_generation/api/endpoints.py index 8eada05cf..91585a943 100644 --- a/llama_toolchain/synthetic_data_generation/api/endpoints.py +++ b/llama_toolchain/synthetic_data_generation/api/endpoints.py @@ -10,7 +10,7 @@ from llama_models.schema_utils import json_schema_type, webmethod from pydantic import BaseModel -from llama_models.llama3_1.api.datatypes import * # noqa: F403 +from llama_models.llama3.api.datatypes import * # noqa: F403 from llama_toolchain.reward_scoring.api.datatypes import * # noqa: F403 from .datatypes import * # noqa: F403