This commit is contained in:
Xi Yan 2024-10-03 13:47:15 -07:00
parent 7143ecfc0d
commit 8339b2cef3
10 changed files with 174 additions and 51 deletions

View file

@ -0,0 +1,57 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import asyncio
import fire
import httpx
from termcolor import cprint
from .evals import * # noqa: F403
class EvaluationClient(Evals):
def __init__(self, base_url: str):
self.base_url = base_url
async def initialize(self) -> None:
pass
async def shutdown(self) -> None:
pass
async def run_evals(self, model: str, dataset: str, task: str) -> EvaluateResponse:
async with httpx.AsyncClient() as client:
response = await client.post(
f"{self.base_url}/evals/run",
json={
"model": model,
"dataset": dataset,
"task": task,
},
headers={"Content-Type": "application/json"},
)
response.raise_for_status()
return EvaluateResponse(**response.json())
async def run_main(host: str, port: int):
client = EvaluationClient(f"http://{host}:{port}")
response = await client.run_evals(
"Llama3.1-8B-Instruct",
"mmlu.csv",
"mmlu",
)
cprint(f"evaluate response={response}", "green")
def main(host: str, port: int):
asyncio.run(run_main(host, port))
if __name__ == "__main__":
fire.Fire(main)

View file

@ -4,7 +4,6 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
from enum import Enum
from typing import List, Protocol from typing import List, Protocol
from llama_models.schema_utils import webmethod from llama_models.schema_utils import webmethod
@ -16,22 +15,6 @@ from llama_stack.apis.dataset import * # noqa: F403
from llama_stack.apis.common.training_types import * # noqa: F403 from llama_stack.apis.common.training_types import * # noqa: F403
class TextGenerationMetric(Enum):
perplexity = "perplexity"
rouge = "rouge"
bleu = "bleu"
class QuestionAnsweringMetric(Enum):
em = "em"
f1 = "f1"
class SummarizationMetric(Enum):
rouge = "rouge"
bleu = "bleu"
class EvaluationJob(BaseModel): class EvaluationJob(BaseModel):
job_uuid: str job_uuid: str
@ -54,28 +37,7 @@ class EvaluateTaskRequestCommon(BaseModel):
class EvaluateResponse(BaseModel): class EvaluateResponse(BaseModel):
"""Scores for evaluation.""" """Scores for evaluation."""
scores = Dict[str, str] metrics: Dict[str, float]
@json_schema_type
class EvaluateTextGenerationRequest(EvaluateTaskRequestCommon):
"""Request to evaluate text generation."""
metrics: List[TextGenerationMetric]
@json_schema_type
class EvaluateQuestionAnsweringRequest(EvaluateTaskRequestCommon):
"""Request to evaluate question answering."""
metrics: List[QuestionAnsweringMetric]
@json_schema_type
class EvaluateSummarizationRequest(EvaluateTaskRequestCommon):
"""Request to evaluate summarization."""
metrics: List[SummarizationMetric]
@json_schema_type @json_schema_type
@ -97,33 +59,36 @@ class EvaluationJobCreateResponse(BaseModel):
job_uuid: str job_uuid: str
class Evaluations(Protocol): class Evals(Protocol):
@webmethod(route="/evaluate") @webmethod(route="/evals/run")
async def evaluate( async def run_evals(
self, model: str, dataset: str, task: str self,
model: str,
dataset: str,
task: str,
) -> EvaluateResponse: ... ) -> EvaluateResponse: ...
@webmethod(route="/evaluate/jobs") @webmethod(route="/evals/jobs")
def get_evaluation_jobs(self) -> List[EvaluationJob]: ... def get_evaluation_jobs(self) -> List[EvaluationJob]: ...
@webmethod(route="/evaluate/job/create") @webmethod(route="/evals/job/create")
async def create_evaluation_job( async def create_evaluation_job(
self, model: str, dataset: str, task: str self, model: str, dataset: str, task: str
) -> EvaluationJob: ... ) -> EvaluationJob: ...
@webmethod(route="/evaluate/job/status") @webmethod(route="/evals/job/status")
def get_evaluation_job_status( def get_evaluation_job_status(
self, job_uuid: str self, job_uuid: str
) -> EvaluationJobStatusResponse: ... ) -> EvaluationJobStatusResponse: ...
# sends SSE stream of logs # sends SSE stream of logs
@webmethod(route="/evaluate/job/logs") @webmethod(route="/evals/job/logs")
def get_evaluation_job_logstream(self, job_uuid: str) -> EvaluationJobLogStream: ... def get_evaluation_job_logstream(self, job_uuid: str) -> EvaluationJobLogStream: ...
@webmethod(route="/evaluate/job/cancel") @webmethod(route="/evals/job/cancel")
def cancel_evaluation_job(self, job_uuid: str) -> None: ... def cancel_evaluation_job(self, job_uuid: str) -> None: ...
@webmethod(route="/evaluate/job/artifacts") @webmethod(route="/evals/job/artifacts")
def get_evaluation_job_artifacts( def get_evaluation_job_artifacts(
self, job_uuid: str self, job_uuid: str
) -> EvaluationJobArtifactsResponse: ... ) -> EvaluationJobArtifactsResponse: ...

View file

@ -8,6 +8,7 @@ import importlib
from typing import Any, Dict, List, Set from typing import Any, Dict, List, Set
from llama_stack.distribution.datatypes import * # noqa: F403 from llama_stack.distribution.datatypes import * # noqa: F403
from llama_stack.distribution.distribution import ( from llama_stack.distribution.distribution import (
builtin_automatically_routed_apis, builtin_automatically_routed_apis,
get_provider_registry, get_provider_registry,

View file

@ -10,6 +10,7 @@ from typing import Dict, List
from pydantic import BaseModel from pydantic import BaseModel
from llama_stack.apis.agents import Agents from llama_stack.apis.agents import Agents
from llama_stack.apis.evals import Evals
from llama_stack.apis.inference import Inference from llama_stack.apis.inference import Inference
from llama_stack.apis.inspect import Inspect from llama_stack.apis.inspect import Inspect
from llama_stack.apis.memory import Memory from llama_stack.apis.memory import Memory
@ -41,6 +42,7 @@ def get_all_api_endpoints() -> Dict[Api, List[ApiEndpoint]]:
Api.shields: Shields, Api.shields: Shields,
Api.memory_banks: MemoryBanks, Api.memory_banks: MemoryBanks,
Api.inspect: Inspect, Api.inspect: Inspect,
Api.evals: Evals,
} }
for api, protocol in protocols.items(): for api, protocol in protocols.items():

View file

@ -24,6 +24,8 @@ class Api(Enum):
shields = "shields" shields = "shields"
memory_banks = "memory_banks" memory_banks = "memory_banks"
evals = "evals"
# built-in API # built-in API
inspect = "inspect" inspect = "inspect"

View file

@ -0,0 +1,19 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .config import MetaReferenceEvalsImplConfig # noqa
from llama_stack.apis.inference import * # noqa: F403
from llama_stack.distribution.datatypes import Api, ProviderSpec
async def get_provider_impl(
config: MetaReferenceEvalsImplConfig, deps: Dict[Api, ProviderSpec]
):
from .evals import MetaReferenceEvalsImpl
impl = MetaReferenceEvalsImpl(config, deps[Api.inference])
await impl.initialize()
return impl

View file

@ -0,0 +1,10 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from pydantic import BaseModel
class MetaReferenceEvalsImplConfig(BaseModel): ...

View file

@ -0,0 +1,34 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_stack.apis.inference import * # noqa: F403
from llama_stack.apis.evals import * # noqa: F403
from .config import MetaReferenceEvalsImplConfig
class MetaReferenceEvalsImpl(Evals):
def __init__(self, config: MetaReferenceEvalsImplConfig, inference_api: Inference):
self.inference_api = inference_api
async def initialize(self) -> None:
pass
async def shutdown(self) -> None:
pass
async def run_evals(
self,
model: str,
dataset: str,
task: str,
) -> EvaluateResponse:
print("hi")
return EvaluateResponse(
metrics={
"accuracy": 0.5,
}
)

View file

@ -0,0 +1,29 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import List
from llama_stack.distribution.datatypes import * # noqa: F403
def available_providers() -> List[ProviderSpec]:
return [
InlineProviderSpec(
api=Api.evals,
provider_type="meta-reference",
pip_packages=[
"matplotlib",
"pillow",
"pandas",
"scikit-learn",
],
module="llama_stack.providers.impls.meta_reference.evals",
config_class="llama_stack.providers.impls.meta_reference.evals.MetaReferenceEvalsImplConfig",
api_dependencies=[
Api.inference,
],
),
]

View file

@ -10,7 +10,11 @@ apis_to_serve:
- memory_banks - memory_banks
- inference - inference
- safety - safety
- evals
api_providers: api_providers:
evals:
provider_type: meta-reference
config: {}
inference: inference:
providers: providers:
- meta-reference - meta-reference
@ -34,12 +38,12 @@ routing_table:
inference: inference:
- provider_type: meta-reference - provider_type: meta-reference
config: config:
model: Llama3.1-8B-Instruct model: Llama3.2-1B
quantization: null quantization: null
torch_seed: null torch_seed: null
max_seq_len: 4096 max_seq_len: 4096
max_batch_size: 1 max_batch_size: 1
routing_key: Llama3.1-8B-Instruct routing_key: Llama3.2-1B
safety: safety:
- provider_type: meta-reference - provider_type: meta-reference
config: config: