test_scoring

This commit is contained in:
Xi Yan 2024-10-23 13:01:49 -07:00
parent 7c280e18fb
commit 92e32f80ad
15 changed files with 240 additions and 5 deletions

View file

@ -0,0 +1,18 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .config import MetaReferenceScoringConfig
async def get_provider_impl(
config: MetaReferenceScoringConfig,
_deps,
):
from .scoring import MetaReferenceScoringImpl
impl = MetaReferenceScoringImpl(config)
await impl.initialize()
return impl

View file

@ -0,0 +1,9 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_stack.apis.scoring import * # noqa: F401, F403
class MetaReferenceScoringConfig(BaseModel): ...

View file

@ -0,0 +1,33 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import List
from llama_models.llama3.api.datatypes import * # noqa: F403
from llama_stack.apis.scoring import * # noqa: F403
from llama_stack.providers.datatypes import ScoringFunctionsProtocolPrivate
from .config import MetaReferenceScoringConfig
class MetaReferenceScoringImpl(Scoring, ScoringFunctionsProtocolPrivate):
def __init__(self, config: MetaReferenceScoringConfig) -> None:
self.config = config
self.dataset_infos = {}
async def initialize(self) -> None: ...
async def shutdown(self) -> None: ...
async def score_batch(
self, dataset_id: str, scoring_functions: List[str]
) -> ScoreBatchResponse:
print("score_batch")
async def score(
self, input_rows: List[Dict[str, Any]], scoring_functions: List[str]
) -> ScoreResponse:
print("score")