migrate apis without implementations

This commit is contained in:
Xi Yan 2024-09-11 14:15:13 -07:00
parent 6049aada71
commit a3081f28fc
7 changed files with 90 additions and 18 deletions

View file

@ -51,11 +51,21 @@ class BatchInference(Protocol):
@webmethod(route="/batch_inference/completion") @webmethod(route="/batch_inference/completion")
async def batch_completion( async def batch_completion(
self, self,
request: BatchCompletionRequest, model: str,
content_batch: List[InterleavedTextMedia],
sampling_params: Optional[SamplingParams] = SamplingParams(),
logprobs: Optional[LogProbConfig] = None,
) -> BatchCompletionResponse: ... ) -> BatchCompletionResponse: ...
@webmethod(route="/batch_inference/chat_completion") @webmethod(route="/batch_inference/chat_completion")
async def batch_chat_completion( async def batch_chat_completion(
self, self,
request: BatchChatCompletionRequest, model: str,
messages_batch: List[List[Message]],
sampling_params: Optional[SamplingParams] = SamplingParams(),
# zero-shot tool definitions as input to the model
tools: Optional[List[ToolDefinition]] = list,
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json,
logprobs: Optional[LogProbConfig] = None,
) -> BatchChatCompletionResponse: ... ) -> BatchChatCompletionResponse: ...

View file

@ -46,7 +46,8 @@ class Datasets(Protocol):
@webmethod(route="/datasets/create") @webmethod(route="/datasets/create")
def create_dataset( def create_dataset(
self, self,
request: CreateDatasetRequest, uuid: str,
dataset: TrainEvalDataset,
) -> None: ... ) -> None: ...
@webmethod(route="/datasets/get") @webmethod(route="/datasets/get")

View file

@ -86,19 +86,19 @@ class Evaluations(Protocol):
@webmethod(route="/evaluate/text_generation/") @webmethod(route="/evaluate/text_generation/")
def evaluate_text_generation( def evaluate_text_generation(
self, self,
request: EvaluateTextGenerationRequest, metrics: List[TextGenerationMetric],
) -> EvaluationJob: ... ) -> EvaluationJob: ...
@webmethod(route="/evaluate/question_answering/") @webmethod(route="/evaluate/question_answering/")
def evaluate_question_answering( def evaluate_question_answering(
self, self,
request: EvaluateQuestionAnsweringRequest, metrics: List[QuestionAnsweringMetric],
) -> EvaluationJob: ... ) -> EvaluationJob: ...
@webmethod(route="/evaluate/summarization/") @webmethod(route="/evaluate/summarization/")
def evaluate_summarization( def evaluate_summarization(
self, self,
request: EvaluateSummarizationRequest, metrics: List[SummarizationMetric],
) -> EvaluationJob: ... ) -> EvaluationJob: ...
@webmethod(route="/evaluate/jobs") @webmethod(route="/evaluate/jobs")

View file

@ -179,13 +179,33 @@ class PostTraining(Protocol):
@webmethod(route="/post_training/supervised_fine_tune") @webmethod(route="/post_training/supervised_fine_tune")
def supervised_fine_tune( def supervised_fine_tune(
self, self,
request: PostTrainingSFTRequest, job_uuid: str,
model: str,
dataset: TrainEvalDataset,
validation_dataset: TrainEvalDataset,
algorithm: FinetuningAlgorithm,
algorithm_config: Union[
LoraFinetuningConfig, QLoraFinetuningConfig, DoraFinetuningConfig
],
optimizer_config: OptimizerConfig,
training_config: TrainingConfig,
hyperparam_search_config: Dict[str, Any],
logger_config: Dict[str, Any],
) -> PostTrainingJob: ... ) -> PostTrainingJob: ...
@webmethod(route="/post_training/preference_optimize") @webmethod(route="/post_training/preference_optimize")
def preference_optimize( def preference_optimize(
self, self,
request: PostTrainingRLHFRequest, job_uuid: str,
finetuned_model: URL,
dataset: TrainEvalDataset,
validation_dataset: TrainEvalDataset,
algorithm: RLHFAlgorithm,
algorithm_config: Union[DPOAlignmentConfig],
optimizer_config: OptimizerConfig,
training_config: TrainingConfig,
hyperparam_search_config: Dict[str, Any],
logger_config: Dict[str, Any],
) -> PostTrainingJob: ... ) -> PostTrainingJob: ...
@webmethod(route="/post_training/jobs") @webmethod(route="/post_training/jobs")

View file

@ -50,5 +50,6 @@ class RewardScoring(Protocol):
@webmethod(route="/reward_scoring/score") @webmethod(route="/reward_scoring/score")
def reward_score( def reward_score(
self, self,
request: RewardScoringRequest, dialog_generations: List[DialogGenerations],
model: str,
) -> Union[RewardScoringResponse]: ... ) -> Union[RewardScoringResponse]: ...

View file

@ -48,5 +48,7 @@ class SyntheticDataGeneration(Protocol):
@webmethod(route="/synthetic_data_generation/generate") @webmethod(route="/synthetic_data_generation/generate")
def synthetic_data_generate( def synthetic_data_generate(
self, self,
request: SyntheticDataGenerationRequest, dialogs: List[Message],
filtering_function: FilteringFunction = FilteringFunction.none,
model: Optional[str] = None,
) -> Union[SyntheticDataGenerationResponse]: ... ) -> Union[SyntheticDataGenerationResponse]: ...

View file

@ -136,7 +136,11 @@ class LogSearchRequest(BaseModel):
class Telemetry(Protocol): class Telemetry(Protocol):
@webmethod(route="/experiments/create") @webmethod(route="/experiments/create")
def create_experiment(self, request: CreateExperimentRequest) -> Experiment: ... def create_experiment(
self,
name: str,
metadata: Optional[Dict[str, Any]] = None,
) -> Experiment: ...
@webmethod(route="/experiments/list") @webmethod(route="/experiments/list")
def list_experiments(self) -> List[Experiment]: ... def list_experiments(self) -> List[Experiment]: ...
@ -145,28 +149,62 @@ class Telemetry(Protocol):
def get_experiment(self, experiment_id: str) -> Experiment: ... def get_experiment(self, experiment_id: str) -> Experiment: ...
@webmethod(route="/experiments/update") @webmethod(route="/experiments/update")
def update_experiment(self, request: UpdateExperimentRequest) -> Experiment: ... def update_experiment(
self,
experiment_id: str,
status: Optional[ExperimentStatus] = None,
metadata: Optional[Dict[str, Any]] = None,
) -> Experiment: ...
@webmethod(route="/experiments/create_run") @webmethod(route="/experiments/create_run")
def create_run(self, request: CreateRunRequest) -> Run: ... def create_run(
self,
experiment_id: str,
metadata: Optional[Dict[str, Any]] = None,
) -> Run: ...
@webmethod(route="/runs/update") @webmethod(route="/runs/update")
def update_run(self, request: UpdateRunRequest) -> Run: ... def update_run(
self,
run_id: str,
status: Optional[str] = None,
ended_at: Optional[datetime] = None,
metadata: Optional[Dict[str, Any]] = None,
) -> Run: ...
@webmethod(route="/runs/log_metrics") @webmethod(route="/runs/log_metrics")
def log_metrics(self, request: LogMetricsRequest) -> None: ... def log_metrics(
self,
run_id: str,
metrics: List[Metric],
) -> None: ...
@webmethod(route="/runs/metrics", method="GET") @webmethod(route="/runs/metrics", method="GET")
def get_metrics(self, run_id: str) -> List[Metric]: ... def get_metrics(self, run_id: str) -> List[Metric]: ...
@webmethod(route="/logging/log_messages") @webmethod(route="/logging/log_messages")
def log_messages(self, request: LogMessagesRequest) -> None: ... def log_messages(
self,
logs: List[Log],
run_id: Optional[str] = None,
) -> None: ...
@webmethod(route="/logging/get_logs") @webmethod(route="/logging/get_logs")
def get_logs(self, request: LogSearchRequest) -> List[Log]: ... def get_logs(
self,
query: str,
filters: Optional[Dict[str, Any]] = None,
) -> List[Log]: ...
@webmethod(route="/experiments/artifacts/upload") @webmethod(route="/experiments/artifacts/upload")
def upload_artifact(self, request: UploadArtifactRequest) -> Artifact: ... def upload_artifact(
self,
experiment_id: str,
name: str,
artifact_type: str,
content: bytes,
metadata: Optional[Dict[str, Any]] = None,
) -> Artifact: ...
@webmethod(route="/experiments/artifacts/get") @webmethod(route="/experiments/artifacts/get")
def list_artifacts(self, experiment_id: str) -> List[Artifact]: ... def list_artifacts(self, experiment_id: str) -> List[Artifact]: ...