mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-29 15:23:51 +00:00
wip api
This commit is contained in:
parent
7143ecfc0d
commit
8339b2cef3
10 changed files with 174 additions and 51 deletions
57
llama_stack/apis/evals/client.py
Normal file
57
llama_stack/apis/evals/client.py
Normal file
|
@ -0,0 +1,57 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
import fire
|
||||||
|
import httpx
|
||||||
|
from termcolor import cprint
|
||||||
|
|
||||||
|
from .evals import * # noqa: F403
|
||||||
|
|
||||||
|
|
||||||
|
class EvaluationClient(Evals):
|
||||||
|
def __init__(self, base_url: str):
|
||||||
|
self.base_url = base_url
|
||||||
|
|
||||||
|
async def initialize(self) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def shutdown(self) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def run_evals(self, model: str, dataset: str, task: str) -> EvaluateResponse:
|
||||||
|
async with httpx.AsyncClient() as client:
|
||||||
|
response = await client.post(
|
||||||
|
f"{self.base_url}/evals/run",
|
||||||
|
json={
|
||||||
|
"model": model,
|
||||||
|
"dataset": dataset,
|
||||||
|
"task": task,
|
||||||
|
},
|
||||||
|
headers={"Content-Type": "application/json"},
|
||||||
|
)
|
||||||
|
response.raise_for_status()
|
||||||
|
return EvaluateResponse(**response.json())
|
||||||
|
|
||||||
|
|
||||||
|
async def run_main(host: str, port: int):
|
||||||
|
client = EvaluationClient(f"http://{host}:{port}")
|
||||||
|
|
||||||
|
response = await client.run_evals(
|
||||||
|
"Llama3.1-8B-Instruct",
|
||||||
|
"mmlu.csv",
|
||||||
|
"mmlu",
|
||||||
|
)
|
||||||
|
cprint(f"evaluate response={response}", "green")
|
||||||
|
|
||||||
|
|
||||||
|
def main(host: str, port: int):
|
||||||
|
asyncio.run(run_main(host, port))
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
fire.Fire(main)
|
|
@ -4,7 +4,6 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from enum import Enum
|
|
||||||
from typing import List, Protocol
|
from typing import List, Protocol
|
||||||
|
|
||||||
from llama_models.schema_utils import webmethod
|
from llama_models.schema_utils import webmethod
|
||||||
|
@ -16,22 +15,6 @@ from llama_stack.apis.dataset import * # noqa: F403
|
||||||
from llama_stack.apis.common.training_types import * # noqa: F403
|
from llama_stack.apis.common.training_types import * # noqa: F403
|
||||||
|
|
||||||
|
|
||||||
class TextGenerationMetric(Enum):
|
|
||||||
perplexity = "perplexity"
|
|
||||||
rouge = "rouge"
|
|
||||||
bleu = "bleu"
|
|
||||||
|
|
||||||
|
|
||||||
class QuestionAnsweringMetric(Enum):
|
|
||||||
em = "em"
|
|
||||||
f1 = "f1"
|
|
||||||
|
|
||||||
|
|
||||||
class SummarizationMetric(Enum):
|
|
||||||
rouge = "rouge"
|
|
||||||
bleu = "bleu"
|
|
||||||
|
|
||||||
|
|
||||||
class EvaluationJob(BaseModel):
|
class EvaluationJob(BaseModel):
|
||||||
job_uuid: str
|
job_uuid: str
|
||||||
|
|
||||||
|
@ -54,28 +37,7 @@ class EvaluateTaskRequestCommon(BaseModel):
|
||||||
class EvaluateResponse(BaseModel):
|
class EvaluateResponse(BaseModel):
|
||||||
"""Scores for evaluation."""
|
"""Scores for evaluation."""
|
||||||
|
|
||||||
scores = Dict[str, str]
|
metrics: Dict[str, float]
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
|
||||||
class EvaluateTextGenerationRequest(EvaluateTaskRequestCommon):
|
|
||||||
"""Request to evaluate text generation."""
|
|
||||||
|
|
||||||
metrics: List[TextGenerationMetric]
|
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
|
||||||
class EvaluateQuestionAnsweringRequest(EvaluateTaskRequestCommon):
|
|
||||||
"""Request to evaluate question answering."""
|
|
||||||
|
|
||||||
metrics: List[QuestionAnsweringMetric]
|
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
|
||||||
class EvaluateSummarizationRequest(EvaluateTaskRequestCommon):
|
|
||||||
"""Request to evaluate summarization."""
|
|
||||||
|
|
||||||
metrics: List[SummarizationMetric]
|
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
|
@ -97,33 +59,36 @@ class EvaluationJobCreateResponse(BaseModel):
|
||||||
job_uuid: str
|
job_uuid: str
|
||||||
|
|
||||||
|
|
||||||
class Evaluations(Protocol):
|
class Evals(Protocol):
|
||||||
@webmethod(route="/evaluate")
|
@webmethod(route="/evals/run")
|
||||||
async def evaluate(
|
async def run_evals(
|
||||||
self, model: str, dataset: str, task: str
|
self,
|
||||||
|
model: str,
|
||||||
|
dataset: str,
|
||||||
|
task: str,
|
||||||
) -> EvaluateResponse: ...
|
) -> EvaluateResponse: ...
|
||||||
|
|
||||||
@webmethod(route="/evaluate/jobs")
|
@webmethod(route="/evals/jobs")
|
||||||
def get_evaluation_jobs(self) -> List[EvaluationJob]: ...
|
def get_evaluation_jobs(self) -> List[EvaluationJob]: ...
|
||||||
|
|
||||||
@webmethod(route="/evaluate/job/create")
|
@webmethod(route="/evals/job/create")
|
||||||
async def create_evaluation_job(
|
async def create_evaluation_job(
|
||||||
self, model: str, dataset: str, task: str
|
self, model: str, dataset: str, task: str
|
||||||
) -> EvaluationJob: ...
|
) -> EvaluationJob: ...
|
||||||
|
|
||||||
@webmethod(route="/evaluate/job/status")
|
@webmethod(route="/evals/job/status")
|
||||||
def get_evaluation_job_status(
|
def get_evaluation_job_status(
|
||||||
self, job_uuid: str
|
self, job_uuid: str
|
||||||
) -> EvaluationJobStatusResponse: ...
|
) -> EvaluationJobStatusResponse: ...
|
||||||
|
|
||||||
# sends SSE stream of logs
|
# sends SSE stream of logs
|
||||||
@webmethod(route="/evaluate/job/logs")
|
@webmethod(route="/evals/job/logs")
|
||||||
def get_evaluation_job_logstream(self, job_uuid: str) -> EvaluationJobLogStream: ...
|
def get_evaluation_job_logstream(self, job_uuid: str) -> EvaluationJobLogStream: ...
|
||||||
|
|
||||||
@webmethod(route="/evaluate/job/cancel")
|
@webmethod(route="/evals/job/cancel")
|
||||||
def cancel_evaluation_job(self, job_uuid: str) -> None: ...
|
def cancel_evaluation_job(self, job_uuid: str) -> None: ...
|
||||||
|
|
||||||
@webmethod(route="/evaluate/job/artifacts")
|
@webmethod(route="/evals/job/artifacts")
|
||||||
def get_evaluation_job_artifacts(
|
def get_evaluation_job_artifacts(
|
||||||
self, job_uuid: str
|
self, job_uuid: str
|
||||||
) -> EvaluationJobArtifactsResponse: ...
|
) -> EvaluationJobArtifactsResponse: ...
|
||||||
|
|
|
@ -8,6 +8,7 @@ import importlib
|
||||||
from typing import Any, Dict, List, Set
|
from typing import Any, Dict, List, Set
|
||||||
|
|
||||||
from llama_stack.distribution.datatypes import * # noqa: F403
|
from llama_stack.distribution.datatypes import * # noqa: F403
|
||||||
|
|
||||||
from llama_stack.distribution.distribution import (
|
from llama_stack.distribution.distribution import (
|
||||||
builtin_automatically_routed_apis,
|
builtin_automatically_routed_apis,
|
||||||
get_provider_registry,
|
get_provider_registry,
|
||||||
|
|
|
@ -10,6 +10,7 @@ from typing import Dict, List
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
from llama_stack.apis.agents import Agents
|
from llama_stack.apis.agents import Agents
|
||||||
|
from llama_stack.apis.evals import Evals
|
||||||
from llama_stack.apis.inference import Inference
|
from llama_stack.apis.inference import Inference
|
||||||
from llama_stack.apis.inspect import Inspect
|
from llama_stack.apis.inspect import Inspect
|
||||||
from llama_stack.apis.memory import Memory
|
from llama_stack.apis.memory import Memory
|
||||||
|
@ -41,6 +42,7 @@ def get_all_api_endpoints() -> Dict[Api, List[ApiEndpoint]]:
|
||||||
Api.shields: Shields,
|
Api.shields: Shields,
|
||||||
Api.memory_banks: MemoryBanks,
|
Api.memory_banks: MemoryBanks,
|
||||||
Api.inspect: Inspect,
|
Api.inspect: Inspect,
|
||||||
|
Api.evals: Evals,
|
||||||
}
|
}
|
||||||
|
|
||||||
for api, protocol in protocols.items():
|
for api, protocol in protocols.items():
|
||||||
|
|
|
@ -24,6 +24,8 @@ class Api(Enum):
|
||||||
shields = "shields"
|
shields = "shields"
|
||||||
memory_banks = "memory_banks"
|
memory_banks = "memory_banks"
|
||||||
|
|
||||||
|
evals = "evals"
|
||||||
|
|
||||||
# built-in API
|
# built-in API
|
||||||
inspect = "inspect"
|
inspect = "inspect"
|
||||||
|
|
||||||
|
|
19
llama_stack/providers/impls/meta_reference/evals/__init__.py
Normal file
19
llama_stack/providers/impls/meta_reference/evals/__init__.py
Normal file
|
@ -0,0 +1,19 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from .config import MetaReferenceEvalsImplConfig # noqa
|
||||||
|
from llama_stack.apis.inference import * # noqa: F403
|
||||||
|
from llama_stack.distribution.datatypes import Api, ProviderSpec
|
||||||
|
|
||||||
|
|
||||||
|
async def get_provider_impl(
|
||||||
|
config: MetaReferenceEvalsImplConfig, deps: Dict[Api, ProviderSpec]
|
||||||
|
):
|
||||||
|
from .evals import MetaReferenceEvalsImpl
|
||||||
|
|
||||||
|
impl = MetaReferenceEvalsImpl(config, deps[Api.inference])
|
||||||
|
await impl.initialize()
|
||||||
|
return impl
|
10
llama_stack/providers/impls/meta_reference/evals/config.py
Normal file
10
llama_stack/providers/impls/meta_reference/evals/config.py
Normal file
|
@ -0,0 +1,10 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
|
||||||
|
class MetaReferenceEvalsImplConfig(BaseModel): ...
|
34
llama_stack/providers/impls/meta_reference/evals/evals.py
Normal file
34
llama_stack/providers/impls/meta_reference/evals/evals.py
Normal file
|
@ -0,0 +1,34 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from llama_stack.apis.inference import * # noqa: F403
|
||||||
|
from llama_stack.apis.evals import * # noqa: F403
|
||||||
|
|
||||||
|
from .config import MetaReferenceEvalsImplConfig
|
||||||
|
|
||||||
|
|
||||||
|
class MetaReferenceEvalsImpl(Evals):
|
||||||
|
def __init__(self, config: MetaReferenceEvalsImplConfig, inference_api: Inference):
|
||||||
|
self.inference_api = inference_api
|
||||||
|
|
||||||
|
async def initialize(self) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def shutdown(self) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def run_evals(
|
||||||
|
self,
|
||||||
|
model: str,
|
||||||
|
dataset: str,
|
||||||
|
task: str,
|
||||||
|
) -> EvaluateResponse:
|
||||||
|
print("hi")
|
||||||
|
return EvaluateResponse(
|
||||||
|
metrics={
|
||||||
|
"accuracy": 0.5,
|
||||||
|
}
|
||||||
|
)
|
29
llama_stack/providers/registry/evals.py
Normal file
29
llama_stack/providers/registry/evals.py
Normal file
|
@ -0,0 +1,29 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
from llama_stack.distribution.datatypes import * # noqa: F403
|
||||||
|
|
||||||
|
|
||||||
|
def available_providers() -> List[ProviderSpec]:
|
||||||
|
return [
|
||||||
|
InlineProviderSpec(
|
||||||
|
api=Api.evals,
|
||||||
|
provider_type="meta-reference",
|
||||||
|
pip_packages=[
|
||||||
|
"matplotlib",
|
||||||
|
"pillow",
|
||||||
|
"pandas",
|
||||||
|
"scikit-learn",
|
||||||
|
],
|
||||||
|
module="llama_stack.providers.impls.meta_reference.evals",
|
||||||
|
config_class="llama_stack.providers.impls.meta_reference.evals.MetaReferenceEvalsImplConfig",
|
||||||
|
api_dependencies=[
|
||||||
|
Api.inference,
|
||||||
|
],
|
||||||
|
),
|
||||||
|
]
|
|
@ -10,7 +10,11 @@ apis_to_serve:
|
||||||
- memory_banks
|
- memory_banks
|
||||||
- inference
|
- inference
|
||||||
- safety
|
- safety
|
||||||
|
- evals
|
||||||
api_providers:
|
api_providers:
|
||||||
|
evals:
|
||||||
|
provider_type: meta-reference
|
||||||
|
config: {}
|
||||||
inference:
|
inference:
|
||||||
providers:
|
providers:
|
||||||
- meta-reference
|
- meta-reference
|
||||||
|
@ -34,12 +38,12 @@ routing_table:
|
||||||
inference:
|
inference:
|
||||||
- provider_type: meta-reference
|
- provider_type: meta-reference
|
||||||
config:
|
config:
|
||||||
model: Llama3.1-8B-Instruct
|
model: Llama3.2-1B
|
||||||
quantization: null
|
quantization: null
|
||||||
torch_seed: null
|
torch_seed: null
|
||||||
max_seq_len: 4096
|
max_seq_len: 4096
|
||||||
max_batch_size: 1
|
max_batch_size: 1
|
||||||
routing_key: Llama3.1-8B-Instruct
|
routing_key: Llama3.2-1B
|
||||||
safety:
|
safety:
|
||||||
- provider_type: meta-reference
|
- provider_type: meta-reference
|
||||||
config:
|
config:
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue