Merge remote-tracking branch 'origin/main' into support_more_data_format

This commit is contained in:
Botao Chen 2025-01-14 11:55:13 -08:00
commit 8d7bb1140f
20 changed files with 381 additions and 414 deletions

View file

@ -40,7 +40,12 @@ from llama_stack.apis.agents import (
ToolExecutionStep,
Turn,
)
from llama_stack.apis.common.content_types import TextContentItem, URL
from llama_stack.apis.common.content_types import (
TextContentItem,
ToolCallDelta,
ToolCallParseStatus,
URL,
)
from llama_stack.apis.inference import (
ChatCompletionResponseEventType,
CompletionMessage,
@ -49,8 +54,6 @@ from llama_stack.apis.inference import (
SamplingParams,
StopReason,
SystemMessage,
ToolCallDelta,
ToolCallParseStatus,
ToolDefinition,
ToolResponse,
ToolResponseMessage,
@ -411,8 +414,8 @@ class ChatAgent(ShieldRunnerMixin):
payload=AgentTurnResponseStepProgressPayload(
step_type=StepType.tool_execution.value,
step_id=step_id,
tool_call_delta=ToolCallDelta(
parse_status=ToolCallParseStatus.success,
delta=ToolCallDelta(
parse_status=ToolCallParseStatus.succeeded,
content=ToolCall(
call_id="",
tool_name=MEMORY_QUERY_TOOL,
@ -507,8 +510,8 @@ class ChatAgent(ShieldRunnerMixin):
continue
delta = event.delta
if isinstance(delta, ToolCallDelta):
if delta.parse_status == ToolCallParseStatus.success:
if delta.type == "tool_call":
if delta.parse_status == ToolCallParseStatus.succeeded:
tool_calls.append(delta.content)
if stream:
yield AgentTurnResponseStreamChunk(
@ -516,21 +519,20 @@ class ChatAgent(ShieldRunnerMixin):
payload=AgentTurnResponseStepProgressPayload(
step_type=StepType.inference.value,
step_id=step_id,
text_delta="",
tool_call_delta=delta,
delta=delta,
)
)
)
elif isinstance(delta, str):
content += delta
elif delta.type == "text":
content += delta.text
if stream and event.stop_reason is None:
yield AgentTurnResponseStreamChunk(
event=AgentTurnResponseEvent(
payload=AgentTurnResponseStepProgressPayload(
step_type=StepType.inference.value,
step_id=step_id,
text_delta=event.delta,
delta=delta,
)
)
)

View file

@ -16,6 +16,11 @@ from llama_models.llama3.api.datatypes import (
)
from llama_models.sku_list import resolve_model
from llama_stack.apis.common.content_types import (
TextDelta,
ToolCallDelta,
ToolCallParseStatus,
)
from llama_stack.apis.inference import (
ChatCompletionRequest,
ChatCompletionResponse,
@ -32,8 +37,6 @@ from llama_stack.apis.inference import (
Message,
ResponseFormat,
TokenLogProbs,
ToolCallDelta,
ToolCallParseStatus,
ToolChoice,
)
from llama_stack.apis.models import Model, ModelType
@ -190,14 +193,14 @@ class MetaReferenceInferenceImpl(
]
yield CompletionResponseStreamChunk(
delta=text,
delta=TextDelta(text=text),
stop_reason=stop_reason,
logprobs=logprobs if request.logprobs else None,
)
if stop_reason is None:
yield CompletionResponseStreamChunk(
delta="",
delta=TextDelta(text=""),
stop_reason=StopReason.out_of_tokens,
)
@ -352,7 +355,7 @@ class MetaReferenceInferenceImpl(
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.start,
delta="",
delta=TextDelta(text=""),
)
)
@ -392,7 +395,7 @@ class MetaReferenceInferenceImpl(
parse_status=ToolCallParseStatus.in_progress,
)
else:
delta = text
delta = TextDelta(text=text)
if stop_reason is None:
if request.logprobs:
@ -428,7 +431,7 @@ class MetaReferenceInferenceImpl(
event_type=ChatCompletionResponseEventType.progress,
delta=ToolCallDelta(
content="",
parse_status=ToolCallParseStatus.failure,
parse_status=ToolCallParseStatus.failed,
),
stop_reason=stop_reason,
)
@ -440,7 +443,7 @@ class MetaReferenceInferenceImpl(
event_type=ChatCompletionResponseEventType.progress,
delta=ToolCallDelta(
content=tool_call,
parse_status=ToolCallParseStatus.success,
parse_status=ToolCallParseStatus.succeeded,
),
stop_reason=stop_reason,
)
@ -449,7 +452,7 @@ class MetaReferenceInferenceImpl(
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.complete,
delta="",
delta=TextDelta(text=""),
stop_reason=stop_reason,
)
)

View file

@ -30,13 +30,10 @@ from llama_stack.apis.telemetry import (
Trace,
UnstructuredLogEvent,
)
from llama_stack.distribution.datatypes import Api
from llama_stack.providers.inline.telemetry.meta_reference.console_span_processor import (
ConsoleSpanProcessor,
)
from llama_stack.providers.inline.telemetry.meta_reference.sqlite_span_processor import (
SQLiteSpanProcessor,
)
@ -52,6 +49,7 @@ _GLOBAL_STORAGE = {
"up_down_counters": {},
}
_global_lock = threading.Lock()
_TRACER_PROVIDER = None
def string_to_trace_id(s: str) -> int:
@ -80,31 +78,34 @@ class TelemetryAdapter(TelemetryDatasetMixin, Telemetry):
}
)
provider = TracerProvider(resource=resource)
trace.set_tracer_provider(provider)
if TelemetrySink.OTEL in self.config.sinks:
otlp_exporter = OTLPSpanExporter(
endpoint=self.config.otel_endpoint,
)
span_processor = BatchSpanProcessor(otlp_exporter)
trace.get_tracer_provider().add_span_processor(span_processor)
metric_reader = PeriodicExportingMetricReader(
OTLPMetricExporter(
global _TRACER_PROVIDER
if _TRACER_PROVIDER is None:
provider = TracerProvider(resource=resource)
trace.set_tracer_provider(provider)
_TRACER_PROVIDER = provider
if TelemetrySink.OTEL in self.config.sinks:
otlp_exporter = OTLPSpanExporter(
endpoint=self.config.otel_endpoint,
)
)
metric_provider = MeterProvider(
resource=resource, metric_readers=[metric_reader]
)
metrics.set_meter_provider(metric_provider)
self.meter = metrics.get_meter(__name__)
if TelemetrySink.SQLITE in self.config.sinks:
trace.get_tracer_provider().add_span_processor(
SQLiteSpanProcessor(self.config.sqlite_db_path)
)
self.trace_store = SQLiteTraceStore(self.config.sqlite_db_path)
if TelemetrySink.CONSOLE in self.config.sinks:
trace.get_tracer_provider().add_span_processor(ConsoleSpanProcessor())
span_processor = BatchSpanProcessor(otlp_exporter)
trace.get_tracer_provider().add_span_processor(span_processor)
metric_reader = PeriodicExportingMetricReader(
OTLPMetricExporter(
endpoint=self.config.otel_endpoint,
)
)
metric_provider = MeterProvider(
resource=resource, metric_readers=[metric_reader]
)
metrics.set_meter_provider(metric_provider)
self.meter = metrics.get_meter(__name__)
if TelemetrySink.SQLITE in self.config.sinks:
trace.get_tracer_provider().add_span_processor(
SQLiteSpanProcessor(self.config.sqlite_db_path)
)
self.trace_store = SQLiteTraceStore(self.config.sqlite_db_path)
if TelemetrySink.CONSOLE in self.config.sinks:
trace.get_tracer_provider().add_span_processor(ConsoleSpanProcessor())
self._lock = _global_lock
async def initialize(self) -> None: