diff --git a/llama_stack/apis/inference/inference.py b/llama_stack/apis/inference/inference.py index d0f5d15c5..fa917ac22 100644 --- a/llama_stack/apis/inference/inference.py +++ b/llama_stack/apis/inference/inference.py @@ -285,7 +285,7 @@ class CompletionRequest(BaseModel): @json_schema_type -class CompletionResponse(BaseModel): +class CompletionResponse(MetricResponseMixin): """Response from a completion request. :param content: The generated completion text @@ -299,7 +299,7 @@ class CompletionResponse(BaseModel): @json_schema_type -class CompletionResponseStreamChunk(BaseModel): +class CompletionResponseStreamChunk(MetricResponseMixin): """A chunk of a streamed completion response. :param delta: New content generated since last chunk. This can be one or more tokens. @@ -368,7 +368,7 @@ class ChatCompletionRequest(BaseModel): @json_schema_type -class ChatCompletionResponseStreamChunk(MetricResponseMixin, BaseModel): +class ChatCompletionResponseStreamChunk(MetricResponseMixin): """A chunk of a streamed chat completion response. :param event: The event containing the new content @@ -378,7 +378,7 @@ class ChatCompletionResponseStreamChunk(MetricResponseMixin, BaseModel): @json_schema_type -class ChatCompletionResponse(MetricResponseMixin, BaseModel): +class ChatCompletionResponse(MetricResponseMixin): """Response from a chat completion request. :param completion_message: The complete response message diff --git a/llama_stack/distribution/library_client.py b/llama_stack/distribution/library_client.py index 5dc70bb67..15c4fe6ea 100644 --- a/llama_stack/distribution/library_client.py +++ b/llama_stack/distribution/library_client.py @@ -33,7 +33,7 @@ from llama_stack.distribution.build import print_pip_install_help from llama_stack.distribution.configure import parse_and_maybe_upgrade_config from llama_stack.distribution.datatypes import Api from llama_stack.distribution.request_headers import ( - preserve_headers_context_async_generator, + PROVIDER_DATA_VAR, request_provider_data_context, ) from llama_stack.distribution.resolver import ProviderRegistry @@ -44,8 +44,10 @@ from llama_stack.distribution.stack import ( redact_sensitive_fields, replace_env_vars, ) +from llama_stack.distribution.utils.context import preserve_contexts_async_generator from llama_stack.distribution.utils.exec import in_notebook from llama_stack.providers.utils.telemetry.tracing import ( + CURRENT_TRACE_CONTEXT, end_trace, setup_logger, start_trace, @@ -384,8 +386,8 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient): finally: await end_trace() - # Wrap the generator to preserve context across iterations - wrapped_gen = preserve_headers_context_async_generator(gen()) + wrapped_gen = preserve_contexts_async_generator(gen(), [CURRENT_TRACE_CONTEXT, PROVIDER_DATA_VAR]) + mock_response = httpx.Response( status_code=httpx.codes.OK, content=wrapped_gen, diff --git a/llama_stack/distribution/request_headers.py b/llama_stack/distribution/request_headers.py index 19afae59b..8709fc040 100644 --- a/llama_stack/distribution/request_headers.py +++ b/llama_stack/distribution/request_headers.py @@ -7,14 +7,14 @@ import contextvars import json import logging -from typing import Any, AsyncGenerator, ContextManager, Dict, Optional, TypeVar +from typing import Any, ContextManager, Dict, Optional from .utils.dynamic import instantiate_class_type log = logging.getLogger(__name__) # Context variable for request provider data -_provider_data_var = contextvars.ContextVar("provider_data", default=None) +PROVIDER_DATA_VAR = contextvars.ContextVar("provider_data", default=None) class RequestProviderDataContext(ContextManager): @@ -26,40 +26,13 @@ class RequestProviderDataContext(ContextManager): def __enter__(self): # Save the current value and set the new one - self.token = _provider_data_var.set(self.provider_data) + self.token = PROVIDER_DATA_VAR.set(self.provider_data) return self def __exit__(self, exc_type, exc_val, exc_tb): # Restore the previous value if self.token is not None: - _provider_data_var.reset(self.token) - - -T = TypeVar("T") - - -def preserve_headers_context_async_generator(gen: AsyncGenerator[T, None]) -> AsyncGenerator[T, None]: - """ - Wraps an async generator to preserve request headers context variables across iterations. - - This ensures that context variables set during generator creation are - available during each iteration of the generator, even if the original - context manager has exited. - """ - # Capture the current context value right now - context_value = _provider_data_var.get() - - async def wrapper(): - while True: - # Set context before each anext() call - _ = _provider_data_var.set(context_value) - try: - item = await gen.__anext__() - yield item - except StopAsyncIteration: - break - - return wrapper() + PROVIDER_DATA_VAR.reset(self.token) class NeedsRequestProviderData: @@ -72,7 +45,7 @@ class NeedsRequestProviderData: if not validator_class: raise ValueError(f"Provider {provider_type} does not have a validator") - val = _provider_data_var.get() + val = PROVIDER_DATA_VAR.get() if not val: return None diff --git a/llama_stack/distribution/resolver.py b/llama_stack/distribution/resolver.py index d7ca4414d..ab075f399 100644 --- a/llama_stack/distribution/resolver.py +++ b/llama_stack/distribution/resolver.py @@ -165,7 +165,9 @@ def specs_for_autorouted_apis(apis_to_serve: List[str] | Set[str]) -> Dict[str, module="llama_stack.distribution.routers", routing_table_api=info.routing_table_api, api_dependencies=[info.routing_table_api], - deps__=[info.routing_table_api.value], + # Add telemetry as an optional dependency to all auto-routed providers + optional_api_dependencies=[Api.telemetry], + deps__=([info.routing_table_api.value, Api.telemetry.value]), ), ) } diff --git a/llama_stack/distribution/routers/__init__.py b/llama_stack/distribution/routers/__init__.py index a54f57fb3..d0fca8771 100644 --- a/llama_stack/distribution/routers/__init__.py +++ b/llama_stack/distribution/routers/__init__.py @@ -45,7 +45,7 @@ async def get_routing_table_impl( return impl -async def get_auto_router_impl(api: Api, routing_table: RoutingTable, _deps) -> Any: +async def get_auto_router_impl(api: Api, routing_table: RoutingTable, deps: Dict[str, Any]) -> Any: from .routers import ( DatasetIORouter, EvalRouter, @@ -65,9 +65,17 @@ async def get_auto_router_impl(api: Api, routing_table: RoutingTable, _deps) -> "eval": EvalRouter, "tool_runtime": ToolRuntimeRouter, } + api_to_deps = { + "inference": {"telemetry": Api.telemetry}, + } if api.value not in api_to_routers: raise ValueError(f"API {api.value} not found in router map") - impl = api_to_routers[api.value](routing_table) + api_to_dep_impl = {} + for dep_name, dep_api in api_to_deps.get(api.value, {}).items(): + if dep_api in deps: + api_to_dep_impl[dep_name] = deps[dep_api] + + impl = api_to_routers[api.value](routing_table, **api_to_dep_impl) await impl.initialize() return impl diff --git a/llama_stack/distribution/routers/routers.py b/llama_stack/distribution/routers/routers.py index 28df67922..68b8e55cb 100644 --- a/llama_stack/distribution/routers/routers.py +++ b/llama_stack/distribution/routers/routers.py @@ -4,7 +4,8 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from typing import Any, AsyncGenerator, Dict, List, Optional +import time +from typing import Any, AsyncGenerator, AsyncIterator, Dict, List, Optional, Union from llama_stack.apis.common.content_types import ( URL, @@ -20,6 +21,10 @@ from llama_stack.apis.eval import ( JobStatus, ) from llama_stack.apis.inference import ( + ChatCompletionResponse, + ChatCompletionResponseEventType, + ChatCompletionResponseStreamChunk, + CompletionMessage, EmbeddingsResponse, EmbeddingTaskType, Inference, @@ -27,13 +32,14 @@ from llama_stack.apis.inference import ( Message, ResponseFormat, SamplingParams, + StopReason, TextTruncation, ToolChoice, ToolConfig, ToolDefinition, ToolPromptFormat, ) -from llama_stack.apis.models import ModelType +from llama_stack.apis.models import Model, ModelType from llama_stack.apis.safety import RunShieldResponse, Safety from llama_stack.apis.scoring import ( ScoreBatchResponse, @@ -42,6 +48,7 @@ from llama_stack.apis.scoring import ( ScoringFnParams, ) from llama_stack.apis.shields import Shield +from llama_stack.apis.telemetry import MetricEvent, Telemetry from llama_stack.apis.tools import ( RAGDocument, RAGQueryConfig, @@ -52,7 +59,10 @@ from llama_stack.apis.tools import ( ) from llama_stack.apis.vector_io import Chunk, QueryChunksResponse, VectorIO from llama_stack.log import get_logger +from llama_stack.models.llama.llama3.chat_format import ChatFormat +from llama_stack.models.llama.llama3.tokenizer import Tokenizer from llama_stack.providers.datatypes import RoutingTable +from llama_stack.providers.utils.telemetry.tracing import get_current_span logger = get_logger(name=__name__, category="core") @@ -119,9 +129,14 @@ class InferenceRouter(Inference): def __init__( self, routing_table: RoutingTable, + telemetry: Optional[Telemetry] = None, ) -> None: logger.debug("Initializing InferenceRouter") self.routing_table = routing_table + self.telemetry = telemetry + if self.telemetry: + self.tokenizer = Tokenizer.get_instance() + self.formatter = ChatFormat(self.tokenizer) async def initialize(self) -> None: logger.debug("InferenceRouter.initialize") @@ -144,6 +159,71 @@ class InferenceRouter(Inference): ) await self.routing_table.register_model(model_id, provider_model_id, provider_id, metadata, model_type) + def _construct_metrics( + self, prompt_tokens: int, completion_tokens: int, total_tokens: int, model: Model + ) -> List[MetricEvent]: + """Constructs a list of MetricEvent objects containing token usage metrics. + + Args: + prompt_tokens: Number of tokens in the prompt + completion_tokens: Number of tokens in the completion + total_tokens: Total number of tokens used + model: Model object containing model_id and provider_id + + Returns: + List of MetricEvent objects with token usage metrics + """ + span = get_current_span() + if span is None: + logger.warning("No span found for token usage metrics") + return [] + metrics = [ + ("prompt_tokens", prompt_tokens), + ("completion_tokens", completion_tokens), + ("total_tokens", total_tokens), + ] + metric_events = [] + for metric_name, value in metrics: + metric_events.append( + MetricEvent( + trace_id=span.trace_id, + span_id=span.span_id, + metric=metric_name, + value=value, + timestamp=time.time(), + unit="tokens", + attributes={ + "model_id": model.model_id, + "provider_id": model.provider_id, + }, + ) + ) + return metric_events + + async def _compute_and_log_token_usage( + self, + prompt_tokens: int, + completion_tokens: int, + total_tokens: int, + model: Model, + ) -> List[MetricEvent]: + metrics = self._construct_metrics(prompt_tokens, completion_tokens, total_tokens, model) + if self.telemetry: + for metric in metrics: + await self.telemetry.log_event(metric) + return metrics + + async def _count_tokens( + self, + messages: List[Message] | InterleavedContent, + tool_prompt_format: Optional[ToolPromptFormat] = None, + ) -> Optional[int]: + if isinstance(messages, list): + encoded = self.formatter.encode_dialog_prompt(messages, tool_prompt_format) + else: + encoded = self.formatter.encode_content(messages) + return len(encoded.tokens) if encoded and encoded.tokens else 0 + async def chat_completion( self, model_id: str, @@ -156,8 +236,9 @@ class InferenceRouter(Inference): stream: Optional[bool] = False, logprobs: Optional[LogProbConfig] = None, tool_config: Optional[ToolConfig] = None, - ) -> AsyncGenerator: + ) -> Union[ChatCompletionResponse, AsyncIterator[ChatCompletionResponseStreamChunk]]: logger.debug( + "core", f"InferenceRouter.chat_completion: {model_id=}, {stream=}, {messages=}, {tools=}, {tool_config=}, {response_format=}", ) if sampling_params is None: @@ -206,10 +287,47 @@ class InferenceRouter(Inference): tool_config=tool_config, ) provider = self.routing_table.get_provider_impl(model_id) + prompt_tokens = await self._count_tokens(messages, tool_config.tool_prompt_format) + if stream: - return (chunk async for chunk in await provider.chat_completion(**params)) + + async def stream_generator(): + completion_text = "" + async for chunk in await provider.chat_completion(**params): + if chunk.event.event_type == ChatCompletionResponseEventType.progress: + if chunk.event.delta.type == "text": + completion_text += chunk.event.delta.text + if chunk.event.event_type == ChatCompletionResponseEventType.complete: + completion_tokens = await self._count_tokens( + [CompletionMessage(content=completion_text, stop_reason=StopReason.end_of_turn)], + tool_config.tool_prompt_format, + ) + total_tokens = (prompt_tokens or 0) + (completion_tokens or 0) + metrics = await self._compute_and_log_token_usage( + prompt_tokens or 0, + completion_tokens or 0, + total_tokens, + model, + ) + chunk.metrics = metrics if chunk.metrics is None else chunk.metrics + metrics + yield chunk + + return stream_generator() else: - return await provider.chat_completion(**params) + response = await provider.chat_completion(**params) + completion_tokens = await self._count_tokens( + [response.completion_message], + tool_config.tool_prompt_format, + ) + total_tokens = (prompt_tokens or 0) + (completion_tokens or 0) + metrics = await self._compute_and_log_token_usage( + prompt_tokens or 0, + completion_tokens or 0, + total_tokens, + model, + ) + response.metrics = metrics if response.metrics is None else response.metrics + metrics + return response async def completion( self, @@ -239,10 +357,41 @@ class InferenceRouter(Inference): stream=stream, logprobs=logprobs, ) + + prompt_tokens = await self._count_tokens(content) + if stream: - return (chunk async for chunk in await provider.completion(**params)) + + async def stream_generator(): + completion_text = "" + async for chunk in await provider.completion(**params): + if hasattr(chunk, "delta"): + completion_text += chunk.delta + if hasattr(chunk, "stop_reason") and chunk.stop_reason and self.telemetry: + completion_tokens = await self._count_tokens(completion_text) + total_tokens = (prompt_tokens or 0) + (completion_tokens or 0) + metrics = await self._compute_and_log_token_usage( + prompt_tokens or 0, + completion_tokens or 0, + total_tokens, + model, + ) + chunk.metrics = metrics if chunk.metrics is None else chunk.metrics + metrics + yield chunk + + return stream_generator() else: - return await provider.completion(**params) + response = await provider.completion(**params) + completion_tokens = await self._count_tokens(response.content) + total_tokens = (prompt_tokens or 0) + (completion_tokens or 0) + metrics = await self._compute_and_log_token_usage( + prompt_tokens or 0, + completion_tokens or 0, + total_tokens, + model, + ) + response.metrics = metrics if response.metrics is None else response.metrics + metrics + return response async def embeddings( self, diff --git a/llama_stack/distribution/server/server.py b/llama_stack/distribution/server/server.py index 2cc70a738..7ca009b13 100644 --- a/llama_stack/distribution/server/server.py +++ b/llama_stack/distribution/server/server.py @@ -28,7 +28,7 @@ from typing_extensions import Annotated from llama_stack.distribution.datatypes import StackRunConfig from llama_stack.distribution.distribution import builtin_automatically_routed_apis from llama_stack.distribution.request_headers import ( - preserve_headers_context_async_generator, + PROVIDER_DATA_VAR, request_provider_data_context, ) from llama_stack.distribution.resolver import InvalidProviderError @@ -38,6 +38,7 @@ from llama_stack.distribution.stack import ( replace_env_vars, validate_env_pair, ) +from llama_stack.distribution.utils.context import preserve_contexts_async_generator from llama_stack.log import get_logger from llama_stack.providers.datatypes import Api from llama_stack.providers.inline.telemetry.meta_reference.config import TelemetryConfig @@ -45,6 +46,7 @@ from llama_stack.providers.inline.telemetry.meta_reference.telemetry import ( TelemetryAdapter, ) from llama_stack.providers.utils.telemetry.tracing import ( + CURRENT_TRACE_CONTEXT, end_trace, setup_logger, start_trace, @@ -182,7 +184,9 @@ def create_dynamic_typed_route(func: Any, method: str, route: str): try: if is_streaming: - gen = preserve_headers_context_async_generator(sse_generator(func(**kwargs))) + gen = preserve_contexts_async_generator( + sse_generator(func(**kwargs)), [CURRENT_TRACE_CONTEXT, PROVIDER_DATA_VAR] + ) return StreamingResponse(gen, media_type="text/event-stream") else: value = func(**kwargs) diff --git a/llama_stack/distribution/utils/context.py b/llama_stack/distribution/utils/context.py new file mode 100644 index 000000000..107ce7127 --- /dev/null +++ b/llama_stack/distribution/utils/context.py @@ -0,0 +1,33 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from contextvars import ContextVar +from typing import AsyncGenerator, List, TypeVar + +T = TypeVar("T") + + +def preserve_contexts_async_generator( + gen: AsyncGenerator[T, None], context_vars: List[ContextVar] +) -> AsyncGenerator[T, None]: + """ + Wraps an async generator to preserve context variables across iterations. + This is needed because we start a new asyncio event loop for each streaming request, + and we need to preserve the context across the event loop boundary. + """ + + async def wrapper(): + while True: + try: + item = await gen.__anext__() + context_values = {context_var.name: context_var.get() for context_var in context_vars} + yield item + for context_var in context_vars: + _ = context_var.set(context_values[context_var.name]) + except StopAsyncIteration: + break + + return wrapper() diff --git a/llama_stack/distribution/utils/tests/test_context.py b/llama_stack/distribution/utils/tests/test_context.py new file mode 100644 index 000000000..84944bfe8 --- /dev/null +++ b/llama_stack/distribution/utils/tests/test_context.py @@ -0,0 +1,155 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import asyncio +from concurrent.futures import ThreadPoolExecutor +from contextvars import ContextVar + +import pytest + +from llama_stack.distribution.utils.context import preserve_contexts_async_generator + + +@pytest.mark.asyncio +async def test_preserve_contexts_with_exception(): + # Create context variable + context_var = ContextVar("exception_var", default="initial") + token = context_var.set("start_value") + + # Create an async generator that raises an exception + async def exception_generator(): + yield context_var.get() + context_var.set("modified") + raise ValueError("Test exception") + yield None # This will never be reached + + # Wrap the generator + wrapped_gen = preserve_contexts_async_generator(exception_generator(), [context_var]) + + # First iteration should work + value = await wrapped_gen.__anext__() + assert value == "start_value" + + # Second iteration should raise the exception + with pytest.raises(ValueError, match="Test exception"): + await wrapped_gen.__anext__() + + # Clean up + context_var.reset(token) + + +@pytest.mark.asyncio +async def test_preserve_contexts_empty_generator(): + # Create context variable + context_var = ContextVar("empty_var", default="initial") + token = context_var.set("value") + + # Create an empty async generator + async def empty_generator(): + if False: # This condition ensures the generator yields nothing + yield None + + # Wrap the generator + wrapped_gen = preserve_contexts_async_generator(empty_generator(), [context_var]) + + # The generator should raise StopAsyncIteration immediately + with pytest.raises(StopAsyncIteration): + await wrapped_gen.__anext__() + + # Context variable should remain unchanged + assert context_var.get() == "value" + + # Clean up + context_var.reset(token) + + +@pytest.mark.asyncio +async def test_preserve_contexts_across_event_loops(): + """ + Test that context variables are preserved across event loop boundaries with nested generators. + This simulates the real-world scenario where: + 1. A new event loop is created for each streaming request + 2. The async generator runs inside that loop + 3. There are multiple levels of nested generators + 4. Context needs to be preserved across these boundaries + """ + # Create context variables + request_id = ContextVar("request_id", default=None) + user_id = ContextVar("user_id", default=None) + + # Set initial values + + # Results container to verify values across thread boundaries + results = [] + + # Inner-most generator (level 2) + async def inner_generator(): + # Should have the context from the outer scope + yield (1, request_id.get(), user_id.get()) + + # Modify one context variable + user_id.set("user-modified") + + # Should reflect the modification + yield (2, request_id.get(), user_id.get()) + + # Middle generator (level 1) + async def middle_generator(): + inner_gen = inner_generator() + + # Forward the first yield from inner + item = await inner_gen.__anext__() + yield item + + # Forward the second yield from inner + item = await inner_gen.__anext__() + yield item + + request_id.set("req-modified") + + # Add our own yield with both modified variables + yield (3, request_id.get(), user_id.get()) + + # Function to run in a separate thread with a new event loop + def run_in_new_loop(): + # Create a new event loop for this thread + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + + try: + # Outer generator (runs in the new loop) + async def outer_generator(): + request_id.set("req-12345") + user_id.set("user-6789") + # Wrap the middle generator + wrapped_gen = preserve_contexts_async_generator(middle_generator(), [request_id, user_id]) + + # Process all items from the middle generator + async for item in wrapped_gen: + # Store results for verification + results.append(item) + + # Run the outer generator in the new loop + loop.run_until_complete(outer_generator()) + finally: + loop.close() + + # Run the generator chain in a separate thread with a new event loop + with ThreadPoolExecutor(max_workers=1) as executor: + future = executor.submit(run_in_new_loop) + future.result() # Wait for completion + + # Verify the results + assert len(results) == 3 + + # First yield should have original values + assert results[0] == (1, "req-12345", "user-6789") + + # Second yield should have modified user_id + assert results[1] == (2, "req-12345", "user-modified") + + # Third yield should have both modified values + assert results[2] == (3, "req-modified", "user-modified") diff --git a/llama_stack/providers/inline/telemetry/meta_reference/telemetry.py b/llama_stack/providers/inline/telemetry/meta_reference/telemetry.py index e713a057f..4cdb420b2 100644 --- a/llama_stack/providers/inline/telemetry/meta_reference/telemetry.py +++ b/llama_stack/providers/inline/telemetry/meta_reference/telemetry.py @@ -73,6 +73,7 @@ class TelemetryAdapter(TelemetryDatasetMixin, Telemetry): def __init__(self, config: TelemetryConfig, deps: Dict[str, Any]) -> None: self.config = config self.datasetio_api = deps.get(Api.datasetio) + self.meter = None resource = Resource.create( { @@ -171,6 +172,8 @@ class TelemetryAdapter(TelemetryDatasetMixin, Telemetry): return _GLOBAL_STORAGE["gauges"][name] def _log_metric(self, event: MetricEvent) -> None: + if self.meter is None: + return if isinstance(event.value, int): counter = self._get_or_create_counter(event.metric, event.unit) counter.add(event.value, attributes=event.attributes) diff --git a/llama_stack/templates/open-benchmark/open_benchmark.py b/llama_stack/templates/open-benchmark/open_benchmark.py index 7df33a715..2b40797f9 100644 --- a/llama_stack/templates/open-benchmark/open_benchmark.py +++ b/llama_stack/templates/open-benchmark/open_benchmark.py @@ -4,8 +4,9 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from typing import List, Tuple +from typing import Dict, List, Tuple +from llama_stack.apis.common.content_types import URL from llama_stack.apis.models.models import ModelType from llama_stack.distribution.datatypes import ( BenchmarkInput, @@ -15,21 +16,27 @@ from llama_stack.distribution.datatypes import ( ShieldInput, ToolGroupInput, ) -from llama_stack.providers.inline.vector_io.sqlite_vec.config import SQLiteVectorIOConfig +from llama_stack.providers.inline.vector_io.sqlite_vec.config import ( + SQLiteVectorIOConfig, +) from llama_stack.providers.remote.inference.anthropic.config import AnthropicConfig from llama_stack.providers.remote.inference.gemini.config import GeminiConfig from llama_stack.providers.remote.inference.groq.config import GroqConfig from llama_stack.providers.remote.inference.openai.config import OpenAIConfig from llama_stack.providers.remote.inference.together.config import TogetherImplConfig from llama_stack.providers.remote.vector_io.chroma.config import ChromaVectorIOConfig -from llama_stack.providers.remote.vector_io.pgvector.config import PGVectorVectorIOConfig -from llama_stack.providers.utils.inference.model_registry import ( - ProviderModelEntry, +from llama_stack.providers.remote.vector_io.pgvector.config import ( + PGVectorVectorIOConfig, +) +from llama_stack.providers.utils.inference.model_registry import ProviderModelEntry +from llama_stack.templates.template import ( + DistributionTemplate, + RunConfigSettings, + get_model_registry, ) -from llama_stack.templates.template import DistributionTemplate, RunConfigSettings, get_model_registry -def get_inference_providers() -> Tuple[List[Provider], List[ModelInput]]: +def get_inference_providers() -> Tuple[List[Provider], Dict[str, List[ProviderModelEntry]]]: # in this template, we allow each API key to be optional providers = [ ( @@ -164,7 +171,7 @@ def get_distribution_template() -> DistributionTemplate: DatasetInput( dataset_id="simpleqa", provider_id="huggingface", - url={"uri": "https://huggingface.co/datasets/llamastack/simpleqa"}, + url=URL(uri="https://huggingface.co/datasets/llamastack/simpleqa"), metadata={ "path": "llamastack/simpleqa", "split": "train", @@ -178,7 +185,7 @@ def get_distribution_template() -> DistributionTemplate: DatasetInput( dataset_id="mmlu_cot", provider_id="huggingface", - url={"uri": "https://huggingface.co/datasets/llamastack/mmlu_cot"}, + url=URL(uri="https://huggingface.co/datasets/llamastack/mmlu_cot"), metadata={ "path": "llamastack/mmlu_cot", "name": "all", @@ -193,7 +200,7 @@ def get_distribution_template() -> DistributionTemplate: DatasetInput( dataset_id="gpqa_cot", provider_id="huggingface", - url={"uri": "https://huggingface.co/datasets/llamastack/gpqa_0shot_cot"}, + url=URL(uri="https://huggingface.co/datasets/llamastack/gpqa_0shot_cot"), metadata={ "path": "llamastack/gpqa_0shot_cot", "name": "gpqa_main", @@ -208,7 +215,7 @@ def get_distribution_template() -> DistributionTemplate: DatasetInput( dataset_id="math_500", provider_id="huggingface", - url={"uri": "https://huggingface.co/datasets/llamastack/math_500"}, + url=URL(uri="https://huggingface.co/datasets/llamastack/math_500"), metadata={ "path": "llamastack/math_500", "split": "test", diff --git a/llama_stack/templates/template.py b/llama_stack/templates/template.py index aa1ce144f..a5c8e80bc 100644 --- a/llama_stack/templates/template.py +++ b/llama_stack/templates/template.py @@ -30,7 +30,9 @@ from llama_stack.providers.utils.inference.model_registry import ProviderModelEn from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig -def get_model_registry(available_models: Dict[str, List[ProviderModelEntry]]) -> List[ModelInput]: +def get_model_registry( + available_models: Dict[str, List[ProviderModelEntry]], +) -> List[ModelInput]: models = [] for provider_id, entries in available_models.items(): for entry in entries: @@ -193,7 +195,7 @@ class DistributionTemplate(BaseModel): default_models.append( DefaultModel( model_id=model_entry.provider_model_id, - doc_string=f"({' -- '.join(doc_parts)})" if doc_parts else "", + doc_string=(f"({' -- '.join(doc_parts)})" if doc_parts else ""), ) )