This commit is contained in:
Xi Yan 2025-03-16 18:30:06 -07:00
parent d9264a0925
commit d34b70e3ab
2 changed files with 246 additions and 2 deletions

View file

@ -0,0 +1,238 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .graders import * # noqa: F401 F403
from enum import Enum
from typing import (
Annotated,
Any,
Dict,
List,
Literal,
Optional,
Protocol,
runtime_checkable,
Union,
)
from llama_stack.apis.datasets import DatasetPurpose
from llama_stack.apis.resource import Resource, ResourceType
from llama_stack.schema_utils import json_schema_type, register_schema, webmethod
from pydantic import BaseModel, Field
class GraderType(Enum):
"""
A type of grader. Each type is a criteria for evaluating answers.
"""
llm = "llm"
regex_parser = "regex_parser"
equality = "equality"
subset_of = "subset_of"
factuality = "factuality"
faithfulness = "faithfulness"
@json_schema_type
class GraderTypeInfo(BaseModel):
"""
:param type: The type of grader.
:param description: A description of the grader type.
- E.g. Write your custom judge prompt to score the answer.
:param supported_dataset_purposes: The purposes that this grader can be used for.
"""
grader_type: GraderType
description: str
supported_dataset_purposes: List[DatasetPurpose] = Field(
description="The supported purposes (supported dataset schema) that this grader can be used for. E.g. eval/question-answer",
default_factory=list,
)
class AggregationFunctionType(Enum):
"""
A type of aggregation function.
:cvar average: Average the scores of each row.
:cvar median: Median the scores of each row.
:cvar categorical_count: Count the number of rows that match each category.
:cvar accuracy: Number of correct results over total results.
"""
average = "average"
median = "median"
categorical_count = "categorical_count"
accuracy = "accuracy"
class BasicGraderParams(BaseModel):
aggregation_functions: List[AggregationFunctionType]
class LlmGraderParams(BaseModel):
model: str
prompt: str
score_regexes: List[str]
aggregation_functions: List[AggregationFunctionType]
class RegexParserGraderParams(BaseModel):
parsing_regexes: List[str]
aggregation_functions: List[AggregationFunctionType]
@json_schema_type
class LlmGrader(BaseModel):
type: Literal[GraderType.llm.value] = GraderType.llm.value
llm: LlmGraderParams
@json_schema_type
class RegexParserGrader(BaseModel):
type: Literal[GraderType.regex_parser.value] = GraderType.regex_parser.value
regex_parser: RegexParserGraderParams
@json_schema_type
class EqualityGrader(BaseModel):
type: Literal[GraderType.equality.value] = GraderType.equality.value
equality: BasicGraderParams
@json_schema_type
class SubsetOfGrader(BaseModel):
type: Literal[GraderType.subset_of.value] = GraderType.subset_of.value
subset_of: BasicGraderParams
@json_schema_type
class FactualityGrader(BaseModel):
type: Literal[GraderType.factuality.value] = GraderType.factuality.value
factuality: BasicGraderParams
@json_schema_type
class FaithfulnessGrader(BaseModel):
type: Literal[GraderType.faithfulness.value] = GraderType.faithfulness.value
faithfulness: BasicGraderParams
GraderDefinition = register_schema(
Annotated[
Union[
LlmGrader,
RegexParserGrader,
EqualityGrader,
SubsetOfGrader,
FactualityGrader,
FaithfulnessGrader,
],
Field(discriminator="type"),
],
name="GraderDefinition",
)
class CommonGraderFields(BaseModel):
grader: GraderDefinition
description: Optional[str] = None
metadata: Dict[str, Any] = Field(
default_factory=dict,
description="Any additional metadata for this definition",
)
@json_schema_type
class Grader(CommonGraderFields, Resource):
type: Literal[ResourceType.grader.value] = ResourceType.grader.value
@property
def grader_id(self) -> str:
return self.identifier
@property
def provider_grader_id(self) -> str:
return self.provider_resource_id
class GraderInput(CommonGraderFields, BaseModel):
grader_id: str
provider_id: Optional[str] = None
provider_grader_id: Optional[str] = None
class ListGradersResponse(BaseModel):
data: List[Grader]
class ListGraderTypesResponse(BaseModel):
data: List[GraderTypeInfo]
@runtime_checkable
class Graders(Protocol):
@webmethod(route="/graders", method="POST")
async def register_grader(
self,
grader: GraderDefinition,
grader_id: Optional[str] = None,
metadata: Optional[Dict[str, Any]] = None,
) -> Grader:
"""
Register a new grader.
:param grader: The grader definition, E.g.
- {
"type": "llm",
"llm": {
"model": "llama-405b",
"prompt": "You are a judge. Score the answer based on the question. {question} {answer}",
}
}
:param grader_id: (Optional) The ID of the grader. If not provided, a random ID will be generated.
:param metadata: (Optional) Any additional metadata for this grader.
- E.g. {
"description": "A grader that scores the answer based on the question.",
}
:return: The registered grader.
"""
...
@webmethod(route="/graders", method="GET")
async def list_graders(self) -> List[Grader]:
"""
List all graders.
:return: A list of graders.
"""
...
@webmethod(route="/graders/{grader_id:path}", method="GET")
async def get_grader(self, grader_id: str) -> Grader:
"""
Get a grader by ID.
:param grader_id: The ID of the grader.
:return: The grader.
"""
...
@webmethod(route="/graders/{grader_id:path}", method="DELETE")
async def delete_grader(self, grader_id: str) -> None:
"""
Delete a grader by ID.
:param grader_id: The ID of the grader.
"""
...
@webmethod(route="/graders/types", method="GET")
async def list_grader_types(self) -> ListGraderTypesResponse:
"""
List all grader types.
:return: A list of grader types and information about the types.
"""
...

View file

@ -14,6 +14,8 @@ class ResourceType(Enum):
shield = "shield"
vector_db = "vector_db"
dataset = "dataset"
grader = "grader"
# TODO: migrate scoring_function -> grader
scoring_function = "scoring_function"
benchmark = "benchmark"
tool = "tool"
@ -23,7 +25,9 @@ class ResourceType(Enum):
class Resource(BaseModel):
"""Base class for all Llama Stack resources"""
identifier: str = Field(description="Unique identifier for this resource in llama stack")
identifier: str = Field(
description="Unique identifier for this resource in llama stack"
)
provider_resource_id: str = Field(
description="Unique identifier for this resource in the provider",
@ -32,4 +36,6 @@ class Resource(BaseModel):
provider_id: str = Field(description="ID of the provider that owns this resource")
type: ResourceType = Field(description="Type of resource (e.g. 'model', 'shield', 'vector_db', etc.)")
type: ResourceType = Field(
description="Type of resource (e.g. 'model', 'shield', 'vector_db', etc.)"
)