forked from phoenix-oss/llama-stack-mirror
* [1/n] migrate inference/chat_completion * migrate inference/completion * inference/completion * inference regenerate openapi spec * safety api * migrate agentic system * migrate apis without implementations * re-generate openapi spec * remove hack from openapi generator * fix inference * fix inference * openapi generator rerun * Simplified Telemetry API and tying it to logger (#57) * Simplified Telemetry API and tying it to logger * small update which adds a METRIC type * move span events one level down into structured log events --------- Co-authored-by: Ashwin Bharambe <ashwin@meta.com> * fix api to work with openapi generator * fix agentic calling inference * together adapter inference * update inference adapters --------- Co-authored-by: Ashwin Bharambe <ashwin.bharambe@gmail.com> Co-authored-by: Ashwin Bharambe <ashwin@meta.com>
55 lines
1.4 KiB
Python
55 lines
1.4 KiB
Python
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
# All rights reserved.
|
|
#
|
|
# This source code is licensed under the terms described in the LICENSE file in
|
|
# the root directory of this source tree.
|
|
|
|
from typing import List, Protocol, Union
|
|
|
|
from llama_models.schema_utils import json_schema_type, webmethod
|
|
|
|
from pydantic import BaseModel
|
|
|
|
from llama_models.llama3.api.datatypes import * # noqa: F403
|
|
|
|
|
|
@json_schema_type
|
|
class ScoredMessage(BaseModel):
|
|
message: Message
|
|
score: float
|
|
|
|
|
|
@json_schema_type
|
|
class DialogGenerations(BaseModel):
|
|
dialog: List[Message]
|
|
sampled_generations: List[Message]
|
|
|
|
|
|
@json_schema_type
|
|
class ScoredDialogGenerations(BaseModel):
|
|
dialog: List[Message]
|
|
scored_generations: List[ScoredMessage]
|
|
|
|
|
|
@json_schema_type
|
|
class RewardScoringRequest(BaseModel):
|
|
"""Request to score a reward function. A list of prompts and a list of responses per prompt."""
|
|
|
|
dialog_generations: List[DialogGenerations]
|
|
model: str
|
|
|
|
|
|
@json_schema_type
|
|
class RewardScoringResponse(BaseModel):
|
|
"""Response from the reward scoring. Batch of (prompt, response, score) tuples that pass the threshold."""
|
|
|
|
scored_generations: List[ScoredDialogGenerations]
|
|
|
|
|
|
class RewardScoring(Protocol):
|
|
@webmethod(route="/reward_scoring/score")
|
|
def reward_score(
|
|
self,
|
|
dialog_generations: List[DialogGenerations],
|
|
model: str,
|
|
) -> Union[RewardScoringResponse]: ...
|