mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-10-03 19:57:35 +00:00
Merge a9ab928ac4
into 188a56af5c
This commit is contained in:
commit
4cc23c4e5f
19 changed files with 387 additions and 32 deletions
|
@ -50,6 +50,7 @@ from .specification import (
|
|||
Document,
|
||||
Example,
|
||||
ExampleRef,
|
||||
ExtraBodyParameter,
|
||||
MediaType,
|
||||
Operation,
|
||||
Parameter,
|
||||
|
@ -677,6 +678,27 @@ class Generator:
|
|||
# parameters passed anywhere
|
||||
parameters = path_parameters + query_parameters
|
||||
|
||||
# Build extra body parameters documentation
|
||||
extra_body_parameters = []
|
||||
for param_name, param_type, description in op.extra_body_params:
|
||||
if is_type_optional(param_type):
|
||||
inner_type: type = unwrap_optional_type(param_type)
|
||||
required = False
|
||||
else:
|
||||
inner_type = param_type
|
||||
required = True
|
||||
|
||||
# Use description from ExtraBodyField if available, otherwise from docstring
|
||||
param_description = description or doc_params.get(param_name)
|
||||
|
||||
extra_body_param = ExtraBodyParameter(
|
||||
name=param_name,
|
||||
schema=self.schema_builder.classdef_to_ref(inner_type),
|
||||
description=param_description,
|
||||
required=required,
|
||||
)
|
||||
extra_body_parameters.append(extra_body_param)
|
||||
|
||||
webmethod = getattr(op.func_ref, "__webmethod__", None)
|
||||
raw_bytes_request_body = False
|
||||
if webmethod:
|
||||
|
@ -898,6 +920,7 @@ class Generator:
|
|||
deprecated=getattr(op.webmethod, "deprecated", False)
|
||||
or "DEPRECATED" in op.func_name,
|
||||
security=[] if op.public else None,
|
||||
extraBodyParameters=extra_body_parameters if extra_body_parameters else None,
|
||||
)
|
||||
|
||||
def _get_api_stability_priority(self, api_level: str) -> int:
|
||||
|
|
|
@ -23,6 +23,8 @@ from fastapi import UploadFile
|
|||
from fastapi.params import File, Form
|
||||
from typing import Annotated
|
||||
|
||||
from llama_stack.schema_utils import ExtraBodyField
|
||||
|
||||
|
||||
def split_prefix(
|
||||
s: str, sep: str, prefix: Union[str, Iterable[str]]
|
||||
|
@ -89,6 +91,7 @@ class EndpointOperation:
|
|||
:param query_params: Parameters of the operation signature that are passed in the query string as `key=value` pairs.
|
||||
:param request_params: The parameter that corresponds to the data transmitted in the request body.
|
||||
:param multipart_params: Parameters that indicate multipart/form-data request body.
|
||||
:param extra_body_params: Parameters that arrive via extra_body and are documented but not in SDK.
|
||||
:param event_type: The Python type of the data that is transmitted out-of-band (e.g. via websockets) while the operation is in progress.
|
||||
:param response_type: The Python type of the data that is transmitted in the response body.
|
||||
:param http_method: The HTTP method used to invoke the endpoint such as POST, GET or PUT.
|
||||
|
@ -106,6 +109,7 @@ class EndpointOperation:
|
|||
query_params: List[OperationParameter]
|
||||
request_params: Optional[OperationParameter]
|
||||
multipart_params: List[OperationParameter]
|
||||
extra_body_params: List[tuple[str, type, str | None]]
|
||||
event_type: Optional[type]
|
||||
response_type: type
|
||||
http_method: HTTPMethod
|
||||
|
@ -265,6 +269,7 @@ def get_endpoint_operations(
|
|||
query_params = []
|
||||
request_params = []
|
||||
multipart_params = []
|
||||
extra_body_params = []
|
||||
|
||||
for param_name, parameter in signature.parameters.items():
|
||||
param_type = _get_annotation_type(parameter.annotation, func_ref)
|
||||
|
@ -279,6 +284,13 @@ def get_endpoint_operations(
|
|||
f"parameter '{param_name}' in function '{func_name}' has no type annotation"
|
||||
)
|
||||
|
||||
# Check if this is an extra_body parameter
|
||||
is_extra_body, extra_body_desc = _is_extra_body_param(param_type)
|
||||
if is_extra_body:
|
||||
# Store in a separate list for documentation
|
||||
extra_body_params.append((param_name, param_type, extra_body_desc))
|
||||
continue # Skip adding to request_params
|
||||
|
||||
is_multipart = _is_multipart_param(param_type)
|
||||
|
||||
if prefix in ["get", "delete"]:
|
||||
|
@ -351,6 +363,7 @@ def get_endpoint_operations(
|
|||
query_params=query_params,
|
||||
request_params=request_params,
|
||||
multipart_params=multipart_params,
|
||||
extra_body_params=extra_body_params,
|
||||
event_type=event_type,
|
||||
response_type=response_type,
|
||||
http_method=http_method,
|
||||
|
@ -429,3 +442,22 @@ def _is_multipart_param(param_type: type) -> bool:
|
|||
if isinstance(annotation, (File, Form)):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def _is_extra_body_param(param_type: type) -> tuple[bool, str | None]:
|
||||
"""
|
||||
Check if parameter is marked as coming from extra_body.
|
||||
|
||||
Returns:
|
||||
(is_extra_body, description): Tuple of boolean and optional description
|
||||
"""
|
||||
origin = get_origin(param_type)
|
||||
if origin is Annotated:
|
||||
args = get_args(param_type)
|
||||
for annotation in args[1:]:
|
||||
if isinstance(annotation, ExtraBodyField):
|
||||
return True, annotation.description
|
||||
# Also check by type name for cases where import matters
|
||||
if type(annotation).__name__ == 'ExtraBodyField':
|
||||
return True, getattr(annotation, 'description', None)
|
||||
return False, None
|
||||
|
|
|
@ -106,6 +106,15 @@ class Parameter:
|
|||
example: Optional[Any] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class ExtraBodyParameter:
|
||||
"""Represents a parameter that arrives via extra_body in the request."""
|
||||
name: str
|
||||
schema: SchemaOrRef
|
||||
description: Optional[str] = None
|
||||
required: Optional[bool] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class Operation:
|
||||
responses: Dict[str, Union[Response, ResponseRef]]
|
||||
|
@ -118,6 +127,7 @@ class Operation:
|
|||
callbacks: Optional[Dict[str, "Callback"]] = None
|
||||
security: Optional[List["SecurityRequirement"]] = None
|
||||
deprecated: Optional[bool] = None
|
||||
extraBodyParameters: Optional[List[ExtraBodyParameter]] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
|
|
|
@ -52,6 +52,17 @@ class Specification:
|
|||
if display_name:
|
||||
tag["x-displayName"] = display_name
|
||||
|
||||
# Handle operations to rename extraBodyParameters -> x-llama-stack-extra-body-params
|
||||
paths = json_doc.get("paths", {})
|
||||
for path_item in paths.values():
|
||||
if isinstance(path_item, dict):
|
||||
for method in ["get", "post", "put", "delete", "patch"]:
|
||||
operation = path_item.get(method)
|
||||
if operation and isinstance(operation, dict):
|
||||
extra_body_params = operation.pop("extraBodyParameters", None)
|
||||
if extra_body_params:
|
||||
operation["x-llama-stack-extra-body-params"] = extra_body_params
|
||||
|
||||
return json_doc
|
||||
|
||||
def get_json_string(self, pretty_print: bool = False) -> str:
|
||||
|
|
37
docs/static/deprecated-llama-stack-spec.html
vendored
37
docs/static/deprecated-llama-stack-spec.html
vendored
|
@ -2132,7 +2132,27 @@
|
|||
},
|
||||
"required": true
|
||||
},
|
||||
"deprecated": true
|
||||
"deprecated": true,
|
||||
"x-llama-stack-extra-body-params": [
|
||||
{
|
||||
"name": "shields",
|
||||
"schema": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"oneOf": [
|
||||
{
|
||||
"type": "string"
|
||||
},
|
||||
{
|
||||
"$ref": "#/components/schemas/ResponseShieldSpec"
|
||||
}
|
||||
]
|
||||
}
|
||||
},
|
||||
"description": "List of shields to apply during response generation. Shields provide safety and content moderation.",
|
||||
"required": false
|
||||
}
|
||||
]
|
||||
}
|
||||
},
|
||||
"/v1/openai/v1/responses/{response_id}": {
|
||||
|
@ -9521,6 +9541,21 @@
|
|||
"title": "OpenAIResponseText",
|
||||
"description": "Text response configuration for OpenAI responses."
|
||||
},
|
||||
"ResponseShieldSpec": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"type": {
|
||||
"type": "string",
|
||||
"description": "The type/identifier of the shield."
|
||||
}
|
||||
},
|
||||
"additionalProperties": false,
|
||||
"required": [
|
||||
"type"
|
||||
],
|
||||
"title": "ResponseShieldSpec",
|
||||
"description": "Specification for a shield to apply during response generation."
|
||||
},
|
||||
"OpenAIResponseInputTool": {
|
||||
"oneOf": [
|
||||
{
|
||||
|
|
24
docs/static/deprecated-llama-stack-spec.yaml
vendored
24
docs/static/deprecated-llama-stack-spec.yaml
vendored
|
@ -1559,6 +1559,18 @@ paths:
|
|||
$ref: '#/components/schemas/CreateOpenaiResponseRequest'
|
||||
required: true
|
||||
deprecated: true
|
||||
x-llama-stack-extra-body-params:
|
||||
- name: shields
|
||||
schema:
|
||||
type: array
|
||||
items:
|
||||
oneOf:
|
||||
- type: string
|
||||
- $ref: '#/components/schemas/ResponseShieldSpec'
|
||||
description: >-
|
||||
List of shields to apply during response generation. Shields provide safety
|
||||
and content moderation.
|
||||
required: false
|
||||
/v1/openai/v1/responses/{response_id}:
|
||||
get:
|
||||
responses:
|
||||
|
@ -7076,6 +7088,18 @@ components:
|
|||
title: OpenAIResponseText
|
||||
description: >-
|
||||
Text response configuration for OpenAI responses.
|
||||
ResponseShieldSpec:
|
||||
type: object
|
||||
properties:
|
||||
type:
|
||||
type: string
|
||||
description: The type/identifier of the shield.
|
||||
additionalProperties: false
|
||||
required:
|
||||
- type
|
||||
title: ResponseShieldSpec
|
||||
description: >-
|
||||
Specification for a shield to apply during response generation.
|
||||
OpenAIResponseInputTool:
|
||||
oneOf:
|
||||
- $ref: '#/components/schemas/OpenAIResponseInputToolWebSearch'
|
||||
|
|
37
docs/static/llama-stack-spec.html
vendored
37
docs/static/llama-stack-spec.html
vendored
|
@ -1830,7 +1830,27 @@
|
|||
},
|
||||
"required": true
|
||||
},
|
||||
"deprecated": false
|
||||
"deprecated": false,
|
||||
"x-llama-stack-extra-body-params": [
|
||||
{
|
||||
"name": "shields",
|
||||
"schema": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"oneOf": [
|
||||
{
|
||||
"type": "string"
|
||||
},
|
||||
{
|
||||
"$ref": "#/components/schemas/ResponseShieldSpec"
|
||||
}
|
||||
]
|
||||
}
|
||||
},
|
||||
"description": "List of shields to apply during response generation. Shields provide safety and content moderation.",
|
||||
"required": false
|
||||
}
|
||||
]
|
||||
}
|
||||
},
|
||||
"/v1/responses/{response_id}": {
|
||||
|
@ -7616,6 +7636,21 @@
|
|||
"title": "OpenAIResponseText",
|
||||
"description": "Text response configuration for OpenAI responses."
|
||||
},
|
||||
"ResponseShieldSpec": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"type": {
|
||||
"type": "string",
|
||||
"description": "The type/identifier of the shield."
|
||||
}
|
||||
},
|
||||
"additionalProperties": false,
|
||||
"required": [
|
||||
"type"
|
||||
],
|
||||
"title": "ResponseShieldSpec",
|
||||
"description": "Specification for a shield to apply during response generation."
|
||||
},
|
||||
"OpenAIResponseInputTool": {
|
||||
"oneOf": [
|
||||
{
|
||||
|
|
24
docs/static/llama-stack-spec.yaml
vendored
24
docs/static/llama-stack-spec.yaml
vendored
|
@ -1411,6 +1411,18 @@ paths:
|
|||
$ref: '#/components/schemas/CreateOpenaiResponseRequest'
|
||||
required: true
|
||||
deprecated: false
|
||||
x-llama-stack-extra-body-params:
|
||||
- name: shields
|
||||
schema:
|
||||
type: array
|
||||
items:
|
||||
oneOf:
|
||||
- type: string
|
||||
- $ref: '#/components/schemas/ResponseShieldSpec'
|
||||
description: >-
|
||||
List of shields to apply during response generation. Shields provide safety
|
||||
and content moderation.
|
||||
required: false
|
||||
/v1/responses/{response_id}:
|
||||
get:
|
||||
responses:
|
||||
|
@ -5739,6 +5751,18 @@ components:
|
|||
title: OpenAIResponseText
|
||||
description: >-
|
||||
Text response configuration for OpenAI responses.
|
||||
ResponseShieldSpec:
|
||||
type: object
|
||||
properties:
|
||||
type:
|
||||
type: string
|
||||
description: The type/identifier of the shield.
|
||||
additionalProperties: false
|
||||
required:
|
||||
- type
|
||||
title: ResponseShieldSpec
|
||||
description: >-
|
||||
Specification for a shield to apply during response generation.
|
||||
OpenAIResponseInputTool:
|
||||
oneOf:
|
||||
- $ref: '#/components/schemas/OpenAIResponseInputToolWebSearch'
|
||||
|
|
37
docs/static/stainless-llama-stack-spec.html
vendored
37
docs/static/stainless-llama-stack-spec.html
vendored
|
@ -1830,7 +1830,27 @@
|
|||
},
|
||||
"required": true
|
||||
},
|
||||
"deprecated": false
|
||||
"deprecated": false,
|
||||
"x-llama-stack-extra-body-params": [
|
||||
{
|
||||
"name": "shields",
|
||||
"schema": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"oneOf": [
|
||||
{
|
||||
"type": "string"
|
||||
},
|
||||
{
|
||||
"$ref": "#/components/schemas/ResponseShieldSpec"
|
||||
}
|
||||
]
|
||||
}
|
||||
},
|
||||
"description": "List of shields to apply during response generation. Shields provide safety and content moderation.",
|
||||
"required": false
|
||||
}
|
||||
]
|
||||
}
|
||||
},
|
||||
"/v1/responses/{response_id}": {
|
||||
|
@ -9625,6 +9645,21 @@
|
|||
"title": "OpenAIResponseText",
|
||||
"description": "Text response configuration for OpenAI responses."
|
||||
},
|
||||
"ResponseShieldSpec": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"type": {
|
||||
"type": "string",
|
||||
"description": "The type/identifier of the shield."
|
||||
}
|
||||
},
|
||||
"additionalProperties": false,
|
||||
"required": [
|
||||
"type"
|
||||
],
|
||||
"title": "ResponseShieldSpec",
|
||||
"description": "Specification for a shield to apply during response generation."
|
||||
},
|
||||
"OpenAIResponseInputTool": {
|
||||
"oneOf": [
|
||||
{
|
||||
|
|
24
docs/static/stainless-llama-stack-spec.yaml
vendored
24
docs/static/stainless-llama-stack-spec.yaml
vendored
|
@ -1414,6 +1414,18 @@ paths:
|
|||
$ref: '#/components/schemas/CreateOpenaiResponseRequest'
|
||||
required: true
|
||||
deprecated: false
|
||||
x-llama-stack-extra-body-params:
|
||||
- name: shields
|
||||
schema:
|
||||
type: array
|
||||
items:
|
||||
oneOf:
|
||||
- type: string
|
||||
- $ref: '#/components/schemas/ResponseShieldSpec'
|
||||
description: >-
|
||||
List of shields to apply during response generation. Shields provide safety
|
||||
and content moderation.
|
||||
required: false
|
||||
/v1/responses/{response_id}:
|
||||
get:
|
||||
responses:
|
||||
|
@ -7184,6 +7196,18 @@ components:
|
|||
title: OpenAIResponseText
|
||||
description: >-
|
||||
Text response configuration for OpenAI responses.
|
||||
ResponseShieldSpec:
|
||||
type: object
|
||||
properties:
|
||||
type:
|
||||
type: string
|
||||
description: The type/identifier of the shield.
|
||||
additionalProperties: false
|
||||
required:
|
||||
- type
|
||||
title: ResponseShieldSpec
|
||||
description: >-
|
||||
Specification for a shield to apply during response generation.
|
||||
OpenAIResponseInputTool:
|
||||
oneOf:
|
||||
- $ref: '#/components/schemas/OpenAIResponseInputToolWebSearch'
|
||||
|
|
|
@ -28,7 +28,7 @@ from llama_stack.apis.inference import (
|
|||
from llama_stack.apis.safety import SafetyViolation
|
||||
from llama_stack.apis.tools import ToolDef
|
||||
from llama_stack.apis.version import LLAMA_STACK_API_V1, LLAMA_STACK_API_V1ALPHA
|
||||
from llama_stack.schema_utils import json_schema_type, register_schema, webmethod
|
||||
from llama_stack.schema_utils import ExtraBodyField, json_schema_type, register_schema, webmethod
|
||||
|
||||
from .openai_responses import (
|
||||
ListOpenAIResponseInputItem,
|
||||
|
@ -42,6 +42,20 @@ from .openai_responses import (
|
|||
)
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class ResponseShieldSpec(BaseModel):
|
||||
"""Specification for a shield to apply during response generation.
|
||||
|
||||
:param type: The type/identifier of the shield.
|
||||
"""
|
||||
|
||||
type: str
|
||||
# TODO: more fields to be added for shield configuration
|
||||
|
||||
|
||||
ResponseShield = str | ResponseShieldSpec
|
||||
|
||||
|
||||
class Attachment(BaseModel):
|
||||
"""An attachment to an agent turn.
|
||||
|
||||
|
@ -805,6 +819,12 @@ class Agents(Protocol):
|
|||
tools: list[OpenAIResponseInputTool] | None = None,
|
||||
include: list[str] | None = None,
|
||||
max_infer_iters: int | None = 10, # this is an extension to the OpenAI API
|
||||
shields: Annotated[
|
||||
list[ResponseShield] | None,
|
||||
ExtraBodyField(
|
||||
"List of shields to apply during response generation. Shields provide safety and content moderation."
|
||||
),
|
||||
] = None,
|
||||
) -> OpenAIResponseObject | AsyncIterator[OpenAIResponseObjectStream]:
|
||||
"""Create a new OpenAI response.
|
||||
|
||||
|
@ -812,6 +832,7 @@ class Agents(Protocol):
|
|||
:param model: The underlying LLM used for completions.
|
||||
:param previous_response_id: (Optional) if specified, the new response will be a continuation of the previous response. This can be used to easily fork-off new responses from existing responses.
|
||||
:param include: (Optional) Additional fields to include in the response.
|
||||
:param shields: (Optional) List of shields to apply during response generation. Can be shield IDs (strings) or shield specifications.
|
||||
:returns: An OpenAIResponseObject.
|
||||
"""
|
||||
...
|
||||
|
|
|
@ -374,6 +374,10 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
|
|||
body = options.params or {}
|
||||
body |= options.json_data or {}
|
||||
|
||||
# Merge extra_json parameters (extra_body from SDK is converted to extra_json)
|
||||
if hasattr(options, "extra_json") and options.extra_json:
|
||||
body |= options.extra_json
|
||||
|
||||
matched_func, path_params, route_path, webmethod = find_matching_route(options.method, path, self.route_impls)
|
||||
body |= path_params
|
||||
|
||||
|
|
|
@ -329,6 +329,7 @@ class MetaReferenceAgentsImpl(Agents):
|
|||
tools: list[OpenAIResponseInputTool] | None = None,
|
||||
include: list[str] | None = None,
|
||||
max_infer_iters: int | None = 10,
|
||||
shields: list | None = None,
|
||||
) -> OpenAIResponseObject:
|
||||
return await self.openai_responses_impl.create_openai_response(
|
||||
input,
|
||||
|
@ -342,6 +343,7 @@ class MetaReferenceAgentsImpl(Agents):
|
|||
tools,
|
||||
include,
|
||||
max_infer_iters,
|
||||
shields,
|
||||
)
|
||||
|
||||
async def list_openai_responses(
|
||||
|
|
|
@ -208,10 +208,15 @@ class OpenAIResponsesImpl:
|
|||
tools: list[OpenAIResponseInputTool] | None = None,
|
||||
include: list[str] | None = None,
|
||||
max_infer_iters: int | None = 10,
|
||||
shields: list | None = None,
|
||||
):
|
||||
stream = bool(stream)
|
||||
text = OpenAIResponseText(format=OpenAIResponseTextFormat(type="text")) if text is None else text
|
||||
|
||||
# Shields parameter received via extra_body - not yet implemented
|
||||
if shields is not None:
|
||||
raise NotImplementedError("Shields parameter is not yet implemented in the meta-reference provider")
|
||||
|
||||
stream_gen = self._create_streaming_response(
|
||||
input=input,
|
||||
model=model,
|
||||
|
|
|
@ -11,6 +11,43 @@ from typing import Any, TypeVar
|
|||
from .strong_typing.schema import json_schema_type, register_schema # noqa: F401
|
||||
|
||||
|
||||
class ExtraBodyField[T]:
|
||||
"""
|
||||
Marker annotation for parameters that arrive via extra_body in the client SDK.
|
||||
|
||||
These parameters:
|
||||
- Will NOT appear in the generated client SDK method signature
|
||||
- WILL be documented in OpenAPI spec under x-llama-stack-extra-body-params
|
||||
- MUST be passed via the extra_body parameter in client SDK calls
|
||||
- WILL be available in server-side method signature with proper typing
|
||||
|
||||
Example:
|
||||
```python
|
||||
async def create_openai_response(
|
||||
self,
|
||||
input: str,
|
||||
model: str,
|
||||
shields: Annotated[
|
||||
list[str] | None, ExtraBodyField("List of shields to apply")
|
||||
] = None,
|
||||
) -> ResponseObject:
|
||||
# shields is available here with proper typing
|
||||
if shields:
|
||||
print(f"Using shields: {shields}")
|
||||
```
|
||||
|
||||
Client usage:
|
||||
```python
|
||||
client.responses.create(
|
||||
input="hello", model="llama-3", extra_body={"shields": ["shield-1"]}
|
||||
)
|
||||
```
|
||||
"""
|
||||
|
||||
def __init__(self, description: str | None = None):
|
||||
self.description = description
|
||||
|
||||
|
||||
@dataclass
|
||||
class WebMethod:
|
||||
level: str | None = None
|
||||
|
@ -26,7 +63,7 @@ class WebMethod:
|
|||
deprecated: bool | None = False
|
||||
|
||||
|
||||
T = TypeVar("T", bound=Callable[..., Any])
|
||||
CallableT = TypeVar("CallableT", bound=Callable[..., Any])
|
||||
|
||||
|
||||
def webmethod(
|
||||
|
@ -40,7 +77,7 @@ def webmethod(
|
|||
descriptive_name: str | None = None,
|
||||
required_scope: str | None = None,
|
||||
deprecated: bool | None = False,
|
||||
) -> Callable[[T], T]:
|
||||
) -> Callable[[CallableT], CallableT]:
|
||||
"""
|
||||
Decorator that supplies additional metadata to an endpoint operation function.
|
||||
|
||||
|
@ -51,7 +88,7 @@ def webmethod(
|
|||
:param required_scope: Required scope for this endpoint (e.g., 'monitoring.viewer').
|
||||
"""
|
||||
|
||||
def wrap(func: T) -> T:
|
||||
def wrap(func: CallableT) -> CallableT:
|
||||
webmethod_obj = WebMethod(
|
||||
route=route,
|
||||
method=method,
|
||||
|
|
33
tests/integration/responses/test_extra_body_shields.py
Normal file
33
tests/integration/responses/test_extra_body_shields.py
Normal file
|
@ -0,0 +1,33 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
"""
|
||||
Test for extra_body parameter support with shields example.
|
||||
|
||||
This test demonstrates that parameters marked with ExtraBodyField annotation
|
||||
can be passed via extra_body in the client SDK and are received by the
|
||||
server-side implementation.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from llama_stack_client import APIStatusError
|
||||
|
||||
|
||||
def test_shields_via_extra_body(compat_client, text_model_id):
|
||||
"""Test that shields parameter is received by the server and raises NotImplementedError."""
|
||||
|
||||
# Test with shields as list of strings (shield IDs)
|
||||
with pytest.raises((APIStatusError, NotImplementedError)) as exc_info:
|
||||
compat_client.responses.create(
|
||||
model=text_model_id,
|
||||
input="What is the capital of France?",
|
||||
stream=False,
|
||||
extra_body={"shields": ["test-shield-1", "test-shield-2"]},
|
||||
)
|
||||
|
||||
# Verify the error message indicates shields are not implemented
|
||||
error_message = str(exc_info.value)
|
||||
assert "not yet implemented" in error_message.lower() or "not implemented" in error_message.lower()
|
Loading…
Add table
Add a link
Reference in a new issue