This commit is contained in:
Xi Yan 2024-10-03 13:47:15 -07:00
parent 7143ecfc0d
commit 8339b2cef3
10 changed files with 174 additions and 51 deletions

View file

@ -24,6 +24,8 @@ class Api(Enum):
shields = "shields"
memory_banks = "memory_banks"
evals = "evals"
# built-in API
inspect = "inspect"

View file

@ -0,0 +1,19 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .config import MetaReferenceEvalsImplConfig # noqa
from llama_stack.apis.inference import * # noqa: F403
from llama_stack.distribution.datatypes import Api, ProviderSpec
async def get_provider_impl(
config: MetaReferenceEvalsImplConfig, deps: Dict[Api, ProviderSpec]
):
from .evals import MetaReferenceEvalsImpl
impl = MetaReferenceEvalsImpl(config, deps[Api.inference])
await impl.initialize()
return impl

View file

@ -0,0 +1,10 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from pydantic import BaseModel
class MetaReferenceEvalsImplConfig(BaseModel): ...

View file

@ -0,0 +1,34 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_stack.apis.inference import * # noqa: F403
from llama_stack.apis.evals import * # noqa: F403
from .config import MetaReferenceEvalsImplConfig
class MetaReferenceEvalsImpl(Evals):
def __init__(self, config: MetaReferenceEvalsImplConfig, inference_api: Inference):
self.inference_api = inference_api
async def initialize(self) -> None:
pass
async def shutdown(self) -> None:
pass
async def run_evals(
self,
model: str,
dataset: str,
task: str,
) -> EvaluateResponse:
print("hi")
return EvaluateResponse(
metrics={
"accuracy": 0.5,
}
)

View file

@ -0,0 +1,29 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import List
from llama_stack.distribution.datatypes import * # noqa: F403
def available_providers() -> List[ProviderSpec]:
return [
InlineProviderSpec(
api=Api.evals,
provider_type="meta-reference",
pip_packages=[
"matplotlib",
"pillow",
"pandas",
"scikit-learn",
],
module="llama_stack.providers.impls.meta_reference.evals",
config_class="llama_stack.providers.impls.meta_reference.evals.MetaReferenceEvalsImplConfig",
api_dependencies=[
Api.inference,
],
),
]