Merge branch 'main' into responses-and-safety

This commit is contained in:
slekkala1 2025-10-10 13:57:47 -07:00 committed by GitHub
commit 90ee3001d9
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
33 changed files with 16970 additions and 713 deletions

View file

@ -73,7 +73,7 @@ class Inspect(Protocol):
"""
...
@webmethod(route="/health", method="GET", level=LLAMA_STACK_API_V1)
@webmethod(route="/health", method="GET", level=LLAMA_STACK_API_V1, require_authentication=False)
async def health(self) -> HealthInfo:
"""Get health status.
@ -83,7 +83,7 @@ class Inspect(Protocol):
"""
...
@webmethod(route="/version", method="GET", level=LLAMA_STACK_API_V1)
@webmethod(route="/version", method="GET", level=LLAMA_STACK_API_V1, require_authentication=False)
async def version(self) -> VersionInfo:
"""Get version.

View file

@ -27,6 +27,11 @@ class AuthenticationMiddleware:
3. Extracts user attributes from the provider's response
4. Makes these attributes available to the route handlers for access control
Unauthenticated Access:
Endpoints can opt out of authentication by setting require_authentication=False
in their @webmethod decorator. This is typically used for operational endpoints
like /health and /version to support monitoring, load balancers, and observability tools.
The middleware supports multiple authentication providers through the AuthProvider interface:
- Kubernetes: Validates tokens against the Kubernetes API server
- Custom: Validates tokens against a custom endpoint
@ -88,7 +93,26 @@ class AuthenticationMiddleware:
async def __call__(self, scope, receive, send):
if scope["type"] == "http":
# First, handle authentication
# Find the route and check if authentication is required
path = scope.get("path", "")
method = scope.get("method", hdrs.METH_GET)
if not hasattr(self, "route_impls"):
self.route_impls = initialize_route_impls(self.impls)
webmethod = None
try:
_, _, _, webmethod = find_matching_route(method, path, self.route_impls)
except ValueError:
# If no matching endpoint is found, pass here to run auth anyways
pass
# If webmethod explicitly sets require_authentication=False, allow without auth
if webmethod and webmethod.require_authentication is False:
logger.debug(f"Allowing unauthenticated access to endpoint: {path}")
return await self.app(scope, receive, send)
# Handle authentication
headers = dict(scope.get("headers", []))
auth_header = headers.get(b"authorization", b"").decode()
@ -127,19 +151,7 @@ class AuthenticationMiddleware:
)
# Scope-based API access control
path = scope.get("path", "")
method = scope.get("method", hdrs.METH_GET)
if not hasattr(self, "route_impls"):
self.route_impls = initialize_route_impls(self.impls)
try:
_, _, _, webmethod = find_matching_route(method, path, self.route_impls)
except ValueError:
# If no matching endpoint is found, pass through to FastAPI
return await self.app(scope, receive, send)
if webmethod.required_scope:
if webmethod and webmethod.required_scope:
user = user_from_scope(scope)
if not _has_required_scope(webmethod.required_scope, user):
return await self._send_auth_error(

View file

@ -41,12 +41,16 @@ from llama_stack.apis.agents.openai_responses import (
OpenAIResponseOutputMessageFunctionToolCall,
OpenAIResponseOutputMessageMCPListTools,
OpenAIResponseText,
OpenAIResponseUsage,
OpenAIResponseUsageInputTokensDetails,
OpenAIResponseUsageOutputTokensDetails,
WebSearchToolTypes,
)
from llama_stack.apis.inference import (
Inference,
OpenAIAssistantMessageParam,
OpenAIChatCompletion,
OpenAIChatCompletionChunk,
OpenAIChatCompletionToolCall,
OpenAIChoice,
OpenAIMessageParam,
@ -116,43 +120,8 @@ class StreamingResponseOrchestrator:
self.final_messages: list[OpenAIMessageParam] = []
# mapping for annotations
self.citation_files: dict[str, str] = {}
# Track accumulated text for shield validation
self.accumulated_text = ""
# Track if we've sent a refusal response
self.violation_detected = False
async def _check_output_stream_safety(self, text_delta: str) -> str | None:
"""Check streaming text content against shields. Returns violation message if blocked."""
if not self.shield_ids:
return None
self.accumulated_text += text_delta
# Check accumulated text periodically for violations (every 50 characters or at word boundaries)
if len(self.accumulated_text) > 50 or text_delta.endswith((" ", "\n", ".", "!", "?")):
temp_messages = [{"role": "assistant", "content": self.accumulated_text}]
messages = convert_openai_to_inference_messages(temp_messages)
try:
await run_multiple_shields(self.safety_api, messages, self.shield_ids)
except SafetyException as e:
logger.info(f"Output shield violation: {e.violation.user_message}")
return e.violation.user_message or "Generated content blocked by safety shields"
async def _create_refusal_response(self, violation_message: str) -> OpenAIResponseObjectStream:
"""Create a refusal response to replace streaming content."""
refusal_content = OpenAIResponseContentPartRefusal(refusal=violation_message)
# Create a completed refusal response
refusal_response = OpenAIResponseObject(
id=self.response_id,
created_at=self.created_at,
model=self.ctx.model,
status="completed",
output=[OpenAIResponseMessage(role="assistant", content=[refusal_content], type="message")],
)
return OpenAIResponseObjectStreamResponseCompleted(response=refusal_response)
# Track accumulated usage across all inference calls
self.accumulated_usage: OpenAIResponseUsage | None = None
def _clone_outputs(self, outputs: list[OpenAIResponseOutput]) -> list[OpenAIResponseOutput]:
cloned: list[OpenAIResponseOutput] = []
@ -180,6 +149,7 @@ class StreamingResponseOrchestrator:
text=self.text,
tools=self.ctx.available_tools(),
error=error,
usage=self.accumulated_usage,
)
async def create_response(self) -> AsyncIterator[OpenAIResponseObjectStream]:
@ -217,6 +187,9 @@ class StreamingResponseOrchestrator:
stream=True,
temperature=self.ctx.temperature,
response_format=response_format,
stream_options={
"include_usage": True,
},
)
# Process streaming chunks and build complete response
@ -352,6 +325,51 @@ class StreamingResponseOrchestrator:
return function_tool_calls, non_function_tool_calls, approvals, next_turn_messages
def _accumulate_chunk_usage(self, chunk: OpenAIChatCompletionChunk) -> None:
"""Accumulate usage from a streaming chunk into the response usage format."""
if not chunk.usage:
return
if self.accumulated_usage is None:
# Convert from chat completion format to response format
self.accumulated_usage = OpenAIResponseUsage(
input_tokens=chunk.usage.prompt_tokens,
output_tokens=chunk.usage.completion_tokens,
total_tokens=chunk.usage.total_tokens,
input_tokens_details=(
OpenAIResponseUsageInputTokensDetails(cached_tokens=chunk.usage.prompt_tokens_details.cached_tokens)
if chunk.usage.prompt_tokens_details
else None
),
output_tokens_details=(
OpenAIResponseUsageOutputTokensDetails(
reasoning_tokens=chunk.usage.completion_tokens_details.reasoning_tokens
)
if chunk.usage.completion_tokens_details
else None
),
)
else:
# Accumulate across multiple inference calls
self.accumulated_usage = OpenAIResponseUsage(
input_tokens=self.accumulated_usage.input_tokens + chunk.usage.prompt_tokens,
output_tokens=self.accumulated_usage.output_tokens + chunk.usage.completion_tokens,
total_tokens=self.accumulated_usage.total_tokens + chunk.usage.total_tokens,
# Use latest non-null details
input_tokens_details=(
OpenAIResponseUsageInputTokensDetails(cached_tokens=chunk.usage.prompt_tokens_details.cached_tokens)
if chunk.usage.prompt_tokens_details
else self.accumulated_usage.input_tokens_details
),
output_tokens_details=(
OpenAIResponseUsageOutputTokensDetails(
reasoning_tokens=chunk.usage.completion_tokens_details.reasoning_tokens
)
if chunk.usage.completion_tokens_details
else self.accumulated_usage.output_tokens_details
),
)
async def _process_streaming_chunks(
self, completion_result, output_messages: list[OpenAIResponseOutput]
) -> AsyncIterator[OpenAIResponseObjectStream | ChatCompletionResult]:
@ -377,6 +395,10 @@ class StreamingResponseOrchestrator:
chat_response_id = chunk.id
chunk_created = chunk.created
chunk_model = chunk.model
# Accumulate usage from chunks (typically in final chunk with stream_options)
self._accumulate_chunk_usage(chunk)
for chunk_choice in chunk.choices:
# Emit incremental text content as delta events
if chunk_choice.delta.content:

View file

@ -178,9 +178,9 @@ class ReferenceBatchesImpl(Batches):
# TODO: set expiration time for garbage collection
if endpoint not in ["/v1/chat/completions", "/v1/completions"]:
if endpoint not in ["/v1/chat/completions", "/v1/completions", "/v1/embeddings"]:
raise ValueError(
f"Invalid endpoint: {endpoint}. Supported values: /v1/chat/completions, /v1/completions. Code: invalid_value. Param: endpoint",
f"Invalid endpoint: {endpoint}. Supported values: /v1/chat/completions, /v1/completions, /v1/embeddings. Code: invalid_value. Param: endpoint",
)
if completion_window != "24h":
@ -425,18 +425,23 @@ class ReferenceBatchesImpl(Batches):
valid = False
if batch.endpoint == "/v1/chat/completions":
required_params = [
required_params: list[tuple[str, Any, str]] = [
("model", str, "a string"),
# messages is specific to /v1/chat/completions
# we could skip validating messages here and let inference fail. however,
# that would be a very expensive way to find out messages is wrong.
("messages", list, "an array"), # TODO: allow messages to be a string?
]
else: # /v1/completions
elif batch.endpoint == "/v1/completions":
required_params = [
("model", str, "a string"),
("prompt", str, "a string"), # TODO: allow prompt to be a list of strings??
]
else: # /v1/embeddings
required_params = [
("model", str, "a string"),
("input", (str, list), "a string or array of strings"),
]
for param, expected_type, type_string in required_params:
if param not in body:
@ -614,7 +619,7 @@ class ReferenceBatchesImpl(Batches):
"body": chat_response.model_dump_json(),
},
}
else: # /v1/completions
elif request.url == "/v1/completions":
completion_response = await self.inference_api.openai_completion(**request.body)
# this is for mypy, we don't allow streaming so we'll get the right type
@ -630,6 +635,20 @@ class ReferenceBatchesImpl(Batches):
"body": completion_response.model_dump_json(),
},
}
else: # /v1/embeddings
embeddings_response = await self.inference_api.openai_embeddings(**request.body)
assert hasattr(embeddings_response, "model_dump_json"), (
"Embeddings response must have model_dump_json method"
)
return {
"id": request_id,
"custom_id": request.custom_id,
"response": {
"status_code": 200,
"request_id": request_id, # TODO: should this be different?
"body": embeddings_response.model_dump_json(),
},
}
except Exception as e:
logger.info(f"Error processing request {request.custom_id} in batch {batch_id}: {e}")
return {

View file

@ -61,6 +61,7 @@ class WebMethod:
descriptive_name: str | None = None
required_scope: str | None = None
deprecated: bool | None = False
require_authentication: bool | None = True
CallableT = TypeVar("CallableT", bound=Callable[..., Any])
@ -77,6 +78,7 @@ def webmethod(
descriptive_name: str | None = None,
required_scope: str | None = None,
deprecated: bool | None = False,
require_authentication: bool | None = True,
) -> Callable[[CallableT], CallableT]:
"""
Decorator that supplies additional metadata to an endpoint operation function.
@ -86,6 +88,7 @@ def webmethod(
:param request_examples: Sample requests that the operation might take. Pass a list of objects, not JSON.
:param response_examples: Sample responses that the operation might produce. Pass a list of objects, not JSON.
:param required_scope: Required scope for this endpoint (e.g., 'monitoring.viewer').
:param require_authentication: Whether this endpoint requires authentication (default True).
"""
def wrap(func: CallableT) -> CallableT:
@ -100,6 +103,7 @@ def webmethod(
descriptive_name=descriptive_name,
required_scope=required_scope,
deprecated=deprecated,
require_authentication=require_authentication if require_authentication is not None else True,
)
# Store all webmethods in a list to support multiple decorators