dataset client

This commit is contained in:
Xi Yan 2024-10-23 13:53:58 -07:00
parent c5db025320
commit bb43369521
6 changed files with 156 additions and 10 deletions

View file

@ -3,16 +3,20 @@
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Dict
from llama_stack.distribution.datatypes import Api, ProviderSpec
from .config import MetaReferenceScoringConfig
async def get_provider_impl(
config: MetaReferenceScoringConfig,
_deps,
deps: Dict[Api, ProviderSpec],
):
print("get_provider_impl", deps)
from .scoring import MetaReferenceScoringImpl
impl = MetaReferenceScoringImpl(config)
impl = MetaReferenceScoringImpl(config, deps[Api.datasetio])
await impl.initialize()
return impl

View file

@ -7,6 +7,9 @@ from typing import List
from llama_models.llama3.api.datatypes import * # noqa: F403
from llama_stack.apis.scoring import * # noqa: F403
from llama_stack.apis.datasetio import * # noqa: F403
from termcolor import cprint
from llama_stack.providers.datatypes import ScoringFunctionsProtocolPrivate
@ -14,9 +17,12 @@ from .config import MetaReferenceScoringConfig
class MetaReferenceScoringImpl(Scoring, ScoringFunctionsProtocolPrivate):
def __init__(self, config: MetaReferenceScoringConfig) -> None:
def __init__(
self, config: MetaReferenceScoringConfig, datasetio_api: DatasetIO
) -> None:
self.config = config
self.dataset_infos = {}
self.datasetio_api = datasetio_api
cprint(f"!!! MetaReferenceScoringImpl init {config} {datasetio_api}", "red")
async def initialize(self) -> None: ...