llama_models.llama3_1 -> llama_models.llama3

This commit is contained in:
Ashwin Bharambe 2024-08-19 10:55:37 -07:00
parent f502716cf7
commit 38244c3161
27 changed files with 44 additions and 42 deletions

View file

@ -13,7 +13,7 @@ import fire
import httpx
from llama_models.llama3_1.api.datatypes import (
from llama_models.llama3.api.datatypes import (
BuiltinTool,
SamplingParams,
ToolParamDefinition,

View file

@ -6,16 +6,16 @@
from typing import Optional
from llama_models.llama3_1.api.datatypes import ToolResponseMessage
from llama_models.llama3_1.api.tool_utils import ToolUtils
from llama_models.llama3.api.datatypes import ToolResponseMessage
from llama_models.llama3.api.tool_utils import ToolUtils
from termcolor import cprint
from llama_toolchain.agentic_system.api import (
AgenticSystemTurnResponseEventType,
StepType,
)
from termcolor import cprint
class LogEvent:
def __init__(

View file

@ -6,7 +6,7 @@
from typing import List
from llama_models.llama3_1.api.datatypes import Message, Role, UserMessage
from llama_models.llama3.api.datatypes import Message, Role, UserMessage
from termcolor import cprint
from llama_toolchain.safety.api.datatypes import (

View file

@ -9,7 +9,7 @@ import json
from abc import abstractmethod
from typing import Dict, List
from llama_models.llama3_1.api.datatypes import * # noqa: F403
from llama_models.llama3.api.datatypes import * # noqa: F403
from llama_toolchain.agentic_system.api import * # noqa: F403
# TODO: this is symptomatic of us needing to pull more tooling related utilities

View file

@ -6,7 +6,7 @@
from typing import Any, AsyncGenerator, List
from llama_models.llama3_1.api.datatypes import StopReason, ToolResponseMessage
from llama_models.llama3.api.datatypes import StopReason, ToolResponseMessage
from llama_toolchain.agentic_system.api import (
AgenticSystem,

View file

@ -7,7 +7,7 @@
import uuid
from typing import Any, List, Optional
from llama_models.llama3_1.api.datatypes import BuiltinTool, Message, SamplingParams
from llama_models.llama3.api.datatypes import BuiltinTool, Message, SamplingParams
from llama_toolchain.agentic_system.api import (
AgenticSystemCreateRequest,

View file

@ -7,10 +7,10 @@
import argparse
import textwrap
from llama_toolchain.cli.subcommand import Subcommand
from termcolor import colored
from llama_toolchain.cli.subcommand import Subcommand
class ModelTemplate(Subcommand):
"""Llama model cli for describe a model template (message formats)"""
@ -48,10 +48,11 @@ class ModelTemplate(Subcommand):
)
def _run_model_template_cmd(self, args: argparse.Namespace) -> None:
from llama_models.llama3_1.api.interface import (
from llama_models.llama3.api.interface import (
list_jinja_templates,
render_jinja_template,
)
from llama_toolchain.cli.table import print_table
if args.name:

View file

@ -7,7 +7,7 @@
from enum import Enum
from typing import Dict, Optional
from llama_models.llama3_1.api.datatypes import URL
from llama_models.llama3.api.datatypes import URL
from llama_models.schema_utils import json_schema_type

View file

@ -4,7 +4,7 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_models.llama3_1.api.datatypes import URL
from llama_models.llama3.api.datatypes import URL
from llama_models.schema_utils import json_schema_type
from pydantic import BaseModel

View file

@ -7,7 +7,7 @@
from enum import Enum
from typing import Any, Dict, Optional
from llama_models.llama3_1.api.datatypes import URL
from llama_models.llama3.api.datatypes import URL
from llama_models.schema_utils import json_schema_type

View file

@ -10,7 +10,7 @@ from llama_models.schema_utils import webmethod
from pydantic import BaseModel
from llama_models.llama3_1.api.datatypes import * # noqa: F403
from llama_models.llama3.api.datatypes import * # noqa: F403
from .datatypes import * # noqa: F403
from llama_toolchain.dataset.api.datatypes import * # noqa: F403
from llama_toolchain.common.training_types import * # noqa: F403

View file

@ -12,7 +12,7 @@ from llama_models.schema_utils import json_schema_type
from pydantic import BaseModel, Field
from typing_extensions import Annotated
from llama_models.llama3_1.api.datatypes import * # noqa: F403
from llama_models.llama3.api.datatypes import * # noqa: F403
class LogProbConfig(BaseModel):

View file

@ -22,11 +22,11 @@ from fairscale.nn.model_parallel.initialize import (
initialize_model_parallel,
model_parallel_is_initialized,
)
from llama_models.llama3_1.api.args import ModelArgs
from llama_models.llama3_1.api.chat_format import ChatFormat, ModelInput
from llama_models.llama3_1.api.datatypes import Message
from llama_models.llama3_1.api.tokenizer import Tokenizer
from llama_models.llama3_1.reference_impl.model import Transformer
from llama_models.llama3.api.args import ModelArgs
from llama_models.llama3.api.chat_format import ChatFormat, ModelInput
from llama_models.llama3.api.datatypes import Message
from llama_models.llama3.api.tokenizer import Tokenizer
from llama_models.llama3.reference_impl.model import Transformer
from llama_models.sku_list import resolve_model
from termcolor import cprint

View file

@ -8,7 +8,7 @@ import asyncio
from typing import AsyncIterator, Dict, Union
from llama_models.llama3_1.api.datatypes import StopReason
from llama_models.llama3.api.datatypes import StopReason
from llama_models.sku_list import resolve_model
from llama_toolchain.distribution.datatypes import Api, ProviderSpec

View file

@ -10,9 +10,9 @@ from dataclasses import dataclass
from functools import partial
from typing import Generator, List, Optional
from llama_models.llama3_1.api.chat_format import ChatFormat
from llama_models.llama3_1.api.datatypes import Message
from llama_models.llama3_1.api.tokenizer import Tokenizer
from llama_models.llama3.api.chat_format import ChatFormat
from llama_models.llama3.api.datatypes import Message
from llama_models.llama3.api.tokenizer import Tokenizer
from llama_models.sku_list import resolve_model
from .config import MetaReferenceImplConfig

View file

@ -9,15 +9,17 @@ from typing import AsyncGenerator, Dict
import httpx
from llama_models.llama3_1.api.datatypes import (
from llama_models.llama3.api.datatypes import (
BuiltinTool,
CompletionMessage,
Message,
StopReason,
ToolCall,
)
from llama_models.llama3_1.api.tool_utils import ToolUtils
from llama_models.llama3.api.tool_utils import ToolUtils
from llama_models.sku_list import resolve_model
from ollama import AsyncClient
from llama_toolchain.distribution.datatypes import Api, ProviderSpec
from llama_toolchain.inference.api import (
ChatCompletionRequest,
@ -30,7 +32,6 @@ from llama_toolchain.inference.api import (
ToolCallDelta,
ToolCallParseStatus,
)
from ollama import AsyncClient
from .config import OllamaImplConfig
@ -64,10 +65,10 @@ class OllamaInference(Inference):
async def initialize(self) -> None:
try:
await self.client.ps()
except httpx.ConnectError:
except httpx.ConnectError as e:
raise RuntimeError(
"Ollama Server is not running, start it using `ollama serve` in a separate terminal"
)
) from e
async def shutdown(self) -> None:
pass

View file

@ -13,7 +13,7 @@ from typing import Optional
import torch
from fairscale.nn.model_parallel.mappings import reduce_from_model_parallel_region
from llama_models.llama3_1.api.model import Transformer, TransformerBlock
from llama_models.llama3.api.model import Transformer, TransformerBlock
from llama_toolchain.inference.api.config import (
CheckpointQuantizationFormat,

View file

@ -12,7 +12,7 @@ from llama_models.schema_utils import json_schema_type, webmethod
from pydantic import BaseModel, Field
from llama_models.llama3_1.api.datatypes import * # noqa: F403
from llama_models.llama3.api.datatypes import * # noqa: F403
from llama_toolchain.dataset.api.datatypes import * # noqa: F403
from llama_toolchain.common.training_types import * # noqa: F403
from .datatypes import * # noqa: F403

View file

@ -10,7 +10,7 @@ from llama_models.schema_utils import json_schema_type
from pydantic import BaseModel
from llama_models.llama3_1.api.datatypes import * # noqa: F403
from llama_models.llama3.api.datatypes import * # noqa: F403
@json_schema_type

View file

@ -7,7 +7,7 @@
from enum import Enum
from typing import Dict, Optional, Union
from llama_models.llama3_1.api.datatypes import ToolParamDefinition
from llama_models.llama3.api.datatypes import ToolParamDefinition
from llama_models.schema_utils import json_schema_type
from pydantic import BaseModel, validator

View file

@ -7,7 +7,7 @@
from .datatypes import * # noqa: F403
from typing import List, Protocol
from llama_models.llama3_1.api.datatypes import Message
from llama_models.llama3.api.datatypes import Message
# this dependency is annoying and we need a forked up version anyway
from llama_models.schema_utils import webmethod

View file

@ -9,7 +9,7 @@ import asyncio
import fire
import httpx
from llama_models.llama3_1.api.datatypes import UserMessage
from llama_models.llama3.api.datatypes import UserMessage
from termcolor import cprint
from .api import (

View file

@ -7,7 +7,7 @@
from abc import ABC, abstractmethod
from typing import List, Union
from llama_models.llama3_1.api.datatypes import Attachment, Message
from llama_models.llama3.api.datatypes import Attachment, Message
from llama_toolchain.safety.api.datatypes import * # noqa: F403
CANNED_RESPONSE_TEXT = "I can't answer that. Can I help with something else?"

View file

@ -6,7 +6,7 @@
from typing import List
from llama_models.llama3_1.api.datatypes import Message
from llama_models.llama3.api.datatypes import Message
from llama_toolchain.safety.meta_reference.shields.base import (
OnViolationAction,

View file

@ -10,7 +10,7 @@ from string import Template
from typing import List, Optional
import torch
from llama_models.llama3_1.api.datatypes import Message, Role
from llama_models.llama3.api.datatypes import Message, Role
from transformers import AutoModelForCausalLM, AutoTokenizer
from .base import CANNED_RESPONSE_TEXT, OnViolationAction, ShieldBase, ShieldResponse

View file

@ -9,7 +9,7 @@ from typing import List
import torch
from llama_models.llama3_1.api.datatypes import Message
from llama_models.llama3.api.datatypes import Message
from termcolor import cprint
from transformers import AutoModelForSequenceClassification, AutoTokenizer

View file

@ -10,7 +10,7 @@ from llama_models.schema_utils import json_schema_type, webmethod
from pydantic import BaseModel
from llama_models.llama3_1.api.datatypes import * # noqa: F403
from llama_models.llama3.api.datatypes import * # noqa: F403
from llama_toolchain.reward_scoring.api.datatypes import * # noqa: F403
from .datatypes import * # noqa: F403