llama_models.llama3_1 -> llama_models.llama3

This commit is contained in:
Ashwin Bharambe 2024-08-19 10:55:37 -07:00
parent f502716cf7
commit 38244c3161
27 changed files with 44 additions and 42 deletions

View file

@ -12,7 +12,7 @@ from llama_models.schema_utils import json_schema_type
from pydantic import BaseModel, Field
from typing_extensions import Annotated
from llama_models.llama3_1.api.datatypes import * # noqa: F403
from llama_models.llama3.api.datatypes import * # noqa: F403
class LogProbConfig(BaseModel):

View file

@ -22,11 +22,11 @@ from fairscale.nn.model_parallel.initialize import (
initialize_model_parallel,
model_parallel_is_initialized,
)
from llama_models.llama3_1.api.args import ModelArgs
from llama_models.llama3_1.api.chat_format import ChatFormat, ModelInput
from llama_models.llama3_1.api.datatypes import Message
from llama_models.llama3_1.api.tokenizer import Tokenizer
from llama_models.llama3_1.reference_impl.model import Transformer
from llama_models.llama3.api.args import ModelArgs
from llama_models.llama3.api.chat_format import ChatFormat, ModelInput
from llama_models.llama3.api.datatypes import Message
from llama_models.llama3.api.tokenizer import Tokenizer
from llama_models.llama3.reference_impl.model import Transformer
from llama_models.sku_list import resolve_model
from termcolor import cprint

View file

@ -8,7 +8,7 @@ import asyncio
from typing import AsyncIterator, Dict, Union
from llama_models.llama3_1.api.datatypes import StopReason
from llama_models.llama3.api.datatypes import StopReason
from llama_models.sku_list import resolve_model
from llama_toolchain.distribution.datatypes import Api, ProviderSpec

View file

@ -10,9 +10,9 @@ from dataclasses import dataclass
from functools import partial
from typing import Generator, List, Optional
from llama_models.llama3_1.api.chat_format import ChatFormat
from llama_models.llama3_1.api.datatypes import Message
from llama_models.llama3_1.api.tokenizer import Tokenizer
from llama_models.llama3.api.chat_format import ChatFormat
from llama_models.llama3.api.datatypes import Message
from llama_models.llama3.api.tokenizer import Tokenizer
from llama_models.sku_list import resolve_model
from .config import MetaReferenceImplConfig

View file

@ -9,15 +9,17 @@ from typing import AsyncGenerator, Dict
import httpx
from llama_models.llama3_1.api.datatypes import (
from llama_models.llama3.api.datatypes import (
BuiltinTool,
CompletionMessage,
Message,
StopReason,
ToolCall,
)
from llama_models.llama3_1.api.tool_utils import ToolUtils
from llama_models.llama3.api.tool_utils import ToolUtils
from llama_models.sku_list import resolve_model
from ollama import AsyncClient
from llama_toolchain.distribution.datatypes import Api, ProviderSpec
from llama_toolchain.inference.api import (
ChatCompletionRequest,
@ -30,7 +32,6 @@ from llama_toolchain.inference.api import (
ToolCallDelta,
ToolCallParseStatus,
)
from ollama import AsyncClient
from .config import OllamaImplConfig
@ -64,10 +65,10 @@ class OllamaInference(Inference):
async def initialize(self) -> None:
try:
await self.client.ps()
except httpx.ConnectError:
except httpx.ConnectError as e:
raise RuntimeError(
"Ollama Server is not running, start it using `ollama serve` in a separate terminal"
)
) from e
async def shutdown(self) -> None:
pass

View file

@ -13,7 +13,7 @@ from typing import Optional
import torch
from fairscale.nn.model_parallel.mappings import reduce_from_model_parallel_region
from llama_models.llama3_1.api.model import Transformer, TransformerBlock
from llama_models.llama3.api.model import Transformer, TransformerBlock
from llama_toolchain.inference.api.config import (
CheckpointQuantizationFormat,