Merge branch 'eval_task_register' into mmlu_benchmark

This commit is contained in:
Xi Yan 2024-11-07 14:41:50 -08:00
commit cc6edf6287
72 changed files with 306 additions and 304 deletions

View file

@ -1,17 +1,14 @@
# What does this PR do?
Closes # (issue)
In short, provide a summary of what this PR does and why. Usually, the relevant context should be present in a linked issue.
- [ ] Addresses issue (#issue)
## Feature/Issue validation/testing/test plan
Please describe the tests that you ran to verify your changes and relevant result summary. Provide instructions so it can be reproduced.
Please also list any relevant details for your test configuration or test plan.
- [ ] Test A
Logs for Test A
- [ ] Test B
Logs for Test B
Please describe:
- tests you ran to verify your changes with result summaries.
- provide instructions so it can be reproduced.
## Sources
@ -20,12 +17,10 @@ Please link relevant resources if necessary.
## Before submitting
- [ ] This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
- [ ] Did you read the [contributor guideline](https://github.com/meta-llama/llama-stack/blob/main/CONTRIBUTING.md),
Pull Request section?
- [ ] Was this discussed/approved via a Github issue? Please add a link
to it if that's the case.
- [ ] Did you make sure to update the documentation with your changes?
- [ ] Did you write any new necessary tests?
Thanks for contributing 🎉!
- [ ] This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
- [ ] Ran pre-commit to handle lint / formatting issues.
- [ ] Read the [contributor guideline](https://github.com/meta-llama/llama-stack/blob/main/CONTRIBUTING.md),
Pull Request section?
- [ ] Updated relevant documentation.
- [ ] Wrote necessary unit or integration tests.

View file

@ -38,14 +38,16 @@ EvalCandidate = Annotated[
@json_schema_type
class BenchmarkEvalTaskConfig(BaseModel):
eval_candidate: EvalCandidate # type: ignore
type: Literal["benchmark"] = "benchmark"
eval_candidate: EvalCandidate
@json_schema_type
class AppEvalTaskConfig(BaseModel):
eval_candidate: EvalCandidate # type: ignore
scoring_params: Dict[str, ScoringFnParams] = Field( # type: ignore
description="Map between scoring function id and parameters",
type: Literal["app"] = "app"
eval_candidate: EvalCandidate
scoring_params: Dict[str, ScoringFnParams] = Field(
description="Map between scoring function id and parameters for each scoring function you want to run",
default_factory=dict,
)
# we could optinally add any specific dataset config here
@ -64,18 +66,18 @@ class EvaluateResponse(BaseModel):
class Eval(Protocol):
@webmethod(route="/eval/run_benchmark_eval", method="POST")
async def run_benchmark_eval(
@webmethod(route="/eval/run_benchmark", method="POST")
async def run_benchmark(
self,
benchmark_id: str,
eval_task_config: BenchmarkEvalTaskConfig,
benchmark_config: BenchmarkEvalTaskConfig,
) -> Job: ...
@webmethod(route="/eval/run_eval", method="POST")
async def run_eval(
self,
eval_task_def: EvalTaskDef,
eval_task_config: EvalTaskConfig,
task: EvalTaskDef,
task_config: AppEvalTaskConfig,
) -> Job: ...
@webmethod(route="/eval/evaluate_rows", method="POST")
@ -83,14 +85,17 @@ class Eval(Protocol):
self,
input_rows: List[Dict[str, Any]],
scoring_functions: List[str],
eval_task_config: EvalTaskConfig, # type: ignore
task_config: EvalTaskConfig,
eval_task_id: Optional[str] = None,
) -> EvaluateResponse: ...
@webmethod(route="/eval/job/status", method="GET")
async def job_status(self, job_id: str) -> Optional[JobStatus]: ...
async def job_status(
self, job_id: str, eval_task_id: str
) -> Optional[JobStatus]: ...
@webmethod(route="/eval/job/cancel", method="POST")
async def job_cancel(self, job_id: str) -> None: ...
async def job_cancel(self, job_id: str, eval_task_id: str) -> None: ...
@webmethod(route="/eval/job/result", method="GET")
async def job_result(self, job_id: str) -> EvaluateResponse: ...
async def job_result(self, job_id: str, eval_task_id: str) -> EvaluateResponse: ...

View file

@ -48,8 +48,7 @@ class Scoring(Protocol):
async def score_batch(
self,
dataset_id: str,
scoring_functions: List[str],
scoring_params: Optional[Dict[str, ScoringFnParams]] = None,
scoring_functions: Optional[Dict[str, ScoringFnParams]] = None,
save_results_dataset: bool = False,
) -> ScoreBatchResponse: ...
@ -57,6 +56,5 @@ class Scoring(Protocol):
async def score(
self,
input_rows: List[Dict[str, Any]],
scoring_functions: List[str],
scoring_params: Optional[Dict[str, ScoringFnParams]] = None,
scoring_functions: Optional[Dict[str, ScoringFnParams]] = None,
) -> ScoreResponse: ...

View file

@ -33,7 +33,7 @@ class ScoringConfigType(Enum):
@json_schema_type
class LLMAsJudgeScoringFnParams(BaseModel):
type: Literal[ScoringConfigType.llm_as_judge.value] = ( # type: ignore
type: Literal[ScoringConfigType.llm_as_judge.value] = (
ScoringConfigType.llm_as_judge.value
)
judge_model: str
@ -46,7 +46,7 @@ class LLMAsJudgeScoringFnParams(BaseModel):
@json_schema_type
class RegexParserScoringFnParams(BaseModel):
type: Literal[ScoringConfigType.regex_parser.value] = ( # type: ignore
type: Literal[ScoringConfigType.regex_parser.value] = (
ScoringConfigType.regex_parser.value
)
parsing_regexes: Optional[List[str]] = Field(
@ -75,8 +75,8 @@ class ScoringFnDef(BaseModel):
return_type: ParamType = Field(
description="The return type of the deterministic function",
)
params: Optional[ScoringFnParams] = Field( # type: ignore
description="The parameters for the scoring function for benchmark eval, we could override this for app eval",
params: Optional[ScoringFnParams] = Field(
description="The parameters for the scoring function for benchmark eval, these can be overridden for app eval",
default=None,
)
# We can optionally add information here to support packaging of code, etc.

View file

@ -8,6 +8,8 @@ import inspect
from typing import Any, Dict, List, Set
from termcolor import cprint
from llama_stack.providers.datatypes import * # noqa: F403
from llama_stack.distribution.datatypes import * # noqa: F403
@ -100,6 +102,12 @@ async def resolve_impls(
)
p = provider_registry[api][provider.provider_type]
if p.deprecation_warning:
cprint(
f"Provider `{provider.provider_type}` for API `{api}` is deprecated and will be removed in a future release: {p.deprecation_warning}",
"red",
attrs=["bold"],
)
p.deps__ = [a.value for a in p.api_dependencies]
spec = ProviderWithSpec(
spec=p,

View file

@ -216,18 +216,16 @@ class ScoringRouter(Scoring):
async def score_batch(
self,
dataset_id: str,
scoring_functions: List[str],
scoring_params: Optional[Dict[str, ScoringFnParams]] = None,
scoring_functions: Optional[Dict[str, ScoringFnParams]] = None,
save_results_dataset: bool = False,
) -> ScoreBatchResponse:
res = {}
for fn_identifier in scoring_functions:
for fn_identifier in scoring_functions.keys():
score_response = await self.routing_table.get_provider_impl(
fn_identifier
).score_batch(
dataset_id=dataset_id,
scoring_functions=[fn_identifier],
scoring_params=scoring_params,
scoring_functions={fn_identifier: scoring_functions[fn_identifier]},
)
res.update(score_response.results)
@ -241,18 +239,16 @@ class ScoringRouter(Scoring):
async def score(
self,
input_rows: List[Dict[str, Any]],
scoring_functions: List[str],
scoring_params: Optional[Dict[str, ScoringFnParams]] = None,
scoring_functions: Optional[Dict[str, ScoringFnParams]] = None,
) -> ScoreResponse:
res = {}
# look up and map each scoring function to its provider impl
for fn_identifier in scoring_functions:
for fn_identifier in scoring_functions.keys():
score_response = await self.routing_table.get_provider_impl(
fn_identifier
).score(
input_rows=input_rows,
scoring_functions=[fn_identifier],
scoring_params=scoring_params,
scoring_functions={fn_identifier: scoring_functions[fn_identifier]},
)
res.update(score_response.results)
@ -272,24 +268,21 @@ class EvalRouter(Eval):
async def shutdown(self) -> None:
pass
async def run_benchmark_eval(
async def run_benchmark(
self,
benchmark_id: str,
eval_task_config: BenchmarkEvalTaskConfig,
benchmark_config: BenchmarkEvalTaskConfig,
) -> Job:
pass
async def run_eval(
self,
eval_task_def: EvalTaskDef,
eval_task_config: EvalTaskConfig,
task: EvalTaskDef,
task_config: AppEvalTaskConfig,
) -> Job:
# NOTE: We need to use DEFAULT_EVAL_TASK_IDENTIFIER to make the router work for all app evals
return await self.routing_table.get_provider_impl(
DEFAULT_EVAL_TASK_IDENTIFIER
).run_eval(
eval_task_def=eval_task_def,
eval_task_config=eval_task_config,
return await self.routing_table.get_provider_impl(task.identifier).run_eval(
task=task,
task_config=task_config,
)
@webmethod(route="/eval/evaluate_rows", method="POST")
@ -297,29 +290,42 @@ class EvalRouter(Eval):
self,
input_rows: List[Dict[str, Any]],
scoring_functions: List[str],
eval_task_config: EvalTaskConfig, # type: ignore
task_config: EvalTaskConfig,
eval_task_id: Optional[str] = None,
) -> EvaluateResponse:
# NOTE: This is to deal with the case where we do not pre-register an eval benchmark_task
# We use default DEFAULT_EVAL_TASK_IDENTIFIER as identifier
return await self.routing_table.get_provider_impl(
DEFAULT_EVAL_TASK_IDENTIFIER
).evaluate_rows(
if eval_task_id is None:
eval_task_id = DEFAULT_EVAL_TASK_IDENTIFIER
return await self.routing_table.get_provider_impl(eval_task_id).evaluate_rows(
input_rows=input_rows,
scoring_functions=scoring_functions,
eval_task_config=eval_task_config,
task_config=task_config,
)
async def job_status(self, job_id: str) -> Optional[JobStatus]:
return await self.routing_table.get_provider_impl(
DEFAULT_EVAL_TASK_IDENTIFIER
).job_status(job_id)
async def job_status(
self,
job_id: str,
eval_task_id: str,
) -> Optional[JobStatus]:
return await self.routing_table.get_provider_impl(eval_task_id).job_status(
job_id, eval_task_id
)
async def job_cancel(self, job_id: str) -> None:
await self.routing_table.get_provider_impl(
DEFAULT_EVAL_TASK_IDENTIFIER
).job_cancel(job_id)
async def job_cancel(
self,
job_id: str,
eval_task_id: str,
) -> None:
await self.routing_table.get_provider_impl(eval_task_id).job_cancel(
job_id, eval_task_id
)
async def job_result(self, job_id: str) -> EvaluateResponse:
return await self.routing_table.get_provider_impl(
DEFAULT_EVAL_TASK_IDENTIFIER
).job_result(job_id)
async def job_result(
self,
job_id: str,
eval_task_id: str,
) -> EvaluateResponse:
return await self.routing_table.get_provider_impl(eval_task_id).job_result(
job_id, eval_task_id
)

View file

@ -31,7 +31,7 @@ from llama_stack.distribution.distribution import (
get_provider_registry,
)
from llama_stack.distribution.utils.config_dirs import DISTRIBS_BASE_DIR
from llama_stack.distribution.store.registry import create_dist_registry
from llama_stack.providers.utils.telemetry.tracing import (
end_trace,
@ -42,8 +42,6 @@ from llama_stack.providers.utils.telemetry.tracing import (
from llama_stack.distribution.datatypes import * # noqa: F403
from llama_stack.distribution.request_headers import set_request_provider_data
from llama_stack.distribution.resolver import resolve_impls
from llama_stack.distribution.store import CachedDiskDistributionRegistry
from llama_stack.providers.utils.kvstore import kvstore_impl, SqliteKVStoreConfig
from .endpoints import get_all_api_endpoints
@ -281,21 +279,8 @@ def main(
config = StackRunConfig(**yaml.safe_load(fp))
app = FastAPI()
# instantiate kvstore for storing and retrieving distribution metadata
if config.metadata_store:
dist_kvstore = asyncio.run(kvstore_impl(config.metadata_store))
else:
dist_kvstore = asyncio.run(
kvstore_impl(
SqliteKVStoreConfig(
db_path=(
DISTRIBS_BASE_DIR / config.image_name / "kvstore.db"
).as_posix()
)
)
)
dist_registry = CachedDiskDistributionRegistry(dist_kvstore)
dist_registry, dist_kvstore = asyncio.run(create_dist_registry(config))
impls = asyncio.run(resolve_impls(config, get_provider_registry(), dist_registry))
if Api.telemetry in impls:

View file

@ -9,9 +9,17 @@ from typing import Dict, List, Protocol
import pydantic
from llama_stack.distribution.datatypes import RoutableObjectWithProvider
from llama_stack.distribution.datatypes import (
RoutableObjectWithProvider,
StackRunConfig,
)
from llama_stack.distribution.utils.config_dirs import DISTRIBS_BASE_DIR
from llama_stack.providers.utils.kvstore import KVStore
from llama_stack.providers.utils.kvstore import (
KVStore,
kvstore_impl,
SqliteKVStoreConfig,
)
class DistributionRegistry(Protocol):
@ -133,3 +141,21 @@ class CachedDiskDistributionRegistry(DiskDistributionRegistry):
self.cache[obj.identifier].append(obj)
return success
async def create_dist_registry(
config: StackRunConfig,
) -> tuple[CachedDiskDistributionRegistry, KVStore]:
# instantiate kvstore for storing and retrieving distribution metadata
if config.metadata_store:
dist_kvstore = await kvstore_impl(config.metadata_store)
else:
dist_kvstore = await kvstore_impl(
SqliteKVStoreConfig(
db_path=(
DISTRIBS_BASE_DIR / config.image_name / "kvstore.db"
).as_posix()
)
)
return CachedDiskDistributionRegistry(dist_kvstore), dist_kvstore

View file

@ -90,6 +90,10 @@ class ProviderSpec(BaseModel):
default_factory=list,
description="Higher-level API surfaces may depend on other providers to provide their functionality",
)
deprecation_warning: Optional[str] = Field(
default=None,
description="If this provider is deprecated, specify the warning message here",
)
# used internally by the resolver; this is a hack for now
deps__: List[str] = Field(default_factory=list)

View file

@ -4,10 +4,9 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from pydantic import BaseModel, Field
from llama_stack.providers.utils.kvstore import KVStoreConfig
from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig
from pydantic import BaseModel, Field
class MetaReferenceAgentsImplConfig(BaseModel):

View file

@ -11,9 +11,8 @@ from datetime import datetime
from typing import List, Optional
from llama_stack.apis.agents import * # noqa: F403
from pydantic import BaseModel
from llama_stack.providers.utils.kvstore import KVStore
from pydantic import BaseModel
class AgentSessionInfo(BaseModel):

View file

@ -10,14 +10,13 @@ from jinja2 import Template
from llama_models.llama3.api import * # noqa: F403
from termcolor import cprint # noqa: F401
from llama_stack.apis.agents import (
DefaultMemoryQueryGeneratorConfig,
LLMMemoryQueryGeneratorConfig,
MemoryQueryGenerator,
MemoryQueryGeneratorConfig,
)
from termcolor import cprint # noqa: F401
from llama_stack.apis.inference import * # noqa: F403

View file

@ -9,8 +9,7 @@ from typing import List
from llama_stack.apis.inference import Message
from llama_stack.apis.safety import * # noqa: F403
from llama_stack.providers.inline.meta_reference.agents.safety import ShieldRunnerMixin
from ..safety import ShieldRunnerMixin
from .builtin import BaseTool

View file

@ -10,9 +10,8 @@ from llama_models.datatypes import * # noqa: F403
from llama_models.sku_list import resolve_model
from llama_stack.apis.inference import * # noqa: F401, F403
from pydantic import BaseModel, Field, field_validator
from llama_stack.providers.utils.inference import supported_inference_models
from pydantic import BaseModel, Field, field_validator
class MetaReferenceInferenceConfig(BaseModel):

View file

@ -35,13 +35,12 @@ from termcolor import cprint
from llama_stack.apis.inference import * # noqa: F403
from lmformatenforcer import JsonSchemaParser, TokenEnforcer, TokenEnforcerTokenizerData
from llama_stack.distribution.utils.model_utils import model_local_dir
from llama_stack.providers.utils.inference.prompt_adapter import (
augment_content_with_response_format_prompt,
chat_completion_request_to_messages,
)
from lmformatenforcer import JsonSchemaParser, TokenEnforcer, TokenEnforcerTokenizerData
from .config import (
Fp8QuantizationConfig,

View file

@ -28,13 +28,13 @@ from fairscale.nn.model_parallel.initialize import (
get_model_parallel_src_rank,
)
from llama_stack.apis.inference import ChatCompletionRequest, CompletionRequest
from pydantic import BaseModel, Field
from torch.distributed.launcher.api import elastic_launch, LaunchConfig
from typing_extensions import Annotated
from llama_stack.apis.inference import ChatCompletionRequest, CompletionRequest
from .generation import TokenResult

View file

@ -20,16 +20,15 @@ from llama_models.datatypes import CheckpointQuantizationFormat
from llama_models.llama3.api.args import ModelArgs
from llama_models.llama3.reference_impl.model import Transformer, TransformerBlock
from llama_models.sku_list import resolve_model
from llama_stack.apis.inference import QuantizationType
from termcolor import cprint
from torch import nn, Tensor
from torchao.quantization.GPTQ import Int8DynActInt4WeightLinear
from llama_stack.apis.inference import QuantizationType
from llama_stack.providers.inline.meta_reference.inference.config import (
MetaReferenceQuantizedInferenceConfig,
)
from ..config import MetaReferenceQuantizedInferenceConfig
def swiglu_wrapper(

View file

@ -5,9 +5,9 @@
# the root directory of this source tree.
from llama_models.schema_utils import json_schema_type
from pydantic import BaseModel, Field, field_validator
from llama_stack.providers.utils.inference import supported_inference_models
from pydantic import BaseModel, Field, field_validator
@json_schema_type

View file

@ -5,13 +5,13 @@
# the root directory of this source tree.
from llama_models.schema_utils import json_schema_type
from pydantic import BaseModel
from llama_stack.distribution.utils.config_dirs import RUNTIME_BASE_DIR
from llama_stack.providers.utils.kvstore.config import (
KVStoreConfig,
SqliteKVStoreConfig,
)
from pydantic import BaseModel
@json_schema_type

View file

@ -8,10 +8,11 @@ import logging
from typing import Any, Dict, List, Optional
import faiss
import numpy as np
from numpy.typing import NDArray
import faiss
from llama_models.llama3.api.datatypes import * # noqa: F403
from llama_stack.apis.memory import * # noqa: F403

View file

@ -7,11 +7,17 @@ from enum import Enum
from llama_models.llama3.api.datatypes import * # noqa: F403
from .....apis.common.job_types import Job
from .....apis.eval.eval import BenchmarkEvalTaskConfig
from .....apis.eval.eval import (
AppEvalTaskConfig,
BenchmarkEvalTaskConfig,
Eval,
EvalTaskConfig,
EvaluateResponse,
JobStatus,
)
from llama_stack.apis.common.type_system import * # noqa: F403
from llama_stack.apis.datasetio import DatasetIO
from llama_stack.apis.datasets import Datasets
from llama_stack.apis.eval import Eval, EvalTaskConfig, EvaluateResponse, JobStatus
from llama_stack.apis.eval_tasks import EvalTaskDef
from llama_stack.apis.inference import Inference
from llama_stack.apis.scoring import Scoring
@ -19,6 +25,10 @@ from llama_stack.providers.datatypes import EvalTasksProtocolPrivate
from .config import MetaReferenceEvalConfig
# NOTE: this is the default eval task identifier for app eval
# it is used to make the router work for all app evals
# For app eval using other eval providers, the eval task identifier will be different
DEFAULT_EVAL_TASK_IDENTIFIER = "meta-reference::app_eval"
@ -88,21 +98,21 @@ class MetaReferenceEvalImpl(Eval, EvalTasksProtocolPrivate):
f"Dataset {dataset_id} does not have a correct input schema in {expected_schemas}"
)
async def run_benchmark_eval(
async def run_benchmark(
self,
benchmark_id: str,
eval_task_config: BenchmarkEvalTaskConfig,
benchmark_config: BenchmarkEvalTaskConfig,
) -> Job:
raise NotImplementedError("Benchmark eval is not implemented yet")
async def run_eval(
self,
eval_task_def: EvalTaskDef,
eval_task_config: EvalTaskConfig,
task: EvalTaskDef,
task_config: AppEvalTaskConfig,
) -> Job:
dataset_id = eval_task_def.dataset_id
candidate = eval_task_config.eval_candidate
scoring_functions = eval_task_def.scoring_functions
dataset_id = task.dataset_id
candidate = task_config.eval_candidate
scoring_functions = task.scoring_functions
await self.validate_eval_input_dataset_schema(dataset_id=dataset_id)
all_rows = await self.datasetio_api.get_rows_paginated(
@ -112,7 +122,7 @@ class MetaReferenceEvalImpl(Eval, EvalTasksProtocolPrivate):
res = await self.evaluate_rows(
input_rows=all_rows.rows,
scoring_functions=scoring_functions,
eval_task_config=eval_task_config,
task_config=task_config,
)
# TODO: currently needs to wait for generation before returning
@ -125,9 +135,10 @@ class MetaReferenceEvalImpl(Eval, EvalTasksProtocolPrivate):
self,
input_rows: List[Dict[str, Any]],
scoring_functions: List[str],
eval_task_config: EvalTaskConfig,
task_config: EvalTaskConfig,
eval_task_id: Optional[str] = None,
) -> EvaluateResponse:
candidate = eval_task_config.eval_candidate
candidate = task_config.eval_candidate
if candidate.type == "agent":
raise NotImplementedError(
"Evaluation with generation has not been implemented for agents"
@ -179,23 +190,33 @@ class MetaReferenceEvalImpl(Eval, EvalTasksProtocolPrivate):
for input_r, generated_r in zip(input_rows, generations)
]
if task_config.type == "app" and task_config.scoring_params is not None:
scoring_functions_dict = {
scoring_fn_id: task_config.scoring_params.get(scoring_fn_id, None)
for scoring_fn_id in scoring_functions
}
else:
scoring_functions_dict = {
scoring_fn_id: None for scoring_fn_id in scoring_functions
}
score_response = await self.scoring_api.score(
input_rows=score_input_rows, scoring_functions=scoring_functions
input_rows=score_input_rows, scoring_functions=scoring_functions_dict
)
return EvaluateResponse(generations=generations, scores=score_response.results)
async def job_status(self, job_id: str) -> Optional[JobStatus]:
async def job_status(self, job_id: str, eval_task_id: str) -> Optional[JobStatus]:
if job_id in self.jobs:
return JobStatus.completed
return None
async def job_cancel(self, job_id: str) -> None:
async def job_cancel(self, job_id: str, eval_task_id: str) -> None:
raise NotImplementedError("Job cancel is not implemented yet")
async def job_result(self, job_id: str) -> EvaluateResponse:
status = await self.job_status(job_id)
async def job_result(self, job_id: str, eval_task_id: str) -> EvaluateResponse:
status = await self.job_status(job_id, eval_task_id)
if not status or status != JobStatus.completed:
raise ValueError(f"Job is not completed, Status: {status.value}")

View file

@ -1,73 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import tempfile
import pytest
from llama_stack.apis.memory import MemoryBankType, VectorMemoryBankDef
from llama_stack.providers.inline.meta_reference.memory.config import FaissImplConfig
from llama_stack.providers.inline.meta_reference.memory.faiss import FaissMemoryImpl
from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig
class TestFaissMemoryImpl:
@pytest.fixture
def faiss_impl(self):
# Create a temporary SQLite database file
temp_db = tempfile.NamedTemporaryFile(suffix=".db", delete=False)
config = FaissImplConfig(kvstore=SqliteKVStoreConfig(db_path=temp_db.name))
return FaissMemoryImpl(config)
@pytest.mark.asyncio
async def test_initialize(self, faiss_impl):
# Test empty initialization
await faiss_impl.initialize()
assert len(faiss_impl.cache) == 0
# Test initialization with existing banks
bank = VectorMemoryBankDef(
identifier="test_bank",
type=MemoryBankType.vector.value,
embedding_model="all-MiniLM-L6-v2",
chunk_size_in_tokens=512,
overlap_size_in_tokens=64,
)
# Register a bank and reinitialize to test loading
await faiss_impl.register_memory_bank(bank)
# Create new instance to test initialization with existing data
new_impl = FaissMemoryImpl(faiss_impl.config)
await new_impl.initialize()
assert len(new_impl.cache) == 1
assert "test_bank" in new_impl.cache
@pytest.mark.asyncio
async def test_register_memory_bank(self, faiss_impl):
bank = VectorMemoryBankDef(
identifier="test_bank",
type=MemoryBankType.vector.value,
embedding_model="all-MiniLM-L6-v2",
chunk_size_in_tokens=512,
overlap_size_in_tokens=64,
)
await faiss_impl.initialize()
await faiss_impl.register_memory_bank(bank)
assert "test_bank" in faiss_impl.cache
assert faiss_impl.cache["test_bank"].bank == bank
# Verify persistence
new_impl = FaissMemoryImpl(faiss_impl.config)
await new_impl.initialize()
assert "test_bank" in new_impl.cache
if __name__ == "__main__":
pytest.main([__file__])

View file

@ -89,8 +89,7 @@ class MetaReferenceScoringImpl(Scoring, ScoringFunctionsProtocolPrivate):
async def score_batch(
self,
dataset_id: str,
scoring_functions: List[str],
scoring_params: Optional[Dict[str, ScoringFnParams]] = None,
scoring_functions: Optional[Dict[str, ScoringFnParams]] = None,
save_results_dataset: bool = False,
) -> ScoreBatchResponse:
await self.validate_scoring_input_dataset_schema(dataset_id=dataset_id)
@ -101,7 +100,6 @@ class MetaReferenceScoringImpl(Scoring, ScoringFunctionsProtocolPrivate):
res = await self.score(
input_rows=all_rows.rows,
scoring_functions=scoring_functions,
scoring_params=scoring_params,
)
if save_results_dataset:
# TODO: persist and register dataset on to server for reading
@ -115,17 +113,14 @@ class MetaReferenceScoringImpl(Scoring, ScoringFunctionsProtocolPrivate):
async def score(
self,
input_rows: List[Dict[str, Any]],
scoring_functions: List[str],
scoring_params: Optional[Dict[str, ScoringFnParams]] = None,
scoring_functions: Optional[Dict[str, ScoringFnParams]] = None,
) -> ScoreResponse:
res = {}
for scoring_fn_id in scoring_functions:
for scoring_fn_id in scoring_functions.keys():
if scoring_fn_id not in self.scoring_fn_id_impls:
raise ValueError(f"Scoring function {scoring_fn_id} is not supported.")
scoring_fn = self.scoring_fn_id_impls[scoring_fn_id]
scoring_fn_params = None
if scoring_params is not None:
scoring_fn_params = scoring_params.get(scoring_fn_id, None)
scoring_fn_params = scoring_functions.get(scoring_fn_id, None)
score_results = await scoring_fn.score(
input_rows, scoring_fn_id, scoring_fn_params
)

View file

@ -22,8 +22,8 @@ def available_providers() -> List[ProviderSpec]:
"scikit-learn",
]
+ kvstore_dependencies(),
module="llama_stack.providers.inline.meta_reference.agents",
config_class="llama_stack.providers.inline.meta_reference.agents.MetaReferenceAgentsImplConfig",
module="llama_stack.providers.inline.agents.meta_reference",
config_class="llama_stack.providers.inline.agents.meta_reference.MetaReferenceAgentsImplConfig",
api_dependencies=[
Api.inference,
Api.safety,

View file

@ -27,8 +27,8 @@ def available_providers() -> List[ProviderSpec]:
api=Api.inference,
provider_type="meta-reference",
pip_packages=META_REFERENCE_DEPS,
module="llama_stack.providers.inline.meta_reference.inference",
config_class="llama_stack.providers.inline.meta_reference.inference.MetaReferenceInferenceConfig",
module="llama_stack.providers.inline.inference.meta_reference",
config_class="llama_stack.providers.inline.inference.meta_reference.MetaReferenceInferenceConfig",
),
InlineProviderSpec(
api=Api.inference,
@ -40,8 +40,17 @@ def available_providers() -> List[ProviderSpec]:
"torchao==0.5.0",
]
),
module="llama_stack.providers.inline.meta_reference.inference",
config_class="llama_stack.providers.inline.meta_reference.inference.MetaReferenceQuantizedInferenceConfig",
module="llama_stack.providers.inline.inference.meta_reference",
config_class="llama_stack.providers.inline.inference.meta_reference.MetaReferenceQuantizedInferenceConfig",
),
InlineProviderSpec(
api=Api.inference,
provider_type="vllm",
pip_packages=[
"vllm",
],
module="llama_stack.providers.inline.inference.vllm",
config_class="llama_stack.providers.inline.inference.vllm.VLLMConfig",
),
remote_provider_spec(
api=Api.inference,
@ -117,7 +126,7 @@ def available_providers() -> List[ProviderSpec]:
],
module="llama_stack.providers.remote.inference.together",
config_class="llama_stack.providers.remote.inference.together.TogetherImplConfig",
provider_data_validator="llama_stack.providers.remote.safety.together.TogetherProviderDataValidator",
provider_data_validator="llama_stack.providers.remote.inference.together.TogetherProviderDataValidator",
),
),
remote_provider_spec(
@ -140,13 +149,4 @@ def available_providers() -> List[ProviderSpec]:
config_class="llama_stack.providers.remote.inference.databricks.DatabricksImplConfig",
),
),
InlineProviderSpec(
api=Api.inference,
provider_type="vllm",
pip_packages=[
"vllm",
],
module="llama_stack.providers.inline.vllm",
config_class="llama_stack.providers.inline.vllm.VLLMConfig",
),
]

View file

@ -36,8 +36,16 @@ def available_providers() -> List[ProviderSpec]:
api=Api.memory,
provider_type="meta-reference",
pip_packages=EMBEDDING_DEPS + ["faiss-cpu"],
module="llama_stack.providers.inline.meta_reference.memory",
config_class="llama_stack.providers.inline.meta_reference.memory.FaissImplConfig",
module="llama_stack.providers.inline.memory.faiss",
config_class="llama_stack.providers.inline.memory.faiss.FaissImplConfig",
deprecation_warning="Please use the `faiss` provider instead.",
),
InlineProviderSpec(
api=Api.memory,
provider_type="faiss",
pip_packages=EMBEDDING_DEPS + ["faiss-cpu"],
module="llama_stack.providers.inline.memory.faiss",
config_class="llama_stack.providers.inline.memory.faiss.FaissImplConfig",
),
remote_provider_spec(
Api.memory,

View file

@ -24,8 +24,8 @@ def available_providers() -> List[ProviderSpec]:
"transformers",
"torch --index-url https://download.pytorch.org/whl/cpu",
],
module="llama_stack.providers.inline.meta_reference.safety",
config_class="llama_stack.providers.inline.meta_reference.safety.SafetyConfig",
module="llama_stack.providers.inline.safety.meta_reference",
config_class="llama_stack.providers.inline.safety.meta_reference.SafetyConfig",
api_dependencies=[
Api.inference,
],
@ -54,8 +54,8 @@ def available_providers() -> List[ProviderSpec]:
pip_packages=[
"codeshield",
],
module="llama_stack.providers.inline.meta_reference.codeshield",
config_class="llama_stack.providers.inline.meta_reference.codeshield.CodeShieldConfig",
module="llama_stack.providers.inline.safety.meta_reference",
config_class="llama_stack.providers.inline.safety.meta_reference.CodeShieldConfig",
api_dependencies=[],
),
]

View file

@ -4,6 +4,8 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Optional
from llama_models.schema_utils import json_schema_type
from pydantic import BaseModel, Field
@ -14,7 +16,7 @@ class FireworksImplConfig(BaseModel):
default="https://api.fireworks.ai/inference",
description="The URL for the Fireworks server",
)
api_key: str = Field(
default="",
api_key: Optional[str] = Field(
default=None,
description="The Fireworks.ai API Key",
)

View file

@ -9,12 +9,11 @@ from typing import AsyncGenerator
from fireworks.client import Fireworks
from llama_models.llama3.api.chat_format import ChatFormat
from llama_models.llama3.api.datatypes import Message
from llama_models.llama3.api.tokenizer import Tokenizer
from llama_stack.apis.inference import * # noqa: F403
from llama_stack.distribution.request_headers import NeedsRequestProviderData
from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper
from llama_stack.providers.utils.inference.openai_compat import (
get_sampling_options,
@ -32,7 +31,6 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
from .config import FireworksImplConfig
FIREWORKS_SUPPORTED_MODELS = {
"Llama3.1-8B-Instruct": "fireworks/llama-v3p1-8b-instruct",
"Llama3.1-70B-Instruct": "fireworks/llama-v3p1-70b-instruct",
@ -41,10 +39,13 @@ FIREWORKS_SUPPORTED_MODELS = {
"Llama3.2-3B-Instruct": "fireworks/llama-v3p2-3b-instruct",
"Llama3.2-11B-Vision-Instruct": "fireworks/llama-v3p2-11b-vision-instruct",
"Llama3.2-90B-Vision-Instruct": "fireworks/llama-v3p2-90b-vision-instruct",
"Llama-Guard-3-8B": "fireworks/llama-guard-3-8b",
}
class FireworksInferenceAdapter(ModelRegistryHelper, Inference):
class FireworksInferenceAdapter(
ModelRegistryHelper, Inference, NeedsRequestProviderData
):
def __init__(self, config: FireworksImplConfig) -> None:
ModelRegistryHelper.__init__(
self, stack_to_provider_models_map=FIREWORKS_SUPPORTED_MODELS
@ -53,11 +54,24 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference):
self.formatter = ChatFormat(Tokenizer.get_instance())
async def initialize(self) -> None:
return
pass
async def shutdown(self) -> None:
pass
def _get_client(self) -> Fireworks:
fireworks_api_key = None
if self.config.api_key is not None:
fireworks_api_key = self.config.api_key
else:
provider_data = self.get_request_provider_data()
if provider_data is None or not provider_data.fireworks_api_key:
raise ValueError(
'Pass Fireworks API Key in the header X-LlamaStack-ProviderData as { "fireworks_api_key": <your api key>}'
)
fireworks_api_key = provider_data.fireworks_api_key
return Fireworks(api_key=fireworks_api_key)
async def completion(
self,
model: str,
@ -75,28 +89,53 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference):
stream=stream,
logprobs=logprobs,
)
client = Fireworks(api_key=self.config.api_key)
if stream:
return self._stream_completion(request, client)
return self._stream_completion(request)
else:
return await self._nonstream_completion(request, client)
return await self._nonstream_completion(request)
async def _nonstream_completion(
self, request: CompletionRequest, client: Fireworks
self, request: CompletionRequest
) -> CompletionResponse:
params = await self._get_params(request)
r = await client.completion.acreate(**params)
r = await self._get_client().completion.acreate(**params)
return process_completion_response(r, self.formatter)
async def _stream_completion(
self, request: CompletionRequest, client: Fireworks
) -> AsyncGenerator:
async def _stream_completion(self, request: CompletionRequest) -> AsyncGenerator:
params = await self._get_params(request)
stream = client.completion.acreate(**params)
# Wrapper for async generator similar
async def _to_async_generator():
stream = self._get_client().completion.create(**params)
for chunk in stream:
yield chunk
stream = _to_async_generator()
async for chunk in process_completion_stream_response(stream, self.formatter):
yield chunk
def _build_options(
self, sampling_params: Optional[SamplingParams], fmt: ResponseFormat
) -> dict:
options = get_sampling_options(sampling_params)
options.setdefault("max_tokens", 512)
if fmt:
if fmt.type == ResponseFormatType.json_schema.value:
options["response_format"] = {
"type": "json_object",
"schema": fmt.json_schema,
}
elif fmt.type == ResponseFormatType.grammar.value:
options["response_format"] = {
"type": "grammar",
"grammar": fmt.bnf,
}
else:
raise ValueError(f"Unknown response format {fmt.type}")
return options
async def chat_completion(
self,
model: str,
@ -121,32 +160,35 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference):
logprobs=logprobs,
)
client = Fireworks(api_key=self.config.api_key)
if stream:
return self._stream_chat_completion(request, client)
return self._stream_chat_completion(request)
else:
return await self._nonstream_chat_completion(request, client)
return await self._nonstream_chat_completion(request)
async def _nonstream_chat_completion(
self, request: ChatCompletionRequest, client: Fireworks
self, request: ChatCompletionRequest
) -> ChatCompletionResponse:
params = await self._get_params(request)
if "messages" in params:
r = await client.chat.completions.acreate(**params)
r = await self._get_client().chat.completions.acreate(**params)
else:
r = await client.completion.acreate(**params)
r = await self._get_client().completion.acreate(**params)
return process_chat_completion_response(r, self.formatter)
async def _stream_chat_completion(
self, request: ChatCompletionRequest, client: Fireworks
self, request: ChatCompletionRequest
) -> AsyncGenerator:
params = await self._get_params(request)
if "messages" in params:
stream = client.chat.completions.acreate(**params)
else:
stream = client.completion.acreate(**params)
async def _to_async_generator():
if "messages" in params:
stream = await self._get_client().chat.completions.acreate(**params)
else:
stream = self._get_client().completion.create(**params)
for chunk in stream:
yield chunk
stream = _to_async_generator()
async for chunk in process_chat_completion_stream_response(
stream, self.formatter
):
@ -167,41 +209,22 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference):
input_dict["prompt"] = chat_completion_request_to_prompt(
request, self.formatter
)
elif isinstance(request, CompletionRequest):
else:
assert (
not media_present
), "Fireworks does not support media for Completion requests"
input_dict["prompt"] = completion_request_to_prompt(request, self.formatter)
else:
raise ValueError(f"Unknown request type {type(request)}")
# Fireworks always prepends with BOS
if "prompt" in input_dict:
if input_dict["prompt"].startswith("<|begin_of_text|>"):
input_dict["prompt"] = input_dict["prompt"][len("<|begin_of_text|>") :]
options = get_sampling_options(request.sampling_params)
options.setdefault("max_tokens", 512)
if fmt := request.response_format:
if fmt.type == ResponseFormatType.json_schema.value:
options["response_format"] = {
"type": "json_object",
"schema": fmt.json_schema,
}
elif fmt.type == ResponseFormatType.grammar.value:
options["response_format"] = {
"type": "grammar",
"grammar": fmt.bnf,
}
else:
raise ValueError(f"Unknown response format {fmt.type}")
return {
"model": self.map_to_provider_model(request.model),
**input_dict,
"stream": request.stream,
**options,
**self._build_options(request.sampling_params, request.response_format),
}
async def embeddings(

View file

@ -4,9 +4,15 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from pydantic import BaseModel
from .config import TogetherImplConfig
class TogetherProviderDataValidator(BaseModel):
together_api_key: str
async def get_adapter_impl(config: TogetherImplConfig, _deps):
from .together import TogetherInferenceAdapter

View file

@ -11,7 +11,7 @@ import pytest_asyncio
from llama_stack.distribution.datatypes import Api, Provider
from llama_stack.providers.inline.meta_reference.agents import (
from llama_stack.providers.inline.agents.meta_reference import (
MetaReferenceAgentsImplConfig,
)

View file

@ -52,7 +52,7 @@ class Testeval:
response = await eval_impl.evaluate_rows(
input_rows=rows.rows,
scoring_functions=scoring_functions,
eval_task_config=AppEvalTaskConfig(
task_config=AppEvalTaskConfig(
eval_candidate=ModelCandidate(
model="Llama3.2-3B-Instruct",
sampling_params=SamplingParams(),
@ -76,13 +76,13 @@ class Testeval:
]
response = await eval_impl.run_eval(
eval_task_def=EvalTaskDef(
task=EvalTaskDef(
# NOTE: this is needed to make the router work for all app evals
identifier="meta-reference::app_eval",
dataset_id="test_dataset_for_eval",
scoring_functions=scoring_functions,
),
eval_task_config=AppEvalTaskConfig(
task_config=AppEvalTaskConfig(
eval_candidate=ModelCandidate(
model="Llama3.2-3B-Instruct",
sampling_params=SamplingParams(),
@ -90,9 +90,13 @@ class Testeval:
),
)
assert response.job_id == "0"
job_status = await eval_impl.job_status(response.job_id)
job_status = await eval_impl.job_status(
response.job_id, "meta-reference::app_eval"
)
assert job_status and job_status.value == "completed"
eval_response = await eval_impl.job_result(response.job_id)
eval_response = await eval_impl.job_result(
response.job_id, "meta-reference::app_eval"
)
assert eval_response is not None
assert len(eval_response.generations) == 5

View file

@ -10,7 +10,7 @@ import pytest
import pytest_asyncio
from llama_stack.distribution.datatypes import Api, Provider
from llama_stack.providers.inline.meta_reference.inference import (
from llama_stack.providers.inline.inference.meta_reference import (
MetaReferenceInferenceConfig,
)

View file

@ -11,7 +11,7 @@ import pytest
import pytest_asyncio
from llama_stack.distribution.datatypes import Api, Provider
from llama_stack.providers.inline.meta_reference.memory import FaissImplConfig
from llama_stack.providers.inline.memory.faiss import FaissImplConfig
from llama_stack.providers.remote.memory.pgvector import PGVectorConfig
from llama_stack.providers.remote.memory.weaviate import WeaviateConfig

View file

@ -8,7 +8,7 @@ import pytest
import pytest_asyncio
from llama_stack.distribution.datatypes import Api, Provider
from llama_stack.providers.inline.meta_reference.safety import (
from llama_stack.providers.inline.safety.meta_reference import (
LlamaGuardShieldConfig,
SafetyConfig,
)

View file

@ -44,10 +44,10 @@ class TestScoring:
)
assert len(rows.rows) == 3
scoring_functions = [
"meta-reference::llm_as_judge_8b_correctness",
"meta-reference::equality",
]
scoring_functions = {
"meta-reference::llm_as_judge_8b_correctness": None,
"meta-reference::equality": None,
}
response = await scoring_impl.score(
input_rows=rows.rows,
scoring_functions=scoring_functions,
@ -83,7 +83,7 @@ class TestScoring:
)
assert len(rows.rows) == 3
params = {
scoring_functions = {
"meta-reference::llm_as_judge_8b_correctness": LLMAsJudgeScoringFnParams(
judge_model="Llama3.1-405B-Instruct",
prompt_template="Output a number response in the following format: Score: <number>, where <number> is the number between 0 and 9.",
@ -91,13 +91,9 @@ class TestScoring:
)
}
scoring_functions = [
"meta-reference::llm_as_judge_8b_correctness",
]
response = await scoring_impl.score(
input_rows=rows.rows,
scoring_functions=scoring_functions,
scoring_params=params,
)
assert len(response.results) == len(scoring_functions)
for x in scoring_functions:
@ -108,7 +104,6 @@ class TestScoring:
response = await scoring_impl.score_batch(
dataset_id="test_dataset",
scoring_functions=scoring_functions,
scoring_params=params,
)
assert len(response.results) == len(scoring_functions)
for x in scoring_functions: