mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-10-22 08:17:18 +00:00
Merge branch 'main' into responses-and-safety
This commit is contained in:
commit
90ee3001d9
33 changed files with 16970 additions and 713 deletions
|
@ -73,7 +73,7 @@ class Inspect(Protocol):
|
|||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/health", method="GET", level=LLAMA_STACK_API_V1)
|
||||
@webmethod(route="/health", method="GET", level=LLAMA_STACK_API_V1, require_authentication=False)
|
||||
async def health(self) -> HealthInfo:
|
||||
"""Get health status.
|
||||
|
||||
|
@ -83,7 +83,7 @@ class Inspect(Protocol):
|
|||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/version", method="GET", level=LLAMA_STACK_API_V1)
|
||||
@webmethod(route="/version", method="GET", level=LLAMA_STACK_API_V1, require_authentication=False)
|
||||
async def version(self) -> VersionInfo:
|
||||
"""Get version.
|
||||
|
||||
|
|
|
@ -27,6 +27,11 @@ class AuthenticationMiddleware:
|
|||
3. Extracts user attributes from the provider's response
|
||||
4. Makes these attributes available to the route handlers for access control
|
||||
|
||||
Unauthenticated Access:
|
||||
Endpoints can opt out of authentication by setting require_authentication=False
|
||||
in their @webmethod decorator. This is typically used for operational endpoints
|
||||
like /health and /version to support monitoring, load balancers, and observability tools.
|
||||
|
||||
The middleware supports multiple authentication providers through the AuthProvider interface:
|
||||
- Kubernetes: Validates tokens against the Kubernetes API server
|
||||
- Custom: Validates tokens against a custom endpoint
|
||||
|
@ -88,7 +93,26 @@ class AuthenticationMiddleware:
|
|||
|
||||
async def __call__(self, scope, receive, send):
|
||||
if scope["type"] == "http":
|
||||
# First, handle authentication
|
||||
# Find the route and check if authentication is required
|
||||
path = scope.get("path", "")
|
||||
method = scope.get("method", hdrs.METH_GET)
|
||||
|
||||
if not hasattr(self, "route_impls"):
|
||||
self.route_impls = initialize_route_impls(self.impls)
|
||||
|
||||
webmethod = None
|
||||
try:
|
||||
_, _, _, webmethod = find_matching_route(method, path, self.route_impls)
|
||||
except ValueError:
|
||||
# If no matching endpoint is found, pass here to run auth anyways
|
||||
pass
|
||||
|
||||
# If webmethod explicitly sets require_authentication=False, allow without auth
|
||||
if webmethod and webmethod.require_authentication is False:
|
||||
logger.debug(f"Allowing unauthenticated access to endpoint: {path}")
|
||||
return await self.app(scope, receive, send)
|
||||
|
||||
# Handle authentication
|
||||
headers = dict(scope.get("headers", []))
|
||||
auth_header = headers.get(b"authorization", b"").decode()
|
||||
|
||||
|
@ -127,19 +151,7 @@ class AuthenticationMiddleware:
|
|||
)
|
||||
|
||||
# Scope-based API access control
|
||||
path = scope.get("path", "")
|
||||
method = scope.get("method", hdrs.METH_GET)
|
||||
|
||||
if not hasattr(self, "route_impls"):
|
||||
self.route_impls = initialize_route_impls(self.impls)
|
||||
|
||||
try:
|
||||
_, _, _, webmethod = find_matching_route(method, path, self.route_impls)
|
||||
except ValueError:
|
||||
# If no matching endpoint is found, pass through to FastAPI
|
||||
return await self.app(scope, receive, send)
|
||||
|
||||
if webmethod.required_scope:
|
||||
if webmethod and webmethod.required_scope:
|
||||
user = user_from_scope(scope)
|
||||
if not _has_required_scope(webmethod.required_scope, user):
|
||||
return await self._send_auth_error(
|
||||
|
|
|
@ -41,12 +41,16 @@ from llama_stack.apis.agents.openai_responses import (
|
|||
OpenAIResponseOutputMessageFunctionToolCall,
|
||||
OpenAIResponseOutputMessageMCPListTools,
|
||||
OpenAIResponseText,
|
||||
OpenAIResponseUsage,
|
||||
OpenAIResponseUsageInputTokensDetails,
|
||||
OpenAIResponseUsageOutputTokensDetails,
|
||||
WebSearchToolTypes,
|
||||
)
|
||||
from llama_stack.apis.inference import (
|
||||
Inference,
|
||||
OpenAIAssistantMessageParam,
|
||||
OpenAIChatCompletion,
|
||||
OpenAIChatCompletionChunk,
|
||||
OpenAIChatCompletionToolCall,
|
||||
OpenAIChoice,
|
||||
OpenAIMessageParam,
|
||||
|
@ -116,43 +120,8 @@ class StreamingResponseOrchestrator:
|
|||
self.final_messages: list[OpenAIMessageParam] = []
|
||||
# mapping for annotations
|
||||
self.citation_files: dict[str, str] = {}
|
||||
# Track accumulated text for shield validation
|
||||
self.accumulated_text = ""
|
||||
# Track if we've sent a refusal response
|
||||
self.violation_detected = False
|
||||
|
||||
async def _check_output_stream_safety(self, text_delta: str) -> str | None:
|
||||
"""Check streaming text content against shields. Returns violation message if blocked."""
|
||||
if not self.shield_ids:
|
||||
return None
|
||||
|
||||
self.accumulated_text += text_delta
|
||||
|
||||
# Check accumulated text periodically for violations (every 50 characters or at word boundaries)
|
||||
if len(self.accumulated_text) > 50 or text_delta.endswith((" ", "\n", ".", "!", "?")):
|
||||
temp_messages = [{"role": "assistant", "content": self.accumulated_text}]
|
||||
messages = convert_openai_to_inference_messages(temp_messages)
|
||||
|
||||
try:
|
||||
await run_multiple_shields(self.safety_api, messages, self.shield_ids)
|
||||
except SafetyException as e:
|
||||
logger.info(f"Output shield violation: {e.violation.user_message}")
|
||||
return e.violation.user_message or "Generated content blocked by safety shields"
|
||||
|
||||
async def _create_refusal_response(self, violation_message: str) -> OpenAIResponseObjectStream:
|
||||
"""Create a refusal response to replace streaming content."""
|
||||
refusal_content = OpenAIResponseContentPartRefusal(refusal=violation_message)
|
||||
|
||||
# Create a completed refusal response
|
||||
refusal_response = OpenAIResponseObject(
|
||||
id=self.response_id,
|
||||
created_at=self.created_at,
|
||||
model=self.ctx.model,
|
||||
status="completed",
|
||||
output=[OpenAIResponseMessage(role="assistant", content=[refusal_content], type="message")],
|
||||
)
|
||||
|
||||
return OpenAIResponseObjectStreamResponseCompleted(response=refusal_response)
|
||||
# Track accumulated usage across all inference calls
|
||||
self.accumulated_usage: OpenAIResponseUsage | None = None
|
||||
|
||||
def _clone_outputs(self, outputs: list[OpenAIResponseOutput]) -> list[OpenAIResponseOutput]:
|
||||
cloned: list[OpenAIResponseOutput] = []
|
||||
|
@ -180,6 +149,7 @@ class StreamingResponseOrchestrator:
|
|||
text=self.text,
|
||||
tools=self.ctx.available_tools(),
|
||||
error=error,
|
||||
usage=self.accumulated_usage,
|
||||
)
|
||||
|
||||
async def create_response(self) -> AsyncIterator[OpenAIResponseObjectStream]:
|
||||
|
@ -217,6 +187,9 @@ class StreamingResponseOrchestrator:
|
|||
stream=True,
|
||||
temperature=self.ctx.temperature,
|
||||
response_format=response_format,
|
||||
stream_options={
|
||||
"include_usage": True,
|
||||
},
|
||||
)
|
||||
|
||||
# Process streaming chunks and build complete response
|
||||
|
@ -352,6 +325,51 @@ class StreamingResponseOrchestrator:
|
|||
|
||||
return function_tool_calls, non_function_tool_calls, approvals, next_turn_messages
|
||||
|
||||
def _accumulate_chunk_usage(self, chunk: OpenAIChatCompletionChunk) -> None:
|
||||
"""Accumulate usage from a streaming chunk into the response usage format."""
|
||||
if not chunk.usage:
|
||||
return
|
||||
|
||||
if self.accumulated_usage is None:
|
||||
# Convert from chat completion format to response format
|
||||
self.accumulated_usage = OpenAIResponseUsage(
|
||||
input_tokens=chunk.usage.prompt_tokens,
|
||||
output_tokens=chunk.usage.completion_tokens,
|
||||
total_tokens=chunk.usage.total_tokens,
|
||||
input_tokens_details=(
|
||||
OpenAIResponseUsageInputTokensDetails(cached_tokens=chunk.usage.prompt_tokens_details.cached_tokens)
|
||||
if chunk.usage.prompt_tokens_details
|
||||
else None
|
||||
),
|
||||
output_tokens_details=(
|
||||
OpenAIResponseUsageOutputTokensDetails(
|
||||
reasoning_tokens=chunk.usage.completion_tokens_details.reasoning_tokens
|
||||
)
|
||||
if chunk.usage.completion_tokens_details
|
||||
else None
|
||||
),
|
||||
)
|
||||
else:
|
||||
# Accumulate across multiple inference calls
|
||||
self.accumulated_usage = OpenAIResponseUsage(
|
||||
input_tokens=self.accumulated_usage.input_tokens + chunk.usage.prompt_tokens,
|
||||
output_tokens=self.accumulated_usage.output_tokens + chunk.usage.completion_tokens,
|
||||
total_tokens=self.accumulated_usage.total_tokens + chunk.usage.total_tokens,
|
||||
# Use latest non-null details
|
||||
input_tokens_details=(
|
||||
OpenAIResponseUsageInputTokensDetails(cached_tokens=chunk.usage.prompt_tokens_details.cached_tokens)
|
||||
if chunk.usage.prompt_tokens_details
|
||||
else self.accumulated_usage.input_tokens_details
|
||||
),
|
||||
output_tokens_details=(
|
||||
OpenAIResponseUsageOutputTokensDetails(
|
||||
reasoning_tokens=chunk.usage.completion_tokens_details.reasoning_tokens
|
||||
)
|
||||
if chunk.usage.completion_tokens_details
|
||||
else self.accumulated_usage.output_tokens_details
|
||||
),
|
||||
)
|
||||
|
||||
async def _process_streaming_chunks(
|
||||
self, completion_result, output_messages: list[OpenAIResponseOutput]
|
||||
) -> AsyncIterator[OpenAIResponseObjectStream | ChatCompletionResult]:
|
||||
|
@ -377,6 +395,10 @@ class StreamingResponseOrchestrator:
|
|||
chat_response_id = chunk.id
|
||||
chunk_created = chunk.created
|
||||
chunk_model = chunk.model
|
||||
|
||||
# Accumulate usage from chunks (typically in final chunk with stream_options)
|
||||
self._accumulate_chunk_usage(chunk)
|
||||
|
||||
for chunk_choice in chunk.choices:
|
||||
# Emit incremental text content as delta events
|
||||
if chunk_choice.delta.content:
|
||||
|
|
|
@ -178,9 +178,9 @@ class ReferenceBatchesImpl(Batches):
|
|||
|
||||
# TODO: set expiration time for garbage collection
|
||||
|
||||
if endpoint not in ["/v1/chat/completions", "/v1/completions"]:
|
||||
if endpoint not in ["/v1/chat/completions", "/v1/completions", "/v1/embeddings"]:
|
||||
raise ValueError(
|
||||
f"Invalid endpoint: {endpoint}. Supported values: /v1/chat/completions, /v1/completions. Code: invalid_value. Param: endpoint",
|
||||
f"Invalid endpoint: {endpoint}. Supported values: /v1/chat/completions, /v1/completions, /v1/embeddings. Code: invalid_value. Param: endpoint",
|
||||
)
|
||||
|
||||
if completion_window != "24h":
|
||||
|
@ -425,18 +425,23 @@ class ReferenceBatchesImpl(Batches):
|
|||
valid = False
|
||||
|
||||
if batch.endpoint == "/v1/chat/completions":
|
||||
required_params = [
|
||||
required_params: list[tuple[str, Any, str]] = [
|
||||
("model", str, "a string"),
|
||||
# messages is specific to /v1/chat/completions
|
||||
# we could skip validating messages here and let inference fail. however,
|
||||
# that would be a very expensive way to find out messages is wrong.
|
||||
("messages", list, "an array"), # TODO: allow messages to be a string?
|
||||
]
|
||||
else: # /v1/completions
|
||||
elif batch.endpoint == "/v1/completions":
|
||||
required_params = [
|
||||
("model", str, "a string"),
|
||||
("prompt", str, "a string"), # TODO: allow prompt to be a list of strings??
|
||||
]
|
||||
else: # /v1/embeddings
|
||||
required_params = [
|
||||
("model", str, "a string"),
|
||||
("input", (str, list), "a string or array of strings"),
|
||||
]
|
||||
|
||||
for param, expected_type, type_string in required_params:
|
||||
if param not in body:
|
||||
|
@ -614,7 +619,7 @@ class ReferenceBatchesImpl(Batches):
|
|||
"body": chat_response.model_dump_json(),
|
||||
},
|
||||
}
|
||||
else: # /v1/completions
|
||||
elif request.url == "/v1/completions":
|
||||
completion_response = await self.inference_api.openai_completion(**request.body)
|
||||
|
||||
# this is for mypy, we don't allow streaming so we'll get the right type
|
||||
|
@ -630,6 +635,20 @@ class ReferenceBatchesImpl(Batches):
|
|||
"body": completion_response.model_dump_json(),
|
||||
},
|
||||
}
|
||||
else: # /v1/embeddings
|
||||
embeddings_response = await self.inference_api.openai_embeddings(**request.body)
|
||||
assert hasattr(embeddings_response, "model_dump_json"), (
|
||||
"Embeddings response must have model_dump_json method"
|
||||
)
|
||||
return {
|
||||
"id": request_id,
|
||||
"custom_id": request.custom_id,
|
||||
"response": {
|
||||
"status_code": 200,
|
||||
"request_id": request_id, # TODO: should this be different?
|
||||
"body": embeddings_response.model_dump_json(),
|
||||
},
|
||||
}
|
||||
except Exception as e:
|
||||
logger.info(f"Error processing request {request.custom_id} in batch {batch_id}: {e}")
|
||||
return {
|
||||
|
|
|
@ -61,6 +61,7 @@ class WebMethod:
|
|||
descriptive_name: str | None = None
|
||||
required_scope: str | None = None
|
||||
deprecated: bool | None = False
|
||||
require_authentication: bool | None = True
|
||||
|
||||
|
||||
CallableT = TypeVar("CallableT", bound=Callable[..., Any])
|
||||
|
@ -77,6 +78,7 @@ def webmethod(
|
|||
descriptive_name: str | None = None,
|
||||
required_scope: str | None = None,
|
||||
deprecated: bool | None = False,
|
||||
require_authentication: bool | None = True,
|
||||
) -> Callable[[CallableT], CallableT]:
|
||||
"""
|
||||
Decorator that supplies additional metadata to an endpoint operation function.
|
||||
|
@ -86,6 +88,7 @@ def webmethod(
|
|||
:param request_examples: Sample requests that the operation might take. Pass a list of objects, not JSON.
|
||||
:param response_examples: Sample responses that the operation might produce. Pass a list of objects, not JSON.
|
||||
:param required_scope: Required scope for this endpoint (e.g., 'monitoring.viewer').
|
||||
:param require_authentication: Whether this endpoint requires authentication (default True).
|
||||
"""
|
||||
|
||||
def wrap(func: CallableT) -> CallableT:
|
||||
|
@ -100,6 +103,7 @@ def webmethod(
|
|||
descriptive_name=descriptive_name,
|
||||
required_scope=required_scope,
|
||||
deprecated=deprecated,
|
||||
require_authentication=require_authentication if require_authentication is not None else True,
|
||||
)
|
||||
|
||||
# Store all webmethods in a list to support multiple decorators
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue