mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-18 17:29:47 +00:00
Merge branch 'meta-llama:main' into main
This commit is contained in:
commit
54e48d555d
110 changed files with 12606 additions and 747 deletions
|
|
@ -28,6 +28,7 @@ class Api(Enum):
|
|||
datasetio = "datasetio"
|
||||
scoring = "scoring"
|
||||
eval = "eval"
|
||||
post_training = "post_training"
|
||||
|
||||
telemetry = "telemetry"
|
||||
|
||||
|
|
@ -200,10 +201,13 @@ API responses, specify the adapter here.
|
|||
return self.adapter.provider_data_validator
|
||||
|
||||
|
||||
def remote_provider_spec(api: Api, adapter: AdapterSpec) -> RemoteProviderSpec:
|
||||
def remote_provider_spec(
|
||||
api: Api, adapter: AdapterSpec, api_dependencies: Optional[List[Api]] = None
|
||||
) -> RemoteProviderSpec:
|
||||
return RemoteProviderSpec(
|
||||
api=api,
|
||||
provider_type=f"remote::{adapter.adapter_type}",
|
||||
config_class=adapter.config_class,
|
||||
adapter=adapter,
|
||||
api_dependencies=api_dependencies or [],
|
||||
)
|
||||
|
|
|
|||
|
|
@ -16,12 +16,14 @@ from llama_models.llama3.api.datatypes import * # noqa: F403
|
|||
from llama_stack.providers.utils.inference.model_registry import build_model_alias
|
||||
from llama_stack.apis.inference import * # noqa: F403
|
||||
from llama_stack.providers.datatypes import ModelsProtocolPrivate
|
||||
from llama_stack.providers.utils.inference.embedding_mixin import (
|
||||
SentenceTransformerEmbeddingMixin,
|
||||
)
|
||||
from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper
|
||||
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||
convert_image_media_to_url,
|
||||
request_has_media,
|
||||
)
|
||||
|
||||
from .config import MetaReferenceInferenceConfig
|
||||
from .generation import Llama
|
||||
from .model_parallel import LlamaModelParallelGenerator
|
||||
|
|
@ -32,12 +34,17 @@ log = logging.getLogger(__name__)
|
|||
SEMAPHORE = asyncio.Semaphore(1)
|
||||
|
||||
|
||||
class MetaReferenceInferenceImpl(Inference, ModelRegistryHelper, ModelsProtocolPrivate):
|
||||
class MetaReferenceInferenceImpl(
|
||||
SentenceTransformerEmbeddingMixin,
|
||||
Inference,
|
||||
ModelsProtocolPrivate,
|
||||
):
|
||||
def __init__(self, config: MetaReferenceInferenceConfig) -> None:
|
||||
self.config = config
|
||||
model = resolve_model(config.model)
|
||||
ModelRegistryHelper.__init__(
|
||||
self,
|
||||
if model is None:
|
||||
raise RuntimeError(f"Unknown model: {config.model}, Run `llama model list`")
|
||||
self.model_registry_helper = ModelRegistryHelper(
|
||||
[
|
||||
build_model_alias(
|
||||
model.descriptor(),
|
||||
|
|
@ -45,8 +52,6 @@ class MetaReferenceInferenceImpl(Inference, ModelRegistryHelper, ModelsProtocolP
|
|||
)
|
||||
],
|
||||
)
|
||||
if model is None:
|
||||
raise RuntimeError(f"Unknown model: {config.model}, Run `llama model list`")
|
||||
self.model = model
|
||||
# verify that the checkpoint actually is for this model lol
|
||||
|
||||
|
|
@ -76,6 +81,12 @@ class MetaReferenceInferenceImpl(Inference, ModelRegistryHelper, ModelsProtocolP
|
|||
async def unregister_model(self, model_id: str) -> None:
|
||||
pass
|
||||
|
||||
async def register_model(self, model: Model) -> Model:
|
||||
model = await self.model_registry_helper.register_model(model)
|
||||
if model.model_type == ModelType.embedding:
|
||||
self._load_sentence_transformer_model(model.provider_resource_id)
|
||||
return model
|
||||
|
||||
async def completion(
|
||||
self,
|
||||
model_id: str,
|
||||
|
|
@ -394,13 +405,6 @@ class MetaReferenceInferenceImpl(Inference, ModelRegistryHelper, ModelsProtocolP
|
|||
for x in impl():
|
||||
yield x
|
||||
|
||||
async def embeddings(
|
||||
self,
|
||||
model_id: str,
|
||||
contents: List[InterleavedTextMedia],
|
||||
) -> EmbeddingsResponse:
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
async def request_with_localized_media(
|
||||
request: Union[ChatCompletionRequest, CompletionRequest],
|
||||
|
|
|
|||
|
|
@ -0,0 +1,20 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from llama_stack.providers.inline.inference.sentence_transformers.config import (
|
||||
SentenceTransformersInferenceConfig,
|
||||
)
|
||||
|
||||
|
||||
async def get_provider_impl(
|
||||
config: SentenceTransformersInferenceConfig,
|
||||
_deps,
|
||||
):
|
||||
from .sentence_transformers import SentenceTransformersInferenceImpl
|
||||
|
||||
impl = SentenceTransformersInferenceImpl(config)
|
||||
await impl.initialize()
|
||||
return impl
|
||||
|
|
@ -0,0 +1,16 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any, Dict
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class SentenceTransformersInferenceConfig(BaseModel):
|
||||
|
||||
@classmethod
|
||||
def sample_run_config(cls) -> Dict[str, Any]:
|
||||
return {}
|
||||
|
|
@ -0,0 +1,74 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import logging
|
||||
from typing import AsyncGenerator, List, Optional, Union
|
||||
|
||||
from llama_stack.apis.inference import (
|
||||
CompletionResponse,
|
||||
Inference,
|
||||
LogProbConfig,
|
||||
Message,
|
||||
ResponseFormat,
|
||||
SamplingParams,
|
||||
ToolChoice,
|
||||
ToolDefinition,
|
||||
ToolPromptFormat,
|
||||
)
|
||||
from llama_stack.providers.datatypes import Model, ModelsProtocolPrivate
|
||||
from llama_stack.providers.utils.inference.embedding_mixin import (
|
||||
SentenceTransformerEmbeddingMixin,
|
||||
)
|
||||
from .config import SentenceTransformersInferenceConfig
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class SentenceTransformersInferenceImpl(
|
||||
SentenceTransformerEmbeddingMixin,
|
||||
Inference,
|
||||
ModelsProtocolPrivate,
|
||||
):
|
||||
def __init__(self, config: SentenceTransformersInferenceConfig) -> None:
|
||||
self.config = config
|
||||
|
||||
async def initialize(self) -> None:
|
||||
pass
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
pass
|
||||
|
||||
async def register_model(self, model: Model) -> None:
|
||||
_ = self._load_sentence_transformer_model(model.provider_resource_id)
|
||||
return model
|
||||
|
||||
async def unregister_model(self, model_id: str) -> None:
|
||||
pass
|
||||
|
||||
async def completion(
|
||||
self,
|
||||
model_id: str,
|
||||
content: str,
|
||||
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
||||
response_format: Optional[ResponseFormat] = None,
|
||||
stream: Optional[bool] = False,
|
||||
logprobs: Optional[LogProbConfig] = None,
|
||||
) -> Union[CompletionResponse, AsyncGenerator]:
|
||||
raise ValueError("Sentence transformers don't support completion")
|
||||
|
||||
async def chat_completion(
|
||||
self,
|
||||
model_id: str,
|
||||
messages: List[Message],
|
||||
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
||||
response_format: Optional[ResponseFormat] = None,
|
||||
tools: Optional[List[ToolDefinition]] = None,
|
||||
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
|
||||
tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json,
|
||||
stream: Optional[bool] = False,
|
||||
logprobs: Optional[LogProbConfig] = None,
|
||||
) -> AsyncGenerator:
|
||||
raise ValueError("Sentence transformers don't support chat completion")
|
||||
|
|
@ -4,16 +4,19 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Dict
|
||||
|
||||
from llama_stack.providers.datatypes import Api, ProviderSpec
|
||||
from .config import FaissImplConfig
|
||||
|
||||
|
||||
async def get_provider_impl(config: FaissImplConfig, _deps):
|
||||
async def get_provider_impl(config: FaissImplConfig, deps: Dict[Api, ProviderSpec]):
|
||||
from .faiss import FaissMemoryImpl
|
||||
|
||||
assert isinstance(
|
||||
config, FaissImplConfig
|
||||
), f"Unexpected config type: {type(config)}"
|
||||
|
||||
impl = FaissMemoryImpl(config)
|
||||
impl = FaissMemoryImpl(config, deps[Api.inference])
|
||||
await impl.initialize()
|
||||
return impl
|
||||
|
|
|
|||
|
|
@ -19,11 +19,10 @@ from numpy.typing import NDArray
|
|||
from llama_models.llama3.api.datatypes import * # noqa: F403
|
||||
|
||||
from llama_stack.apis.memory import * # noqa: F403
|
||||
from llama_stack.providers.datatypes import MemoryBanksProtocolPrivate
|
||||
from llama_stack.providers.datatypes import Api, MemoryBanksProtocolPrivate
|
||||
from llama_stack.providers.utils.kvstore import kvstore_impl
|
||||
|
||||
from llama_stack.providers.utils.memory.vector_store import (
|
||||
ALL_MINILM_L6_V2_DIMENSION,
|
||||
BankWithIndex,
|
||||
EmbeddingIndex,
|
||||
)
|
||||
|
|
@ -32,7 +31,8 @@ from .config import FaissImplConfig
|
|||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
MEMORY_BANKS_PREFIX = "memory_banks:v1::"
|
||||
MEMORY_BANKS_PREFIX = "memory_banks:v2::"
|
||||
FAISS_INDEX_PREFIX = "faiss_index:v2::"
|
||||
|
||||
|
||||
class FaissIndex(EmbeddingIndex):
|
||||
|
|
@ -56,7 +56,7 @@ class FaissIndex(EmbeddingIndex):
|
|||
if not self.kvstore:
|
||||
return
|
||||
|
||||
index_key = f"faiss_index:v1::{self.bank_id}"
|
||||
index_key = f"{FAISS_INDEX_PREFIX}{self.bank_id}"
|
||||
stored_data = await self.kvstore.get(index_key)
|
||||
|
||||
if stored_data:
|
||||
|
|
@ -85,16 +85,25 @@ class FaissIndex(EmbeddingIndex):
|
|||
"faiss_index": base64.b64encode(buffer.getvalue()).decode("utf-8"),
|
||||
}
|
||||
|
||||
index_key = f"faiss_index:v1::{self.bank_id}"
|
||||
index_key = f"{FAISS_INDEX_PREFIX}{self.bank_id}"
|
||||
await self.kvstore.set(key=index_key, value=json.dumps(data))
|
||||
|
||||
async def delete(self):
|
||||
if not self.kvstore or not self.bank_id:
|
||||
return
|
||||
|
||||
await self.kvstore.delete(f"faiss_index:v1::{self.bank_id}")
|
||||
await self.kvstore.delete(f"{FAISS_INDEX_PREFIX}{self.bank_id}")
|
||||
|
||||
async def add_chunks(self, chunks: List[Chunk], embeddings: NDArray):
|
||||
# Add dimension check
|
||||
embedding_dim = (
|
||||
embeddings.shape[1] if len(embeddings.shape) > 1 else embeddings.shape[0]
|
||||
)
|
||||
if embedding_dim != self.index.d:
|
||||
raise ValueError(
|
||||
f"Embedding dimension mismatch. Expected {self.index.d}, got {embedding_dim}"
|
||||
)
|
||||
|
||||
indexlen = len(self.id_by_index)
|
||||
for i, chunk in enumerate(chunks):
|
||||
self.chunk_by_index[indexlen + i] = chunk
|
||||
|
|
@ -124,8 +133,9 @@ class FaissIndex(EmbeddingIndex):
|
|||
|
||||
|
||||
class FaissMemoryImpl(Memory, MemoryBanksProtocolPrivate):
|
||||
def __init__(self, config: FaissImplConfig) -> None:
|
||||
def __init__(self, config: FaissImplConfig, inference_api: Api.inference) -> None:
|
||||
self.config = config
|
||||
self.inference_api = inference_api
|
||||
self.cache = {}
|
||||
self.kvstore = None
|
||||
|
||||
|
|
@ -139,10 +149,11 @@ class FaissMemoryImpl(Memory, MemoryBanksProtocolPrivate):
|
|||
for bank_data in stored_banks:
|
||||
bank = VectorMemoryBank.model_validate_json(bank_data)
|
||||
index = BankWithIndex(
|
||||
bank=bank,
|
||||
index=await FaissIndex.create(
|
||||
ALL_MINILM_L6_V2_DIMENSION, self.kvstore, bank.identifier
|
||||
bank,
|
||||
await FaissIndex.create(
|
||||
bank.embedding_dimension, self.kvstore, bank.identifier
|
||||
),
|
||||
self.inference_api,
|
||||
)
|
||||
self.cache[bank.identifier] = index
|
||||
|
||||
|
|
@ -166,13 +177,13 @@ class FaissMemoryImpl(Memory, MemoryBanksProtocolPrivate):
|
|||
)
|
||||
|
||||
# Store in cache
|
||||
index = BankWithIndex(
|
||||
bank=memory_bank,
|
||||
index=await FaissIndex.create(
|
||||
ALL_MINILM_L6_V2_DIMENSION, self.kvstore, memory_bank.identifier
|
||||
self.cache[memory_bank.identifier] = BankWithIndex(
|
||||
memory_bank,
|
||||
await FaissIndex.create(
|
||||
memory_bank.embedding_dimension, self.kvstore, memory_bank.identifier
|
||||
),
|
||||
self.inference_api,
|
||||
)
|
||||
self.cache[memory_bank.identifier] = index
|
||||
|
||||
async def list_memory_banks(self) -> List[MemoryBank]:
|
||||
return [i.bank for i in self.cache.values()]
|
||||
|
|
|
|||
|
|
@ -0,0 +1,27 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Dict
|
||||
|
||||
from llama_stack.distribution.datatypes import Api, ProviderSpec
|
||||
|
||||
from .config import TorchtunePostTrainingConfig
|
||||
|
||||
# post_training api and the torchtune provider is still experimental and under heavy development
|
||||
|
||||
|
||||
async def get_provider_impl(
|
||||
config: TorchtunePostTrainingConfig,
|
||||
deps: Dict[Api, ProviderSpec],
|
||||
):
|
||||
from .post_training import TorchtunePostTrainingImpl
|
||||
|
||||
impl = TorchtunePostTrainingImpl(
|
||||
config,
|
||||
deps[Api.datasetio],
|
||||
deps[Api.datasets],
|
||||
)
|
||||
return impl
|
||||
|
|
@ -0,0 +1,157 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import os
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List
|
||||
|
||||
import torch
|
||||
from torchtune import training
|
||||
from torchtune.models import convert_weights
|
||||
from torchtune.training.checkpointing._utils import ModelType, safe_torch_load
|
||||
from torchtune.utils._logging import get_logger
|
||||
|
||||
logger = get_logger("DEBUG")
|
||||
|
||||
|
||||
class TorchtuneCheckpointer:
|
||||
def __init__(
|
||||
self,
|
||||
model_id: str,
|
||||
training_algorithm: str,
|
||||
checkpoint_dir: str,
|
||||
checkpoint_files: List[str],
|
||||
output_dir: str,
|
||||
model_type: str,
|
||||
) -> None:
|
||||
# Fail fast if ``checkpoint_files`` is invalid
|
||||
# TODO: support loading more than one file
|
||||
if len(checkpoint_files) != 1:
|
||||
raise ValueError(
|
||||
"Currently we only support reading from a single torchtune checkpoint file. "
|
||||
f"Got {len(checkpoint_files)} files instead."
|
||||
)
|
||||
self._checkpoint_file = checkpoint_files[0]
|
||||
self._model_id = model_id
|
||||
self._training_algorithm = training_algorithm
|
||||
self._checkpoint_dir = Path(checkpoint_dir)
|
||||
self._model_type = ModelType[model_type]
|
||||
self._output_dir = output_dir
|
||||
# get ckpt paths
|
||||
self._checkpoint_path = Path.joinpath(
|
||||
self._checkpoint_dir, self._checkpoint_file
|
||||
)
|
||||
|
||||
def load_checkpoint(self) -> Dict[str, Any]:
|
||||
"""
|
||||
Load Meta checkpoint from file. Currently only loading from a single file is supported.
|
||||
"""
|
||||
state_dict: Dict[str:Any] = {}
|
||||
model_state_dict = safe_torch_load(self._checkpoint_path)
|
||||
if self._model_type == ModelType.LLAMA3_VISION:
|
||||
from torchtune.models.llama3_2_vision._convert_weights import (
|
||||
llama3_vision_meta_to_tune,
|
||||
)
|
||||
|
||||
state_dict[training.MODEL_KEY] = llama3_vision_meta_to_tune(
|
||||
model_state_dict
|
||||
)
|
||||
else:
|
||||
state_dict[training.MODEL_KEY] = convert_weights.meta_to_tune(
|
||||
model_state_dict
|
||||
)
|
||||
|
||||
# llama3_2 has tied weights, so we need to remove the output.weight key
|
||||
if self._model_type == ModelType.LLAMA3_2:
|
||||
logger.info(
|
||||
"Identified model_type = Llama3_2. Ignoring output.weight in"
|
||||
" checkpoint in favor of the tok_embedding.weight"
|
||||
" tied weights."
|
||||
)
|
||||
state_dict[training.MODEL_KEY].pop("output.weight")
|
||||
|
||||
return state_dict
|
||||
|
||||
def save_checkpoint(
|
||||
self,
|
||||
state_dict: Dict[str, Any],
|
||||
epoch: int,
|
||||
adapter_only: bool = False,
|
||||
) -> str:
|
||||
model_file_path = (
|
||||
Path(self._output_dir)
|
||||
/ f"{self._model_id}-{self._training_algorithm}-{epoch}"
|
||||
)
|
||||
|
||||
model_file_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# copy the related files for inference
|
||||
shutil.copy(
|
||||
Path.joinpath(self._checkpoint_dir, "params.json"),
|
||||
Path.joinpath(model_file_path, "params.json"),
|
||||
)
|
||||
shutil.copy(
|
||||
Path.joinpath(self._checkpoint_dir, "tokenizer.model"),
|
||||
Path.joinpath(model_file_path, "tokenizer.model"),
|
||||
)
|
||||
shutil.copy(
|
||||
Path.joinpath(self._checkpoint_dir, "orig_params.json"),
|
||||
Path.joinpath(model_file_path, "orig_params.json"),
|
||||
)
|
||||
|
||||
if not adapter_only:
|
||||
model_state_dict = state_dict[training.MODEL_KEY]
|
||||
if self._model_type == ModelType.LLAMA3_VISION:
|
||||
from torchtune.models.llama3_2_vision._convert_weights import (
|
||||
llama3_vision_tune_to_meta,
|
||||
)
|
||||
|
||||
state_dict[training.MODEL_KEY] = llama3_vision_tune_to_meta(
|
||||
model_state_dict
|
||||
)
|
||||
else:
|
||||
# llama3_2 has tied weights, so we need to add the output.weight key
|
||||
if (
|
||||
self._model_type == ModelType.LLAMA3_2
|
||||
and "output.weight" not in model_state_dict
|
||||
):
|
||||
model_state_dict["output.weight"] = model_state_dict[
|
||||
"tok_embeddings.weight"
|
||||
]
|
||||
|
||||
state_dict[training.MODEL_KEY] = convert_weights.tune_to_meta(
|
||||
model_state_dict
|
||||
)
|
||||
|
||||
model_file_name = Path.joinpath(model_file_path, "consolidated.00.pth")
|
||||
|
||||
torch.save(state_dict[training.MODEL_KEY], model_file_name)
|
||||
logger.info(
|
||||
"Model checkpoint of size "
|
||||
f"{os.path.getsize(model_file_name) / 1000**3:.2f} GB "
|
||||
f"saved to {model_file_name}"
|
||||
)
|
||||
|
||||
if training.ADAPTER_KEY in state_dict:
|
||||
adapter_file_path = model_file_path / "adapter"
|
||||
adapter_file_path.mkdir(parents=True, exist_ok=True)
|
||||
adapter_file_name = Path.joinpath(adapter_file_path, "adapter.pth")
|
||||
torch.save(state_dict[training.ADAPTER_KEY], adapter_file_name)
|
||||
logger.info(
|
||||
"Adapter checkpoint of size "
|
||||
f"{os.path.getsize(adapter_file_name) / 1000**3:.2f} GB "
|
||||
f"saved to {adapter_file_name}"
|
||||
)
|
||||
|
||||
elif adapter_only:
|
||||
raise ValueError(
|
||||
"Adapter checkpoint not found in state_dict. Please ensure that the state_dict contains adapter weights."
|
||||
)
|
||||
|
||||
print("model_file_path", str(model_file_path))
|
||||
|
||||
return str(model_file_path)
|
||||
|
|
@ -0,0 +1,139 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
# Copyright (c) Meta Platforms, IAny, nc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from enum import Enum
|
||||
from typing import Any, Callable, Dict, List
|
||||
|
||||
import torch
|
||||
from llama_stack.apis.datasets import Datasets
|
||||
from llama_stack.apis.common.type_system import * # noqa
|
||||
from llama_models.datatypes import Model
|
||||
from llama_models.sku_list import resolve_model
|
||||
from llama_stack.apis.common.type_system import ParamType
|
||||
|
||||
from torchtune.models.llama3 import llama3_tokenizer, lora_llama3_8b
|
||||
from torchtune.models.llama3._tokenizer import Llama3Tokenizer
|
||||
from torchtune.models.llama3_2 import lora_llama3_2_3b
|
||||
|
||||
|
||||
class ColumnName(Enum):
|
||||
instruction = "instruction"
|
||||
input = "input"
|
||||
output = "output"
|
||||
text = "text"
|
||||
|
||||
|
||||
class ModelConfig(BaseModel):
|
||||
model_definition: Any
|
||||
tokenizer_type: Any
|
||||
checkpoint_type: str
|
||||
|
||||
|
||||
class DatasetSchema(BaseModel):
|
||||
alpaca: List[Dict[str, ParamType]]
|
||||
|
||||
|
||||
MODEL_CONFIGS: Dict[str, ModelConfig] = {
|
||||
"Llama3.2-3B-Instruct": ModelConfig(
|
||||
model_definition=lora_llama3_2_3b,
|
||||
tokenizer_type=llama3_tokenizer,
|
||||
checkpoint_type="LLAMA3_2",
|
||||
),
|
||||
"Llama-3-8B-Instruct": ModelConfig(
|
||||
model_definition=lora_llama3_8b,
|
||||
tokenizer_type=llama3_tokenizer,
|
||||
checkpoint_type="LLAMA3",
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
EXPECTED_DATASET_SCHEMA = DatasetSchema(
|
||||
alpaca=[
|
||||
{
|
||||
ColumnName.instruction.value: StringType(),
|
||||
ColumnName.input.value: StringType(),
|
||||
ColumnName.output.value: StringType(),
|
||||
ColumnName.text.value: StringType(),
|
||||
},
|
||||
{
|
||||
ColumnName.instruction.value: StringType(),
|
||||
ColumnName.input.value: StringType(),
|
||||
ColumnName.output.value: StringType(),
|
||||
},
|
||||
{
|
||||
ColumnName.instruction.value: StringType(),
|
||||
ColumnName.output.value: StringType(),
|
||||
},
|
||||
]
|
||||
)
|
||||
|
||||
BuildLoraModelCallable = Callable[..., torch.nn.Module]
|
||||
BuildTokenizerCallable = Callable[..., Llama3Tokenizer]
|
||||
|
||||
|
||||
def _validate_model_id(model_id: str) -> Model:
|
||||
model = resolve_model(model_id)
|
||||
if model is None or model.core_model_id.value not in MODEL_CONFIGS:
|
||||
raise ValueError(f"Model {model_id} is not supported.")
|
||||
return model
|
||||
|
||||
|
||||
async def get_model_definition(
|
||||
model_id: str,
|
||||
) -> BuildLoraModelCallable:
|
||||
model = _validate_model_id(model_id)
|
||||
model_config = MODEL_CONFIGS[model.core_model_id.value]
|
||||
if not hasattr(model_config, "model_definition"):
|
||||
raise ValueError(f"Model {model_id} does not have model definition.")
|
||||
return model_config.model_definition
|
||||
|
||||
|
||||
async def get_tokenizer_type(
|
||||
model_id: str,
|
||||
) -> BuildTokenizerCallable:
|
||||
model = _validate_model_id(model_id)
|
||||
model_config = MODEL_CONFIGS[model.core_model_id.value]
|
||||
if not hasattr(model_config, "tokenizer_type"):
|
||||
raise ValueError(f"Model {model_id} does not have tokenizer_type.")
|
||||
return model_config.tokenizer_type
|
||||
|
||||
|
||||
async def get_checkpointer_model_type(
|
||||
model_id: str,
|
||||
) -> str:
|
||||
"""
|
||||
checkpointer model type is used in checkpointer for some special treatment on some specific model types
|
||||
For example, llama3.2 model tied weights (https://github.com/pytorch/torchtune/blob/main/torchtune/training/checkpointing/_checkpointer.py#L1041)
|
||||
"""
|
||||
model = _validate_model_id(model_id)
|
||||
model_config = MODEL_CONFIGS[model.core_model_id.value]
|
||||
if not hasattr(model_config, "checkpoint_type"):
|
||||
raise ValueError(f"Model {model_id} does not have checkpoint_type.")
|
||||
return model_config.checkpoint_type
|
||||
|
||||
|
||||
async def validate_input_dataset_schema(
|
||||
datasets_api: Datasets,
|
||||
dataset_id: str,
|
||||
dataset_type: str,
|
||||
) -> None:
|
||||
dataset_def = await datasets_api.get_dataset(dataset_id=dataset_id)
|
||||
if not dataset_def.dataset_schema or len(dataset_def.dataset_schema) == 0:
|
||||
raise ValueError(f"Dataset {dataset_id} does not have a schema defined.")
|
||||
|
||||
if not hasattr(EXPECTED_DATASET_SCHEMA, dataset_type):
|
||||
raise ValueError(f"Dataset type {dataset_type} is not supported.")
|
||||
|
||||
if dataset_def.dataset_schema not in getattr(EXPECTED_DATASET_SCHEMA, dataset_type):
|
||||
raise ValueError(
|
||||
f"Dataset {dataset_id} does not have a correct input schema in {getattr(EXPECTED_DATASET_SCHEMA, dataset_type)}"
|
||||
)
|
||||
|
|
@ -0,0 +1,13 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class TorchtunePostTrainingConfig(BaseModel):
|
||||
torch_seed: Optional[int] = None
|
||||
|
|
@ -0,0 +1,66 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the BSD-style license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
from typing import Any, Dict, List, Mapping
|
||||
|
||||
import numpy as np
|
||||
|
||||
from torch.utils.data import Dataset
|
||||
from torchtune.data._common import CROSS_ENTROPY_IGNORE_IDX
|
||||
from torchtune.data._messages import validate_messages
|
||||
from torchtune.modules.transforms import Transform
|
||||
|
||||
|
||||
class SFTDataset(Dataset):
|
||||
def __init__(
|
||||
self,
|
||||
rows: List[Dict[str, Any]],
|
||||
message_transform: Transform,
|
||||
model_transform: Transform,
|
||||
) -> None:
|
||||
self._rows = rows
|
||||
self._message_transform = message_transform
|
||||
self._model_transform = model_transform
|
||||
|
||||
def __len__(self):
|
||||
return len(self._rows)
|
||||
|
||||
def __getitem__(self, index: int) -> Dict[str, Any]:
|
||||
sample = self._rows[index]
|
||||
return self._prepare_sample(sample)
|
||||
|
||||
def _prepare_sample(self, sample: Mapping[str, Any]) -> Dict[str, Any]:
|
||||
transformed_sample = self._message_transform(sample)
|
||||
if "messages" in transformed_sample:
|
||||
validate_messages(transformed_sample["messages"])
|
||||
|
||||
tokenized_dict = self._model_transform(transformed_sample)
|
||||
|
||||
if not ("tokens" in tokenized_dict and "mask" in tokenized_dict):
|
||||
keys_str = ", ".join(tokenized_dict.keys())
|
||||
error_message = (
|
||||
"model_transform returned the following keys: "
|
||||
f"{keys_str}. Must return 'tokens' and 'mask' as keys."
|
||||
)
|
||||
raise ValueError(error_message)
|
||||
|
||||
# Wherever mask == True, set to CROSS_ENTROPY_IGNORE_IDX. Otherwise keep as tokens
|
||||
tokenized_dict["labels"] = list(
|
||||
np.where(
|
||||
tokenized_dict["mask"],
|
||||
CROSS_ENTROPY_IGNORE_IDX,
|
||||
tokenized_dict["tokens"],
|
||||
)
|
||||
)
|
||||
assert len(tokenized_dict["tokens"]) == len(tokenized_dict["labels"])
|
||||
|
||||
return tokenized_dict
|
||||
|
|
@ -0,0 +1,126 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
from llama_stack.apis.datasetio import DatasetIO
|
||||
from llama_stack.providers.inline.post_training.torchtune.config import (
|
||||
TorchtunePostTrainingConfig,
|
||||
)
|
||||
from llama_stack.apis.post_training import * # noqa
|
||||
from llama_stack.providers.inline.post_training.torchtune.recipes.lora_finetuning_single_device import (
|
||||
LoraFinetuningSingleDevice,
|
||||
)
|
||||
|
||||
|
||||
class TorchtunePostTrainingImpl:
|
||||
def __init__(
|
||||
self,
|
||||
config: TorchtunePostTrainingConfig,
|
||||
datasetio_api: DatasetIO,
|
||||
datasets: Datasets,
|
||||
) -> None:
|
||||
self.config = config
|
||||
self.datasetio_api = datasetio_api
|
||||
self.datasets_api = datasets
|
||||
|
||||
# TODO: assume sync job, will need jobs API for async scheduling
|
||||
self.jobs_status = {}
|
||||
self.jobs_list = []
|
||||
self.checkpoints_dict = {}
|
||||
|
||||
async def supervised_fine_tune(
|
||||
self,
|
||||
job_uuid: str,
|
||||
training_config: TrainingConfig,
|
||||
hyperparam_search_config: Dict[str, Any],
|
||||
logger_config: Dict[str, Any],
|
||||
model: str,
|
||||
checkpoint_dir: Optional[str],
|
||||
algorithm_config: Optional[AlgorithmConfig],
|
||||
) -> PostTrainingJob:
|
||||
for job in self.jobs_list:
|
||||
if job_uuid == job.job_uuid:
|
||||
raise ValueError(f"Job {job_uuid} already exists")
|
||||
|
||||
post_training_job = PostTrainingJob(job_uuid=job_uuid)
|
||||
|
||||
job_status_response = PostTrainingJobStatusResponse(
|
||||
job_uuid=job_uuid,
|
||||
status=JobStatus.scheduled,
|
||||
scheduled_at=datetime.now(),
|
||||
)
|
||||
|
||||
self.jobs_list.append(post_training_job)
|
||||
if isinstance(algorithm_config, LoraFinetuningConfig):
|
||||
try:
|
||||
recipe = LoraFinetuningSingleDevice(
|
||||
self.config,
|
||||
job_uuid,
|
||||
training_config,
|
||||
hyperparam_search_config,
|
||||
logger_config,
|
||||
model,
|
||||
checkpoint_dir,
|
||||
algorithm_config,
|
||||
self.datasetio_api,
|
||||
self.datasets_api,
|
||||
)
|
||||
|
||||
job_status_response.status = JobStatus.in_progress
|
||||
job_status_response.started_at = datetime.now()
|
||||
|
||||
await recipe.setup()
|
||||
resources_allocated, checkpoints = await recipe.train()
|
||||
|
||||
self.checkpoints_dict[job_uuid] = checkpoints
|
||||
job_status_response.resources_allocated = resources_allocated
|
||||
job_status_response.checkpoints = checkpoints
|
||||
job_status_response.status = JobStatus.completed
|
||||
job_status_response.completed_at = datetime.now()
|
||||
|
||||
except Exception:
|
||||
job_status_response.status = JobStatus.failed
|
||||
raise
|
||||
else:
|
||||
raise NotImplementedError()
|
||||
|
||||
self.jobs_status[job_uuid] = job_status_response
|
||||
|
||||
return post_training_job
|
||||
|
||||
async def preference_optimize(
|
||||
self,
|
||||
job_uuid: str,
|
||||
finetuned_model: str,
|
||||
algorithm_config: DPOAlignmentConfig,
|
||||
training_config: TrainingConfig,
|
||||
hyperparam_search_config: Dict[str, Any],
|
||||
logger_config: Dict[str, Any],
|
||||
) -> PostTrainingJob: ...
|
||||
|
||||
async def get_training_jobs(self) -> List[PostTrainingJob]:
|
||||
return self.jobs_list
|
||||
|
||||
@webmethod(route="/post-training/job/status")
|
||||
async def get_training_job_status(
|
||||
self, job_uuid: str
|
||||
) -> Optional[PostTrainingJobStatusResponse]:
|
||||
if job_uuid in self.jobs_status:
|
||||
return self.jobs_status[job_uuid]
|
||||
return None
|
||||
|
||||
@webmethod(route="/post-training/job/cancel")
|
||||
async def cancel_training_job(self, job_uuid: str) -> None:
|
||||
raise NotImplementedError("Job cancel is not implemented yet")
|
||||
|
||||
@webmethod(route="/post-training/job/artifacts")
|
||||
async def get_training_job_artifacts(
|
||||
self, job_uuid: str
|
||||
) -> Optional[PostTrainingJobArtifactsResponse]:
|
||||
if job_uuid in self.checkpoints_dict:
|
||||
checkpoints = self.checkpoints_dict.get(job_uuid, [])
|
||||
return PostTrainingJobArtifactsResponse(
|
||||
job_uuid=job_uuid, checkpoints=checkpoints
|
||||
)
|
||||
return None
|
||||
|
|
@ -0,0 +1,596 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import logging
|
||||
import os
|
||||
import time
|
||||
from functools import partial
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
from llama_models.sku_list import resolve_model
|
||||
|
||||
from llama_stack.apis.datasetio import DatasetIO
|
||||
|
||||
from llama_stack.distribution.utils.config_dirs import DEFAULT_CHECKPOINT_DIR
|
||||
from llama_stack.providers.inline.post_training.torchtune.common.checkpointer import (
|
||||
TorchtuneCheckpointer,
|
||||
)
|
||||
from torch import nn
|
||||
from torchtune import utils as torchtune_utils
|
||||
from torchtune.training.metric_logging import DiskLogger
|
||||
from tqdm import tqdm
|
||||
from llama_stack.apis.post_training import * # noqa
|
||||
from llama_stack.distribution.utils.model_utils import model_local_dir
|
||||
|
||||
from llama_stack.providers.inline.post_training.torchtune.common import utils
|
||||
from llama_stack.providers.inline.post_training.torchtune.config import (
|
||||
TorchtunePostTrainingConfig,
|
||||
)
|
||||
from llama_stack.providers.inline.post_training.torchtune.datasets.sft import SFTDataset
|
||||
from torch.optim import Optimizer
|
||||
from torch.utils.data import DataLoader, DistributedSampler
|
||||
from torchtune import modules, training
|
||||
from torchtune.data import AlpacaToMessages, padded_collate_sft
|
||||
|
||||
from torchtune.modules.loss import CEWithChunkedOutputLoss
|
||||
from torchtune.modules.peft import (
|
||||
get_adapter_params,
|
||||
get_adapter_state_dict,
|
||||
get_lora_module_names,
|
||||
get_merged_lora_ckpt,
|
||||
load_dora_magnitudes,
|
||||
set_trainable_params,
|
||||
validate_missing_and_unexpected_for_lora,
|
||||
)
|
||||
from torchtune.training.lr_schedulers import get_cosine_schedule_with_warmup
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
from torchtune.models.llama3._tokenizer import Llama3Tokenizer
|
||||
|
||||
|
||||
class LoraFinetuningSingleDevice:
|
||||
# This recipe only supports GPU training
|
||||
|
||||
# This recipe doesn't include several training efficiency setting within origin torchtune repo, including
|
||||
# - compile
|
||||
# - activation offloading
|
||||
|
||||
# Resume from checkpoint hasn't been supported yet
|
||||
# Validation hasn't been supported yet
|
||||
|
||||
# Currently logging only logs limited training metrics to local disk
|
||||
# will figure out more loggings and how it works with telemetry in future PRs
|
||||
def __init__(
|
||||
self,
|
||||
config: TorchtunePostTrainingConfig,
|
||||
job_uuid: str,
|
||||
training_config: TrainingConfig,
|
||||
hyperparam_search_config: Dict[str, Any],
|
||||
logger_config: Dict[str, Any],
|
||||
model: str,
|
||||
checkpoint_dir: Optional[str],
|
||||
algorithm_config: Optional[AlgorithmConfig],
|
||||
datasetio_api: DatasetIO,
|
||||
datasets_api: Datasets,
|
||||
) -> None:
|
||||
self.job_uuid = job_uuid
|
||||
self.training_config = training_config
|
||||
if not isinstance(algorithm_config, LoraFinetuningConfig):
|
||||
raise ValueError(
|
||||
"You need to speicifc LoraFinetuningConfig for LoRA finetuning"
|
||||
)
|
||||
self.algorithm_config = algorithm_config
|
||||
self._device = torchtune_utils.get_device(device="cuda")
|
||||
self._dtype = training.get_dtype(training_config.dtype, device=self._device)
|
||||
self.model_id = model
|
||||
|
||||
def model_checkpoint_dir(model) -> str:
|
||||
checkpoint_dir = Path(model_local_dir(model.descriptor()))
|
||||
|
||||
paths = [
|
||||
Path(checkpoint_dir / f"consolidated.{ext}")
|
||||
for ext in ["pth", "00.pth"]
|
||||
]
|
||||
if not any(p.exists() for p in paths):
|
||||
checkpoint_dir = checkpoint_dir / "original"
|
||||
|
||||
assert checkpoint_dir.exists(), (
|
||||
f"Could not find checkpoints in: {model_local_dir(model.descriptor())}. "
|
||||
f"Please download model using `llama download --model-id {model.descriptor()}`"
|
||||
)
|
||||
return str(checkpoint_dir)
|
||||
|
||||
if checkpoint_dir and checkpoint_dir != "null":
|
||||
self.checkpoint_dir = config.checkpoint_dir
|
||||
else:
|
||||
model = resolve_model(self.model_id)
|
||||
self.checkpoint_dir = model_checkpoint_dir(model)
|
||||
|
||||
self._output_dir = str(DEFAULT_CHECKPOINT_DIR)
|
||||
|
||||
self.seed = training.set_seed(seed=config.torch_seed)
|
||||
self.epochs_run = 0
|
||||
self.total_epochs = training_config.n_epochs
|
||||
self._shuffle = training_config.data_config.shuffle
|
||||
self._batch_size = training_config.data_config.batch_size
|
||||
|
||||
# this is important for debugging purpose
|
||||
self.max_steps_per_epoch = training_config.max_steps_per_epoch
|
||||
self.global_step = 0
|
||||
|
||||
self._gradient_accumulation_steps = training_config.gradient_accumulation_steps
|
||||
|
||||
self._clip_grad_norm = 1.0
|
||||
self._enable_activation_checkpointing = (
|
||||
(training_config.efficiency_config.enable_activation_checkpointing)
|
||||
if training_config.efficiency_config
|
||||
else False
|
||||
)
|
||||
self._enable_activation_offloading = (
|
||||
(training_config.efficiency_config.enable_activation_offloading)
|
||||
if training_config.efficiency_config
|
||||
else False
|
||||
)
|
||||
|
||||
self.datasetio_api = datasetio_api
|
||||
self.datasets_api = datasets_api
|
||||
|
||||
async def load_checkpoint(self):
|
||||
def get_checkpoint_files(checkpoint_dir: str) -> List[str]:
|
||||
try:
|
||||
# List all files in the given directory
|
||||
files = os.listdir(checkpoint_dir)
|
||||
# Filter files that end with .pth
|
||||
pth_files = [file for file in files if file.endswith(".pth")]
|
||||
return pth_files
|
||||
except FileNotFoundError:
|
||||
return [f"Error: The directory '{checkpoint_dir}' does not exist."]
|
||||
|
||||
self._checkpointer = TorchtuneCheckpointer(
|
||||
model_id=self.model_id,
|
||||
training_algorithm="sft",
|
||||
checkpoint_dir=self.checkpoint_dir,
|
||||
checkpoint_files=get_checkpoint_files(self.checkpoint_dir),
|
||||
output_dir=self._output_dir,
|
||||
model_type=await utils.get_checkpointer_model_type(self.model_id),
|
||||
)
|
||||
checkpoint_dict = self._checkpointer.load_checkpoint()
|
||||
return checkpoint_dict
|
||||
|
||||
async def setup(self) -> None:
|
||||
checkpoint_dict = await self.load_checkpoint()
|
||||
|
||||
self._model = await self._setup_model(
|
||||
enable_activation_checkpointing=self._enable_activation_checkpointing,
|
||||
enable_activation_offloading=self._enable_activation_offloading,
|
||||
base_model_state_dict=checkpoint_dict[training.MODEL_KEY],
|
||||
lora_weights_state_dict=None,
|
||||
)
|
||||
log.info(f"Model is initialized with precision {self._dtype}.")
|
||||
|
||||
self._tokenizer = await self._setup_tokenizer()
|
||||
log.info("Tokenizer is initialized.")
|
||||
|
||||
self._optimizer = await self._setup_optimizer(
|
||||
optimizer_config=self.training_config.optimizer_config
|
||||
)
|
||||
log.info("Optimizer is initialized.")
|
||||
|
||||
self._loss_fn = CEWithChunkedOutputLoss()
|
||||
self._model.set_num_output_chunks(self._loss_fn.num_output_chunks)
|
||||
log.info("Loss is initialized.")
|
||||
|
||||
self._training_sampler, self._training_dataloader = await self._setup_data(
|
||||
dataset_id=self.training_config.data_config.dataset_id,
|
||||
tokenizer=self._tokenizer,
|
||||
shuffle=self._shuffle,
|
||||
batch_size=self._batch_size,
|
||||
)
|
||||
|
||||
if self.training_config.data_config.validation_dataset_id:
|
||||
_, self._validation_dataloader = await self._setup_data(
|
||||
dataset_id=self.training_config.data_config.validation_dataset_id,
|
||||
tokenizer=self._tokenizer,
|
||||
shuffle=False,
|
||||
batch_size=self._batch_size,
|
||||
)
|
||||
|
||||
log.info("Dataset and Sampler are initialized.")
|
||||
|
||||
# Number of training steps in each epoch depends on the number of batches produced
|
||||
# by the dataloader and the max_steps_per_epoch param set by the user and is used
|
||||
# for logging and tracking training state. This should be computed after the dataloader
|
||||
# has been setup
|
||||
self._steps_per_epoch = (
|
||||
len(self._training_dataloader) // self._gradient_accumulation_steps
|
||||
)
|
||||
if (
|
||||
self.max_steps_per_epoch is not None
|
||||
and self.max_steps_per_epoch < self._steps_per_epoch
|
||||
):
|
||||
self._steps_per_epoch = self.max_steps_per_epoch
|
||||
self.global_step = self.epochs_run * self._steps_per_epoch
|
||||
|
||||
# Learning rate scheduler can only be set up after number of steps
|
||||
# has been computed
|
||||
self._lr_scheduler = await self._setup_lr_scheduler(
|
||||
num_warmup_steps=self.training_config.optimizer_config.num_warmup_steps,
|
||||
num_training_steps=self.total_epochs * self._steps_per_epoch,
|
||||
last_epoch=self.global_step - 1,
|
||||
)
|
||||
log.info("Learning rate scheduler is initialized.")
|
||||
|
||||
# Used to ignore labels for loss computation
|
||||
self.ignore_labels_cache = torch.full(
|
||||
(self._batch_size, 1), self._loss_fn.ignore_index, device=self._device
|
||||
)
|
||||
|
||||
async def _setup_model(
|
||||
self,
|
||||
enable_activation_checkpointing: bool,
|
||||
enable_activation_offloading: bool,
|
||||
base_model_state_dict: Dict[str, Any],
|
||||
lora_weights_state_dict: Optional[Dict[str, Any]] = None,
|
||||
) -> nn.Module:
|
||||
self._lora_rank = self.algorithm_config.rank
|
||||
self._lora_alpha = self.algorithm_config.alpha
|
||||
self._lora_attn_modules = list(self.algorithm_config.lora_attn_modules)
|
||||
self._apply_lora_to_mlp = self.algorithm_config.apply_lora_to_mlp
|
||||
self._apply_lora_to_output = self.algorithm_config.apply_lora_to_output
|
||||
self._use_dora = self.algorithm_config.use_dora or False
|
||||
|
||||
with training.set_default_dtype(self._dtype), self._device:
|
||||
model_type = await utils.get_model_definition(self.model_id)
|
||||
model = model_type(
|
||||
lora_attn_modules=self._lora_attn_modules,
|
||||
apply_lora_to_mlp=self._apply_lora_to_mlp,
|
||||
apply_lora_to_output=self._apply_lora_to_output,
|
||||
lora_rank=self._lora_rank,
|
||||
lora_alpha=self._lora_alpha,
|
||||
quantize_base=False,
|
||||
use_dora=self._use_dora,
|
||||
)
|
||||
|
||||
self.adapter_params = get_adapter_params(model)
|
||||
self._is_dora = any(["magnitude" in k for k in self.adapter_params.keys()])
|
||||
|
||||
set_trainable_params(model, self.adapter_params)
|
||||
|
||||
if enable_activation_checkpointing:
|
||||
training.set_activation_checkpointing(
|
||||
model, auto_wrap_policy={modules.TransformerSelfAttentionLayer}
|
||||
)
|
||||
|
||||
base_missing, base_unexpected = model.load_state_dict(
|
||||
base_model_state_dict, strict=False
|
||||
)
|
||||
|
||||
# This is for any adapters that need to be initialized after base weights
|
||||
# have been loaded (e.g. DoRA).
|
||||
if self._is_dora:
|
||||
for m in model.modules():
|
||||
if hasattr(m, "initialize_dora_magnitude"):
|
||||
m.initialize_dora_magnitude()
|
||||
load_dora_magnitudes(model)
|
||||
if lora_weights_state_dict:
|
||||
lora_missing, lora_unexpected = model.load_state_dict(
|
||||
lora_weights_state_dict, strict=False
|
||||
)
|
||||
else:
|
||||
lora_missing, lora_unexpected = None, None
|
||||
validate_missing_and_unexpected_for_lora(
|
||||
lora_attn_modules=self._lora_attn_modules,
|
||||
apply_lora_to_mlp=self._apply_lora_to_mlp,
|
||||
apply_lora_to_output=self._apply_lora_to_output,
|
||||
base_missing=base_missing,
|
||||
base_unexpected=base_unexpected,
|
||||
lora_missing=lora_missing,
|
||||
lora_unexpected=lora_unexpected,
|
||||
)
|
||||
|
||||
# Validate model adapter params were loaded in with the expected dtype
|
||||
training.validate_expected_param_dtype(
|
||||
self.adapter_params.items(), dtype=self._dtype
|
||||
)
|
||||
|
||||
# activation offloading
|
||||
self.activations_handling_ctx = training.get_act_offloading_ctx_manager(
|
||||
model, enable_activation_offloading
|
||||
)
|
||||
|
||||
memory_stats = training.get_memory_stats(device=self._device)
|
||||
training.log_memory_stats(memory_stats)
|
||||
|
||||
return model
|
||||
|
||||
async def _setup_tokenizer(
|
||||
self,
|
||||
) -> Llama3Tokenizer:
|
||||
tokenizer_path = self.checkpoint_dir + "/tokenizer.model"
|
||||
tokenizer_type = await utils.get_tokenizer_type(self.model_id)
|
||||
return tokenizer_type(path=tokenizer_path)
|
||||
|
||||
async def _setup_optimizer(self, optimizer_config: OptimizerConfig) -> Optimizer:
|
||||
optimizer = torch.optim.AdamW(
|
||||
params=self._model.parameters(),
|
||||
lr=optimizer_config.lr,
|
||||
betas=(0.9, 0.95),
|
||||
eps=1e-8,
|
||||
weight_decay=0.1,
|
||||
)
|
||||
return optimizer
|
||||
|
||||
async def _setup_data(
|
||||
self,
|
||||
dataset_id: str,
|
||||
tokenizer: Llama3Tokenizer,
|
||||
shuffle: bool,
|
||||
batch_size: int,
|
||||
) -> Tuple[DistributedSampler, DataLoader]:
|
||||
async def fetch_rows(dataset_id: str):
|
||||
return await self.datasetio_api.get_rows_paginated(
|
||||
dataset_id=dataset_id,
|
||||
rows_in_page=-1,
|
||||
)
|
||||
|
||||
all_rows = await fetch_rows(dataset_id)
|
||||
rows = all_rows.rows
|
||||
|
||||
# Curretly only support alpaca instruct dataset
|
||||
# TODO @SLR722 make the message_transform swappable and support more dataset types
|
||||
# TODO @SLR722 make the input dataset schema more flexible by exposing column_map
|
||||
await utils.validate_input_dataset_schema(
|
||||
datasets_api=self.datasets_api,
|
||||
dataset_id=dataset_id,
|
||||
dataset_type="alpaca",
|
||||
)
|
||||
ds = SFTDataset(
|
||||
rows,
|
||||
message_transform=AlpacaToMessages(train_on_input=False),
|
||||
model_transform=tokenizer,
|
||||
)
|
||||
|
||||
sampler = DistributedSampler(
|
||||
ds,
|
||||
num_replicas=1,
|
||||
rank=0,
|
||||
shuffle=shuffle,
|
||||
seed=0,
|
||||
)
|
||||
dataloader = DataLoader(
|
||||
dataset=ds,
|
||||
sampler=sampler,
|
||||
batch_size=batch_size,
|
||||
# dropping last avoids shape issues with compile + flex attention
|
||||
drop_last=True,
|
||||
collate_fn=(
|
||||
partial(
|
||||
padded_collate_sft,
|
||||
padding_idx=self._tokenizer.pad_id,
|
||||
ignore_idx=self._loss_fn.ignore_index,
|
||||
)
|
||||
),
|
||||
)
|
||||
|
||||
return sampler, dataloader
|
||||
|
||||
async def _setup_lr_scheduler(
|
||||
self,
|
||||
num_warmup_steps: int,
|
||||
num_training_steps: int,
|
||||
last_epoch: int,
|
||||
) -> Optimizer:
|
||||
lr_scheduler = get_cosine_schedule_with_warmup(
|
||||
self._optimizer,
|
||||
num_warmup_steps=num_warmup_steps,
|
||||
num_training_steps=num_training_steps,
|
||||
last_epoch=last_epoch,
|
||||
)
|
||||
return lr_scheduler
|
||||
|
||||
async def save_checkpoint(self, epoch: int) -> str:
|
||||
ckpt_dict = {}
|
||||
|
||||
adapter_state_dict = get_adapter_state_dict(self._model.state_dict())
|
||||
ckpt_dict.update({training.ADAPTER_KEY: adapter_state_dict})
|
||||
|
||||
# Construct the full state dict with LoRA weights merged into base LLM weights
|
||||
# Move to CPU to avoid a copy on GPU
|
||||
state_dict = {k: v.cpu() for k, v in self._model.state_dict().items()}
|
||||
|
||||
merged_state_dict = get_merged_lora_ckpt(
|
||||
state_dict,
|
||||
rank=self._lora_rank,
|
||||
alpha=self._lora_alpha,
|
||||
)
|
||||
|
||||
ckpt_dict.update({training.MODEL_KEY: merged_state_dict})
|
||||
|
||||
adapter_config = {
|
||||
"r": self._lora_rank,
|
||||
"lora_alpha": self._lora_alpha,
|
||||
"target_modules": get_lora_module_names(
|
||||
self._lora_attn_modules,
|
||||
self._apply_lora_to_mlp,
|
||||
self._apply_lora_to_output,
|
||||
),
|
||||
"peft_type": "LORA",
|
||||
}
|
||||
ckpt_dict.update({training.ADAPTER_CONFIG: adapter_config})
|
||||
|
||||
return self._checkpointer.save_checkpoint(
|
||||
ckpt_dict,
|
||||
epoch=epoch,
|
||||
)
|
||||
|
||||
async def _loss_step(self, batch: Dict[str, torch.Tensor]) -> torch.Tensor:
|
||||
# Shape [b, s], needed for the loss not the model
|
||||
labels = batch.pop("labels")
|
||||
# run model
|
||||
with self.activations_handling_ctx:
|
||||
logits = self._model(**batch)
|
||||
|
||||
# Shift labels to compute loss
|
||||
# equivalent to doing labels[..., 1:] and logits[..., :-1, :]
|
||||
# But this way we dont need to slice the logits. We just add an ignore index to labels.
|
||||
labels = torch.hstack(
|
||||
(labels[..., 1:], self.ignore_labels_cache[: labels.shape[0]])
|
||||
)
|
||||
if not isinstance(logits, list):
|
||||
labels = labels.reshape(-1)
|
||||
logits = logits.reshape(-1, logits.size(-1))
|
||||
|
||||
loss = self._loss_fn(logits, labels)
|
||||
|
||||
# free logits otherwise it peaks backward memory
|
||||
del logits
|
||||
|
||||
return loss
|
||||
|
||||
async def train(self) -> Tuple[Dict[str, Any], List[Checkpoint]]:
|
||||
"""
|
||||
The core training loop.
|
||||
"""
|
||||
# Initialize tokens count and running loss (for grad accumulation)
|
||||
t0 = time.perf_counter()
|
||||
running_loss = 0
|
||||
num_tokens = 0
|
||||
|
||||
# training artifacts
|
||||
checkpoints = []
|
||||
memory_stats = {}
|
||||
|
||||
# self.epochs_run should be non-zero when we're resuming from a checkpoint
|
||||
for curr_epoch in range(self.epochs_run, self.total_epochs):
|
||||
# Update the sampler to ensure data is correctly shuffled across epochs
|
||||
# in case shuffle is True
|
||||
metric_logger = DiskLogger(
|
||||
log_dir=self._output_dir + f"/{self.model_id}-sft-{curr_epoch}"
|
||||
)
|
||||
self._training_sampler.set_epoch(curr_epoch)
|
||||
loss_to_log = 0.0
|
||||
|
||||
pbar = tqdm(total=self._steps_per_epoch)
|
||||
for idx, batch in enumerate(self._training_dataloader):
|
||||
if (
|
||||
self.max_steps_per_epoch is not None
|
||||
and (idx // self._gradient_accumulation_steps)
|
||||
== self.max_steps_per_epoch
|
||||
):
|
||||
break
|
||||
|
||||
torchtune_utils.batch_to_device(batch, self._device)
|
||||
|
||||
# Calculate the number of unmasked tokens in the current batch
|
||||
# and increment the total number of tokens seen in the step
|
||||
current_num_tokens = (
|
||||
batch["labels"] != self._loss_fn.ignore_index
|
||||
).sum()
|
||||
num_tokens += current_num_tokens
|
||||
|
||||
# Loss is normalized by default so we multiply by the number of tokens
|
||||
# This way we can normalize by the total number of tokens if we're accumulating gradients
|
||||
current_loss = await self._loss_step(batch) * current_num_tokens
|
||||
running_loss += current_loss
|
||||
current_loss.backward()
|
||||
|
||||
# Step with optimizer
|
||||
if (idx + 1) % self._gradient_accumulation_steps == 0:
|
||||
training.scale_grads(self._model, 1 / num_tokens)
|
||||
grad_norm = torch.nn.utils.clip_grad_norm_(
|
||||
self._model.parameters(),
|
||||
max_norm=float(self._clip_grad_norm),
|
||||
)
|
||||
self._optimizer.step()
|
||||
self._optimizer.zero_grad(set_to_none=True)
|
||||
self._lr_scheduler.step()
|
||||
# Update the number of steps when the weights are updated
|
||||
self.global_step += 1
|
||||
|
||||
loss_to_log = running_loss.item() / num_tokens
|
||||
|
||||
pbar.update(1)
|
||||
pbar.set_description(
|
||||
f"{curr_epoch + 1}|{self.global_step}|Loss: {loss_to_log}"
|
||||
)
|
||||
|
||||
time_per_step = time.perf_counter() - t0
|
||||
log_dict = {
|
||||
"loss": loss_to_log,
|
||||
"lr": self._optimizer.param_groups[0]["lr"],
|
||||
"tokens_per_second_per_gpu": num_tokens / time_per_step,
|
||||
}
|
||||
|
||||
memory_stats = training.get_memory_stats(device=self._device)
|
||||
log_dict.update(memory_stats)
|
||||
|
||||
if self._clip_grad_norm is not None:
|
||||
log_dict.update({"grad_norm": grad_norm})
|
||||
|
||||
metric_logger.log_dict(
|
||||
log_dict,
|
||||
step=self.global_step,
|
||||
)
|
||||
|
||||
# Reset running stats for the next step
|
||||
running_loss = 0
|
||||
num_tokens = 0
|
||||
t0 = time.perf_counter()
|
||||
|
||||
self.epochs_run += 1
|
||||
log.info("Starting checkpoint save...")
|
||||
checkpoint_path = await self.save_checkpoint(epoch=curr_epoch)
|
||||
checkpoint = Checkpoint(
|
||||
identifier=f"{self.model_id}-sft-{curr_epoch}",
|
||||
created_at=datetime.now(),
|
||||
epoch=curr_epoch,
|
||||
post_training_job_id=self.job_uuid,
|
||||
path=checkpoint_path,
|
||||
)
|
||||
if self.training_config.data_config.validation_dataset_id:
|
||||
validation_loss, perplexity = await self.validation()
|
||||
training_metrics = PostTrainingMetric(
|
||||
epoch=curr_epoch,
|
||||
train_loss=loss_to_log,
|
||||
validation_loss=validation_loss,
|
||||
perplexity=perplexity,
|
||||
)
|
||||
checkpoint.training_metrics = training_metrics
|
||||
checkpoints.append(checkpoint)
|
||||
|
||||
return (memory_stats, checkpoints)
|
||||
|
||||
async def validation(self) -> Tuple[float, float]:
|
||||
total_loss = 0.0
|
||||
total_tokens = 0
|
||||
log.info("Starting validation...")
|
||||
pbar = tqdm(total=len(self._validation_dataloader))
|
||||
for idx, batch in enumerate(self._validation_dataloader):
|
||||
if idx == 10:
|
||||
break
|
||||
torchtune_utils.batch_to_device(batch, self._device)
|
||||
|
||||
# Calculate the number of unmasked tokens in the current batch
|
||||
# and increment the total number of tokens seen in the step
|
||||
num_tokens = (batch["labels"] != self._loss_fn.ignore_index).sum()
|
||||
|
||||
# Loss is normalized by default so we multiply by the number of tokens
|
||||
# This way we can normalize by the total number of tokens if we're accumulating gradients
|
||||
loss = await self._loss_step(batch) * num_tokens
|
||||
|
||||
total_loss += loss
|
||||
total_tokens += num_tokens
|
||||
|
||||
pbar.update(1)
|
||||
pbar.set_description(f"validation step: {idx}")
|
||||
|
||||
mean_loss = total_loss / total_tokens
|
||||
perplexity = torch.exp(torch.tensor(mean_loss))
|
||||
|
||||
return mean_loss, perplexity.item()
|
||||
|
|
@ -18,6 +18,7 @@ META_REFERENCE_DEPS = [
|
|||
"transformers",
|
||||
"zmq",
|
||||
"lm-format-enforcer",
|
||||
"sentence-transformers",
|
||||
]
|
||||
|
||||
|
||||
|
|
@ -52,6 +53,13 @@ def available_providers() -> List[ProviderSpec]:
|
|||
module="llama_stack.providers.inline.inference.vllm",
|
||||
config_class="llama_stack.providers.inline.inference.vllm.VLLMConfig",
|
||||
),
|
||||
InlineProviderSpec(
|
||||
api=Api.inference,
|
||||
provider_type="inline::sentence-transformers",
|
||||
pip_packages=["sentence-transformers"],
|
||||
module="llama_stack.providers.inline.inference.sentence_transformers",
|
||||
config_class="llama_stack.providers.inline.inference.sentence_transformers.config.SentenceTransformersInferenceConfig",
|
||||
),
|
||||
remote_provider_spec(
|
||||
api=Api.inference,
|
||||
adapter=AdapterSpec(
|
||||
|
|
|
|||
|
|
@ -39,6 +39,7 @@ def available_providers() -> List[ProviderSpec]:
|
|||
module="llama_stack.providers.inline.memory.faiss",
|
||||
config_class="llama_stack.providers.inline.memory.faiss.FaissImplConfig",
|
||||
deprecation_warning="Please use the `inline::faiss` provider instead.",
|
||||
api_dependencies=[Api.inference],
|
||||
),
|
||||
InlineProviderSpec(
|
||||
api=Api.memory,
|
||||
|
|
@ -46,6 +47,7 @@ def available_providers() -> List[ProviderSpec]:
|
|||
pip_packages=EMBEDDING_DEPS + ["faiss-cpu"],
|
||||
module="llama_stack.providers.inline.memory.faiss",
|
||||
config_class="llama_stack.providers.inline.memory.faiss.FaissImplConfig",
|
||||
api_dependencies=[Api.inference],
|
||||
),
|
||||
remote_provider_spec(
|
||||
Api.memory,
|
||||
|
|
@ -55,6 +57,7 @@ def available_providers() -> List[ProviderSpec]:
|
|||
module="llama_stack.providers.remote.memory.chroma",
|
||||
config_class="llama_stack.providers.remote.memory.chroma.ChromaRemoteImplConfig",
|
||||
),
|
||||
api_dependencies=[Api.inference],
|
||||
),
|
||||
InlineProviderSpec(
|
||||
api=Api.memory,
|
||||
|
|
@ -71,6 +74,7 @@ def available_providers() -> List[ProviderSpec]:
|
|||
module="llama_stack.providers.remote.memory.pgvector",
|
||||
config_class="llama_stack.providers.remote.memory.pgvector.PGVectorConfig",
|
||||
),
|
||||
api_dependencies=[Api.inference],
|
||||
),
|
||||
remote_provider_spec(
|
||||
Api.memory,
|
||||
|
|
@ -81,6 +85,7 @@ def available_providers() -> List[ProviderSpec]:
|
|||
config_class="llama_stack.providers.remote.memory.weaviate.WeaviateConfig",
|
||||
provider_data_validator="llama_stack.providers.remote.memory.weaviate.WeaviateRequestProviderData",
|
||||
),
|
||||
api_dependencies=[Api.inference],
|
||||
),
|
||||
remote_provider_spec(
|
||||
api=Api.memory,
|
||||
|
|
@ -90,6 +95,7 @@ def available_providers() -> List[ProviderSpec]:
|
|||
module="llama_stack.providers.remote.memory.sample",
|
||||
config_class="llama_stack.providers.remote.memory.sample.SampleConfig",
|
||||
),
|
||||
api_dependencies=[],
|
||||
),
|
||||
remote_provider_spec(
|
||||
Api.memory,
|
||||
|
|
@ -99,5 +105,6 @@ def available_providers() -> List[ProviderSpec]:
|
|||
module="llama_stack.providers.remote.memory.qdrant",
|
||||
config_class="llama_stack.providers.remote.memory.qdrant.QdrantConfig",
|
||||
),
|
||||
api_dependencies=[Api.inference],
|
||||
),
|
||||
]
|
||||
|
|
|
|||
25
llama_stack/providers/registry/post_training.py
Normal file
25
llama_stack/providers/registry/post_training.py
Normal file
|
|
@ -0,0 +1,25 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import List
|
||||
|
||||
from llama_stack.distribution.datatypes import * # noqa: F403
|
||||
|
||||
|
||||
def available_providers() -> List[ProviderSpec]:
|
||||
return [
|
||||
InlineProviderSpec(
|
||||
api=Api.post_training,
|
||||
provider_type="inline::torchtune",
|
||||
pip_packages=["torch", "torchtune", "torchao", "numpy"],
|
||||
module="llama_stack.providers.inline.post_training.torchtune",
|
||||
config_class="llama_stack.providers.inline.post_training.torchtune.TorchtunePostTrainingConfig",
|
||||
api_dependencies=[
|
||||
Api.datasetio,
|
||||
Api.datasets,
|
||||
],
|
||||
),
|
||||
]
|
||||
|
|
@ -21,14 +21,19 @@ DATASETS_PREFIX = "datasets:"
|
|||
|
||||
def load_hf_dataset(dataset_def: Dataset):
|
||||
if dataset_def.metadata.get("path", None):
|
||||
return hf_datasets.load_dataset(**dataset_def.metadata)
|
||||
dataset = hf_datasets.load_dataset(**dataset_def.metadata)
|
||||
else:
|
||||
df = get_dataframe_from_url(dataset_def.url)
|
||||
|
||||
df = get_dataframe_from_url(dataset_def.url)
|
||||
if df is None:
|
||||
raise ValueError(f"Failed to load dataset from {dataset_def.url}")
|
||||
|
||||
if df is None:
|
||||
raise ValueError(f"Failed to load dataset from {dataset_def.url}")
|
||||
dataset = hf_datasets.Dataset.from_pandas(df)
|
||||
|
||||
# drop columns not specified by schema
|
||||
if dataset_def.dataset_schema:
|
||||
dataset = dataset.select_columns(list(dataset_def.dataset_schema.keys()))
|
||||
|
||||
dataset = hf_datasets.Dataset.from_pandas(df)
|
||||
return dataset
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -5,6 +5,7 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
from typing import * # noqa: F403
|
||||
import json
|
||||
|
||||
from botocore.client import BaseClient
|
||||
from llama_models.datatypes import CoreModelId
|
||||
|
|
@ -19,8 +20,10 @@ from llama_stack.providers.utils.inference.model_registry import (
|
|||
|
||||
from llama_stack.apis.inference import * # noqa: F403
|
||||
|
||||
|
||||
from llama_stack.providers.remote.inference.bedrock.config import BedrockConfig
|
||||
from llama_stack.providers.utils.bedrock.client import create_bedrock_client
|
||||
from llama_stack.providers.utils.inference.prompt_adapter import content_has_media
|
||||
|
||||
|
||||
model_aliases = [
|
||||
|
|
@ -448,4 +451,21 @@ class BedrockInferenceAdapter(ModelRegistryHelper, Inference):
|
|||
model_id: str,
|
||||
contents: List[InterleavedTextMedia],
|
||||
) -> EmbeddingsResponse:
|
||||
raise NotImplementedError()
|
||||
model = await self.model_store.get_model(model_id)
|
||||
embeddings = []
|
||||
for content in contents:
|
||||
assert not content_has_media(
|
||||
content
|
||||
), "Bedrock does not support media for embeddings"
|
||||
input_text = interleaved_text_media_as_str(content)
|
||||
input_body = {"inputText": input_text}
|
||||
body = json.dumps(input_body)
|
||||
response = self.client.invoke_model(
|
||||
body=body,
|
||||
modelId=model.provider_resource_id,
|
||||
accept="application/json",
|
||||
contentType="application/json",
|
||||
)
|
||||
response_body = json.loads(response.get("body").read())
|
||||
embeddings.append(response_body.get("embedding"))
|
||||
return EmbeddingsResponse(embeddings=embeddings)
|
||||
|
|
|
|||
|
|
@ -13,7 +13,7 @@ from pydantic import BaseModel, Field
|
|||
@json_schema_type
|
||||
class FireworksImplConfig(BaseModel):
|
||||
url: str = Field(
|
||||
default="https://api.fireworks.ai/inference",
|
||||
default="https://api.fireworks.ai/inference/v1",
|
||||
description="The URL for the Fireworks server",
|
||||
)
|
||||
api_key: Optional[str] = Field(
|
||||
|
|
@ -24,6 +24,6 @@ class FireworksImplConfig(BaseModel):
|
|||
@classmethod
|
||||
def sample_run_config(cls) -> Dict[str, Any]:
|
||||
return {
|
||||
"url": "https://api.fireworks.ai/inference",
|
||||
"url": "https://api.fireworks.ai/inference/v1",
|
||||
"api_key": "${env.FIREWORKS_API_KEY}",
|
||||
}
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import AsyncGenerator
|
||||
from typing import AsyncGenerator, List, Optional, Union
|
||||
|
||||
from fireworks.client import Fireworks
|
||||
from llama_models.datatypes import CoreModelId
|
||||
|
|
@ -28,6 +28,7 @@ from llama_stack.providers.utils.inference.openai_compat import (
|
|||
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||
chat_completion_request_to_prompt,
|
||||
completion_request_to_prompt,
|
||||
content_has_media,
|
||||
convert_message_to_dict,
|
||||
request_has_media,
|
||||
)
|
||||
|
|
@ -89,17 +90,19 @@ class FireworksInferenceAdapter(
|
|||
async def shutdown(self) -> None:
|
||||
pass
|
||||
|
||||
def _get_client(self) -> Fireworks:
|
||||
fireworks_api_key = None
|
||||
def _get_api_key(self) -> str:
|
||||
if self.config.api_key is not None:
|
||||
fireworks_api_key = self.config.api_key
|
||||
return self.config.api_key
|
||||
else:
|
||||
provider_data = self.get_request_provider_data()
|
||||
if provider_data is None or not provider_data.fireworks_api_key:
|
||||
raise ValueError(
|
||||
'Pass Fireworks API Key in the header X-LlamaStack-ProviderData as { "fireworks_api_key": <your api key>}'
|
||||
)
|
||||
fireworks_api_key = provider_data.fireworks_api_key
|
||||
return provider_data.fireworks_api_key
|
||||
|
||||
def _get_client(self) -> Fireworks:
|
||||
fireworks_api_key = self._get_api_key()
|
||||
return Fireworks(api_key=fireworks_api_key)
|
||||
|
||||
async def completion(
|
||||
|
|
@ -264,4 +267,19 @@ class FireworksInferenceAdapter(
|
|||
model_id: str,
|
||||
contents: List[InterleavedTextMedia],
|
||||
) -> EmbeddingsResponse:
|
||||
raise NotImplementedError()
|
||||
model = await self.model_store.get_model(model_id)
|
||||
|
||||
kwargs = {}
|
||||
if model.metadata.get("embedding_dimensions"):
|
||||
kwargs["dimensions"] = model.metadata.get("embedding_dimensions")
|
||||
assert all(
|
||||
not content_has_media(content) for content in contents
|
||||
), "Fireworks does not support media for embeddings"
|
||||
response = self._get_client().embeddings.create(
|
||||
model=model.provider_resource_id,
|
||||
input=[interleaved_text_media_as_str(content) for content in contents],
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
embeddings = [data.embedding for data in response.data]
|
||||
return EmbeddingsResponse(embeddings=embeddings)
|
||||
|
|
|
|||
|
|
@ -36,6 +36,7 @@ from llama_stack.providers.utils.inference.openai_compat import (
|
|||
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||
chat_completion_request_to_prompt,
|
||||
completion_request_to_prompt,
|
||||
content_has_media,
|
||||
convert_image_media_to_url,
|
||||
request_has_media,
|
||||
)
|
||||
|
|
@ -321,9 +322,30 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
|
|||
model_id: str,
|
||||
contents: List[InterleavedTextMedia],
|
||||
) -> EmbeddingsResponse:
|
||||
raise NotImplementedError()
|
||||
model = await self.model_store.get_model(model_id)
|
||||
|
||||
assert all(
|
||||
not content_has_media(content) for content in contents
|
||||
), "Ollama does not support media for embeddings"
|
||||
response = await self.client.embed(
|
||||
model=model.provider_resource_id,
|
||||
input=[interleaved_text_media_as_str(content) for content in contents],
|
||||
)
|
||||
embeddings = response["embeddings"]
|
||||
|
||||
return EmbeddingsResponse(embeddings=embeddings)
|
||||
|
||||
async def register_model(self, model: Model) -> Model:
|
||||
# ollama does not have embedding models running. Check if the model is in list of available models.
|
||||
if model.model_type == ModelType.embedding:
|
||||
response = await self.client.list()
|
||||
available_models = [m["model"] for m in response["models"]]
|
||||
if model.provider_resource_id not in available_models:
|
||||
raise ValueError(
|
||||
f"Model '{model.provider_resource_id}' is not available in Ollama. "
|
||||
f"Available models: {', '.join(available_models)}"
|
||||
)
|
||||
return model
|
||||
model = await self.register_helper.register_model(model)
|
||||
models = await self.client.ps()
|
||||
available_models = [m["model"] for m in models["models"]]
|
||||
|
|
|
|||
|
|
@ -31,6 +31,7 @@ from llama_stack.providers.utils.inference.openai_compat import (
|
|||
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||
chat_completion_request_to_prompt,
|
||||
completion_request_to_prompt,
|
||||
content_has_media,
|
||||
convert_message_to_dict,
|
||||
request_has_media,
|
||||
)
|
||||
|
|
@ -253,4 +254,13 @@ class TogetherInferenceAdapter(
|
|||
model_id: str,
|
||||
contents: List[InterleavedTextMedia],
|
||||
) -> EmbeddingsResponse:
|
||||
raise NotImplementedError()
|
||||
model = await self.model_store.get_model(model_id)
|
||||
assert all(
|
||||
not content_has_media(content) for content in contents
|
||||
), "Together does not support media for embeddings"
|
||||
r = self._get_client().embeddings.create(
|
||||
model=model.provider_resource_id,
|
||||
input=[interleaved_text_media_as_str(content) for content in contents],
|
||||
)
|
||||
embeddings = [item.embedding for item in r.data]
|
||||
return EmbeddingsResponse(embeddings=embeddings)
|
||||
|
|
|
|||
|
|
@ -29,6 +29,7 @@ from llama_stack.providers.utils.inference.openai_compat import (
|
|||
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||
chat_completion_request_to_prompt,
|
||||
completion_request_to_prompt,
|
||||
content_has_media,
|
||||
convert_message_to_dict,
|
||||
request_has_media,
|
||||
)
|
||||
|
|
@ -203,4 +204,20 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
|
|||
model_id: str,
|
||||
contents: List[InterleavedTextMedia],
|
||||
) -> EmbeddingsResponse:
|
||||
raise NotImplementedError()
|
||||
model = await self.model_store.get_model(model_id)
|
||||
|
||||
kwargs = {}
|
||||
assert model.model_type == ModelType.embedding
|
||||
assert model.metadata.get("embedding_dimensions")
|
||||
kwargs["dimensions"] = model.metadata.get("embedding_dimensions")
|
||||
assert all(
|
||||
not content_has_media(content) for content in contents
|
||||
), "VLLM does not support media for embeddings"
|
||||
response = self.client.embeddings.create(
|
||||
model=model.provider_resource_id,
|
||||
input=[interleaved_text_media_as_str(content) for content in contents],
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
embeddings = [data.embedding for data in response.data]
|
||||
return EmbeddingsResponse(embeddings=embeddings)
|
||||
|
|
|
|||
|
|
@ -4,12 +4,18 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Dict
|
||||
|
||||
from llama_stack.providers.datatypes import Api, ProviderSpec
|
||||
|
||||
from .config import ChromaRemoteImplConfig
|
||||
|
||||
|
||||
async def get_adapter_impl(config: ChromaRemoteImplConfig, _deps):
|
||||
async def get_adapter_impl(
|
||||
config: ChromaRemoteImplConfig, deps: Dict[Api, ProviderSpec]
|
||||
):
|
||||
from .chroma import ChromaMemoryAdapter
|
||||
|
||||
impl = ChromaMemoryAdapter(config)
|
||||
impl = ChromaMemoryAdapter(config, deps[Api.inference])
|
||||
await impl.initialize()
|
||||
return impl
|
||||
|
|
|
|||
|
|
@ -13,8 +13,7 @@ import chromadb
|
|||
from numpy.typing import NDArray
|
||||
|
||||
from llama_stack.apis.memory import * # noqa: F403
|
||||
|
||||
from llama_stack.providers.datatypes import MemoryBanksProtocolPrivate
|
||||
from llama_stack.providers.datatypes import Api, MemoryBanksProtocolPrivate
|
||||
from llama_stack.providers.inline.memory.chroma import ChromaInlineImplConfig
|
||||
from llama_stack.providers.utils.memory.vector_store import (
|
||||
BankWithIndex,
|
||||
|
|
@ -87,10 +86,14 @@ class ChromaIndex(EmbeddingIndex):
|
|||
|
||||
class ChromaMemoryAdapter(Memory, MemoryBanksProtocolPrivate):
|
||||
def __init__(
|
||||
self, config: Union[ChromaRemoteImplConfig, ChromaInlineImplConfig]
|
||||
self,
|
||||
config: Union[ChromaRemoteImplConfig, ChromaInlineImplConfig],
|
||||
inference_api: Api.inference,
|
||||
) -> None:
|
||||
log.info(f"Initializing ChromaMemoryAdapter with url: {config}")
|
||||
self.config = config
|
||||
self.inference_api = inference_api
|
||||
|
||||
self.client = None
|
||||
self.cache = {}
|
||||
|
||||
|
|
@ -127,10 +130,9 @@ class ChromaMemoryAdapter(Memory, MemoryBanksProtocolPrivate):
|
|||
metadata={"bank": memory_bank.model_dump_json()},
|
||||
)
|
||||
)
|
||||
bank_index = BankWithIndex(
|
||||
bank=memory_bank, index=ChromaIndex(self.client, collection)
|
||||
self.cache[memory_bank.identifier] = BankWithIndex(
|
||||
memory_bank, ChromaIndex(self.client, collection), self.inference_api
|
||||
)
|
||||
self.cache[memory_bank.identifier] = bank_index
|
||||
|
||||
async def unregister_memory_bank(self, memory_bank_id: str) -> None:
|
||||
await self.cache[memory_bank_id].index.delete()
|
||||
|
|
@ -166,6 +168,8 @@ class ChromaMemoryAdapter(Memory, MemoryBanksProtocolPrivate):
|
|||
collection = await maybe_await(self.client.get_collection(bank_id))
|
||||
if not collection:
|
||||
raise ValueError(f"Bank {bank_id} not found in Chroma")
|
||||
index = BankWithIndex(bank=bank, index=ChromaIndex(self.client, collection))
|
||||
index = BankWithIndex(
|
||||
bank, ChromaIndex(self.client, collection), self.inference_api
|
||||
)
|
||||
self.cache[bank_id] = index
|
||||
return index
|
||||
|
|
|
|||
|
|
@ -4,12 +4,16 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Dict
|
||||
|
||||
from llama_stack.providers.datatypes import Api, ProviderSpec
|
||||
|
||||
from .config import PGVectorConfig
|
||||
|
||||
|
||||
async def get_adapter_impl(config: PGVectorConfig, _deps):
|
||||
async def get_adapter_impl(config: PGVectorConfig, deps: Dict[Api, ProviderSpec]):
|
||||
from .pgvector import PGVectorMemoryAdapter
|
||||
|
||||
impl = PGVectorMemoryAdapter(config)
|
||||
impl = PGVectorMemoryAdapter(config, deps[Api.inference])
|
||||
await impl.initialize()
|
||||
return impl
|
||||
|
|
|
|||
|
|
@ -16,9 +16,9 @@ from pydantic import BaseModel, parse_obj_as
|
|||
|
||||
from llama_stack.apis.memory import * # noqa: F403
|
||||
|
||||
from llama_stack.providers.datatypes import MemoryBanksProtocolPrivate
|
||||
from llama_stack.providers.datatypes import Api, MemoryBanksProtocolPrivate
|
||||
|
||||
from llama_stack.providers.utils.memory.vector_store import (
|
||||
ALL_MINILM_L6_V2_DIMENSION,
|
||||
BankWithIndex,
|
||||
EmbeddingIndex,
|
||||
)
|
||||
|
|
@ -120,8 +120,9 @@ class PGVectorIndex(EmbeddingIndex):
|
|||
|
||||
|
||||
class PGVectorMemoryAdapter(Memory, MemoryBanksProtocolPrivate):
|
||||
def __init__(self, config: PGVectorConfig) -> None:
|
||||
def __init__(self, config: PGVectorConfig, inference_api: Api.inference) -> None:
|
||||
self.config = config
|
||||
self.inference_api = inference_api
|
||||
self.cursor = None
|
||||
self.conn = None
|
||||
self.cache = {}
|
||||
|
|
@ -160,27 +161,17 @@ class PGVectorMemoryAdapter(Memory, MemoryBanksProtocolPrivate):
|
|||
async def shutdown(self) -> None:
|
||||
pass
|
||||
|
||||
async def register_memory_bank(
|
||||
self,
|
||||
memory_bank: MemoryBank,
|
||||
) -> None:
|
||||
async def register_memory_bank(self, memory_bank: MemoryBank) -> None:
|
||||
assert (
|
||||
memory_bank.memory_bank_type == MemoryBankType.vector.value
|
||||
), f"Only vector banks are supported {memory_bank.memory_bank_type}"
|
||||
|
||||
upsert_models(
|
||||
self.cursor,
|
||||
[
|
||||
(memory_bank.identifier, memory_bank),
|
||||
],
|
||||
upsert_models(self.cursor, [(memory_bank.identifier, memory_bank)])
|
||||
index = PGVectorIndex(memory_bank, memory_bank.embedding_dimension, self.cursor)
|
||||
self.cache[memory_bank.identifier] = BankWithIndex(
|
||||
memory_bank, index, self.inference_api
|
||||
)
|
||||
|
||||
index = BankWithIndex(
|
||||
bank=memory_bank,
|
||||
index=PGVectorIndex(memory_bank, ALL_MINILM_L6_V2_DIMENSION, self.cursor),
|
||||
)
|
||||
self.cache[memory_bank.identifier] = index
|
||||
|
||||
async def unregister_memory_bank(self, memory_bank_id: str) -> None:
|
||||
await self.cache[memory_bank_id].index.delete()
|
||||
del self.cache[memory_bank_id]
|
||||
|
|
@ -203,14 +194,13 @@ class PGVectorMemoryAdapter(Memory, MemoryBanksProtocolPrivate):
|
|||
index = await self._get_and_cache_bank_index(bank_id)
|
||||
return await index.query_documents(query, params)
|
||||
|
||||
self.inference_api = inference_api
|
||||
|
||||
async def _get_and_cache_bank_index(self, bank_id: str) -> BankWithIndex:
|
||||
if bank_id in self.cache:
|
||||
return self.cache[bank_id]
|
||||
|
||||
bank = await self.memory_bank_store.get_memory_bank(bank_id)
|
||||
index = BankWithIndex(
|
||||
bank=bank,
|
||||
index=PGVectorIndex(bank, ALL_MINILM_L6_V2_DIMENSION, self.cursor),
|
||||
)
|
||||
self.cache[bank_id] = index
|
||||
return index
|
||||
index = PGVectorIndex(bank, bank.embedding_dimension, self.cursor)
|
||||
self.cache[bank_id] = BankWithIndex(bank, index, self.inference_api)
|
||||
return self.cache[bank_id]
|
||||
|
|
|
|||
|
|
@ -4,12 +4,16 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Dict
|
||||
|
||||
from llama_stack.providers.datatypes import Api, ProviderSpec
|
||||
|
||||
from .config import QdrantConfig
|
||||
|
||||
|
||||
async def get_adapter_impl(config: QdrantConfig, _deps):
|
||||
async def get_adapter_impl(config: QdrantConfig, deps: Dict[Api, ProviderSpec]):
|
||||
from .qdrant import QdrantVectorMemoryAdapter
|
||||
|
||||
impl = QdrantVectorMemoryAdapter(config)
|
||||
impl = QdrantVectorMemoryAdapter(config, deps[Api.inference])
|
||||
await impl.initialize()
|
||||
return impl
|
||||
|
|
|
|||
|
|
@ -101,10 +101,11 @@ class QdrantIndex(EmbeddingIndex):
|
|||
|
||||
|
||||
class QdrantVectorMemoryAdapter(Memory, MemoryBanksProtocolPrivate):
|
||||
def __init__(self, config: QdrantConfig) -> None:
|
||||
def __init__(self, config: QdrantConfig, inference_api: Api.inference) -> None:
|
||||
self.config = config
|
||||
self.client = AsyncQdrantClient(**self.config.model_dump(exclude_none=True))
|
||||
self.cache = {}
|
||||
self.inference_api = inference_api
|
||||
|
||||
async def initialize(self) -> None:
|
||||
pass
|
||||
|
|
@ -123,6 +124,7 @@ class QdrantVectorMemoryAdapter(Memory, MemoryBanksProtocolPrivate):
|
|||
index = BankWithIndex(
|
||||
bank=memory_bank,
|
||||
index=QdrantIndex(self.client, memory_bank.identifier),
|
||||
inference_api=self.inference_api,
|
||||
)
|
||||
|
||||
self.cache[memory_bank.identifier] = index
|
||||
|
|
@ -138,6 +140,7 @@ class QdrantVectorMemoryAdapter(Memory, MemoryBanksProtocolPrivate):
|
|||
index = BankWithIndex(
|
||||
bank=bank,
|
||||
index=QdrantIndex(client=self.client, collection_name=bank_id),
|
||||
inference_api=self.inference_api,
|
||||
)
|
||||
self.cache[bank_id] = index
|
||||
return index
|
||||
|
|
|
|||
|
|
@ -4,12 +4,16 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Dict
|
||||
|
||||
from llama_stack.providers.datatypes import Api, ProviderSpec
|
||||
|
||||
from .config import WeaviateConfig, WeaviateRequestProviderData # noqa: F401
|
||||
|
||||
|
||||
async def get_adapter_impl(config: WeaviateConfig, _deps):
|
||||
async def get_adapter_impl(config: WeaviateConfig, deps: Dict[Api, ProviderSpec]):
|
||||
from .weaviate import WeaviateMemoryAdapter
|
||||
|
||||
impl = WeaviateMemoryAdapter(config)
|
||||
impl = WeaviateMemoryAdapter(config, deps[Api.inference])
|
||||
await impl.initialize()
|
||||
return impl
|
||||
|
|
|
|||
|
|
@ -12,10 +12,11 @@ import weaviate
|
|||
import weaviate.classes as wvc
|
||||
from numpy.typing import NDArray
|
||||
from weaviate.classes.init import Auth
|
||||
from weaviate.classes.query import Filter
|
||||
|
||||
from llama_stack.apis.memory import * # noqa: F403
|
||||
from llama_stack.distribution.request_headers import NeedsRequestProviderData
|
||||
from llama_stack.providers.datatypes import MemoryBanksProtocolPrivate
|
||||
from llama_stack.providers.datatypes import Api, MemoryBanksProtocolPrivate
|
||||
from llama_stack.providers.utils.memory.vector_store import (
|
||||
BankWithIndex,
|
||||
EmbeddingIndex,
|
||||
|
|
@ -80,12 +81,21 @@ class WeaviateIndex(EmbeddingIndex):
|
|||
|
||||
return QueryDocumentsResponse(chunks=chunks, scores=scores)
|
||||
|
||||
async def delete(self, chunk_ids: List[str]) -> None:
|
||||
collection = self.client.collections.get(self.collection_name)
|
||||
collection.data.delete_many(
|
||||
where=Filter.by_property("id").contains_any(chunk_ids)
|
||||
)
|
||||
|
||||
|
||||
class WeaviateMemoryAdapter(
|
||||
Memory, NeedsRequestProviderData, MemoryBanksProtocolPrivate
|
||||
Memory,
|
||||
NeedsRequestProviderData,
|
||||
MemoryBanksProtocolPrivate,
|
||||
):
|
||||
def __init__(self, config: WeaviateConfig) -> None:
|
||||
def __init__(self, config: WeaviateConfig, inference_api: Api.inference) -> None:
|
||||
self.config = config
|
||||
self.inference_api = inference_api
|
||||
self.client_cache = {}
|
||||
self.cache = {}
|
||||
|
||||
|
|
@ -117,7 +127,7 @@ class WeaviateMemoryAdapter(
|
|||
memory_bank: MemoryBank,
|
||||
) -> None:
|
||||
assert (
|
||||
memory_bank.memory_bank_type == MemoryBankType.vector
|
||||
memory_bank.memory_bank_type == MemoryBankType.vector.value
|
||||
), f"Only vector banks are supported {memory_bank.memory_bank_type}"
|
||||
|
||||
client = self._get_client()
|
||||
|
|
@ -135,11 +145,11 @@ class WeaviateMemoryAdapter(
|
|||
],
|
||||
)
|
||||
|
||||
index = BankWithIndex(
|
||||
bank=memory_bank,
|
||||
index=WeaviateIndex(client=client, collection_name=memory_bank.identifier),
|
||||
self.cache[memory_bank.identifier] = BankWithIndex(
|
||||
memory_bank,
|
||||
WeaviateIndex(client=client, collection_name=memory_bank.identifier),
|
||||
self.inference_api,
|
||||
)
|
||||
self.cache[memory_bank.identifier] = index
|
||||
|
||||
async def _get_and_cache_bank_index(self, bank_id: str) -> Optional[BankWithIndex]:
|
||||
if bank_id in self.cache:
|
||||
|
|
@ -156,6 +166,7 @@ class WeaviateMemoryAdapter(
|
|||
index = BankWithIndex(
|
||||
bank=bank,
|
||||
index=WeaviateIndex(client=client, collection_name=bank_id),
|
||||
inference_api=self.inference_api,
|
||||
)
|
||||
self.cache[bank_id] = index
|
||||
return index
|
||||
|
|
|
|||
|
|
@ -156,4 +156,5 @@ pytest_plugins = [
|
|||
"llama_stack.providers.tests.datasetio.fixtures",
|
||||
"llama_stack.providers.tests.scoring.fixtures",
|
||||
"llama_stack.providers.tests.eval.fixtures",
|
||||
"llama_stack.providers.tests.post_training.fixtures",
|
||||
]
|
||||
|
|
|
|||
|
|
@ -10,6 +10,7 @@ import pytest_asyncio
|
|||
from llama_stack.distribution.datatypes import Api, Provider
|
||||
|
||||
from llama_stack.providers.tests.resolver import construct_stack_for_test
|
||||
|
||||
from ..conftest import ProviderFixture, remote_stack_fixture
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -18,6 +18,12 @@ def pytest_addoption(parser):
|
|||
default=None,
|
||||
help="Specify the inference model to use for testing",
|
||||
)
|
||||
parser.addoption(
|
||||
"--embedding-model",
|
||||
action="store",
|
||||
default=None,
|
||||
help="Specify the embedding model to use for testing",
|
||||
)
|
||||
|
||||
|
||||
def pytest_configure(config):
|
||||
|
|
|
|||
|
|
@ -9,9 +9,9 @@ import os
|
|||
import pytest
|
||||
import pytest_asyncio
|
||||
|
||||
from llama_stack.apis.models import ModelInput
|
||||
|
||||
from llama_stack.apis.models import ModelInput, ModelType
|
||||
from llama_stack.distribution.datatypes import Api, Provider
|
||||
|
||||
from llama_stack.providers.inline.inference.meta_reference import (
|
||||
MetaReferenceInferenceConfig,
|
||||
)
|
||||
|
|
@ -47,6 +47,9 @@ def inference_meta_reference(inference_model) -> ProviderFixture:
|
|||
inference_model = (
|
||||
[inference_model] if isinstance(inference_model, str) else inference_model
|
||||
)
|
||||
# If embedding dimension is set, use the 8B model for testing
|
||||
if os.getenv("EMBEDDING_DIMENSION"):
|
||||
inference_model = ["meta-llama/Llama-3.1-8B-Instruct"]
|
||||
|
||||
return ProviderFixture(
|
||||
providers=[
|
||||
|
|
@ -85,7 +88,7 @@ def inference_ollama(inference_model) -> ProviderFixture:
|
|||
inference_model = (
|
||||
[inference_model] if isinstance(inference_model, str) else inference_model
|
||||
)
|
||||
if "Llama3.1-8B-Instruct" in inference_model:
|
||||
if inference_model and "Llama3.1-8B-Instruct" in inference_model:
|
||||
pytest.skip("Ollama only supports Llama3.2-3B-Instruct for testing")
|
||||
|
||||
return ProviderFixture(
|
||||
|
|
@ -232,11 +235,23 @@ INFERENCE_FIXTURES = [
|
|||
async def inference_stack(request, inference_model):
|
||||
fixture_name = request.param
|
||||
inference_fixture = request.getfixturevalue(f"inference_{fixture_name}")
|
||||
model_type = ModelType.llm
|
||||
metadata = {}
|
||||
if os.getenv("EMBEDDING_DIMENSION"):
|
||||
model_type = ModelType.embedding
|
||||
metadata["embedding_dimension"] = get_env_or_fail("EMBEDDING_DIMENSION")
|
||||
|
||||
test_stack = await construct_stack_for_test(
|
||||
[Api.inference],
|
||||
{"inference": inference_fixture.providers},
|
||||
inference_fixture.provider_data,
|
||||
models=[ModelInput(model_id=inference_model)],
|
||||
models=[
|
||||
ModelInput(
|
||||
model_id=inference_model,
|
||||
model_type=model_type,
|
||||
metadata=metadata,
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
return test_stack.impls[Api.inference], test_stack.impls[Api.models]
|
||||
|
|
|
|||
62
llama_stack/providers/tests/inference/test_embeddings.py
Normal file
62
llama_stack/providers/tests/inference/test_embeddings.py
Normal file
|
|
@ -0,0 +1,62 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import pytest
|
||||
|
||||
from llama_stack.apis.inference import EmbeddingsResponse, ModelType
|
||||
|
||||
# How to run this test:
|
||||
# pytest -v -s llama_stack/providers/tests/inference/test_embeddings.py
|
||||
|
||||
|
||||
class TestEmbeddings:
|
||||
@pytest.mark.asyncio
|
||||
async def test_embeddings(self, inference_model, inference_stack):
|
||||
inference_impl, models_impl = inference_stack
|
||||
model = await models_impl.get_model(inference_model)
|
||||
|
||||
if model.model_type != ModelType.embedding:
|
||||
pytest.skip("This test is only applicable for embedding models")
|
||||
|
||||
response = await inference_impl.embeddings(
|
||||
model_id=inference_model,
|
||||
contents=["Hello, world!"],
|
||||
)
|
||||
assert isinstance(response, EmbeddingsResponse)
|
||||
assert len(response.embeddings) > 0
|
||||
assert all(isinstance(embedding, list) for embedding in response.embeddings)
|
||||
assert all(
|
||||
isinstance(value, float)
|
||||
for embedding in response.embeddings
|
||||
for value in embedding
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_batch_embeddings(self, inference_model, inference_stack):
|
||||
inference_impl, models_impl = inference_stack
|
||||
model = await models_impl.get_model(inference_model)
|
||||
|
||||
if model.model_type != ModelType.embedding:
|
||||
pytest.skip("This test is only applicable for embedding models")
|
||||
|
||||
texts = ["Hello, world!", "This is a test", "Testing embeddings"]
|
||||
|
||||
response = await inference_impl.embeddings(
|
||||
model_id=inference_model,
|
||||
contents=texts,
|
||||
)
|
||||
|
||||
assert isinstance(response, EmbeddingsResponse)
|
||||
assert len(response.embeddings) == len(texts)
|
||||
assert all(isinstance(embedding, list) for embedding in response.embeddings)
|
||||
assert all(
|
||||
isinstance(value, float)
|
||||
for embedding in response.embeddings
|
||||
for value in embedding
|
||||
)
|
||||
|
||||
embedding_dim = len(response.embeddings[0])
|
||||
assert all(len(embedding) == embedding_dim for embedding in response.embeddings)
|
||||
|
|
@ -128,6 +128,61 @@ class TestInference:
|
|||
last = chunks[-1]
|
||||
assert last.stop_reason == StopReason.out_of_tokens
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_completion_logprobs(self, inference_model, inference_stack):
|
||||
inference_impl, _ = inference_stack
|
||||
|
||||
provider = inference_impl.routing_table.get_provider_impl(inference_model)
|
||||
if provider.__provider_spec__.provider_type not in (
|
||||
# "remote::nvidia", -- provider doesn't provide all logprobs
|
||||
):
|
||||
pytest.skip("Other inference providers don't support completion() yet")
|
||||
|
||||
response = await inference_impl.completion(
|
||||
content="Micheael Jordan is born in ",
|
||||
stream=False,
|
||||
model_id=inference_model,
|
||||
sampling_params=SamplingParams(
|
||||
max_tokens=5,
|
||||
),
|
||||
logprobs=LogProbConfig(
|
||||
top_k=3,
|
||||
),
|
||||
)
|
||||
|
||||
assert isinstance(response, CompletionResponse)
|
||||
assert 1 <= len(response.logprobs) <= 5
|
||||
assert response.logprobs, "Logprobs should not be empty"
|
||||
assert all(len(logprob.logprobs_by_token) == 3 for logprob in response.logprobs)
|
||||
|
||||
chunks = [
|
||||
r
|
||||
async for r in await inference_impl.completion(
|
||||
content="Roses are red,",
|
||||
stream=True,
|
||||
model_id=inference_model,
|
||||
sampling_params=SamplingParams(
|
||||
max_tokens=5,
|
||||
),
|
||||
logprobs=LogProbConfig(
|
||||
top_k=3,
|
||||
),
|
||||
)
|
||||
]
|
||||
|
||||
assert all(isinstance(chunk, CompletionResponseStreamChunk) for chunk in chunks)
|
||||
assert (
|
||||
1 <= len(chunks) <= 6
|
||||
) # why 6 and not 5? the response may have an extra closing chunk, e.g. for usage or stop_reason
|
||||
for chunk in chunks:
|
||||
if chunk.delta: # if there's a token, we expect logprobs
|
||||
assert chunk.logprobs, "Logprobs should not be empty"
|
||||
assert all(
|
||||
len(logprob.logprobs_by_token) == 3 for logprob in chunk.logprobs
|
||||
)
|
||||
else: # no token, no logprobs
|
||||
assert not chunk.logprobs, "Logprobs should be empty"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.skip("This test is not quite robust")
|
||||
async def test_completion_structured_output(self, inference_model, inference_stack):
|
||||
|
|
|
|||
|
|
@ -6,9 +6,65 @@
|
|||
|
||||
import pytest
|
||||
|
||||
from ..conftest import get_provider_fixture_overrides
|
||||
|
||||
from ..inference.fixtures import INFERENCE_FIXTURES
|
||||
from .fixtures import MEMORY_FIXTURES
|
||||
|
||||
|
||||
DEFAULT_PROVIDER_COMBINATIONS = [
|
||||
pytest.param(
|
||||
{
|
||||
"inference": "meta_reference",
|
||||
"memory": "faiss",
|
||||
},
|
||||
id="meta_reference",
|
||||
marks=pytest.mark.meta_reference,
|
||||
),
|
||||
pytest.param(
|
||||
{
|
||||
"inference": "ollama",
|
||||
"memory": "pgvector",
|
||||
},
|
||||
id="ollama",
|
||||
marks=pytest.mark.ollama,
|
||||
),
|
||||
pytest.param(
|
||||
{
|
||||
"inference": "together",
|
||||
"memory": "chroma",
|
||||
},
|
||||
id="chroma",
|
||||
marks=pytest.mark.chroma,
|
||||
),
|
||||
pytest.param(
|
||||
{
|
||||
"inference": "bedrock",
|
||||
"memory": "qdrant",
|
||||
},
|
||||
id="qdrant",
|
||||
marks=pytest.mark.qdrant,
|
||||
),
|
||||
pytest.param(
|
||||
{
|
||||
"inference": "fireworks",
|
||||
"memory": "weaviate",
|
||||
},
|
||||
id="weaviate",
|
||||
marks=pytest.mark.weaviate,
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
def pytest_addoption(parser):
|
||||
parser.addoption(
|
||||
"--inference-model",
|
||||
action="store",
|
||||
default=None,
|
||||
help="Specify the inference model to use for testing",
|
||||
)
|
||||
|
||||
|
||||
def pytest_configure(config):
|
||||
for fixture_name in MEMORY_FIXTURES:
|
||||
config.addinivalue_line(
|
||||
|
|
@ -18,12 +74,22 @@ def pytest_configure(config):
|
|||
|
||||
|
||||
def pytest_generate_tests(metafunc):
|
||||
if "inference_model" in metafunc.fixturenames:
|
||||
model = metafunc.config.getoption("--inference-model")
|
||||
if not model:
|
||||
raise ValueError(
|
||||
"No inference model specified. Please provide a valid inference model."
|
||||
)
|
||||
params = [pytest.param(model, id="")]
|
||||
|
||||
metafunc.parametrize("inference_model", params, indirect=True)
|
||||
if "memory_stack" in metafunc.fixturenames:
|
||||
metafunc.parametrize(
|
||||
"memory_stack",
|
||||
[
|
||||
pytest.param(fixture_name, marks=getattr(pytest.mark, fixture_name))
|
||||
for fixture_name in MEMORY_FIXTURES
|
||||
],
|
||||
indirect=True,
|
||||
available_fixtures = {
|
||||
"inference": INFERENCE_FIXTURES,
|
||||
"memory": MEMORY_FIXTURES,
|
||||
}
|
||||
combinations = (
|
||||
get_provider_fixture_overrides(metafunc.config, available_fixtures)
|
||||
or DEFAULT_PROVIDER_COMBINATIONS
|
||||
)
|
||||
metafunc.parametrize("memory_stack", combinations, indirect=True)
|
||||
|
|
|
|||
|
|
@ -10,6 +10,8 @@ import tempfile
|
|||
import pytest
|
||||
import pytest_asyncio
|
||||
|
||||
from llama_stack.apis.inference import ModelInput, ModelType
|
||||
|
||||
from llama_stack.distribution.datatypes import Api, Provider
|
||||
from llama_stack.providers.inline.memory.chroma import ChromaInlineImplConfig
|
||||
from llama_stack.providers.inline.memory.faiss import FaissImplConfig
|
||||
|
|
@ -105,14 +107,30 @@ MEMORY_FIXTURES = ["faiss", "pgvector", "weaviate", "remote", "chroma"]
|
|||
|
||||
|
||||
@pytest_asyncio.fixture(scope="session")
|
||||
async def memory_stack(request):
|
||||
fixture_name = request.param
|
||||
fixture = request.getfixturevalue(f"memory_{fixture_name}")
|
||||
async def memory_stack(inference_model, request):
|
||||
fixture_dict = request.param
|
||||
|
||||
providers = {}
|
||||
provider_data = {}
|
||||
for key in ["inference", "memory"]:
|
||||
fixture = request.getfixturevalue(f"{key}_{fixture_dict[key]}")
|
||||
providers[key] = fixture.providers
|
||||
if fixture.provider_data:
|
||||
provider_data.update(fixture.provider_data)
|
||||
|
||||
test_stack = await construct_stack_for_test(
|
||||
[Api.memory],
|
||||
{"memory": fixture.providers},
|
||||
fixture.provider_data,
|
||||
[Api.memory, Api.inference],
|
||||
providers,
|
||||
provider_data,
|
||||
models=[
|
||||
ModelInput(
|
||||
model_id=inference_model,
|
||||
model_type=ModelType.embedding,
|
||||
metadata={
|
||||
"embedding_dimension": get_env_or_fail("EMBEDDING_DIMENSION"),
|
||||
},
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
return test_stack.impls[Api.memory], test_stack.impls[Api.memory_banks]
|
||||
|
|
|
|||
|
|
@ -45,12 +45,14 @@ def sample_documents():
|
|||
]
|
||||
|
||||
|
||||
async def register_memory_bank(banks_impl: MemoryBanks) -> MemoryBank:
|
||||
async def register_memory_bank(
|
||||
banks_impl: MemoryBanks, inference_model: str
|
||||
) -> MemoryBank:
|
||||
bank_id = f"test_bank_{uuid.uuid4().hex}"
|
||||
return await banks_impl.register_memory_bank(
|
||||
memory_bank_id=bank_id,
|
||||
params=VectorMemoryBankParams(
|
||||
embedding_model="all-MiniLM-L6-v2",
|
||||
embedding_model=inference_model,
|
||||
chunk_size_in_tokens=512,
|
||||
overlap_size_in_tokens=64,
|
||||
),
|
||||
|
|
@ -59,11 +61,11 @@ async def register_memory_bank(banks_impl: MemoryBanks) -> MemoryBank:
|
|||
|
||||
class TestMemory:
|
||||
@pytest.mark.asyncio
|
||||
async def test_banks_list(self, memory_stack):
|
||||
async def test_banks_list(self, memory_stack, inference_model):
|
||||
_, banks_impl = memory_stack
|
||||
|
||||
# Register a test bank
|
||||
registered_bank = await register_memory_bank(banks_impl)
|
||||
registered_bank = await register_memory_bank(banks_impl, inference_model)
|
||||
|
||||
try:
|
||||
# Verify our bank shows up in list
|
||||
|
|
@ -84,7 +86,7 @@ class TestMemory:
|
|||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_banks_register(self, memory_stack):
|
||||
async def test_banks_register(self, memory_stack, inference_model):
|
||||
_, banks_impl = memory_stack
|
||||
|
||||
bank_id = f"test_bank_{uuid.uuid4().hex}"
|
||||
|
|
@ -94,7 +96,7 @@ class TestMemory:
|
|||
await banks_impl.register_memory_bank(
|
||||
memory_bank_id=bank_id,
|
||||
params=VectorMemoryBankParams(
|
||||
embedding_model="all-MiniLM-L6-v2",
|
||||
embedding_model=inference_model,
|
||||
chunk_size_in_tokens=512,
|
||||
overlap_size_in_tokens=64,
|
||||
),
|
||||
|
|
@ -109,7 +111,7 @@ class TestMemory:
|
|||
await banks_impl.register_memory_bank(
|
||||
memory_bank_id=bank_id,
|
||||
params=VectorMemoryBankParams(
|
||||
embedding_model="all-MiniLM-L6-v2",
|
||||
embedding_model=inference_model,
|
||||
chunk_size_in_tokens=512,
|
||||
overlap_size_in_tokens=64,
|
||||
),
|
||||
|
|
@ -126,13 +128,15 @@ class TestMemory:
|
|||
await banks_impl.unregister_memory_bank(bank_id)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_documents(self, memory_stack, sample_documents):
|
||||
async def test_query_documents(
|
||||
self, memory_stack, inference_model, sample_documents
|
||||
):
|
||||
memory_impl, banks_impl = memory_stack
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
await memory_impl.insert_documents("test_bank", sample_documents)
|
||||
|
||||
registered_bank = await register_memory_bank(banks_impl)
|
||||
registered_bank = await register_memory_bank(banks_impl, inference_model)
|
||||
await memory_impl.insert_documents(
|
||||
registered_bank.memory_bank_id, sample_documents
|
||||
)
|
||||
|
|
@ -165,13 +169,13 @@ class TestMemory:
|
|||
|
||||
# Test case 5: Query with threshold on similarity score
|
||||
query5 = "quantum computing" # Not directly related to any document
|
||||
params5 = {"score_threshold": 0.2}
|
||||
params5 = {"score_threshold": 0.01}
|
||||
response5 = await memory_impl.query_documents(
|
||||
registered_bank.memory_bank_id, query5, params5
|
||||
)
|
||||
assert_valid_response(response5)
|
||||
print("The scores are:", response5.scores)
|
||||
assert all(score >= 0.2 for score in response5.scores)
|
||||
assert all(score >= 0.01 for score in response5.scores)
|
||||
|
||||
|
||||
def assert_valid_response(response: QueryDocumentsResponse):
|
||||
|
|
|
|||
5
llama_stack/providers/tests/post_training/__init__.py
Normal file
5
llama_stack/providers/tests/post_training/__init__.py
Normal file
|
|
@ -0,0 +1,5 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
45
llama_stack/providers/tests/post_training/conftest.py
Normal file
45
llama_stack/providers/tests/post_training/conftest.py
Normal file
|
|
@ -0,0 +1,45 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import pytest
|
||||
|
||||
from ..conftest import get_provider_fixture_overrides
|
||||
|
||||
from ..datasetio.fixtures import DATASETIO_FIXTURES
|
||||
|
||||
from .fixtures import POST_TRAINING_FIXTURES
|
||||
|
||||
DEFAULT_PROVIDER_COMBINATIONS = [
|
||||
pytest.param(
|
||||
{
|
||||
"post_training": "torchtune",
|
||||
"datasetio": "huggingface",
|
||||
},
|
||||
id="torchtune_post_training_huggingface_datasetio",
|
||||
marks=pytest.mark.torchtune_post_training_huggingface_datasetio,
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
def pytest_configure(config):
|
||||
combined_fixtures = "torchtune_post_training_huggingface_datasetio"
|
||||
config.addinivalue_line(
|
||||
"markers",
|
||||
f"{combined_fixtures}: marks tests as {combined_fixtures} specific",
|
||||
)
|
||||
|
||||
|
||||
def pytest_generate_tests(metafunc):
|
||||
if "post_training_stack" in metafunc.fixturenames:
|
||||
available_fixtures = {
|
||||
"eval": POST_TRAINING_FIXTURES,
|
||||
"datasetio": DATASETIO_FIXTURES,
|
||||
}
|
||||
combinations = (
|
||||
get_provider_fixture_overrides(metafunc.config, available_fixtures)
|
||||
or DEFAULT_PROVIDER_COMBINATIONS
|
||||
)
|
||||
metafunc.parametrize("post_training_stack", combinations, indirect=True)
|
||||
74
llama_stack/providers/tests/post_training/fixtures.py
Normal file
74
llama_stack/providers/tests/post_training/fixtures.py
Normal file
|
|
@ -0,0 +1,74 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
|
||||
from llama_models.llama3.api.datatypes import URL
|
||||
from llama_stack.apis.common.type_system import * # noqa: F403
|
||||
from llama_stack.apis.datasets import DatasetInput
|
||||
from llama_stack.apis.models import ModelInput
|
||||
|
||||
from llama_stack.distribution.datatypes import Api, Provider
|
||||
|
||||
from llama_stack.providers.tests.resolver import construct_stack_for_test
|
||||
|
||||
from ..conftest import ProviderFixture
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def post_training_torchtune() -> ProviderFixture:
|
||||
return ProviderFixture(
|
||||
providers=[
|
||||
Provider(
|
||||
provider_id="torchtune",
|
||||
provider_type="inline::torchtune",
|
||||
config={},
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
POST_TRAINING_FIXTURES = ["torchtune"]
|
||||
|
||||
|
||||
@pytest_asyncio.fixture(scope="session")
|
||||
async def post_training_stack(request):
|
||||
fixture_dict = request.param
|
||||
|
||||
providers = {}
|
||||
provider_data = {}
|
||||
for key in ["post_training", "datasetio"]:
|
||||
fixture = request.getfixturevalue(f"{key}_{fixture_dict[key]}")
|
||||
providers[key] = fixture.providers
|
||||
if fixture.provider_data:
|
||||
provider_data.update(fixture.provider_data)
|
||||
|
||||
test_stack = await construct_stack_for_test(
|
||||
[Api.post_training, Api.datasetio],
|
||||
providers,
|
||||
provider_data,
|
||||
models=[ModelInput(model_id="meta-llama/Llama-3.2-3B-Instruct")],
|
||||
datasets=[
|
||||
DatasetInput(
|
||||
dataset_id="alpaca",
|
||||
provider_id="huggingface",
|
||||
url=URL(uri="https://huggingface.co/datasets/tatsu-lab/alpaca"),
|
||||
metadata={
|
||||
"path": "tatsu-lab/alpaca",
|
||||
"split": "train",
|
||||
},
|
||||
dataset_schema={
|
||||
"instruction": StringType(),
|
||||
"input": StringType(),
|
||||
"output": StringType(),
|
||||
"text": StringType(),
|
||||
},
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
return test_stack.impls[Api.post_training]
|
||||
|
|
@ -0,0 +1,92 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
import pytest
|
||||
from llama_stack.apis.common.type_system import * # noqa: F403
|
||||
from llama_stack.apis.post_training import * # noqa: F403
|
||||
from llama_stack.distribution.datatypes import * # noqa: F403
|
||||
|
||||
# How to run this test:
|
||||
#
|
||||
# pytest llama_stack/providers/tests/post_training/test_post_training.py
|
||||
# -m "torchtune_post_training_huggingface_datasetio"
|
||||
# -v -s --tb=short --disable-warnings
|
||||
|
||||
|
||||
class TestPostTraining:
|
||||
@pytest.mark.asyncio
|
||||
async def test_supervised_fine_tune(self, post_training_stack):
|
||||
algorithm_config = LoraFinetuningConfig(
|
||||
type="LoRA",
|
||||
lora_attn_modules=["q_proj", "v_proj", "output_proj"],
|
||||
apply_lora_to_mlp=True,
|
||||
apply_lora_to_output=False,
|
||||
rank=8,
|
||||
alpha=16,
|
||||
)
|
||||
|
||||
data_config = DataConfig(
|
||||
dataset_id="alpaca",
|
||||
batch_size=1,
|
||||
shuffle=False,
|
||||
)
|
||||
|
||||
optimizer_config = OptimizerConfig(
|
||||
optimizer_type="adamw",
|
||||
lr=3e-4,
|
||||
lr_min=3e-5,
|
||||
weight_decay=0.1,
|
||||
num_warmup_steps=100,
|
||||
)
|
||||
|
||||
training_config = TrainingConfig(
|
||||
n_epochs=1,
|
||||
data_config=data_config,
|
||||
optimizer_config=optimizer_config,
|
||||
max_steps_per_epoch=1,
|
||||
gradient_accumulation_steps=1,
|
||||
)
|
||||
post_training_impl = post_training_stack
|
||||
response = await post_training_impl.supervised_fine_tune(
|
||||
job_uuid="1234",
|
||||
model="Llama3.2-3B-Instruct",
|
||||
algorithm_config=algorithm_config,
|
||||
training_config=training_config,
|
||||
hyperparam_search_config={},
|
||||
logger_config={},
|
||||
checkpoint_dir="null",
|
||||
)
|
||||
assert isinstance(response, PostTrainingJob)
|
||||
assert response.job_uuid == "1234"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_training_jobs(self, post_training_stack):
|
||||
post_training_impl = post_training_stack
|
||||
jobs_list = await post_training_impl.get_training_jobs()
|
||||
assert isinstance(jobs_list, List)
|
||||
assert jobs_list[0].job_uuid == "1234"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_training_job_status(self, post_training_stack):
|
||||
post_training_impl = post_training_stack
|
||||
job_status = await post_training_impl.get_training_job_status("1234")
|
||||
assert isinstance(job_status, PostTrainingJobStatusResponse)
|
||||
assert job_status.job_uuid == "1234"
|
||||
assert job_status.status == JobStatus.completed
|
||||
assert isinstance(job_status.checkpoints[0], Checkpoint)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_training_job_artifacts(self, post_training_stack):
|
||||
post_training_impl = post_training_stack
|
||||
job_artifacts = await post_training_impl.get_training_job_artifacts("1234")
|
||||
assert isinstance(job_artifacts, PostTrainingJobArtifactsResponse)
|
||||
assert job_artifacts.job_uuid == "1234"
|
||||
assert isinstance(job_artifacts.checkpoints[0], Checkpoint)
|
||||
assert job_artifacts.checkpoints[0].identifier == "Llama3.2-3B-Instruct-sft-0"
|
||||
assert job_artifacts.checkpoints[0].epoch == 0
|
||||
assert (
|
||||
"/.llama/checkpoints/Llama3.2-3B-Instruct-sft-0"
|
||||
in job_artifacts.checkpoints[0].path
|
||||
)
|
||||
47
llama_stack/providers/utils/inference/embedding_mixin.py
Normal file
47
llama_stack/providers/utils/inference/embedding_mixin.py
Normal file
|
|
@ -0,0 +1,47 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import logging
|
||||
from typing import List
|
||||
|
||||
from llama_models.llama3.api.datatypes import InterleavedTextMedia
|
||||
|
||||
from llama_stack.apis.inference.inference import EmbeddingsResponse, ModelStore
|
||||
|
||||
EMBEDDING_MODELS = {}
|
||||
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class SentenceTransformerEmbeddingMixin:
|
||||
model_store: ModelStore
|
||||
|
||||
async def embeddings(
|
||||
self,
|
||||
model_id: str,
|
||||
contents: List[InterleavedTextMedia],
|
||||
) -> EmbeddingsResponse:
|
||||
model = await self.model_store.get_model(model_id)
|
||||
embedding_model = self._load_sentence_transformer_model(
|
||||
model.provider_resource_id
|
||||
)
|
||||
embeddings = embedding_model.encode(contents)
|
||||
return EmbeddingsResponse(embeddings=embeddings)
|
||||
|
||||
def _load_sentence_transformer_model(self, model: str) -> "SentenceTransformer":
|
||||
global EMBEDDING_MODELS
|
||||
|
||||
loaded_model = EMBEDDING_MODELS.get(model)
|
||||
if loaded_model is not None:
|
||||
return loaded_model
|
||||
|
||||
log.info(f"Loading sentence transformer for {model}...")
|
||||
from sentence_transformers import SentenceTransformer
|
||||
|
||||
loaded_model = SentenceTransformer(model)
|
||||
EMBEDDING_MODELS[model] = loaded_model
|
||||
return loaded_model
|
||||
|
|
@ -9,6 +9,7 @@ from typing import List, Optional
|
|||
|
||||
from llama_models.sku_list import all_registered_models
|
||||
|
||||
from llama_stack.apis.models.models import ModelType
|
||||
from llama_stack.providers.datatypes import Model, ModelsProtocolPrivate
|
||||
|
||||
from llama_stack.providers.utils.inference import (
|
||||
|
|
@ -77,7 +78,13 @@ class ModelRegistryHelper(ModelsProtocolPrivate):
|
|||
return None
|
||||
|
||||
async def register_model(self, model: Model) -> Model:
|
||||
provider_resource_id = self.get_provider_model_id(model.provider_resource_id)
|
||||
if model.model_type == ModelType.embedding:
|
||||
# embedding models are always registered by their provider model id and does not need to be mapped to a llama model
|
||||
provider_resource_id = model.provider_resource_id
|
||||
else:
|
||||
provider_resource_id = self.get_provider_model_id(
|
||||
model.provider_resource_id
|
||||
)
|
||||
if provider_resource_id:
|
||||
model.provider_resource_id = provider_resource_id
|
||||
else:
|
||||
|
|
|
|||
|
|
@ -22,28 +22,10 @@ from llama_models.llama3.api.datatypes import * # noqa: F403
|
|||
from llama_models.llama3.api.tokenizer import Tokenizer
|
||||
|
||||
from llama_stack.apis.memory import * # noqa: F403
|
||||
from llama_stack.providers.datatypes import Api
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
ALL_MINILM_L6_V2_DIMENSION = 384
|
||||
|
||||
EMBEDDING_MODELS = {}
|
||||
|
||||
|
||||
def get_embedding_model(model: str) -> "SentenceTransformer":
|
||||
global EMBEDDING_MODELS
|
||||
|
||||
loaded_model = EMBEDDING_MODELS.get(model)
|
||||
if loaded_model is not None:
|
||||
return loaded_model
|
||||
|
||||
log.info(f"Loading sentence transformer for {model}...")
|
||||
from sentence_transformers import SentenceTransformer
|
||||
|
||||
loaded_model = SentenceTransformer(model)
|
||||
EMBEDDING_MODELS[model] = loaded_model
|
||||
return loaded_model
|
||||
|
||||
|
||||
def parse_pdf(data: bytes) -> str:
|
||||
# For PDF and DOC/DOCX files, we can't reliably convert to string
|
||||
|
|
@ -166,12 +148,12 @@ class EmbeddingIndex(ABC):
|
|||
class BankWithIndex:
|
||||
bank: VectorMemoryBank
|
||||
index: EmbeddingIndex
|
||||
inference_api: Api.inference
|
||||
|
||||
async def insert_documents(
|
||||
self,
|
||||
documents: List[MemoryBankDocument],
|
||||
) -> None:
|
||||
model = get_embedding_model(self.bank.embedding_model)
|
||||
for doc in documents:
|
||||
content = await content_from_doc(doc)
|
||||
chunks = make_overlapped_chunks(
|
||||
|
|
@ -183,7 +165,10 @@ class BankWithIndex:
|
|||
)
|
||||
if not chunks:
|
||||
continue
|
||||
embeddings = model.encode([x.content for x in chunks]).astype(np.float32)
|
||||
embeddings_response = await self.inference_api.embeddings(
|
||||
self.bank.embedding_model, [x.content for x in chunks]
|
||||
)
|
||||
embeddings = np.array(embeddings_response.embeddings)
|
||||
|
||||
await self.index.add_chunks(chunks, embeddings)
|
||||
|
||||
|
|
@ -208,6 +193,8 @@ class BankWithIndex:
|
|||
else:
|
||||
query_str = _process(query)
|
||||
|
||||
model = get_embedding_model(self.bank.embedding_model)
|
||||
query_vector = model.encode([query_str])[0].astype(np.float32)
|
||||
embeddings_response = await self.inference_api.embeddings(
|
||||
self.bank.embedding_model, [query_str]
|
||||
)
|
||||
query_vector = np.array(embeddings_response.embeddings[0], dtype=np.float32)
|
||||
return await self.index.query(query_vector, k, score_threshold)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue