forked from phoenix-oss/llama-stack-mirror
Merge branch 'main' into pr1573
This commit is contained in:
commit
f840018088
12 changed files with 402 additions and 64 deletions
|
@ -285,7 +285,7 @@ class CompletionRequest(BaseModel):
|
|||
|
||||
|
||||
@json_schema_type
|
||||
class CompletionResponse(BaseModel):
|
||||
class CompletionResponse(MetricResponseMixin):
|
||||
"""Response from a completion request.
|
||||
|
||||
:param content: The generated completion text
|
||||
|
@ -299,7 +299,7 @@ class CompletionResponse(BaseModel):
|
|||
|
||||
|
||||
@json_schema_type
|
||||
class CompletionResponseStreamChunk(BaseModel):
|
||||
class CompletionResponseStreamChunk(MetricResponseMixin):
|
||||
"""A chunk of a streamed completion response.
|
||||
|
||||
:param delta: New content generated since last chunk. This can be one or more tokens.
|
||||
|
@ -368,7 +368,7 @@ class ChatCompletionRequest(BaseModel):
|
|||
|
||||
|
||||
@json_schema_type
|
||||
class ChatCompletionResponseStreamChunk(MetricResponseMixin, BaseModel):
|
||||
class ChatCompletionResponseStreamChunk(MetricResponseMixin):
|
||||
"""A chunk of a streamed chat completion response.
|
||||
|
||||
:param event: The event containing the new content
|
||||
|
@ -378,7 +378,7 @@ class ChatCompletionResponseStreamChunk(MetricResponseMixin, BaseModel):
|
|||
|
||||
|
||||
@json_schema_type
|
||||
class ChatCompletionResponse(MetricResponseMixin, BaseModel):
|
||||
class ChatCompletionResponse(MetricResponseMixin):
|
||||
"""Response from a chat completion request.
|
||||
|
||||
:param completion_message: The complete response message
|
||||
|
|
|
@ -33,7 +33,7 @@ from llama_stack.distribution.build import print_pip_install_help
|
|||
from llama_stack.distribution.configure import parse_and_maybe_upgrade_config
|
||||
from llama_stack.distribution.datatypes import Api
|
||||
from llama_stack.distribution.request_headers import (
|
||||
preserve_headers_context_async_generator,
|
||||
PROVIDER_DATA_VAR,
|
||||
request_provider_data_context,
|
||||
)
|
||||
from llama_stack.distribution.resolver import ProviderRegistry
|
||||
|
@ -44,8 +44,10 @@ from llama_stack.distribution.stack import (
|
|||
redact_sensitive_fields,
|
||||
replace_env_vars,
|
||||
)
|
||||
from llama_stack.distribution.utils.context import preserve_contexts_async_generator
|
||||
from llama_stack.distribution.utils.exec import in_notebook
|
||||
from llama_stack.providers.utils.telemetry.tracing import (
|
||||
CURRENT_TRACE_CONTEXT,
|
||||
end_trace,
|
||||
setup_logger,
|
||||
start_trace,
|
||||
|
@ -384,8 +386,8 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
|
|||
finally:
|
||||
await end_trace()
|
||||
|
||||
# Wrap the generator to preserve context across iterations
|
||||
wrapped_gen = preserve_headers_context_async_generator(gen())
|
||||
wrapped_gen = preserve_contexts_async_generator(gen(), [CURRENT_TRACE_CONTEXT, PROVIDER_DATA_VAR])
|
||||
|
||||
mock_response = httpx.Response(
|
||||
status_code=httpx.codes.OK,
|
||||
content=wrapped_gen,
|
||||
|
|
|
@ -7,14 +7,14 @@
|
|||
import contextvars
|
||||
import json
|
||||
import logging
|
||||
from typing import Any, AsyncGenerator, ContextManager, Dict, Optional, TypeVar
|
||||
from typing import Any, ContextManager, Dict, Optional
|
||||
|
||||
from .utils.dynamic import instantiate_class_type
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
# Context variable for request provider data
|
||||
_provider_data_var = contextvars.ContextVar("provider_data", default=None)
|
||||
PROVIDER_DATA_VAR = contextvars.ContextVar("provider_data", default=None)
|
||||
|
||||
|
||||
class RequestProviderDataContext(ContextManager):
|
||||
|
@ -26,40 +26,13 @@ class RequestProviderDataContext(ContextManager):
|
|||
|
||||
def __enter__(self):
|
||||
# Save the current value and set the new one
|
||||
self.token = _provider_data_var.set(self.provider_data)
|
||||
self.token = PROVIDER_DATA_VAR.set(self.provider_data)
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
# Restore the previous value
|
||||
if self.token is not None:
|
||||
_provider_data_var.reset(self.token)
|
||||
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
def preserve_headers_context_async_generator(gen: AsyncGenerator[T, None]) -> AsyncGenerator[T, None]:
|
||||
"""
|
||||
Wraps an async generator to preserve request headers context variables across iterations.
|
||||
|
||||
This ensures that context variables set during generator creation are
|
||||
available during each iteration of the generator, even if the original
|
||||
context manager has exited.
|
||||
"""
|
||||
# Capture the current context value right now
|
||||
context_value = _provider_data_var.get()
|
||||
|
||||
async def wrapper():
|
||||
while True:
|
||||
# Set context before each anext() call
|
||||
_ = _provider_data_var.set(context_value)
|
||||
try:
|
||||
item = await gen.__anext__()
|
||||
yield item
|
||||
except StopAsyncIteration:
|
||||
break
|
||||
|
||||
return wrapper()
|
||||
PROVIDER_DATA_VAR.reset(self.token)
|
||||
|
||||
|
||||
class NeedsRequestProviderData:
|
||||
|
@ -72,7 +45,7 @@ class NeedsRequestProviderData:
|
|||
if not validator_class:
|
||||
raise ValueError(f"Provider {provider_type} does not have a validator")
|
||||
|
||||
val = _provider_data_var.get()
|
||||
val = PROVIDER_DATA_VAR.get()
|
||||
if not val:
|
||||
return None
|
||||
|
||||
|
|
|
@ -165,7 +165,9 @@ def specs_for_autorouted_apis(apis_to_serve: List[str] | Set[str]) -> Dict[str,
|
|||
module="llama_stack.distribution.routers",
|
||||
routing_table_api=info.routing_table_api,
|
||||
api_dependencies=[info.routing_table_api],
|
||||
deps__=[info.routing_table_api.value],
|
||||
# Add telemetry as an optional dependency to all auto-routed providers
|
||||
optional_api_dependencies=[Api.telemetry],
|
||||
deps__=([info.routing_table_api.value, Api.telemetry.value]),
|
||||
),
|
||||
)
|
||||
}
|
||||
|
|
|
@ -45,7 +45,7 @@ async def get_routing_table_impl(
|
|||
return impl
|
||||
|
||||
|
||||
async def get_auto_router_impl(api: Api, routing_table: RoutingTable, _deps) -> Any:
|
||||
async def get_auto_router_impl(api: Api, routing_table: RoutingTable, deps: Dict[str, Any]) -> Any:
|
||||
from .routers import (
|
||||
DatasetIORouter,
|
||||
EvalRouter,
|
||||
|
@ -65,9 +65,17 @@ async def get_auto_router_impl(api: Api, routing_table: RoutingTable, _deps) ->
|
|||
"eval": EvalRouter,
|
||||
"tool_runtime": ToolRuntimeRouter,
|
||||
}
|
||||
api_to_deps = {
|
||||
"inference": {"telemetry": Api.telemetry},
|
||||
}
|
||||
if api.value not in api_to_routers:
|
||||
raise ValueError(f"API {api.value} not found in router map")
|
||||
|
||||
impl = api_to_routers[api.value](routing_table)
|
||||
api_to_dep_impl = {}
|
||||
for dep_name, dep_api in api_to_deps.get(api.value, {}).items():
|
||||
if dep_api in deps:
|
||||
api_to_dep_impl[dep_name] = deps[dep_api]
|
||||
|
||||
impl = api_to_routers[api.value](routing_table, **api_to_dep_impl)
|
||||
await impl.initialize()
|
||||
return impl
|
||||
|
|
|
@ -4,7 +4,8 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any, AsyncGenerator, Dict, List, Optional
|
||||
import time
|
||||
from typing import Any, AsyncGenerator, AsyncIterator, Dict, List, Optional, Union
|
||||
|
||||
from llama_stack.apis.common.content_types import (
|
||||
URL,
|
||||
|
@ -20,6 +21,10 @@ from llama_stack.apis.eval import (
|
|||
JobStatus,
|
||||
)
|
||||
from llama_stack.apis.inference import (
|
||||
ChatCompletionResponse,
|
||||
ChatCompletionResponseEventType,
|
||||
ChatCompletionResponseStreamChunk,
|
||||
CompletionMessage,
|
||||
EmbeddingsResponse,
|
||||
EmbeddingTaskType,
|
||||
Inference,
|
||||
|
@ -27,13 +32,14 @@ from llama_stack.apis.inference import (
|
|||
Message,
|
||||
ResponseFormat,
|
||||
SamplingParams,
|
||||
StopReason,
|
||||
TextTruncation,
|
||||
ToolChoice,
|
||||
ToolConfig,
|
||||
ToolDefinition,
|
||||
ToolPromptFormat,
|
||||
)
|
||||
from llama_stack.apis.models import ModelType
|
||||
from llama_stack.apis.models import Model, ModelType
|
||||
from llama_stack.apis.safety import RunShieldResponse, Safety
|
||||
from llama_stack.apis.scoring import (
|
||||
ScoreBatchResponse,
|
||||
|
@ -42,6 +48,7 @@ from llama_stack.apis.scoring import (
|
|||
ScoringFnParams,
|
||||
)
|
||||
from llama_stack.apis.shields import Shield
|
||||
from llama_stack.apis.telemetry import MetricEvent, Telemetry
|
||||
from llama_stack.apis.tools import (
|
||||
RAGDocument,
|
||||
RAGQueryConfig,
|
||||
|
@ -52,7 +59,10 @@ from llama_stack.apis.tools import (
|
|||
)
|
||||
from llama_stack.apis.vector_io import Chunk, QueryChunksResponse, VectorIO
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.models.llama.llama3.chat_format import ChatFormat
|
||||
from llama_stack.models.llama.llama3.tokenizer import Tokenizer
|
||||
from llama_stack.providers.datatypes import RoutingTable
|
||||
from llama_stack.providers.utils.telemetry.tracing import get_current_span
|
||||
|
||||
logger = get_logger(name=__name__, category="core")
|
||||
|
||||
|
@ -119,9 +129,14 @@ class InferenceRouter(Inference):
|
|||
def __init__(
|
||||
self,
|
||||
routing_table: RoutingTable,
|
||||
telemetry: Optional[Telemetry] = None,
|
||||
) -> None:
|
||||
logger.debug("Initializing InferenceRouter")
|
||||
self.routing_table = routing_table
|
||||
self.telemetry = telemetry
|
||||
if self.telemetry:
|
||||
self.tokenizer = Tokenizer.get_instance()
|
||||
self.formatter = ChatFormat(self.tokenizer)
|
||||
|
||||
async def initialize(self) -> None:
|
||||
logger.debug("InferenceRouter.initialize")
|
||||
|
@ -144,6 +159,71 @@ class InferenceRouter(Inference):
|
|||
)
|
||||
await self.routing_table.register_model(model_id, provider_model_id, provider_id, metadata, model_type)
|
||||
|
||||
def _construct_metrics(
|
||||
self, prompt_tokens: int, completion_tokens: int, total_tokens: int, model: Model
|
||||
) -> List[MetricEvent]:
|
||||
"""Constructs a list of MetricEvent objects containing token usage metrics.
|
||||
|
||||
Args:
|
||||
prompt_tokens: Number of tokens in the prompt
|
||||
completion_tokens: Number of tokens in the completion
|
||||
total_tokens: Total number of tokens used
|
||||
model: Model object containing model_id and provider_id
|
||||
|
||||
Returns:
|
||||
List of MetricEvent objects with token usage metrics
|
||||
"""
|
||||
span = get_current_span()
|
||||
if span is None:
|
||||
logger.warning("No span found for token usage metrics")
|
||||
return []
|
||||
metrics = [
|
||||
("prompt_tokens", prompt_tokens),
|
||||
("completion_tokens", completion_tokens),
|
||||
("total_tokens", total_tokens),
|
||||
]
|
||||
metric_events = []
|
||||
for metric_name, value in metrics:
|
||||
metric_events.append(
|
||||
MetricEvent(
|
||||
trace_id=span.trace_id,
|
||||
span_id=span.span_id,
|
||||
metric=metric_name,
|
||||
value=value,
|
||||
timestamp=time.time(),
|
||||
unit="tokens",
|
||||
attributes={
|
||||
"model_id": model.model_id,
|
||||
"provider_id": model.provider_id,
|
||||
},
|
||||
)
|
||||
)
|
||||
return metric_events
|
||||
|
||||
async def _compute_and_log_token_usage(
|
||||
self,
|
||||
prompt_tokens: int,
|
||||
completion_tokens: int,
|
||||
total_tokens: int,
|
||||
model: Model,
|
||||
) -> List[MetricEvent]:
|
||||
metrics = self._construct_metrics(prompt_tokens, completion_tokens, total_tokens, model)
|
||||
if self.telemetry:
|
||||
for metric in metrics:
|
||||
await self.telemetry.log_event(metric)
|
||||
return metrics
|
||||
|
||||
async def _count_tokens(
|
||||
self,
|
||||
messages: List[Message] | InterleavedContent,
|
||||
tool_prompt_format: Optional[ToolPromptFormat] = None,
|
||||
) -> Optional[int]:
|
||||
if isinstance(messages, list):
|
||||
encoded = self.formatter.encode_dialog_prompt(messages, tool_prompt_format)
|
||||
else:
|
||||
encoded = self.formatter.encode_content(messages)
|
||||
return len(encoded.tokens) if encoded and encoded.tokens else 0
|
||||
|
||||
async def chat_completion(
|
||||
self,
|
||||
model_id: str,
|
||||
|
@ -156,8 +236,9 @@ class InferenceRouter(Inference):
|
|||
stream: Optional[bool] = False,
|
||||
logprobs: Optional[LogProbConfig] = None,
|
||||
tool_config: Optional[ToolConfig] = None,
|
||||
) -> AsyncGenerator:
|
||||
) -> Union[ChatCompletionResponse, AsyncIterator[ChatCompletionResponseStreamChunk]]:
|
||||
logger.debug(
|
||||
"core",
|
||||
f"InferenceRouter.chat_completion: {model_id=}, {stream=}, {messages=}, {tools=}, {tool_config=}, {response_format=}",
|
||||
)
|
||||
if sampling_params is None:
|
||||
|
@ -206,10 +287,47 @@ class InferenceRouter(Inference):
|
|||
tool_config=tool_config,
|
||||
)
|
||||
provider = self.routing_table.get_provider_impl(model_id)
|
||||
prompt_tokens = await self._count_tokens(messages, tool_config.tool_prompt_format)
|
||||
|
||||
if stream:
|
||||
return (chunk async for chunk in await provider.chat_completion(**params))
|
||||
|
||||
async def stream_generator():
|
||||
completion_text = ""
|
||||
async for chunk in await provider.chat_completion(**params):
|
||||
if chunk.event.event_type == ChatCompletionResponseEventType.progress:
|
||||
if chunk.event.delta.type == "text":
|
||||
completion_text += chunk.event.delta.text
|
||||
if chunk.event.event_type == ChatCompletionResponseEventType.complete:
|
||||
completion_tokens = await self._count_tokens(
|
||||
[CompletionMessage(content=completion_text, stop_reason=StopReason.end_of_turn)],
|
||||
tool_config.tool_prompt_format,
|
||||
)
|
||||
total_tokens = (prompt_tokens or 0) + (completion_tokens or 0)
|
||||
metrics = await self._compute_and_log_token_usage(
|
||||
prompt_tokens or 0,
|
||||
completion_tokens or 0,
|
||||
total_tokens,
|
||||
model,
|
||||
)
|
||||
chunk.metrics = metrics if chunk.metrics is None else chunk.metrics + metrics
|
||||
yield chunk
|
||||
|
||||
return stream_generator()
|
||||
else:
|
||||
return await provider.chat_completion(**params)
|
||||
response = await provider.chat_completion(**params)
|
||||
completion_tokens = await self._count_tokens(
|
||||
[response.completion_message],
|
||||
tool_config.tool_prompt_format,
|
||||
)
|
||||
total_tokens = (prompt_tokens or 0) + (completion_tokens or 0)
|
||||
metrics = await self._compute_and_log_token_usage(
|
||||
prompt_tokens or 0,
|
||||
completion_tokens or 0,
|
||||
total_tokens,
|
||||
model,
|
||||
)
|
||||
response.metrics = metrics if response.metrics is None else response.metrics + metrics
|
||||
return response
|
||||
|
||||
async def completion(
|
||||
self,
|
||||
|
@ -239,10 +357,41 @@ class InferenceRouter(Inference):
|
|||
stream=stream,
|
||||
logprobs=logprobs,
|
||||
)
|
||||
|
||||
prompt_tokens = await self._count_tokens(content)
|
||||
|
||||
if stream:
|
||||
return (chunk async for chunk in await provider.completion(**params))
|
||||
|
||||
async def stream_generator():
|
||||
completion_text = ""
|
||||
async for chunk in await provider.completion(**params):
|
||||
if hasattr(chunk, "delta"):
|
||||
completion_text += chunk.delta
|
||||
if hasattr(chunk, "stop_reason") and chunk.stop_reason and self.telemetry:
|
||||
completion_tokens = await self._count_tokens(completion_text)
|
||||
total_tokens = (prompt_tokens or 0) + (completion_tokens or 0)
|
||||
metrics = await self._compute_and_log_token_usage(
|
||||
prompt_tokens or 0,
|
||||
completion_tokens or 0,
|
||||
total_tokens,
|
||||
model,
|
||||
)
|
||||
chunk.metrics = metrics if chunk.metrics is None else chunk.metrics + metrics
|
||||
yield chunk
|
||||
|
||||
return stream_generator()
|
||||
else:
|
||||
return await provider.completion(**params)
|
||||
response = await provider.completion(**params)
|
||||
completion_tokens = await self._count_tokens(response.content)
|
||||
total_tokens = (prompt_tokens or 0) + (completion_tokens or 0)
|
||||
metrics = await self._compute_and_log_token_usage(
|
||||
prompt_tokens or 0,
|
||||
completion_tokens or 0,
|
||||
total_tokens,
|
||||
model,
|
||||
)
|
||||
response.metrics = metrics if response.metrics is None else response.metrics + metrics
|
||||
return response
|
||||
|
||||
async def embeddings(
|
||||
self,
|
||||
|
|
|
@ -28,7 +28,7 @@ from typing_extensions import Annotated
|
|||
from llama_stack.distribution.datatypes import StackRunConfig
|
||||
from llama_stack.distribution.distribution import builtin_automatically_routed_apis
|
||||
from llama_stack.distribution.request_headers import (
|
||||
preserve_headers_context_async_generator,
|
||||
PROVIDER_DATA_VAR,
|
||||
request_provider_data_context,
|
||||
)
|
||||
from llama_stack.distribution.resolver import InvalidProviderError
|
||||
|
@ -38,6 +38,7 @@ from llama_stack.distribution.stack import (
|
|||
replace_env_vars,
|
||||
validate_env_pair,
|
||||
)
|
||||
from llama_stack.distribution.utils.context import preserve_contexts_async_generator
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.datatypes import Api
|
||||
from llama_stack.providers.inline.telemetry.meta_reference.config import TelemetryConfig
|
||||
|
@ -45,6 +46,7 @@ from llama_stack.providers.inline.telemetry.meta_reference.telemetry import (
|
|||
TelemetryAdapter,
|
||||
)
|
||||
from llama_stack.providers.utils.telemetry.tracing import (
|
||||
CURRENT_TRACE_CONTEXT,
|
||||
end_trace,
|
||||
setup_logger,
|
||||
start_trace,
|
||||
|
@ -182,7 +184,9 @@ def create_dynamic_typed_route(func: Any, method: str, route: str):
|
|||
|
||||
try:
|
||||
if is_streaming:
|
||||
gen = preserve_headers_context_async_generator(sse_generator(func(**kwargs)))
|
||||
gen = preserve_contexts_async_generator(
|
||||
sse_generator(func(**kwargs)), [CURRENT_TRACE_CONTEXT, PROVIDER_DATA_VAR]
|
||||
)
|
||||
return StreamingResponse(gen, media_type="text/event-stream")
|
||||
else:
|
||||
value = func(**kwargs)
|
||||
|
|
33
llama_stack/distribution/utils/context.py
Normal file
33
llama_stack/distribution/utils/context.py
Normal file
|
@ -0,0 +1,33 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from contextvars import ContextVar
|
||||
from typing import AsyncGenerator, List, TypeVar
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
def preserve_contexts_async_generator(
|
||||
gen: AsyncGenerator[T, None], context_vars: List[ContextVar]
|
||||
) -> AsyncGenerator[T, None]:
|
||||
"""
|
||||
Wraps an async generator to preserve context variables across iterations.
|
||||
This is needed because we start a new asyncio event loop for each streaming request,
|
||||
and we need to preserve the context across the event loop boundary.
|
||||
"""
|
||||
|
||||
async def wrapper():
|
||||
while True:
|
||||
try:
|
||||
item = await gen.__anext__()
|
||||
context_values = {context_var.name: context_var.get() for context_var in context_vars}
|
||||
yield item
|
||||
for context_var in context_vars:
|
||||
_ = context_var.set(context_values[context_var.name])
|
||||
except StopAsyncIteration:
|
||||
break
|
||||
|
||||
return wrapper()
|
155
llama_stack/distribution/utils/tests/test_context.py
Normal file
155
llama_stack/distribution/utils/tests/test_context.py
Normal file
|
@ -0,0 +1,155 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import asyncio
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from contextvars import ContextVar
|
||||
|
||||
import pytest
|
||||
|
||||
from llama_stack.distribution.utils.context import preserve_contexts_async_generator
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_preserve_contexts_with_exception():
|
||||
# Create context variable
|
||||
context_var = ContextVar("exception_var", default="initial")
|
||||
token = context_var.set("start_value")
|
||||
|
||||
# Create an async generator that raises an exception
|
||||
async def exception_generator():
|
||||
yield context_var.get()
|
||||
context_var.set("modified")
|
||||
raise ValueError("Test exception")
|
||||
yield None # This will never be reached
|
||||
|
||||
# Wrap the generator
|
||||
wrapped_gen = preserve_contexts_async_generator(exception_generator(), [context_var])
|
||||
|
||||
# First iteration should work
|
||||
value = await wrapped_gen.__anext__()
|
||||
assert value == "start_value"
|
||||
|
||||
# Second iteration should raise the exception
|
||||
with pytest.raises(ValueError, match="Test exception"):
|
||||
await wrapped_gen.__anext__()
|
||||
|
||||
# Clean up
|
||||
context_var.reset(token)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_preserve_contexts_empty_generator():
|
||||
# Create context variable
|
||||
context_var = ContextVar("empty_var", default="initial")
|
||||
token = context_var.set("value")
|
||||
|
||||
# Create an empty async generator
|
||||
async def empty_generator():
|
||||
if False: # This condition ensures the generator yields nothing
|
||||
yield None
|
||||
|
||||
# Wrap the generator
|
||||
wrapped_gen = preserve_contexts_async_generator(empty_generator(), [context_var])
|
||||
|
||||
# The generator should raise StopAsyncIteration immediately
|
||||
with pytest.raises(StopAsyncIteration):
|
||||
await wrapped_gen.__anext__()
|
||||
|
||||
# Context variable should remain unchanged
|
||||
assert context_var.get() == "value"
|
||||
|
||||
# Clean up
|
||||
context_var.reset(token)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_preserve_contexts_across_event_loops():
|
||||
"""
|
||||
Test that context variables are preserved across event loop boundaries with nested generators.
|
||||
This simulates the real-world scenario where:
|
||||
1. A new event loop is created for each streaming request
|
||||
2. The async generator runs inside that loop
|
||||
3. There are multiple levels of nested generators
|
||||
4. Context needs to be preserved across these boundaries
|
||||
"""
|
||||
# Create context variables
|
||||
request_id = ContextVar("request_id", default=None)
|
||||
user_id = ContextVar("user_id", default=None)
|
||||
|
||||
# Set initial values
|
||||
|
||||
# Results container to verify values across thread boundaries
|
||||
results = []
|
||||
|
||||
# Inner-most generator (level 2)
|
||||
async def inner_generator():
|
||||
# Should have the context from the outer scope
|
||||
yield (1, request_id.get(), user_id.get())
|
||||
|
||||
# Modify one context variable
|
||||
user_id.set("user-modified")
|
||||
|
||||
# Should reflect the modification
|
||||
yield (2, request_id.get(), user_id.get())
|
||||
|
||||
# Middle generator (level 1)
|
||||
async def middle_generator():
|
||||
inner_gen = inner_generator()
|
||||
|
||||
# Forward the first yield from inner
|
||||
item = await inner_gen.__anext__()
|
||||
yield item
|
||||
|
||||
# Forward the second yield from inner
|
||||
item = await inner_gen.__anext__()
|
||||
yield item
|
||||
|
||||
request_id.set("req-modified")
|
||||
|
||||
# Add our own yield with both modified variables
|
||||
yield (3, request_id.get(), user_id.get())
|
||||
|
||||
# Function to run in a separate thread with a new event loop
|
||||
def run_in_new_loop():
|
||||
# Create a new event loop for this thread
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
|
||||
try:
|
||||
# Outer generator (runs in the new loop)
|
||||
async def outer_generator():
|
||||
request_id.set("req-12345")
|
||||
user_id.set("user-6789")
|
||||
# Wrap the middle generator
|
||||
wrapped_gen = preserve_contexts_async_generator(middle_generator(), [request_id, user_id])
|
||||
|
||||
# Process all items from the middle generator
|
||||
async for item in wrapped_gen:
|
||||
# Store results for verification
|
||||
results.append(item)
|
||||
|
||||
# Run the outer generator in the new loop
|
||||
loop.run_until_complete(outer_generator())
|
||||
finally:
|
||||
loop.close()
|
||||
|
||||
# Run the generator chain in a separate thread with a new event loop
|
||||
with ThreadPoolExecutor(max_workers=1) as executor:
|
||||
future = executor.submit(run_in_new_loop)
|
||||
future.result() # Wait for completion
|
||||
|
||||
# Verify the results
|
||||
assert len(results) == 3
|
||||
|
||||
# First yield should have original values
|
||||
assert results[0] == (1, "req-12345", "user-6789")
|
||||
|
||||
# Second yield should have modified user_id
|
||||
assert results[1] == (2, "req-12345", "user-modified")
|
||||
|
||||
# Third yield should have both modified values
|
||||
assert results[2] == (3, "req-modified", "user-modified")
|
|
@ -73,6 +73,7 @@ class TelemetryAdapter(TelemetryDatasetMixin, Telemetry):
|
|||
def __init__(self, config: TelemetryConfig, deps: Dict[str, Any]) -> None:
|
||||
self.config = config
|
||||
self.datasetio_api = deps.get(Api.datasetio)
|
||||
self.meter = None
|
||||
|
||||
resource = Resource.create(
|
||||
{
|
||||
|
@ -171,6 +172,8 @@ class TelemetryAdapter(TelemetryDatasetMixin, Telemetry):
|
|||
return _GLOBAL_STORAGE["gauges"][name]
|
||||
|
||||
def _log_metric(self, event: MetricEvent) -> None:
|
||||
if self.meter is None:
|
||||
return
|
||||
if isinstance(event.value, int):
|
||||
counter = self._get_or_create_counter(event.metric, event.unit)
|
||||
counter.add(event.value, attributes=event.attributes)
|
||||
|
|
|
@ -4,8 +4,9 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import List, Tuple
|
||||
from typing import Dict, List, Tuple
|
||||
|
||||
from llama_stack.apis.common.content_types import URL
|
||||
from llama_stack.apis.models.models import ModelType
|
||||
from llama_stack.distribution.datatypes import (
|
||||
BenchmarkInput,
|
||||
|
@ -15,21 +16,27 @@ from llama_stack.distribution.datatypes import (
|
|||
ShieldInput,
|
||||
ToolGroupInput,
|
||||
)
|
||||
from llama_stack.providers.inline.vector_io.sqlite_vec.config import SQLiteVectorIOConfig
|
||||
from llama_stack.providers.inline.vector_io.sqlite_vec.config import (
|
||||
SQLiteVectorIOConfig,
|
||||
)
|
||||
from llama_stack.providers.remote.inference.anthropic.config import AnthropicConfig
|
||||
from llama_stack.providers.remote.inference.gemini.config import GeminiConfig
|
||||
from llama_stack.providers.remote.inference.groq.config import GroqConfig
|
||||
from llama_stack.providers.remote.inference.openai.config import OpenAIConfig
|
||||
from llama_stack.providers.remote.inference.together.config import TogetherImplConfig
|
||||
from llama_stack.providers.remote.vector_io.chroma.config import ChromaVectorIOConfig
|
||||
from llama_stack.providers.remote.vector_io.pgvector.config import PGVectorVectorIOConfig
|
||||
from llama_stack.providers.utils.inference.model_registry import (
|
||||
ProviderModelEntry,
|
||||
from llama_stack.providers.remote.vector_io.pgvector.config import (
|
||||
PGVectorVectorIOConfig,
|
||||
)
|
||||
from llama_stack.providers.utils.inference.model_registry import ProviderModelEntry
|
||||
from llama_stack.templates.template import (
|
||||
DistributionTemplate,
|
||||
RunConfigSettings,
|
||||
get_model_registry,
|
||||
)
|
||||
from llama_stack.templates.template import DistributionTemplate, RunConfigSettings, get_model_registry
|
||||
|
||||
|
||||
def get_inference_providers() -> Tuple[List[Provider], List[ModelInput]]:
|
||||
def get_inference_providers() -> Tuple[List[Provider], Dict[str, List[ProviderModelEntry]]]:
|
||||
# in this template, we allow each API key to be optional
|
||||
providers = [
|
||||
(
|
||||
|
@ -164,7 +171,7 @@ def get_distribution_template() -> DistributionTemplate:
|
|||
DatasetInput(
|
||||
dataset_id="simpleqa",
|
||||
provider_id="huggingface",
|
||||
url={"uri": "https://huggingface.co/datasets/llamastack/simpleqa"},
|
||||
url=URL(uri="https://huggingface.co/datasets/llamastack/simpleqa"),
|
||||
metadata={
|
||||
"path": "llamastack/simpleqa",
|
||||
"split": "train",
|
||||
|
@ -178,7 +185,7 @@ def get_distribution_template() -> DistributionTemplate:
|
|||
DatasetInput(
|
||||
dataset_id="mmlu_cot",
|
||||
provider_id="huggingface",
|
||||
url={"uri": "https://huggingface.co/datasets/llamastack/mmlu_cot"},
|
||||
url=URL(uri="https://huggingface.co/datasets/llamastack/mmlu_cot"),
|
||||
metadata={
|
||||
"path": "llamastack/mmlu_cot",
|
||||
"name": "all",
|
||||
|
@ -193,7 +200,7 @@ def get_distribution_template() -> DistributionTemplate:
|
|||
DatasetInput(
|
||||
dataset_id="gpqa_cot",
|
||||
provider_id="huggingface",
|
||||
url={"uri": "https://huggingface.co/datasets/llamastack/gpqa_0shot_cot"},
|
||||
url=URL(uri="https://huggingface.co/datasets/llamastack/gpqa_0shot_cot"),
|
||||
metadata={
|
||||
"path": "llamastack/gpqa_0shot_cot",
|
||||
"name": "gpqa_main",
|
||||
|
@ -208,7 +215,7 @@ def get_distribution_template() -> DistributionTemplate:
|
|||
DatasetInput(
|
||||
dataset_id="math_500",
|
||||
provider_id="huggingface",
|
||||
url={"uri": "https://huggingface.co/datasets/llamastack/math_500"},
|
||||
url=URL(uri="https://huggingface.co/datasets/llamastack/math_500"),
|
||||
metadata={
|
||||
"path": "llamastack/math_500",
|
||||
"split": "test",
|
||||
|
|
|
@ -30,7 +30,9 @@ from llama_stack.providers.utils.inference.model_registry import ProviderModelEn
|
|||
from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig
|
||||
|
||||
|
||||
def get_model_registry(available_models: Dict[str, List[ProviderModelEntry]]) -> List[ModelInput]:
|
||||
def get_model_registry(
|
||||
available_models: Dict[str, List[ProviderModelEntry]],
|
||||
) -> List[ModelInput]:
|
||||
models = []
|
||||
for provider_id, entries in available_models.items():
|
||||
for entry in entries:
|
||||
|
@ -193,7 +195,7 @@ class DistributionTemplate(BaseModel):
|
|||
default_models.append(
|
||||
DefaultModel(
|
||||
model_id=model_entry.provider_model_id,
|
||||
doc_string=f"({' -- '.join(doc_parts)})" if doc_parts else "",
|
||||
doc_string=(f"({' -- '.join(doc_parts)})" if doc_parts else ""),
|
||||
)
|
||||
)
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue