generate openapi

This commit is contained in:
Xi Yan 2024-10-24 17:41:15 -07:00
parent cdfd584a8f
commit ec7c8f95de
6 changed files with 2854 additions and 1225 deletions

View file

@ -33,14 +33,16 @@ schema_utils.json_schema_type = json_schema_type
from llama_models.llama3.api.datatypes import * # noqa: F403 from llama_models.llama3.api.datatypes import * # noqa: F403
from llama_stack.apis.agents import * # noqa: F403 from llama_stack.apis.agents import * # noqa: F403
from llama_stack.apis.dataset import * # noqa: F403 from llama_stack.apis.datasets import * # noqa: F403
from llama_stack.apis.evals import * # noqa: F403 from llama_stack.apis.datasetio import * # noqa: F403
from llama_stack.apis.scoring import * # noqa: F403
from llama_stack.apis.scoring_functions import * # noqa: F403
from llama_stack.apis.eval import * # noqa: F403
from llama_stack.apis.inference import * # noqa: F403 from llama_stack.apis.inference import * # noqa: F403
from llama_stack.apis.batch_inference import * # noqa: F403 from llama_stack.apis.batch_inference import * # noqa: F403
from llama_stack.apis.memory import * # noqa: F403 from llama_stack.apis.memory import * # noqa: F403
from llama_stack.apis.telemetry import * # noqa: F403 from llama_stack.apis.telemetry import * # noqa: F403
from llama_stack.apis.post_training import * # noqa: F403 from llama_stack.apis.post_training import * # noqa: F403
from llama_stack.apis.reward_scoring import * # noqa: F403
from llama_stack.apis.synthetic_data_generation import * # noqa: F403 from llama_stack.apis.synthetic_data_generation import * # noqa: F403
from llama_stack.apis.safety import * # noqa: F403 from llama_stack.apis.safety import * # noqa: F403
from llama_stack.apis.models import * # noqa: F403 from llama_stack.apis.models import * # noqa: F403
@ -54,14 +56,16 @@ class LlamaStack(
Inference, Inference,
BatchInference, BatchInference,
Agents, Agents,
RewardScoring,
Safety, Safety,
SyntheticDataGeneration, SyntheticDataGeneration,
Datasets, Datasets,
Telemetry, Telemetry,
PostTraining, PostTraining,
Memory, Memory,
Evaluations, Eval,
Scoring,
ScoringFunctions,
DatasetIO,
Models, Models,
Shields, Shields,
Inspect, Inspect,

File diff suppressed because it is too large Load diff

File diff suppressed because it is too large Load diff

View file

@ -4,7 +4,7 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
from typing import Dict, List, Literal, Union from typing import Literal, Union
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from typing_extensions import Annotated from typing_extensions import Annotated
@ -24,12 +24,10 @@ class BooleanType(BaseModel):
class ArrayType(BaseModel): class ArrayType(BaseModel):
type: Literal["array"] = "array" type: Literal["array"] = "array"
items: "ParamType"
class ObjectType(BaseModel): class ObjectType(BaseModel):
type: Literal["object"] = "object" type: Literal["object"] = "object"
properties: Dict[str, "ParamType"] = Field(default_factory=dict)
class JsonType(BaseModel): class JsonType(BaseModel):
@ -38,7 +36,6 @@ class JsonType(BaseModel):
class UnionType(BaseModel): class UnionType(BaseModel):
type: Literal["union"] = "union" type: Literal["union"] = "union"
options: List["ParamType"] = Field(default_factory=list)
class CustomType(BaseModel): class CustomType(BaseModel):
@ -77,7 +74,3 @@ ParamType = Annotated[
], ],
Field(discriminator="type"), Field(discriminator="type"),
] ]
ArrayType.model_rebuild()
ObjectType.model_rebuild()
UnionType.model_rebuild()

View file

@ -14,7 +14,7 @@ from llama_models.schema_utils import json_schema_type, webmethod
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from llama_models.llama3.api.datatypes import * # noqa: F403 from llama_models.llama3.api.datatypes import * # noqa: F403
from llama_stack.apis.dataset import * # noqa: F403 from llama_stack.apis.datasets import * # noqa: F403
from llama_stack.apis.common.training_types import * # noqa: F403 from llama_stack.apis.common.training_types import * # noqa: F403
@ -107,8 +107,8 @@ class PostTrainingSFTRequest(BaseModel):
job_uuid: str job_uuid: str
model: str model: str
dataset: TrainEvalDataset dataset: str
validation_dataset: TrainEvalDataset validation_dataset: str
algorithm: FinetuningAlgorithm algorithm: FinetuningAlgorithm
algorithm_config: Union[ algorithm_config: Union[
@ -131,8 +131,8 @@ class PostTrainingRLHFRequest(BaseModel):
finetuned_model: URL finetuned_model: URL
dataset: TrainEvalDataset dataset: str
validation_dataset: TrainEvalDataset validation_dataset: str
algorithm: RLHFAlgorithm algorithm: RLHFAlgorithm
algorithm_config: Union[DPOAlignmentConfig] algorithm_config: Union[DPOAlignmentConfig]
@ -181,8 +181,8 @@ class PostTraining(Protocol):
self, self,
job_uuid: str, job_uuid: str,
model: str, model: str,
dataset: TrainEvalDataset, dataset: str,
validation_dataset: TrainEvalDataset, validation_dataset: str,
algorithm: FinetuningAlgorithm, algorithm: FinetuningAlgorithm,
algorithm_config: Union[ algorithm_config: Union[
LoraFinetuningConfig, QLoraFinetuningConfig, DoraFinetuningConfig LoraFinetuningConfig, QLoraFinetuningConfig, DoraFinetuningConfig
@ -198,8 +198,8 @@ class PostTraining(Protocol):
self, self,
job_uuid: str, job_uuid: str,
finetuned_model: URL, finetuned_model: URL,
dataset: TrainEvalDataset, dataset: str,
validation_dataset: TrainEvalDataset, validation_dataset: str,
algorithm: RLHFAlgorithm, algorithm: RLHFAlgorithm,
algorithm_config: Union[DPOAlignmentConfig], algorithm_config: Union[DPOAlignmentConfig],
optimizer_config: OptimizerConfig, optimizer_config: OptimizerConfig,

View file

@ -13,7 +13,6 @@ from llama_models.schema_utils import json_schema_type, webmethod
from pydantic import BaseModel from pydantic import BaseModel
from llama_models.llama3.api.datatypes import * # noqa: F403 from llama_models.llama3.api.datatypes import * # noqa: F403
from llama_stack.apis.reward_scoring import * # noqa: F403
class FilteringFunction(Enum): class FilteringFunction(Enum):
@ -40,7 +39,7 @@ class SyntheticDataGenerationRequest(BaseModel):
class SyntheticDataGenerationResponse(BaseModel): class SyntheticDataGenerationResponse(BaseModel):
"""Response from the synthetic data generation. Batch of (prompt, response, score) tuples that pass the threshold.""" """Response from the synthetic data generation. Batch of (prompt, response, score) tuples that pass the threshold."""
synthetic_data: List[ScoredDialogGenerations] synthetic_data: List[Dict[str, Any]]
statistics: Optional[Dict[str, Any]] = None statistics: Optional[Dict[str, Any]] = None