mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-29 07:14:20 +00:00
wip api
This commit is contained in:
parent
7143ecfc0d
commit
8339b2cef3
10 changed files with 174 additions and 51 deletions
57
llama_stack/apis/evals/client.py
Normal file
57
llama_stack/apis/evals/client.py
Normal file
|
@ -0,0 +1,57 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import asyncio
|
||||
|
||||
import fire
|
||||
import httpx
|
||||
from termcolor import cprint
|
||||
|
||||
from .evals import * # noqa: F403
|
||||
|
||||
|
||||
class EvaluationClient(Evals):
|
||||
def __init__(self, base_url: str):
|
||||
self.base_url = base_url
|
||||
|
||||
async def initialize(self) -> None:
|
||||
pass
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
pass
|
||||
|
||||
async def run_evals(self, model: str, dataset: str, task: str) -> EvaluateResponse:
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.post(
|
||||
f"{self.base_url}/evals/run",
|
||||
json={
|
||||
"model": model,
|
||||
"dataset": dataset,
|
||||
"task": task,
|
||||
},
|
||||
headers={"Content-Type": "application/json"},
|
||||
)
|
||||
response.raise_for_status()
|
||||
return EvaluateResponse(**response.json())
|
||||
|
||||
|
||||
async def run_main(host: str, port: int):
|
||||
client = EvaluationClient(f"http://{host}:{port}")
|
||||
|
||||
response = await client.run_evals(
|
||||
"Llama3.1-8B-Instruct",
|
||||
"mmlu.csv",
|
||||
"mmlu",
|
||||
)
|
||||
cprint(f"evaluate response={response}", "green")
|
||||
|
||||
|
||||
def main(host: str, port: int):
|
||||
asyncio.run(run_main(host, port))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
fire.Fire(main)
|
|
@ -4,7 +4,6 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from enum import Enum
|
||||
from typing import List, Protocol
|
||||
|
||||
from llama_models.schema_utils import webmethod
|
||||
|
@ -16,22 +15,6 @@ from llama_stack.apis.dataset import * # noqa: F403
|
|||
from llama_stack.apis.common.training_types import * # noqa: F403
|
||||
|
||||
|
||||
class TextGenerationMetric(Enum):
|
||||
perplexity = "perplexity"
|
||||
rouge = "rouge"
|
||||
bleu = "bleu"
|
||||
|
||||
|
||||
class QuestionAnsweringMetric(Enum):
|
||||
em = "em"
|
||||
f1 = "f1"
|
||||
|
||||
|
||||
class SummarizationMetric(Enum):
|
||||
rouge = "rouge"
|
||||
bleu = "bleu"
|
||||
|
||||
|
||||
class EvaluationJob(BaseModel):
|
||||
job_uuid: str
|
||||
|
||||
|
@ -54,28 +37,7 @@ class EvaluateTaskRequestCommon(BaseModel):
|
|||
class EvaluateResponse(BaseModel):
|
||||
"""Scores for evaluation."""
|
||||
|
||||
scores = Dict[str, str]
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class EvaluateTextGenerationRequest(EvaluateTaskRequestCommon):
|
||||
"""Request to evaluate text generation."""
|
||||
|
||||
metrics: List[TextGenerationMetric]
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class EvaluateQuestionAnsweringRequest(EvaluateTaskRequestCommon):
|
||||
"""Request to evaluate question answering."""
|
||||
|
||||
metrics: List[QuestionAnsweringMetric]
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class EvaluateSummarizationRequest(EvaluateTaskRequestCommon):
|
||||
"""Request to evaluate summarization."""
|
||||
|
||||
metrics: List[SummarizationMetric]
|
||||
metrics: Dict[str, float]
|
||||
|
||||
|
||||
@json_schema_type
|
||||
|
@ -97,33 +59,36 @@ class EvaluationJobCreateResponse(BaseModel):
|
|||
job_uuid: str
|
||||
|
||||
|
||||
class Evaluations(Protocol):
|
||||
@webmethod(route="/evaluate")
|
||||
async def evaluate(
|
||||
self, model: str, dataset: str, task: str
|
||||
class Evals(Protocol):
|
||||
@webmethod(route="/evals/run")
|
||||
async def run_evals(
|
||||
self,
|
||||
model: str,
|
||||
dataset: str,
|
||||
task: str,
|
||||
) -> EvaluateResponse: ...
|
||||
|
||||
@webmethod(route="/evaluate/jobs")
|
||||
@webmethod(route="/evals/jobs")
|
||||
def get_evaluation_jobs(self) -> List[EvaluationJob]: ...
|
||||
|
||||
@webmethod(route="/evaluate/job/create")
|
||||
@webmethod(route="/evals/job/create")
|
||||
async def create_evaluation_job(
|
||||
self, model: str, dataset: str, task: str
|
||||
) -> EvaluationJob: ...
|
||||
|
||||
@webmethod(route="/evaluate/job/status")
|
||||
@webmethod(route="/evals/job/status")
|
||||
def get_evaluation_job_status(
|
||||
self, job_uuid: str
|
||||
) -> EvaluationJobStatusResponse: ...
|
||||
|
||||
# sends SSE stream of logs
|
||||
@webmethod(route="/evaluate/job/logs")
|
||||
@webmethod(route="/evals/job/logs")
|
||||
def get_evaluation_job_logstream(self, job_uuid: str) -> EvaluationJobLogStream: ...
|
||||
|
||||
@webmethod(route="/evaluate/job/cancel")
|
||||
@webmethod(route="/evals/job/cancel")
|
||||
def cancel_evaluation_job(self, job_uuid: str) -> None: ...
|
||||
|
||||
@webmethod(route="/evaluate/job/artifacts")
|
||||
@webmethod(route="/evals/job/artifacts")
|
||||
def get_evaluation_job_artifacts(
|
||||
self, job_uuid: str
|
||||
) -> EvaluationJobArtifactsResponse: ...
|
||||
|
|
|
@ -8,6 +8,7 @@ import importlib
|
|||
from typing import Any, Dict, List, Set
|
||||
|
||||
from llama_stack.distribution.datatypes import * # noqa: F403
|
||||
|
||||
from llama_stack.distribution.distribution import (
|
||||
builtin_automatically_routed_apis,
|
||||
get_provider_registry,
|
||||
|
|
|
@ -10,6 +10,7 @@ from typing import Dict, List
|
|||
from pydantic import BaseModel
|
||||
|
||||
from llama_stack.apis.agents import Agents
|
||||
from llama_stack.apis.evals import Evals
|
||||
from llama_stack.apis.inference import Inference
|
||||
from llama_stack.apis.inspect import Inspect
|
||||
from llama_stack.apis.memory import Memory
|
||||
|
@ -41,6 +42,7 @@ def get_all_api_endpoints() -> Dict[Api, List[ApiEndpoint]]:
|
|||
Api.shields: Shields,
|
||||
Api.memory_banks: MemoryBanks,
|
||||
Api.inspect: Inspect,
|
||||
Api.evals: Evals,
|
||||
}
|
||||
|
||||
for api, protocol in protocols.items():
|
||||
|
|
|
@ -24,6 +24,8 @@ class Api(Enum):
|
|||
shields = "shields"
|
||||
memory_banks = "memory_banks"
|
||||
|
||||
evals = "evals"
|
||||
|
||||
# built-in API
|
||||
inspect = "inspect"
|
||||
|
||||
|
|
19
llama_stack/providers/impls/meta_reference/evals/__init__.py
Normal file
19
llama_stack/providers/impls/meta_reference/evals/__init__.py
Normal file
|
@ -0,0 +1,19 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from .config import MetaReferenceEvalsImplConfig # noqa
|
||||
from llama_stack.apis.inference import * # noqa: F403
|
||||
from llama_stack.distribution.datatypes import Api, ProviderSpec
|
||||
|
||||
|
||||
async def get_provider_impl(
|
||||
config: MetaReferenceEvalsImplConfig, deps: Dict[Api, ProviderSpec]
|
||||
):
|
||||
from .evals import MetaReferenceEvalsImpl
|
||||
|
||||
impl = MetaReferenceEvalsImpl(config, deps[Api.inference])
|
||||
await impl.initialize()
|
||||
return impl
|
10
llama_stack/providers/impls/meta_reference/evals/config.py
Normal file
10
llama_stack/providers/impls/meta_reference/evals/config.py
Normal file
|
@ -0,0 +1,10 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class MetaReferenceEvalsImplConfig(BaseModel): ...
|
34
llama_stack/providers/impls/meta_reference/evals/evals.py
Normal file
34
llama_stack/providers/impls/meta_reference/evals/evals.py
Normal file
|
@ -0,0 +1,34 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from llama_stack.apis.inference import * # noqa: F403
|
||||
from llama_stack.apis.evals import * # noqa: F403
|
||||
|
||||
from .config import MetaReferenceEvalsImplConfig
|
||||
|
||||
|
||||
class MetaReferenceEvalsImpl(Evals):
|
||||
def __init__(self, config: MetaReferenceEvalsImplConfig, inference_api: Inference):
|
||||
self.inference_api = inference_api
|
||||
|
||||
async def initialize(self) -> None:
|
||||
pass
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
pass
|
||||
|
||||
async def run_evals(
|
||||
self,
|
||||
model: str,
|
||||
dataset: str,
|
||||
task: str,
|
||||
) -> EvaluateResponse:
|
||||
print("hi")
|
||||
return EvaluateResponse(
|
||||
metrics={
|
||||
"accuracy": 0.5,
|
||||
}
|
||||
)
|
29
llama_stack/providers/registry/evals.py
Normal file
29
llama_stack/providers/registry/evals.py
Normal file
|
@ -0,0 +1,29 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import List
|
||||
|
||||
from llama_stack.distribution.datatypes import * # noqa: F403
|
||||
|
||||
|
||||
def available_providers() -> List[ProviderSpec]:
|
||||
return [
|
||||
InlineProviderSpec(
|
||||
api=Api.evals,
|
||||
provider_type="meta-reference",
|
||||
pip_packages=[
|
||||
"matplotlib",
|
||||
"pillow",
|
||||
"pandas",
|
||||
"scikit-learn",
|
||||
],
|
||||
module="llama_stack.providers.impls.meta_reference.evals",
|
||||
config_class="llama_stack.providers.impls.meta_reference.evals.MetaReferenceEvalsImplConfig",
|
||||
api_dependencies=[
|
||||
Api.inference,
|
||||
],
|
||||
),
|
||||
]
|
|
@ -10,7 +10,11 @@ apis_to_serve:
|
|||
- memory_banks
|
||||
- inference
|
||||
- safety
|
||||
- evals
|
||||
api_providers:
|
||||
evals:
|
||||
provider_type: meta-reference
|
||||
config: {}
|
||||
inference:
|
||||
providers:
|
||||
- meta-reference
|
||||
|
@ -34,12 +38,12 @@ routing_table:
|
|||
inference:
|
||||
- provider_type: meta-reference
|
||||
config:
|
||||
model: Llama3.1-8B-Instruct
|
||||
model: Llama3.2-1B
|
||||
quantization: null
|
||||
torch_seed: null
|
||||
max_seq_len: 4096
|
||||
max_batch_size: 1
|
||||
routing_key: Llama3.1-8B-Instruct
|
||||
routing_key: Llama3.2-1B
|
||||
safety:
|
||||
- provider_type: meta-reference
|
||||
config:
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue