mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-10-16 06:53:47 +00:00
dataset client
This commit is contained in:
parent
c5db025320
commit
bb43369521
6 changed files with 156 additions and 10 deletions
|
@ -3,16 +3,20 @@
|
|||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
from typing import Dict
|
||||
|
||||
from llama_stack.distribution.datatypes import Api, ProviderSpec
|
||||
|
||||
from .config import MetaReferenceScoringConfig
|
||||
|
||||
|
||||
async def get_provider_impl(
|
||||
config: MetaReferenceScoringConfig,
|
||||
_deps,
|
||||
deps: Dict[Api, ProviderSpec],
|
||||
):
|
||||
print("get_provider_impl", deps)
|
||||
from .scoring import MetaReferenceScoringImpl
|
||||
|
||||
impl = MetaReferenceScoringImpl(config)
|
||||
impl = MetaReferenceScoringImpl(config, deps[Api.datasetio])
|
||||
await impl.initialize()
|
||||
return impl
|
||||
|
|
|
@ -7,6 +7,9 @@ from typing import List
|
|||
|
||||
from llama_models.llama3.api.datatypes import * # noqa: F403
|
||||
from llama_stack.apis.scoring import * # noqa: F403
|
||||
from llama_stack.apis.datasetio import * # noqa: F403
|
||||
|
||||
from termcolor import cprint
|
||||
|
||||
from llama_stack.providers.datatypes import ScoringFunctionsProtocolPrivate
|
||||
|
||||
|
@ -14,9 +17,12 @@ from .config import MetaReferenceScoringConfig
|
|||
|
||||
|
||||
class MetaReferenceScoringImpl(Scoring, ScoringFunctionsProtocolPrivate):
|
||||
def __init__(self, config: MetaReferenceScoringConfig) -> None:
|
||||
def __init__(
|
||||
self, config: MetaReferenceScoringConfig, datasetio_api: DatasetIO
|
||||
) -> None:
|
||||
self.config = config
|
||||
self.dataset_infos = {}
|
||||
self.datasetio_api = datasetio_api
|
||||
cprint(f"!!! MetaReferenceScoringImpl init {config} {datasetio_api}", "red")
|
||||
|
||||
async def initialize(self) -> None: ...
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue