mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-10-04 12:07:34 +00:00
rename toolchain/ --> llama_toolchain/
This commit is contained in:
parent
d95f5f863d
commit
f9111652ef
73 changed files with 36 additions and 37 deletions
2
llama_toolchain/reward_scoring/api/__init__.py
Normal file
2
llama_toolchain/reward_scoring/api/__init__.py
Normal file
|
@ -0,0 +1,2 @@
|
|||
from .datatypes import * # noqa: F401 F403
|
||||
from .endpoints import * # noqa: F401 F403
|
25
llama_toolchain/reward_scoring/api/datatypes.py
Normal file
25
llama_toolchain/reward_scoring/api/datatypes.py
Normal file
|
@ -0,0 +1,25 @@
|
|||
from typing import List
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from strong_typing.schema import json_schema_type
|
||||
|
||||
from llama_models.llama3_1.api.datatypes import * # noqa: F403
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class ScoredMessage(BaseModel):
|
||||
message: Message
|
||||
score: float
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class DialogGenerations(BaseModel):
|
||||
dialog: List[Message]
|
||||
sampled_generations: List[Message]
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class ScoredDialogGenerations(BaseModel):
|
||||
dialog: List[Message]
|
||||
scored_generations: List[ScoredMessage]
|
27
llama_toolchain/reward_scoring/api/endpoints.py
Normal file
27
llama_toolchain/reward_scoring/api/endpoints.py
Normal file
|
@ -0,0 +1,27 @@
|
|||
from typing import List, Protocol, Union
|
||||
from .datatypes import * # noqa: F403
|
||||
|
||||
from pyopenapi import webmethod
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class RewardScoringRequest(BaseModel):
|
||||
"""Request to score a reward function. A list of prompts and a list of responses per prompt."""
|
||||
|
||||
dialog_generations: List[DialogGenerations]
|
||||
model: RewardModel
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class RewardScoringResponse(BaseModel):
|
||||
"""Response from the reward scoring. Batch of (prompt, response, score) tuples that pass the threshold."""
|
||||
|
||||
scored_generations: List[ScoredDialogGenerations]
|
||||
|
||||
|
||||
class RewardScoring(Protocol):
|
||||
@webmethod(route="/reward_scoring/score")
|
||||
def post_score(
|
||||
self,
|
||||
request: RewardScoringRequest,
|
||||
) -> Union[RewardScoringResponse]: ...
|
Loading…
Add table
Add a link
Reference in a new issue