mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-03 09:21:45 +00:00
Merge remote-tracking branch 'origin/main' into support_more_data_format
This commit is contained in:
commit
2a992d4f05
10 changed files with 76 additions and 55 deletions
|
@ -7,6 +7,7 @@
|
|||
import asyncio
|
||||
import inspect
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import queue
|
||||
import threading
|
||||
|
@ -16,7 +17,6 @@ from pathlib import Path
|
|||
from typing import Any, Generator, get_args, get_origin, Optional, TypeVar
|
||||
|
||||
import httpx
|
||||
|
||||
import yaml
|
||||
from llama_stack_client import (
|
||||
APIResponse,
|
||||
|
@ -28,7 +28,6 @@ from llama_stack_client import (
|
|||
)
|
||||
from pydantic import BaseModel, TypeAdapter
|
||||
from rich.console import Console
|
||||
|
||||
from termcolor import cprint
|
||||
|
||||
from llama_stack.distribution.build import print_pip_install_help
|
||||
|
@ -42,7 +41,6 @@ from llama_stack.distribution.stack import (
|
|||
redact_sensitive_fields,
|
||||
replace_env_vars,
|
||||
)
|
||||
|
||||
from llama_stack.providers.utils.telemetry.tracing import (
|
||||
end_trace,
|
||||
setup_logger,
|
||||
|
@ -174,6 +172,7 @@ class LlamaStackAsLibraryClient(LlamaStackClient):
|
|||
def __init__(
|
||||
self,
|
||||
config_path_or_template_name: str,
|
||||
skip_logger_removal: bool = False,
|
||||
custom_provider_registry: Optional[ProviderRegistry] = None,
|
||||
):
|
||||
super().__init__()
|
||||
|
@ -181,15 +180,28 @@ class LlamaStackAsLibraryClient(LlamaStackClient):
|
|||
config_path_or_template_name, custom_provider_registry
|
||||
)
|
||||
self.pool_executor = ThreadPoolExecutor(max_workers=4)
|
||||
self.skip_logger_removal = skip_logger_removal
|
||||
|
||||
def initialize(self):
|
||||
if in_notebook():
|
||||
import nest_asyncio
|
||||
|
||||
nest_asyncio.apply()
|
||||
if not self.skip_logger_removal:
|
||||
self._remove_root_logger_handlers()
|
||||
|
||||
return asyncio.run(self.async_client.initialize())
|
||||
|
||||
def _remove_root_logger_handlers(self):
|
||||
"""
|
||||
Remove all handlers from the root logger. Needed to avoid polluting the console with logs.
|
||||
"""
|
||||
root_logger = logging.getLogger()
|
||||
|
||||
for handler in root_logger.handlers[:]:
|
||||
root_logger.removeHandler(handler)
|
||||
print(f"Removed handler {handler.__class__.__name__} from root logger")
|
||||
|
||||
def _get_path(
|
||||
self,
|
||||
cast_to: Any,
|
||||
|
|
|
@ -18,8 +18,8 @@ from llama_stack.providers.datatypes import EvalTasksProtocolPrivate
|
|||
|
||||
from llama_stack.providers.utils.common.data_schema_validator import (
|
||||
ColumnName,
|
||||
DataSchemaValidatorMixin,
|
||||
get_valid_schemas,
|
||||
validate_dataset_schema,
|
||||
)
|
||||
from llama_stack.providers.utils.kvstore import kvstore_impl
|
||||
|
||||
|
@ -31,7 +31,10 @@ from .config import MetaReferenceEvalConfig
|
|||
EVAL_TASKS_PREFIX = "eval_tasks:"
|
||||
|
||||
|
||||
class MetaReferenceEvalImpl(Eval, EvalTasksProtocolPrivate, DataSchemaValidatorMixin):
|
||||
class MetaReferenceEvalImpl(
|
||||
Eval,
|
||||
EvalTasksProtocolPrivate,
|
||||
):
|
||||
def __init__(
|
||||
self,
|
||||
config: MetaReferenceEvalConfig,
|
||||
|
@ -85,7 +88,7 @@ class MetaReferenceEvalImpl(Eval, EvalTasksProtocolPrivate, DataSchemaValidatorM
|
|||
candidate = task_config.eval_candidate
|
||||
scoring_functions = task_def.scoring_functions
|
||||
dataset_def = await self.datasets_api.get_dataset(dataset_id=dataset_id)
|
||||
self.validate_dataset_schema(
|
||||
validate_dataset_schema(
|
||||
dataset_def.dataset_schema, get_valid_schemas(Api.eval.value)
|
||||
)
|
||||
all_rows = await self.datasetio_api.get_rows_paginated(
|
||||
|
|
|
@ -90,18 +90,24 @@ class TorchtuneCheckpointer:
|
|||
model_file_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# copy the related files for inference
|
||||
shutil.copy(
|
||||
Path.joinpath(self._checkpoint_dir, "params.json"),
|
||||
Path.joinpath(model_file_path, "params.json"),
|
||||
)
|
||||
shutil.copy(
|
||||
Path.joinpath(self._checkpoint_dir, "tokenizer.model"),
|
||||
Path.joinpath(model_file_path, "tokenizer.model"),
|
||||
)
|
||||
shutil.copy(
|
||||
Path.joinpath(self._checkpoint_dir, "orig_params.json"),
|
||||
Path.joinpath(model_file_path, "orig_params.json"),
|
||||
)
|
||||
source_path = Path.joinpath(self._checkpoint_dir, "params.json")
|
||||
if source_path.exists():
|
||||
shutil.copy(
|
||||
source_path,
|
||||
Path.joinpath(model_file_path, "params.json"),
|
||||
)
|
||||
source_path = Path.joinpath(self._checkpoint_dir, "tokenizer.model")
|
||||
if source_path.exists():
|
||||
shutil.copy(
|
||||
source_path,
|
||||
Path.joinpath(model_file_path, "tokenizer.model"),
|
||||
)
|
||||
source_path = Path.joinpath(self._checkpoint_dir, "orig_params.json")
|
||||
if source_path.exists():
|
||||
shutil.copy(
|
||||
source_path,
|
||||
Path.joinpath(model_file_path, "orig_params.json"),
|
||||
)
|
||||
|
||||
if not adapter_only:
|
||||
model_state_dict = state_dict[training.MODEL_KEY]
|
||||
|
|
|
@ -29,8 +29,9 @@ from torchtune.data._messages import (
|
|||
ShareGPTToMessages,
|
||||
)
|
||||
|
||||
from torchtune.models.llama3 import llama3_tokenizer, lora_llama3_8b
|
||||
from torchtune.models.llama3 import llama3_tokenizer
|
||||
from torchtune.models.llama3._tokenizer import Llama3Tokenizer
|
||||
from torchtune.models.llama3_1 import lora_llama3_1_8b
|
||||
from torchtune.models.llama3_2 import lora_llama3_2_3b
|
||||
from torchtune.modules.transforms import Transform
|
||||
|
||||
|
@ -63,8 +64,8 @@ MODEL_CONFIGS: Dict[str, ModelConfig] = {
|
|||
tokenizer_type=llama3_tokenizer,
|
||||
checkpoint_type="LLAMA3_2",
|
||||
),
|
||||
"Llama-3-8B-Instruct": ModelConfig(
|
||||
model_definition=lora_llama3_8b,
|
||||
"Llama3.1-8B-Instruct": ModelConfig(
|
||||
model_definition=lora_llama3_1_8b,
|
||||
tokenizer_type=llama3_tokenizer,
|
||||
checkpoint_type="LLAMA3",
|
||||
),
|
||||
|
|
|
@ -18,8 +18,8 @@ from llama_stack.apis.scoring_functions import ScoringFn, ScoringFnParams
|
|||
from llama_stack.distribution.datatypes import Api
|
||||
from llama_stack.providers.datatypes import ScoringFunctionsProtocolPrivate
|
||||
from llama_stack.providers.utils.common.data_schema_validator import (
|
||||
DataSchemaValidatorMixin,
|
||||
get_valid_schemas,
|
||||
validate_dataset_schema,
|
||||
)
|
||||
from .config import BasicScoringConfig
|
||||
from .scoring_fn.equality_scoring_fn import EqualityScoringFn
|
||||
|
@ -30,7 +30,8 @@ FIXED_FNS = [EqualityScoringFn, SubsetOfScoringFn, RegexParserScoringFn]
|
|||
|
||||
|
||||
class BasicScoringImpl(
|
||||
Scoring, ScoringFunctionsProtocolPrivate, DataSchemaValidatorMixin
|
||||
Scoring,
|
||||
ScoringFunctionsProtocolPrivate,
|
||||
):
|
||||
def __init__(
|
||||
self,
|
||||
|
@ -75,7 +76,7 @@ class BasicScoringImpl(
|
|||
save_results_dataset: bool = False,
|
||||
) -> ScoreBatchResponse:
|
||||
dataset_def = await self.datasets_api.get_dataset(dataset_id=dataset_id)
|
||||
self.validate_dataset_schema(
|
||||
validate_dataset_schema(
|
||||
dataset_def.dataset_schema, get_valid_schemas(Api.scoring.value)
|
||||
)
|
||||
|
||||
|
|
|
@ -35,8 +35,9 @@ from llama_stack.distribution.datatypes import Api
|
|||
from llama_stack.distribution.request_headers import NeedsRequestProviderData
|
||||
from llama_stack.providers.datatypes import ScoringFunctionsProtocolPrivate
|
||||
from llama_stack.providers.utils.common.data_schema_validator import (
|
||||
DataSchemaValidatorMixin,
|
||||
get_valid_schemas,
|
||||
validate_dataset_schema,
|
||||
validate_row_schema,
|
||||
)
|
||||
|
||||
from llama_stack.providers.utils.scoring.aggregation_utils import aggregate_metrics
|
||||
|
@ -111,7 +112,6 @@ class BraintrustScoringImpl(
|
|||
Scoring,
|
||||
ScoringFunctionsProtocolPrivate,
|
||||
NeedsRequestProviderData,
|
||||
DataSchemaValidatorMixin,
|
||||
):
|
||||
def __init__(
|
||||
self,
|
||||
|
@ -171,7 +171,7 @@ class BraintrustScoringImpl(
|
|||
await self.set_api_key()
|
||||
|
||||
dataset_def = await self.datasets_api.get_dataset(dataset_id=dataset_id)
|
||||
self.validate_dataset_schema(
|
||||
validate_dataset_schema(
|
||||
dataset_def.dataset_schema, get_valid_schemas(Api.scoring.value)
|
||||
)
|
||||
|
||||
|
@ -194,7 +194,7 @@ class BraintrustScoringImpl(
|
|||
async def score_row(
|
||||
self, input_row: Dict[str, Any], scoring_fn_identifier: Optional[str] = None
|
||||
) -> ScoringResultRow:
|
||||
self.validate_row_schema(input_row, get_valid_schemas(Api.scoring.value))
|
||||
validate_row_schema(input_row, get_valid_schemas(Api.scoring.value))
|
||||
await self.set_api_key()
|
||||
assert scoring_fn_identifier is not None, "scoring_fn_identifier cannot be None"
|
||||
expected_answer = input_row["expected_answer"]
|
||||
|
|
|
@ -19,8 +19,8 @@ from llama_stack.apis.scoring_functions import ScoringFn, ScoringFnParams
|
|||
from llama_stack.distribution.datatypes import Api
|
||||
from llama_stack.providers.datatypes import ScoringFunctionsProtocolPrivate
|
||||
from llama_stack.providers.utils.common.data_schema_validator import (
|
||||
DataSchemaValidatorMixin,
|
||||
get_valid_schemas,
|
||||
validate_dataset_schema,
|
||||
)
|
||||
|
||||
from .config import LlmAsJudgeScoringConfig
|
||||
|
@ -31,7 +31,8 @@ LLM_JUDGE_FNS = [LlmAsJudgeScoringFn]
|
|||
|
||||
|
||||
class LlmAsJudgeScoringImpl(
|
||||
Scoring, ScoringFunctionsProtocolPrivate, DataSchemaValidatorMixin
|
||||
Scoring,
|
||||
ScoringFunctionsProtocolPrivate,
|
||||
):
|
||||
def __init__(
|
||||
self,
|
||||
|
@ -79,7 +80,7 @@ class LlmAsJudgeScoringImpl(
|
|||
save_results_dataset: bool = False,
|
||||
) -> ScoreBatchResponse:
|
||||
dataset_def = await self.datasets_api.get_dataset(dataset_id=dataset_id)
|
||||
self.validate_dataset_schema(
|
||||
validate_dataset_schema(
|
||||
dataset_def.dataset_schema, get_valid_schemas(Api.scoring.value)
|
||||
)
|
||||
|
||||
|
|
|
@ -140,7 +140,7 @@ class GroqInferenceAdapter(Inference, ModelRegistryHelper, NeedsRequestProviderD
|
|||
|
||||
def _get_client(self) -> Groq:
|
||||
if self._config.api_key is not None:
|
||||
return Groq(api_key=self.config.api_key)
|
||||
return Groq(api_key=self._config.api_key)
|
||||
else:
|
||||
provider_data = self.get_request_provider_data()
|
||||
if provider_data is None or not provider_data.groq_api_key:
|
||||
|
|
|
@ -193,10 +193,9 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
|
|||
else:
|
||||
assert (
|
||||
not media_present
|
||||
), "Together does not support media for Completion requests"
|
||||
), "vLLM does not support media for Completion requests"
|
||||
input_dict["prompt"] = await completion_request_to_prompt(
|
||||
request,
|
||||
self.register_helper.get_llama_model(request.model),
|
||||
self.formatter,
|
||||
)
|
||||
|
||||
|
|
|
@ -62,26 +62,24 @@ def get_valid_schemas(api_str: str):
|
|||
raise ValueError(f"Invalid API string: {api_str}")
|
||||
|
||||
|
||||
class DataSchemaValidatorMixin:
|
||||
def validate_dataset_schema(
|
||||
self,
|
||||
dataset_schema: Dict[str, Any],
|
||||
expected_schemas: List[Dict[str, Any]],
|
||||
):
|
||||
if dataset_schema not in expected_schemas:
|
||||
raise ValueError(
|
||||
f"Dataset {dataset_schema} does not have a correct input schema in {expected_schemas}"
|
||||
)
|
||||
|
||||
def validate_row_schema(
|
||||
self,
|
||||
input_row: Dict[str, Any],
|
||||
expected_schemas: List[Dict[str, Any]],
|
||||
):
|
||||
for schema in expected_schemas:
|
||||
if all(key in input_row for key in schema):
|
||||
return
|
||||
|
||||
def validate_dataset_schema(
|
||||
dataset_schema: Dict[str, Any],
|
||||
expected_schemas: List[Dict[str, Any]],
|
||||
):
|
||||
if dataset_schema not in expected_schemas:
|
||||
raise ValueError(
|
||||
f"Input row {input_row} does not match any of the expected schemas in {expected_schemas}"
|
||||
f"Dataset {dataset_schema} does not have a correct input schema in {expected_schemas}"
|
||||
)
|
||||
|
||||
|
||||
def validate_row_schema(
|
||||
input_row: Dict[str, Any],
|
||||
expected_schemas: List[Dict[str, Any]],
|
||||
):
|
||||
for schema in expected_schemas:
|
||||
if all(key in input_row for key in schema):
|
||||
return
|
||||
|
||||
raise ValueError(
|
||||
f"Input row {input_row} does not match any of the expected schemas in {expected_schemas}"
|
||||
)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue