feat: openai files api (#2321)

# What does this PR do?
* Adds the OpenAI compatible Files API
* Modified doc gen script to support multipart parameter

## Test Plan

---
[//]: # (BEGIN SAPLING FOOTER)
Stack created with [Sapling](https://sapling-scm.com). Best reviewed
with
[ReviewStack](https://reviewstack.dev/meta-llama/llama-stack/pull/2321).
* #2330
* __->__ #2321
This commit is contained in:
ehhuang 2025-06-02 11:45:53 -07:00 committed by GitHub
parent 17f4414be9
commit 31a3ae60f4
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
6 changed files with 881 additions and 946 deletions

View file

@ -30,6 +30,9 @@ from llama_stack.strong_typing.schema import (
Schema,
SchemaOptions,
)
from typing import get_origin, get_args
from typing import Annotated
from fastapi import UploadFile
from llama_stack.strong_typing.serialization import json_dump_string, object_to_json
from .operations import (
@ -618,6 +621,45 @@ class Generator:
},
required=True,
)
# data passed in request body as multipart/form-data
elif op.multipart_params:
builder = ContentBuilder(self.schema_builder)
# Create schema properties for multipart form fields
properties = {}
required_fields = []
for name, param_type in op.multipart_params:
if get_origin(param_type) is Annotated:
base_type = get_args(param_type)[0]
else:
base_type = param_type
if base_type is UploadFile:
# File upload
properties[name] = {
"type": "string",
"format": "binary"
}
else:
# Form field
properties[name] = self.schema_builder.classdef_to_ref(base_type)
required_fields.append(name)
multipart_schema = {
"type": "object",
"properties": properties,
"required": required_fields
}
requestBody = RequestBody(
content={
"multipart/form-data": {
"schema": multipart_schema
}
},
required=True,
)
# data passed in payload as JSON and mapped to request parameters
elif op.request_params:
builder = ContentBuilder(self.schema_builder)

View file

@ -17,6 +17,12 @@ from termcolor import colored
from llama_stack.strong_typing.inspection import get_signature
from typing import get_origin, get_args
from fastapi import UploadFile
from fastapi.params import File, Form
from typing import Annotated
def split_prefix(
s: str, sep: str, prefix: Union[str, Iterable[str]]
@ -82,6 +88,7 @@ class EndpointOperation:
:param path_params: Parameters of the operation signature that are passed in the path component of the URL string.
:param query_params: Parameters of the operation signature that are passed in the query string as `key=value` pairs.
:param request_params: The parameter that corresponds to the data transmitted in the request body.
:param multipart_params: Parameters that indicate multipart/form-data request body.
:param event_type: The Python type of the data that is transmitted out-of-band (e.g. via websockets) while the operation is in progress.
:param response_type: The Python type of the data that is transmitted in the response body.
:param http_method: The HTTP method used to invoke the endpoint such as POST, GET or PUT.
@ -98,6 +105,7 @@ class EndpointOperation:
path_params: List[OperationParameter]
query_params: List[OperationParameter]
request_params: Optional[OperationParameter]
multipart_params: List[OperationParameter]
event_type: Optional[type]
response_type: type
http_method: HTTPMethod
@ -252,6 +260,7 @@ def get_endpoint_operations(
path_params = []
query_params = []
request_params = []
multipart_params = []
for param_name, parameter in signature.parameters.items():
param_type = _get_annotation_type(parameter.annotation, func_ref)
@ -266,6 +275,8 @@ def get_endpoint_operations(
f"parameter '{param_name}' in function '{func_name}' has no type annotation"
)
is_multipart = _is_multipart_param(param_type)
if prefix in ["get", "delete"]:
if route_params is not None and param_name in route_params:
path_params.append((param_name, param_type))
@ -274,6 +285,8 @@ def get_endpoint_operations(
else:
if route_params is not None and param_name in route_params:
path_params.append((param_name, param_type))
elif is_multipart:
multipart_params.append((param_name, param_type))
else:
request_params.append((param_name, param_type))
@ -333,6 +346,7 @@ def get_endpoint_operations(
path_params=path_params,
query_params=query_params,
request_params=request_params,
multipart_params=multipart_params,
event_type=event_type,
response_type=response_type,
http_method=http_method,
@ -377,3 +391,34 @@ def get_endpoint_events(endpoint: type) -> Dict[str, type]:
results[param_type.__name__] = param_type
return results
def _is_multipart_param(param_type: type) -> bool:
"""
Check if a parameter type indicates multipart form data.
Returns True if the type is:
- UploadFile
- Annotated[UploadFile, File()]
- Annotated[str, Form()]
- Annotated[Any, File()]
- Annotated[Any, Form()]
"""
if param_type is UploadFile:
return True
# Check for Annotated types
origin = get_origin(param_type)
if origin is None:
return False
if origin is Annotated:
args = get_args(param_type)
if len(args) < 2:
return False
# Check the annotations for File() or Form()
for annotation in args[1:]:
if isinstance(annotation, (File, Form)):
return True
return False

View file

@ -153,6 +153,12 @@ def _validate_api_delete_method_returns_none(method) -> str | None:
return "has no return type annotation"
return_type = hints['return']
# Allow OpenAI endpoints to return response objects since they follow OpenAI specification
method_name = getattr(method, '__name__', '')
if method_name.startswith('openai_'):
return None
if return_type is not None and return_type is not type(None):
return "does not return None where None is mandatory"