diff --git a/.github/workflows/unit-tests.yml b/.github/workflows/unit-tests.yml index 39505ba11..59d18b3be 100644 --- a/.github/workflows/unit-tests.yml +++ b/.github/workflows/unit-tests.yml @@ -1,6 +1,8 @@ name: Unit Tests on: + push: + branches: [ main ] pull_request: branches: [ main ] workflow_dispatch: diff --git a/README.md b/README.md index b24e69514..6e1fd088e 100644 --- a/README.md +++ b/README.md @@ -4,6 +4,7 @@ [![PyPI - Downloads](https://img.shields.io/pypi/dm/llama-stack)](https://pypi.org/project/llama-stack/) [![License](https://img.shields.io/pypi/l/llama_stack.svg)](https://github.com/meta-llama/llama-stack/blob/main/LICENSE) [![Discord](https://img.shields.io/discord/1257833999603335178)](https://discord.gg/llama-stack) +![Unit](https://github.com/meta-llama/llama-stack/actions/workflows/unit-tests.yml/badge.svg?branch=main) [**Quick Start**](https://llama-stack.readthedocs.io/en/latest/getting_started/index.html) | [**Documentation**](https://llama-stack.readthedocs.io/en/latest/index.html) | [**Colab Notebook**](./docs/getting_started.ipynb) diff --git a/docs/_static/llama-stack-spec.html b/docs/_static/llama-stack-spec.html index 817a65ca8..53ccd5326 100644 --- a/docs/_static/llama-stack-spec.html +++ b/docs/_static/llama-stack-spec.html @@ -4570,7 +4570,7 @@ "metrics": { "type": "array", "items": { - "$ref": "#/components/schemas/MetricEvent" + "$ref": "#/components/schemas/MetricInResponse" } }, "completion_message": { @@ -4592,46 +4592,9 @@ "title": "ChatCompletionResponse", "description": "Response from a chat completion request." }, - "MetricEvent": { + "MetricInResponse": { "type": "object", "properties": { - "trace_id": { - "type": "string" - }, - "span_id": { - "type": "string" - }, - "timestamp": { - "type": "string", - "format": "date-time" - }, - "attributes": { - "type": "object", - "additionalProperties": { - "oneOf": [ - { - "type": "string" - }, - { - "type": "integer" - }, - { - "type": "number" - }, - { - "type": "boolean" - }, - { - "type": "null" - } - ] - } - }, - "type": { - "type": "string", - "const": "metric", - "default": "metric" - }, "metric": { "type": "string" }, @@ -4651,15 +4614,10 @@ }, "additionalProperties": false, "required": [ - "trace_id", - "span_id", - "timestamp", - "type", "metric", - "value", - "unit" + "value" ], - "title": "MetricEvent" + "title": "MetricInResponse" }, "TokenLogProbs": { "type": "object", @@ -4736,6 +4694,12 @@ "CompletionResponse": { "type": "object", "properties": { + "metrics": { + "type": "array", + "items": { + "$ref": "#/components/schemas/MetricInResponse" + } + }, "content": { "type": "string", "description": "The generated completion text" @@ -4945,7 +4909,7 @@ "metrics": { "type": "array", "items": { - "$ref": "#/components/schemas/MetricEvent" + "$ref": "#/components/schemas/MetricInResponse" } }, "event": { @@ -5103,6 +5067,12 @@ "CompletionResponseStreamChunk": { "type": "object", "properties": { + "metrics": { + "type": "array", + "items": { + "$ref": "#/components/schemas/MetricInResponse" + } + }, "delta": { "type": "string", "description": "New content generated since last chunk. This can be one or more tokens." @@ -7192,15 +7162,16 @@ "const": "dataset", "default": "dataset" }, - "schema": { + "purpose": { "type": "string", "enum": [ - "messages" + "post-training/messages", + "eval/question-answer" ], - "title": "Schema", - "description": "Schema of the dataset. Each type has a different column format." + "title": "DatasetPurpose", + "description": "Purpose of the dataset. Each type has a different column format." }, - "data_source": { + "source": { "$ref": "#/components/schemas/DataSource" }, "metadata": { @@ -7235,8 +7206,8 @@ "provider_resource_id", "provider_id", "type", - "schema", - "data_source", + "purpose", + "source", "metadata" ], "title": "Dataset" @@ -7249,8 +7220,9 @@ "const": "huggingface", "default": "huggingface" }, - "dataset_path": { - "type": "string" + "path": { + "type": "string", + "description": "The path to the dataset in Huggingface. E.g. - \"llamastack/simpleqa\"" }, "params": { "type": "object", @@ -7275,16 +7247,18 @@ "type": "object" } ] - } + }, + "description": "The parameters for the dataset." } }, "additionalProperties": false, "required": [ "type", - "dataset_path", + "path", "params" ], - "title": "HuggingfaceDataSource" + "title": "HuggingfaceDataSource", + "description": "A dataset stored in Huggingface." }, "RowsDataSource": { "type": "object", @@ -7320,7 +7294,8 @@ } ] } - } + }, + "description": "The dataset is stored in rows. E.g. - [ {\"messages\": [{\"role\": \"user\", \"content\": \"Hello, world!\"}, {\"role\": \"assistant\", \"content\": \"Hello, world!\"}]} ]" } }, "additionalProperties": false, @@ -7328,7 +7303,8 @@ "type", "rows" ], - "title": "RowsDataSource" + "title": "RowsDataSource", + "description": "A dataset stored in rows." }, "URIDataSource": { "type": "object", @@ -7339,7 +7315,8 @@ "default": "uri" }, "uri": { - "type": "string" + "type": "string", + "description": "The dataset can be obtained from a URI. E.g. - \"https://mywebsite.com/mydata.jsonl\" - \"lsfs://mydata.jsonl\" - \"data:csv;base64,{base64_content}\"" } }, "additionalProperties": false, @@ -7347,7 +7324,8 @@ "type", "uri" ], - "title": "URIDataSource" + "title": "URIDataSource", + "description": "A dataset that can be obtained from a URI." }, "Model": { "type": "object", @@ -8634,6 +8612,75 @@ ], "title": "LogSeverity" }, + "MetricEvent": { + "type": "object", + "properties": { + "trace_id": { + "type": "string" + }, + "span_id": { + "type": "string" + }, + "timestamp": { + "type": "string", + "format": "date-time" + }, + "attributes": { + "type": "object", + "additionalProperties": { + "oneOf": [ + { + "type": "string" + }, + { + "type": "integer" + }, + { + "type": "number" + }, + { + "type": "boolean" + }, + { + "type": "null" + } + ] + } + }, + "type": { + "type": "string", + "const": "metric", + "default": "metric" + }, + "metric": { + "type": "string" + }, + "value": { + "oneOf": [ + { + "type": "integer" + }, + { + "type": "number" + } + ] + }, + "unit": { + "type": "string" + } + }, + "additionalProperties": false, + "required": [ + "trace_id", + "span_id", + "timestamp", + "type", + "metric", + "value", + "unit" + ], + "title": "MetricEvent" + }, "SpanEndPayload": { "type": "object", "properties": { @@ -9510,14 +9557,15 @@ "RegisterDatasetRequest": { "type": "object", "properties": { - "schema": { + "purpose": { "type": "string", "enum": [ - "messages" + "post-training/messages", + "eval/question-answer" ], - "description": "The schema format of the dataset. One of - messages: The dataset contains a messages column with list of messages for post-training." + "description": "The purpose of the dataset. One of - \"post-training/messages\": The dataset contains a messages column with list of messages for post-training. - \"eval/question-answer\": The dataset contains a question and answer column." }, - "data_source": { + "source": { "$ref": "#/components/schemas/DataSource", "description": "The data source of the dataset. Examples: - { \"type\": \"uri\", \"uri\": \"https://mywebsite.com/mydata.jsonl\" } - { \"type\": \"uri\", \"uri\": \"lsfs://mydata.jsonl\" } - { \"type\": \"huggingface\", \"dataset_path\": \"tatsu-lab/alpaca\", \"params\": { \"split\": \"train\" } } - { \"type\": \"rows\", \"rows\": [ { \"messages\": [ {\"role\": \"user\", \"content\": \"Hello, world!\"}, {\"role\": \"assistant\", \"content\": \"Hello, world!\"}, ] } ] }" }, @@ -9554,8 +9602,8 @@ }, "additionalProperties": false, "required": [ - "schema", - "data_source" + "purpose", + "source" ], "title": "RegisterDatasetRequest" }, @@ -9769,21 +9817,11 @@ "type": "object", "properties": { "tool_responses": { - "oneOf": [ - { - "type": "array", - "items": { - "$ref": "#/components/schemas/ToolResponse" - } - }, - { - "type": "array", - "items": { - "$ref": "#/components/schemas/ToolResponseMessage" - } - } - ], - "description": "The tool call responses to resume the turn with. NOTE: ToolResponseMessage will be deprecated. Use ToolResponse." + "type": "array", + "items": { + "$ref": "#/components/schemas/ToolResponse" + }, + "description": "The tool call responses to resume the turn with." }, "stream": { "type": "boolean", diff --git a/docs/_static/llama-stack-spec.yaml b/docs/_static/llama-stack-spec.yaml index 62fb02651..c8687e9d7 100644 --- a/docs/_static/llama-stack-spec.yaml +++ b/docs/_static/llama-stack-spec.yaml @@ -3115,7 +3115,7 @@ components: metrics: type: array items: - $ref: '#/components/schemas/MetricEvent' + $ref: '#/components/schemas/MetricInResponse' completion_message: $ref: '#/components/schemas/CompletionMessage' description: The complete response message @@ -3130,29 +3130,9 @@ components: - completion_message title: ChatCompletionResponse description: Response from a chat completion request. - MetricEvent: + MetricInResponse: type: object properties: - trace_id: - type: string - span_id: - type: string - timestamp: - type: string - format: date-time - attributes: - type: object - additionalProperties: - oneOf: - - type: string - - type: integer - - type: number - - type: boolean - - type: 'null' - type: - type: string - const: metric - default: metric metric: type: string value: @@ -3163,14 +3143,9 @@ components: type: string additionalProperties: false required: - - trace_id - - span_id - - timestamp - - type - metric - value - - unit - title: MetricEvent + title: MetricInResponse TokenLogProbs: type: object properties: @@ -3227,6 +3202,10 @@ components: CompletionResponse: type: object properties: + metrics: + type: array + items: + $ref: '#/components/schemas/MetricInResponse' content: type: string description: The generated completion text @@ -3426,7 +3405,7 @@ components: metrics: type: array items: - $ref: '#/components/schemas/MetricEvent' + $ref: '#/components/schemas/MetricInResponse' event: $ref: '#/components/schemas/ChatCompletionResponseEvent' description: The event containing the new content @@ -3545,6 +3524,10 @@ components: CompletionResponseStreamChunk: type: object properties: + metrics: + type: array + items: + $ref: '#/components/schemas/MetricInResponse' delta: type: string description: >- @@ -5008,14 +4991,15 @@ components: type: string const: dataset default: dataset - schema: + purpose: type: string enum: - - messages - title: Schema + - post-training/messages + - eval/question-answer + title: DatasetPurpose description: >- - Schema of the dataset. Each type has a different column format. - data_source: + Purpose of the dataset. Each type has a different column format. + source: $ref: '#/components/schemas/DataSource' metadata: type: object @@ -5033,8 +5017,8 @@ components: - provider_resource_id - provider_id - type - - schema - - data_source + - purpose + - source - metadata title: Dataset HuggingfaceDataSource: @@ -5044,8 +5028,10 @@ components: type: string const: huggingface default: huggingface - dataset_path: + path: type: string + description: >- + The path to the dataset in Huggingface. E.g. - "llamastack/simpleqa" params: type: object additionalProperties: @@ -5056,12 +5042,14 @@ components: - type: string - type: array - type: object + description: The parameters for the dataset. additionalProperties: false required: - type - - dataset_path + - path - params title: HuggingfaceDataSource + description: A dataset stored in Huggingface. RowsDataSource: type: object properties: @@ -5081,11 +5069,16 @@ components: - type: string - type: array - type: object + description: >- + The dataset is stored in rows. E.g. - [ {"messages": [{"role": "user", + "content": "Hello, world!"}, {"role": "assistant", "content": "Hello, + world!"}]} ] additionalProperties: false required: - type - rows title: RowsDataSource + description: A dataset stored in rows. URIDataSource: type: object properties: @@ -5095,11 +5088,16 @@ components: default: uri uri: type: string + description: >- + The dataset can be obtained from a URI. E.g. - "https://mywebsite.com/mydata.jsonl" + - "lsfs://mydata.jsonl" - "data:csv;base64,{base64_content}" additionalProperties: false required: - type - uri title: URIDataSource + description: >- + A dataset that can be obtained from a URI. Model: type: object properties: @@ -5920,6 +5918,47 @@ components: - error - critical title: LogSeverity + MetricEvent: + type: object + properties: + trace_id: + type: string + span_id: + type: string + timestamp: + type: string + format: date-time + attributes: + type: object + additionalProperties: + oneOf: + - type: string + - type: integer + - type: number + - type: boolean + - type: 'null' + type: + type: string + const: metric + default: metric + metric: + type: string + value: + oneOf: + - type: integer + - type: number + unit: + type: string + additionalProperties: false + required: + - trace_id + - span_id + - timestamp + - type + - metric + - value + - unit + title: MetricEvent SpanEndPayload: type: object properties: @@ -6483,14 +6522,16 @@ components: RegisterDatasetRequest: type: object properties: - schema: + purpose: type: string enum: - - messages + - post-training/messages + - eval/question-answer description: >- - The schema format of the dataset. One of - messages: The dataset contains - a messages column with list of messages for post-training. - data_source: + The purpose of the dataset. One of - "post-training/messages": The dataset + contains a messages column with list of messages for post-training. - + "eval/question-answer": The dataset contains a question and answer column. + source: $ref: '#/components/schemas/DataSource' description: >- The data source of the dataset. Examples: - { "type": "uri", "uri": "https://mywebsite.com/mydata.jsonl" @@ -6517,8 +6558,8 @@ components: The ID of the dataset. If not provided, a random ID will be generated. additionalProperties: false required: - - schema - - data_source + - purpose + - source title: RegisterDatasetRequest RegisterModelRequest: type: object @@ -6643,16 +6684,11 @@ components: type: object properties: tool_responses: - oneOf: - - type: array - items: - $ref: '#/components/schemas/ToolResponse' - - type: array - items: - $ref: '#/components/schemas/ToolResponseMessage' + type: array + items: + $ref: '#/components/schemas/ToolResponse' description: >- - The tool call responses to resume the turn with. NOTE: ToolResponseMessage - will be deprecated. Use ToolResponse. + The tool call responses to resume the turn with. stream: type: boolean description: Whether to stream the response. diff --git a/llama_stack/apis/agents/agents.py b/llama_stack/apis/agents/agents.py index 1170a56d5..5cc910a55 100644 --- a/llama_stack/apis/agents/agents.py +++ b/llama_stack/apis/agents/agents.py @@ -370,7 +370,7 @@ class AgentTurnResumeRequest(BaseModel): agent_id: str session_id: str turn_id: str - tool_responses: Union[List[ToolResponse], List[ToolResponseMessage]] + tool_responses: List[ToolResponse] stream: Optional[bool] = False @@ -449,7 +449,7 @@ class Agents(Protocol): agent_id: str, session_id: str, turn_id: str, - tool_responses: Union[List[ToolResponse], List[ToolResponseMessage]], + tool_responses: List[ToolResponse], stream: Optional[bool] = False, ) -> Union[Turn, AsyncIterator[AgentTurnResponseStreamChunk]]: """Resume an agent turn with executed tool call responses. @@ -460,7 +460,6 @@ class Agents(Protocol): :param session_id: The ID of the session to resume. :param turn_id: The ID of the turn to resume. :param tool_responses: The tool call responses to resume the turn with. - NOTE: ToolResponseMessage will be deprecated. Use ToolResponse. :param stream: Whether to stream the response. :returns: A Turn object if stream is False, otherwise an AsyncIterator of AgentTurnResponseStreamChunk objects. """ diff --git a/llama_stack/apis/datasets/datasets.py b/llama_stack/apis/datasets/datasets.py index 4b3ce3e6f..36f75d7b3 100644 --- a/llama_stack/apis/datasets/datasets.py +++ b/llama_stack/apis/datasets/datasets.py @@ -13,10 +13,10 @@ from llama_stack.apis.resource import Resource, ResourceType from llama_stack.schema_utils import json_schema_type, register_schema, webmethod -class Schema(Enum): +class DatasetPurpose(Enum): """ - Schema of the dataset. Each type has a different column format. - :cvar messages: The dataset contains messages used for post-training. Examples: + Purpose of the dataset. Each type has a different column format. + :cvar post-training/messages: The dataset contains messages used for post-training. Examples: { "messages": [ {"role": "user", "content": "Hello, world!"}, @@ -25,11 +25,19 @@ class Schema(Enum): } """ - messages = "messages" + post_training_messages = "post-training/messages" + eval_question_answer = "eval/question-answer" + # TODO: add more schemas here class DatasetType(Enum): + """ + Type of the dataset source. + :cvar huggingface: The dataset is stored in Huggingface. + :cvar uri: The dataset can be obtained from a URI. + :cvar rows: The dataset is stored in rows. + """ huggingface = "huggingface" uri = "uri" rows = "rows" @@ -37,19 +45,36 @@ class DatasetType(Enum): @json_schema_type class URIDataSource(BaseModel): + """A dataset that can be obtained from a URI. + :param uri: The dataset can be obtained from a URI. E.g. + - "https://mywebsite.com/mydata.jsonl" + - "lsfs://mydata.jsonl" + - "data:csv;base64,{base64_content}" + """ type: Literal["uri"] = "uri" uri: str @json_schema_type class HuggingfaceDataSource(BaseModel): + """A dataset stored in Huggingface. + :param path: The path to the dataset in Huggingface. E.g. + - "llamastack/simpleqa" + :param params: The parameters for the dataset. + """ type: Literal["huggingface"] = "huggingface" - dataset_path: str + path: str params: Dict[str, Any] @json_schema_type class RowsDataSource(BaseModel): + """A dataset stored in rows. + :param rows: The dataset is stored in rows. E.g. + - [ + {"messages": [{"role": "user", "content": "Hello, world!"}, {"role": "assistant", "content": "Hello, world!"}]} + ] + """ type: Literal["rows"] = "rows" rows: List[Dict[str, Any]] @@ -64,8 +89,11 @@ DataSource = register_schema( class CommonDatasetFields(BaseModel): - schema: Schema - data_source: DataSource + """ + Common fields for a dataset. + """ + purpose: DatasetPurpose + source: DataSource metadata: Dict[str, Any] = Field( default_factory=dict, description="Any additional metadata for this dataset", @@ -99,17 +127,18 @@ class Datasets(Protocol): @webmethod(route="/datasets", method="POST") async def register_dataset( self, - schema: Schema, - data_source: DataSource, + purpose: DatasetPurpose, + source: DataSource, metadata: Optional[Dict[str, Any]] = None, dataset_id: Optional[str] = None, ) -> Dataset: """ Register a new dataset. - :param schema: The schema format of the dataset. One of - - messages: The dataset contains a messages column with list of messages for post-training. - :param data_source: The data source of the dataset. Examples: + :param purpose: The purpose of the dataset. One of + - "post-training/messages": The dataset contains a messages column with list of messages for post-training. + - "eval/question-answer": The dataset contains a question and answer column. + :param source: The data source of the dataset. Examples: - { "type": "uri", "uri": "https://mywebsite.com/mydata.jsonl" diff --git a/llama_stack/apis/inference/inference.py b/llama_stack/apis/inference/inference.py index d0f5d15c5..fa917ac22 100644 --- a/llama_stack/apis/inference/inference.py +++ b/llama_stack/apis/inference/inference.py @@ -285,7 +285,7 @@ class CompletionRequest(BaseModel): @json_schema_type -class CompletionResponse(BaseModel): +class CompletionResponse(MetricResponseMixin): """Response from a completion request. :param content: The generated completion text @@ -299,7 +299,7 @@ class CompletionResponse(BaseModel): @json_schema_type -class CompletionResponseStreamChunk(BaseModel): +class CompletionResponseStreamChunk(MetricResponseMixin): """A chunk of a streamed completion response. :param delta: New content generated since last chunk. This can be one or more tokens. @@ -368,7 +368,7 @@ class ChatCompletionRequest(BaseModel): @json_schema_type -class ChatCompletionResponseStreamChunk(MetricResponseMixin, BaseModel): +class ChatCompletionResponseStreamChunk(MetricResponseMixin): """A chunk of a streamed chat completion response. :param event: The event containing the new content @@ -378,7 +378,7 @@ class ChatCompletionResponseStreamChunk(MetricResponseMixin, BaseModel): @json_schema_type -class ChatCompletionResponse(MetricResponseMixin, BaseModel): +class ChatCompletionResponse(MetricResponseMixin): """Response from a chat completion request. :param completion_message: The complete response message diff --git a/llama_stack/apis/telemetry/telemetry.py b/llama_stack/apis/telemetry/telemetry.py index fe75677e7..cbea57e79 100644 --- a/llama_stack/apis/telemetry/telemetry.py +++ b/llama_stack/apis/telemetry/telemetry.py @@ -96,6 +96,13 @@ class MetricEvent(EventCommon): unit: str +@json_schema_type +class MetricInResponse(BaseModel): + metric: str + value: Union[int, float] + unit: Optional[str] = None + + # This is a short term solution to allow inference API to return metrics # The ideal way to do this is to have a way for all response types to include metrics # and all metric events logged to the telemetry API to be inlcuded with the response @@ -117,7 +124,7 @@ class MetricEvent(EventCommon): class MetricResponseMixin(BaseModel): - metrics: Optional[List[MetricEvent]] = None + metrics: Optional[List[MetricInResponse]] = None @json_schema_type diff --git a/llama_stack/distribution/library_client.py b/llama_stack/distribution/library_client.py index 5dc70bb67..15c4fe6ea 100644 --- a/llama_stack/distribution/library_client.py +++ b/llama_stack/distribution/library_client.py @@ -33,7 +33,7 @@ from llama_stack.distribution.build import print_pip_install_help from llama_stack.distribution.configure import parse_and_maybe_upgrade_config from llama_stack.distribution.datatypes import Api from llama_stack.distribution.request_headers import ( - preserve_headers_context_async_generator, + PROVIDER_DATA_VAR, request_provider_data_context, ) from llama_stack.distribution.resolver import ProviderRegistry @@ -44,8 +44,10 @@ from llama_stack.distribution.stack import ( redact_sensitive_fields, replace_env_vars, ) +from llama_stack.distribution.utils.context import preserve_contexts_async_generator from llama_stack.distribution.utils.exec import in_notebook from llama_stack.providers.utils.telemetry.tracing import ( + CURRENT_TRACE_CONTEXT, end_trace, setup_logger, start_trace, @@ -384,8 +386,8 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient): finally: await end_trace() - # Wrap the generator to preserve context across iterations - wrapped_gen = preserve_headers_context_async_generator(gen()) + wrapped_gen = preserve_contexts_async_generator(gen(), [CURRENT_TRACE_CONTEXT, PROVIDER_DATA_VAR]) + mock_response = httpx.Response( status_code=httpx.codes.OK, content=wrapped_gen, diff --git a/llama_stack/distribution/request_headers.py b/llama_stack/distribution/request_headers.py index 19afae59b..8709fc040 100644 --- a/llama_stack/distribution/request_headers.py +++ b/llama_stack/distribution/request_headers.py @@ -7,14 +7,14 @@ import contextvars import json import logging -from typing import Any, AsyncGenerator, ContextManager, Dict, Optional, TypeVar +from typing import Any, ContextManager, Dict, Optional from .utils.dynamic import instantiate_class_type log = logging.getLogger(__name__) # Context variable for request provider data -_provider_data_var = contextvars.ContextVar("provider_data", default=None) +PROVIDER_DATA_VAR = contextvars.ContextVar("provider_data", default=None) class RequestProviderDataContext(ContextManager): @@ -26,40 +26,13 @@ class RequestProviderDataContext(ContextManager): def __enter__(self): # Save the current value and set the new one - self.token = _provider_data_var.set(self.provider_data) + self.token = PROVIDER_DATA_VAR.set(self.provider_data) return self def __exit__(self, exc_type, exc_val, exc_tb): # Restore the previous value if self.token is not None: - _provider_data_var.reset(self.token) - - -T = TypeVar("T") - - -def preserve_headers_context_async_generator(gen: AsyncGenerator[T, None]) -> AsyncGenerator[T, None]: - """ - Wraps an async generator to preserve request headers context variables across iterations. - - This ensures that context variables set during generator creation are - available during each iteration of the generator, even if the original - context manager has exited. - """ - # Capture the current context value right now - context_value = _provider_data_var.get() - - async def wrapper(): - while True: - # Set context before each anext() call - _ = _provider_data_var.set(context_value) - try: - item = await gen.__anext__() - yield item - except StopAsyncIteration: - break - - return wrapper() + PROVIDER_DATA_VAR.reset(self.token) class NeedsRequestProviderData: @@ -72,7 +45,7 @@ class NeedsRequestProviderData: if not validator_class: raise ValueError(f"Provider {provider_type} does not have a validator") - val = _provider_data_var.get() + val = PROVIDER_DATA_VAR.get() if not val: return None diff --git a/llama_stack/distribution/resolver.py b/llama_stack/distribution/resolver.py index d7ca4414d..ab075f399 100644 --- a/llama_stack/distribution/resolver.py +++ b/llama_stack/distribution/resolver.py @@ -165,7 +165,9 @@ def specs_for_autorouted_apis(apis_to_serve: List[str] | Set[str]) -> Dict[str, module="llama_stack.distribution.routers", routing_table_api=info.routing_table_api, api_dependencies=[info.routing_table_api], - deps__=[info.routing_table_api.value], + # Add telemetry as an optional dependency to all auto-routed providers + optional_api_dependencies=[Api.telemetry], + deps__=([info.routing_table_api.value, Api.telemetry.value]), ), ) } diff --git a/llama_stack/distribution/routers/__init__.py b/llama_stack/distribution/routers/__init__.py index a54f57fb3..d0fca8771 100644 --- a/llama_stack/distribution/routers/__init__.py +++ b/llama_stack/distribution/routers/__init__.py @@ -45,7 +45,7 @@ async def get_routing_table_impl( return impl -async def get_auto_router_impl(api: Api, routing_table: RoutingTable, _deps) -> Any: +async def get_auto_router_impl(api: Api, routing_table: RoutingTable, deps: Dict[str, Any]) -> Any: from .routers import ( DatasetIORouter, EvalRouter, @@ -65,9 +65,17 @@ async def get_auto_router_impl(api: Api, routing_table: RoutingTable, _deps) -> "eval": EvalRouter, "tool_runtime": ToolRuntimeRouter, } + api_to_deps = { + "inference": {"telemetry": Api.telemetry}, + } if api.value not in api_to_routers: raise ValueError(f"API {api.value} not found in router map") - impl = api_to_routers[api.value](routing_table) + api_to_dep_impl = {} + for dep_name, dep_api in api_to_deps.get(api.value, {}).items(): + if dep_api in deps: + api_to_dep_impl[dep_name] = deps[dep_api] + + impl = api_to_routers[api.value](routing_table, **api_to_dep_impl) await impl.initialize() return impl diff --git a/llama_stack/distribution/routers/routers.py b/llama_stack/distribution/routers/routers.py index 28df67922..22a1e46f9 100644 --- a/llama_stack/distribution/routers/routers.py +++ b/llama_stack/distribution/routers/routers.py @@ -4,7 +4,8 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from typing import Any, AsyncGenerator, Dict, List, Optional +import time +from typing import Any, AsyncGenerator, AsyncIterator, Dict, List, Optional, Union from llama_stack.apis.common.content_types import ( URL, @@ -20,6 +21,10 @@ from llama_stack.apis.eval import ( JobStatus, ) from llama_stack.apis.inference import ( + ChatCompletionResponse, + ChatCompletionResponseEventType, + ChatCompletionResponseStreamChunk, + CompletionMessage, EmbeddingsResponse, EmbeddingTaskType, Inference, @@ -27,13 +32,14 @@ from llama_stack.apis.inference import ( Message, ResponseFormat, SamplingParams, + StopReason, TextTruncation, ToolChoice, ToolConfig, ToolDefinition, ToolPromptFormat, ) -from llama_stack.apis.models import ModelType +from llama_stack.apis.models import Model, ModelType from llama_stack.apis.safety import RunShieldResponse, Safety from llama_stack.apis.scoring import ( ScoreBatchResponse, @@ -42,6 +48,7 @@ from llama_stack.apis.scoring import ( ScoringFnParams, ) from llama_stack.apis.shields import Shield +from llama_stack.apis.telemetry import MetricEvent, MetricInResponse, Telemetry from llama_stack.apis.tools import ( RAGDocument, RAGQueryConfig, @@ -52,7 +59,10 @@ from llama_stack.apis.tools import ( ) from llama_stack.apis.vector_io import Chunk, QueryChunksResponse, VectorIO from llama_stack.log import get_logger +from llama_stack.models.llama.llama3.chat_format import ChatFormat +from llama_stack.models.llama.llama3.tokenizer import Tokenizer from llama_stack.providers.datatypes import RoutingTable +from llama_stack.providers.utils.telemetry.tracing import get_current_span logger = get_logger(name=__name__, category="core") @@ -119,9 +129,14 @@ class InferenceRouter(Inference): def __init__( self, routing_table: RoutingTable, + telemetry: Optional[Telemetry] = None, ) -> None: logger.debug("Initializing InferenceRouter") self.routing_table = routing_table + self.telemetry = telemetry + if self.telemetry: + self.tokenizer = Tokenizer.get_instance() + self.formatter = ChatFormat(self.tokenizer) async def initialize(self) -> None: logger.debug("InferenceRouter.initialize") @@ -144,6 +159,71 @@ class InferenceRouter(Inference): ) await self.routing_table.register_model(model_id, provider_model_id, provider_id, metadata, model_type) + def _construct_metrics( + self, prompt_tokens: int, completion_tokens: int, total_tokens: int, model: Model + ) -> List[MetricEvent]: + """Constructs a list of MetricEvent objects containing token usage metrics. + + Args: + prompt_tokens: Number of tokens in the prompt + completion_tokens: Number of tokens in the completion + total_tokens: Total number of tokens used + model: Model object containing model_id and provider_id + + Returns: + List of MetricEvent objects with token usage metrics + """ + span = get_current_span() + if span is None: + logger.warning("No span found for token usage metrics") + return [] + metrics = [ + ("prompt_tokens", prompt_tokens), + ("completion_tokens", completion_tokens), + ("total_tokens", total_tokens), + ] + metric_events = [] + for metric_name, value in metrics: + metric_events.append( + MetricEvent( + trace_id=span.trace_id, + span_id=span.span_id, + metric=metric_name, + value=value, + timestamp=time.time(), + unit="tokens", + attributes={ + "model_id": model.model_id, + "provider_id": model.provider_id, + }, + ) + ) + return metric_events + + async def _compute_and_log_token_usage( + self, + prompt_tokens: int, + completion_tokens: int, + total_tokens: int, + model: Model, + ) -> List[MetricInResponse]: + metrics = self._construct_metrics(prompt_tokens, completion_tokens, total_tokens, model) + if self.telemetry: + for metric in metrics: + await self.telemetry.log_event(metric) + return [MetricInResponse(metric=metric.metric, value=metric.value) for metric in metrics] + + async def _count_tokens( + self, + messages: List[Message] | InterleavedContent, + tool_prompt_format: Optional[ToolPromptFormat] = None, + ) -> Optional[int]: + if isinstance(messages, list): + encoded = self.formatter.encode_dialog_prompt(messages, tool_prompt_format) + else: + encoded = self.formatter.encode_content(messages) + return len(encoded.tokens) if encoded and encoded.tokens else 0 + async def chat_completion( self, model_id: str, @@ -156,7 +236,7 @@ class InferenceRouter(Inference): stream: Optional[bool] = False, logprobs: Optional[LogProbConfig] = None, tool_config: Optional[ToolConfig] = None, - ) -> AsyncGenerator: + ) -> Union[ChatCompletionResponse, AsyncIterator[ChatCompletionResponseStreamChunk]]: logger.debug( f"InferenceRouter.chat_completion: {model_id=}, {stream=}, {messages=}, {tools=}, {tool_config=}, {response_format=}", ) @@ -206,10 +286,47 @@ class InferenceRouter(Inference): tool_config=tool_config, ) provider = self.routing_table.get_provider_impl(model_id) + prompt_tokens = await self._count_tokens(messages, tool_config.tool_prompt_format) + if stream: - return (chunk async for chunk in await provider.chat_completion(**params)) + + async def stream_generator(): + completion_text = "" + async for chunk in await provider.chat_completion(**params): + if chunk.event.event_type == ChatCompletionResponseEventType.progress: + if chunk.event.delta.type == "text": + completion_text += chunk.event.delta.text + if chunk.event.event_type == ChatCompletionResponseEventType.complete: + completion_tokens = await self._count_tokens( + [CompletionMessage(content=completion_text, stop_reason=StopReason.end_of_turn)], + tool_config.tool_prompt_format, + ) + total_tokens = (prompt_tokens or 0) + (completion_tokens or 0) + metrics = await self._compute_and_log_token_usage( + prompt_tokens or 0, + completion_tokens or 0, + total_tokens, + model, + ) + chunk.metrics = metrics if chunk.metrics is None else chunk.metrics + metrics + yield chunk + + return stream_generator() else: - return await provider.chat_completion(**params) + response = await provider.chat_completion(**params) + completion_tokens = await self._count_tokens( + [response.completion_message], + tool_config.tool_prompt_format, + ) + total_tokens = (prompt_tokens or 0) + (completion_tokens or 0) + metrics = await self._compute_and_log_token_usage( + prompt_tokens or 0, + completion_tokens or 0, + total_tokens, + model, + ) + response.metrics = metrics if response.metrics is None else response.metrics + metrics + return response async def completion( self, @@ -239,10 +356,41 @@ class InferenceRouter(Inference): stream=stream, logprobs=logprobs, ) + + prompt_tokens = await self._count_tokens(content) + if stream: - return (chunk async for chunk in await provider.completion(**params)) + + async def stream_generator(): + completion_text = "" + async for chunk in await provider.completion(**params): + if hasattr(chunk, "delta"): + completion_text += chunk.delta + if hasattr(chunk, "stop_reason") and chunk.stop_reason and self.telemetry: + completion_tokens = await self._count_tokens(completion_text) + total_tokens = (prompt_tokens or 0) + (completion_tokens or 0) + metrics = await self._compute_and_log_token_usage( + prompt_tokens or 0, + completion_tokens or 0, + total_tokens, + model, + ) + chunk.metrics = metrics if chunk.metrics is None else chunk.metrics + metrics + yield chunk + + return stream_generator() else: - return await provider.completion(**params) + response = await provider.completion(**params) + completion_tokens = await self._count_tokens(response.content) + total_tokens = (prompt_tokens or 0) + (completion_tokens or 0) + metrics = await self._compute_and_log_token_usage( + prompt_tokens or 0, + completion_tokens or 0, + total_tokens, + model, + ) + response.metrics = metrics if response.metrics is None else response.metrics + metrics + return response async def embeddings( self, diff --git a/llama_stack/distribution/server/server.py b/llama_stack/distribution/server/server.py index 2cc70a738..7ca009b13 100644 --- a/llama_stack/distribution/server/server.py +++ b/llama_stack/distribution/server/server.py @@ -28,7 +28,7 @@ from typing_extensions import Annotated from llama_stack.distribution.datatypes import StackRunConfig from llama_stack.distribution.distribution import builtin_automatically_routed_apis from llama_stack.distribution.request_headers import ( - preserve_headers_context_async_generator, + PROVIDER_DATA_VAR, request_provider_data_context, ) from llama_stack.distribution.resolver import InvalidProviderError @@ -38,6 +38,7 @@ from llama_stack.distribution.stack import ( replace_env_vars, validate_env_pair, ) +from llama_stack.distribution.utils.context import preserve_contexts_async_generator from llama_stack.log import get_logger from llama_stack.providers.datatypes import Api from llama_stack.providers.inline.telemetry.meta_reference.config import TelemetryConfig @@ -45,6 +46,7 @@ from llama_stack.providers.inline.telemetry.meta_reference.telemetry import ( TelemetryAdapter, ) from llama_stack.providers.utils.telemetry.tracing import ( + CURRENT_TRACE_CONTEXT, end_trace, setup_logger, start_trace, @@ -182,7 +184,9 @@ def create_dynamic_typed_route(func: Any, method: str, route: str): try: if is_streaming: - gen = preserve_headers_context_async_generator(sse_generator(func(**kwargs))) + gen = preserve_contexts_async_generator( + sse_generator(func(**kwargs)), [CURRENT_TRACE_CONTEXT, PROVIDER_DATA_VAR] + ) return StreamingResponse(gen, media_type="text/event-stream") else: value = func(**kwargs) diff --git a/llama_stack/distribution/utils/context.py b/llama_stack/distribution/utils/context.py new file mode 100644 index 000000000..2f32afba2 --- /dev/null +++ b/llama_stack/distribution/utils/context.py @@ -0,0 +1,33 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from contextvars import ContextVar +from typing import AsyncGenerator, List, TypeVar + +T = TypeVar("T") + + +def preserve_contexts_async_generator( + gen: AsyncGenerator[T, None], context_vars: List[ContextVar] +) -> AsyncGenerator[T, None]: + """ + Wraps an async generator to preserve context variables across iterations. + This is needed because we start a new asyncio event loop for each streaming request, + and we need to preserve the context across the event loop boundary. + """ + + async def wrapper() -> AsyncGenerator[T, None]: + while True: + try: + item = await gen.__anext__() + context_values = {context_var.name: context_var.get() for context_var in context_vars} + yield item + for context_var in context_vars: + _ = context_var.set(context_values[context_var.name]) + except StopAsyncIteration: + break + + return wrapper() diff --git a/llama_stack/distribution/utils/tests/test_context.py b/llama_stack/distribution/utils/tests/test_context.py new file mode 100644 index 000000000..84944bfe8 --- /dev/null +++ b/llama_stack/distribution/utils/tests/test_context.py @@ -0,0 +1,155 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import asyncio +from concurrent.futures import ThreadPoolExecutor +from contextvars import ContextVar + +import pytest + +from llama_stack.distribution.utils.context import preserve_contexts_async_generator + + +@pytest.mark.asyncio +async def test_preserve_contexts_with_exception(): + # Create context variable + context_var = ContextVar("exception_var", default="initial") + token = context_var.set("start_value") + + # Create an async generator that raises an exception + async def exception_generator(): + yield context_var.get() + context_var.set("modified") + raise ValueError("Test exception") + yield None # This will never be reached + + # Wrap the generator + wrapped_gen = preserve_contexts_async_generator(exception_generator(), [context_var]) + + # First iteration should work + value = await wrapped_gen.__anext__() + assert value == "start_value" + + # Second iteration should raise the exception + with pytest.raises(ValueError, match="Test exception"): + await wrapped_gen.__anext__() + + # Clean up + context_var.reset(token) + + +@pytest.mark.asyncio +async def test_preserve_contexts_empty_generator(): + # Create context variable + context_var = ContextVar("empty_var", default="initial") + token = context_var.set("value") + + # Create an empty async generator + async def empty_generator(): + if False: # This condition ensures the generator yields nothing + yield None + + # Wrap the generator + wrapped_gen = preserve_contexts_async_generator(empty_generator(), [context_var]) + + # The generator should raise StopAsyncIteration immediately + with pytest.raises(StopAsyncIteration): + await wrapped_gen.__anext__() + + # Context variable should remain unchanged + assert context_var.get() == "value" + + # Clean up + context_var.reset(token) + + +@pytest.mark.asyncio +async def test_preserve_contexts_across_event_loops(): + """ + Test that context variables are preserved across event loop boundaries with nested generators. + This simulates the real-world scenario where: + 1. A new event loop is created for each streaming request + 2. The async generator runs inside that loop + 3. There are multiple levels of nested generators + 4. Context needs to be preserved across these boundaries + """ + # Create context variables + request_id = ContextVar("request_id", default=None) + user_id = ContextVar("user_id", default=None) + + # Set initial values + + # Results container to verify values across thread boundaries + results = [] + + # Inner-most generator (level 2) + async def inner_generator(): + # Should have the context from the outer scope + yield (1, request_id.get(), user_id.get()) + + # Modify one context variable + user_id.set("user-modified") + + # Should reflect the modification + yield (2, request_id.get(), user_id.get()) + + # Middle generator (level 1) + async def middle_generator(): + inner_gen = inner_generator() + + # Forward the first yield from inner + item = await inner_gen.__anext__() + yield item + + # Forward the second yield from inner + item = await inner_gen.__anext__() + yield item + + request_id.set("req-modified") + + # Add our own yield with both modified variables + yield (3, request_id.get(), user_id.get()) + + # Function to run in a separate thread with a new event loop + def run_in_new_loop(): + # Create a new event loop for this thread + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + + try: + # Outer generator (runs in the new loop) + async def outer_generator(): + request_id.set("req-12345") + user_id.set("user-6789") + # Wrap the middle generator + wrapped_gen = preserve_contexts_async_generator(middle_generator(), [request_id, user_id]) + + # Process all items from the middle generator + async for item in wrapped_gen: + # Store results for verification + results.append(item) + + # Run the outer generator in the new loop + loop.run_until_complete(outer_generator()) + finally: + loop.close() + + # Run the generator chain in a separate thread with a new event loop + with ThreadPoolExecutor(max_workers=1) as executor: + future = executor.submit(run_in_new_loop) + future.result() # Wait for completion + + # Verify the results + assert len(results) == 3 + + # First yield should have original values + assert results[0] == (1, "req-12345", "user-6789") + + # Second yield should have modified user_id + assert results[1] == (2, "req-12345", "user-modified") + + # Third yield should have both modified values + assert results[2] == (3, "req-modified", "user-modified") diff --git a/llama_stack/providers/inline/agents/meta_reference/agent_instance.py b/llama_stack/providers/inline/agents/meta_reference/agent_instance.py index fedd695c1..1d9f54e96 100644 --- a/llama_stack/providers/inline/agents/meta_reference/agent_instance.py +++ b/llama_stack/providers/inline/agents/meta_reference/agent_instance.py @@ -218,18 +218,10 @@ class ChatAgent(ShieldRunnerMixin): steps = [] messages = await self.get_messages_from_turns(turns) if is_resume: - if isinstance(request.tool_responses[0], ToolResponseMessage): - tool_response_messages = request.tool_responses - tool_responses = [ - ToolResponse(call_id=x.call_id, tool_name=x.tool_name, content=x.content) - for x in request.tool_responses - ] - else: - tool_response_messages = [ - ToolResponseMessage(call_id=x.call_id, tool_name=x.tool_name, content=x.content) - for x in request.tool_responses - ] - tool_responses = request.tool_responses + tool_response_messages = [ + ToolResponseMessage(call_id=x.call_id, tool_name=x.tool_name, content=x.content) + for x in request.tool_responses + ] messages.extend(tool_response_messages) last_turn = turns[-1] last_turn_messages = self.turn_to_messages(last_turn) @@ -252,7 +244,7 @@ class ChatAgent(ShieldRunnerMixin): step_id=(in_progress_tool_call_step.step_id if in_progress_tool_call_step else str(uuid.uuid4())), turn_id=request.turn_id, tool_calls=(in_progress_tool_call_step.tool_calls if in_progress_tool_call_step else []), - tool_responses=tool_responses, + tool_responses=request.tool_responses, completed_at=now, started_at=(in_progress_tool_call_step.started_at if in_progress_tool_call_step else now), ) diff --git a/llama_stack/providers/inline/agents/meta_reference/agents.py b/llama_stack/providers/inline/agents/meta_reference/agents.py index c24b14e35..5ca123595 100644 --- a/llama_stack/providers/inline/agents/meta_reference/agents.py +++ b/llama_stack/providers/inline/agents/meta_reference/agents.py @@ -172,7 +172,7 @@ class MetaReferenceAgentsImpl(Agents): agent_id: str, session_id: str, turn_id: str, - tool_responses: Union[List[ToolResponse], List[ToolResponseMessage]], + tool_responses: List[ToolResponse], stream: Optional[bool] = False, ) -> AsyncGenerator: request = AgentTurnResumeRequest( diff --git a/llama_stack/providers/inline/telemetry/meta_reference/telemetry.py b/llama_stack/providers/inline/telemetry/meta_reference/telemetry.py index e713a057f..4cdb420b2 100644 --- a/llama_stack/providers/inline/telemetry/meta_reference/telemetry.py +++ b/llama_stack/providers/inline/telemetry/meta_reference/telemetry.py @@ -73,6 +73,7 @@ class TelemetryAdapter(TelemetryDatasetMixin, Telemetry): def __init__(self, config: TelemetryConfig, deps: Dict[str, Any]) -> None: self.config = config self.datasetio_api = deps.get(Api.datasetio) + self.meter = None resource = Resource.create( { @@ -171,6 +172,8 @@ class TelemetryAdapter(TelemetryDatasetMixin, Telemetry): return _GLOBAL_STORAGE["gauges"][name] def _log_metric(self, event: MetricEvent) -> None: + if self.meter is None: + return if isinstance(event.value, int): counter = self._get_or_create_counter(event.metric, event.unit) counter.add(event.value, attributes=event.attributes) diff --git a/llama_stack/templates/open-benchmark/open_benchmark.py b/llama_stack/templates/open-benchmark/open_benchmark.py index 7df33a715..2b40797f9 100644 --- a/llama_stack/templates/open-benchmark/open_benchmark.py +++ b/llama_stack/templates/open-benchmark/open_benchmark.py @@ -4,8 +4,9 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from typing import List, Tuple +from typing import Dict, List, Tuple +from llama_stack.apis.common.content_types import URL from llama_stack.apis.models.models import ModelType from llama_stack.distribution.datatypes import ( BenchmarkInput, @@ -15,21 +16,27 @@ from llama_stack.distribution.datatypes import ( ShieldInput, ToolGroupInput, ) -from llama_stack.providers.inline.vector_io.sqlite_vec.config import SQLiteVectorIOConfig +from llama_stack.providers.inline.vector_io.sqlite_vec.config import ( + SQLiteVectorIOConfig, +) from llama_stack.providers.remote.inference.anthropic.config import AnthropicConfig from llama_stack.providers.remote.inference.gemini.config import GeminiConfig from llama_stack.providers.remote.inference.groq.config import GroqConfig from llama_stack.providers.remote.inference.openai.config import OpenAIConfig from llama_stack.providers.remote.inference.together.config import TogetherImplConfig from llama_stack.providers.remote.vector_io.chroma.config import ChromaVectorIOConfig -from llama_stack.providers.remote.vector_io.pgvector.config import PGVectorVectorIOConfig -from llama_stack.providers.utils.inference.model_registry import ( - ProviderModelEntry, +from llama_stack.providers.remote.vector_io.pgvector.config import ( + PGVectorVectorIOConfig, +) +from llama_stack.providers.utils.inference.model_registry import ProviderModelEntry +from llama_stack.templates.template import ( + DistributionTemplate, + RunConfigSettings, + get_model_registry, ) -from llama_stack.templates.template import DistributionTemplate, RunConfigSettings, get_model_registry -def get_inference_providers() -> Tuple[List[Provider], List[ModelInput]]: +def get_inference_providers() -> Tuple[List[Provider], Dict[str, List[ProviderModelEntry]]]: # in this template, we allow each API key to be optional providers = [ ( @@ -164,7 +171,7 @@ def get_distribution_template() -> DistributionTemplate: DatasetInput( dataset_id="simpleqa", provider_id="huggingface", - url={"uri": "https://huggingface.co/datasets/llamastack/simpleqa"}, + url=URL(uri="https://huggingface.co/datasets/llamastack/simpleqa"), metadata={ "path": "llamastack/simpleqa", "split": "train", @@ -178,7 +185,7 @@ def get_distribution_template() -> DistributionTemplate: DatasetInput( dataset_id="mmlu_cot", provider_id="huggingface", - url={"uri": "https://huggingface.co/datasets/llamastack/mmlu_cot"}, + url=URL(uri="https://huggingface.co/datasets/llamastack/mmlu_cot"), metadata={ "path": "llamastack/mmlu_cot", "name": "all", @@ -193,7 +200,7 @@ def get_distribution_template() -> DistributionTemplate: DatasetInput( dataset_id="gpqa_cot", provider_id="huggingface", - url={"uri": "https://huggingface.co/datasets/llamastack/gpqa_0shot_cot"}, + url=URL(uri="https://huggingface.co/datasets/llamastack/gpqa_0shot_cot"), metadata={ "path": "llamastack/gpqa_0shot_cot", "name": "gpqa_main", @@ -208,7 +215,7 @@ def get_distribution_template() -> DistributionTemplate: DatasetInput( dataset_id="math_500", provider_id="huggingface", - url={"uri": "https://huggingface.co/datasets/llamastack/math_500"}, + url=URL(uri="https://huggingface.co/datasets/llamastack/math_500"), metadata={ "path": "llamastack/math_500", "split": "test", diff --git a/llama_stack/templates/template.py b/llama_stack/templates/template.py index aa1ce144f..a5c8e80bc 100644 --- a/llama_stack/templates/template.py +++ b/llama_stack/templates/template.py @@ -30,7 +30,9 @@ from llama_stack.providers.utils.inference.model_registry import ProviderModelEn from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig -def get_model_registry(available_models: Dict[str, List[ProviderModelEntry]]) -> List[ModelInput]: +def get_model_registry( + available_models: Dict[str, List[ProviderModelEntry]], +) -> List[ModelInput]: models = [] for provider_id, entries in available_models.items(): for entry in entries: @@ -193,7 +195,7 @@ class DistributionTemplate(BaseModel): default_models.append( DefaultModel( model_id=model_entry.provider_model_id, - doc_string=f"({' -- '.join(doc_parts)})" if doc_parts else "", + doc_string=(f"({' -- '.join(doc_parts)})" if doc_parts else ""), ) )