mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-03 09:53:45 +00:00
feat(openapi): switch to fastapi-based generator (#3944)
Some checks failed
Pre-commit / pre-commit (push) Successful in 3m27s
SqlStore Integration Tests / test-postgres (3.12) (push) Failing after 0s
Integration Auth Tests / test-matrix (oauth2_token) (push) Failing after 1s
SqlStore Integration Tests / test-postgres (3.13) (push) Failing after 0s
Integration Tests (Replay) / generate-matrix (push) Successful in 3s
Test Llama Stack Build / generate-matrix (push) Successful in 3s
Test External Providers Installed via Module / test-external-providers-from-module (venv) (push) Has been skipped
Test llama stack list-deps / generate-matrix (push) Successful in 3s
Python Package Build Test / build (3.12) (push) Failing after 4s
API Conformance Tests / check-schema-compatibility (push) Successful in 11s
Test llama stack list-deps / show-single-provider (push) Successful in 25s
Test External API and Providers / test-external (venv) (push) Failing after 34s
Vector IO Integration Tests / test-matrix (push) Failing after 43s
Test Llama Stack Build / build (push) Successful in 37s
Test Llama Stack Build / build-single-provider (push) Successful in 48s
Test llama stack list-deps / list-deps-from-config (push) Successful in 52s
Test llama stack list-deps / list-deps (push) Failing after 52s
Python Package Build Test / build (3.13) (push) Failing after 1m2s
UI Tests / ui-tests (22) (push) Successful in 1m15s
Test Llama Stack Build / build-custom-container-distribution (push) Successful in 1m29s
Unit Tests / unit-tests (3.12) (push) Failing after 1m45s
Test Llama Stack Build / build-ubi9-container-distribution (push) Successful in 1m54s
Unit Tests / unit-tests (3.13) (push) Failing after 2m13s
Integration Tests (Replay) / Integration Tests (, , , client=, ) (push) Failing after 2m20s
Some checks failed
Pre-commit / pre-commit (push) Successful in 3m27s
SqlStore Integration Tests / test-postgres (3.12) (push) Failing after 0s
Integration Auth Tests / test-matrix (oauth2_token) (push) Failing after 1s
SqlStore Integration Tests / test-postgres (3.13) (push) Failing after 0s
Integration Tests (Replay) / generate-matrix (push) Successful in 3s
Test Llama Stack Build / generate-matrix (push) Successful in 3s
Test External Providers Installed via Module / test-external-providers-from-module (venv) (push) Has been skipped
Test llama stack list-deps / generate-matrix (push) Successful in 3s
Python Package Build Test / build (3.12) (push) Failing after 4s
API Conformance Tests / check-schema-compatibility (push) Successful in 11s
Test llama stack list-deps / show-single-provider (push) Successful in 25s
Test External API and Providers / test-external (venv) (push) Failing after 34s
Vector IO Integration Tests / test-matrix (push) Failing after 43s
Test Llama Stack Build / build (push) Successful in 37s
Test Llama Stack Build / build-single-provider (push) Successful in 48s
Test llama stack list-deps / list-deps-from-config (push) Successful in 52s
Test llama stack list-deps / list-deps (push) Failing after 52s
Python Package Build Test / build (3.13) (push) Failing after 1m2s
UI Tests / ui-tests (22) (push) Successful in 1m15s
Test Llama Stack Build / build-custom-container-distribution (push) Successful in 1m29s
Unit Tests / unit-tests (3.12) (push) Failing after 1m45s
Test Llama Stack Build / build-ubi9-container-distribution (push) Successful in 1m54s
Unit Tests / unit-tests (3.13) (push) Failing after 2m13s
Integration Tests (Replay) / Integration Tests (, , , client=, ) (push) Failing after 2m20s
# What does this PR do?
This replaces the legacy "pyopenapi + strong_typing" pipeline with a
FastAPI-backed generator that has an explicit schema registry inside
`llama_stack_api`. The key changes:
1. **New generator architecture.** FastAPI now builds the OpenAPI schema
directly from the real routes, while helper modules
(`schema_collection`, `endpoints`, `schema_transforms`, etc.)
post-process the result. The old pyopenapi stack and its strong_typing
helpers are removed entirely, so we no longer rely on fragile AST
analysis or top-level import side effects.
2. **Schema registry in `llama_stack_api`.** `schema_utils.py` keeps a
`SchemaInfo` record for every `@json_schema_type`, `register_schema`,
and dynamically created request model. The OpenAPI generator and other
tooling query this registry instead of scanning the package tree,
producing deterministic names (e.g., `{MethodName}Request`), capturing
all optional/nullable fields, and making schema discovery testable. A
new unit test covers the registry behavior.
3. **Regenerated specs + CI alignment.** All docs/Stainless specs are
regenerated from the new pipeline, so optional/nullable fields now match
reality (expect the API Conformance workflow to report breaking
changes—this PR establishes the new baseline). The workflow itself is
back to the stock oasdiff invocation so future regressions surface
normally.
*Conformance will be RED on this PR; we choose to accept the
deviations.*
## Test Plan
- `uv run pytest tests/unit/server/test_schema_registry.py`
- `uv run python -m scripts.openapi_generator.main docs/static`
---------
Signed-off-by: Sébastien Han <seb@redhat.com>
Co-authored-by: Ashwin Bharambe <ashwin.bharambe@gmail.com>
This commit is contained in:
parent
cc88789071
commit
97f535c4f1
64 changed files with 47592 additions and 30218 deletions
16
scripts/openapi_generator/__init__.py
Normal file
16
scripts/openapi_generator/__init__.py
Normal file
|
|
@ -0,0 +1,16 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
"""
|
||||
OpenAPI generator module for Llama Stack.
|
||||
|
||||
This module provides functionality to generate OpenAPI specifications
|
||||
from FastAPI applications.
|
||||
"""
|
||||
|
||||
from .main import generate_openapi_spec, main
|
||||
|
||||
__all__ = ["generate_openapi_spec", "main"]
|
||||
14
scripts/openapi_generator/__main__.py
Normal file
14
scripts/openapi_generator/__main__.py
Normal file
|
|
@ -0,0 +1,14 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
"""
|
||||
Entry point for running the openapi_generator module as a package.
|
||||
"""
|
||||
|
||||
from .main import main
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
502
scripts/openapi_generator/_legacy_order.py
Normal file
502
scripts/openapi_generator/_legacy_order.py
Normal file
|
|
@ -0,0 +1,502 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
"""
|
||||
Temporary ordering helpers extracted from origin/main client-sdks/stainless/openapi.yml.
|
||||
|
||||
These lists help the new generator match the previous ordering so that diffs
|
||||
remain readable while we debug schema content regressions. Remove once stable.
|
||||
"""
|
||||
|
||||
LEGACY_PATH_ORDER = [
|
||||
"/v1/batches",
|
||||
"/v1/batches/{batch_id}",
|
||||
"/v1/batches/{batch_id}/cancel",
|
||||
"/v1/chat/completions",
|
||||
"/v1/chat/completions/{completion_id}",
|
||||
"/v1/completions",
|
||||
"/v1/conversations",
|
||||
"/v1/conversations/{conversation_id}",
|
||||
"/v1/conversations/{conversation_id}/items",
|
||||
"/v1/conversations/{conversation_id}/items/{item_id}",
|
||||
"/v1/embeddings",
|
||||
"/v1/files",
|
||||
"/v1/files/{file_id}",
|
||||
"/v1/files/{file_id}/content",
|
||||
"/v1/health",
|
||||
"/v1/inspect/routes",
|
||||
"/v1/models",
|
||||
"/v1/models/{model_id}",
|
||||
"/v1/moderations",
|
||||
"/v1/prompts",
|
||||
"/v1/prompts/{prompt_id}",
|
||||
"/v1/prompts/{prompt_id}/set-default-version",
|
||||
"/v1/prompts/{prompt_id}/versions",
|
||||
"/v1/providers",
|
||||
"/v1/providers/{provider_id}",
|
||||
"/v1/responses",
|
||||
"/v1/responses/{response_id}",
|
||||
"/v1/responses/{response_id}/input_items",
|
||||
"/v1/safety/run-shield",
|
||||
"/v1/scoring-functions",
|
||||
"/v1/scoring-functions/{scoring_fn_id}",
|
||||
"/v1/scoring/score",
|
||||
"/v1/scoring/score-batch",
|
||||
"/v1/shields",
|
||||
"/v1/shields/{identifier}",
|
||||
"/v1/tool-runtime/invoke",
|
||||
"/v1/tool-runtime/list-tools",
|
||||
"/v1/toolgroups",
|
||||
"/v1/toolgroups/{toolgroup_id}",
|
||||
"/v1/tools",
|
||||
"/v1/tools/{tool_name}",
|
||||
"/v1/vector-io/insert",
|
||||
"/v1/vector-io/query",
|
||||
"/v1/vector_stores",
|
||||
"/v1/vector_stores/{vector_store_id}",
|
||||
"/v1/vector_stores/{vector_store_id}/file_batches",
|
||||
"/v1/vector_stores/{vector_store_id}/file_batches/{batch_id}",
|
||||
"/v1/vector_stores/{vector_store_id}/file_batches/{batch_id}/cancel",
|
||||
"/v1/vector_stores/{vector_store_id}/file_batches/{batch_id}/files",
|
||||
"/v1/vector_stores/{vector_store_id}/files",
|
||||
"/v1/vector_stores/{vector_store_id}/files/{file_id}",
|
||||
"/v1/vector_stores/{vector_store_id}/files/{file_id}/content",
|
||||
"/v1/vector_stores/{vector_store_id}/search",
|
||||
"/v1/version",
|
||||
"/v1beta/datasetio/append-rows/{dataset_id}",
|
||||
"/v1beta/datasetio/iterrows/{dataset_id}",
|
||||
"/v1beta/datasets",
|
||||
"/v1beta/datasets/{dataset_id}",
|
||||
"/v1alpha/eval/benchmarks",
|
||||
"/v1alpha/eval/benchmarks/{benchmark_id}",
|
||||
"/v1alpha/eval/benchmarks/{benchmark_id}/evaluations",
|
||||
"/v1alpha/eval/benchmarks/{benchmark_id}/jobs",
|
||||
"/v1alpha/eval/benchmarks/{benchmark_id}/jobs/{job_id}",
|
||||
"/v1alpha/eval/benchmarks/{benchmark_id}/jobs/{job_id}/result",
|
||||
"/v1alpha/inference/rerank",
|
||||
"/v1alpha/post-training/job/artifacts",
|
||||
"/v1alpha/post-training/job/cancel",
|
||||
"/v1alpha/post-training/job/status",
|
||||
"/v1alpha/post-training/jobs",
|
||||
"/v1alpha/post-training/preference-optimize",
|
||||
"/v1alpha/post-training/supervised-fine-tune",
|
||||
]
|
||||
|
||||
LEGACY_SCHEMA_ORDER = [
|
||||
"Error",
|
||||
"ListBatchesResponse",
|
||||
"CreateBatchRequest",
|
||||
"Batch",
|
||||
"Order",
|
||||
"ListOpenAIChatCompletionResponse",
|
||||
"OpenAIAssistantMessageParam",
|
||||
"OpenAIChatCompletionContentPartImageParam",
|
||||
"OpenAIChatCompletionContentPartParam",
|
||||
"OpenAIChatCompletionContentPartTextParam",
|
||||
"OpenAIChatCompletionToolCall",
|
||||
"OpenAIChatCompletionToolCallFunction",
|
||||
"OpenAIChatCompletionUsage",
|
||||
"OpenAIChoice",
|
||||
"OpenAIChoiceLogprobs",
|
||||
"OpenAIDeveloperMessageParam",
|
||||
"OpenAIFile",
|
||||
"OpenAIFileFile",
|
||||
"OpenAIImageURL",
|
||||
"OpenAIMessageParam",
|
||||
"OpenAISystemMessageParam",
|
||||
"OpenAITokenLogProb",
|
||||
"OpenAIToolMessageParam",
|
||||
"OpenAITopLogProb",
|
||||
"OpenAIUserMessageParam",
|
||||
"OpenAIJSONSchema",
|
||||
"OpenAIResponseFormatJSONObject",
|
||||
"OpenAIResponseFormatJSONSchema",
|
||||
"OpenAIResponseFormatParam",
|
||||
"OpenAIResponseFormatText",
|
||||
"OpenAIChatCompletionRequestWithExtraBody",
|
||||
"OpenAIChatCompletion",
|
||||
"OpenAIChatCompletionChunk",
|
||||
"OpenAIChoiceDelta",
|
||||
"OpenAIChunkChoice",
|
||||
"OpenAICompletionWithInputMessages",
|
||||
"OpenAICompletionRequestWithExtraBody",
|
||||
"OpenAICompletion",
|
||||
"OpenAICompletionChoice",
|
||||
"ConversationItem",
|
||||
"OpenAIResponseAnnotationCitation",
|
||||
"OpenAIResponseAnnotationContainerFileCitation",
|
||||
"OpenAIResponseAnnotationFileCitation",
|
||||
"OpenAIResponseAnnotationFilePath",
|
||||
"OpenAIResponseAnnotations",
|
||||
"OpenAIResponseContentPartRefusal",
|
||||
"OpenAIResponseInputFunctionToolCallOutput",
|
||||
"OpenAIResponseInputMessageContent",
|
||||
"OpenAIResponseInputMessageContentFile",
|
||||
"OpenAIResponseInputMessageContentImage",
|
||||
"OpenAIResponseInputMessageContentText",
|
||||
"OpenAIResponseMCPApprovalRequest",
|
||||
"OpenAIResponseMCPApprovalResponse",
|
||||
"OpenAIResponseMessage",
|
||||
"OpenAIResponseOutputMessageContent",
|
||||
"OpenAIResponseOutputMessageContentOutputText",
|
||||
"OpenAIResponseOutputMessageFileSearchToolCall",
|
||||
"OpenAIResponseOutputMessageFunctionToolCall",
|
||||
"OpenAIResponseOutputMessageMCPCall",
|
||||
"OpenAIResponseOutputMessageMCPListTools",
|
||||
"OpenAIResponseOutputMessageWebSearchToolCall",
|
||||
"CreateConversationRequest",
|
||||
"Conversation",
|
||||
"UpdateConversationRequest",
|
||||
"ConversationDeletedResource",
|
||||
"ConversationItemList",
|
||||
"AddItemsRequest",
|
||||
"ConversationItemDeletedResource",
|
||||
"OpenAIEmbeddingsRequestWithExtraBody",
|
||||
"OpenAIEmbeddingData",
|
||||
"OpenAIEmbeddingUsage",
|
||||
"OpenAIEmbeddingsResponse",
|
||||
"OpenAIFilePurpose",
|
||||
"ListOpenAIFileResponse",
|
||||
"OpenAIFileObject",
|
||||
"ExpiresAfter",
|
||||
"OpenAIFileDeleteResponse",
|
||||
"Response",
|
||||
"HealthInfo",
|
||||
"RouteInfo",
|
||||
"ListRoutesResponse",
|
||||
"OpenAIModel",
|
||||
"OpenAIListModelsResponse",
|
||||
"Model",
|
||||
"ModelType",
|
||||
"RunModerationRequest",
|
||||
"ModerationObject",
|
||||
"ModerationObjectResults",
|
||||
"Prompt",
|
||||
"ListPromptsResponse",
|
||||
"CreatePromptRequest",
|
||||
"UpdatePromptRequest",
|
||||
"SetDefaultVersionRequest",
|
||||
"ProviderInfo",
|
||||
"ListProvidersResponse",
|
||||
"ListOpenAIResponseObject",
|
||||
"OpenAIResponseError",
|
||||
"OpenAIResponseInput",
|
||||
"OpenAIResponseInputToolFileSearch",
|
||||
"OpenAIResponseInputToolFunction",
|
||||
"OpenAIResponseInputToolWebSearch",
|
||||
"OpenAIResponseObjectWithInput",
|
||||
"OpenAIResponseOutput",
|
||||
"OpenAIResponsePrompt",
|
||||
"OpenAIResponseText",
|
||||
"OpenAIResponseTool",
|
||||
"OpenAIResponseToolMCP",
|
||||
"OpenAIResponseUsage",
|
||||
"ResponseGuardrailSpec",
|
||||
"OpenAIResponseInputTool",
|
||||
"OpenAIResponseInputToolMCP",
|
||||
"CreateOpenaiResponseRequest",
|
||||
"OpenAIResponseObject",
|
||||
"OpenAIResponseContentPartOutputText",
|
||||
"OpenAIResponseContentPartReasoningSummary",
|
||||
"OpenAIResponseContentPartReasoningText",
|
||||
"OpenAIResponseObjectStream",
|
||||
"OpenAIResponseObjectStreamResponseCompleted",
|
||||
"OpenAIResponseObjectStreamResponseContentPartAdded",
|
||||
"OpenAIResponseObjectStreamResponseContentPartDone",
|
||||
"OpenAIResponseObjectStreamResponseCreated",
|
||||
"OpenAIResponseObjectStreamResponseFailed",
|
||||
"OpenAIResponseObjectStreamResponseFileSearchCallCompleted",
|
||||
"OpenAIResponseObjectStreamResponseFileSearchCallInProgress",
|
||||
"OpenAIResponseObjectStreamResponseFileSearchCallSearching",
|
||||
"OpenAIResponseObjectStreamResponseFunctionCallArgumentsDelta",
|
||||
"OpenAIResponseObjectStreamResponseFunctionCallArgumentsDone",
|
||||
"OpenAIResponseObjectStreamResponseInProgress",
|
||||
"OpenAIResponseObjectStreamResponseIncomplete",
|
||||
"OpenAIResponseObjectStreamResponseMcpCallArgumentsDelta",
|
||||
"OpenAIResponseObjectStreamResponseMcpCallArgumentsDone",
|
||||
"OpenAIResponseObjectStreamResponseMcpCallCompleted",
|
||||
"OpenAIResponseObjectStreamResponseMcpCallFailed",
|
||||
"OpenAIResponseObjectStreamResponseMcpCallInProgress",
|
||||
"OpenAIResponseObjectStreamResponseMcpListToolsCompleted",
|
||||
"OpenAIResponseObjectStreamResponseMcpListToolsFailed",
|
||||
"OpenAIResponseObjectStreamResponseMcpListToolsInProgress",
|
||||
"OpenAIResponseObjectStreamResponseOutputItemAdded",
|
||||
"OpenAIResponseObjectStreamResponseOutputItemDone",
|
||||
"OpenAIResponseObjectStreamResponseOutputTextAnnotationAdded",
|
||||
"OpenAIResponseObjectStreamResponseOutputTextDelta",
|
||||
"OpenAIResponseObjectStreamResponseOutputTextDone",
|
||||
"OpenAIResponseObjectStreamResponseReasoningSummaryPartAdded",
|
||||
"OpenAIResponseObjectStreamResponseReasoningSummaryPartDone",
|
||||
"OpenAIResponseObjectStreamResponseReasoningSummaryTextDelta",
|
||||
"OpenAIResponseObjectStreamResponseReasoningSummaryTextDone",
|
||||
"OpenAIResponseObjectStreamResponseReasoningTextDelta",
|
||||
"OpenAIResponseObjectStreamResponseReasoningTextDone",
|
||||
"OpenAIResponseObjectStreamResponseRefusalDelta",
|
||||
"OpenAIResponseObjectStreamResponseRefusalDone",
|
||||
"OpenAIResponseObjectStreamResponseWebSearchCallCompleted",
|
||||
"OpenAIResponseObjectStreamResponseWebSearchCallInProgress",
|
||||
"OpenAIResponseObjectStreamResponseWebSearchCallSearching",
|
||||
"OpenAIDeleteResponseObject",
|
||||
"ListOpenAIResponseInputItem",
|
||||
"RunShieldRequest",
|
||||
"RunShieldResponse",
|
||||
"SafetyViolation",
|
||||
"ViolationLevel",
|
||||
"AggregationFunctionType",
|
||||
"ArrayType",
|
||||
"BasicScoringFnParams",
|
||||
"BooleanType",
|
||||
"ChatCompletionInputType",
|
||||
"CompletionInputType",
|
||||
"JsonType",
|
||||
"LLMAsJudgeScoringFnParams",
|
||||
"NumberType",
|
||||
"ObjectType",
|
||||
"RegexParserScoringFnParams",
|
||||
"ScoringFn",
|
||||
"ScoringFnParams",
|
||||
"ScoringFnParamsType",
|
||||
"StringType",
|
||||
"UnionType",
|
||||
"ListScoringFunctionsResponse",
|
||||
"ScoreRequest",
|
||||
"ScoreResponse",
|
||||
"ScoringResult",
|
||||
"ScoreBatchRequest",
|
||||
"ScoreBatchResponse",
|
||||
"Shield",
|
||||
"ListShieldsResponse",
|
||||
"InvokeToolRequest",
|
||||
"ImageContentItem",
|
||||
"InterleavedContent",
|
||||
"InterleavedContentItem",
|
||||
"TextContentItem",
|
||||
"ToolInvocationResult",
|
||||
"URL",
|
||||
"ToolDef",
|
||||
"ListToolDefsResponse",
|
||||
"ToolGroup",
|
||||
"ListToolGroupsResponse",
|
||||
"Chunk",
|
||||
"ChunkMetadata",
|
||||
"InsertChunksRequest",
|
||||
"QueryChunksRequest",
|
||||
"QueryChunksResponse",
|
||||
"VectorStoreFileCounts",
|
||||
"VectorStoreListResponse",
|
||||
"VectorStoreObject",
|
||||
"VectorStoreChunkingStrategy",
|
||||
"VectorStoreChunkingStrategyAuto",
|
||||
"VectorStoreChunkingStrategyStatic",
|
||||
"VectorStoreChunkingStrategyStaticConfig",
|
||||
"OpenAICreateVectorStoreRequestWithExtraBody",
|
||||
"OpenaiUpdateVectorStoreRequest",
|
||||
"VectorStoreDeleteResponse",
|
||||
"OpenAICreateVectorStoreFileBatchRequestWithExtraBody",
|
||||
"VectorStoreFileBatchObject",
|
||||
"VectorStoreFileStatus",
|
||||
"VectorStoreFileLastError",
|
||||
"VectorStoreFileObject",
|
||||
"VectorStoreFilesListInBatchResponse",
|
||||
"VectorStoreListFilesResponse",
|
||||
"OpenaiAttachFileToVectorStoreRequest",
|
||||
"OpenaiUpdateVectorStoreFileRequest",
|
||||
"VectorStoreFileDeleteResponse",
|
||||
"bool",
|
||||
"VectorStoreContent",
|
||||
"VectorStoreFileContentResponse",
|
||||
"OpenaiSearchVectorStoreRequest",
|
||||
"VectorStoreSearchResponse",
|
||||
"VectorStoreSearchResponsePage",
|
||||
"VersionInfo",
|
||||
"AppendRowsRequest",
|
||||
"PaginatedResponse",
|
||||
"Dataset",
|
||||
"RowsDataSource",
|
||||
"URIDataSource",
|
||||
"ListDatasetsResponse",
|
||||
"Benchmark",
|
||||
"ListBenchmarksResponse",
|
||||
"BenchmarkConfig",
|
||||
"GreedySamplingStrategy",
|
||||
"ModelCandidate",
|
||||
"SamplingParams",
|
||||
"SystemMessage",
|
||||
"TopKSamplingStrategy",
|
||||
"TopPSamplingStrategy",
|
||||
"EvaluateRowsRequest",
|
||||
"EvaluateResponse",
|
||||
"RunEvalRequest",
|
||||
"Job",
|
||||
"RerankRequest",
|
||||
"RerankData",
|
||||
"RerankResponse",
|
||||
"Checkpoint",
|
||||
"PostTrainingJobArtifactsResponse",
|
||||
"PostTrainingMetric",
|
||||
"CancelTrainingJobRequest",
|
||||
"PostTrainingJobStatusResponse",
|
||||
"ListPostTrainingJobsResponse",
|
||||
"DPOAlignmentConfig",
|
||||
"DPOLossType",
|
||||
"DataConfig",
|
||||
"DatasetFormat",
|
||||
"EfficiencyConfig",
|
||||
"OptimizerConfig",
|
||||
"OptimizerType",
|
||||
"TrainingConfig",
|
||||
"PreferenceOptimizeRequest",
|
||||
"PostTrainingJob",
|
||||
"AlgorithmConfig",
|
||||
"LoraFinetuningConfig",
|
||||
"QATFinetuningConfig",
|
||||
"SupervisedFineTuneRequest",
|
||||
"RegisterModelRequest",
|
||||
"ParamType",
|
||||
"RegisterScoringFunctionRequest",
|
||||
"RegisterShieldRequest",
|
||||
"RegisterToolGroupRequest",
|
||||
"DataSource",
|
||||
"RegisterDatasetRequest",
|
||||
"RegisterBenchmarkRequest",
|
||||
]
|
||||
|
||||
LEGACY_RESPONSE_ORDER = ["BadRequest400", "TooManyRequests429", "InternalServerError500", "DefaultError"]
|
||||
|
||||
LEGACY_TAGS = [
|
||||
{
|
||||
"description": "APIs for creating and interacting with agentic systems.",
|
||||
"name": "Agents",
|
||||
"x-displayName": "Agents",
|
||||
},
|
||||
{
|
||||
"description": "The API is designed to allow use of openai client libraries for seamless integration.\n"
|
||||
"\n"
|
||||
"This API provides the following extensions:\n"
|
||||
" - idempotent batch creation\n"
|
||||
"\n"
|
||||
"Note: This API is currently under active development and may undergo changes.",
|
||||
"name": "Batches",
|
||||
"x-displayName": "The Batches API enables efficient processing of multiple requests in a single operation, "
|
||||
"particularly useful for processing large datasets, batch evaluation workflows, and cost-effective "
|
||||
"inference at scale.",
|
||||
},
|
||||
{"description": "", "name": "Benchmarks"},
|
||||
{
|
||||
"description": "Protocol for conversation management operations.",
|
||||
"name": "Conversations",
|
||||
"x-displayName": "Conversations",
|
||||
},
|
||||
{"description": "", "name": "DatasetIO"},
|
||||
{"description": "", "name": "Datasets"},
|
||||
{
|
||||
"description": "Llama Stack Evaluation API for running evaluations on model and agent candidates.",
|
||||
"name": "Eval",
|
||||
"x-displayName": "Evaluations",
|
||||
},
|
||||
{
|
||||
"description": "This API is used to upload documents that can be used with other Llama Stack APIs.",
|
||||
"name": "Files",
|
||||
"x-displayName": "Files",
|
||||
},
|
||||
{
|
||||
"description": "Llama Stack Inference API for generating completions, chat completions, and embeddings.\n"
|
||||
"\n"
|
||||
"This API provides the raw interface to the underlying models. Three kinds of models are supported:\n"
|
||||
'- LLM models: these models generate "raw" and "chat" (conversational) completions.\n'
|
||||
"- Embedding models: these models generate embeddings to be used for semantic search.\n"
|
||||
"- Rerank models: these models reorder the documents based on their relevance to a query.",
|
||||
"name": "Inference",
|
||||
"x-displayName": "Inference",
|
||||
},
|
||||
{
|
||||
"description": "APIs for inspecting the Llama Stack service, including health status, available API routes with "
|
||||
"methods and implementing providers.",
|
||||
"name": "Inspect",
|
||||
"x-displayName": "Inspect",
|
||||
},
|
||||
{"description": "", "name": "Models"},
|
||||
{"description": "", "name": "PostTraining (Coming Soon)"},
|
||||
{"description": "Protocol for prompt management operations.", "name": "Prompts", "x-displayName": "Prompts"},
|
||||
{
|
||||
"description": "Providers API for inspecting, listing, and modifying providers and their configurations.",
|
||||
"name": "Providers",
|
||||
"x-displayName": "Providers",
|
||||
},
|
||||
{"description": "OpenAI-compatible Moderations API.", "name": "Safety", "x-displayName": "Safety"},
|
||||
{"description": "", "name": "Scoring"},
|
||||
{"description": "", "name": "ScoringFunctions"},
|
||||
{"description": "", "name": "Shields"},
|
||||
{"description": "", "name": "ToolGroups"},
|
||||
{"description": "", "name": "ToolRuntime"},
|
||||
{"description": "", "name": "VectorIO"},
|
||||
]
|
||||
|
||||
LEGACY_TAG_ORDER = [
|
||||
"Agents",
|
||||
"Batches",
|
||||
"Benchmarks",
|
||||
"Conversations",
|
||||
"DatasetIO",
|
||||
"Datasets",
|
||||
"Eval",
|
||||
"Files",
|
||||
"Inference",
|
||||
"Inspect",
|
||||
"Models",
|
||||
"PostTraining (Coming Soon)",
|
||||
"Prompts",
|
||||
"Providers",
|
||||
"Safety",
|
||||
"Scoring",
|
||||
"ScoringFunctions",
|
||||
"Shields",
|
||||
"ToolGroups",
|
||||
"ToolRuntime",
|
||||
"VectorIO",
|
||||
]
|
||||
|
||||
LEGACY_TAG_GROUPS = [
|
||||
{
|
||||
"name": "Operations",
|
||||
"tags": [
|
||||
"Agents",
|
||||
"Batches",
|
||||
"Benchmarks",
|
||||
"Conversations",
|
||||
"DatasetIO",
|
||||
"Datasets",
|
||||
"Eval",
|
||||
"Files",
|
||||
"Inference",
|
||||
"Inspect",
|
||||
"Models",
|
||||
"PostTraining (Coming Soon)",
|
||||
"Prompts",
|
||||
"Providers",
|
||||
"Safety",
|
||||
"Scoring",
|
||||
"ScoringFunctions",
|
||||
"Shields",
|
||||
"ToolGroups",
|
||||
"ToolRuntime",
|
||||
"VectorIO",
|
||||
],
|
||||
}
|
||||
]
|
||||
|
||||
LEGACY_SECURITY = [{"Default": []}]
|
||||
|
||||
LEGACY_OPERATION_KEYS = [
|
||||
"responses",
|
||||
"tags",
|
||||
"summary",
|
||||
"description",
|
||||
"operationId",
|
||||
"parameters",
|
||||
"requestBody",
|
||||
"deprecated",
|
||||
]
|
||||
91
scripts/openapi_generator/app.py
Normal file
91
scripts/openapi_generator/app.py
Normal file
|
|
@ -0,0 +1,91 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
"""
|
||||
FastAPI app creation for OpenAPI generation.
|
||||
"""
|
||||
|
||||
import inspect
|
||||
from typing import Any
|
||||
|
||||
from fastapi import FastAPI
|
||||
|
||||
from llama_stack.core.resolver import api_protocol_map
|
||||
from llama_stack_api import Api
|
||||
|
||||
from .state import _protocol_methods_cache
|
||||
|
||||
|
||||
def _get_protocol_method(api: Api, method_name: str) -> Any | None:
|
||||
"""
|
||||
Get a protocol method function by API and method name.
|
||||
Uses caching to avoid repeated lookups.
|
||||
|
||||
Args:
|
||||
api: The API enum
|
||||
method_name: The method name (function name)
|
||||
|
||||
Returns:
|
||||
The function object, or None if not found
|
||||
"""
|
||||
global _protocol_methods_cache
|
||||
|
||||
if _protocol_methods_cache is None:
|
||||
_protocol_methods_cache = {}
|
||||
protocols = api_protocol_map()
|
||||
from llama_stack_api.tools import SpecialToolGroup, ToolRuntime
|
||||
|
||||
toolgroup_protocols = {
|
||||
SpecialToolGroup.rag_tool: ToolRuntime,
|
||||
}
|
||||
|
||||
for api_key, protocol in protocols.items():
|
||||
method_map: dict[str, Any] = {}
|
||||
protocol_methods = inspect.getmembers(protocol, predicate=inspect.isfunction)
|
||||
for name, method in protocol_methods:
|
||||
method_map[name] = method
|
||||
|
||||
# Handle tool_runtime special case
|
||||
if api_key == Api.tool_runtime:
|
||||
for tool_group, sub_protocol in toolgroup_protocols.items():
|
||||
sub_protocol_methods = inspect.getmembers(sub_protocol, predicate=inspect.isfunction)
|
||||
for name, method in sub_protocol_methods:
|
||||
if hasattr(method, "__webmethod__"):
|
||||
method_map[f"{tool_group.value}.{name}"] = method
|
||||
|
||||
_protocol_methods_cache[api_key] = method_map
|
||||
|
||||
return _protocol_methods_cache.get(api, {}).get(method_name)
|
||||
|
||||
|
||||
def create_llama_stack_app() -> FastAPI:
|
||||
"""
|
||||
Create a FastAPI app that represents the Llama Stack API.
|
||||
This uses the existing route discovery system to automatically find all routes.
|
||||
"""
|
||||
app = FastAPI(
|
||||
title="Llama Stack API",
|
||||
description="A comprehensive API for building and deploying AI applications",
|
||||
version="1.0.0",
|
||||
servers=[
|
||||
{"url": "http://any-hosted-llama-stack.com"},
|
||||
],
|
||||
)
|
||||
|
||||
# Get all API routes
|
||||
from llama_stack.core.server.routes import get_all_api_routes
|
||||
|
||||
api_routes = get_all_api_routes()
|
||||
|
||||
# Create FastAPI routes from the discovered routes
|
||||
from . import endpoints
|
||||
|
||||
for api, routes in api_routes.items():
|
||||
for route, webmethod in routes:
|
||||
# Convert the route to a FastAPI endpoint
|
||||
endpoints._create_fastapi_endpoint(app, route, webmethod, api)
|
||||
|
||||
return app
|
||||
657
scripts/openapi_generator/endpoints.py
Normal file
657
scripts/openapi_generator/endpoints.py
Normal file
|
|
@ -0,0 +1,657 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
"""
|
||||
Endpoint generation logic for FastAPI OpenAPI generation.
|
||||
"""
|
||||
|
||||
import inspect
|
||||
import re
|
||||
import types
|
||||
import typing
|
||||
from typing import Annotated, Any, get_args, get_origin
|
||||
|
||||
from fastapi import FastAPI
|
||||
from pydantic import Field, create_model
|
||||
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack_api import Api
|
||||
from llama_stack_api.schema_utils import get_registered_schema_info
|
||||
|
||||
from . import app as app_module
|
||||
from .state import _extra_body_fields, register_dynamic_model
|
||||
|
||||
logger = get_logger(name=__name__, category="core")
|
||||
|
||||
|
||||
def _to_pascal_case(segment: str) -> str:
|
||||
tokens = re.findall(r"[A-Za-z]+|\d+", segment)
|
||||
return "".join(token.capitalize() for token in tokens if token)
|
||||
|
||||
|
||||
def _compose_request_model_name(api: Api, method_name: str, variant: str | None = None) -> str:
|
||||
"""Generate a deterministic model name from the protocol method."""
|
||||
|
||||
def _to_pascal_from_snake(value: str) -> str:
|
||||
return "".join(segment.capitalize() for segment in value.split("_") if segment)
|
||||
|
||||
base_name = _to_pascal_from_snake(method_name)
|
||||
if not base_name:
|
||||
base_name = _to_pascal_case(api.value)
|
||||
base_name = f"{base_name}Request"
|
||||
if variant:
|
||||
base_name = f"{base_name}{variant}"
|
||||
return base_name
|
||||
|
||||
|
||||
def _extract_path_parameters(path: str) -> list[dict[str, Any]]:
|
||||
"""Extract path parameters from a URL path and return them as OpenAPI parameter definitions."""
|
||||
matches = re.findall(r"\{([^}:]+)(?::[^}]+)?\}", path)
|
||||
return [
|
||||
{
|
||||
"name": param_name,
|
||||
"in": "path",
|
||||
"required": True,
|
||||
"schema": {"type": "string"},
|
||||
"description": f"Path parameter: {param_name}",
|
||||
}
|
||||
for param_name in matches
|
||||
]
|
||||
|
||||
|
||||
def _create_endpoint_with_request_model(
|
||||
request_model: type, response_model: type | None, operation_description: str | None
|
||||
):
|
||||
"""Create an endpoint function with a request body model."""
|
||||
|
||||
async def endpoint(request: request_model) -> response_model:
|
||||
return response_model() if response_model else {}
|
||||
|
||||
if operation_description:
|
||||
endpoint.__doc__ = operation_description
|
||||
return endpoint
|
||||
|
||||
|
||||
def _build_field_definitions(query_parameters: list[tuple[str, type, Any]], use_any: bool = False) -> dict[str, tuple]:
|
||||
"""Build field definitions for a Pydantic model from query parameters."""
|
||||
from typing import Any
|
||||
|
||||
field_definitions = {}
|
||||
for param_name, param_type, default_value in query_parameters:
|
||||
if use_any:
|
||||
field_definitions[param_name] = (Any, ... if default_value is inspect.Parameter.empty else default_value)
|
||||
continue
|
||||
|
||||
base_type = param_type
|
||||
extracted_field = None
|
||||
if get_origin(param_type) is Annotated:
|
||||
args = get_args(param_type)
|
||||
if args:
|
||||
base_type = args[0]
|
||||
for arg in args[1:]:
|
||||
if isinstance(arg, Field):
|
||||
extracted_field = arg
|
||||
break
|
||||
|
||||
try:
|
||||
if extracted_field:
|
||||
field_definitions[param_name] = (base_type, extracted_field)
|
||||
else:
|
||||
field_definitions[param_name] = (
|
||||
base_type,
|
||||
... if default_value is inspect.Parameter.empty else default_value,
|
||||
)
|
||||
except (TypeError, ValueError):
|
||||
field_definitions[param_name] = (Any, ... if default_value is inspect.Parameter.empty else default_value)
|
||||
|
||||
# Ensure all parameters are included
|
||||
expected_params = {name for name, _, _ in query_parameters}
|
||||
missing = expected_params - set(field_definitions.keys())
|
||||
if missing:
|
||||
for param_name, _, default_value in query_parameters:
|
||||
if param_name in missing:
|
||||
field_definitions[param_name] = (
|
||||
Any,
|
||||
... if default_value is inspect.Parameter.empty else default_value,
|
||||
)
|
||||
|
||||
return field_definitions
|
||||
|
||||
|
||||
def _create_dynamic_request_model(
|
||||
api: Api,
|
||||
webmethod,
|
||||
method_name: str,
|
||||
http_method: str,
|
||||
query_parameters: list[tuple[str, type, Any]],
|
||||
use_any: bool = False,
|
||||
variant_suffix: str | None = None,
|
||||
) -> type | None:
|
||||
"""Create a dynamic Pydantic model for request body."""
|
||||
try:
|
||||
field_definitions = _build_field_definitions(query_parameters, use_any)
|
||||
if not field_definitions:
|
||||
return None
|
||||
model_name = _compose_request_model_name(api, method_name, variant_suffix or None)
|
||||
request_model = create_model(model_name, **field_definitions)
|
||||
return register_dynamic_model(model_name, request_model)
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
|
||||
def _build_signature_params(
|
||||
query_parameters: list[tuple[str, type, Any]],
|
||||
) -> tuple[list[inspect.Parameter], dict[str, type]]:
|
||||
"""Build signature parameters and annotations from query parameters."""
|
||||
signature_params = []
|
||||
param_annotations = {}
|
||||
for param_name, param_type, default_value in query_parameters:
|
||||
param_annotations[param_name] = param_type
|
||||
signature_params.append(
|
||||
inspect.Parameter(
|
||||
param_name,
|
||||
inspect.Parameter.POSITIONAL_OR_KEYWORD,
|
||||
default=default_value if default_value is not inspect.Parameter.empty else inspect.Parameter.empty,
|
||||
annotation=param_type,
|
||||
)
|
||||
)
|
||||
return signature_params, param_annotations
|
||||
|
||||
|
||||
def _extract_operation_description_from_docstring(api: Api, method_name: str) -> str | None:
|
||||
"""Extract operation description from the actual function docstring."""
|
||||
func = app_module._get_protocol_method(api, method_name)
|
||||
if not func or not func.__doc__:
|
||||
return None
|
||||
|
||||
doc_lines = func.__doc__.split("\n")
|
||||
description_lines = []
|
||||
metadata_markers = (":param", ":type", ":return", ":returns", ":raises", ":exception", ":yield", ":yields", ":cvar")
|
||||
|
||||
for line in doc_lines:
|
||||
if line.strip().startswith(metadata_markers):
|
||||
break
|
||||
description_lines.append(line)
|
||||
|
||||
description = "\n".join(description_lines).strip()
|
||||
return description if description else None
|
||||
|
||||
|
||||
def _extract_response_description_from_docstring(webmethod, response_model, api: Api, method_name: str) -> str:
|
||||
"""Extract response description from the actual function docstring."""
|
||||
func = app_module._get_protocol_method(api, method_name)
|
||||
if not func or not func.__doc__:
|
||||
return "Successful Response"
|
||||
for line in func.__doc__.split("\n"):
|
||||
if line.strip().startswith(":returns:"):
|
||||
if desc := line.strip()[9:].strip():
|
||||
return desc
|
||||
return "Successful Response"
|
||||
|
||||
|
||||
def _get_tag_from_api(api: Api) -> str:
|
||||
"""Extract a tag name from the API enum for API grouping."""
|
||||
return api.value.replace("_", " ").title()
|
||||
|
||||
|
||||
def _is_file_or_form_param(param_type: Any) -> bool:
|
||||
"""Check if a parameter type is annotated with File() or Form()."""
|
||||
if get_origin(param_type) is Annotated:
|
||||
args = get_args(param_type)
|
||||
if len(args) > 1:
|
||||
# Check metadata for File or Form
|
||||
for metadata in args[1:]:
|
||||
# Check if it's a File or Form instance
|
||||
if hasattr(metadata, "__class__"):
|
||||
class_name = metadata.__class__.__name__
|
||||
if class_name in ("File", "Form"):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def _is_extra_body_field(metadata_item: Any) -> bool:
|
||||
"""Check if a metadata item is an ExtraBodyField instance."""
|
||||
from llama_stack_api.schema_utils import ExtraBodyField
|
||||
|
||||
return isinstance(metadata_item, ExtraBodyField)
|
||||
|
||||
|
||||
def _is_async_iterator_type(type_obj: Any) -> bool:
|
||||
"""Check if a type is AsyncIterator or AsyncIterable."""
|
||||
from collections.abc import AsyncIterable, AsyncIterator
|
||||
|
||||
origin = get_origin(type_obj)
|
||||
if origin is None:
|
||||
# Check if it's the class itself
|
||||
return type_obj in (AsyncIterator, AsyncIterable) or (
|
||||
hasattr(type_obj, "__origin__") and type_obj.__origin__ in (AsyncIterator, AsyncIterable)
|
||||
)
|
||||
return origin in (AsyncIterator, AsyncIterable)
|
||||
|
||||
|
||||
def _extract_response_models_from_union(union_type: Any) -> tuple[type | None, type | None]:
|
||||
"""
|
||||
Extract non-streaming and streaming response models from a union type.
|
||||
|
||||
Returns:
|
||||
tuple: (non_streaming_model, streaming_model)
|
||||
"""
|
||||
non_streaming_model = None
|
||||
streaming_model = None
|
||||
|
||||
args = get_args(union_type)
|
||||
for arg in args:
|
||||
# Check if it's an AsyncIterator
|
||||
if _is_async_iterator_type(arg):
|
||||
# Extract the type argument from AsyncIterator[T]
|
||||
iterator_args = get_args(arg)
|
||||
if iterator_args:
|
||||
inner_type = iterator_args[0]
|
||||
# Check if the inner type is a registered schema (union type)
|
||||
# or a Pydantic model
|
||||
if hasattr(inner_type, "model_json_schema"):
|
||||
streaming_model = inner_type
|
||||
else:
|
||||
# Might be a registered schema - check if it's registered
|
||||
if get_registered_schema_info(inner_type):
|
||||
# We'll need to look this up later, but for now store the type
|
||||
streaming_model = inner_type
|
||||
elif hasattr(arg, "model_json_schema"):
|
||||
# Non-streaming Pydantic model
|
||||
if non_streaming_model is None:
|
||||
non_streaming_model = arg
|
||||
|
||||
return non_streaming_model, streaming_model
|
||||
|
||||
|
||||
def _find_models_for_endpoint(
|
||||
webmethod, api: Api, method_name: str, is_post_put: bool = False
|
||||
) -> tuple[type | None, type | None, list[tuple[str, type, Any]], list[inspect.Parameter], type | None, str | None]:
|
||||
"""
|
||||
Find appropriate request and response models for an endpoint by analyzing the actual function signature.
|
||||
This uses the protocol function to determine the correct models dynamically.
|
||||
|
||||
Args:
|
||||
webmethod: The webmethod metadata
|
||||
api: The API enum for looking up the function
|
||||
method_name: The method name (function name)
|
||||
is_post_put: Whether this is a POST, PUT, or PATCH request (GET requests should never have request bodies)
|
||||
|
||||
Returns:
|
||||
tuple: (request_model, response_model, query_parameters, file_form_params, streaming_response_model, response_schema_name)
|
||||
where query_parameters is a list of (name, type, default_value) tuples
|
||||
and file_form_params is a list of inspect.Parameter objects for File()/Form() params
|
||||
and streaming_response_model is the model for streaming responses (AsyncIterator content)
|
||||
"""
|
||||
route_descriptor = f"{webmethod.method or 'UNKNOWN'} {webmethod.route}"
|
||||
try:
|
||||
# Get the function from the protocol
|
||||
func = app_module._get_protocol_method(api, method_name)
|
||||
if not func:
|
||||
logger.warning("No protocol method for %s.%s (%s)", api, method_name, route_descriptor)
|
||||
return None, None, [], [], None, None
|
||||
|
||||
# Analyze the function signature
|
||||
sig = inspect.signature(func)
|
||||
|
||||
# Find request model and collect all body parameters
|
||||
request_model = None
|
||||
query_parameters = []
|
||||
file_form_params = []
|
||||
path_params = set()
|
||||
extra_body_params = []
|
||||
response_schema_name = None
|
||||
|
||||
# Extract path parameters from the route
|
||||
if webmethod and hasattr(webmethod, "route"):
|
||||
path_matches = re.findall(r"\{([^}:]+)(?::[^}]+)?\}", webmethod.route)
|
||||
path_params = set(path_matches)
|
||||
|
||||
for param_name, param in sig.parameters.items():
|
||||
if param_name == "self":
|
||||
continue
|
||||
|
||||
# Skip *args and **kwargs parameters - these are not real API parameters
|
||||
if param.kind in (inspect.Parameter.VAR_POSITIONAL, inspect.Parameter.VAR_KEYWORD):
|
||||
continue
|
||||
|
||||
# Check if this is a path parameter
|
||||
if param_name in path_params:
|
||||
# Path parameters are handled separately, skip them
|
||||
continue
|
||||
|
||||
# Check if it's a File() or Form() parameter - these need special handling
|
||||
param_type = param.annotation
|
||||
if _is_file_or_form_param(param_type):
|
||||
# File() and Form() parameters must be in the function signature directly
|
||||
# They cannot be part of a Pydantic model
|
||||
file_form_params.append(param)
|
||||
continue
|
||||
|
||||
# Check for ExtraBodyField in Annotated types
|
||||
is_extra_body = False
|
||||
extra_body_description = None
|
||||
if get_origin(param_type) is Annotated:
|
||||
args = get_args(param_type)
|
||||
base_type = args[0] if args else param_type
|
||||
metadata = args[1:] if len(args) > 1 else []
|
||||
|
||||
# Check if any metadata item is an ExtraBodyField
|
||||
for metadata_item in metadata:
|
||||
if _is_extra_body_field(metadata_item):
|
||||
is_extra_body = True
|
||||
extra_body_description = metadata_item.description
|
||||
break
|
||||
|
||||
if is_extra_body:
|
||||
# Store as extra body parameter - exclude from request model
|
||||
extra_body_params.append((param_name, base_type, extra_body_description))
|
||||
continue
|
||||
|
||||
# Check if it's a Pydantic model (for POST/PUT requests)
|
||||
if hasattr(param_type, "model_json_schema"):
|
||||
# Collect all body parameters including Pydantic models
|
||||
# We'll decide later whether to use a single model or create a combined one
|
||||
query_parameters.append((param_name, param_type, param.default))
|
||||
elif get_origin(param_type) is Annotated:
|
||||
# Handle Annotated types - get the base type
|
||||
args = get_args(param_type)
|
||||
if args and hasattr(args[0], "model_json_schema"):
|
||||
# Collect Pydantic models from Annotated types
|
||||
query_parameters.append((param_name, args[0], param.default))
|
||||
else:
|
||||
# Regular annotated parameter (but not File/Form, already handled above)
|
||||
query_parameters.append((param_name, param_type, param.default))
|
||||
else:
|
||||
# This is likely a body parameter for POST/PUT or query parameter for GET
|
||||
# Store the parameter info for later use
|
||||
# Preserve inspect.Parameter.empty to distinguish "no default" from "default=None"
|
||||
default_value = param.default
|
||||
|
||||
# Extract the base type from union types (e.g., str | None -> str)
|
||||
# Also make it safe for FastAPI to avoid forward reference issues
|
||||
query_parameters.append((param_name, param_type, default_value))
|
||||
|
||||
# Store extra body fields for later use in post-processing
|
||||
# We'll store them when the endpoint is created, as we need the full path
|
||||
# For now, attach to the function for later retrieval
|
||||
if extra_body_params:
|
||||
func._extra_body_params = extra_body_params # type: ignore
|
||||
|
||||
# If there's exactly one body parameter and it's a Pydantic model, use it directly
|
||||
# Otherwise, we'll create a combined request model from all parameters
|
||||
# BUT: For GET requests, never create a request body - all parameters should be query parameters
|
||||
if is_post_put and len(query_parameters) == 1:
|
||||
param_name, param_type, default_value = query_parameters[0]
|
||||
if hasattr(param_type, "model_json_schema"):
|
||||
request_model = param_type
|
||||
query_parameters = [] # Clear query_parameters so we use the single model
|
||||
|
||||
# Find response model from return annotation
|
||||
# Also detect streaming response models (AsyncIterator)
|
||||
response_model = None
|
||||
streaming_response_model = None
|
||||
return_annotation = sig.return_annotation
|
||||
if return_annotation != inspect.Signature.empty:
|
||||
origin = get_origin(return_annotation)
|
||||
if hasattr(return_annotation, "model_json_schema"):
|
||||
response_model = return_annotation
|
||||
elif origin is Annotated:
|
||||
# Handle Annotated return types
|
||||
args = get_args(return_annotation)
|
||||
if args:
|
||||
# Check if the first argument is a Pydantic model
|
||||
if hasattr(args[0], "model_json_schema"):
|
||||
response_model = args[0]
|
||||
else:
|
||||
# Check if the first argument is a union type
|
||||
inner_origin = get_origin(args[0])
|
||||
if inner_origin is not None and (
|
||||
inner_origin is types.UnionType or inner_origin is typing.Union
|
||||
):
|
||||
response_model, streaming_response_model = _extract_response_models_from_union(args[0])
|
||||
elif origin is not None and (origin is types.UnionType or origin is typing.Union):
|
||||
# Handle union types - extract both non-streaming and streaming models
|
||||
response_model, streaming_response_model = _extract_response_models_from_union(return_annotation)
|
||||
else:
|
||||
try:
|
||||
from fastapi import Response as FastAPIResponse
|
||||
except ImportError:
|
||||
fastapi_response_cls = None
|
||||
else:
|
||||
fastapi_response_cls = FastAPIResponse
|
||||
try:
|
||||
from starlette.responses import Response as StarletteResponse
|
||||
except ImportError:
|
||||
starlette_response_cls = None
|
||||
else:
|
||||
starlette_response_cls = StarletteResponse
|
||||
|
||||
response_types = tuple(t for t in (fastapi_response_cls, starlette_response_cls) if t is not None)
|
||||
if response_types and any(return_annotation is t for t in response_types):
|
||||
response_schema_name = "Response"
|
||||
|
||||
return (
|
||||
request_model,
|
||||
response_model,
|
||||
query_parameters,
|
||||
file_form_params,
|
||||
streaming_response_model,
|
||||
response_schema_name,
|
||||
)
|
||||
|
||||
except Exception as exc:
|
||||
logger.warning(
|
||||
"Failed to analyze endpoint %s.%s (%s): %s", api, method_name, route_descriptor, exc, exc_info=True
|
||||
)
|
||||
return None, None, [], [], None, None
|
||||
|
||||
|
||||
def _create_fastapi_endpoint(app: FastAPI, route, webmethod, api: Api):
|
||||
"""Create a FastAPI endpoint from a discovered route and webmethod."""
|
||||
path = route.path
|
||||
raw_methods = route.methods or set()
|
||||
method_list = sorted({method.upper() for method in raw_methods if method and method.upper() != "HEAD"})
|
||||
if not method_list:
|
||||
method_list = ["GET"]
|
||||
primary_method = method_list[0]
|
||||
name = route.name
|
||||
fastapi_path = path.replace("{", "{").replace("}", "}")
|
||||
is_post_put = any(method in ["POST", "PUT", "PATCH"] for method in method_list)
|
||||
|
||||
(
|
||||
request_model,
|
||||
response_model,
|
||||
query_parameters,
|
||||
file_form_params,
|
||||
streaming_response_model,
|
||||
response_schema_name,
|
||||
) = _find_models_for_endpoint(webmethod, api, name, is_post_put)
|
||||
operation_description = _extract_operation_description_from_docstring(api, name)
|
||||
response_description = _extract_response_description_from_docstring(webmethod, response_model, api, name)
|
||||
|
||||
# Retrieve and store extra body fields for this endpoint
|
||||
func = app_module._get_protocol_method(api, name)
|
||||
extra_body_params = getattr(func, "_extra_body_params", []) if func else []
|
||||
if extra_body_params:
|
||||
for method in method_list:
|
||||
key = (fastapi_path, method.upper())
|
||||
_extra_body_fields[key] = extra_body_params
|
||||
|
||||
if is_post_put and not request_model and not file_form_params and query_parameters:
|
||||
request_model = _create_dynamic_request_model(
|
||||
api, webmethod, name, primary_method, query_parameters, use_any=False
|
||||
)
|
||||
if not request_model:
|
||||
request_model = _create_dynamic_request_model(
|
||||
api, webmethod, name, primary_method, query_parameters, use_any=True, variant_suffix="Loose"
|
||||
)
|
||||
if request_model:
|
||||
query_parameters = []
|
||||
|
||||
if file_form_params and is_post_put:
|
||||
signature_params = list(file_form_params)
|
||||
param_annotations = {param.name: param.annotation for param in file_form_params}
|
||||
for param_name, param_type, default_value in query_parameters:
|
||||
signature_params.append(
|
||||
inspect.Parameter(
|
||||
param_name,
|
||||
inspect.Parameter.POSITIONAL_OR_KEYWORD,
|
||||
default=default_value if default_value is not inspect.Parameter.empty else inspect.Parameter.empty,
|
||||
annotation=param_type,
|
||||
)
|
||||
)
|
||||
param_annotations[param_name] = param_type
|
||||
|
||||
async def file_form_endpoint():
|
||||
return response_model() if response_model else {}
|
||||
|
||||
if operation_description:
|
||||
file_form_endpoint.__doc__ = operation_description
|
||||
file_form_endpoint.__signature__ = inspect.Signature(signature_params)
|
||||
file_form_endpoint.__annotations__ = param_annotations
|
||||
endpoint_func = file_form_endpoint
|
||||
elif request_model and response_model:
|
||||
endpoint_func = _create_endpoint_with_request_model(request_model, response_model, operation_description)
|
||||
elif request_model:
|
||||
endpoint_func = _create_endpoint_with_request_model(request_model, None, operation_description)
|
||||
elif response_model and query_parameters:
|
||||
if is_post_put:
|
||||
request_model = _create_dynamic_request_model(
|
||||
api, webmethod, name, primary_method, query_parameters, use_any=False
|
||||
)
|
||||
if not request_model:
|
||||
request_model = _create_dynamic_request_model(
|
||||
api, webmethod, name, primary_method, query_parameters, use_any=True, variant_suffix="Loose"
|
||||
)
|
||||
|
||||
if request_model:
|
||||
endpoint_func = _create_endpoint_with_request_model(
|
||||
request_model, response_model, operation_description
|
||||
)
|
||||
else:
|
||||
|
||||
async def empty_endpoint() -> response_model:
|
||||
return response_model() if response_model else {}
|
||||
|
||||
if operation_description:
|
||||
empty_endpoint.__doc__ = operation_description
|
||||
endpoint_func = empty_endpoint
|
||||
else:
|
||||
sorted_params = sorted(query_parameters, key=lambda x: (x[2] is not inspect.Parameter.empty, x[0]))
|
||||
signature_params, param_annotations = _build_signature_params(sorted_params)
|
||||
|
||||
async def query_endpoint():
|
||||
return response_model()
|
||||
|
||||
if operation_description:
|
||||
query_endpoint.__doc__ = operation_description
|
||||
query_endpoint.__signature__ = inspect.Signature(signature_params)
|
||||
query_endpoint.__annotations__ = param_annotations
|
||||
endpoint_func = query_endpoint
|
||||
elif response_model:
|
||||
|
||||
async def response_only_endpoint() -> response_model:
|
||||
return response_model()
|
||||
|
||||
if operation_description:
|
||||
response_only_endpoint.__doc__ = operation_description
|
||||
endpoint_func = response_only_endpoint
|
||||
elif query_parameters:
|
||||
signature_params, param_annotations = _build_signature_params(query_parameters)
|
||||
|
||||
async def params_only_endpoint():
|
||||
return {}
|
||||
|
||||
if operation_description:
|
||||
params_only_endpoint.__doc__ = operation_description
|
||||
params_only_endpoint.__signature__ = inspect.Signature(signature_params)
|
||||
params_only_endpoint.__annotations__ = param_annotations
|
||||
endpoint_func = params_only_endpoint
|
||||
else:
|
||||
# Endpoint with no parameters and no response model
|
||||
# If we have a response_model from the function signature, use it even if _find_models_for_endpoint didn't find it
|
||||
# This can happen if there was an exception during model finding
|
||||
if response_model is None:
|
||||
# Try to get response model directly from the function signature as a fallback
|
||||
func = app_module._get_protocol_method(api, name)
|
||||
if func:
|
||||
try:
|
||||
sig = inspect.signature(func)
|
||||
return_annotation = sig.return_annotation
|
||||
if return_annotation != inspect.Signature.empty:
|
||||
if hasattr(return_annotation, "model_json_schema"):
|
||||
response_model = return_annotation
|
||||
elif get_origin(return_annotation) is Annotated:
|
||||
args = get_args(return_annotation)
|
||||
if args and hasattr(args[0], "model_json_schema"):
|
||||
response_model = args[0]
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
if response_model:
|
||||
|
||||
async def no_params_endpoint() -> response_model:
|
||||
return response_model() if response_model else {}
|
||||
else:
|
||||
|
||||
async def no_params_endpoint():
|
||||
return {}
|
||||
|
||||
if operation_description:
|
||||
no_params_endpoint.__doc__ = operation_description
|
||||
endpoint_func = no_params_endpoint
|
||||
|
||||
# Build response content with both application/json and text/event-stream if streaming
|
||||
response_content: dict[str, Any] = {}
|
||||
if response_model:
|
||||
response_content["application/json"] = {"schema": {"$ref": f"#/components/schemas/{response_model.__name__}"}}
|
||||
elif response_schema_name:
|
||||
response_content["application/json"] = {"schema": {"$ref": f"#/components/schemas/{response_schema_name}"}}
|
||||
if streaming_response_model:
|
||||
# Get the schema name for the streaming model
|
||||
# It might be a registered schema or a Pydantic model
|
||||
streaming_schema_name = None
|
||||
# Check if it's a registered schema first (before checking __name__)
|
||||
# because registered schemas might be Annotated types
|
||||
if schema_info := get_registered_schema_info(streaming_response_model):
|
||||
streaming_schema_name = schema_info.name
|
||||
elif hasattr(streaming_response_model, "__name__"):
|
||||
streaming_schema_name = streaming_response_model.__name__
|
||||
|
||||
if streaming_schema_name:
|
||||
response_content["text/event-stream"] = {
|
||||
"schema": {"$ref": f"#/components/schemas/{streaming_schema_name}"}
|
||||
}
|
||||
|
||||
# If no content types, use empty schema
|
||||
# Add the endpoint to the FastAPI app
|
||||
is_deprecated = webmethod.deprecated or False
|
||||
route_kwargs = {
|
||||
"name": name,
|
||||
"tags": [_get_tag_from_api(api)],
|
||||
"deprecated": is_deprecated,
|
||||
"responses": {
|
||||
400: {"$ref": "#/components/responses/BadRequest400"},
|
||||
429: {"$ref": "#/components/responses/TooManyRequests429"},
|
||||
500: {"$ref": "#/components/responses/InternalServerError500"},
|
||||
"default": {"$ref": "#/components/responses/DefaultError"},
|
||||
},
|
||||
}
|
||||
success_response: dict[str, Any] = {"description": response_description}
|
||||
if response_content:
|
||||
success_response["content"] = response_content
|
||||
route_kwargs["responses"][200] = success_response
|
||||
|
||||
# FastAPI needs response_model parameter to properly generate OpenAPI spec
|
||||
# Use the non-streaming response model if available
|
||||
if response_model:
|
||||
route_kwargs["response_model"] = response_model
|
||||
|
||||
method_map = {"GET": app.get, "POST": app.post, "PUT": app.put, "DELETE": app.delete, "PATCH": app.patch}
|
||||
for method in method_list:
|
||||
if handler := method_map.get(method):
|
||||
handler(fastapi_path, **route_kwargs)(endpoint_func)
|
||||
241
scripts/openapi_generator/main.py
Executable file
241
scripts/openapi_generator/main.py
Executable file
|
|
@ -0,0 +1,241 @@
|
|||
#!/usr/bin/env python3
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
"""
|
||||
Main entry point for the FastAPI OpenAPI generator.
|
||||
"""
|
||||
|
||||
import copy
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import yaml
|
||||
from fastapi.openapi.utils import get_openapi
|
||||
|
||||
from . import app, schema_collection, schema_filtering, schema_transforms, state
|
||||
|
||||
|
||||
def generate_openapi_spec(output_dir: str) -> dict[str, Any]:
|
||||
"""
|
||||
Generate OpenAPI specification using FastAPI's built-in method.
|
||||
|
||||
Args:
|
||||
output_dir: Directory to save the generated files
|
||||
|
||||
Returns:
|
||||
The generated OpenAPI specification as a dictionary
|
||||
"""
|
||||
state.reset_generator_state()
|
||||
# Create the FastAPI app
|
||||
fastapi_app = app.create_llama_stack_app()
|
||||
|
||||
# Generate the OpenAPI schema
|
||||
openapi_schema = get_openapi(
|
||||
title=fastapi_app.title,
|
||||
version=fastapi_app.version,
|
||||
description=fastapi_app.description,
|
||||
routes=fastapi_app.routes,
|
||||
servers=fastapi_app.servers,
|
||||
)
|
||||
|
||||
# Set OpenAPI version to 3.1.0
|
||||
openapi_schema["openapi"] = "3.1.0"
|
||||
|
||||
# Add standard error responses
|
||||
openapi_schema = schema_transforms._add_error_responses(openapi_schema)
|
||||
|
||||
# Ensure all @json_schema_type decorated models are included
|
||||
openapi_schema = schema_collection._ensure_json_schema_types_included(openapi_schema)
|
||||
|
||||
# Fix $ref references to point to components/schemas instead of $defs
|
||||
openapi_schema = schema_transforms._fix_ref_references(openapi_schema)
|
||||
|
||||
# Fix path parameter resolution issues
|
||||
openapi_schema = schema_transforms._fix_path_parameters(openapi_schema)
|
||||
|
||||
# Eliminate $defs section entirely for oasdiff compatibility
|
||||
openapi_schema = schema_transforms._eliminate_defs_section(openapi_schema)
|
||||
|
||||
# Clean descriptions in schema definitions by removing docstring metadata
|
||||
openapi_schema = schema_transforms._clean_schema_descriptions(openapi_schema)
|
||||
openapi_schema = schema_transforms._normalize_empty_responses(openapi_schema)
|
||||
|
||||
# Remove query parameters from POST/PUT/PATCH endpoints that have a request body
|
||||
# FastAPI sometimes infers parameters as query params even when they should be in the request body
|
||||
openapi_schema = schema_transforms._remove_query_params_from_body_endpoints(openapi_schema)
|
||||
|
||||
# Add x-llama-stack-extra-body-params extension for ExtraBodyField parameters
|
||||
openapi_schema = schema_transforms._add_extra_body_params_extension(openapi_schema)
|
||||
|
||||
# Remove request bodies from GET endpoints (GET requests should never have request bodies)
|
||||
# This must run AFTER _add_extra_body_params_extension to ensure any request bodies
|
||||
# that FastAPI incorrectly added to GET endpoints are removed
|
||||
openapi_schema = schema_transforms._remove_request_bodies_from_get_endpoints(openapi_schema)
|
||||
|
||||
# Extract duplicate union types to shared schema references
|
||||
openapi_schema = schema_transforms._extract_duplicate_union_types(openapi_schema)
|
||||
|
||||
# Split into stable (v1 only), experimental (v1alpha + v1beta), deprecated, and combined (stainless) specs
|
||||
# Each spec needs its own deep copy of the full schema to avoid cross-contamination
|
||||
stable_schema = schema_filtering._filter_schema_by_version(
|
||||
copy.deepcopy(openapi_schema), stable_only=True, exclude_deprecated=True
|
||||
)
|
||||
experimental_schema = schema_filtering._filter_schema_by_version(
|
||||
copy.deepcopy(openapi_schema), stable_only=False, exclude_deprecated=True
|
||||
)
|
||||
deprecated_schema = schema_filtering._filter_deprecated_schema(copy.deepcopy(openapi_schema))
|
||||
combined_schema = schema_filtering._filter_combined_schema(copy.deepcopy(openapi_schema))
|
||||
|
||||
# Apply duplicate union extraction to combined schema (used by Stainless)
|
||||
combined_schema = schema_transforms._extract_duplicate_union_types(combined_schema)
|
||||
|
||||
base_description = (
|
||||
"This is the specification of the Llama Stack that provides\n"
|
||||
" a set of endpoints and their corresponding interfaces that are\n"
|
||||
" tailored to\n"
|
||||
" best leverage Llama Models."
|
||||
)
|
||||
|
||||
schema_configs = [
|
||||
(
|
||||
stable_schema,
|
||||
"Llama Stack Specification",
|
||||
"**✅ STABLE**: Production-ready APIs with backward compatibility guarantees.",
|
||||
),
|
||||
(
|
||||
experimental_schema,
|
||||
"Llama Stack Specification - Experimental APIs",
|
||||
"**🧪 EXPERIMENTAL**: Pre-release APIs (v1alpha, v1beta) that may change before\n becoming stable.",
|
||||
),
|
||||
(
|
||||
deprecated_schema,
|
||||
"Llama Stack Specification - Deprecated APIs",
|
||||
"**⚠️ DEPRECATED**: Legacy APIs that may be removed in future versions. Use for\n migration reference only.",
|
||||
),
|
||||
(
|
||||
combined_schema,
|
||||
"Llama Stack Specification - Stable & Experimental APIs",
|
||||
"**🔗 COMBINED**: This specification includes both stable production-ready APIs\n and experimental pre-release APIs. Use stable APIs for production deployments\n and experimental APIs for testing new features.",
|
||||
),
|
||||
]
|
||||
|
||||
for schema, title, description_suffix in schema_configs:
|
||||
if "info" not in schema:
|
||||
schema["info"] = {}
|
||||
schema["info"].update(
|
||||
{
|
||||
"title": title,
|
||||
"version": "v1",
|
||||
"description": f"{base_description}\n\n {description_suffix}",
|
||||
}
|
||||
)
|
||||
|
||||
schemas_to_validate = [
|
||||
(stable_schema, "Stable schema"),
|
||||
(experimental_schema, "Experimental schema"),
|
||||
(deprecated_schema, "Deprecated schema"),
|
||||
(combined_schema, "Combined (stainless) schema"),
|
||||
]
|
||||
|
||||
for schema, _ in schemas_to_validate:
|
||||
schema_transforms._fix_schema_issues(schema)
|
||||
schema_transforms._apply_legacy_sorting(schema)
|
||||
|
||||
print("\nValidating generated schemas...")
|
||||
failed_schemas = [
|
||||
name for schema, name in schemas_to_validate if not schema_transforms.validate_openapi_schema(schema, name)
|
||||
]
|
||||
if failed_schemas:
|
||||
raise ValueError(f"Invalid schemas: {', '.join(failed_schemas)}")
|
||||
|
||||
# Ensure output directory exists
|
||||
output_path = Path(output_dir)
|
||||
output_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Save the stable specification
|
||||
yaml_path = output_path / "llama-stack-spec.yaml"
|
||||
schema_transforms._write_yaml_file(yaml_path, stable_schema)
|
||||
# Post-process the YAML file to remove $defs section and fix references
|
||||
with open(yaml_path) as f:
|
||||
yaml_content = f.read()
|
||||
|
||||
if " $defs:" in yaml_content or "#/$defs/" in yaml_content:
|
||||
# Use string replacement to fix references directly
|
||||
if "#/$defs/" in yaml_content:
|
||||
yaml_content = yaml_content.replace("#/$defs/", "#/components/schemas/")
|
||||
|
||||
# Parse the YAML content
|
||||
yaml_data = yaml.safe_load(yaml_content)
|
||||
|
||||
# Move $defs to components/schemas if it exists
|
||||
if "$defs" in yaml_data:
|
||||
if "components" not in yaml_data:
|
||||
yaml_data["components"] = {}
|
||||
if "schemas" not in yaml_data["components"]:
|
||||
yaml_data["components"]["schemas"] = {}
|
||||
|
||||
# Move all $defs to components/schemas
|
||||
for def_name, def_schema in yaml_data["$defs"].items():
|
||||
yaml_data["components"]["schemas"][def_name] = def_schema
|
||||
|
||||
# Remove the $defs section
|
||||
del yaml_data["$defs"]
|
||||
|
||||
# Write the modified YAML back
|
||||
schema_transforms._write_yaml_file(yaml_path, yaml_data)
|
||||
|
||||
print(f"Generated YAML (stable): {yaml_path}")
|
||||
|
||||
experimental_yaml_path = output_path / "experimental-llama-stack-spec.yaml"
|
||||
schema_transforms._write_yaml_file(experimental_yaml_path, experimental_schema)
|
||||
print(f"Generated YAML (experimental): {experimental_yaml_path}")
|
||||
|
||||
deprecated_yaml_path = output_path / "deprecated-llama-stack-spec.yaml"
|
||||
schema_transforms._write_yaml_file(deprecated_yaml_path, deprecated_schema)
|
||||
print(f"Generated YAML (deprecated): {deprecated_yaml_path}")
|
||||
|
||||
# Generate combined (stainless) spec
|
||||
stainless_yaml_path = output_path / "stainless-llama-stack-spec.yaml"
|
||||
schema_transforms._write_yaml_file(stainless_yaml_path, combined_schema)
|
||||
print(f"Generated YAML (stainless/combined): {stainless_yaml_path}")
|
||||
|
||||
return stable_schema
|
||||
|
||||
|
||||
def main():
|
||||
"""Main entry point for the FastAPI OpenAPI generator."""
|
||||
import argparse
|
||||
|
||||
parser = argparse.ArgumentParser(description="Generate OpenAPI specification using FastAPI")
|
||||
parser.add_argument("output_dir", help="Output directory for generated files")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
print("Generating OpenAPI specification using FastAPI...")
|
||||
print(f"Output directory: {args.output_dir}")
|
||||
|
||||
try:
|
||||
openapi_schema = generate_openapi_spec(output_dir=args.output_dir)
|
||||
|
||||
print("\nOpenAPI specification generated successfully!")
|
||||
print(f"Schemas: {len(openapi_schema.get('components', {}).get('schemas', {}))}")
|
||||
print(f"Paths: {len(openapi_schema.get('paths', {}))}")
|
||||
operation_count = sum(
|
||||
1
|
||||
for path_info in openapi_schema.get("paths", {}).values()
|
||||
for method in ["get", "post", "put", "delete", "patch"]
|
||||
if method in path_info
|
||||
)
|
||||
print(f"Operations: {operation_count}")
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error generating OpenAPI specification: {e}")
|
||||
raise
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
131
scripts/openapi_generator/schema_collection.py
Normal file
131
scripts/openapi_generator/schema_collection.py
Normal file
|
|
@ -0,0 +1,131 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
"""
|
||||
Schema discovery and collection for OpenAPI generation.
|
||||
"""
|
||||
|
||||
import importlib
|
||||
from typing import Any
|
||||
|
||||
|
||||
def _ensure_components_schemas(openapi_schema: dict[str, Any]) -> None:
|
||||
"""Ensure components.schemas exists in the schema."""
|
||||
if "components" not in openapi_schema:
|
||||
openapi_schema["components"] = {}
|
||||
if "schemas" not in openapi_schema["components"]:
|
||||
openapi_schema["components"]["schemas"] = {}
|
||||
|
||||
|
||||
def _load_extra_schema_modules() -> None:
|
||||
"""
|
||||
Import modules outside llama_stack_api that use schema_utils to register schemas.
|
||||
|
||||
The API package already imports its submodules via __init__, but server-side modules
|
||||
like telemetry need to be imported explicitly so their decorator side effects run.
|
||||
"""
|
||||
extra_modules = [
|
||||
"llama_stack.core.telemetry.telemetry",
|
||||
]
|
||||
for module_name in extra_modules:
|
||||
try:
|
||||
importlib.import_module(module_name)
|
||||
except ImportError:
|
||||
continue
|
||||
|
||||
|
||||
def _extract_and_fix_defs(schema: dict[str, Any], openapi_schema: dict[str, Any]) -> None:
|
||||
"""
|
||||
Extract $defs from a schema, move them to components/schemas, and fix references.
|
||||
This handles both TypeAdapter-generated schemas and model_json_schema() schemas.
|
||||
"""
|
||||
if "$defs" in schema:
|
||||
defs = schema.pop("$defs")
|
||||
for def_name, def_schema in defs.items():
|
||||
if def_name not in openapi_schema["components"]["schemas"]:
|
||||
openapi_schema["components"]["schemas"][def_name] = def_schema
|
||||
# Recursively handle $defs in nested schemas
|
||||
_extract_and_fix_defs(def_schema, openapi_schema)
|
||||
|
||||
# Fix any references in the main schema that point to $defs
|
||||
def fix_refs_in_schema(obj: Any) -> None:
|
||||
if isinstance(obj, dict):
|
||||
if "$ref" in obj and obj["$ref"].startswith("#/$defs/"):
|
||||
obj["$ref"] = obj["$ref"].replace("#/$defs/", "#/components/schemas/")
|
||||
for value in obj.values():
|
||||
fix_refs_in_schema(value)
|
||||
elif isinstance(obj, list):
|
||||
for item in obj:
|
||||
fix_refs_in_schema(item)
|
||||
|
||||
fix_refs_in_schema(schema)
|
||||
|
||||
|
||||
def _ensure_json_schema_types_included(openapi_schema: dict[str, Any]) -> dict[str, Any]:
|
||||
"""
|
||||
Ensure all registered schemas (decorated, explicit, and dynamic) are included in the OpenAPI schema.
|
||||
Relies on llama_stack_api's registry instead of recursively importing every module.
|
||||
"""
|
||||
_ensure_components_schemas(openapi_schema)
|
||||
|
||||
from pydantic import TypeAdapter
|
||||
|
||||
from llama_stack_api.schema_utils import (
|
||||
iter_dynamic_schema_types,
|
||||
iter_json_schema_types,
|
||||
iter_registered_schema_types,
|
||||
)
|
||||
|
||||
# Import extra modules (e.g., telemetry) whose schema registrations live outside llama_stack_api
|
||||
_load_extra_schema_modules()
|
||||
|
||||
# Handle explicitly registered schemas first (union types, Annotated structs, etc.)
|
||||
for registration_info in iter_registered_schema_types():
|
||||
schema_type = registration_info.type
|
||||
schema_name = registration_info.name
|
||||
if schema_name not in openapi_schema["components"]["schemas"]:
|
||||
try:
|
||||
adapter = TypeAdapter(schema_type)
|
||||
schema = adapter.json_schema(ref_template="#/components/schemas/{model}")
|
||||
_extract_and_fix_defs(schema, openapi_schema)
|
||||
openapi_schema["components"]["schemas"][schema_name] = schema
|
||||
except Exception as e:
|
||||
print(f"Warning: Failed to generate schema for registered type {schema_name}: {e}")
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
continue
|
||||
|
||||
# Add @json_schema_type decorated models
|
||||
for model in iter_json_schema_types():
|
||||
schema_name = getattr(model, "_llama_stack_schema_name", None) or getattr(model, "__name__", None)
|
||||
if not schema_name:
|
||||
continue
|
||||
if schema_name not in openapi_schema["components"]["schemas"]:
|
||||
try:
|
||||
if hasattr(model, "model_json_schema"):
|
||||
schema = model.model_json_schema(ref_template="#/components/schemas/{model}")
|
||||
else:
|
||||
adapter = TypeAdapter(model)
|
||||
schema = adapter.json_schema(ref_template="#/components/schemas/{model}")
|
||||
_extract_and_fix_defs(schema, openapi_schema)
|
||||
openapi_schema["components"]["schemas"][schema_name] = schema
|
||||
except Exception as e:
|
||||
print(f"Warning: Failed to generate schema for {schema_name}: {e}")
|
||||
continue
|
||||
|
||||
# Include any dynamic models generated while building endpoints
|
||||
for model in iter_dynamic_schema_types():
|
||||
try:
|
||||
schema_name = model.__name__
|
||||
if schema_name not in openapi_schema["components"]["schemas"]:
|
||||
schema = model.model_json_schema(ref_template="#/components/schemas/{model}")
|
||||
_extract_and_fix_defs(schema, openapi_schema)
|
||||
openapi_schema["components"]["schemas"][schema_name] = schema
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
return openapi_schema
|
||||
297
scripts/openapi_generator/schema_filtering.py
Normal file
297
scripts/openapi_generator/schema_filtering.py
Normal file
|
|
@ -0,0 +1,297 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
"""
|
||||
Schema filtering and version filtering for OpenAPI generation.
|
||||
"""
|
||||
|
||||
from typing import Any
|
||||
|
||||
from llama_stack_api.schema_utils import iter_json_schema_types, iter_registered_schema_types
|
||||
from llama_stack_api.version import (
|
||||
LLAMA_STACK_API_V1,
|
||||
LLAMA_STACK_API_V1ALPHA,
|
||||
LLAMA_STACK_API_V1BETA,
|
||||
)
|
||||
|
||||
|
||||
def _get_all_json_schema_type_names() -> set[str]:
|
||||
"""Collect schema names from @json_schema_type-decorated models."""
|
||||
schema_names = set()
|
||||
for model in iter_json_schema_types():
|
||||
schema_name = getattr(model, "_llama_stack_schema_name", None) or getattr(model, "__name__", None)
|
||||
if schema_name:
|
||||
schema_names.add(schema_name)
|
||||
return schema_names
|
||||
|
||||
|
||||
def _get_explicit_schema_names(openapi_schema: dict[str, Any]) -> set[str]:
|
||||
"""Schema names to keep even if not referenced by a path."""
|
||||
registered_schema_names = {info.name for info in iter_registered_schema_types()}
|
||||
json_schema_type_names = _get_all_json_schema_type_names()
|
||||
return registered_schema_names | json_schema_type_names
|
||||
|
||||
|
||||
def _find_schema_refs_in_object(obj: Any) -> set[str]:
|
||||
"""
|
||||
Recursively find all schema references ($ref) in an object.
|
||||
"""
|
||||
refs = set()
|
||||
|
||||
if isinstance(obj, dict):
|
||||
for key, value in obj.items():
|
||||
if key == "$ref" and isinstance(value, str) and value.startswith("#/components/schemas/"):
|
||||
schema_name = value.split("/")[-1]
|
||||
refs.add(schema_name)
|
||||
else:
|
||||
refs.update(_find_schema_refs_in_object(value))
|
||||
elif isinstance(obj, list):
|
||||
for item in obj:
|
||||
refs.update(_find_schema_refs_in_object(item))
|
||||
|
||||
return refs
|
||||
|
||||
|
||||
def _add_transitive_references(
|
||||
referenced_schemas: set[str], all_schemas: dict[str, Any], initial_schemas: set[str] | None = None
|
||||
) -> set[str]:
|
||||
"""Add transitive references for given schemas."""
|
||||
if initial_schemas:
|
||||
referenced_schemas.update(initial_schemas)
|
||||
additional_schemas = set()
|
||||
for schema_name in initial_schemas:
|
||||
if schema_name in all_schemas:
|
||||
additional_schemas.update(_find_schema_refs_in_object(all_schemas[schema_name]))
|
||||
else:
|
||||
additional_schemas = set()
|
||||
for schema_name in referenced_schemas:
|
||||
if schema_name in all_schemas:
|
||||
additional_schemas.update(_find_schema_refs_in_object(all_schemas[schema_name]))
|
||||
|
||||
while additional_schemas:
|
||||
new_schemas = additional_schemas - referenced_schemas
|
||||
if not new_schemas:
|
||||
break
|
||||
referenced_schemas.update(new_schemas)
|
||||
additional_schemas = set()
|
||||
for schema_name in new_schemas:
|
||||
if schema_name in all_schemas:
|
||||
additional_schemas.update(_find_schema_refs_in_object(all_schemas[schema_name]))
|
||||
|
||||
return referenced_schemas
|
||||
|
||||
|
||||
def _find_schemas_referenced_by_paths(filtered_paths: dict[str, Any], openapi_schema: dict[str, Any]) -> set[str]:
|
||||
"""
|
||||
Find all schemas that are referenced by the filtered paths.
|
||||
This recursively traverses the path definitions to find all $ref references.
|
||||
"""
|
||||
referenced_schemas = set()
|
||||
|
||||
# Traverse all filtered paths
|
||||
for _, path_item in filtered_paths.items():
|
||||
if not isinstance(path_item, dict):
|
||||
continue
|
||||
|
||||
# Check each HTTP method in the path
|
||||
for method in ["get", "post", "put", "delete", "patch", "head", "options"]:
|
||||
if method in path_item:
|
||||
operation = path_item[method]
|
||||
if isinstance(operation, dict):
|
||||
# Find all schema references in this operation
|
||||
referenced_schemas.update(_find_schema_refs_in_object(operation))
|
||||
|
||||
# Also check the responses section for schema references
|
||||
if "components" in openapi_schema and "responses" in openapi_schema["components"]:
|
||||
referenced_schemas.update(_find_schema_refs_in_object(openapi_schema["components"]["responses"]))
|
||||
|
||||
# Also include schemas that are referenced by other schemas (transitive references)
|
||||
# This ensures we include all dependencies
|
||||
all_schemas = openapi_schema.get("components", {}).get("schemas", {})
|
||||
additional_schemas = set()
|
||||
|
||||
for schema_name in referenced_schemas:
|
||||
if schema_name in all_schemas:
|
||||
additional_schemas.update(_find_schema_refs_in_object(all_schemas[schema_name]))
|
||||
|
||||
# Keep adding transitive references until no new ones are found
|
||||
while additional_schemas:
|
||||
new_schemas = additional_schemas - referenced_schemas
|
||||
if not new_schemas:
|
||||
break
|
||||
referenced_schemas.update(new_schemas)
|
||||
additional_schemas = set()
|
||||
for schema_name in new_schemas:
|
||||
if schema_name in all_schemas:
|
||||
additional_schemas.update(_find_schema_refs_in_object(all_schemas[schema_name]))
|
||||
|
||||
return referenced_schemas
|
||||
|
||||
|
||||
def _filter_schemas_by_references(
|
||||
filtered_schema: dict[str, Any], filtered_paths: dict[str, Any], openapi_schema: dict[str, Any]
|
||||
) -> dict[str, Any]:
|
||||
"""Filter schemas to only include ones referenced by filtered paths and explicit schemas."""
|
||||
if "components" not in filtered_schema or "schemas" not in filtered_schema["components"]:
|
||||
return filtered_schema
|
||||
|
||||
referenced_schemas = _find_schemas_referenced_by_paths(filtered_paths, openapi_schema)
|
||||
all_schemas = openapi_schema.get("components", {}).get("schemas", {})
|
||||
explicit_names = _get_explicit_schema_names(openapi_schema)
|
||||
referenced_schemas = _add_transitive_references(referenced_schemas, all_schemas, explicit_names)
|
||||
|
||||
filtered_schemas = {
|
||||
name: schema for name, schema in filtered_schema["components"]["schemas"].items() if name in referenced_schemas
|
||||
}
|
||||
filtered_schema["components"]["schemas"] = filtered_schemas
|
||||
|
||||
if "components" in openapi_schema and "$defs" in openapi_schema["components"]:
|
||||
if "components" not in filtered_schema:
|
||||
filtered_schema["components"] = {}
|
||||
filtered_schema["components"]["$defs"] = openapi_schema["components"]["$defs"]
|
||||
|
||||
return filtered_schema
|
||||
|
||||
|
||||
def _path_starts_with_version(path: str, version: str) -> bool:
|
||||
"""Check if a path starts with a specific API version prefix."""
|
||||
return path.startswith(f"/{version}/")
|
||||
|
||||
|
||||
def _is_stable_path(path: str) -> bool:
|
||||
"""Check if a path is a stable v1 path (not v1alpha or v1beta)."""
|
||||
return (
|
||||
_path_starts_with_version(path, LLAMA_STACK_API_V1)
|
||||
and not _path_starts_with_version(path, LLAMA_STACK_API_V1ALPHA)
|
||||
and not _path_starts_with_version(path, LLAMA_STACK_API_V1BETA)
|
||||
)
|
||||
|
||||
|
||||
def _is_experimental_path(path: str) -> bool:
|
||||
"""Check if a path is an experimental path (v1alpha or v1beta)."""
|
||||
return _path_starts_with_version(path, LLAMA_STACK_API_V1ALPHA) or _path_starts_with_version(
|
||||
path, LLAMA_STACK_API_V1BETA
|
||||
)
|
||||
|
||||
|
||||
def _is_path_deprecated(path_item: dict[str, Any]) -> bool:
|
||||
"""Check if a path item has any deprecated operations."""
|
||||
if not isinstance(path_item, dict):
|
||||
return False
|
||||
for method in ["get", "post", "put", "delete", "patch", "head", "options"]:
|
||||
if isinstance(path_item.get(method), dict) and path_item[method].get("deprecated", False):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def _filter_schema_by_version(
|
||||
openapi_schema: dict[str, Any], stable_only: bool = True, exclude_deprecated: bool = True
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Filter OpenAPI schema by API version.
|
||||
|
||||
Args:
|
||||
openapi_schema: The full OpenAPI schema
|
||||
stable_only: If True, return only /v1/ paths (stable). If False, return only /v1alpha/ and /v1beta/ paths (experimental).
|
||||
exclude_deprecated: If True, exclude deprecated endpoints from the result.
|
||||
|
||||
Returns:
|
||||
Filtered OpenAPI schema
|
||||
"""
|
||||
filtered_schema = openapi_schema.copy()
|
||||
|
||||
if "paths" not in filtered_schema:
|
||||
return filtered_schema
|
||||
|
||||
filtered_paths = {}
|
||||
for path, path_item in filtered_schema["paths"].items():
|
||||
if not isinstance(path_item, dict):
|
||||
continue
|
||||
|
||||
# Filter at operation level, not path level
|
||||
# This allows paths with both deprecated and non-deprecated operations
|
||||
filtered_path_item = {}
|
||||
for method in ["get", "post", "put", "delete", "patch", "head", "options"]:
|
||||
if method not in path_item:
|
||||
continue
|
||||
operation = path_item[method]
|
||||
if not isinstance(operation, dict):
|
||||
continue
|
||||
|
||||
# Skip deprecated operations if exclude_deprecated is True
|
||||
if exclude_deprecated and operation.get("deprecated", False):
|
||||
continue
|
||||
|
||||
filtered_path_item[method] = operation
|
||||
|
||||
# Only include path if it has at least one operation after filtering
|
||||
if filtered_path_item:
|
||||
# Check if path matches version filter
|
||||
if (stable_only and _is_stable_path(path)) or (not stable_only and _is_experimental_path(path)):
|
||||
filtered_paths[path] = filtered_path_item
|
||||
|
||||
filtered_schema["paths"] = filtered_paths
|
||||
return _filter_schemas_by_references(filtered_schema, filtered_paths, openapi_schema)
|
||||
|
||||
|
||||
def _filter_deprecated_schema(openapi_schema: dict[str, Any]) -> dict[str, Any]:
|
||||
"""
|
||||
Filter OpenAPI schema to include only deprecated endpoints.
|
||||
Includes all deprecated endpoints regardless of version (v1, v1alpha, v1beta).
|
||||
"""
|
||||
filtered_schema = openapi_schema.copy()
|
||||
|
||||
if "paths" not in filtered_schema:
|
||||
return filtered_schema
|
||||
|
||||
# Filter paths to only include deprecated ones
|
||||
filtered_paths = {}
|
||||
for path, path_item in filtered_schema["paths"].items():
|
||||
if _is_path_deprecated(path_item):
|
||||
filtered_paths[path] = path_item
|
||||
|
||||
filtered_schema["paths"] = filtered_paths
|
||||
|
||||
return filtered_schema
|
||||
|
||||
|
||||
def _filter_combined_schema(openapi_schema: dict[str, Any]) -> dict[str, Any]:
|
||||
"""
|
||||
Filter OpenAPI schema to include both stable (v1) and experimental (v1alpha, v1beta) APIs.
|
||||
Includes deprecated endpoints. This is used for the combined "stainless" spec.
|
||||
"""
|
||||
filtered_schema = openapi_schema.copy()
|
||||
|
||||
if "paths" not in filtered_schema:
|
||||
return filtered_schema
|
||||
|
||||
# Filter paths to include stable (v1) and experimental (v1alpha, v1beta), excluding deprecated
|
||||
filtered_paths = {}
|
||||
for path, path_item in filtered_schema["paths"].items():
|
||||
if not isinstance(path_item, dict):
|
||||
continue
|
||||
|
||||
# Filter at operation level, not path level
|
||||
# This allows paths with both deprecated and non-deprecated operations
|
||||
filtered_path_item = {}
|
||||
for method in ["get", "post", "put", "delete", "patch", "head", "options"]:
|
||||
if method not in path_item:
|
||||
continue
|
||||
operation = path_item[method]
|
||||
if not isinstance(operation, dict):
|
||||
continue
|
||||
|
||||
filtered_path_item[method] = operation
|
||||
|
||||
# Only include path if it has at least one operation after filtering
|
||||
if filtered_path_item:
|
||||
# Check if path matches version filter (stable or experimental)
|
||||
if _is_stable_path(path) or _is_experimental_path(path):
|
||||
filtered_paths[path] = filtered_path_item
|
||||
|
||||
filtered_schema["paths"] = filtered_paths
|
||||
|
||||
return _filter_schemas_by_references(filtered_schema, filtered_paths, openapi_schema)
|
||||
963
scripts/openapi_generator/schema_transforms.py
Normal file
963
scripts/openapi_generator/schema_transforms.py
Normal file
|
|
@ -0,0 +1,963 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
"""
|
||||
Schema transformations and fixes for OpenAPI generation.
|
||||
"""
|
||||
|
||||
import copy
|
||||
from collections import OrderedDict
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import yaml
|
||||
from openapi_spec_validator import validate_spec
|
||||
from openapi_spec_validator.exceptions import OpenAPISpecValidatorError
|
||||
|
||||
from . import endpoints, schema_collection
|
||||
from ._legacy_order import (
|
||||
LEGACY_OPERATION_KEYS,
|
||||
LEGACY_PATH_ORDER,
|
||||
LEGACY_RESPONSE_ORDER,
|
||||
LEGACY_SCHEMA_ORDER,
|
||||
LEGACY_SECURITY,
|
||||
LEGACY_TAG_GROUPS,
|
||||
LEGACY_TAGS,
|
||||
)
|
||||
from .state import _extra_body_fields
|
||||
|
||||
|
||||
def _fix_ref_references(openapi_schema: dict[str, Any]) -> dict[str, Any]:
|
||||
"""
|
||||
Fix $ref references to point to components/schemas instead of $defs.
|
||||
This prevents the YAML dumper from creating a root-level $defs section.
|
||||
"""
|
||||
|
||||
def fix_refs(obj: Any) -> None:
|
||||
if isinstance(obj, dict):
|
||||
if "$ref" in obj and obj["$ref"].startswith("#/$defs/"):
|
||||
# Replace #/$defs/ with #/components/schemas/
|
||||
obj["$ref"] = obj["$ref"].replace("#/$defs/", "#/components/schemas/")
|
||||
for value in obj.values():
|
||||
fix_refs(value)
|
||||
elif isinstance(obj, list):
|
||||
for item in obj:
|
||||
fix_refs(item)
|
||||
|
||||
fix_refs(openapi_schema)
|
||||
return openapi_schema
|
||||
|
||||
|
||||
def _normalize_empty_responses(openapi_schema: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Convert empty 200 responses into 204 No Content."""
|
||||
|
||||
for path_item in openapi_schema.get("paths", {}).values():
|
||||
if not isinstance(path_item, dict):
|
||||
continue
|
||||
for method in list(path_item.keys()):
|
||||
operation = path_item.get(method)
|
||||
if not isinstance(operation, dict):
|
||||
continue
|
||||
responses = operation.get("responses")
|
||||
if not isinstance(responses, dict):
|
||||
continue
|
||||
response_200 = responses.get("200") or responses.get(200)
|
||||
if response_200 is None:
|
||||
continue
|
||||
content = response_200.get("content")
|
||||
if content and any(
|
||||
isinstance(media, dict) and media.get("schema") not in ({}, None) for media in content.values()
|
||||
):
|
||||
continue
|
||||
responses.pop("200", None)
|
||||
responses.pop(200, None)
|
||||
responses["204"] = {"description": response_200.get("description", "No Content")}
|
||||
return openapi_schema
|
||||
|
||||
|
||||
def _eliminate_defs_section(openapi_schema: dict[str, Any]) -> dict[str, Any]:
|
||||
"""
|
||||
Eliminate $defs section entirely by moving all definitions to components/schemas.
|
||||
This matches the structure of the old pyopenapi generator for oasdiff compatibility.
|
||||
"""
|
||||
schema_collection._ensure_components_schemas(openapi_schema)
|
||||
|
||||
# First pass: collect all $defs from anywhere in the schema
|
||||
defs_to_move = {}
|
||||
|
||||
def collect_defs(obj: Any) -> None:
|
||||
if isinstance(obj, dict):
|
||||
if "$defs" in obj:
|
||||
# Collect $defs for later processing
|
||||
for def_name, def_schema in obj["$defs"].items():
|
||||
if def_name not in defs_to_move:
|
||||
defs_to_move[def_name] = def_schema
|
||||
|
||||
# Recursively process all values
|
||||
for value in obj.values():
|
||||
collect_defs(value)
|
||||
elif isinstance(obj, list):
|
||||
for item in obj:
|
||||
collect_defs(item)
|
||||
|
||||
# Collect all $defs
|
||||
collect_defs(openapi_schema)
|
||||
|
||||
# Move all $defs to components/schemas
|
||||
for def_name, def_schema in defs_to_move.items():
|
||||
if def_name not in openapi_schema["components"]["schemas"]:
|
||||
openapi_schema["components"]["schemas"][def_name] = def_schema
|
||||
|
||||
# Also move any existing root-level $defs to components/schemas
|
||||
if "$defs" in openapi_schema:
|
||||
print(f"Found root-level $defs with {len(openapi_schema['$defs'])} items, moving to components/schemas")
|
||||
for def_name, def_schema in openapi_schema["$defs"].items():
|
||||
if def_name not in openapi_schema["components"]["schemas"]:
|
||||
openapi_schema["components"]["schemas"][def_name] = def_schema
|
||||
# Remove the root-level $defs
|
||||
del openapi_schema["$defs"]
|
||||
|
||||
# Second pass: remove all $defs sections from anywhere in the schema
|
||||
def remove_defs(obj: Any) -> None:
|
||||
if isinstance(obj, dict):
|
||||
if "$defs" in obj:
|
||||
del obj["$defs"]
|
||||
|
||||
# Recursively process all values
|
||||
for value in obj.values():
|
||||
remove_defs(value)
|
||||
elif isinstance(obj, list):
|
||||
for item in obj:
|
||||
remove_defs(item)
|
||||
|
||||
# Remove all $defs sections
|
||||
remove_defs(openapi_schema)
|
||||
|
||||
return openapi_schema
|
||||
|
||||
|
||||
def _add_error_responses(openapi_schema: dict[str, Any]) -> dict[str, Any]:
|
||||
"""
|
||||
Add standard error response definitions to the OpenAPI schema.
|
||||
Uses the actual Error model from the codebase for consistency.
|
||||
"""
|
||||
if "components" not in openapi_schema:
|
||||
openapi_schema["components"] = {}
|
||||
if "responses" not in openapi_schema["components"]:
|
||||
openapi_schema["components"]["responses"] = {}
|
||||
|
||||
try:
|
||||
from llama_stack_api.datatypes import Error
|
||||
|
||||
schema_collection._ensure_components_schemas(openapi_schema)
|
||||
if "Error" not in openapi_schema["components"]["schemas"]:
|
||||
openapi_schema["components"]["schemas"]["Error"] = Error.model_json_schema()
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
schema_collection._ensure_components_schemas(openapi_schema)
|
||||
if "Response" not in openapi_schema["components"]["schemas"]:
|
||||
openapi_schema["components"]["schemas"]["Response"] = {"title": "Response", "type": "object"}
|
||||
|
||||
# Define standard HTTP error responses
|
||||
error_responses = {
|
||||
400: {
|
||||
"name": "BadRequest400",
|
||||
"description": "The request was invalid or malformed",
|
||||
"example": {"status": 400, "title": "Bad Request", "detail": "The request was invalid or malformed"},
|
||||
},
|
||||
429: {
|
||||
"name": "TooManyRequests429",
|
||||
"description": "The client has sent too many requests in a given amount of time",
|
||||
"example": {
|
||||
"status": 429,
|
||||
"title": "Too Many Requests",
|
||||
"detail": "You have exceeded the rate limit. Please try again later.",
|
||||
},
|
||||
},
|
||||
500: {
|
||||
"name": "InternalServerError500",
|
||||
"description": "The server encountered an unexpected error",
|
||||
"example": {"status": 500, "title": "Internal Server Error", "detail": "An unexpected error occurred"},
|
||||
},
|
||||
}
|
||||
|
||||
# Add each error response to the schema
|
||||
for _, error_info in error_responses.items():
|
||||
response_name = error_info["name"]
|
||||
openapi_schema["components"]["responses"][response_name] = {
|
||||
"description": error_info["description"],
|
||||
"content": {
|
||||
"application/json": {"schema": {"$ref": "#/components/schemas/Error"}, "example": error_info["example"]}
|
||||
},
|
||||
}
|
||||
|
||||
# Add a default error response
|
||||
openapi_schema["components"]["responses"]["DefaultError"] = {
|
||||
"description": "An error occurred",
|
||||
"content": {"application/json": {"schema": {"$ref": "#/components/schemas/Error"}}},
|
||||
}
|
||||
|
||||
return openapi_schema
|
||||
|
||||
|
||||
def _fix_path_parameters(openapi_schema: dict[str, Any]) -> dict[str, Any]:
|
||||
"""
|
||||
Fix path parameter resolution issues by adding explicit parameter definitions.
|
||||
"""
|
||||
if "paths" not in openapi_schema:
|
||||
return openapi_schema
|
||||
|
||||
for path, path_item in openapi_schema["paths"].items():
|
||||
# Extract path parameters from the URL
|
||||
path_params = endpoints._extract_path_parameters(path)
|
||||
|
||||
if not path_params:
|
||||
continue
|
||||
|
||||
# Add parameters to each operation in this path
|
||||
for method in ["get", "post", "put", "delete", "patch", "head", "options"]:
|
||||
if method in path_item and isinstance(path_item[method], dict):
|
||||
operation = path_item[method]
|
||||
if "parameters" not in operation:
|
||||
operation["parameters"] = []
|
||||
|
||||
# Add path parameters that aren't already defined
|
||||
existing_param_names = {p.get("name") for p in operation["parameters"] if p.get("in") == "path"}
|
||||
for param in path_params:
|
||||
if param["name"] not in existing_param_names:
|
||||
operation["parameters"].append(param)
|
||||
|
||||
return openapi_schema
|
||||
|
||||
|
||||
def _get_schema_title(item: dict[str, Any]) -> str | None:
|
||||
"""Extract a title for a schema item to use in union variant names."""
|
||||
if "$ref" in item:
|
||||
return item["$ref"].split("/")[-1]
|
||||
elif "type" in item:
|
||||
type_val = item["type"]
|
||||
if type_val == "null":
|
||||
return None
|
||||
if type_val == "array" and "items" in item:
|
||||
items = item["items"]
|
||||
if isinstance(items, dict):
|
||||
if "anyOf" in items or "oneOf" in items:
|
||||
nested_union = items.get("anyOf") or items.get("oneOf")
|
||||
if isinstance(nested_union, list) and len(nested_union) > 0:
|
||||
nested_types = []
|
||||
for nested_item in nested_union:
|
||||
if isinstance(nested_item, dict):
|
||||
if "$ref" in nested_item:
|
||||
nested_types.append(nested_item["$ref"].split("/")[-1])
|
||||
elif "oneOf" in nested_item:
|
||||
one_of_items = nested_item.get("oneOf", [])
|
||||
if one_of_items and isinstance(one_of_items[0], dict) and "$ref" in one_of_items[0]:
|
||||
base_name = one_of_items[0]["$ref"].split("/")[-1].split("-")[0]
|
||||
nested_types.append(f"{base_name}Union")
|
||||
else:
|
||||
nested_types.append("Union")
|
||||
elif "type" in nested_item and nested_item["type"] != "null":
|
||||
nested_types.append(nested_item["type"])
|
||||
if nested_types:
|
||||
unique_nested = list(dict.fromkeys(nested_types))
|
||||
# Use more descriptive names for better code generation
|
||||
if len(unique_nested) <= 3:
|
||||
return f"list[{' | '.join(unique_nested)}]"
|
||||
else:
|
||||
# Include first few types for better naming
|
||||
return f"list[{unique_nested[0]} | {unique_nested[1]} | ...]"
|
||||
return "list[Union]"
|
||||
elif "$ref" in items:
|
||||
return f"list[{items['$ref'].split('/')[-1]}]"
|
||||
elif "type" in items:
|
||||
return f"list[{items['type']}]"
|
||||
return "array"
|
||||
return type_val
|
||||
elif "title" in item:
|
||||
return item["title"]
|
||||
return None
|
||||
|
||||
|
||||
def _add_titles_to_unions(obj: Any, parent_key: str | None = None) -> None:
|
||||
"""Recursively add titles to union schemas (anyOf/oneOf) to help code generators infer names."""
|
||||
if isinstance(obj, dict):
|
||||
# Check if this is a union schema (anyOf or oneOf)
|
||||
if "anyOf" in obj or "oneOf" in obj:
|
||||
union_type = "anyOf" if "anyOf" in obj else "oneOf"
|
||||
union_items = obj[union_type]
|
||||
|
||||
if isinstance(union_items, list) and len(union_items) > 0:
|
||||
# Skip simple nullable unions (type | null) - these don't need titles
|
||||
is_simple_nullable = (
|
||||
len(union_items) == 2
|
||||
and any(isinstance(item, dict) and item.get("type") == "null" for item in union_items)
|
||||
and any(
|
||||
isinstance(item, dict) and "type" in item and item.get("type") != "null" for item in union_items
|
||||
)
|
||||
and not any(
|
||||
isinstance(item, dict) and ("$ref" in item or "anyOf" in item or "oneOf" in item)
|
||||
for item in union_items
|
||||
)
|
||||
)
|
||||
|
||||
if is_simple_nullable:
|
||||
# Remove title from simple nullable unions if it exists
|
||||
if "title" in obj:
|
||||
del obj["title"]
|
||||
else:
|
||||
# Add titles to individual union variants that need them
|
||||
for item in union_items:
|
||||
if isinstance(item, dict):
|
||||
# Skip null types
|
||||
if item.get("type") == "null":
|
||||
continue
|
||||
# Add title to complex variants (arrays with unions, nested unions, etc.)
|
||||
# Also add to simple types if they're part of a complex union
|
||||
needs_title = (
|
||||
"items" in item
|
||||
or "anyOf" in item
|
||||
or "oneOf" in item
|
||||
or ("$ref" in item and "title" not in item)
|
||||
)
|
||||
if needs_title and "title" not in item:
|
||||
variant_title = _get_schema_title(item)
|
||||
if variant_title:
|
||||
item["title"] = variant_title
|
||||
|
||||
# Try to infer a meaningful title from the union items for the parent
|
||||
titles = []
|
||||
for item in union_items:
|
||||
if isinstance(item, dict):
|
||||
title = _get_schema_title(item)
|
||||
if title:
|
||||
titles.append(title)
|
||||
|
||||
if titles:
|
||||
# Create a title from the union items
|
||||
unique_titles = list(dict.fromkeys(titles)) # Preserve order, remove duplicates
|
||||
if len(unique_titles) <= 3:
|
||||
title = " | ".join(unique_titles)
|
||||
else:
|
||||
title = f"{unique_titles[0]} | ... ({len(unique_titles)} variants)"
|
||||
# Always set the title for unions to help code generators
|
||||
# This will replace generic property titles with union-specific ones
|
||||
obj["title"] = title
|
||||
elif "title" not in obj and parent_key:
|
||||
# Use parent key as fallback only if no title exists
|
||||
obj["title"] = f"{parent_key.title()}Union"
|
||||
|
||||
# Recursively process all values
|
||||
for key, value in obj.items():
|
||||
_add_titles_to_unions(value, key)
|
||||
elif isinstance(obj, list):
|
||||
for item in obj:
|
||||
_add_titles_to_unions(item, parent_key)
|
||||
|
||||
|
||||
def _convert_anyof_const_to_enum(obj: Any) -> None:
|
||||
"""Convert anyOf with multiple const string values to a proper enum."""
|
||||
if isinstance(obj, dict):
|
||||
if "anyOf" in obj:
|
||||
any_of = obj["anyOf"]
|
||||
if isinstance(any_of, list):
|
||||
# Check if all items are const string values
|
||||
const_values = []
|
||||
has_null = False
|
||||
can_convert = True
|
||||
for item in any_of:
|
||||
if isinstance(item, dict):
|
||||
if item.get("type") == "null":
|
||||
has_null = True
|
||||
elif item.get("type") == "string" and "const" in item:
|
||||
const_values.append(item["const"])
|
||||
else:
|
||||
# Not a simple const pattern, skip conversion for this anyOf
|
||||
can_convert = False
|
||||
break
|
||||
|
||||
# If we have const values and they're all strings, convert to enum
|
||||
if can_convert and const_values and len(const_values) == len(any_of) - (1 if has_null else 0):
|
||||
# Convert to enum
|
||||
obj["type"] = "string"
|
||||
obj["enum"] = const_values
|
||||
# Preserve default if present, otherwise try to get from first const item
|
||||
if "default" not in obj:
|
||||
for item in any_of:
|
||||
if isinstance(item, dict) and "const" in item:
|
||||
obj["default"] = item["const"]
|
||||
break
|
||||
# Remove anyOf
|
||||
del obj["anyOf"]
|
||||
# Handle nullable
|
||||
if has_null:
|
||||
obj["nullable"] = True
|
||||
# Remove title if it's just "string"
|
||||
if obj.get("title") == "string":
|
||||
del obj["title"]
|
||||
|
||||
# Recursively process all values
|
||||
for value in obj.values():
|
||||
_convert_anyof_const_to_enum(value)
|
||||
elif isinstance(obj, list):
|
||||
for item in obj:
|
||||
_convert_anyof_const_to_enum(item)
|
||||
|
||||
|
||||
def _fix_schema_recursive(obj: Any) -> None:
|
||||
"""Recursively fix schema issues: exclusiveMinimum and null defaults."""
|
||||
if isinstance(obj, dict):
|
||||
if "exclusiveMinimum" in obj and isinstance(obj["exclusiveMinimum"], int | float):
|
||||
obj["minimum"] = obj.pop("exclusiveMinimum")
|
||||
if "default" in obj and obj["default"] is None:
|
||||
del obj["default"]
|
||||
obj["nullable"] = True
|
||||
for value in obj.values():
|
||||
_fix_schema_recursive(value)
|
||||
elif isinstance(obj, list):
|
||||
for item in obj:
|
||||
_fix_schema_recursive(item)
|
||||
|
||||
|
||||
def _clean_description(description: str) -> str:
|
||||
"""Remove :param, :type, :returns, and other docstring metadata from description."""
|
||||
if not description:
|
||||
return description
|
||||
|
||||
lines = description.split("\n")
|
||||
cleaned_lines = []
|
||||
skip_until_empty = False
|
||||
|
||||
for line in lines:
|
||||
stripped = line.strip()
|
||||
# Skip lines that start with docstring metadata markers
|
||||
if stripped.startswith(
|
||||
(":param", ":type", ":return", ":returns", ":raises", ":exception", ":yield", ":yields", ":cvar")
|
||||
):
|
||||
skip_until_empty = True
|
||||
continue
|
||||
# If we're skipping and hit an empty line, resume normal processing
|
||||
if skip_until_empty:
|
||||
if not stripped:
|
||||
skip_until_empty = False
|
||||
continue
|
||||
# Include the line if we're not skipping
|
||||
cleaned_lines.append(line)
|
||||
|
||||
# Join and strip trailing whitespace
|
||||
result = "\n".join(cleaned_lines).strip()
|
||||
return result
|
||||
|
||||
|
||||
def _clean_schema_descriptions(openapi_schema: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Clean descriptions in schema definitions by removing docstring metadata."""
|
||||
if "components" not in openapi_schema or "schemas" not in openapi_schema["components"]:
|
||||
return openapi_schema
|
||||
|
||||
schemas = openapi_schema["components"]["schemas"]
|
||||
for schema_def in schemas.values():
|
||||
if isinstance(schema_def, dict) and "description" in schema_def and isinstance(schema_def["description"], str):
|
||||
schema_def["description"] = _clean_description(schema_def["description"])
|
||||
|
||||
return openapi_schema
|
||||
|
||||
|
||||
def _add_extra_body_params_extension(openapi_schema: dict[str, Any]) -> dict[str, Any]:
|
||||
"""
|
||||
Add x-llama-stack-extra-body-params extension to requestBody for endpoints with ExtraBodyField parameters.
|
||||
"""
|
||||
if "paths" not in openapi_schema:
|
||||
return openapi_schema
|
||||
|
||||
from pydantic import TypeAdapter
|
||||
|
||||
for path, path_item in openapi_schema["paths"].items():
|
||||
if not isinstance(path_item, dict):
|
||||
continue
|
||||
|
||||
for method in ["get", "post", "put", "delete", "patch", "head", "options"]:
|
||||
if method not in path_item:
|
||||
continue
|
||||
|
||||
operation = path_item[method]
|
||||
if not isinstance(operation, dict):
|
||||
continue
|
||||
|
||||
# Check if we have extra body fields for this path/method
|
||||
key = (path, method.upper())
|
||||
if key not in _extra_body_fields:
|
||||
continue
|
||||
|
||||
extra_body_params = _extra_body_fields[key]
|
||||
|
||||
# Ensure requestBody exists
|
||||
if "requestBody" not in operation:
|
||||
continue
|
||||
|
||||
request_body = operation["requestBody"]
|
||||
if not isinstance(request_body, dict):
|
||||
continue
|
||||
|
||||
# Get the schema from requestBody
|
||||
content = request_body.get("content", {})
|
||||
json_content = content.get("application/json", {})
|
||||
schema_ref = json_content.get("schema", {})
|
||||
|
||||
# Remove extra body fields from the schema if they exist as properties
|
||||
# Handle both $ref schemas and inline schemas
|
||||
if isinstance(schema_ref, dict):
|
||||
if "$ref" in schema_ref:
|
||||
# Schema is a reference - remove from the referenced schema
|
||||
ref_path = schema_ref["$ref"]
|
||||
if ref_path.startswith("#/components/schemas/"):
|
||||
schema_name = ref_path.split("/")[-1]
|
||||
if "components" in openapi_schema and "schemas" in openapi_schema["components"]:
|
||||
schema_def = openapi_schema["components"]["schemas"].get(schema_name)
|
||||
if isinstance(schema_def, dict) and "properties" in schema_def:
|
||||
for param_name, _, _ in extra_body_params:
|
||||
if param_name in schema_def["properties"]:
|
||||
del schema_def["properties"][param_name]
|
||||
# Also remove from required if present
|
||||
if "required" in schema_def and param_name in schema_def["required"]:
|
||||
schema_def["required"].remove(param_name)
|
||||
elif "properties" in schema_ref:
|
||||
# Schema is inline - remove directly from it
|
||||
for param_name, _, _ in extra_body_params:
|
||||
if param_name in schema_ref["properties"]:
|
||||
del schema_ref["properties"][param_name]
|
||||
# Also remove from required if present
|
||||
if "required" in schema_ref and param_name in schema_ref["required"]:
|
||||
schema_ref["required"].remove(param_name)
|
||||
|
||||
# Build the extra body params schema
|
||||
extra_params_schema = {}
|
||||
for param_name, param_type, description in extra_body_params:
|
||||
try:
|
||||
# Generate JSON schema for the parameter type
|
||||
adapter = TypeAdapter(param_type)
|
||||
param_schema = adapter.json_schema(ref_template="#/components/schemas/{model}")
|
||||
|
||||
# Add description if provided
|
||||
if description:
|
||||
param_schema["description"] = description
|
||||
|
||||
extra_params_schema[param_name] = param_schema
|
||||
except Exception:
|
||||
# If we can't generate schema, skip this parameter
|
||||
continue
|
||||
|
||||
if extra_params_schema:
|
||||
# Add the extension to requestBody
|
||||
if "x-llama-stack-extra-body-params" not in request_body:
|
||||
request_body["x-llama-stack-extra-body-params"] = extra_params_schema
|
||||
|
||||
return openapi_schema
|
||||
|
||||
|
||||
def _remove_query_params_from_body_endpoints(openapi_schema: dict[str, Any]) -> dict[str, Any]:
|
||||
"""
|
||||
Remove query parameters from POST/PUT/PATCH endpoints that have a request body.
|
||||
FastAPI sometimes infers parameters as query params even when they should be in the request body.
|
||||
"""
|
||||
if "paths" not in openapi_schema:
|
||||
return openapi_schema
|
||||
|
||||
body_methods = {"post", "put", "patch"}
|
||||
|
||||
for _path, path_item in openapi_schema["paths"].items():
|
||||
if not isinstance(path_item, dict):
|
||||
continue
|
||||
|
||||
for method in body_methods:
|
||||
if method not in path_item:
|
||||
continue
|
||||
|
||||
operation = path_item[method]
|
||||
if not isinstance(operation, dict):
|
||||
continue
|
||||
|
||||
# Check if this operation has a request body
|
||||
has_request_body = "requestBody" in operation and operation["requestBody"]
|
||||
|
||||
if has_request_body:
|
||||
# Remove all query parameters (parameters with "in": "query")
|
||||
if "parameters" in operation:
|
||||
# Filter out query parameters, keep path and header parameters
|
||||
operation["parameters"] = [
|
||||
param
|
||||
for param in operation["parameters"]
|
||||
if isinstance(param, dict) and param.get("in") != "query"
|
||||
]
|
||||
# Remove the parameters key if it's now empty
|
||||
if not operation["parameters"]:
|
||||
del operation["parameters"]
|
||||
|
||||
return openapi_schema
|
||||
|
||||
|
||||
def _remove_request_bodies_from_get_endpoints(openapi_schema: dict[str, Any]) -> dict[str, Any]:
|
||||
"""
|
||||
Remove request bodies from GET endpoints and convert their parameters to query parameters.
|
||||
|
||||
GET requests should never have request bodies - all parameters should be query parameters.
|
||||
This function removes any requestBody that FastAPI may have incorrectly added to GET endpoints
|
||||
and converts any parameters in the requestBody to query parameters.
|
||||
"""
|
||||
if "paths" not in openapi_schema:
|
||||
return openapi_schema
|
||||
|
||||
for _path, path_item in openapi_schema["paths"].items():
|
||||
if not isinstance(path_item, dict):
|
||||
continue
|
||||
|
||||
# Check GET method specifically
|
||||
if "get" in path_item:
|
||||
operation = path_item["get"]
|
||||
if not isinstance(operation, dict):
|
||||
continue
|
||||
|
||||
if "requestBody" in operation:
|
||||
request_body = operation["requestBody"]
|
||||
# Extract parameters from requestBody and convert to query parameters
|
||||
if isinstance(request_body, dict) and "content" in request_body:
|
||||
content = request_body.get("content", {})
|
||||
json_content = content.get("application/json", {})
|
||||
schema = json_content.get("schema", {})
|
||||
|
||||
if "parameters" not in operation:
|
||||
operation["parameters"] = []
|
||||
elif not isinstance(operation["parameters"], list):
|
||||
operation["parameters"] = []
|
||||
|
||||
# If the schema has properties, convert each to a query parameter
|
||||
if isinstance(schema, dict) and "properties" in schema:
|
||||
for param_name, param_schema in schema["properties"].items():
|
||||
# Check if this parameter is already in the parameters list
|
||||
existing_param = None
|
||||
for existing in operation["parameters"]:
|
||||
if isinstance(existing, dict) and existing.get("name") == param_name:
|
||||
existing_param = existing
|
||||
break
|
||||
|
||||
if not existing_param:
|
||||
# Create a new query parameter from the requestBody property
|
||||
required = param_name in schema.get("required", [])
|
||||
query_param = {
|
||||
"name": param_name,
|
||||
"in": "query",
|
||||
"required": required,
|
||||
"schema": param_schema,
|
||||
}
|
||||
# Add description if present
|
||||
if "description" in param_schema:
|
||||
query_param["description"] = param_schema["description"]
|
||||
operation["parameters"].append(query_param)
|
||||
elif isinstance(schema, dict):
|
||||
# Handle direct schema (not a model with properties)
|
||||
# Try to infer parameter name from schema title
|
||||
param_name = schema.get("title", "").lower().replace(" ", "_")
|
||||
if param_name:
|
||||
# Check if this parameter is already in the parameters list
|
||||
existing_param = None
|
||||
for existing in operation["parameters"]:
|
||||
if isinstance(existing, dict) and existing.get("name") == param_name:
|
||||
existing_param = existing
|
||||
break
|
||||
|
||||
if not existing_param:
|
||||
# Create a new query parameter from the requestBody schema
|
||||
query_param = {
|
||||
"name": param_name,
|
||||
"in": "query",
|
||||
"required": False, # Default to optional for GET requests
|
||||
"schema": schema,
|
||||
}
|
||||
# Add description if present
|
||||
if "description" in schema:
|
||||
query_param["description"] = schema["description"]
|
||||
operation["parameters"].append(query_param)
|
||||
|
||||
# Remove request body from GET endpoint
|
||||
del operation["requestBody"]
|
||||
|
||||
return openapi_schema
|
||||
|
||||
|
||||
def _extract_duplicate_union_types(openapi_schema: dict[str, Any]) -> dict[str, Any]:
|
||||
"""
|
||||
Extract duplicate union types to shared schema references.
|
||||
|
||||
Stainless generates type names from union types based on their context, which can cause
|
||||
duplicate names when the same union appears in different places. This function extracts
|
||||
these duplicate unions to shared schema definitions and replaces inline definitions with
|
||||
references to them.
|
||||
|
||||
According to Stainless docs, when duplicate types are detected, they should be extracted
|
||||
to the same ref and declared as a model. This ensures Stainless generates consistent
|
||||
type names regardless of where the union is referenced.
|
||||
|
||||
Fixes: https://www.stainless.com/docs/reference/diagnostics#Python/DuplicateDeclaration
|
||||
"""
|
||||
if "components" not in openapi_schema or "schemas" not in openapi_schema["components"]:
|
||||
return openapi_schema
|
||||
|
||||
schemas = openapi_schema["components"]["schemas"]
|
||||
|
||||
# Extract the Output union type (used in OpenAIResponseObjectWithInput-Output and ListOpenAIResponseInputItem)
|
||||
output_union_schema_name = "OpenAIResponseMessageOutputUnion"
|
||||
output_union_title = None
|
||||
|
||||
# Get the union type from OpenAIResponseObjectWithInput-Output.input.items.anyOf
|
||||
if "OpenAIResponseObjectWithInput-Output" in schemas:
|
||||
schema = schemas["OpenAIResponseObjectWithInput-Output"]
|
||||
if isinstance(schema, dict) and "properties" in schema:
|
||||
input_prop = schema["properties"].get("input")
|
||||
if isinstance(input_prop, dict) and "items" in input_prop:
|
||||
items = input_prop["items"]
|
||||
if isinstance(items, dict) and "anyOf" in items:
|
||||
# Extract the union schema with deep copy
|
||||
output_union_schema = copy.deepcopy(items["anyOf"])
|
||||
output_union_title = items.get("title", "OpenAIResponseMessageOutputUnion")
|
||||
|
||||
# Collect all refs from the oneOf to detect duplicates
|
||||
refs_in_oneof = set()
|
||||
for item in output_union_schema:
|
||||
if isinstance(item, dict) and "oneOf" in item:
|
||||
oneof = item["oneOf"]
|
||||
if isinstance(oneof, list):
|
||||
for variant in oneof:
|
||||
if isinstance(variant, dict) and "$ref" in variant:
|
||||
refs_in_oneof.add(variant["$ref"])
|
||||
item["x-stainless-naming"] = "OpenAIResponseMessageOutputOneOf"
|
||||
|
||||
# Remove duplicate refs from anyOf that are already in oneOf
|
||||
deduplicated_schema = []
|
||||
for item in output_union_schema:
|
||||
if isinstance(item, dict) and "$ref" in item:
|
||||
if item["$ref"] not in refs_in_oneof:
|
||||
deduplicated_schema.append(item)
|
||||
else:
|
||||
deduplicated_schema.append(item)
|
||||
output_union_schema = deduplicated_schema
|
||||
|
||||
# Create the shared schema with x-stainless-naming to ensure consistent naming
|
||||
if output_union_schema_name not in schemas:
|
||||
schemas[output_union_schema_name] = {
|
||||
"anyOf": output_union_schema,
|
||||
"title": output_union_title,
|
||||
"x-stainless-naming": output_union_schema_name,
|
||||
}
|
||||
# Replace with reference
|
||||
input_prop["items"] = {"$ref": f"#/components/schemas/{output_union_schema_name}"}
|
||||
|
||||
# Replace the same union in ListOpenAIResponseInputItem.data.items.anyOf
|
||||
if "ListOpenAIResponseInputItem" in schemas and output_union_schema_name in schemas:
|
||||
schema = schemas["ListOpenAIResponseInputItem"]
|
||||
if isinstance(schema, dict) and "properties" in schema:
|
||||
data_prop = schema["properties"].get("data")
|
||||
if isinstance(data_prop, dict) and "items" in data_prop:
|
||||
items = data_prop["items"]
|
||||
if isinstance(items, dict) and "anyOf" in items:
|
||||
# Replace with reference
|
||||
data_prop["items"] = {"$ref": f"#/components/schemas/{output_union_schema_name}"}
|
||||
|
||||
# Extract the Input union type (used in _responses_Request.input.anyOf[1].items.anyOf)
|
||||
input_union_schema_name = "OpenAIResponseMessageInputUnion"
|
||||
|
||||
if "_responses_Request" in schemas:
|
||||
schema = schemas["_responses_Request"]
|
||||
if isinstance(schema, dict) and "properties" in schema:
|
||||
input_prop = schema["properties"].get("input")
|
||||
if isinstance(input_prop, dict) and "anyOf" in input_prop:
|
||||
any_of = input_prop["anyOf"]
|
||||
if isinstance(any_of, list) and len(any_of) > 1:
|
||||
# Check the second item (index 1) which should be the array type
|
||||
second_item = any_of[1]
|
||||
if isinstance(second_item, dict) and "items" in second_item:
|
||||
items = second_item["items"]
|
||||
if isinstance(items, dict) and "anyOf" in items:
|
||||
# Extract the union schema with deep copy
|
||||
input_union_schema = copy.deepcopy(items["anyOf"])
|
||||
input_union_title = items.get("title", "OpenAIResponseMessageInputUnion")
|
||||
|
||||
# Collect all refs from the oneOf to detect duplicates
|
||||
refs_in_oneof = set()
|
||||
for item in input_union_schema:
|
||||
if isinstance(item, dict) and "oneOf" in item:
|
||||
oneof = item["oneOf"]
|
||||
if isinstance(oneof, list):
|
||||
for variant in oneof:
|
||||
if isinstance(variant, dict) and "$ref" in variant:
|
||||
refs_in_oneof.add(variant["$ref"])
|
||||
item["x-stainless-naming"] = "OpenAIResponseMessageInputOneOf"
|
||||
|
||||
# Remove duplicate refs from anyOf that are already in oneOf
|
||||
deduplicated_schema = []
|
||||
for item in input_union_schema:
|
||||
if isinstance(item, dict) and "$ref" in item:
|
||||
if item["$ref"] not in refs_in_oneof:
|
||||
deduplicated_schema.append(item)
|
||||
else:
|
||||
deduplicated_schema.append(item)
|
||||
input_union_schema = deduplicated_schema
|
||||
|
||||
# Create the shared schema with x-stainless-naming to ensure consistent naming
|
||||
if input_union_schema_name not in schemas:
|
||||
schemas[input_union_schema_name] = {
|
||||
"anyOf": input_union_schema,
|
||||
"title": input_union_title,
|
||||
"x-stainless-naming": input_union_schema_name,
|
||||
}
|
||||
# Replace with reference
|
||||
second_item["items"] = {"$ref": f"#/components/schemas/{input_union_schema_name}"}
|
||||
|
||||
return openapi_schema
|
||||
|
||||
|
||||
def _convert_multiline_strings_to_literal(obj: Any) -> Any:
|
||||
"""Recursively convert multi-line strings to LiteralScalarString for YAML block scalar formatting."""
|
||||
try:
|
||||
from ruamel.yaml.scalarstring import LiteralScalarString
|
||||
|
||||
if isinstance(obj, str) and "\n" in obj:
|
||||
return LiteralScalarString(obj)
|
||||
elif isinstance(obj, dict):
|
||||
return {key: _convert_multiline_strings_to_literal(value) for key, value in obj.items()}
|
||||
elif isinstance(obj, list):
|
||||
return [_convert_multiline_strings_to_literal(item) for item in obj]
|
||||
else:
|
||||
return obj
|
||||
except ImportError:
|
||||
return obj
|
||||
|
||||
|
||||
def _write_yaml_file(file_path: Path, schema: dict[str, Any]) -> None:
|
||||
"""Write schema to YAML file using ruamel.yaml if available, otherwise standard yaml."""
|
||||
try:
|
||||
from ruamel.yaml import YAML
|
||||
|
||||
yaml_writer = YAML()
|
||||
yaml_writer.default_flow_style = False
|
||||
yaml_writer.sort_keys = False
|
||||
yaml_writer.width = 4096
|
||||
yaml_writer.allow_unicode = True
|
||||
schema = _convert_multiline_strings_to_literal(schema)
|
||||
with open(file_path, "w") as f:
|
||||
yaml_writer.dump(schema, f)
|
||||
except ImportError:
|
||||
with open(file_path, "w") as f:
|
||||
yaml.dump(schema, f, default_flow_style=False, sort_keys=False)
|
||||
|
||||
# Post-process to remove trailing whitespace from all lines
|
||||
with open(file_path) as f:
|
||||
lines = f.readlines()
|
||||
|
||||
# Strip trailing whitespace from each line, preserving newlines
|
||||
cleaned_lines = [line.rstrip() + "\n" if line.endswith("\n") else line.rstrip() for line in lines]
|
||||
|
||||
with open(file_path, "w") as f:
|
||||
f.writelines(cleaned_lines)
|
||||
|
||||
|
||||
def _apply_legacy_sorting(openapi_schema: dict[str, Any]) -> dict[str, Any]:
|
||||
"""
|
||||
Temporarily match the legacy ordering from origin/main so diffs are easier to read.
|
||||
Remove this once the generator output stabilizes and we no longer need legacy diffs.
|
||||
"""
|
||||
|
||||
def order_mapping(data: dict[str, Any], priority: list[str]) -> OrderedDict[str, Any]:
|
||||
ordered: OrderedDict[str, Any] = OrderedDict()
|
||||
for key in priority:
|
||||
if key in data:
|
||||
ordered[key] = data[key]
|
||||
for key, value in data.items():
|
||||
if key not in ordered:
|
||||
ordered[key] = value
|
||||
return ordered
|
||||
|
||||
paths = openapi_schema.get("paths")
|
||||
if isinstance(paths, dict):
|
||||
openapi_schema["paths"] = order_mapping(paths, LEGACY_PATH_ORDER)
|
||||
for path, path_item in openapi_schema["paths"].items():
|
||||
if not isinstance(path_item, dict):
|
||||
continue
|
||||
ordered_path_item = OrderedDict()
|
||||
for method in ["get", "post", "put", "delete", "patch", "head", "options"]:
|
||||
if method in path_item:
|
||||
ordered_path_item[method] = order_mapping(path_item[method], LEGACY_OPERATION_KEYS)
|
||||
for key, value in path_item.items():
|
||||
if key not in ordered_path_item:
|
||||
if isinstance(value, dict) and key.lower() in {
|
||||
"get",
|
||||
"post",
|
||||
"put",
|
||||
"delete",
|
||||
"patch",
|
||||
"head",
|
||||
"options",
|
||||
}:
|
||||
ordered_path_item[key] = order_mapping(value, LEGACY_OPERATION_KEYS)
|
||||
else:
|
||||
ordered_path_item[key] = value
|
||||
openapi_schema["paths"][path] = ordered_path_item
|
||||
|
||||
components = openapi_schema.setdefault("components", {})
|
||||
schemas = components.get("schemas")
|
||||
if isinstance(schemas, dict):
|
||||
components["schemas"] = order_mapping(schemas, LEGACY_SCHEMA_ORDER)
|
||||
responses = components.get("responses")
|
||||
if isinstance(responses, dict):
|
||||
components["responses"] = order_mapping(responses, LEGACY_RESPONSE_ORDER)
|
||||
|
||||
if LEGACY_TAGS:
|
||||
openapi_schema["tags"] = LEGACY_TAGS
|
||||
|
||||
if LEGACY_TAG_GROUPS:
|
||||
openapi_schema["x-tagGroups"] = LEGACY_TAG_GROUPS
|
||||
|
||||
if LEGACY_SECURITY:
|
||||
openapi_schema["security"] = LEGACY_SECURITY
|
||||
|
||||
return openapi_schema
|
||||
|
||||
|
||||
def _fix_schema_issues(openapi_schema: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Fix common schema issues: exclusiveMinimum, null defaults, and add titles to unions."""
|
||||
# Convert anyOf with const values to enums across the entire schema
|
||||
_convert_anyof_const_to_enum(openapi_schema)
|
||||
|
||||
# Fix other schema issues and add titles to unions
|
||||
if "components" in openapi_schema and "schemas" in openapi_schema["components"]:
|
||||
for schema_name, schema_def in openapi_schema["components"]["schemas"].items():
|
||||
_fix_schema_recursive(schema_def)
|
||||
_add_titles_to_unions(schema_def, schema_name)
|
||||
return openapi_schema
|
||||
|
||||
|
||||
def validate_openapi_schema(schema: dict[str, Any], schema_name: str = "OpenAPI schema") -> bool:
|
||||
"""
|
||||
Validate an OpenAPI schema using openapi-spec-validator.
|
||||
|
||||
Args:
|
||||
schema: The OpenAPI schema dictionary to validate
|
||||
schema_name: Name of the schema for error reporting
|
||||
|
||||
Returns:
|
||||
True if valid, False otherwise
|
||||
|
||||
Raises:
|
||||
OpenAPIValidationError: If validation fails
|
||||
"""
|
||||
try:
|
||||
validate_spec(schema)
|
||||
print(f"{schema_name} is valid")
|
||||
return True
|
||||
except OpenAPISpecValidatorError as e:
|
||||
print(f"{schema_name} validation failed: {e}")
|
||||
return False
|
||||
except Exception as e:
|
||||
print(f"{schema_name} validation error: {e}")
|
||||
return False
|
||||
41
scripts/openapi_generator/state.py
Normal file
41
scripts/openapi_generator/state.py
Normal file
|
|
@ -0,0 +1,41 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
"""
|
||||
Shared state for the OpenAPI generator module.
|
||||
"""
|
||||
|
||||
from typing import Any
|
||||
|
||||
from llama_stack_api import Api
|
||||
from llama_stack_api.schema_utils import clear_dynamic_schema_types, register_dynamic_schema_type
|
||||
|
||||
_dynamic_model_registry: dict[str, type] = {}
|
||||
|
||||
# Cache for protocol methods to avoid repeated lookups
|
||||
_protocol_methods_cache: dict[Api, dict[str, Any]] | None = None
|
||||
|
||||
# Global dict to store extra body field information by endpoint
|
||||
# Key: (path, method) tuple, Value: list of (param_name, param_type, description) tuples
|
||||
_extra_body_fields: dict[tuple[str, str], list[tuple[str, type, str | None]]] = {}
|
||||
|
||||
|
||||
def register_dynamic_model(name: str, model: type) -> type:
|
||||
"""Register and deduplicate dynamically generated request models."""
|
||||
existing = _dynamic_model_registry.get(name)
|
||||
if existing is not None:
|
||||
register_dynamic_schema_type(existing)
|
||||
return existing
|
||||
_dynamic_model_registry[name] = model
|
||||
register_dynamic_schema_type(model)
|
||||
return model
|
||||
|
||||
|
||||
def reset_generator_state() -> None:
|
||||
"""Clear per-run caches so repeated generations stay deterministic."""
|
||||
_dynamic_model_registry.clear()
|
||||
_extra_body_fields.clear()
|
||||
clear_dynamic_schema_types()
|
||||
19
scripts/run_openapi_generator.sh
Executable file
19
scripts/run_openapi_generator.sh
Executable file
|
|
@ -0,0 +1,19 @@
|
|||
#!/usr/bin/env bash
|
||||
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
PYTHONPATH=${PYTHONPATH:-}
|
||||
THIS_DIR="$(cd "$(dirname "$(readlink -f "${BASH_SOURCE[0]}")")" && pwd)"
|
||||
|
||||
set -euo pipefail
|
||||
|
||||
|
||||
stack_dir=$(dirname "$THIS_DIR")
|
||||
PYTHONPATH=$PYTHONPATH:$stack_dir \
|
||||
python3 -m scripts.openapi_generator "$stack_dir"/docs/static
|
||||
|
||||
cp "$stack_dir"/docs/static/stainless-llama-stack-spec.yaml "$stack_dir"/client-sdks/stainless/openapi.yml
|
||||
Loading…
Add table
Add a link
Reference in a new issue