mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-03 09:53:45 +00:00
chore: re-add text/event-stream media type
Signed-off-by: Sébastien Han <seb@redhat.com>
This commit is contained in:
parent
66056ddb87
commit
e3d831f504
4 changed files with 120 additions and 26 deletions
|
|
@ -1513,6 +1513,9 @@ paths:
|
||||||
application/json:
|
application/json:
|
||||||
schema:
|
schema:
|
||||||
$ref: '#/components/schemas/OpenAIResponseObject'
|
$ref: '#/components/schemas/OpenAIResponseObject'
|
||||||
|
text/event-stream:
|
||||||
|
schema:
|
||||||
|
$ref: '#/components/schemas/OpenAIResponseObjectStream'
|
||||||
'400':
|
'400':
|
||||||
$ref: '#/components/responses/BadRequest400'
|
$ref: '#/components/responses/BadRequest400'
|
||||||
description: Bad Request
|
description: Bad Request
|
||||||
|
|
@ -1866,6 +1869,9 @@ paths:
|
||||||
application/json:
|
application/json:
|
||||||
schema:
|
schema:
|
||||||
$ref: '#/components/schemas/OpenAIChatCompletion'
|
$ref: '#/components/schemas/OpenAIChatCompletion'
|
||||||
|
text/event-stream:
|
||||||
|
schema:
|
||||||
|
$ref: '#/components/schemas/OpenAIChatCompletionChunk'
|
||||||
'400':
|
'400':
|
||||||
$ref: '#/components/responses/BadRequest400'
|
$ref: '#/components/responses/BadRequest400'
|
||||||
description: Bad Request
|
description: Bad Request
|
||||||
|
|
|
||||||
6
docs/static/llama-stack-spec.yaml
vendored
6
docs/static/llama-stack-spec.yaml
vendored
|
|
@ -121,6 +121,9 @@ paths:
|
||||||
application/json:
|
application/json:
|
||||||
schema:
|
schema:
|
||||||
$ref: '#/components/schemas/OpenAIResponseObject'
|
$ref: '#/components/schemas/OpenAIResponseObject'
|
||||||
|
text/event-stream:
|
||||||
|
schema:
|
||||||
|
$ref: '#/components/schemas/OpenAIResponseObjectStream'
|
||||||
'400':
|
'400':
|
||||||
$ref: '#/components/responses/BadRequest400'
|
$ref: '#/components/responses/BadRequest400'
|
||||||
description: Bad Request
|
description: Bad Request
|
||||||
|
|
@ -454,6 +457,9 @@ paths:
|
||||||
application/json:
|
application/json:
|
||||||
schema:
|
schema:
|
||||||
$ref: '#/components/schemas/OpenAIChatCompletion'
|
$ref: '#/components/schemas/OpenAIChatCompletion'
|
||||||
|
text/event-stream:
|
||||||
|
schema:
|
||||||
|
$ref: '#/components/schemas/OpenAIChatCompletionChunk'
|
||||||
'400':
|
'400':
|
||||||
$ref: '#/components/responses/BadRequest400'
|
$ref: '#/components/responses/BadRequest400'
|
||||||
description: Bad Request
|
description: Bad Request
|
||||||
|
|
|
||||||
6
docs/static/stainless-llama-stack-spec.yaml
vendored
6
docs/static/stainless-llama-stack-spec.yaml
vendored
|
|
@ -1513,6 +1513,9 @@ paths:
|
||||||
application/json:
|
application/json:
|
||||||
schema:
|
schema:
|
||||||
$ref: '#/components/schemas/OpenAIResponseObject'
|
$ref: '#/components/schemas/OpenAIResponseObject'
|
||||||
|
text/event-stream:
|
||||||
|
schema:
|
||||||
|
$ref: '#/components/schemas/OpenAIResponseObjectStream'
|
||||||
'400':
|
'400':
|
||||||
$ref: '#/components/responses/BadRequest400'
|
$ref: '#/components/responses/BadRequest400'
|
||||||
description: Bad Request
|
description: Bad Request
|
||||||
|
|
@ -1864,6 +1867,9 @@ paths:
|
||||||
application/json:
|
application/json:
|
||||||
schema:
|
schema:
|
||||||
$ref: '#/components/schemas/OpenAIChatCompletion'
|
$ref: '#/components/schemas/OpenAIChatCompletion'
|
||||||
|
text/event-stream:
|
||||||
|
schema:
|
||||||
|
$ref: '#/components/schemas/OpenAIChatCompletionChunk'
|
||||||
'400':
|
'400':
|
||||||
$ref: '#/components/responses/BadRequest400'
|
$ref: '#/components/responses/BadRequest400'
|
||||||
description: Bad Request
|
description: Bad Request
|
||||||
|
|
|
||||||
|
|
@ -12,6 +12,8 @@ FastAPI-based OpenAPI generator for Llama Stack.
|
||||||
import importlib
|
import importlib
|
||||||
import inspect
|
import inspect
|
||||||
import pkgutil
|
import pkgutil
|
||||||
|
import types
|
||||||
|
import typing
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Annotated, Any, get_args, get_origin
|
from typing import Annotated, Any, get_args, get_origin
|
||||||
|
|
||||||
|
|
@ -237,7 +239,9 @@ def _create_fastapi_endpoint(app: FastAPI, route, webmethod, api: Api):
|
||||||
name = route.name
|
name = route.name
|
||||||
fastapi_path = path.replace("{", "{").replace("}", "}")
|
fastapi_path = path.replace("{", "{").replace("}", "}")
|
||||||
|
|
||||||
request_model, response_model, query_parameters, file_form_params = _find_models_for_endpoint(webmethod, api, name)
|
request_model, response_model, query_parameters, file_form_params, streaming_response_model = (
|
||||||
|
_find_models_for_endpoint(webmethod, api, name)
|
||||||
|
)
|
||||||
operation_description = _extract_operation_description_from_docstring(api, name)
|
operation_description = _extract_operation_description_from_docstring(api, name)
|
||||||
response_description = _extract_response_description_from_docstring(webmethod, response_model, api, name)
|
response_description = _extract_response_description_from_docstring(webmethod, response_model, api, name)
|
||||||
is_post_put = any(method.upper() in ["POST", "PUT", "PATCH"] for method in methods)
|
is_post_put = any(method.upper() in ["POST", "PUT", "PATCH"] for method in methods)
|
||||||
|
|
@ -336,6 +340,32 @@ def _create_fastapi_endpoint(app: FastAPI, route, webmethod, api: Api):
|
||||||
no_params_endpoint.__doc__ = operation_description
|
no_params_endpoint.__doc__ = operation_description
|
||||||
endpoint_func = no_params_endpoint
|
endpoint_func = no_params_endpoint
|
||||||
|
|
||||||
|
# Build response content with both application/json and text/event-stream if streaming
|
||||||
|
response_content = {}
|
||||||
|
if response_model:
|
||||||
|
response_content["application/json"] = {"schema": {"$ref": f"#/components/schemas/{response_model.__name__}"}}
|
||||||
|
if streaming_response_model:
|
||||||
|
# Get the schema name for the streaming model
|
||||||
|
# It might be a registered schema or a Pydantic model
|
||||||
|
streaming_schema_name = None
|
||||||
|
# Check if it's a registered schema first (before checking __name__)
|
||||||
|
# because registered schemas might be Annotated types
|
||||||
|
from llama_stack.schema_utils import _registered_schemas
|
||||||
|
|
||||||
|
if streaming_response_model in _registered_schemas:
|
||||||
|
streaming_schema_name = _registered_schemas[streaming_response_model]["name"]
|
||||||
|
elif hasattr(streaming_response_model, "__name__"):
|
||||||
|
streaming_schema_name = streaming_response_model.__name__
|
||||||
|
|
||||||
|
if streaming_schema_name:
|
||||||
|
response_content["text/event-stream"] = {
|
||||||
|
"schema": {"$ref": f"#/components/schemas/{streaming_schema_name}"}
|
||||||
|
}
|
||||||
|
|
||||||
|
# If no content types, use empty schema
|
||||||
|
if not response_content:
|
||||||
|
response_content["application/json"] = {"schema": {}}
|
||||||
|
|
||||||
# Add the endpoint to the FastAPI app
|
# Add the endpoint to the FastAPI app
|
||||||
is_deprecated = webmethod.deprecated or False
|
is_deprecated = webmethod.deprecated or False
|
||||||
route_kwargs = {
|
route_kwargs = {
|
||||||
|
|
@ -345,11 +375,7 @@ def _create_fastapi_endpoint(app: FastAPI, route, webmethod, api: Api):
|
||||||
"responses": {
|
"responses": {
|
||||||
200: {
|
200: {
|
||||||
"description": response_description,
|
"description": response_description,
|
||||||
"content": {
|
"content": response_content,
|
||||||
"application/json": {
|
|
||||||
"schema": {"$ref": f"#/components/schemas/{response_model.__name__}"} if response_model else {}
|
|
||||||
}
|
|
||||||
},
|
|
||||||
},
|
},
|
||||||
400: {"$ref": "#/components/responses/BadRequest400"},
|
400: {"$ref": "#/components/responses/BadRequest400"},
|
||||||
429: {"$ref": "#/components/responses/TooManyRequests429"},
|
429: {"$ref": "#/components/responses/TooManyRequests429"},
|
||||||
|
|
@ -422,9 +448,59 @@ def _is_extra_body_field(metadata_item: Any) -> bool:
|
||||||
return isinstance(metadata_item, ExtraBodyField)
|
return isinstance(metadata_item, ExtraBodyField)
|
||||||
|
|
||||||
|
|
||||||
|
def _is_async_iterator_type(type_obj: Any) -> bool:
|
||||||
|
"""Check if a type is AsyncIterator or AsyncIterable."""
|
||||||
|
from collections.abc import AsyncIterable, AsyncIterator
|
||||||
|
|
||||||
|
origin = get_origin(type_obj)
|
||||||
|
if origin is None:
|
||||||
|
# Check if it's the class itself
|
||||||
|
return type_obj in (AsyncIterator, AsyncIterable) or (
|
||||||
|
hasattr(type_obj, "__origin__") and type_obj.__origin__ in (AsyncIterator, AsyncIterable)
|
||||||
|
)
|
||||||
|
return origin in (AsyncIterator, AsyncIterable)
|
||||||
|
|
||||||
|
|
||||||
|
def _extract_response_models_from_union(union_type: Any) -> tuple[type | None, type | None]:
|
||||||
|
"""
|
||||||
|
Extract non-streaming and streaming response models from a union type.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
tuple: (non_streaming_model, streaming_model)
|
||||||
|
"""
|
||||||
|
non_streaming_model = None
|
||||||
|
streaming_model = None
|
||||||
|
|
||||||
|
args = get_args(union_type)
|
||||||
|
for arg in args:
|
||||||
|
# Check if it's an AsyncIterator
|
||||||
|
if _is_async_iterator_type(arg):
|
||||||
|
# Extract the type argument from AsyncIterator[T]
|
||||||
|
iterator_args = get_args(arg)
|
||||||
|
if iterator_args:
|
||||||
|
inner_type = iterator_args[0]
|
||||||
|
# Check if the inner type is a registered schema (union type)
|
||||||
|
# or a Pydantic model
|
||||||
|
if hasattr(inner_type, "model_json_schema"):
|
||||||
|
streaming_model = inner_type
|
||||||
|
else:
|
||||||
|
# Might be a registered schema - check if it's registered
|
||||||
|
from llama_stack.schema_utils import _registered_schemas
|
||||||
|
|
||||||
|
if inner_type in _registered_schemas:
|
||||||
|
# We'll need to look this up later, but for now store the type
|
||||||
|
streaming_model = inner_type
|
||||||
|
elif hasattr(arg, "model_json_schema"):
|
||||||
|
# Non-streaming Pydantic model
|
||||||
|
if non_streaming_model is None:
|
||||||
|
non_streaming_model = arg
|
||||||
|
|
||||||
|
return non_streaming_model, streaming_model
|
||||||
|
|
||||||
|
|
||||||
def _find_models_for_endpoint(
|
def _find_models_for_endpoint(
|
||||||
webmethod, api: Api, method_name: str
|
webmethod, api: Api, method_name: str
|
||||||
) -> tuple[type | None, type | None, list[tuple[str, type, Any]], list[inspect.Parameter]]:
|
) -> tuple[type | None, type | None, list[tuple[str, type, Any]], list[inspect.Parameter], type | None]:
|
||||||
"""
|
"""
|
||||||
Find appropriate request and response models for an endpoint by analyzing the actual function signature.
|
Find appropriate request and response models for an endpoint by analyzing the actual function signature.
|
||||||
This uses the protocol function to determine the correct models dynamically.
|
This uses the protocol function to determine the correct models dynamically.
|
||||||
|
|
@ -435,15 +511,16 @@ def _find_models_for_endpoint(
|
||||||
method_name: The method name (function name)
|
method_name: The method name (function name)
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
tuple: (request_model, response_model, query_parameters, file_form_params)
|
tuple: (request_model, response_model, query_parameters, file_form_params, streaming_response_model)
|
||||||
where query_parameters is a list of (name, type, default_value) tuples
|
where query_parameters is a list of (name, type, default_value) tuples
|
||||||
and file_form_params is a list of inspect.Parameter objects for File()/Form() params
|
and file_form_params is a list of inspect.Parameter objects for File()/Form() params
|
||||||
|
and streaming_response_model is the model for streaming responses (AsyncIterator content)
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
# Get the function from the protocol
|
# Get the function from the protocol
|
||||||
func = _get_protocol_method(api, method_name)
|
func = _get_protocol_method(api, method_name)
|
||||||
if not func:
|
if not func:
|
||||||
return None, None, [], []
|
return None, None, [], [], None
|
||||||
|
|
||||||
# Analyze the function signature
|
# Analyze the function signature
|
||||||
sig = inspect.signature(func)
|
sig = inspect.signature(func)
|
||||||
|
|
@ -542,38 +619,37 @@ def _find_models_for_endpoint(
|
||||||
query_parameters = [] # Clear query_parameters so we use the single model
|
query_parameters = [] # Clear query_parameters so we use the single model
|
||||||
|
|
||||||
# Find response model from return annotation
|
# Find response model from return annotation
|
||||||
|
# Also detect streaming response models (AsyncIterator)
|
||||||
response_model = None
|
response_model = None
|
||||||
|
streaming_response_model = None
|
||||||
return_annotation = sig.return_annotation
|
return_annotation = sig.return_annotation
|
||||||
if return_annotation != inspect.Signature.empty:
|
if return_annotation != inspect.Signature.empty:
|
||||||
|
origin = get_origin(return_annotation)
|
||||||
if hasattr(return_annotation, "model_json_schema"):
|
if hasattr(return_annotation, "model_json_schema"):
|
||||||
response_model = return_annotation
|
response_model = return_annotation
|
||||||
elif get_origin(return_annotation) is Annotated:
|
elif origin is Annotated:
|
||||||
# Handle Annotated return types
|
# Handle Annotated return types
|
||||||
args = get_args(return_annotation)
|
args = get_args(return_annotation)
|
||||||
if args:
|
if args:
|
||||||
# Check if the first argument is a Pydantic model
|
# Check if the first argument is a Pydantic model
|
||||||
if hasattr(args[0], "model_json_schema"):
|
if hasattr(args[0], "model_json_schema"):
|
||||||
response_model = args[0]
|
response_model = args[0]
|
||||||
# Check if the first argument is a union type
|
else:
|
||||||
elif get_origin(args[0]) is type(args[0]): # Union type
|
# Check if the first argument is a union type
|
||||||
union_args = get_args(args[0])
|
inner_origin = get_origin(args[0])
|
||||||
for arg in union_args:
|
if inner_origin is not None and (
|
||||||
if hasattr(arg, "model_json_schema"):
|
inner_origin is types.UnionType or inner_origin is typing.Union
|
||||||
response_model = arg
|
):
|
||||||
break
|
response_model, streaming_response_model = _extract_response_models_from_union(args[0])
|
||||||
elif get_origin(return_annotation) is type(return_annotation): # Union type
|
elif origin is not None and (origin is types.UnionType or origin is typing.Union):
|
||||||
# Handle union types - try to find the first Pydantic model
|
# Handle union types - extract both non-streaming and streaming models
|
||||||
args = get_args(return_annotation)
|
response_model, streaming_response_model = _extract_response_models_from_union(return_annotation)
|
||||||
for arg in args:
|
|
||||||
if hasattr(arg, "model_json_schema"):
|
|
||||||
response_model = arg
|
|
||||||
break
|
|
||||||
|
|
||||||
return request_model, response_model, query_parameters, file_form_params
|
return request_model, response_model, query_parameters, file_form_params, streaming_response_model
|
||||||
|
|
||||||
except Exception:
|
except Exception:
|
||||||
# If we can't analyze the function signature, return None
|
# If we can't analyze the function signature, return None
|
||||||
return None, None, [], []
|
return None, None, [], [], None
|
||||||
|
|
||||||
|
|
||||||
def _ensure_components_schemas(openapi_schema: dict[str, Any]) -> None:
|
def _ensure_components_schemas(openapi_schema: dict[str, Any]) -> None:
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue