Merge branch 'eval_task_register' into mmlu_benchmark

This commit is contained in:
Xi Yan 2024-11-07 14:41:50 -08:00
commit cc6edf6287
72 changed files with 306 additions and 304 deletions

View file

@ -4,10 +4,9 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from pydantic import BaseModel, Field
from llama_stack.providers.utils.kvstore import KVStoreConfig
from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig
from pydantic import BaseModel, Field
class MetaReferenceAgentsImplConfig(BaseModel):

View file

@ -11,9 +11,8 @@ from datetime import datetime
from typing import List, Optional
from llama_stack.apis.agents import * # noqa: F403
from pydantic import BaseModel
from llama_stack.providers.utils.kvstore import KVStore
from pydantic import BaseModel
class AgentSessionInfo(BaseModel):

View file

@ -10,14 +10,13 @@ from jinja2 import Template
from llama_models.llama3.api import * # noqa: F403
from termcolor import cprint # noqa: F401
from llama_stack.apis.agents import (
DefaultMemoryQueryGeneratorConfig,
LLMMemoryQueryGeneratorConfig,
MemoryQueryGenerator,
MemoryQueryGeneratorConfig,
)
from termcolor import cprint # noqa: F401
from llama_stack.apis.inference import * # noqa: F403

View file

@ -9,8 +9,7 @@ from typing import List
from llama_stack.apis.inference import Message
from llama_stack.apis.safety import * # noqa: F403
from llama_stack.providers.inline.meta_reference.agents.safety import ShieldRunnerMixin
from ..safety import ShieldRunnerMixin
from .builtin import BaseTool

View file

@ -10,9 +10,8 @@ from llama_models.datatypes import * # noqa: F403
from llama_models.sku_list import resolve_model
from llama_stack.apis.inference import * # noqa: F401, F403
from pydantic import BaseModel, Field, field_validator
from llama_stack.providers.utils.inference import supported_inference_models
from pydantic import BaseModel, Field, field_validator
class MetaReferenceInferenceConfig(BaseModel):

View file

@ -35,13 +35,12 @@ from termcolor import cprint
from llama_stack.apis.inference import * # noqa: F403
from lmformatenforcer import JsonSchemaParser, TokenEnforcer, TokenEnforcerTokenizerData
from llama_stack.distribution.utils.model_utils import model_local_dir
from llama_stack.providers.utils.inference.prompt_adapter import (
augment_content_with_response_format_prompt,
chat_completion_request_to_messages,
)
from lmformatenforcer import JsonSchemaParser, TokenEnforcer, TokenEnforcerTokenizerData
from .config import (
Fp8QuantizationConfig,

View file

@ -28,13 +28,13 @@ from fairscale.nn.model_parallel.initialize import (
get_model_parallel_src_rank,
)
from llama_stack.apis.inference import ChatCompletionRequest, CompletionRequest
from pydantic import BaseModel, Field
from torch.distributed.launcher.api import elastic_launch, LaunchConfig
from typing_extensions import Annotated
from llama_stack.apis.inference import ChatCompletionRequest, CompletionRequest
from .generation import TokenResult

View file

@ -20,16 +20,15 @@ from llama_models.datatypes import CheckpointQuantizationFormat
from llama_models.llama3.api.args import ModelArgs
from llama_models.llama3.reference_impl.model import Transformer, TransformerBlock
from llama_models.sku_list import resolve_model
from llama_stack.apis.inference import QuantizationType
from termcolor import cprint
from torch import nn, Tensor
from torchao.quantization.GPTQ import Int8DynActInt4WeightLinear
from llama_stack.apis.inference import QuantizationType
from llama_stack.providers.inline.meta_reference.inference.config import (
MetaReferenceQuantizedInferenceConfig,
)
from ..config import MetaReferenceQuantizedInferenceConfig
def swiglu_wrapper(

View file

@ -5,9 +5,9 @@
# the root directory of this source tree.
from llama_models.schema_utils import json_schema_type
from pydantic import BaseModel, Field, field_validator
from llama_stack.providers.utils.inference import supported_inference_models
from pydantic import BaseModel, Field, field_validator
@json_schema_type

View file

@ -5,13 +5,13 @@
# the root directory of this source tree.
from llama_models.schema_utils import json_schema_type
from pydantic import BaseModel
from llama_stack.distribution.utils.config_dirs import RUNTIME_BASE_DIR
from llama_stack.providers.utils.kvstore.config import (
KVStoreConfig,
SqliteKVStoreConfig,
)
from pydantic import BaseModel
@json_schema_type

View file

@ -8,10 +8,11 @@ import logging
from typing import Any, Dict, List, Optional
import faiss
import numpy as np
from numpy.typing import NDArray
import faiss
from llama_models.llama3.api.datatypes import * # noqa: F403
from llama_stack.apis.memory import * # noqa: F403

View file

@ -7,11 +7,17 @@ from enum import Enum
from llama_models.llama3.api.datatypes import * # noqa: F403
from .....apis.common.job_types import Job
from .....apis.eval.eval import BenchmarkEvalTaskConfig
from .....apis.eval.eval import (
AppEvalTaskConfig,
BenchmarkEvalTaskConfig,
Eval,
EvalTaskConfig,
EvaluateResponse,
JobStatus,
)
from llama_stack.apis.common.type_system import * # noqa: F403
from llama_stack.apis.datasetio import DatasetIO
from llama_stack.apis.datasets import Datasets
from llama_stack.apis.eval import Eval, EvalTaskConfig, EvaluateResponse, JobStatus
from llama_stack.apis.eval_tasks import EvalTaskDef
from llama_stack.apis.inference import Inference
from llama_stack.apis.scoring import Scoring
@ -19,6 +25,10 @@ from llama_stack.providers.datatypes import EvalTasksProtocolPrivate
from .config import MetaReferenceEvalConfig
# NOTE: this is the default eval task identifier for app eval
# it is used to make the router work for all app evals
# For app eval using other eval providers, the eval task identifier will be different
DEFAULT_EVAL_TASK_IDENTIFIER = "meta-reference::app_eval"
@ -88,21 +98,21 @@ class MetaReferenceEvalImpl(Eval, EvalTasksProtocolPrivate):
f"Dataset {dataset_id} does not have a correct input schema in {expected_schemas}"
)
async def run_benchmark_eval(
async def run_benchmark(
self,
benchmark_id: str,
eval_task_config: BenchmarkEvalTaskConfig,
benchmark_config: BenchmarkEvalTaskConfig,
) -> Job:
raise NotImplementedError("Benchmark eval is not implemented yet")
async def run_eval(
self,
eval_task_def: EvalTaskDef,
eval_task_config: EvalTaskConfig,
task: EvalTaskDef,
task_config: AppEvalTaskConfig,
) -> Job:
dataset_id = eval_task_def.dataset_id
candidate = eval_task_config.eval_candidate
scoring_functions = eval_task_def.scoring_functions
dataset_id = task.dataset_id
candidate = task_config.eval_candidate
scoring_functions = task.scoring_functions
await self.validate_eval_input_dataset_schema(dataset_id=dataset_id)
all_rows = await self.datasetio_api.get_rows_paginated(
@ -112,7 +122,7 @@ class MetaReferenceEvalImpl(Eval, EvalTasksProtocolPrivate):
res = await self.evaluate_rows(
input_rows=all_rows.rows,
scoring_functions=scoring_functions,
eval_task_config=eval_task_config,
task_config=task_config,
)
# TODO: currently needs to wait for generation before returning
@ -125,9 +135,10 @@ class MetaReferenceEvalImpl(Eval, EvalTasksProtocolPrivate):
self,
input_rows: List[Dict[str, Any]],
scoring_functions: List[str],
eval_task_config: EvalTaskConfig,
task_config: EvalTaskConfig,
eval_task_id: Optional[str] = None,
) -> EvaluateResponse:
candidate = eval_task_config.eval_candidate
candidate = task_config.eval_candidate
if candidate.type == "agent":
raise NotImplementedError(
"Evaluation with generation has not been implemented for agents"
@ -179,23 +190,33 @@ class MetaReferenceEvalImpl(Eval, EvalTasksProtocolPrivate):
for input_r, generated_r in zip(input_rows, generations)
]
if task_config.type == "app" and task_config.scoring_params is not None:
scoring_functions_dict = {
scoring_fn_id: task_config.scoring_params.get(scoring_fn_id, None)
for scoring_fn_id in scoring_functions
}
else:
scoring_functions_dict = {
scoring_fn_id: None for scoring_fn_id in scoring_functions
}
score_response = await self.scoring_api.score(
input_rows=score_input_rows, scoring_functions=scoring_functions
input_rows=score_input_rows, scoring_functions=scoring_functions_dict
)
return EvaluateResponse(generations=generations, scores=score_response.results)
async def job_status(self, job_id: str) -> Optional[JobStatus]:
async def job_status(self, job_id: str, eval_task_id: str) -> Optional[JobStatus]:
if job_id in self.jobs:
return JobStatus.completed
return None
async def job_cancel(self, job_id: str) -> None:
async def job_cancel(self, job_id: str, eval_task_id: str) -> None:
raise NotImplementedError("Job cancel is not implemented yet")
async def job_result(self, job_id: str) -> EvaluateResponse:
status = await self.job_status(job_id)
async def job_result(self, job_id: str, eval_task_id: str) -> EvaluateResponse:
status = await self.job_status(job_id, eval_task_id)
if not status or status != JobStatus.completed:
raise ValueError(f"Job is not completed, Status: {status.value}")

View file

@ -1,73 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import tempfile
import pytest
from llama_stack.apis.memory import MemoryBankType, VectorMemoryBankDef
from llama_stack.providers.inline.meta_reference.memory.config import FaissImplConfig
from llama_stack.providers.inline.meta_reference.memory.faiss import FaissMemoryImpl
from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig
class TestFaissMemoryImpl:
@pytest.fixture
def faiss_impl(self):
# Create a temporary SQLite database file
temp_db = tempfile.NamedTemporaryFile(suffix=".db", delete=False)
config = FaissImplConfig(kvstore=SqliteKVStoreConfig(db_path=temp_db.name))
return FaissMemoryImpl(config)
@pytest.mark.asyncio
async def test_initialize(self, faiss_impl):
# Test empty initialization
await faiss_impl.initialize()
assert len(faiss_impl.cache) == 0
# Test initialization with existing banks
bank = VectorMemoryBankDef(
identifier="test_bank",
type=MemoryBankType.vector.value,
embedding_model="all-MiniLM-L6-v2",
chunk_size_in_tokens=512,
overlap_size_in_tokens=64,
)
# Register a bank and reinitialize to test loading
await faiss_impl.register_memory_bank(bank)
# Create new instance to test initialization with existing data
new_impl = FaissMemoryImpl(faiss_impl.config)
await new_impl.initialize()
assert len(new_impl.cache) == 1
assert "test_bank" in new_impl.cache
@pytest.mark.asyncio
async def test_register_memory_bank(self, faiss_impl):
bank = VectorMemoryBankDef(
identifier="test_bank",
type=MemoryBankType.vector.value,
embedding_model="all-MiniLM-L6-v2",
chunk_size_in_tokens=512,
overlap_size_in_tokens=64,
)
await faiss_impl.initialize()
await faiss_impl.register_memory_bank(bank)
assert "test_bank" in faiss_impl.cache
assert faiss_impl.cache["test_bank"].bank == bank
# Verify persistence
new_impl = FaissMemoryImpl(faiss_impl.config)
await new_impl.initialize()
assert "test_bank" in new_impl.cache
if __name__ == "__main__":
pytest.main([__file__])

View file

@ -89,8 +89,7 @@ class MetaReferenceScoringImpl(Scoring, ScoringFunctionsProtocolPrivate):
async def score_batch(
self,
dataset_id: str,
scoring_functions: List[str],
scoring_params: Optional[Dict[str, ScoringFnParams]] = None,
scoring_functions: Optional[Dict[str, ScoringFnParams]] = None,
save_results_dataset: bool = False,
) -> ScoreBatchResponse:
await self.validate_scoring_input_dataset_schema(dataset_id=dataset_id)
@ -101,7 +100,6 @@ class MetaReferenceScoringImpl(Scoring, ScoringFunctionsProtocolPrivate):
res = await self.score(
input_rows=all_rows.rows,
scoring_functions=scoring_functions,
scoring_params=scoring_params,
)
if save_results_dataset:
# TODO: persist and register dataset on to server for reading
@ -115,17 +113,14 @@ class MetaReferenceScoringImpl(Scoring, ScoringFunctionsProtocolPrivate):
async def score(
self,
input_rows: List[Dict[str, Any]],
scoring_functions: List[str],
scoring_params: Optional[Dict[str, ScoringFnParams]] = None,
scoring_functions: Optional[Dict[str, ScoringFnParams]] = None,
) -> ScoreResponse:
res = {}
for scoring_fn_id in scoring_functions:
for scoring_fn_id in scoring_functions.keys():
if scoring_fn_id not in self.scoring_fn_id_impls:
raise ValueError(f"Scoring function {scoring_fn_id} is not supported.")
scoring_fn = self.scoring_fn_id_impls[scoring_fn_id]
scoring_fn_params = None
if scoring_params is not None:
scoring_fn_params = scoring_params.get(scoring_fn_id, None)
scoring_fn_params = scoring_functions.get(scoring_fn_id, None)
score_results = await scoring_fn.score(
input_rows, scoring_fn_id, scoring_fn_params
)