reward scoring model enum

This commit is contained in:
Raghotham Murthy 2024-07-10 21:59:01 -07:00
parent ebb59aa35f
commit 6ec7c47938
4 changed files with 24 additions and 17 deletions

View file

@ -29,6 +29,7 @@ from model_types import (
InstructModel, InstructModel,
Message, Message,
PretrainedModel, PretrainedModel,
RewardModel,
SamplingParams, SamplingParams,
ShieldConfig, ShieldConfig,
StopReason, StopReason,
@ -245,9 +246,7 @@ class RewardScoringRequest:
"""Request to score a reward function. A list of prompts and a list of responses per prompt.""" """Request to score a reward function. A list of prompts and a list of responses per prompt."""
prompt_generations: List[KPromptGenerations] prompt_generations: List[KPromptGenerations]
model: RewardModel
# TODO(ragho): create a RewardModel enum tye
model: str
@json_schema_type @json_schema_type

View file

@ -137,3 +137,6 @@ class PretrainedModel(Enum):
class InstructModel(Enum): class InstructModel(Enum):
llama3_8b_chat = "llama3_8b_chat" llama3_8b_chat = "llama3_8b_chat"
llama3_70b_chat = "llama3_70b_chat" llama3_70b_chat = "llama3_70b_chat"
class RewardModel(Enum):
llama3_405b_reward = "llama3_405b_reward"

View file

@ -1976,7 +1976,10 @@
} }
}, },
"model": { "model": {
"type": "string" "type": "string",
"enum": [
"llama3_405b_reward"
]
} }
}, },
"additionalProperties": false, "additionalProperties": false,
@ -2312,23 +2315,23 @@
} }
], ],
"tags": [ "tags": [
{
"name": "RewardScoring"
},
{ {
"name": "Inference" "name": "Inference"
}, },
{
"name": "SyntheticDataGeneration"
},
{
"name": "Datasets"
},
{ {
"name": "AgenticSystem" "name": "AgenticSystem"
}, },
{ {
"name": "Finetuning" "name": "Datasets"
}, },
{ {
"name": "RewardScoring" "name": "SyntheticDataGeneration"
},
{
"name": "Finetuning"
}, },
{ {
"name": "ShieldConfig", "name": "ShieldConfig",

View file

@ -1039,6 +1039,8 @@ components:
additionalProperties: false additionalProperties: false
properties: properties:
model: model:
enum:
- llama3_405b_reward
type: string type: string
prompt_generations: prompt_generations:
items: items:
@ -1412,12 +1414,12 @@ security:
servers: servers:
- url: http://llama.meta.com - url: http://llama.meta.com
tags: tags:
- name: Inference
- name: SyntheticDataGeneration
- name: Datasets
- name: AgenticSystem
- name: Finetuning
- name: RewardScoring - name: RewardScoring
- name: Inference
- name: AgenticSystem
- name: Datasets
- name: SyntheticDataGeneration
- name: Finetuning
- description: <SchemaDefinition schemaRef="#/components/schemas/ShieldConfig" /> - description: <SchemaDefinition schemaRef="#/components/schemas/ShieldConfig" />
name: ShieldConfig name: ShieldConfig
- description: <SchemaDefinition schemaRef="#/components/schemas/AgenticSystemCreateRequest" - description: <SchemaDefinition schemaRef="#/components/schemas/AgenticSystemCreateRequest"