Merge branch 'pr1573' into api_2

This commit is contained in:
Xi Yan 2025-03-12 21:02:30 -07:00
commit d7dbc8cf64
21 changed files with 673 additions and 232 deletions

View file

@ -1,6 +1,8 @@
name: Unit Tests
on:
push:
branches: [ main ]
pull_request:
branches: [ main ]
workflow_dispatch:

View file

@ -4,6 +4,7 @@
[![PyPI - Downloads](https://img.shields.io/pypi/dm/llama-stack)](https://pypi.org/project/llama-stack/)
[![License](https://img.shields.io/pypi/l/llama_stack.svg)](https://github.com/meta-llama/llama-stack/blob/main/LICENSE)
[![Discord](https://img.shields.io/discord/1257833999603335178)](https://discord.gg/llama-stack)
![Unit](https://github.com/meta-llama/llama-stack/actions/workflows/unit-tests.yml/badge.svg?branch=main)
[**Quick Start**](https://llama-stack.readthedocs.io/en/latest/getting_started/index.html) | [**Documentation**](https://llama-stack.readthedocs.io/en/latest/index.html) | [**Colab Notebook**](./docs/getting_started.ipynb)

View file

@ -4570,7 +4570,7 @@
"metrics": {
"type": "array",
"items": {
"$ref": "#/components/schemas/MetricEvent"
"$ref": "#/components/schemas/MetricInResponse"
}
},
"completion_message": {
@ -4592,46 +4592,9 @@
"title": "ChatCompletionResponse",
"description": "Response from a chat completion request."
},
"MetricEvent": {
"MetricInResponse": {
"type": "object",
"properties": {
"trace_id": {
"type": "string"
},
"span_id": {
"type": "string"
},
"timestamp": {
"type": "string",
"format": "date-time"
},
"attributes": {
"type": "object",
"additionalProperties": {
"oneOf": [
{
"type": "string"
},
{
"type": "integer"
},
{
"type": "number"
},
{
"type": "boolean"
},
{
"type": "null"
}
]
}
},
"type": {
"type": "string",
"const": "metric",
"default": "metric"
},
"metric": {
"type": "string"
},
@ -4651,15 +4614,10 @@
},
"additionalProperties": false,
"required": [
"trace_id",
"span_id",
"timestamp",
"type",
"metric",
"value",
"unit"
"value"
],
"title": "MetricEvent"
"title": "MetricInResponse"
},
"TokenLogProbs": {
"type": "object",
@ -4736,6 +4694,12 @@
"CompletionResponse": {
"type": "object",
"properties": {
"metrics": {
"type": "array",
"items": {
"$ref": "#/components/schemas/MetricInResponse"
}
},
"content": {
"type": "string",
"description": "The generated completion text"
@ -4945,7 +4909,7 @@
"metrics": {
"type": "array",
"items": {
"$ref": "#/components/schemas/MetricEvent"
"$ref": "#/components/schemas/MetricInResponse"
}
},
"event": {
@ -5103,6 +5067,12 @@
"CompletionResponseStreamChunk": {
"type": "object",
"properties": {
"metrics": {
"type": "array",
"items": {
"$ref": "#/components/schemas/MetricInResponse"
}
},
"delta": {
"type": "string",
"description": "New content generated since last chunk. This can be one or more tokens."
@ -7192,15 +7162,16 @@
"const": "dataset",
"default": "dataset"
},
"schema": {
"purpose": {
"type": "string",
"enum": [
"messages"
"post-training/messages",
"eval/question-answer"
],
"title": "Schema",
"description": "Schema of the dataset. Each type has a different column format."
"title": "DatasetPurpose",
"description": "Purpose of the dataset. Each type has a different column format."
},
"data_source": {
"source": {
"$ref": "#/components/schemas/DataSource"
},
"metadata": {
@ -7235,8 +7206,8 @@
"provider_resource_id",
"provider_id",
"type",
"schema",
"data_source",
"purpose",
"source",
"metadata"
],
"title": "Dataset"
@ -7249,8 +7220,9 @@
"const": "huggingface",
"default": "huggingface"
},
"dataset_path": {
"type": "string"
"path": {
"type": "string",
"description": "The path to the dataset in Huggingface. E.g. - \"llamastack/simpleqa\""
},
"params": {
"type": "object",
@ -7275,16 +7247,18 @@
"type": "object"
}
]
}
},
"description": "The parameters for the dataset."
}
},
"additionalProperties": false,
"required": [
"type",
"dataset_path",
"path",
"params"
],
"title": "HuggingfaceDataSource"
"title": "HuggingfaceDataSource",
"description": "A dataset stored in Huggingface."
},
"RowsDataSource": {
"type": "object",
@ -7320,7 +7294,8 @@
}
]
}
}
},
"description": "The dataset is stored in rows. E.g. - [ {\"messages\": [{\"role\": \"user\", \"content\": \"Hello, world!\"}, {\"role\": \"assistant\", \"content\": \"Hello, world!\"}]} ]"
}
},
"additionalProperties": false,
@ -7328,7 +7303,8 @@
"type",
"rows"
],
"title": "RowsDataSource"
"title": "RowsDataSource",
"description": "A dataset stored in rows."
},
"URIDataSource": {
"type": "object",
@ -7339,7 +7315,8 @@
"default": "uri"
},
"uri": {
"type": "string"
"type": "string",
"description": "The dataset can be obtained from a URI. E.g. - \"https://mywebsite.com/mydata.jsonl\" - \"lsfs://mydata.jsonl\" - \"data:csv;base64,{base64_content}\""
}
},
"additionalProperties": false,
@ -7347,7 +7324,8 @@
"type",
"uri"
],
"title": "URIDataSource"
"title": "URIDataSource",
"description": "A dataset that can be obtained from a URI."
},
"Model": {
"type": "object",
@ -8634,6 +8612,75 @@
],
"title": "LogSeverity"
},
"MetricEvent": {
"type": "object",
"properties": {
"trace_id": {
"type": "string"
},
"span_id": {
"type": "string"
},
"timestamp": {
"type": "string",
"format": "date-time"
},
"attributes": {
"type": "object",
"additionalProperties": {
"oneOf": [
{
"type": "string"
},
{
"type": "integer"
},
{
"type": "number"
},
{
"type": "boolean"
},
{
"type": "null"
}
]
}
},
"type": {
"type": "string",
"const": "metric",
"default": "metric"
},
"metric": {
"type": "string"
},
"value": {
"oneOf": [
{
"type": "integer"
},
{
"type": "number"
}
]
},
"unit": {
"type": "string"
}
},
"additionalProperties": false,
"required": [
"trace_id",
"span_id",
"timestamp",
"type",
"metric",
"value",
"unit"
],
"title": "MetricEvent"
},
"SpanEndPayload": {
"type": "object",
"properties": {
@ -9510,14 +9557,15 @@
"RegisterDatasetRequest": {
"type": "object",
"properties": {
"schema": {
"purpose": {
"type": "string",
"enum": [
"messages"
"post-training/messages",
"eval/question-answer"
],
"description": "The schema format of the dataset. One of - messages: The dataset contains a messages column with list of messages for post-training."
"description": "The purpose of the dataset. One of - \"post-training/messages\": The dataset contains a messages column with list of messages for post-training. - \"eval/question-answer\": The dataset contains a question and answer column."
},
"data_source": {
"source": {
"$ref": "#/components/schemas/DataSource",
"description": "The data source of the dataset. Examples: - { \"type\": \"uri\", \"uri\": \"https://mywebsite.com/mydata.jsonl\" } - { \"type\": \"uri\", \"uri\": \"lsfs://mydata.jsonl\" } - { \"type\": \"huggingface\", \"dataset_path\": \"tatsu-lab/alpaca\", \"params\": { \"split\": \"train\" } } - { \"type\": \"rows\", \"rows\": [ { \"messages\": [ {\"role\": \"user\", \"content\": \"Hello, world!\"}, {\"role\": \"assistant\", \"content\": \"Hello, world!\"}, ] } ] }"
},
@ -9554,8 +9602,8 @@
},
"additionalProperties": false,
"required": [
"schema",
"data_source"
"purpose",
"source"
],
"title": "RegisterDatasetRequest"
},
@ -9769,21 +9817,11 @@
"type": "object",
"properties": {
"tool_responses": {
"oneOf": [
{
"type": "array",
"items": {
"$ref": "#/components/schemas/ToolResponse"
}
},
{
"type": "array",
"items": {
"$ref": "#/components/schemas/ToolResponseMessage"
}
}
],
"description": "The tool call responses to resume the turn with. NOTE: ToolResponseMessage will be deprecated. Use ToolResponse."
"description": "The tool call responses to resume the turn with."
},
"stream": {
"type": "boolean",

View file

@ -3115,7 +3115,7 @@ components:
metrics:
type: array
items:
$ref: '#/components/schemas/MetricEvent'
$ref: '#/components/schemas/MetricInResponse'
completion_message:
$ref: '#/components/schemas/CompletionMessage'
description: The complete response message
@ -3130,29 +3130,9 @@ components:
- completion_message
title: ChatCompletionResponse
description: Response from a chat completion request.
MetricEvent:
MetricInResponse:
type: object
properties:
trace_id:
type: string
span_id:
type: string
timestamp:
type: string
format: date-time
attributes:
type: object
additionalProperties:
oneOf:
- type: string
- type: integer
- type: number
- type: boolean
- type: 'null'
type:
type: string
const: metric
default: metric
metric:
type: string
value:
@ -3163,14 +3143,9 @@ components:
type: string
additionalProperties: false
required:
- trace_id
- span_id
- timestamp
- type
- metric
- value
- unit
title: MetricEvent
title: MetricInResponse
TokenLogProbs:
type: object
properties:
@ -3227,6 +3202,10 @@ components:
CompletionResponse:
type: object
properties:
metrics:
type: array
items:
$ref: '#/components/schemas/MetricInResponse'
content:
type: string
description: The generated completion text
@ -3426,7 +3405,7 @@ components:
metrics:
type: array
items:
$ref: '#/components/schemas/MetricEvent'
$ref: '#/components/schemas/MetricInResponse'
event:
$ref: '#/components/schemas/ChatCompletionResponseEvent'
description: The event containing the new content
@ -3545,6 +3524,10 @@ components:
CompletionResponseStreamChunk:
type: object
properties:
metrics:
type: array
items:
$ref: '#/components/schemas/MetricInResponse'
delta:
type: string
description: >-
@ -5008,14 +4991,15 @@ components:
type: string
const: dataset
default: dataset
schema:
purpose:
type: string
enum:
- messages
title: Schema
- post-training/messages
- eval/question-answer
title: DatasetPurpose
description: >-
Schema of the dataset. Each type has a different column format.
data_source:
Purpose of the dataset. Each type has a different column format.
source:
$ref: '#/components/schemas/DataSource'
metadata:
type: object
@ -5033,8 +5017,8 @@ components:
- provider_resource_id
- provider_id
- type
- schema
- data_source
- purpose
- source
- metadata
title: Dataset
HuggingfaceDataSource:
@ -5044,8 +5028,10 @@ components:
type: string
const: huggingface
default: huggingface
dataset_path:
path:
type: string
description: >-
The path to the dataset in Huggingface. E.g. - "llamastack/simpleqa"
params:
type: object
additionalProperties:
@ -5056,12 +5042,14 @@ components:
- type: string
- type: array
- type: object
description: The parameters for the dataset.
additionalProperties: false
required:
- type
- dataset_path
- path
- params
title: HuggingfaceDataSource
description: A dataset stored in Huggingface.
RowsDataSource:
type: object
properties:
@ -5081,11 +5069,16 @@ components:
- type: string
- type: array
- type: object
description: >-
The dataset is stored in rows. E.g. - [ {"messages": [{"role": "user",
"content": "Hello, world!"}, {"role": "assistant", "content": "Hello,
world!"}]} ]
additionalProperties: false
required:
- type
- rows
title: RowsDataSource
description: A dataset stored in rows.
URIDataSource:
type: object
properties:
@ -5095,11 +5088,16 @@ components:
default: uri
uri:
type: string
description: >-
The dataset can be obtained from a URI. E.g. - "https://mywebsite.com/mydata.jsonl"
- "lsfs://mydata.jsonl" - "data:csv;base64,{base64_content}"
additionalProperties: false
required:
- type
- uri
title: URIDataSource
description: >-
A dataset that can be obtained from a URI.
Model:
type: object
properties:
@ -5920,6 +5918,47 @@ components:
- error
- critical
title: LogSeverity
MetricEvent:
type: object
properties:
trace_id:
type: string
span_id:
type: string
timestamp:
type: string
format: date-time
attributes:
type: object
additionalProperties:
oneOf:
- type: string
- type: integer
- type: number
- type: boolean
- type: 'null'
type:
type: string
const: metric
default: metric
metric:
type: string
value:
oneOf:
- type: integer
- type: number
unit:
type: string
additionalProperties: false
required:
- trace_id
- span_id
- timestamp
- type
- metric
- value
- unit
title: MetricEvent
SpanEndPayload:
type: object
properties:
@ -6483,14 +6522,16 @@ components:
RegisterDatasetRequest:
type: object
properties:
schema:
purpose:
type: string
enum:
- messages
- post-training/messages
- eval/question-answer
description: >-
The schema format of the dataset. One of - messages: The dataset contains
a messages column with list of messages for post-training.
data_source:
The purpose of the dataset. One of - "post-training/messages": The dataset
contains a messages column with list of messages for post-training. -
"eval/question-answer": The dataset contains a question and answer column.
source:
$ref: '#/components/schemas/DataSource'
description: >-
The data source of the dataset. Examples: - { "type": "uri", "uri": "https://mywebsite.com/mydata.jsonl"
@ -6517,8 +6558,8 @@ components:
The ID of the dataset. If not provided, a random ID will be generated.
additionalProperties: false
required:
- schema
- data_source
- purpose
- source
title: RegisterDatasetRequest
RegisterModelRequest:
type: object
@ -6643,16 +6684,11 @@ components:
type: object
properties:
tool_responses:
oneOf:
- type: array
type: array
items:
$ref: '#/components/schemas/ToolResponse'
- type: array
items:
$ref: '#/components/schemas/ToolResponseMessage'
description: >-
The tool call responses to resume the turn with. NOTE: ToolResponseMessage
will be deprecated. Use ToolResponse.
The tool call responses to resume the turn with.
stream:
type: boolean
description: Whether to stream the response.

View file

@ -370,7 +370,7 @@ class AgentTurnResumeRequest(BaseModel):
agent_id: str
session_id: str
turn_id: str
tool_responses: Union[List[ToolResponse], List[ToolResponseMessage]]
tool_responses: List[ToolResponse]
stream: Optional[bool] = False
@ -449,7 +449,7 @@ class Agents(Protocol):
agent_id: str,
session_id: str,
turn_id: str,
tool_responses: Union[List[ToolResponse], List[ToolResponseMessage]],
tool_responses: List[ToolResponse],
stream: Optional[bool] = False,
) -> Union[Turn, AsyncIterator[AgentTurnResponseStreamChunk]]:
"""Resume an agent turn with executed tool call responses.
@ -460,7 +460,6 @@ class Agents(Protocol):
:param session_id: The ID of the session to resume.
:param turn_id: The ID of the turn to resume.
:param tool_responses: The tool call responses to resume the turn with.
NOTE: ToolResponseMessage will be deprecated. Use ToolResponse.
:param stream: Whether to stream the response.
:returns: A Turn object if stream is False, otherwise an AsyncIterator of AgentTurnResponseStreamChunk objects.
"""

View file

@ -13,10 +13,10 @@ from llama_stack.apis.resource import Resource, ResourceType
from llama_stack.schema_utils import json_schema_type, register_schema, webmethod
class Schema(Enum):
class DatasetPurpose(Enum):
"""
Schema of the dataset. Each type has a different column format.
:cvar messages: The dataset contains messages used for post-training. Examples:
Purpose of the dataset. Each type has a different column format.
:cvar post-training/messages: The dataset contains messages used for post-training. Examples:
{
"messages": [
{"role": "user", "content": "Hello, world!"},
@ -25,11 +25,19 @@ class Schema(Enum):
}
"""
messages = "messages"
post_training_messages = "post-training/messages"
eval_question_answer = "eval/question-answer"
# TODO: add more schemas here
class DatasetType(Enum):
"""
Type of the dataset source.
:cvar huggingface: The dataset is stored in Huggingface.
:cvar uri: The dataset can be obtained from a URI.
:cvar rows: The dataset is stored in rows.
"""
huggingface = "huggingface"
uri = "uri"
rows = "rows"
@ -37,19 +45,36 @@ class DatasetType(Enum):
@json_schema_type
class URIDataSource(BaseModel):
"""A dataset that can be obtained from a URI.
:param uri: The dataset can be obtained from a URI. E.g.
- "https://mywebsite.com/mydata.jsonl"
- "lsfs://mydata.jsonl"
- "data:csv;base64,{base64_content}"
"""
type: Literal["uri"] = "uri"
uri: str
@json_schema_type
class HuggingfaceDataSource(BaseModel):
"""A dataset stored in Huggingface.
:param path: The path to the dataset in Huggingface. E.g.
- "llamastack/simpleqa"
:param params: The parameters for the dataset.
"""
type: Literal["huggingface"] = "huggingface"
dataset_path: str
path: str
params: Dict[str, Any]
@json_schema_type
class RowsDataSource(BaseModel):
"""A dataset stored in rows.
:param rows: The dataset is stored in rows. E.g.
- [
{"messages": [{"role": "user", "content": "Hello, world!"}, {"role": "assistant", "content": "Hello, world!"}]}
]
"""
type: Literal["rows"] = "rows"
rows: List[Dict[str, Any]]
@ -64,8 +89,11 @@ DataSource = register_schema(
class CommonDatasetFields(BaseModel):
schema: Schema
data_source: DataSource
"""
Common fields for a dataset.
"""
purpose: DatasetPurpose
source: DataSource
metadata: Dict[str, Any] = Field(
default_factory=dict,
description="Any additional metadata for this dataset",
@ -99,17 +127,18 @@ class Datasets(Protocol):
@webmethod(route="/datasets", method="POST")
async def register_dataset(
self,
schema: Schema,
data_source: DataSource,
purpose: DatasetPurpose,
source: DataSource,
metadata: Optional[Dict[str, Any]] = None,
dataset_id: Optional[str] = None,
) -> Dataset:
"""
Register a new dataset.
:param schema: The schema format of the dataset. One of
- messages: The dataset contains a messages column with list of messages for post-training.
:param data_source: The data source of the dataset. Examples:
:param purpose: The purpose of the dataset. One of
- "post-training/messages": The dataset contains a messages column with list of messages for post-training.
- "eval/question-answer": The dataset contains a question and answer column.
:param source: The data source of the dataset. Examples:
- {
"type": "uri",
"uri": "https://mywebsite.com/mydata.jsonl"

View file

@ -285,7 +285,7 @@ class CompletionRequest(BaseModel):
@json_schema_type
class CompletionResponse(BaseModel):
class CompletionResponse(MetricResponseMixin):
"""Response from a completion request.
:param content: The generated completion text
@ -299,7 +299,7 @@ class CompletionResponse(BaseModel):
@json_schema_type
class CompletionResponseStreamChunk(BaseModel):
class CompletionResponseStreamChunk(MetricResponseMixin):
"""A chunk of a streamed completion response.
:param delta: New content generated since last chunk. This can be one or more tokens.
@ -368,7 +368,7 @@ class ChatCompletionRequest(BaseModel):
@json_schema_type
class ChatCompletionResponseStreamChunk(MetricResponseMixin, BaseModel):
class ChatCompletionResponseStreamChunk(MetricResponseMixin):
"""A chunk of a streamed chat completion response.
:param event: The event containing the new content
@ -378,7 +378,7 @@ class ChatCompletionResponseStreamChunk(MetricResponseMixin, BaseModel):
@json_schema_type
class ChatCompletionResponse(MetricResponseMixin, BaseModel):
class ChatCompletionResponse(MetricResponseMixin):
"""Response from a chat completion request.
:param completion_message: The complete response message

View file

@ -96,6 +96,13 @@ class MetricEvent(EventCommon):
unit: str
@json_schema_type
class MetricInResponse(BaseModel):
metric: str
value: Union[int, float]
unit: Optional[str] = None
# This is a short term solution to allow inference API to return metrics
# The ideal way to do this is to have a way for all response types to include metrics
# and all metric events logged to the telemetry API to be inlcuded with the response
@ -117,7 +124,7 @@ class MetricEvent(EventCommon):
class MetricResponseMixin(BaseModel):
metrics: Optional[List[MetricEvent]] = None
metrics: Optional[List[MetricInResponse]] = None
@json_schema_type

View file

@ -33,7 +33,7 @@ from llama_stack.distribution.build import print_pip_install_help
from llama_stack.distribution.configure import parse_and_maybe_upgrade_config
from llama_stack.distribution.datatypes import Api
from llama_stack.distribution.request_headers import (
preserve_headers_context_async_generator,
PROVIDER_DATA_VAR,
request_provider_data_context,
)
from llama_stack.distribution.resolver import ProviderRegistry
@ -44,8 +44,10 @@ from llama_stack.distribution.stack import (
redact_sensitive_fields,
replace_env_vars,
)
from llama_stack.distribution.utils.context import preserve_contexts_async_generator
from llama_stack.distribution.utils.exec import in_notebook
from llama_stack.providers.utils.telemetry.tracing import (
CURRENT_TRACE_CONTEXT,
end_trace,
setup_logger,
start_trace,
@ -384,8 +386,8 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
finally:
await end_trace()
# Wrap the generator to preserve context across iterations
wrapped_gen = preserve_headers_context_async_generator(gen())
wrapped_gen = preserve_contexts_async_generator(gen(), [CURRENT_TRACE_CONTEXT, PROVIDER_DATA_VAR])
mock_response = httpx.Response(
status_code=httpx.codes.OK,
content=wrapped_gen,

View file

@ -7,14 +7,14 @@
import contextvars
import json
import logging
from typing import Any, AsyncGenerator, ContextManager, Dict, Optional, TypeVar
from typing import Any, ContextManager, Dict, Optional
from .utils.dynamic import instantiate_class_type
log = logging.getLogger(__name__)
# Context variable for request provider data
_provider_data_var = contextvars.ContextVar("provider_data", default=None)
PROVIDER_DATA_VAR = contextvars.ContextVar("provider_data", default=None)
class RequestProviderDataContext(ContextManager):
@ -26,40 +26,13 @@ class RequestProviderDataContext(ContextManager):
def __enter__(self):
# Save the current value and set the new one
self.token = _provider_data_var.set(self.provider_data)
self.token = PROVIDER_DATA_VAR.set(self.provider_data)
return self
def __exit__(self, exc_type, exc_val, exc_tb):
# Restore the previous value
if self.token is not None:
_provider_data_var.reset(self.token)
T = TypeVar("T")
def preserve_headers_context_async_generator(gen: AsyncGenerator[T, None]) -> AsyncGenerator[T, None]:
"""
Wraps an async generator to preserve request headers context variables across iterations.
This ensures that context variables set during generator creation are
available during each iteration of the generator, even if the original
context manager has exited.
"""
# Capture the current context value right now
context_value = _provider_data_var.get()
async def wrapper():
while True:
# Set context before each anext() call
_ = _provider_data_var.set(context_value)
try:
item = await gen.__anext__()
yield item
except StopAsyncIteration:
break
return wrapper()
PROVIDER_DATA_VAR.reset(self.token)
class NeedsRequestProviderData:
@ -72,7 +45,7 @@ class NeedsRequestProviderData:
if not validator_class:
raise ValueError(f"Provider {provider_type} does not have a validator")
val = _provider_data_var.get()
val = PROVIDER_DATA_VAR.get()
if not val:
return None

View file

@ -165,7 +165,9 @@ def specs_for_autorouted_apis(apis_to_serve: List[str] | Set[str]) -> Dict[str,
module="llama_stack.distribution.routers",
routing_table_api=info.routing_table_api,
api_dependencies=[info.routing_table_api],
deps__=[info.routing_table_api.value],
# Add telemetry as an optional dependency to all auto-routed providers
optional_api_dependencies=[Api.telemetry],
deps__=([info.routing_table_api.value, Api.telemetry.value]),
),
)
}

View file

@ -45,7 +45,7 @@ async def get_routing_table_impl(
return impl
async def get_auto_router_impl(api: Api, routing_table: RoutingTable, _deps) -> Any:
async def get_auto_router_impl(api: Api, routing_table: RoutingTable, deps: Dict[str, Any]) -> Any:
from .routers import (
DatasetIORouter,
EvalRouter,
@ -65,9 +65,17 @@ async def get_auto_router_impl(api: Api, routing_table: RoutingTable, _deps) ->
"eval": EvalRouter,
"tool_runtime": ToolRuntimeRouter,
}
api_to_deps = {
"inference": {"telemetry": Api.telemetry},
}
if api.value not in api_to_routers:
raise ValueError(f"API {api.value} not found in router map")
impl = api_to_routers[api.value](routing_table)
api_to_dep_impl = {}
for dep_name, dep_api in api_to_deps.get(api.value, {}).items():
if dep_api in deps:
api_to_dep_impl[dep_name] = deps[dep_api]
impl = api_to_routers[api.value](routing_table, **api_to_dep_impl)
await impl.initialize()
return impl

View file

@ -4,7 +4,8 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any, AsyncGenerator, Dict, List, Optional
import time
from typing import Any, AsyncGenerator, AsyncIterator, Dict, List, Optional, Union
from llama_stack.apis.common.content_types import (
URL,
@ -20,6 +21,10 @@ from llama_stack.apis.eval import (
JobStatus,
)
from llama_stack.apis.inference import (
ChatCompletionResponse,
ChatCompletionResponseEventType,
ChatCompletionResponseStreamChunk,
CompletionMessage,
EmbeddingsResponse,
EmbeddingTaskType,
Inference,
@ -27,13 +32,14 @@ from llama_stack.apis.inference import (
Message,
ResponseFormat,
SamplingParams,
StopReason,
TextTruncation,
ToolChoice,
ToolConfig,
ToolDefinition,
ToolPromptFormat,
)
from llama_stack.apis.models import ModelType
from llama_stack.apis.models import Model, ModelType
from llama_stack.apis.safety import RunShieldResponse, Safety
from llama_stack.apis.scoring import (
ScoreBatchResponse,
@ -42,6 +48,7 @@ from llama_stack.apis.scoring import (
ScoringFnParams,
)
from llama_stack.apis.shields import Shield
from llama_stack.apis.telemetry import MetricEvent, MetricInResponse, Telemetry
from llama_stack.apis.tools import (
RAGDocument,
RAGQueryConfig,
@ -52,7 +59,10 @@ from llama_stack.apis.tools import (
)
from llama_stack.apis.vector_io import Chunk, QueryChunksResponse, VectorIO
from llama_stack.log import get_logger
from llama_stack.models.llama.llama3.chat_format import ChatFormat
from llama_stack.models.llama.llama3.tokenizer import Tokenizer
from llama_stack.providers.datatypes import RoutingTable
from llama_stack.providers.utils.telemetry.tracing import get_current_span
logger = get_logger(name=__name__, category="core")
@ -119,9 +129,14 @@ class InferenceRouter(Inference):
def __init__(
self,
routing_table: RoutingTable,
telemetry: Optional[Telemetry] = None,
) -> None:
logger.debug("Initializing InferenceRouter")
self.routing_table = routing_table
self.telemetry = telemetry
if self.telemetry:
self.tokenizer = Tokenizer.get_instance()
self.formatter = ChatFormat(self.tokenizer)
async def initialize(self) -> None:
logger.debug("InferenceRouter.initialize")
@ -144,6 +159,71 @@ class InferenceRouter(Inference):
)
await self.routing_table.register_model(model_id, provider_model_id, provider_id, metadata, model_type)
def _construct_metrics(
self, prompt_tokens: int, completion_tokens: int, total_tokens: int, model: Model
) -> List[MetricEvent]:
"""Constructs a list of MetricEvent objects containing token usage metrics.
Args:
prompt_tokens: Number of tokens in the prompt
completion_tokens: Number of tokens in the completion
total_tokens: Total number of tokens used
model: Model object containing model_id and provider_id
Returns:
List of MetricEvent objects with token usage metrics
"""
span = get_current_span()
if span is None:
logger.warning("No span found for token usage metrics")
return []
metrics = [
("prompt_tokens", prompt_tokens),
("completion_tokens", completion_tokens),
("total_tokens", total_tokens),
]
metric_events = []
for metric_name, value in metrics:
metric_events.append(
MetricEvent(
trace_id=span.trace_id,
span_id=span.span_id,
metric=metric_name,
value=value,
timestamp=time.time(),
unit="tokens",
attributes={
"model_id": model.model_id,
"provider_id": model.provider_id,
},
)
)
return metric_events
async def _compute_and_log_token_usage(
self,
prompt_tokens: int,
completion_tokens: int,
total_tokens: int,
model: Model,
) -> List[MetricInResponse]:
metrics = self._construct_metrics(prompt_tokens, completion_tokens, total_tokens, model)
if self.telemetry:
for metric in metrics:
await self.telemetry.log_event(metric)
return [MetricInResponse(metric=metric.metric, value=metric.value) for metric in metrics]
async def _count_tokens(
self,
messages: List[Message] | InterleavedContent,
tool_prompt_format: Optional[ToolPromptFormat] = None,
) -> Optional[int]:
if isinstance(messages, list):
encoded = self.formatter.encode_dialog_prompt(messages, tool_prompt_format)
else:
encoded = self.formatter.encode_content(messages)
return len(encoded.tokens) if encoded and encoded.tokens else 0
async def chat_completion(
self,
model_id: str,
@ -156,7 +236,7 @@ class InferenceRouter(Inference):
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
tool_config: Optional[ToolConfig] = None,
) -> AsyncGenerator:
) -> Union[ChatCompletionResponse, AsyncIterator[ChatCompletionResponseStreamChunk]]:
logger.debug(
f"InferenceRouter.chat_completion: {model_id=}, {stream=}, {messages=}, {tools=}, {tool_config=}, {response_format=}",
)
@ -206,10 +286,47 @@ class InferenceRouter(Inference):
tool_config=tool_config,
)
provider = self.routing_table.get_provider_impl(model_id)
prompt_tokens = await self._count_tokens(messages, tool_config.tool_prompt_format)
if stream:
return (chunk async for chunk in await provider.chat_completion(**params))
async def stream_generator():
completion_text = ""
async for chunk in await provider.chat_completion(**params):
if chunk.event.event_type == ChatCompletionResponseEventType.progress:
if chunk.event.delta.type == "text":
completion_text += chunk.event.delta.text
if chunk.event.event_type == ChatCompletionResponseEventType.complete:
completion_tokens = await self._count_tokens(
[CompletionMessage(content=completion_text, stop_reason=StopReason.end_of_turn)],
tool_config.tool_prompt_format,
)
total_tokens = (prompt_tokens or 0) + (completion_tokens or 0)
metrics = await self._compute_and_log_token_usage(
prompt_tokens or 0,
completion_tokens or 0,
total_tokens,
model,
)
chunk.metrics = metrics if chunk.metrics is None else chunk.metrics + metrics
yield chunk
return stream_generator()
else:
return await provider.chat_completion(**params)
response = await provider.chat_completion(**params)
completion_tokens = await self._count_tokens(
[response.completion_message],
tool_config.tool_prompt_format,
)
total_tokens = (prompt_tokens or 0) + (completion_tokens or 0)
metrics = await self._compute_and_log_token_usage(
prompt_tokens or 0,
completion_tokens or 0,
total_tokens,
model,
)
response.metrics = metrics if response.metrics is None else response.metrics + metrics
return response
async def completion(
self,
@ -239,10 +356,41 @@ class InferenceRouter(Inference):
stream=stream,
logprobs=logprobs,
)
prompt_tokens = await self._count_tokens(content)
if stream:
return (chunk async for chunk in await provider.completion(**params))
async def stream_generator():
completion_text = ""
async for chunk in await provider.completion(**params):
if hasattr(chunk, "delta"):
completion_text += chunk.delta
if hasattr(chunk, "stop_reason") and chunk.stop_reason and self.telemetry:
completion_tokens = await self._count_tokens(completion_text)
total_tokens = (prompt_tokens or 0) + (completion_tokens or 0)
metrics = await self._compute_and_log_token_usage(
prompt_tokens or 0,
completion_tokens or 0,
total_tokens,
model,
)
chunk.metrics = metrics if chunk.metrics is None else chunk.metrics + metrics
yield chunk
return stream_generator()
else:
return await provider.completion(**params)
response = await provider.completion(**params)
completion_tokens = await self._count_tokens(response.content)
total_tokens = (prompt_tokens or 0) + (completion_tokens or 0)
metrics = await self._compute_and_log_token_usage(
prompt_tokens or 0,
completion_tokens or 0,
total_tokens,
model,
)
response.metrics = metrics if response.metrics is None else response.metrics + metrics
return response
async def embeddings(
self,

View file

@ -28,7 +28,7 @@ from typing_extensions import Annotated
from llama_stack.distribution.datatypes import StackRunConfig
from llama_stack.distribution.distribution import builtin_automatically_routed_apis
from llama_stack.distribution.request_headers import (
preserve_headers_context_async_generator,
PROVIDER_DATA_VAR,
request_provider_data_context,
)
from llama_stack.distribution.resolver import InvalidProviderError
@ -38,6 +38,7 @@ from llama_stack.distribution.stack import (
replace_env_vars,
validate_env_pair,
)
from llama_stack.distribution.utils.context import preserve_contexts_async_generator
from llama_stack.log import get_logger
from llama_stack.providers.datatypes import Api
from llama_stack.providers.inline.telemetry.meta_reference.config import TelemetryConfig
@ -45,6 +46,7 @@ from llama_stack.providers.inline.telemetry.meta_reference.telemetry import (
TelemetryAdapter,
)
from llama_stack.providers.utils.telemetry.tracing import (
CURRENT_TRACE_CONTEXT,
end_trace,
setup_logger,
start_trace,
@ -182,7 +184,9 @@ def create_dynamic_typed_route(func: Any, method: str, route: str):
try:
if is_streaming:
gen = preserve_headers_context_async_generator(sse_generator(func(**kwargs)))
gen = preserve_contexts_async_generator(
sse_generator(func(**kwargs)), [CURRENT_TRACE_CONTEXT, PROVIDER_DATA_VAR]
)
return StreamingResponse(gen, media_type="text/event-stream")
else:
value = func(**kwargs)

View file

@ -0,0 +1,33 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from contextvars import ContextVar
from typing import AsyncGenerator, List, TypeVar
T = TypeVar("T")
def preserve_contexts_async_generator(
gen: AsyncGenerator[T, None], context_vars: List[ContextVar]
) -> AsyncGenerator[T, None]:
"""
Wraps an async generator to preserve context variables across iterations.
This is needed because we start a new asyncio event loop for each streaming request,
and we need to preserve the context across the event loop boundary.
"""
async def wrapper() -> AsyncGenerator[T, None]:
while True:
try:
item = await gen.__anext__()
context_values = {context_var.name: context_var.get() for context_var in context_vars}
yield item
for context_var in context_vars:
_ = context_var.set(context_values[context_var.name])
except StopAsyncIteration:
break
return wrapper()

View file

@ -0,0 +1,155 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import asyncio
from concurrent.futures import ThreadPoolExecutor
from contextvars import ContextVar
import pytest
from llama_stack.distribution.utils.context import preserve_contexts_async_generator
@pytest.mark.asyncio
async def test_preserve_contexts_with_exception():
# Create context variable
context_var = ContextVar("exception_var", default="initial")
token = context_var.set("start_value")
# Create an async generator that raises an exception
async def exception_generator():
yield context_var.get()
context_var.set("modified")
raise ValueError("Test exception")
yield None # This will never be reached
# Wrap the generator
wrapped_gen = preserve_contexts_async_generator(exception_generator(), [context_var])
# First iteration should work
value = await wrapped_gen.__anext__()
assert value == "start_value"
# Second iteration should raise the exception
with pytest.raises(ValueError, match="Test exception"):
await wrapped_gen.__anext__()
# Clean up
context_var.reset(token)
@pytest.mark.asyncio
async def test_preserve_contexts_empty_generator():
# Create context variable
context_var = ContextVar("empty_var", default="initial")
token = context_var.set("value")
# Create an empty async generator
async def empty_generator():
if False: # This condition ensures the generator yields nothing
yield None
# Wrap the generator
wrapped_gen = preserve_contexts_async_generator(empty_generator(), [context_var])
# The generator should raise StopAsyncIteration immediately
with pytest.raises(StopAsyncIteration):
await wrapped_gen.__anext__()
# Context variable should remain unchanged
assert context_var.get() == "value"
# Clean up
context_var.reset(token)
@pytest.mark.asyncio
async def test_preserve_contexts_across_event_loops():
"""
Test that context variables are preserved across event loop boundaries with nested generators.
This simulates the real-world scenario where:
1. A new event loop is created for each streaming request
2. The async generator runs inside that loop
3. There are multiple levels of nested generators
4. Context needs to be preserved across these boundaries
"""
# Create context variables
request_id = ContextVar("request_id", default=None)
user_id = ContextVar("user_id", default=None)
# Set initial values
# Results container to verify values across thread boundaries
results = []
# Inner-most generator (level 2)
async def inner_generator():
# Should have the context from the outer scope
yield (1, request_id.get(), user_id.get())
# Modify one context variable
user_id.set("user-modified")
# Should reflect the modification
yield (2, request_id.get(), user_id.get())
# Middle generator (level 1)
async def middle_generator():
inner_gen = inner_generator()
# Forward the first yield from inner
item = await inner_gen.__anext__()
yield item
# Forward the second yield from inner
item = await inner_gen.__anext__()
yield item
request_id.set("req-modified")
# Add our own yield with both modified variables
yield (3, request_id.get(), user_id.get())
# Function to run in a separate thread with a new event loop
def run_in_new_loop():
# Create a new event loop for this thread
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
# Outer generator (runs in the new loop)
async def outer_generator():
request_id.set("req-12345")
user_id.set("user-6789")
# Wrap the middle generator
wrapped_gen = preserve_contexts_async_generator(middle_generator(), [request_id, user_id])
# Process all items from the middle generator
async for item in wrapped_gen:
# Store results for verification
results.append(item)
# Run the outer generator in the new loop
loop.run_until_complete(outer_generator())
finally:
loop.close()
# Run the generator chain in a separate thread with a new event loop
with ThreadPoolExecutor(max_workers=1) as executor:
future = executor.submit(run_in_new_loop)
future.result() # Wait for completion
# Verify the results
assert len(results) == 3
# First yield should have original values
assert results[0] == (1, "req-12345", "user-6789")
# Second yield should have modified user_id
assert results[1] == (2, "req-12345", "user-modified")
# Third yield should have both modified values
assert results[2] == (3, "req-modified", "user-modified")

View file

@ -218,18 +218,10 @@ class ChatAgent(ShieldRunnerMixin):
steps = []
messages = await self.get_messages_from_turns(turns)
if is_resume:
if isinstance(request.tool_responses[0], ToolResponseMessage):
tool_response_messages = request.tool_responses
tool_responses = [
ToolResponse(call_id=x.call_id, tool_name=x.tool_name, content=x.content)
for x in request.tool_responses
]
else:
tool_response_messages = [
ToolResponseMessage(call_id=x.call_id, tool_name=x.tool_name, content=x.content)
for x in request.tool_responses
]
tool_responses = request.tool_responses
messages.extend(tool_response_messages)
last_turn = turns[-1]
last_turn_messages = self.turn_to_messages(last_turn)
@ -252,7 +244,7 @@ class ChatAgent(ShieldRunnerMixin):
step_id=(in_progress_tool_call_step.step_id if in_progress_tool_call_step else str(uuid.uuid4())),
turn_id=request.turn_id,
tool_calls=(in_progress_tool_call_step.tool_calls if in_progress_tool_call_step else []),
tool_responses=tool_responses,
tool_responses=request.tool_responses,
completed_at=now,
started_at=(in_progress_tool_call_step.started_at if in_progress_tool_call_step else now),
)

View file

@ -172,7 +172,7 @@ class MetaReferenceAgentsImpl(Agents):
agent_id: str,
session_id: str,
turn_id: str,
tool_responses: Union[List[ToolResponse], List[ToolResponseMessage]],
tool_responses: List[ToolResponse],
stream: Optional[bool] = False,
) -> AsyncGenerator:
request = AgentTurnResumeRequest(

View file

@ -73,6 +73,7 @@ class TelemetryAdapter(TelemetryDatasetMixin, Telemetry):
def __init__(self, config: TelemetryConfig, deps: Dict[str, Any]) -> None:
self.config = config
self.datasetio_api = deps.get(Api.datasetio)
self.meter = None
resource = Resource.create(
{
@ -171,6 +172,8 @@ class TelemetryAdapter(TelemetryDatasetMixin, Telemetry):
return _GLOBAL_STORAGE["gauges"][name]
def _log_metric(self, event: MetricEvent) -> None:
if self.meter is None:
return
if isinstance(event.value, int):
counter = self._get_or_create_counter(event.metric, event.unit)
counter.add(event.value, attributes=event.attributes)

View file

@ -4,8 +4,9 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import List, Tuple
from typing import Dict, List, Tuple
from llama_stack.apis.common.content_types import URL
from llama_stack.apis.models.models import ModelType
from llama_stack.distribution.datatypes import (
BenchmarkInput,
@ -15,21 +16,27 @@ from llama_stack.distribution.datatypes import (
ShieldInput,
ToolGroupInput,
)
from llama_stack.providers.inline.vector_io.sqlite_vec.config import SQLiteVectorIOConfig
from llama_stack.providers.inline.vector_io.sqlite_vec.config import (
SQLiteVectorIOConfig,
)
from llama_stack.providers.remote.inference.anthropic.config import AnthropicConfig
from llama_stack.providers.remote.inference.gemini.config import GeminiConfig
from llama_stack.providers.remote.inference.groq.config import GroqConfig
from llama_stack.providers.remote.inference.openai.config import OpenAIConfig
from llama_stack.providers.remote.inference.together.config import TogetherImplConfig
from llama_stack.providers.remote.vector_io.chroma.config import ChromaVectorIOConfig
from llama_stack.providers.remote.vector_io.pgvector.config import PGVectorVectorIOConfig
from llama_stack.providers.utils.inference.model_registry import (
ProviderModelEntry,
from llama_stack.providers.remote.vector_io.pgvector.config import (
PGVectorVectorIOConfig,
)
from llama_stack.providers.utils.inference.model_registry import ProviderModelEntry
from llama_stack.templates.template import (
DistributionTemplate,
RunConfigSettings,
get_model_registry,
)
from llama_stack.templates.template import DistributionTemplate, RunConfigSettings, get_model_registry
def get_inference_providers() -> Tuple[List[Provider], List[ModelInput]]:
def get_inference_providers() -> Tuple[List[Provider], Dict[str, List[ProviderModelEntry]]]:
# in this template, we allow each API key to be optional
providers = [
(
@ -164,7 +171,7 @@ def get_distribution_template() -> DistributionTemplate:
DatasetInput(
dataset_id="simpleqa",
provider_id="huggingface",
url={"uri": "https://huggingface.co/datasets/llamastack/simpleqa"},
url=URL(uri="https://huggingface.co/datasets/llamastack/simpleqa"),
metadata={
"path": "llamastack/simpleqa",
"split": "train",
@ -178,7 +185,7 @@ def get_distribution_template() -> DistributionTemplate:
DatasetInput(
dataset_id="mmlu_cot",
provider_id="huggingface",
url={"uri": "https://huggingface.co/datasets/llamastack/mmlu_cot"},
url=URL(uri="https://huggingface.co/datasets/llamastack/mmlu_cot"),
metadata={
"path": "llamastack/mmlu_cot",
"name": "all",
@ -193,7 +200,7 @@ def get_distribution_template() -> DistributionTemplate:
DatasetInput(
dataset_id="gpqa_cot",
provider_id="huggingface",
url={"uri": "https://huggingface.co/datasets/llamastack/gpqa_0shot_cot"},
url=URL(uri="https://huggingface.co/datasets/llamastack/gpqa_0shot_cot"),
metadata={
"path": "llamastack/gpqa_0shot_cot",
"name": "gpqa_main",
@ -208,7 +215,7 @@ def get_distribution_template() -> DistributionTemplate:
DatasetInput(
dataset_id="math_500",
provider_id="huggingface",
url={"uri": "https://huggingface.co/datasets/llamastack/math_500"},
url=URL(uri="https://huggingface.co/datasets/llamastack/math_500"),
metadata={
"path": "llamastack/math_500",
"split": "test",

View file

@ -30,7 +30,9 @@ from llama_stack.providers.utils.inference.model_registry import ProviderModelEn
from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig
def get_model_registry(available_models: Dict[str, List[ProviderModelEntry]]) -> List[ModelInput]:
def get_model_registry(
available_models: Dict[str, List[ProviderModelEntry]],
) -> List[ModelInput]:
models = []
for provider_id, entries in available_models.items():
for entry in entries:
@ -193,7 +195,7 @@ class DistributionTemplate(BaseModel):
default_models.append(
DefaultModel(
model_id=model_entry.provider_model_id,
doc_string=f"({' -- '.join(doc_parts)})" if doc_parts else "",
doc_string=(f"({' -- '.join(doc_parts)})" if doc_parts else ""),
)
)