Add OpenAPI generation utility, update SPEC to reflect latest types

This commit is contained in:
Ashwin Bharambe 2024-08-15 13:45:45 -07:00
parent 417ba2aea0
commit 1f5eb9ff96
10 changed files with 770 additions and 656 deletions

View file

@ -60,19 +60,19 @@ class EvaluationJobArtifactsResponse(BaseModel):
class Evaluations(Protocol): class Evaluations(Protocol):
@webmethod(route="/evaluate/text_generation/") @webmethod(route="/evaluate/text_generation/")
def post_evaluate_text_generation( def evaluate_text_generation(
self, self,
request: EvaluateTextGenerationRequest, request: EvaluateTextGenerationRequest,
) -> EvaluationJob: ... ) -> EvaluationJob: ...
@webmethod(route="/evaluate/question_answering/") @webmethod(route="/evaluate/question_answering/")
def post_evaluate_question_answering( def evaluate_question_answering(
self, self,
request: EvaluateQuestionAnsweringRequest, request: EvaluateQuestionAnsweringRequest,
) -> EvaluationJob: ... ) -> EvaluationJob: ...
@webmethod(route="/evaluate/summarization/") @webmethod(route="/evaluate/summarization/")
def post_evaluate_summarization( def evaluate_summarization(
self, self,
request: EvaluateSummarizationRequest, request: EvaluateSummarizationRequest,
) -> EvaluationJob: ... ) -> EvaluationJob: ...

View file

@ -13,7 +13,7 @@ from .datatypes import * # noqa: F403
class MemoryBanks(Protocol): class MemoryBanks(Protocol):
@webmethod(route="/memory_banks/create") @webmethod(route="/memory_banks/create")
def post_create_memory_bank( def create_memory_bank(
self, self,
bank_id: str, bank_id: str,
bank_name: str, bank_name: str,
@ -33,14 +33,14 @@ class MemoryBanks(Protocol):
) -> str: ... ) -> str: ...
@webmethod(route="/memory_bank/insert") @webmethod(route="/memory_bank/insert")
def post_insert_memory_documents( def insert_memory_documents(
self, self,
bank_id: str, bank_id: str,
documents: List[MemoryBankDocument], documents: List[MemoryBankDocument],
) -> None: ... ) -> None: ...
@webmethod(route="/memory_bank/update") @webmethod(route="/memory_bank/update")
def post_update_memory_documents( def update_memory_documents(
self, self,
bank_id: str, bank_id: str,
documents: List[MemoryBankDocument], documents: List[MemoryBankDocument],

View file

@ -95,13 +95,13 @@ class PostTrainingJobArtifactsResponse(BaseModel):
class PostTraining(Protocol): class PostTraining(Protocol):
@webmethod(route="/post_training/supervised_fine_tune") @webmethod(route="/post_training/supervised_fine_tune")
def post_supervised_fine_tune( def supervised_fine_tune(
self, self,
request: PostTrainingSFTRequest, request: PostTrainingSFTRequest,
) -> PostTrainingJob: ... ) -> PostTrainingJob: ...
@webmethod(route="/post_training/preference_optimize") @webmethod(route="/post_training/preference_optimize")
def post_preference_optimize( def preference_optimize(
self, self,
request: PostTrainingRLHFRequest, request: PostTrainingRLHFRequest,
) -> PostTrainingJob: ... ) -> PostTrainingJob: ...

View file

@ -27,7 +27,7 @@ class RewardScoringResponse(BaseModel):
class RewardScoring(Protocol): class RewardScoring(Protocol):
@webmethod(route="/reward_scoring/score") @webmethod(route="/reward_scoring/score")
def post_score( def reward_score(
self, self,
request: RewardScoringRequest, request: RewardScoringRequest,
) -> Union[RewardScoringResponse]: ... ) -> Union[RewardScoringResponse]: ...

View file

@ -34,7 +34,7 @@ class SyntheticDataGenerationResponse(BaseModel):
class SyntheticDataGeneration(Protocol): class SyntheticDataGeneration(Protocol):
@webmethod(route="/synthetic_data_generation/generate") @webmethod(route="/synthetic_data_generation/generate")
def post_generate( def synthetic_data_generate(
self, self,
request: SyntheticDataGenerationRequest, request: SyntheticDataGenerationRequest,
) -> Union[SyntheticDataGenerationResponse]: ... ) -> Union[SyntheticDataGenerationResponse]: ...

File diff suppressed because it is too large Load diff

View file

@ -7,7 +7,7 @@ components:
instance_config: instance_config:
$ref: '#/components/schemas/AgenticSystemInstanceConfig' $ref: '#/components/schemas/AgenticSystemInstanceConfig'
model: model:
$ref: '#/components/schemas/InstructModel' type: string
required: required:
- model - model
- instance_config - instance_config
@ -170,7 +170,7 @@ components:
type: array type: array
type: array type: array
model: model:
$ref: '#/components/schemas/InstructModel' type: string
quantization_config: quantization_config:
oneOf: oneOf:
- $ref: '#/components/schemas/Bf16QuantizationConfig' - $ref: '#/components/schemas/Bf16QuantizationConfig'
@ -212,7 +212,7 @@ components:
type: integer type: integer
type: object type: object
model: model:
$ref: '#/components/schemas/PretrainedModel' type: string
quantization_config: quantization_config:
oneOf: oneOf:
- $ref: '#/components/schemas/Bf16QuantizationConfig' - $ref: '#/components/schemas/Bf16QuantizationConfig'
@ -279,7 +279,7 @@ components:
- $ref: '#/components/schemas/CompletionMessage' - $ref: '#/components/schemas/CompletionMessage'
type: array type: array
model: model:
$ref: '#/components/schemas/InstructModel' type: string
quantization_config: quantization_config:
oneOf: oneOf:
- $ref: '#/components/schemas/Bf16QuantizationConfig' - $ref: '#/components/schemas/Bf16QuantizationConfig'
@ -375,7 +375,7 @@ components:
type: integer type: integer
type: object type: object
model: model:
$ref: '#/components/schemas/PretrainedModel' type: string
quantization_config: quantization_config:
oneOf: oneOf:
- $ref: '#/components/schemas/Bf16QuantizationConfig' - $ref: '#/components/schemas/Bf16QuantizationConfig'
@ -629,11 +629,6 @@ components:
- step_type - step_type
- model_response - model_response
type: object type: object
InstructModel:
enum:
- llama3_8b_chat
- llama3_70b_chat
type: string
LoraFinetuningConfig: LoraFinetuningConfig:
additionalProperties: false additionalProperties: false
properties: properties:
@ -922,7 +917,7 @@ components:
- type: object - type: object
type: object type: object
model: model:
$ref: '#/components/schemas/PretrainedModel' type: string
optimizer_config: optimizer_config:
$ref: '#/components/schemas/OptimizerConfig' $ref: '#/components/schemas/OptimizerConfig'
training_config: training_config:
@ -942,9 +937,6 @@ components:
- logger_config - logger_config
title: Request to finetune a model. title: Request to finetune a model.
type: object type: object
PretrainedModel:
description: The type of the model. This is used to determine the model family
and SKU.
QLoraFinetuningConfig: QLoraFinetuningConfig:
additionalProperties: false additionalProperties: false
properties: properties:
@ -1001,11 +993,6 @@ components:
- PUT - PUT
- DELETE - DELETE
type: string type: string
RewardModel:
enum:
- llama3_70b_reward
- llama3_405b_reward
type: string
RewardScoringRequest: RewardScoringRequest:
additionalProperties: false additionalProperties: false
properties: properties:
@ -1014,7 +1001,7 @@ components:
$ref: '#/components/schemas/DialogGenerations' $ref: '#/components/schemas/DialogGenerations'
type: array type: array
model: model:
$ref: '#/components/schemas/RewardModel' type: string
required: required:
- dialog_generations - dialog_generations
- model - model
@ -1202,7 +1189,7 @@ components:
title: The type of filtering function. title: The type of filtering function.
type: string type: string
model: model:
$ref: '#/components/schemas/RewardModel' type: string
required: required:
- dialogs - dialogs
- filtering_function - filtering_function
@ -1551,7 +1538,7 @@ info:
description: "This is the specification of the llama stack that provides\n \ description: "This is the specification of the llama stack that provides\n \
\ a set of endpoints and their corresponding interfaces that are tailored\ \ a set of endpoints and their corresponding interfaces that are tailored\
\ to\n best leverage Llama Models. The specification is still in\ \ to\n best leverage Llama Models. The specification is still in\
\ draft and subject to change.\n Generated at 2024-07-23 02:02:16.069876" \ draft and subject to change.\n Generated at 2024-08-15 13:41:52.916332"
title: '[DRAFT] Llama Stack Specification' title: '[DRAFT] Llama Stack Specification'
version: 0.0.1 version: 0.0.1
jsonSchemaDialect: https://json-schema.org/draft/2020-12/schema jsonSchemaDialect: https://json-schema.org/draft/2020-12/schema
@ -2338,14 +2325,14 @@ security:
servers: servers:
- url: http://any-hosted-llama-stack.com - url: http://any-hosted-llama-stack.com
tags: tags:
- name: PostTraining
- name: MemoryBanks
- name: RewardScoring
- name: Datasets
- name: Evaluations - name: Evaluations
- name: AgenticSystem
- name: Inference - name: Inference
- name: SyntheticDataGeneration - name: SyntheticDataGeneration
- name: AgenticSystem
- name: RewardScoring
- name: Datasets
- name: PostTraining
- name: MemoryBanks
- description: <SchemaDefinition schemaRef="#/components/schemas/Attachment" /> - description: <SchemaDefinition schemaRef="#/components/schemas/Attachment" />
name: Attachment name: Attachment
- description: <SchemaDefinition schemaRef="#/components/schemas/BatchChatCompletionRequest" - description: <SchemaDefinition schemaRef="#/components/schemas/BatchChatCompletionRequest"
@ -2362,8 +2349,6 @@ tags:
- description: <SchemaDefinition schemaRef="#/components/schemas/Fp8QuantizationConfig" - description: <SchemaDefinition schemaRef="#/components/schemas/Fp8QuantizationConfig"
/> />
name: Fp8QuantizationConfig name: Fp8QuantizationConfig
- description: <SchemaDefinition schemaRef="#/components/schemas/InstructModel" />
name: InstructModel
- description: <SchemaDefinition schemaRef="#/components/schemas/SamplingParams" /> - description: <SchemaDefinition schemaRef="#/components/schemas/SamplingParams" />
name: SamplingParams name: SamplingParams
- description: <SchemaDefinition schemaRef="#/components/schemas/SamplingStrategy" - description: <SchemaDefinition schemaRef="#/components/schemas/SamplingStrategy"
@ -2393,12 +2378,6 @@ tags:
- description: <SchemaDefinition schemaRef="#/components/schemas/BatchCompletionRequest" - description: <SchemaDefinition schemaRef="#/components/schemas/BatchCompletionRequest"
/> />
name: BatchCompletionRequest name: BatchCompletionRequest
- description: 'The type of the model. This is used to determine the model family
and SKU.
<SchemaDefinition schemaRef="#/components/schemas/PretrainedModel" />'
name: PretrainedModel
- description: <SchemaDefinition schemaRef="#/components/schemas/BatchCompletionResponse" - description: <SchemaDefinition schemaRef="#/components/schemas/BatchCompletionResponse"
/> />
name: BatchCompletionResponse name: BatchCompletionResponse
@ -2489,11 +2468,36 @@ tags:
- description: <SchemaDefinition schemaRef="#/components/schemas/TrainEvalDatasetColumnType" - description: <SchemaDefinition schemaRef="#/components/schemas/TrainEvalDatasetColumnType"
/> />
name: TrainEvalDatasetColumnType name: TrainEvalDatasetColumnType
- description: <SchemaDefinition schemaRef="#/components/schemas/InferenceStep" />
name: InferenceStep
- description: <SchemaDefinition schemaRef="#/components/schemas/MemoryBankDocument" - description: <SchemaDefinition schemaRef="#/components/schemas/MemoryBankDocument"
/> />
name: MemoryBankDocument name: MemoryBankDocument
- description: 'Checkpoint created during training runs
<SchemaDefinition schemaRef="#/components/schemas/Checkpoint" />'
name: Checkpoint
- description: 'Request to evaluate question answering.
<SchemaDefinition schemaRef="#/components/schemas/EvaluateQuestionAnsweringRequest"
/>'
name: EvaluateQuestionAnsweringRequest
- description: <SchemaDefinition schemaRef="#/components/schemas/EvaluationJob" />
name: EvaluationJob
- description: 'Request to evaluate summarization.
<SchemaDefinition schemaRef="#/components/schemas/EvaluateSummarizationRequest"
/>'
name: EvaluateSummarizationRequest
- description: 'Request to evaluate text generation.
<SchemaDefinition schemaRef="#/components/schemas/EvaluateTextGenerationRequest"
/>'
name: EvaluateTextGenerationRequest
- description: <SchemaDefinition schemaRef="#/components/schemas/InferenceStep" />
name: InferenceStep
- description: <SchemaDefinition schemaRef="#/components/schemas/MemoryRetrievalStep" - description: <SchemaDefinition schemaRef="#/components/schemas/MemoryRetrievalStep"
/> />
name: MemoryRetrievalStep name: MemoryRetrievalStep
@ -2531,15 +2535,8 @@ tags:
- description: <SchemaDefinition schemaRef="#/components/schemas/EvaluationJobStatusResponse" - description: <SchemaDefinition schemaRef="#/components/schemas/EvaluationJobStatusResponse"
/> />
name: EvaluationJobStatusResponse name: EvaluationJobStatusResponse
- description: <SchemaDefinition schemaRef="#/components/schemas/EvaluationJob" />
name: EvaluationJob
- description: <SchemaDefinition schemaRef="#/components/schemas/MemoryBank" /> - description: <SchemaDefinition schemaRef="#/components/schemas/MemoryBank" />
name: MemoryBank name: MemoryBank
- description: 'Checkpoint created during training runs
<SchemaDefinition schemaRef="#/components/schemas/Checkpoint" />'
name: Checkpoint
- description: 'Artifacts of a finetuning job. - description: 'Artifacts of a finetuning job.
@ -2563,45 +2560,6 @@ tags:
- description: <SchemaDefinition schemaRef="#/components/schemas/PostTrainingJob" - description: <SchemaDefinition schemaRef="#/components/schemas/PostTrainingJob"
/> />
name: PostTrainingJob name: PostTrainingJob
- description: 'Request to evaluate question answering.
<SchemaDefinition schemaRef="#/components/schemas/EvaluateQuestionAnsweringRequest"
/>'
name: EvaluateQuestionAnsweringRequest
- description: 'Request to evaluate summarization.
<SchemaDefinition schemaRef="#/components/schemas/EvaluateSummarizationRequest"
/>'
name: EvaluateSummarizationRequest
- description: 'Request to evaluate text generation.
<SchemaDefinition schemaRef="#/components/schemas/EvaluateTextGenerationRequest"
/>'
name: EvaluateTextGenerationRequest
- description: <SchemaDefinition schemaRef="#/components/schemas/RewardModel" />
name: RewardModel
- description: 'Request to generate synthetic data. A small batch of prompts and a
filtering function
<SchemaDefinition schemaRef="#/components/schemas/SyntheticDataGenerationRequest"
/>'
name: SyntheticDataGenerationRequest
- description: <SchemaDefinition schemaRef="#/components/schemas/ScoredDialogGenerations"
/>
name: ScoredDialogGenerations
- description: <SchemaDefinition schemaRef="#/components/schemas/ScoredMessage" />
name: ScoredMessage
- description: 'Response from the synthetic data generation. Batch of (prompt, response,
score) tuples that pass the threshold.
<SchemaDefinition schemaRef="#/components/schemas/SyntheticDataGenerationResponse"
/>'
name: SyntheticDataGenerationResponse
- description: <SchemaDefinition schemaRef="#/components/schemas/DPOAlignmentConfig" - description: <SchemaDefinition schemaRef="#/components/schemas/DPOAlignmentConfig"
/> />
name: DPOAlignmentConfig name: DPOAlignmentConfig
@ -2632,6 +2590,11 @@ tags:
<SchemaDefinition schemaRef="#/components/schemas/RewardScoringResponse" />' <SchemaDefinition schemaRef="#/components/schemas/RewardScoringResponse" />'
name: RewardScoringResponse name: RewardScoringResponse
- description: <SchemaDefinition schemaRef="#/components/schemas/ScoredDialogGenerations"
/>
name: ScoredDialogGenerations
- description: <SchemaDefinition schemaRef="#/components/schemas/ScoredMessage" />
name: ScoredMessage
- description: <SchemaDefinition schemaRef="#/components/schemas/DoraFinetuningConfig" - description: <SchemaDefinition schemaRef="#/components/schemas/DoraFinetuningConfig"
/> />
name: DoraFinetuningConfig name: DoraFinetuningConfig
@ -2649,6 +2612,20 @@ tags:
- description: <SchemaDefinition schemaRef="#/components/schemas/QLoraFinetuningConfig" - description: <SchemaDefinition schemaRef="#/components/schemas/QLoraFinetuningConfig"
/> />
name: QLoraFinetuningConfig name: QLoraFinetuningConfig
- description: 'Request to generate synthetic data. A small batch of prompts and a
filtering function
<SchemaDefinition schemaRef="#/components/schemas/SyntheticDataGenerationRequest"
/>'
name: SyntheticDataGenerationRequest
- description: 'Response from the synthetic data generation. Batch of (prompt, response,
score) tuples that pass the threshold.
<SchemaDefinition schemaRef="#/components/schemas/SyntheticDataGenerationResponse"
/>'
name: SyntheticDataGenerationResponse
x-tagGroups: x-tagGroups:
- name: Operations - name: Operations
tags: tags:
@ -2701,7 +2678,6 @@ x-tagGroups:
- FinetuningAlgorithm - FinetuningAlgorithm
- Fp8QuantizationConfig - Fp8QuantizationConfig
- InferenceStep - InferenceStep
- InstructModel
- LoraFinetuningConfig - LoraFinetuningConfig
- MemoryBank - MemoryBank
- MemoryBankDocument - MemoryBankDocument
@ -2715,12 +2691,10 @@ x-tagGroups:
- PostTrainingJobStatusResponse - PostTrainingJobStatusResponse
- PostTrainingRLHFRequest - PostTrainingRLHFRequest
- PostTrainingSFTRequest - PostTrainingSFTRequest
- PretrainedModel
- QLoraFinetuningConfig - QLoraFinetuningConfig
- RLHFAlgorithm - RLHFAlgorithm
- RestAPIExecutionConfig - RestAPIExecutionConfig
- RestAPIMethod - RestAPIMethod
- RewardModel
- RewardScoringRequest - RewardScoringRequest
- RewardScoringResponse - RewardScoringResponse
- SamplingParams - SamplingParams

View file

@ -0,0 +1,9 @@
The RFC Specification (OpenAPI format) is generated from the set of API endpoints located in `llama_toolchain/[<subdir>]/api/endpoints.py` using the `generate.py` utility.
Please install the following packages before running the script:
```
pip install python-openapi json-strong-typing fire PyYAML llama-models
```
Then simply run `sh run_openapi_generator.sh <OUTPUT_DIR>`

View file

@ -0,0 +1,130 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described found in the
# LICENSE file in the root directory of this source tree.
import inspect
from datetime import datetime
from pathlib import Path
from typing import Callable, Iterator, List, Tuple
import fire
import yaml
from llama_models import schema_utils
from pyopenapi import Info, operations, Options, Server, Specification
# We do a series of monkey-patching to ensure our definitions only use the minimal
# (json_schema_type, webmethod) definitions from the llama_models package. For
# generation though, we need the full definitions and implementations from the
# (python-openapi, json-strong-typing) packages.
from strong_typing.schema import json_schema_type
from termcolor import colored
# PATCH `json_schema_type` first
schema_utils.json_schema_type = json_schema_type
from llama_models.llama3_1.api.datatypes import * # noqa: F403
from llama_toolchain.agentic_system.api import * # noqa: F403
from llama_toolchain.dataset.api import * # noqa: F403
from llama_toolchain.evaluations.api import * # noqa: F403
from llama_toolchain.inference.api import * # noqa: F403
from llama_toolchain.memory.api import * # noqa: F403
from llama_toolchain.post_training.api import * # noqa: F403
from llama_toolchain.reward_scoring.api import * # noqa: F403
from llama_toolchain.synthetic_data_generation.api import * # noqa: F403
def patched_get_endpoint_functions(
endpoint: type, prefixes: List[str]
) -> Iterator[Tuple[str, str, str, Callable]]:
if not inspect.isclass(endpoint):
raise ValueError(f"object is not a class type: {endpoint}")
functions = inspect.getmembers(endpoint, inspect.isfunction)
for func_name, func_ref in functions:
webmethod = getattr(func_ref, "__webmethod__", None)
if not webmethod:
continue
print(f"Processing {colored(func_name, 'white')}...")
operation_name = func_name
if operation_name.startswith("get_") or operation_name.endswith("/get"):
prefix = "get"
elif (
operation_name.startswith("delete_")
or operation_name.startswith("remove_")
or operation_name.endswith("/delete")
or operation_name.endswith("/remove")
):
prefix = "delete"
else:
if webmethod.method == "GET":
prefix = "get"
elif webmethod.method == "DELETE":
prefix = "delete"
else:
# by default everything else is a POST
prefix = "post"
yield prefix, operation_name, func_name, func_ref
operations._get_endpoint_functions = patched_get_endpoint_functions
class LlamaStackEndpoints(
Inference,
AgenticSystem,
RewardScoring,
SyntheticDataGeneration,
Datasets,
PostTraining,
MemoryBanks,
Evaluations,
): ...
def main(output_dir: str):
output_dir = Path(output_dir)
if not output_dir.exists():
raise ValueError(f"Directory {output_dir} does not exist")
now = str(datetime.now())
print(
"Converting the spec to YAML (openapi.yaml) and HTML (openapi.html) at " + now
)
print("")
spec = Specification(
LlamaStackEndpoints,
Options(
server=Server(url="http://any-hosted-llama-stack.com"),
info=Info(
title="[DRAFT] Llama Stack Specification",
version="0.0.1",
description="""This is the specification of the llama stack that provides
a set of endpoints and their corresponding interfaces that are tailored to
best leverage Llama Models. The specification is still in draft and subject to change.
Generated at """
+ now,
),
),
)
with open(output_dir / "llama-stack-spec.yaml", "w", encoding="utf-8") as fp:
yaml.dump(spec.get_json(), fp, allow_unicode=True)
with open(output_dir / "llama-stack-spec.html", "w") as fp:
spec.write_html(fp, pretty_print=True)
if __name__ == "__main__":
fire.Fire(main)

View file

@ -0,0 +1,33 @@
#!/bin/bash
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
PYTHONPATH=${PYTHONPATH:-}
set -euo pipefail
missing_packages=()
check_package() {
if ! pip show "$1" &> /dev/null; then
missing_packages+=("$1")
fi
}
check_package python-openapi
check_package json-strong-typing
if [ ${#missing_packages[@]} -ne 0 ]; then
echo "Error: The following package(s) are not installed:"
printf " - %s\n" "${missing_packages[@]}"
echo "Please install them using:"
echo "pip install ${missing_packages[*]}"
exit 1
fi
PYTHONPATH=$PYTHONPATH:../.. python3 -m rfcs.openapi_generator.generate $*