forked from phoenix-oss/llama-stack-mirror
Merge branch 'main' into pr1573
This commit is contained in:
commit
f840018088
12 changed files with 402 additions and 64 deletions
|
@ -285,7 +285,7 @@ class CompletionRequest(BaseModel):
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class CompletionResponse(BaseModel):
|
class CompletionResponse(MetricResponseMixin):
|
||||||
"""Response from a completion request.
|
"""Response from a completion request.
|
||||||
|
|
||||||
:param content: The generated completion text
|
:param content: The generated completion text
|
||||||
|
@ -299,7 +299,7 @@ class CompletionResponse(BaseModel):
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class CompletionResponseStreamChunk(BaseModel):
|
class CompletionResponseStreamChunk(MetricResponseMixin):
|
||||||
"""A chunk of a streamed completion response.
|
"""A chunk of a streamed completion response.
|
||||||
|
|
||||||
:param delta: New content generated since last chunk. This can be one or more tokens.
|
:param delta: New content generated since last chunk. This can be one or more tokens.
|
||||||
|
@ -368,7 +368,7 @@ class ChatCompletionRequest(BaseModel):
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class ChatCompletionResponseStreamChunk(MetricResponseMixin, BaseModel):
|
class ChatCompletionResponseStreamChunk(MetricResponseMixin):
|
||||||
"""A chunk of a streamed chat completion response.
|
"""A chunk of a streamed chat completion response.
|
||||||
|
|
||||||
:param event: The event containing the new content
|
:param event: The event containing the new content
|
||||||
|
@ -378,7 +378,7 @@ class ChatCompletionResponseStreamChunk(MetricResponseMixin, BaseModel):
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class ChatCompletionResponse(MetricResponseMixin, BaseModel):
|
class ChatCompletionResponse(MetricResponseMixin):
|
||||||
"""Response from a chat completion request.
|
"""Response from a chat completion request.
|
||||||
|
|
||||||
:param completion_message: The complete response message
|
:param completion_message: The complete response message
|
||||||
|
|
|
@ -33,7 +33,7 @@ from llama_stack.distribution.build import print_pip_install_help
|
||||||
from llama_stack.distribution.configure import parse_and_maybe_upgrade_config
|
from llama_stack.distribution.configure import parse_and_maybe_upgrade_config
|
||||||
from llama_stack.distribution.datatypes import Api
|
from llama_stack.distribution.datatypes import Api
|
||||||
from llama_stack.distribution.request_headers import (
|
from llama_stack.distribution.request_headers import (
|
||||||
preserve_headers_context_async_generator,
|
PROVIDER_DATA_VAR,
|
||||||
request_provider_data_context,
|
request_provider_data_context,
|
||||||
)
|
)
|
||||||
from llama_stack.distribution.resolver import ProviderRegistry
|
from llama_stack.distribution.resolver import ProviderRegistry
|
||||||
|
@ -44,8 +44,10 @@ from llama_stack.distribution.stack import (
|
||||||
redact_sensitive_fields,
|
redact_sensitive_fields,
|
||||||
replace_env_vars,
|
replace_env_vars,
|
||||||
)
|
)
|
||||||
|
from llama_stack.distribution.utils.context import preserve_contexts_async_generator
|
||||||
from llama_stack.distribution.utils.exec import in_notebook
|
from llama_stack.distribution.utils.exec import in_notebook
|
||||||
from llama_stack.providers.utils.telemetry.tracing import (
|
from llama_stack.providers.utils.telemetry.tracing import (
|
||||||
|
CURRENT_TRACE_CONTEXT,
|
||||||
end_trace,
|
end_trace,
|
||||||
setup_logger,
|
setup_logger,
|
||||||
start_trace,
|
start_trace,
|
||||||
|
@ -384,8 +386,8 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
|
||||||
finally:
|
finally:
|
||||||
await end_trace()
|
await end_trace()
|
||||||
|
|
||||||
# Wrap the generator to preserve context across iterations
|
wrapped_gen = preserve_contexts_async_generator(gen(), [CURRENT_TRACE_CONTEXT, PROVIDER_DATA_VAR])
|
||||||
wrapped_gen = preserve_headers_context_async_generator(gen())
|
|
||||||
mock_response = httpx.Response(
|
mock_response = httpx.Response(
|
||||||
status_code=httpx.codes.OK,
|
status_code=httpx.codes.OK,
|
||||||
content=wrapped_gen,
|
content=wrapped_gen,
|
||||||
|
|
|
@ -7,14 +7,14 @@
|
||||||
import contextvars
|
import contextvars
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
from typing import Any, AsyncGenerator, ContextManager, Dict, Optional, TypeVar
|
from typing import Any, ContextManager, Dict, Optional
|
||||||
|
|
||||||
from .utils.dynamic import instantiate_class_type
|
from .utils.dynamic import instantiate_class_type
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
|
|
||||||
# Context variable for request provider data
|
# Context variable for request provider data
|
||||||
_provider_data_var = contextvars.ContextVar("provider_data", default=None)
|
PROVIDER_DATA_VAR = contextvars.ContextVar("provider_data", default=None)
|
||||||
|
|
||||||
|
|
||||||
class RequestProviderDataContext(ContextManager):
|
class RequestProviderDataContext(ContextManager):
|
||||||
|
@ -26,40 +26,13 @@ class RequestProviderDataContext(ContextManager):
|
||||||
|
|
||||||
def __enter__(self):
|
def __enter__(self):
|
||||||
# Save the current value and set the new one
|
# Save the current value and set the new one
|
||||||
self.token = _provider_data_var.set(self.provider_data)
|
self.token = PROVIDER_DATA_VAR.set(self.provider_data)
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||||
# Restore the previous value
|
# Restore the previous value
|
||||||
if self.token is not None:
|
if self.token is not None:
|
||||||
_provider_data_var.reset(self.token)
|
PROVIDER_DATA_VAR.reset(self.token)
|
||||||
|
|
||||||
|
|
||||||
T = TypeVar("T")
|
|
||||||
|
|
||||||
|
|
||||||
def preserve_headers_context_async_generator(gen: AsyncGenerator[T, None]) -> AsyncGenerator[T, None]:
|
|
||||||
"""
|
|
||||||
Wraps an async generator to preserve request headers context variables across iterations.
|
|
||||||
|
|
||||||
This ensures that context variables set during generator creation are
|
|
||||||
available during each iteration of the generator, even if the original
|
|
||||||
context manager has exited.
|
|
||||||
"""
|
|
||||||
# Capture the current context value right now
|
|
||||||
context_value = _provider_data_var.get()
|
|
||||||
|
|
||||||
async def wrapper():
|
|
||||||
while True:
|
|
||||||
# Set context before each anext() call
|
|
||||||
_ = _provider_data_var.set(context_value)
|
|
||||||
try:
|
|
||||||
item = await gen.__anext__()
|
|
||||||
yield item
|
|
||||||
except StopAsyncIteration:
|
|
||||||
break
|
|
||||||
|
|
||||||
return wrapper()
|
|
||||||
|
|
||||||
|
|
||||||
class NeedsRequestProviderData:
|
class NeedsRequestProviderData:
|
||||||
|
@ -72,7 +45,7 @@ class NeedsRequestProviderData:
|
||||||
if not validator_class:
|
if not validator_class:
|
||||||
raise ValueError(f"Provider {provider_type} does not have a validator")
|
raise ValueError(f"Provider {provider_type} does not have a validator")
|
||||||
|
|
||||||
val = _provider_data_var.get()
|
val = PROVIDER_DATA_VAR.get()
|
||||||
if not val:
|
if not val:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
|
@ -165,7 +165,9 @@ def specs_for_autorouted_apis(apis_to_serve: List[str] | Set[str]) -> Dict[str,
|
||||||
module="llama_stack.distribution.routers",
|
module="llama_stack.distribution.routers",
|
||||||
routing_table_api=info.routing_table_api,
|
routing_table_api=info.routing_table_api,
|
||||||
api_dependencies=[info.routing_table_api],
|
api_dependencies=[info.routing_table_api],
|
||||||
deps__=[info.routing_table_api.value],
|
# Add telemetry as an optional dependency to all auto-routed providers
|
||||||
|
optional_api_dependencies=[Api.telemetry],
|
||||||
|
deps__=([info.routing_table_api.value, Api.telemetry.value]),
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
|
@ -45,7 +45,7 @@ async def get_routing_table_impl(
|
||||||
return impl
|
return impl
|
||||||
|
|
||||||
|
|
||||||
async def get_auto_router_impl(api: Api, routing_table: RoutingTable, _deps) -> Any:
|
async def get_auto_router_impl(api: Api, routing_table: RoutingTable, deps: Dict[str, Any]) -> Any:
|
||||||
from .routers import (
|
from .routers import (
|
||||||
DatasetIORouter,
|
DatasetIORouter,
|
||||||
EvalRouter,
|
EvalRouter,
|
||||||
|
@ -65,9 +65,17 @@ async def get_auto_router_impl(api: Api, routing_table: RoutingTable, _deps) ->
|
||||||
"eval": EvalRouter,
|
"eval": EvalRouter,
|
||||||
"tool_runtime": ToolRuntimeRouter,
|
"tool_runtime": ToolRuntimeRouter,
|
||||||
}
|
}
|
||||||
|
api_to_deps = {
|
||||||
|
"inference": {"telemetry": Api.telemetry},
|
||||||
|
}
|
||||||
if api.value not in api_to_routers:
|
if api.value not in api_to_routers:
|
||||||
raise ValueError(f"API {api.value} not found in router map")
|
raise ValueError(f"API {api.value} not found in router map")
|
||||||
|
|
||||||
impl = api_to_routers[api.value](routing_table)
|
api_to_dep_impl = {}
|
||||||
|
for dep_name, dep_api in api_to_deps.get(api.value, {}).items():
|
||||||
|
if dep_api in deps:
|
||||||
|
api_to_dep_impl[dep_name] = deps[dep_api]
|
||||||
|
|
||||||
|
impl = api_to_routers[api.value](routing_table, **api_to_dep_impl)
|
||||||
await impl.initialize()
|
await impl.initialize()
|
||||||
return impl
|
return impl
|
||||||
|
|
|
@ -4,7 +4,8 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from typing import Any, AsyncGenerator, Dict, List, Optional
|
import time
|
||||||
|
from typing import Any, AsyncGenerator, AsyncIterator, Dict, List, Optional, Union
|
||||||
|
|
||||||
from llama_stack.apis.common.content_types import (
|
from llama_stack.apis.common.content_types import (
|
||||||
URL,
|
URL,
|
||||||
|
@ -20,6 +21,10 @@ from llama_stack.apis.eval import (
|
||||||
JobStatus,
|
JobStatus,
|
||||||
)
|
)
|
||||||
from llama_stack.apis.inference import (
|
from llama_stack.apis.inference import (
|
||||||
|
ChatCompletionResponse,
|
||||||
|
ChatCompletionResponseEventType,
|
||||||
|
ChatCompletionResponseStreamChunk,
|
||||||
|
CompletionMessage,
|
||||||
EmbeddingsResponse,
|
EmbeddingsResponse,
|
||||||
EmbeddingTaskType,
|
EmbeddingTaskType,
|
||||||
Inference,
|
Inference,
|
||||||
|
@ -27,13 +32,14 @@ from llama_stack.apis.inference import (
|
||||||
Message,
|
Message,
|
||||||
ResponseFormat,
|
ResponseFormat,
|
||||||
SamplingParams,
|
SamplingParams,
|
||||||
|
StopReason,
|
||||||
TextTruncation,
|
TextTruncation,
|
||||||
ToolChoice,
|
ToolChoice,
|
||||||
ToolConfig,
|
ToolConfig,
|
||||||
ToolDefinition,
|
ToolDefinition,
|
||||||
ToolPromptFormat,
|
ToolPromptFormat,
|
||||||
)
|
)
|
||||||
from llama_stack.apis.models import ModelType
|
from llama_stack.apis.models import Model, ModelType
|
||||||
from llama_stack.apis.safety import RunShieldResponse, Safety
|
from llama_stack.apis.safety import RunShieldResponse, Safety
|
||||||
from llama_stack.apis.scoring import (
|
from llama_stack.apis.scoring import (
|
||||||
ScoreBatchResponse,
|
ScoreBatchResponse,
|
||||||
|
@ -42,6 +48,7 @@ from llama_stack.apis.scoring import (
|
||||||
ScoringFnParams,
|
ScoringFnParams,
|
||||||
)
|
)
|
||||||
from llama_stack.apis.shields import Shield
|
from llama_stack.apis.shields import Shield
|
||||||
|
from llama_stack.apis.telemetry import MetricEvent, Telemetry
|
||||||
from llama_stack.apis.tools import (
|
from llama_stack.apis.tools import (
|
||||||
RAGDocument,
|
RAGDocument,
|
||||||
RAGQueryConfig,
|
RAGQueryConfig,
|
||||||
|
@ -52,7 +59,10 @@ from llama_stack.apis.tools import (
|
||||||
)
|
)
|
||||||
from llama_stack.apis.vector_io import Chunk, QueryChunksResponse, VectorIO
|
from llama_stack.apis.vector_io import Chunk, QueryChunksResponse, VectorIO
|
||||||
from llama_stack.log import get_logger
|
from llama_stack.log import get_logger
|
||||||
|
from llama_stack.models.llama.llama3.chat_format import ChatFormat
|
||||||
|
from llama_stack.models.llama.llama3.tokenizer import Tokenizer
|
||||||
from llama_stack.providers.datatypes import RoutingTable
|
from llama_stack.providers.datatypes import RoutingTable
|
||||||
|
from llama_stack.providers.utils.telemetry.tracing import get_current_span
|
||||||
|
|
||||||
logger = get_logger(name=__name__, category="core")
|
logger = get_logger(name=__name__, category="core")
|
||||||
|
|
||||||
|
@ -119,9 +129,14 @@ class InferenceRouter(Inference):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
routing_table: RoutingTable,
|
routing_table: RoutingTable,
|
||||||
|
telemetry: Optional[Telemetry] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
logger.debug("Initializing InferenceRouter")
|
logger.debug("Initializing InferenceRouter")
|
||||||
self.routing_table = routing_table
|
self.routing_table = routing_table
|
||||||
|
self.telemetry = telemetry
|
||||||
|
if self.telemetry:
|
||||||
|
self.tokenizer = Tokenizer.get_instance()
|
||||||
|
self.formatter = ChatFormat(self.tokenizer)
|
||||||
|
|
||||||
async def initialize(self) -> None:
|
async def initialize(self) -> None:
|
||||||
logger.debug("InferenceRouter.initialize")
|
logger.debug("InferenceRouter.initialize")
|
||||||
|
@ -144,6 +159,71 @@ class InferenceRouter(Inference):
|
||||||
)
|
)
|
||||||
await self.routing_table.register_model(model_id, provider_model_id, provider_id, metadata, model_type)
|
await self.routing_table.register_model(model_id, provider_model_id, provider_id, metadata, model_type)
|
||||||
|
|
||||||
|
def _construct_metrics(
|
||||||
|
self, prompt_tokens: int, completion_tokens: int, total_tokens: int, model: Model
|
||||||
|
) -> List[MetricEvent]:
|
||||||
|
"""Constructs a list of MetricEvent objects containing token usage metrics.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
prompt_tokens: Number of tokens in the prompt
|
||||||
|
completion_tokens: Number of tokens in the completion
|
||||||
|
total_tokens: Total number of tokens used
|
||||||
|
model: Model object containing model_id and provider_id
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of MetricEvent objects with token usage metrics
|
||||||
|
"""
|
||||||
|
span = get_current_span()
|
||||||
|
if span is None:
|
||||||
|
logger.warning("No span found for token usage metrics")
|
||||||
|
return []
|
||||||
|
metrics = [
|
||||||
|
("prompt_tokens", prompt_tokens),
|
||||||
|
("completion_tokens", completion_tokens),
|
||||||
|
("total_tokens", total_tokens),
|
||||||
|
]
|
||||||
|
metric_events = []
|
||||||
|
for metric_name, value in metrics:
|
||||||
|
metric_events.append(
|
||||||
|
MetricEvent(
|
||||||
|
trace_id=span.trace_id,
|
||||||
|
span_id=span.span_id,
|
||||||
|
metric=metric_name,
|
||||||
|
value=value,
|
||||||
|
timestamp=time.time(),
|
||||||
|
unit="tokens",
|
||||||
|
attributes={
|
||||||
|
"model_id": model.model_id,
|
||||||
|
"provider_id": model.provider_id,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return metric_events
|
||||||
|
|
||||||
|
async def _compute_and_log_token_usage(
|
||||||
|
self,
|
||||||
|
prompt_tokens: int,
|
||||||
|
completion_tokens: int,
|
||||||
|
total_tokens: int,
|
||||||
|
model: Model,
|
||||||
|
) -> List[MetricEvent]:
|
||||||
|
metrics = self._construct_metrics(prompt_tokens, completion_tokens, total_tokens, model)
|
||||||
|
if self.telemetry:
|
||||||
|
for metric in metrics:
|
||||||
|
await self.telemetry.log_event(metric)
|
||||||
|
return metrics
|
||||||
|
|
||||||
|
async def _count_tokens(
|
||||||
|
self,
|
||||||
|
messages: List[Message] | InterleavedContent,
|
||||||
|
tool_prompt_format: Optional[ToolPromptFormat] = None,
|
||||||
|
) -> Optional[int]:
|
||||||
|
if isinstance(messages, list):
|
||||||
|
encoded = self.formatter.encode_dialog_prompt(messages, tool_prompt_format)
|
||||||
|
else:
|
||||||
|
encoded = self.formatter.encode_content(messages)
|
||||||
|
return len(encoded.tokens) if encoded and encoded.tokens else 0
|
||||||
|
|
||||||
async def chat_completion(
|
async def chat_completion(
|
||||||
self,
|
self,
|
||||||
model_id: str,
|
model_id: str,
|
||||||
|
@ -156,8 +236,9 @@ class InferenceRouter(Inference):
|
||||||
stream: Optional[bool] = False,
|
stream: Optional[bool] = False,
|
||||||
logprobs: Optional[LogProbConfig] = None,
|
logprobs: Optional[LogProbConfig] = None,
|
||||||
tool_config: Optional[ToolConfig] = None,
|
tool_config: Optional[ToolConfig] = None,
|
||||||
) -> AsyncGenerator:
|
) -> Union[ChatCompletionResponse, AsyncIterator[ChatCompletionResponseStreamChunk]]:
|
||||||
logger.debug(
|
logger.debug(
|
||||||
|
"core",
|
||||||
f"InferenceRouter.chat_completion: {model_id=}, {stream=}, {messages=}, {tools=}, {tool_config=}, {response_format=}",
|
f"InferenceRouter.chat_completion: {model_id=}, {stream=}, {messages=}, {tools=}, {tool_config=}, {response_format=}",
|
||||||
)
|
)
|
||||||
if sampling_params is None:
|
if sampling_params is None:
|
||||||
|
@ -206,10 +287,47 @@ class InferenceRouter(Inference):
|
||||||
tool_config=tool_config,
|
tool_config=tool_config,
|
||||||
)
|
)
|
||||||
provider = self.routing_table.get_provider_impl(model_id)
|
provider = self.routing_table.get_provider_impl(model_id)
|
||||||
|
prompt_tokens = await self._count_tokens(messages, tool_config.tool_prompt_format)
|
||||||
|
|
||||||
if stream:
|
if stream:
|
||||||
return (chunk async for chunk in await provider.chat_completion(**params))
|
|
||||||
|
async def stream_generator():
|
||||||
|
completion_text = ""
|
||||||
|
async for chunk in await provider.chat_completion(**params):
|
||||||
|
if chunk.event.event_type == ChatCompletionResponseEventType.progress:
|
||||||
|
if chunk.event.delta.type == "text":
|
||||||
|
completion_text += chunk.event.delta.text
|
||||||
|
if chunk.event.event_type == ChatCompletionResponseEventType.complete:
|
||||||
|
completion_tokens = await self._count_tokens(
|
||||||
|
[CompletionMessage(content=completion_text, stop_reason=StopReason.end_of_turn)],
|
||||||
|
tool_config.tool_prompt_format,
|
||||||
|
)
|
||||||
|
total_tokens = (prompt_tokens or 0) + (completion_tokens or 0)
|
||||||
|
metrics = await self._compute_and_log_token_usage(
|
||||||
|
prompt_tokens or 0,
|
||||||
|
completion_tokens or 0,
|
||||||
|
total_tokens,
|
||||||
|
model,
|
||||||
|
)
|
||||||
|
chunk.metrics = metrics if chunk.metrics is None else chunk.metrics + metrics
|
||||||
|
yield chunk
|
||||||
|
|
||||||
|
return stream_generator()
|
||||||
else:
|
else:
|
||||||
return await provider.chat_completion(**params)
|
response = await provider.chat_completion(**params)
|
||||||
|
completion_tokens = await self._count_tokens(
|
||||||
|
[response.completion_message],
|
||||||
|
tool_config.tool_prompt_format,
|
||||||
|
)
|
||||||
|
total_tokens = (prompt_tokens or 0) + (completion_tokens or 0)
|
||||||
|
metrics = await self._compute_and_log_token_usage(
|
||||||
|
prompt_tokens or 0,
|
||||||
|
completion_tokens or 0,
|
||||||
|
total_tokens,
|
||||||
|
model,
|
||||||
|
)
|
||||||
|
response.metrics = metrics if response.metrics is None else response.metrics + metrics
|
||||||
|
return response
|
||||||
|
|
||||||
async def completion(
|
async def completion(
|
||||||
self,
|
self,
|
||||||
|
@ -239,10 +357,41 @@ class InferenceRouter(Inference):
|
||||||
stream=stream,
|
stream=stream,
|
||||||
logprobs=logprobs,
|
logprobs=logprobs,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
prompt_tokens = await self._count_tokens(content)
|
||||||
|
|
||||||
if stream:
|
if stream:
|
||||||
return (chunk async for chunk in await provider.completion(**params))
|
|
||||||
|
async def stream_generator():
|
||||||
|
completion_text = ""
|
||||||
|
async for chunk in await provider.completion(**params):
|
||||||
|
if hasattr(chunk, "delta"):
|
||||||
|
completion_text += chunk.delta
|
||||||
|
if hasattr(chunk, "stop_reason") and chunk.stop_reason and self.telemetry:
|
||||||
|
completion_tokens = await self._count_tokens(completion_text)
|
||||||
|
total_tokens = (prompt_tokens or 0) + (completion_tokens or 0)
|
||||||
|
metrics = await self._compute_and_log_token_usage(
|
||||||
|
prompt_tokens or 0,
|
||||||
|
completion_tokens or 0,
|
||||||
|
total_tokens,
|
||||||
|
model,
|
||||||
|
)
|
||||||
|
chunk.metrics = metrics if chunk.metrics is None else chunk.metrics + metrics
|
||||||
|
yield chunk
|
||||||
|
|
||||||
|
return stream_generator()
|
||||||
else:
|
else:
|
||||||
return await provider.completion(**params)
|
response = await provider.completion(**params)
|
||||||
|
completion_tokens = await self._count_tokens(response.content)
|
||||||
|
total_tokens = (prompt_tokens or 0) + (completion_tokens or 0)
|
||||||
|
metrics = await self._compute_and_log_token_usage(
|
||||||
|
prompt_tokens or 0,
|
||||||
|
completion_tokens or 0,
|
||||||
|
total_tokens,
|
||||||
|
model,
|
||||||
|
)
|
||||||
|
response.metrics = metrics if response.metrics is None else response.metrics + metrics
|
||||||
|
return response
|
||||||
|
|
||||||
async def embeddings(
|
async def embeddings(
|
||||||
self,
|
self,
|
||||||
|
|
|
@ -28,7 +28,7 @@ from typing_extensions import Annotated
|
||||||
from llama_stack.distribution.datatypes import StackRunConfig
|
from llama_stack.distribution.datatypes import StackRunConfig
|
||||||
from llama_stack.distribution.distribution import builtin_automatically_routed_apis
|
from llama_stack.distribution.distribution import builtin_automatically_routed_apis
|
||||||
from llama_stack.distribution.request_headers import (
|
from llama_stack.distribution.request_headers import (
|
||||||
preserve_headers_context_async_generator,
|
PROVIDER_DATA_VAR,
|
||||||
request_provider_data_context,
|
request_provider_data_context,
|
||||||
)
|
)
|
||||||
from llama_stack.distribution.resolver import InvalidProviderError
|
from llama_stack.distribution.resolver import InvalidProviderError
|
||||||
|
@ -38,6 +38,7 @@ from llama_stack.distribution.stack import (
|
||||||
replace_env_vars,
|
replace_env_vars,
|
||||||
validate_env_pair,
|
validate_env_pair,
|
||||||
)
|
)
|
||||||
|
from llama_stack.distribution.utils.context import preserve_contexts_async_generator
|
||||||
from llama_stack.log import get_logger
|
from llama_stack.log import get_logger
|
||||||
from llama_stack.providers.datatypes import Api
|
from llama_stack.providers.datatypes import Api
|
||||||
from llama_stack.providers.inline.telemetry.meta_reference.config import TelemetryConfig
|
from llama_stack.providers.inline.telemetry.meta_reference.config import TelemetryConfig
|
||||||
|
@ -45,6 +46,7 @@ from llama_stack.providers.inline.telemetry.meta_reference.telemetry import (
|
||||||
TelemetryAdapter,
|
TelemetryAdapter,
|
||||||
)
|
)
|
||||||
from llama_stack.providers.utils.telemetry.tracing import (
|
from llama_stack.providers.utils.telemetry.tracing import (
|
||||||
|
CURRENT_TRACE_CONTEXT,
|
||||||
end_trace,
|
end_trace,
|
||||||
setup_logger,
|
setup_logger,
|
||||||
start_trace,
|
start_trace,
|
||||||
|
@ -182,7 +184,9 @@ def create_dynamic_typed_route(func: Any, method: str, route: str):
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if is_streaming:
|
if is_streaming:
|
||||||
gen = preserve_headers_context_async_generator(sse_generator(func(**kwargs)))
|
gen = preserve_contexts_async_generator(
|
||||||
|
sse_generator(func(**kwargs)), [CURRENT_TRACE_CONTEXT, PROVIDER_DATA_VAR]
|
||||||
|
)
|
||||||
return StreamingResponse(gen, media_type="text/event-stream")
|
return StreamingResponse(gen, media_type="text/event-stream")
|
||||||
else:
|
else:
|
||||||
value = func(**kwargs)
|
value = func(**kwargs)
|
||||||
|
|
33
llama_stack/distribution/utils/context.py
Normal file
33
llama_stack/distribution/utils/context.py
Normal file
|
@ -0,0 +1,33 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from contextvars import ContextVar
|
||||||
|
from typing import AsyncGenerator, List, TypeVar
|
||||||
|
|
||||||
|
T = TypeVar("T")
|
||||||
|
|
||||||
|
|
||||||
|
def preserve_contexts_async_generator(
|
||||||
|
gen: AsyncGenerator[T, None], context_vars: List[ContextVar]
|
||||||
|
) -> AsyncGenerator[T, None]:
|
||||||
|
"""
|
||||||
|
Wraps an async generator to preserve context variables across iterations.
|
||||||
|
This is needed because we start a new asyncio event loop for each streaming request,
|
||||||
|
and we need to preserve the context across the event loop boundary.
|
||||||
|
"""
|
||||||
|
|
||||||
|
async def wrapper():
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
item = await gen.__anext__()
|
||||||
|
context_values = {context_var.name: context_var.get() for context_var in context_vars}
|
||||||
|
yield item
|
||||||
|
for context_var in context_vars:
|
||||||
|
_ = context_var.set(context_values[context_var.name])
|
||||||
|
except StopAsyncIteration:
|
||||||
|
break
|
||||||
|
|
||||||
|
return wrapper()
|
155
llama_stack/distribution/utils/tests/test_context.py
Normal file
155
llama_stack/distribution/utils/tests/test_context.py
Normal file
|
@ -0,0 +1,155 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
|
from contextvars import ContextVar
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from llama_stack.distribution.utils.context import preserve_contexts_async_generator
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_preserve_contexts_with_exception():
|
||||||
|
# Create context variable
|
||||||
|
context_var = ContextVar("exception_var", default="initial")
|
||||||
|
token = context_var.set("start_value")
|
||||||
|
|
||||||
|
# Create an async generator that raises an exception
|
||||||
|
async def exception_generator():
|
||||||
|
yield context_var.get()
|
||||||
|
context_var.set("modified")
|
||||||
|
raise ValueError("Test exception")
|
||||||
|
yield None # This will never be reached
|
||||||
|
|
||||||
|
# Wrap the generator
|
||||||
|
wrapped_gen = preserve_contexts_async_generator(exception_generator(), [context_var])
|
||||||
|
|
||||||
|
# First iteration should work
|
||||||
|
value = await wrapped_gen.__anext__()
|
||||||
|
assert value == "start_value"
|
||||||
|
|
||||||
|
# Second iteration should raise the exception
|
||||||
|
with pytest.raises(ValueError, match="Test exception"):
|
||||||
|
await wrapped_gen.__anext__()
|
||||||
|
|
||||||
|
# Clean up
|
||||||
|
context_var.reset(token)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_preserve_contexts_empty_generator():
|
||||||
|
# Create context variable
|
||||||
|
context_var = ContextVar("empty_var", default="initial")
|
||||||
|
token = context_var.set("value")
|
||||||
|
|
||||||
|
# Create an empty async generator
|
||||||
|
async def empty_generator():
|
||||||
|
if False: # This condition ensures the generator yields nothing
|
||||||
|
yield None
|
||||||
|
|
||||||
|
# Wrap the generator
|
||||||
|
wrapped_gen = preserve_contexts_async_generator(empty_generator(), [context_var])
|
||||||
|
|
||||||
|
# The generator should raise StopAsyncIteration immediately
|
||||||
|
with pytest.raises(StopAsyncIteration):
|
||||||
|
await wrapped_gen.__anext__()
|
||||||
|
|
||||||
|
# Context variable should remain unchanged
|
||||||
|
assert context_var.get() == "value"
|
||||||
|
|
||||||
|
# Clean up
|
||||||
|
context_var.reset(token)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_preserve_contexts_across_event_loops():
|
||||||
|
"""
|
||||||
|
Test that context variables are preserved across event loop boundaries with nested generators.
|
||||||
|
This simulates the real-world scenario where:
|
||||||
|
1. A new event loop is created for each streaming request
|
||||||
|
2. The async generator runs inside that loop
|
||||||
|
3. There are multiple levels of nested generators
|
||||||
|
4. Context needs to be preserved across these boundaries
|
||||||
|
"""
|
||||||
|
# Create context variables
|
||||||
|
request_id = ContextVar("request_id", default=None)
|
||||||
|
user_id = ContextVar("user_id", default=None)
|
||||||
|
|
||||||
|
# Set initial values
|
||||||
|
|
||||||
|
# Results container to verify values across thread boundaries
|
||||||
|
results = []
|
||||||
|
|
||||||
|
# Inner-most generator (level 2)
|
||||||
|
async def inner_generator():
|
||||||
|
# Should have the context from the outer scope
|
||||||
|
yield (1, request_id.get(), user_id.get())
|
||||||
|
|
||||||
|
# Modify one context variable
|
||||||
|
user_id.set("user-modified")
|
||||||
|
|
||||||
|
# Should reflect the modification
|
||||||
|
yield (2, request_id.get(), user_id.get())
|
||||||
|
|
||||||
|
# Middle generator (level 1)
|
||||||
|
async def middle_generator():
|
||||||
|
inner_gen = inner_generator()
|
||||||
|
|
||||||
|
# Forward the first yield from inner
|
||||||
|
item = await inner_gen.__anext__()
|
||||||
|
yield item
|
||||||
|
|
||||||
|
# Forward the second yield from inner
|
||||||
|
item = await inner_gen.__anext__()
|
||||||
|
yield item
|
||||||
|
|
||||||
|
request_id.set("req-modified")
|
||||||
|
|
||||||
|
# Add our own yield with both modified variables
|
||||||
|
yield (3, request_id.get(), user_id.get())
|
||||||
|
|
||||||
|
# Function to run in a separate thread with a new event loop
|
||||||
|
def run_in_new_loop():
|
||||||
|
# Create a new event loop for this thread
|
||||||
|
loop = asyncio.new_event_loop()
|
||||||
|
asyncio.set_event_loop(loop)
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Outer generator (runs in the new loop)
|
||||||
|
async def outer_generator():
|
||||||
|
request_id.set("req-12345")
|
||||||
|
user_id.set("user-6789")
|
||||||
|
# Wrap the middle generator
|
||||||
|
wrapped_gen = preserve_contexts_async_generator(middle_generator(), [request_id, user_id])
|
||||||
|
|
||||||
|
# Process all items from the middle generator
|
||||||
|
async for item in wrapped_gen:
|
||||||
|
# Store results for verification
|
||||||
|
results.append(item)
|
||||||
|
|
||||||
|
# Run the outer generator in the new loop
|
||||||
|
loop.run_until_complete(outer_generator())
|
||||||
|
finally:
|
||||||
|
loop.close()
|
||||||
|
|
||||||
|
# Run the generator chain in a separate thread with a new event loop
|
||||||
|
with ThreadPoolExecutor(max_workers=1) as executor:
|
||||||
|
future = executor.submit(run_in_new_loop)
|
||||||
|
future.result() # Wait for completion
|
||||||
|
|
||||||
|
# Verify the results
|
||||||
|
assert len(results) == 3
|
||||||
|
|
||||||
|
# First yield should have original values
|
||||||
|
assert results[0] == (1, "req-12345", "user-6789")
|
||||||
|
|
||||||
|
# Second yield should have modified user_id
|
||||||
|
assert results[1] == (2, "req-12345", "user-modified")
|
||||||
|
|
||||||
|
# Third yield should have both modified values
|
||||||
|
assert results[2] == (3, "req-modified", "user-modified")
|
|
@ -73,6 +73,7 @@ class TelemetryAdapter(TelemetryDatasetMixin, Telemetry):
|
||||||
def __init__(self, config: TelemetryConfig, deps: Dict[str, Any]) -> None:
|
def __init__(self, config: TelemetryConfig, deps: Dict[str, Any]) -> None:
|
||||||
self.config = config
|
self.config = config
|
||||||
self.datasetio_api = deps.get(Api.datasetio)
|
self.datasetio_api = deps.get(Api.datasetio)
|
||||||
|
self.meter = None
|
||||||
|
|
||||||
resource = Resource.create(
|
resource = Resource.create(
|
||||||
{
|
{
|
||||||
|
@ -171,6 +172,8 @@ class TelemetryAdapter(TelemetryDatasetMixin, Telemetry):
|
||||||
return _GLOBAL_STORAGE["gauges"][name]
|
return _GLOBAL_STORAGE["gauges"][name]
|
||||||
|
|
||||||
def _log_metric(self, event: MetricEvent) -> None:
|
def _log_metric(self, event: MetricEvent) -> None:
|
||||||
|
if self.meter is None:
|
||||||
|
return
|
||||||
if isinstance(event.value, int):
|
if isinstance(event.value, int):
|
||||||
counter = self._get_or_create_counter(event.metric, event.unit)
|
counter = self._get_or_create_counter(event.metric, event.unit)
|
||||||
counter.add(event.value, attributes=event.attributes)
|
counter.add(event.value, attributes=event.attributes)
|
||||||
|
|
|
@ -4,8 +4,9 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from typing import List, Tuple
|
from typing import Dict, List, Tuple
|
||||||
|
|
||||||
|
from llama_stack.apis.common.content_types import URL
|
||||||
from llama_stack.apis.models.models import ModelType
|
from llama_stack.apis.models.models import ModelType
|
||||||
from llama_stack.distribution.datatypes import (
|
from llama_stack.distribution.datatypes import (
|
||||||
BenchmarkInput,
|
BenchmarkInput,
|
||||||
|
@ -15,21 +16,27 @@ from llama_stack.distribution.datatypes import (
|
||||||
ShieldInput,
|
ShieldInput,
|
||||||
ToolGroupInput,
|
ToolGroupInput,
|
||||||
)
|
)
|
||||||
from llama_stack.providers.inline.vector_io.sqlite_vec.config import SQLiteVectorIOConfig
|
from llama_stack.providers.inline.vector_io.sqlite_vec.config import (
|
||||||
|
SQLiteVectorIOConfig,
|
||||||
|
)
|
||||||
from llama_stack.providers.remote.inference.anthropic.config import AnthropicConfig
|
from llama_stack.providers.remote.inference.anthropic.config import AnthropicConfig
|
||||||
from llama_stack.providers.remote.inference.gemini.config import GeminiConfig
|
from llama_stack.providers.remote.inference.gemini.config import GeminiConfig
|
||||||
from llama_stack.providers.remote.inference.groq.config import GroqConfig
|
from llama_stack.providers.remote.inference.groq.config import GroqConfig
|
||||||
from llama_stack.providers.remote.inference.openai.config import OpenAIConfig
|
from llama_stack.providers.remote.inference.openai.config import OpenAIConfig
|
||||||
from llama_stack.providers.remote.inference.together.config import TogetherImplConfig
|
from llama_stack.providers.remote.inference.together.config import TogetherImplConfig
|
||||||
from llama_stack.providers.remote.vector_io.chroma.config import ChromaVectorIOConfig
|
from llama_stack.providers.remote.vector_io.chroma.config import ChromaVectorIOConfig
|
||||||
from llama_stack.providers.remote.vector_io.pgvector.config import PGVectorVectorIOConfig
|
from llama_stack.providers.remote.vector_io.pgvector.config import (
|
||||||
from llama_stack.providers.utils.inference.model_registry import (
|
PGVectorVectorIOConfig,
|
||||||
ProviderModelEntry,
|
)
|
||||||
|
from llama_stack.providers.utils.inference.model_registry import ProviderModelEntry
|
||||||
|
from llama_stack.templates.template import (
|
||||||
|
DistributionTemplate,
|
||||||
|
RunConfigSettings,
|
||||||
|
get_model_registry,
|
||||||
)
|
)
|
||||||
from llama_stack.templates.template import DistributionTemplate, RunConfigSettings, get_model_registry
|
|
||||||
|
|
||||||
|
|
||||||
def get_inference_providers() -> Tuple[List[Provider], List[ModelInput]]:
|
def get_inference_providers() -> Tuple[List[Provider], Dict[str, List[ProviderModelEntry]]]:
|
||||||
# in this template, we allow each API key to be optional
|
# in this template, we allow each API key to be optional
|
||||||
providers = [
|
providers = [
|
||||||
(
|
(
|
||||||
|
@ -164,7 +171,7 @@ def get_distribution_template() -> DistributionTemplate:
|
||||||
DatasetInput(
|
DatasetInput(
|
||||||
dataset_id="simpleqa",
|
dataset_id="simpleqa",
|
||||||
provider_id="huggingface",
|
provider_id="huggingface",
|
||||||
url={"uri": "https://huggingface.co/datasets/llamastack/simpleqa"},
|
url=URL(uri="https://huggingface.co/datasets/llamastack/simpleqa"),
|
||||||
metadata={
|
metadata={
|
||||||
"path": "llamastack/simpleqa",
|
"path": "llamastack/simpleqa",
|
||||||
"split": "train",
|
"split": "train",
|
||||||
|
@ -178,7 +185,7 @@ def get_distribution_template() -> DistributionTemplate:
|
||||||
DatasetInput(
|
DatasetInput(
|
||||||
dataset_id="mmlu_cot",
|
dataset_id="mmlu_cot",
|
||||||
provider_id="huggingface",
|
provider_id="huggingface",
|
||||||
url={"uri": "https://huggingface.co/datasets/llamastack/mmlu_cot"},
|
url=URL(uri="https://huggingface.co/datasets/llamastack/mmlu_cot"),
|
||||||
metadata={
|
metadata={
|
||||||
"path": "llamastack/mmlu_cot",
|
"path": "llamastack/mmlu_cot",
|
||||||
"name": "all",
|
"name": "all",
|
||||||
|
@ -193,7 +200,7 @@ def get_distribution_template() -> DistributionTemplate:
|
||||||
DatasetInput(
|
DatasetInput(
|
||||||
dataset_id="gpqa_cot",
|
dataset_id="gpqa_cot",
|
||||||
provider_id="huggingface",
|
provider_id="huggingface",
|
||||||
url={"uri": "https://huggingface.co/datasets/llamastack/gpqa_0shot_cot"},
|
url=URL(uri="https://huggingface.co/datasets/llamastack/gpqa_0shot_cot"),
|
||||||
metadata={
|
metadata={
|
||||||
"path": "llamastack/gpqa_0shot_cot",
|
"path": "llamastack/gpqa_0shot_cot",
|
||||||
"name": "gpqa_main",
|
"name": "gpqa_main",
|
||||||
|
@ -208,7 +215,7 @@ def get_distribution_template() -> DistributionTemplate:
|
||||||
DatasetInput(
|
DatasetInput(
|
||||||
dataset_id="math_500",
|
dataset_id="math_500",
|
||||||
provider_id="huggingface",
|
provider_id="huggingface",
|
||||||
url={"uri": "https://huggingface.co/datasets/llamastack/math_500"},
|
url=URL(uri="https://huggingface.co/datasets/llamastack/math_500"),
|
||||||
metadata={
|
metadata={
|
||||||
"path": "llamastack/math_500",
|
"path": "llamastack/math_500",
|
||||||
"split": "test",
|
"split": "test",
|
||||||
|
|
|
@ -30,7 +30,9 @@ from llama_stack.providers.utils.inference.model_registry import ProviderModelEn
|
||||||
from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig
|
from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig
|
||||||
|
|
||||||
|
|
||||||
def get_model_registry(available_models: Dict[str, List[ProviderModelEntry]]) -> List[ModelInput]:
|
def get_model_registry(
|
||||||
|
available_models: Dict[str, List[ProviderModelEntry]],
|
||||||
|
) -> List[ModelInput]:
|
||||||
models = []
|
models = []
|
||||||
for provider_id, entries in available_models.items():
|
for provider_id, entries in available_models.items():
|
||||||
for entry in entries:
|
for entry in entries:
|
||||||
|
@ -193,7 +195,7 @@ class DistributionTemplate(BaseModel):
|
||||||
default_models.append(
|
default_models.append(
|
||||||
DefaultModel(
|
DefaultModel(
|
||||||
model_id=model_entry.provider_model_id,
|
model_id=model_entry.provider_model_id,
|
||||||
doc_string=f"({' -- '.join(doc_parts)})" if doc_parts else "",
|
doc_string=(f"({' -- '.join(doc_parts)})" if doc_parts else ""),
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue