Merge remote-tracking branch 'origin/main' into support_more_data_format

This commit is contained in:
Botao Chen 2025-01-06 14:19:10 -08:00
commit 2a992d4f05
10 changed files with 76 additions and 55 deletions

View file

@ -7,6 +7,7 @@
import asyncio import asyncio
import inspect import inspect
import json import json
import logging
import os import os
import queue import queue
import threading import threading
@ -16,7 +17,6 @@ from pathlib import Path
from typing import Any, Generator, get_args, get_origin, Optional, TypeVar from typing import Any, Generator, get_args, get_origin, Optional, TypeVar
import httpx import httpx
import yaml import yaml
from llama_stack_client import ( from llama_stack_client import (
APIResponse, APIResponse,
@ -28,7 +28,6 @@ from llama_stack_client import (
) )
from pydantic import BaseModel, TypeAdapter from pydantic import BaseModel, TypeAdapter
from rich.console import Console from rich.console import Console
from termcolor import cprint from termcolor import cprint
from llama_stack.distribution.build import print_pip_install_help from llama_stack.distribution.build import print_pip_install_help
@ -42,7 +41,6 @@ from llama_stack.distribution.stack import (
redact_sensitive_fields, redact_sensitive_fields,
replace_env_vars, replace_env_vars,
) )
from llama_stack.providers.utils.telemetry.tracing import ( from llama_stack.providers.utils.telemetry.tracing import (
end_trace, end_trace,
setup_logger, setup_logger,
@ -174,6 +172,7 @@ class LlamaStackAsLibraryClient(LlamaStackClient):
def __init__( def __init__(
self, self,
config_path_or_template_name: str, config_path_or_template_name: str,
skip_logger_removal: bool = False,
custom_provider_registry: Optional[ProviderRegistry] = None, custom_provider_registry: Optional[ProviderRegistry] = None,
): ):
super().__init__() super().__init__()
@ -181,15 +180,28 @@ class LlamaStackAsLibraryClient(LlamaStackClient):
config_path_or_template_name, custom_provider_registry config_path_or_template_name, custom_provider_registry
) )
self.pool_executor = ThreadPoolExecutor(max_workers=4) self.pool_executor = ThreadPoolExecutor(max_workers=4)
self.skip_logger_removal = skip_logger_removal
def initialize(self): def initialize(self):
if in_notebook(): if in_notebook():
import nest_asyncio import nest_asyncio
nest_asyncio.apply() nest_asyncio.apply()
if not self.skip_logger_removal:
self._remove_root_logger_handlers()
return asyncio.run(self.async_client.initialize()) return asyncio.run(self.async_client.initialize())
def _remove_root_logger_handlers(self):
"""
Remove all handlers from the root logger. Needed to avoid polluting the console with logs.
"""
root_logger = logging.getLogger()
for handler in root_logger.handlers[:]:
root_logger.removeHandler(handler)
print(f"Removed handler {handler.__class__.__name__} from root logger")
def _get_path( def _get_path(
self, self,
cast_to: Any, cast_to: Any,

View file

@ -18,8 +18,8 @@ from llama_stack.providers.datatypes import EvalTasksProtocolPrivate
from llama_stack.providers.utils.common.data_schema_validator import ( from llama_stack.providers.utils.common.data_schema_validator import (
ColumnName, ColumnName,
DataSchemaValidatorMixin,
get_valid_schemas, get_valid_schemas,
validate_dataset_schema,
) )
from llama_stack.providers.utils.kvstore import kvstore_impl from llama_stack.providers.utils.kvstore import kvstore_impl
@ -31,7 +31,10 @@ from .config import MetaReferenceEvalConfig
EVAL_TASKS_PREFIX = "eval_tasks:" EVAL_TASKS_PREFIX = "eval_tasks:"
class MetaReferenceEvalImpl(Eval, EvalTasksProtocolPrivate, DataSchemaValidatorMixin): class MetaReferenceEvalImpl(
Eval,
EvalTasksProtocolPrivate,
):
def __init__( def __init__(
self, self,
config: MetaReferenceEvalConfig, config: MetaReferenceEvalConfig,
@ -85,7 +88,7 @@ class MetaReferenceEvalImpl(Eval, EvalTasksProtocolPrivate, DataSchemaValidatorM
candidate = task_config.eval_candidate candidate = task_config.eval_candidate
scoring_functions = task_def.scoring_functions scoring_functions = task_def.scoring_functions
dataset_def = await self.datasets_api.get_dataset(dataset_id=dataset_id) dataset_def = await self.datasets_api.get_dataset(dataset_id=dataset_id)
self.validate_dataset_schema( validate_dataset_schema(
dataset_def.dataset_schema, get_valid_schemas(Api.eval.value) dataset_def.dataset_schema, get_valid_schemas(Api.eval.value)
) )
all_rows = await self.datasetio_api.get_rows_paginated( all_rows = await self.datasetio_api.get_rows_paginated(

View file

@ -90,16 +90,22 @@ class TorchtuneCheckpointer:
model_file_path.mkdir(parents=True, exist_ok=True) model_file_path.mkdir(parents=True, exist_ok=True)
# copy the related files for inference # copy the related files for inference
source_path = Path.joinpath(self._checkpoint_dir, "params.json")
if source_path.exists():
shutil.copy( shutil.copy(
Path.joinpath(self._checkpoint_dir, "params.json"), source_path,
Path.joinpath(model_file_path, "params.json"), Path.joinpath(model_file_path, "params.json"),
) )
source_path = Path.joinpath(self._checkpoint_dir, "tokenizer.model")
if source_path.exists():
shutil.copy( shutil.copy(
Path.joinpath(self._checkpoint_dir, "tokenizer.model"), source_path,
Path.joinpath(model_file_path, "tokenizer.model"), Path.joinpath(model_file_path, "tokenizer.model"),
) )
source_path = Path.joinpath(self._checkpoint_dir, "orig_params.json")
if source_path.exists():
shutil.copy( shutil.copy(
Path.joinpath(self._checkpoint_dir, "orig_params.json"), source_path,
Path.joinpath(model_file_path, "orig_params.json"), Path.joinpath(model_file_path, "orig_params.json"),
) )

View file

@ -29,8 +29,9 @@ from torchtune.data._messages import (
ShareGPTToMessages, ShareGPTToMessages,
) )
from torchtune.models.llama3 import llama3_tokenizer, lora_llama3_8b from torchtune.models.llama3 import llama3_tokenizer
from torchtune.models.llama3._tokenizer import Llama3Tokenizer from torchtune.models.llama3._tokenizer import Llama3Tokenizer
from torchtune.models.llama3_1 import lora_llama3_1_8b
from torchtune.models.llama3_2 import lora_llama3_2_3b from torchtune.models.llama3_2 import lora_llama3_2_3b
from torchtune.modules.transforms import Transform from torchtune.modules.transforms import Transform
@ -63,8 +64,8 @@ MODEL_CONFIGS: Dict[str, ModelConfig] = {
tokenizer_type=llama3_tokenizer, tokenizer_type=llama3_tokenizer,
checkpoint_type="LLAMA3_2", checkpoint_type="LLAMA3_2",
), ),
"Llama-3-8B-Instruct": ModelConfig( "Llama3.1-8B-Instruct": ModelConfig(
model_definition=lora_llama3_8b, model_definition=lora_llama3_1_8b,
tokenizer_type=llama3_tokenizer, tokenizer_type=llama3_tokenizer,
checkpoint_type="LLAMA3", checkpoint_type="LLAMA3",
), ),

View file

@ -18,8 +18,8 @@ from llama_stack.apis.scoring_functions import ScoringFn, ScoringFnParams
from llama_stack.distribution.datatypes import Api from llama_stack.distribution.datatypes import Api
from llama_stack.providers.datatypes import ScoringFunctionsProtocolPrivate from llama_stack.providers.datatypes import ScoringFunctionsProtocolPrivate
from llama_stack.providers.utils.common.data_schema_validator import ( from llama_stack.providers.utils.common.data_schema_validator import (
DataSchemaValidatorMixin,
get_valid_schemas, get_valid_schemas,
validate_dataset_schema,
) )
from .config import BasicScoringConfig from .config import BasicScoringConfig
from .scoring_fn.equality_scoring_fn import EqualityScoringFn from .scoring_fn.equality_scoring_fn import EqualityScoringFn
@ -30,7 +30,8 @@ FIXED_FNS = [EqualityScoringFn, SubsetOfScoringFn, RegexParserScoringFn]
class BasicScoringImpl( class BasicScoringImpl(
Scoring, ScoringFunctionsProtocolPrivate, DataSchemaValidatorMixin Scoring,
ScoringFunctionsProtocolPrivate,
): ):
def __init__( def __init__(
self, self,
@ -75,7 +76,7 @@ class BasicScoringImpl(
save_results_dataset: bool = False, save_results_dataset: bool = False,
) -> ScoreBatchResponse: ) -> ScoreBatchResponse:
dataset_def = await self.datasets_api.get_dataset(dataset_id=dataset_id) dataset_def = await self.datasets_api.get_dataset(dataset_id=dataset_id)
self.validate_dataset_schema( validate_dataset_schema(
dataset_def.dataset_schema, get_valid_schemas(Api.scoring.value) dataset_def.dataset_schema, get_valid_schemas(Api.scoring.value)
) )

View file

@ -35,8 +35,9 @@ from llama_stack.distribution.datatypes import Api
from llama_stack.distribution.request_headers import NeedsRequestProviderData from llama_stack.distribution.request_headers import NeedsRequestProviderData
from llama_stack.providers.datatypes import ScoringFunctionsProtocolPrivate from llama_stack.providers.datatypes import ScoringFunctionsProtocolPrivate
from llama_stack.providers.utils.common.data_schema_validator import ( from llama_stack.providers.utils.common.data_schema_validator import (
DataSchemaValidatorMixin,
get_valid_schemas, get_valid_schemas,
validate_dataset_schema,
validate_row_schema,
) )
from llama_stack.providers.utils.scoring.aggregation_utils import aggregate_metrics from llama_stack.providers.utils.scoring.aggregation_utils import aggregate_metrics
@ -111,7 +112,6 @@ class BraintrustScoringImpl(
Scoring, Scoring,
ScoringFunctionsProtocolPrivate, ScoringFunctionsProtocolPrivate,
NeedsRequestProviderData, NeedsRequestProviderData,
DataSchemaValidatorMixin,
): ):
def __init__( def __init__(
self, self,
@ -171,7 +171,7 @@ class BraintrustScoringImpl(
await self.set_api_key() await self.set_api_key()
dataset_def = await self.datasets_api.get_dataset(dataset_id=dataset_id) dataset_def = await self.datasets_api.get_dataset(dataset_id=dataset_id)
self.validate_dataset_schema( validate_dataset_schema(
dataset_def.dataset_schema, get_valid_schemas(Api.scoring.value) dataset_def.dataset_schema, get_valid_schemas(Api.scoring.value)
) )
@ -194,7 +194,7 @@ class BraintrustScoringImpl(
async def score_row( async def score_row(
self, input_row: Dict[str, Any], scoring_fn_identifier: Optional[str] = None self, input_row: Dict[str, Any], scoring_fn_identifier: Optional[str] = None
) -> ScoringResultRow: ) -> ScoringResultRow:
self.validate_row_schema(input_row, get_valid_schemas(Api.scoring.value)) validate_row_schema(input_row, get_valid_schemas(Api.scoring.value))
await self.set_api_key() await self.set_api_key()
assert scoring_fn_identifier is not None, "scoring_fn_identifier cannot be None" assert scoring_fn_identifier is not None, "scoring_fn_identifier cannot be None"
expected_answer = input_row["expected_answer"] expected_answer = input_row["expected_answer"]

View file

@ -19,8 +19,8 @@ from llama_stack.apis.scoring_functions import ScoringFn, ScoringFnParams
from llama_stack.distribution.datatypes import Api from llama_stack.distribution.datatypes import Api
from llama_stack.providers.datatypes import ScoringFunctionsProtocolPrivate from llama_stack.providers.datatypes import ScoringFunctionsProtocolPrivate
from llama_stack.providers.utils.common.data_schema_validator import ( from llama_stack.providers.utils.common.data_schema_validator import (
DataSchemaValidatorMixin,
get_valid_schemas, get_valid_schemas,
validate_dataset_schema,
) )
from .config import LlmAsJudgeScoringConfig from .config import LlmAsJudgeScoringConfig
@ -31,7 +31,8 @@ LLM_JUDGE_FNS = [LlmAsJudgeScoringFn]
class LlmAsJudgeScoringImpl( class LlmAsJudgeScoringImpl(
Scoring, ScoringFunctionsProtocolPrivate, DataSchemaValidatorMixin Scoring,
ScoringFunctionsProtocolPrivate,
): ):
def __init__( def __init__(
self, self,
@ -79,7 +80,7 @@ class LlmAsJudgeScoringImpl(
save_results_dataset: bool = False, save_results_dataset: bool = False,
) -> ScoreBatchResponse: ) -> ScoreBatchResponse:
dataset_def = await self.datasets_api.get_dataset(dataset_id=dataset_id) dataset_def = await self.datasets_api.get_dataset(dataset_id=dataset_id)
self.validate_dataset_schema( validate_dataset_schema(
dataset_def.dataset_schema, get_valid_schemas(Api.scoring.value) dataset_def.dataset_schema, get_valid_schemas(Api.scoring.value)
) )

View file

@ -140,7 +140,7 @@ class GroqInferenceAdapter(Inference, ModelRegistryHelper, NeedsRequestProviderD
def _get_client(self) -> Groq: def _get_client(self) -> Groq:
if self._config.api_key is not None: if self._config.api_key is not None:
return Groq(api_key=self.config.api_key) return Groq(api_key=self._config.api_key)
else: else:
provider_data = self.get_request_provider_data() provider_data = self.get_request_provider_data()
if provider_data is None or not provider_data.groq_api_key: if provider_data is None or not provider_data.groq_api_key:

View file

@ -193,10 +193,9 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
else: else:
assert ( assert (
not media_present not media_present
), "Together does not support media for Completion requests" ), "vLLM does not support media for Completion requests"
input_dict["prompt"] = await completion_request_to_prompt( input_dict["prompt"] = await completion_request_to_prompt(
request, request,
self.register_helper.get_llama_model(request.model),
self.formatter, self.formatter,
) )

View file

@ -62,22 +62,20 @@ def get_valid_schemas(api_str: str):
raise ValueError(f"Invalid API string: {api_str}") raise ValueError(f"Invalid API string: {api_str}")
class DataSchemaValidatorMixin: def validate_dataset_schema(
def validate_dataset_schema(
self,
dataset_schema: Dict[str, Any], dataset_schema: Dict[str, Any],
expected_schemas: List[Dict[str, Any]], expected_schemas: List[Dict[str, Any]],
): ):
if dataset_schema not in expected_schemas: if dataset_schema not in expected_schemas:
raise ValueError( raise ValueError(
f"Dataset {dataset_schema} does not have a correct input schema in {expected_schemas}" f"Dataset {dataset_schema} does not have a correct input schema in {expected_schemas}"
) )
def validate_row_schema(
self, def validate_row_schema(
input_row: Dict[str, Any], input_row: Dict[str, Any],
expected_schemas: List[Dict[str, Any]], expected_schemas: List[Dict[str, Any]],
): ):
for schema in expected_schemas: for schema in expected_schemas:
if all(key in input_row for key in schema): if all(key in input_row for key in schema):
return return