diff --git a/llama_toolchain/batch_inference/api/api.py b/llama_toolchain/batch_inference/api/api.py index a02815388..3d67120dd 100644 --- a/llama_toolchain/batch_inference/api/api.py +++ b/llama_toolchain/batch_inference/api/api.py @@ -51,11 +51,21 @@ class BatchInference(Protocol): @webmethod(route="/batch_inference/completion") async def batch_completion( self, - request: BatchCompletionRequest, + model: str, + content_batch: List[InterleavedTextMedia], + sampling_params: Optional[SamplingParams] = SamplingParams(), + logprobs: Optional[LogProbConfig] = None, ) -> BatchCompletionResponse: ... @webmethod(route="/batch_inference/chat_completion") async def batch_chat_completion( self, - request: BatchChatCompletionRequest, + model: str, + messages_batch: List[List[Message]], + sampling_params: Optional[SamplingParams] = SamplingParams(), + # zero-shot tool definitions as input to the model + tools: Optional[List[ToolDefinition]] = list, + tool_choice: Optional[ToolChoice] = ToolChoice.auto, + tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json, + logprobs: Optional[LogProbConfig] = None, ) -> BatchChatCompletionResponse: ... diff --git a/llama_toolchain/dataset/api/api.py b/llama_toolchain/dataset/api/api.py index c22fc01b0..2fa8bb4e5 100644 --- a/llama_toolchain/dataset/api/api.py +++ b/llama_toolchain/dataset/api/api.py @@ -46,7 +46,8 @@ class Datasets(Protocol): @webmethod(route="/datasets/create") def create_dataset( self, - request: CreateDatasetRequest, + uuid: str, + dataset: TrainEvalDataset, ) -> None: ... @webmethod(route="/datasets/get") diff --git a/llama_toolchain/evaluations/api/api.py b/llama_toolchain/evaluations/api/api.py index b8f3fa825..898dc2822 100644 --- a/llama_toolchain/evaluations/api/api.py +++ b/llama_toolchain/evaluations/api/api.py @@ -86,19 +86,19 @@ class Evaluations(Protocol): @webmethod(route="/evaluate/text_generation/") def evaluate_text_generation( self, - request: EvaluateTextGenerationRequest, + metrics: List[TextGenerationMetric], ) -> EvaluationJob: ... @webmethod(route="/evaluate/question_answering/") def evaluate_question_answering( self, - request: EvaluateQuestionAnsweringRequest, + metrics: List[QuestionAnsweringMetric], ) -> EvaluationJob: ... @webmethod(route="/evaluate/summarization/") def evaluate_summarization( self, - request: EvaluateSummarizationRequest, + metrics: List[SummarizationMetric], ) -> EvaluationJob: ... @webmethod(route="/evaluate/jobs") diff --git a/llama_toolchain/post_training/api/api.py b/llama_toolchain/post_training/api/api.py index 447a729fb..378515f83 100644 --- a/llama_toolchain/post_training/api/api.py +++ b/llama_toolchain/post_training/api/api.py @@ -179,13 +179,33 @@ class PostTraining(Protocol): @webmethod(route="/post_training/supervised_fine_tune") def supervised_fine_tune( self, - request: PostTrainingSFTRequest, + job_uuid: str, + model: str, + dataset: TrainEvalDataset, + validation_dataset: TrainEvalDataset, + algorithm: FinetuningAlgorithm, + algorithm_config: Union[ + LoraFinetuningConfig, QLoraFinetuningConfig, DoraFinetuningConfig + ], + optimizer_config: OptimizerConfig, + training_config: TrainingConfig, + hyperparam_search_config: Dict[str, Any], + logger_config: Dict[str, Any], ) -> PostTrainingJob: ... @webmethod(route="/post_training/preference_optimize") def preference_optimize( self, - request: PostTrainingRLHFRequest, + job_uuid: str, + finetuned_model: URL, + dataset: TrainEvalDataset, + validation_dataset: TrainEvalDataset, + algorithm: RLHFAlgorithm, + algorithm_config: Union[DPOAlignmentConfig], + optimizer_config: OptimizerConfig, + training_config: TrainingConfig, + hyperparam_search_config: Dict[str, Any], + logger_config: Dict[str, Any], ) -> PostTrainingJob: ... @webmethod(route="/post_training/jobs") diff --git a/llama_toolchain/reward_scoring/api/api.py b/llama_toolchain/reward_scoring/api/api.py index c91931f09..9d689f232 100644 --- a/llama_toolchain/reward_scoring/api/api.py +++ b/llama_toolchain/reward_scoring/api/api.py @@ -50,5 +50,6 @@ class RewardScoring(Protocol): @webmethod(route="/reward_scoring/score") def reward_score( self, - request: RewardScoringRequest, + dialog_generations: List[DialogGenerations], + model: str, ) -> Union[RewardScoringResponse]: ... diff --git a/llama_toolchain/synthetic_data_generation/api/api.py b/llama_toolchain/synthetic_data_generation/api/api.py index 44b8327a9..9a6c487af 100644 --- a/llama_toolchain/synthetic_data_generation/api/api.py +++ b/llama_toolchain/synthetic_data_generation/api/api.py @@ -48,5 +48,7 @@ class SyntheticDataGeneration(Protocol): @webmethod(route="/synthetic_data_generation/generate") def synthetic_data_generate( self, - request: SyntheticDataGenerationRequest, + dialogs: List[Message], + filtering_function: FilteringFunction = FilteringFunction.none, + model: Optional[str] = None, ) -> Union[SyntheticDataGenerationResponse]: ... diff --git a/llama_toolchain/telemetry/api/api.py b/llama_toolchain/telemetry/api/api.py index ae784428b..eec34a596 100644 --- a/llama_toolchain/telemetry/api/api.py +++ b/llama_toolchain/telemetry/api/api.py @@ -136,7 +136,11 @@ class LogSearchRequest(BaseModel): class Telemetry(Protocol): @webmethod(route="/experiments/create") - def create_experiment(self, request: CreateExperimentRequest) -> Experiment: ... + def create_experiment( + self, + name: str, + metadata: Optional[Dict[str, Any]] = None, + ) -> Experiment: ... @webmethod(route="/experiments/list") def list_experiments(self) -> List[Experiment]: ... @@ -145,28 +149,62 @@ class Telemetry(Protocol): def get_experiment(self, experiment_id: str) -> Experiment: ... @webmethod(route="/experiments/update") - def update_experiment(self, request: UpdateExperimentRequest) -> Experiment: ... + def update_experiment( + self, + experiment_id: str, + status: Optional[ExperimentStatus] = None, + metadata: Optional[Dict[str, Any]] = None, + ) -> Experiment: ... @webmethod(route="/experiments/create_run") - def create_run(self, request: CreateRunRequest) -> Run: ... + def create_run( + self, + experiment_id: str, + metadata: Optional[Dict[str, Any]] = None, + ) -> Run: ... @webmethod(route="/runs/update") - def update_run(self, request: UpdateRunRequest) -> Run: ... + def update_run( + self, + run_id: str, + status: Optional[str] = None, + ended_at: Optional[datetime] = None, + metadata: Optional[Dict[str, Any]] = None, + ) -> Run: ... @webmethod(route="/runs/log_metrics") - def log_metrics(self, request: LogMetricsRequest) -> None: ... + def log_metrics( + self, + run_id: str, + metrics: List[Metric], + ) -> None: ... @webmethod(route="/runs/metrics", method="GET") def get_metrics(self, run_id: str) -> List[Metric]: ... @webmethod(route="/logging/log_messages") - def log_messages(self, request: LogMessagesRequest) -> None: ... + def log_messages( + self, + logs: List[Log], + run_id: Optional[str] = None, + ) -> None: ... @webmethod(route="/logging/get_logs") - def get_logs(self, request: LogSearchRequest) -> List[Log]: ... + def get_logs( + self, + query: str, + filters: Optional[Dict[str, Any]] = None, + ) -> List[Log]: ... @webmethod(route="/experiments/artifacts/upload") - def upload_artifact(self, request: UploadArtifactRequest) -> Artifact: ... + def upload_artifact( + self, + experiment_id: str, + name: str, + artifact_type: str, + content: bytes, + metadata: Optional[Dict[str, Any]] = None, + ) -> Artifact: ... @webmethod(route="/experiments/artifacts/get") def list_artifacts(self, experiment_id: str) -> List[Artifact]: ...