llama_models.llama3_1 -> llama_models.llama3

This commit is contained in:
Ashwin Bharambe 2024-08-19 10:55:37 -07:00
parent f502716cf7
commit 38244c3161
27 changed files with 44 additions and 42 deletions

View file

@ -13,7 +13,7 @@ import fire
import httpx import httpx
from llama_models.llama3_1.api.datatypes import ( from llama_models.llama3.api.datatypes import (
BuiltinTool, BuiltinTool,
SamplingParams, SamplingParams,
ToolParamDefinition, ToolParamDefinition,

View file

@ -6,16 +6,16 @@
from typing import Optional from typing import Optional
from llama_models.llama3_1.api.datatypes import ToolResponseMessage from llama_models.llama3.api.datatypes import ToolResponseMessage
from llama_models.llama3_1.api.tool_utils import ToolUtils from llama_models.llama3.api.tool_utils import ToolUtils
from termcolor import cprint
from llama_toolchain.agentic_system.api import ( from llama_toolchain.agentic_system.api import (
AgenticSystemTurnResponseEventType, AgenticSystemTurnResponseEventType,
StepType, StepType,
) )
from termcolor import cprint
class LogEvent: class LogEvent:
def __init__( def __init__(

View file

@ -6,7 +6,7 @@
from typing import List from typing import List
from llama_models.llama3_1.api.datatypes import Message, Role, UserMessage from llama_models.llama3.api.datatypes import Message, Role, UserMessage
from termcolor import cprint from termcolor import cprint
from llama_toolchain.safety.api.datatypes import ( from llama_toolchain.safety.api.datatypes import (

View file

@ -9,7 +9,7 @@ import json
from abc import abstractmethod from abc import abstractmethod
from typing import Dict, List from typing import Dict, List
from llama_models.llama3_1.api.datatypes import * # noqa: F403 from llama_models.llama3.api.datatypes import * # noqa: F403
from llama_toolchain.agentic_system.api import * # noqa: F403 from llama_toolchain.agentic_system.api import * # noqa: F403
# TODO: this is symptomatic of us needing to pull more tooling related utilities # TODO: this is symptomatic of us needing to pull more tooling related utilities

View file

@ -6,7 +6,7 @@
from typing import Any, AsyncGenerator, List from typing import Any, AsyncGenerator, List
from llama_models.llama3_1.api.datatypes import StopReason, ToolResponseMessage from llama_models.llama3.api.datatypes import StopReason, ToolResponseMessage
from llama_toolchain.agentic_system.api import ( from llama_toolchain.agentic_system.api import (
AgenticSystem, AgenticSystem,

View file

@ -7,7 +7,7 @@
import uuid import uuid
from typing import Any, List, Optional from typing import Any, List, Optional
from llama_models.llama3_1.api.datatypes import BuiltinTool, Message, SamplingParams from llama_models.llama3.api.datatypes import BuiltinTool, Message, SamplingParams
from llama_toolchain.agentic_system.api import ( from llama_toolchain.agentic_system.api import (
AgenticSystemCreateRequest, AgenticSystemCreateRequest,

View file

@ -7,10 +7,10 @@
import argparse import argparse
import textwrap import textwrap
from llama_toolchain.cli.subcommand import Subcommand
from termcolor import colored from termcolor import colored
from llama_toolchain.cli.subcommand import Subcommand
class ModelTemplate(Subcommand): class ModelTemplate(Subcommand):
"""Llama model cli for describe a model template (message formats)""" """Llama model cli for describe a model template (message formats)"""
@ -48,10 +48,11 @@ class ModelTemplate(Subcommand):
) )
def _run_model_template_cmd(self, args: argparse.Namespace) -> None: def _run_model_template_cmd(self, args: argparse.Namespace) -> None:
from llama_models.llama3_1.api.interface import ( from llama_models.llama3.api.interface import (
list_jinja_templates, list_jinja_templates,
render_jinja_template, render_jinja_template,
) )
from llama_toolchain.cli.table import print_table from llama_toolchain.cli.table import print_table
if args.name: if args.name:

View file

@ -7,7 +7,7 @@
from enum import Enum from enum import Enum
from typing import Dict, Optional from typing import Dict, Optional
from llama_models.llama3_1.api.datatypes import URL from llama_models.llama3.api.datatypes import URL
from llama_models.schema_utils import json_schema_type from llama_models.schema_utils import json_schema_type

View file

@ -4,7 +4,7 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
from llama_models.llama3_1.api.datatypes import URL from llama_models.llama3.api.datatypes import URL
from llama_models.schema_utils import json_schema_type from llama_models.schema_utils import json_schema_type
from pydantic import BaseModel from pydantic import BaseModel

View file

@ -7,7 +7,7 @@
from enum import Enum from enum import Enum
from typing import Any, Dict, Optional from typing import Any, Dict, Optional
from llama_models.llama3_1.api.datatypes import URL from llama_models.llama3.api.datatypes import URL
from llama_models.schema_utils import json_schema_type from llama_models.schema_utils import json_schema_type

View file

@ -10,7 +10,7 @@ from llama_models.schema_utils import webmethod
from pydantic import BaseModel from pydantic import BaseModel
from llama_models.llama3_1.api.datatypes import * # noqa: F403 from llama_models.llama3.api.datatypes import * # noqa: F403
from .datatypes import * # noqa: F403 from .datatypes import * # noqa: F403
from llama_toolchain.dataset.api.datatypes import * # noqa: F403 from llama_toolchain.dataset.api.datatypes import * # noqa: F403
from llama_toolchain.common.training_types import * # noqa: F403 from llama_toolchain.common.training_types import * # noqa: F403

View file

@ -12,7 +12,7 @@ from llama_models.schema_utils import json_schema_type
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from typing_extensions import Annotated from typing_extensions import Annotated
from llama_models.llama3_1.api.datatypes import * # noqa: F403 from llama_models.llama3.api.datatypes import * # noqa: F403
class LogProbConfig(BaseModel): class LogProbConfig(BaseModel):

View file

@ -22,11 +22,11 @@ from fairscale.nn.model_parallel.initialize import (
initialize_model_parallel, initialize_model_parallel,
model_parallel_is_initialized, model_parallel_is_initialized,
) )
from llama_models.llama3_1.api.args import ModelArgs from llama_models.llama3.api.args import ModelArgs
from llama_models.llama3_1.api.chat_format import ChatFormat, ModelInput from llama_models.llama3.api.chat_format import ChatFormat, ModelInput
from llama_models.llama3_1.api.datatypes import Message from llama_models.llama3.api.datatypes import Message
from llama_models.llama3_1.api.tokenizer import Tokenizer from llama_models.llama3.api.tokenizer import Tokenizer
from llama_models.llama3_1.reference_impl.model import Transformer from llama_models.llama3.reference_impl.model import Transformer
from llama_models.sku_list import resolve_model from llama_models.sku_list import resolve_model
from termcolor import cprint from termcolor import cprint

View file

@ -8,7 +8,7 @@ import asyncio
from typing import AsyncIterator, Dict, Union from typing import AsyncIterator, Dict, Union
from llama_models.llama3_1.api.datatypes import StopReason from llama_models.llama3.api.datatypes import StopReason
from llama_models.sku_list import resolve_model from llama_models.sku_list import resolve_model
from llama_toolchain.distribution.datatypes import Api, ProviderSpec from llama_toolchain.distribution.datatypes import Api, ProviderSpec

View file

@ -10,9 +10,9 @@ from dataclasses import dataclass
from functools import partial from functools import partial
from typing import Generator, List, Optional from typing import Generator, List, Optional
from llama_models.llama3_1.api.chat_format import ChatFormat from llama_models.llama3.api.chat_format import ChatFormat
from llama_models.llama3_1.api.datatypes import Message from llama_models.llama3.api.datatypes import Message
from llama_models.llama3_1.api.tokenizer import Tokenizer from llama_models.llama3.api.tokenizer import Tokenizer
from llama_models.sku_list import resolve_model from llama_models.sku_list import resolve_model
from .config import MetaReferenceImplConfig from .config import MetaReferenceImplConfig

View file

@ -9,15 +9,17 @@ from typing import AsyncGenerator, Dict
import httpx import httpx
from llama_models.llama3_1.api.datatypes import ( from llama_models.llama3.api.datatypes import (
BuiltinTool, BuiltinTool,
CompletionMessage, CompletionMessage,
Message, Message,
StopReason, StopReason,
ToolCall, ToolCall,
) )
from llama_models.llama3_1.api.tool_utils import ToolUtils from llama_models.llama3.api.tool_utils import ToolUtils
from llama_models.sku_list import resolve_model from llama_models.sku_list import resolve_model
from ollama import AsyncClient
from llama_toolchain.distribution.datatypes import Api, ProviderSpec from llama_toolchain.distribution.datatypes import Api, ProviderSpec
from llama_toolchain.inference.api import ( from llama_toolchain.inference.api import (
ChatCompletionRequest, ChatCompletionRequest,
@ -30,7 +32,6 @@ from llama_toolchain.inference.api import (
ToolCallDelta, ToolCallDelta,
ToolCallParseStatus, ToolCallParseStatus,
) )
from ollama import AsyncClient
from .config import OllamaImplConfig from .config import OllamaImplConfig
@ -64,10 +65,10 @@ class OllamaInference(Inference):
async def initialize(self) -> None: async def initialize(self) -> None:
try: try:
await self.client.ps() await self.client.ps()
except httpx.ConnectError: except httpx.ConnectError as e:
raise RuntimeError( raise RuntimeError(
"Ollama Server is not running, start it using `ollama serve` in a separate terminal" "Ollama Server is not running, start it using `ollama serve` in a separate terminal"
) ) from e
async def shutdown(self) -> None: async def shutdown(self) -> None:
pass pass

View file

@ -13,7 +13,7 @@ from typing import Optional
import torch import torch
from fairscale.nn.model_parallel.mappings import reduce_from_model_parallel_region from fairscale.nn.model_parallel.mappings import reduce_from_model_parallel_region
from llama_models.llama3_1.api.model import Transformer, TransformerBlock from llama_models.llama3.api.model import Transformer, TransformerBlock
from llama_toolchain.inference.api.config import ( from llama_toolchain.inference.api.config import (
CheckpointQuantizationFormat, CheckpointQuantizationFormat,

View file

@ -12,7 +12,7 @@ from llama_models.schema_utils import json_schema_type, webmethod
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from llama_models.llama3_1.api.datatypes import * # noqa: F403 from llama_models.llama3.api.datatypes import * # noqa: F403
from llama_toolchain.dataset.api.datatypes import * # noqa: F403 from llama_toolchain.dataset.api.datatypes import * # noqa: F403
from llama_toolchain.common.training_types import * # noqa: F403 from llama_toolchain.common.training_types import * # noqa: F403
from .datatypes import * # noqa: F403 from .datatypes import * # noqa: F403

View file

@ -10,7 +10,7 @@ from llama_models.schema_utils import json_schema_type
from pydantic import BaseModel from pydantic import BaseModel
from llama_models.llama3_1.api.datatypes import * # noqa: F403 from llama_models.llama3.api.datatypes import * # noqa: F403
@json_schema_type @json_schema_type

View file

@ -7,7 +7,7 @@
from enum import Enum from enum import Enum
from typing import Dict, Optional, Union from typing import Dict, Optional, Union
from llama_models.llama3_1.api.datatypes import ToolParamDefinition from llama_models.llama3.api.datatypes import ToolParamDefinition
from llama_models.schema_utils import json_schema_type from llama_models.schema_utils import json_schema_type
from pydantic import BaseModel, validator from pydantic import BaseModel, validator

View file

@ -7,7 +7,7 @@
from .datatypes import * # noqa: F403 from .datatypes import * # noqa: F403
from typing import List, Protocol from typing import List, Protocol
from llama_models.llama3_1.api.datatypes import Message from llama_models.llama3.api.datatypes import Message
# this dependency is annoying and we need a forked up version anyway # this dependency is annoying and we need a forked up version anyway
from llama_models.schema_utils import webmethod from llama_models.schema_utils import webmethod

View file

@ -9,7 +9,7 @@ import asyncio
import fire import fire
import httpx import httpx
from llama_models.llama3_1.api.datatypes import UserMessage from llama_models.llama3.api.datatypes import UserMessage
from termcolor import cprint from termcolor import cprint
from .api import ( from .api import (

View file

@ -7,7 +7,7 @@
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import List, Union from typing import List, Union
from llama_models.llama3_1.api.datatypes import Attachment, Message from llama_models.llama3.api.datatypes import Attachment, Message
from llama_toolchain.safety.api.datatypes import * # noqa: F403 from llama_toolchain.safety.api.datatypes import * # noqa: F403
CANNED_RESPONSE_TEXT = "I can't answer that. Can I help with something else?" CANNED_RESPONSE_TEXT = "I can't answer that. Can I help with something else?"

View file

@ -6,7 +6,7 @@
from typing import List from typing import List
from llama_models.llama3_1.api.datatypes import Message from llama_models.llama3.api.datatypes import Message
from llama_toolchain.safety.meta_reference.shields.base import ( from llama_toolchain.safety.meta_reference.shields.base import (
OnViolationAction, OnViolationAction,

View file

@ -10,7 +10,7 @@ from string import Template
from typing import List, Optional from typing import List, Optional
import torch import torch
from llama_models.llama3_1.api.datatypes import Message, Role from llama_models.llama3.api.datatypes import Message, Role
from transformers import AutoModelForCausalLM, AutoTokenizer from transformers import AutoModelForCausalLM, AutoTokenizer
from .base import CANNED_RESPONSE_TEXT, OnViolationAction, ShieldBase, ShieldResponse from .base import CANNED_RESPONSE_TEXT, OnViolationAction, ShieldBase, ShieldResponse

View file

@ -9,7 +9,7 @@ from typing import List
import torch import torch
from llama_models.llama3_1.api.datatypes import Message from llama_models.llama3.api.datatypes import Message
from termcolor import cprint from termcolor import cprint
from transformers import AutoModelForSequenceClassification, AutoTokenizer from transformers import AutoModelForSequenceClassification, AutoTokenizer

View file

@ -10,7 +10,7 @@ from llama_models.schema_utils import json_schema_type, webmethod
from pydantic import BaseModel from pydantic import BaseModel
from llama_models.llama3_1.api.datatypes import * # noqa: F403 from llama_models.llama3.api.datatypes import * # noqa: F403
from llama_toolchain.reward_scoring.api.datatypes import * # noqa: F403 from llama_toolchain.reward_scoring.api.datatypes import * # noqa: F403
from .datatypes import * # noqa: F403 from .datatypes import * # noqa: F403