mirror of
https://github.com/meta-llama/llama-stack.git
synced 2026-01-02 10:44:30 +00:00
Merge remote-tracking branch 'origin/main' into support_more_data_format
This commit is contained in:
commit
8d7bb1140f
20 changed files with 381 additions and 414 deletions
|
|
@ -40,7 +40,12 @@ from llama_stack.apis.agents import (
|
|||
ToolExecutionStep,
|
||||
Turn,
|
||||
)
|
||||
from llama_stack.apis.common.content_types import TextContentItem, URL
|
||||
from llama_stack.apis.common.content_types import (
|
||||
TextContentItem,
|
||||
ToolCallDelta,
|
||||
ToolCallParseStatus,
|
||||
URL,
|
||||
)
|
||||
from llama_stack.apis.inference import (
|
||||
ChatCompletionResponseEventType,
|
||||
CompletionMessage,
|
||||
|
|
@ -49,8 +54,6 @@ from llama_stack.apis.inference import (
|
|||
SamplingParams,
|
||||
StopReason,
|
||||
SystemMessage,
|
||||
ToolCallDelta,
|
||||
ToolCallParseStatus,
|
||||
ToolDefinition,
|
||||
ToolResponse,
|
||||
ToolResponseMessage,
|
||||
|
|
@ -411,8 +414,8 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
payload=AgentTurnResponseStepProgressPayload(
|
||||
step_type=StepType.tool_execution.value,
|
||||
step_id=step_id,
|
||||
tool_call_delta=ToolCallDelta(
|
||||
parse_status=ToolCallParseStatus.success,
|
||||
delta=ToolCallDelta(
|
||||
parse_status=ToolCallParseStatus.succeeded,
|
||||
content=ToolCall(
|
||||
call_id="",
|
||||
tool_name=MEMORY_QUERY_TOOL,
|
||||
|
|
@ -507,8 +510,8 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
continue
|
||||
|
||||
delta = event.delta
|
||||
if isinstance(delta, ToolCallDelta):
|
||||
if delta.parse_status == ToolCallParseStatus.success:
|
||||
if delta.type == "tool_call":
|
||||
if delta.parse_status == ToolCallParseStatus.succeeded:
|
||||
tool_calls.append(delta.content)
|
||||
if stream:
|
||||
yield AgentTurnResponseStreamChunk(
|
||||
|
|
@ -516,21 +519,20 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
payload=AgentTurnResponseStepProgressPayload(
|
||||
step_type=StepType.inference.value,
|
||||
step_id=step_id,
|
||||
text_delta="",
|
||||
tool_call_delta=delta,
|
||||
delta=delta,
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
elif isinstance(delta, str):
|
||||
content += delta
|
||||
elif delta.type == "text":
|
||||
content += delta.text
|
||||
if stream and event.stop_reason is None:
|
||||
yield AgentTurnResponseStreamChunk(
|
||||
event=AgentTurnResponseEvent(
|
||||
payload=AgentTurnResponseStepProgressPayload(
|
||||
step_type=StepType.inference.value,
|
||||
step_id=step_id,
|
||||
text_delta=event.delta,
|
||||
delta=delta,
|
||||
)
|
||||
)
|
||||
)
|
||||
|
|
|
|||
|
|
@ -16,6 +16,11 @@ from llama_models.llama3.api.datatypes import (
|
|||
)
|
||||
from llama_models.sku_list import resolve_model
|
||||
|
||||
from llama_stack.apis.common.content_types import (
|
||||
TextDelta,
|
||||
ToolCallDelta,
|
||||
ToolCallParseStatus,
|
||||
)
|
||||
from llama_stack.apis.inference import (
|
||||
ChatCompletionRequest,
|
||||
ChatCompletionResponse,
|
||||
|
|
@ -32,8 +37,6 @@ from llama_stack.apis.inference import (
|
|||
Message,
|
||||
ResponseFormat,
|
||||
TokenLogProbs,
|
||||
ToolCallDelta,
|
||||
ToolCallParseStatus,
|
||||
ToolChoice,
|
||||
)
|
||||
from llama_stack.apis.models import Model, ModelType
|
||||
|
|
@ -190,14 +193,14 @@ class MetaReferenceInferenceImpl(
|
|||
]
|
||||
|
||||
yield CompletionResponseStreamChunk(
|
||||
delta=text,
|
||||
delta=TextDelta(text=text),
|
||||
stop_reason=stop_reason,
|
||||
logprobs=logprobs if request.logprobs else None,
|
||||
)
|
||||
|
||||
if stop_reason is None:
|
||||
yield CompletionResponseStreamChunk(
|
||||
delta="",
|
||||
delta=TextDelta(text=""),
|
||||
stop_reason=StopReason.out_of_tokens,
|
||||
)
|
||||
|
||||
|
|
@ -352,7 +355,7 @@ class MetaReferenceInferenceImpl(
|
|||
yield ChatCompletionResponseStreamChunk(
|
||||
event=ChatCompletionResponseEvent(
|
||||
event_type=ChatCompletionResponseEventType.start,
|
||||
delta="",
|
||||
delta=TextDelta(text=""),
|
||||
)
|
||||
)
|
||||
|
||||
|
|
@ -392,7 +395,7 @@ class MetaReferenceInferenceImpl(
|
|||
parse_status=ToolCallParseStatus.in_progress,
|
||||
)
|
||||
else:
|
||||
delta = text
|
||||
delta = TextDelta(text=text)
|
||||
|
||||
if stop_reason is None:
|
||||
if request.logprobs:
|
||||
|
|
@ -428,7 +431,7 @@ class MetaReferenceInferenceImpl(
|
|||
event_type=ChatCompletionResponseEventType.progress,
|
||||
delta=ToolCallDelta(
|
||||
content="",
|
||||
parse_status=ToolCallParseStatus.failure,
|
||||
parse_status=ToolCallParseStatus.failed,
|
||||
),
|
||||
stop_reason=stop_reason,
|
||||
)
|
||||
|
|
@ -440,7 +443,7 @@ class MetaReferenceInferenceImpl(
|
|||
event_type=ChatCompletionResponseEventType.progress,
|
||||
delta=ToolCallDelta(
|
||||
content=tool_call,
|
||||
parse_status=ToolCallParseStatus.success,
|
||||
parse_status=ToolCallParseStatus.succeeded,
|
||||
),
|
||||
stop_reason=stop_reason,
|
||||
)
|
||||
|
|
@ -449,7 +452,7 @@ class MetaReferenceInferenceImpl(
|
|||
yield ChatCompletionResponseStreamChunk(
|
||||
event=ChatCompletionResponseEvent(
|
||||
event_type=ChatCompletionResponseEventType.complete,
|
||||
delta="",
|
||||
delta=TextDelta(text=""),
|
||||
stop_reason=stop_reason,
|
||||
)
|
||||
)
|
||||
|
|
|
|||
|
|
@ -30,13 +30,10 @@ from llama_stack.apis.telemetry import (
|
|||
Trace,
|
||||
UnstructuredLogEvent,
|
||||
)
|
||||
|
||||
from llama_stack.distribution.datatypes import Api
|
||||
|
||||
from llama_stack.providers.inline.telemetry.meta_reference.console_span_processor import (
|
||||
ConsoleSpanProcessor,
|
||||
)
|
||||
|
||||
from llama_stack.providers.inline.telemetry.meta_reference.sqlite_span_processor import (
|
||||
SQLiteSpanProcessor,
|
||||
)
|
||||
|
|
@ -52,6 +49,7 @@ _GLOBAL_STORAGE = {
|
|||
"up_down_counters": {},
|
||||
}
|
||||
_global_lock = threading.Lock()
|
||||
_TRACER_PROVIDER = None
|
||||
|
||||
|
||||
def string_to_trace_id(s: str) -> int:
|
||||
|
|
@ -80,31 +78,34 @@ class TelemetryAdapter(TelemetryDatasetMixin, Telemetry):
|
|||
}
|
||||
)
|
||||
|
||||
provider = TracerProvider(resource=resource)
|
||||
trace.set_tracer_provider(provider)
|
||||
if TelemetrySink.OTEL in self.config.sinks:
|
||||
otlp_exporter = OTLPSpanExporter(
|
||||
endpoint=self.config.otel_endpoint,
|
||||
)
|
||||
span_processor = BatchSpanProcessor(otlp_exporter)
|
||||
trace.get_tracer_provider().add_span_processor(span_processor)
|
||||
metric_reader = PeriodicExportingMetricReader(
|
||||
OTLPMetricExporter(
|
||||
global _TRACER_PROVIDER
|
||||
if _TRACER_PROVIDER is None:
|
||||
provider = TracerProvider(resource=resource)
|
||||
trace.set_tracer_provider(provider)
|
||||
_TRACER_PROVIDER = provider
|
||||
if TelemetrySink.OTEL in self.config.sinks:
|
||||
otlp_exporter = OTLPSpanExporter(
|
||||
endpoint=self.config.otel_endpoint,
|
||||
)
|
||||
)
|
||||
metric_provider = MeterProvider(
|
||||
resource=resource, metric_readers=[metric_reader]
|
||||
)
|
||||
metrics.set_meter_provider(metric_provider)
|
||||
self.meter = metrics.get_meter(__name__)
|
||||
if TelemetrySink.SQLITE in self.config.sinks:
|
||||
trace.get_tracer_provider().add_span_processor(
|
||||
SQLiteSpanProcessor(self.config.sqlite_db_path)
|
||||
)
|
||||
self.trace_store = SQLiteTraceStore(self.config.sqlite_db_path)
|
||||
if TelemetrySink.CONSOLE in self.config.sinks:
|
||||
trace.get_tracer_provider().add_span_processor(ConsoleSpanProcessor())
|
||||
span_processor = BatchSpanProcessor(otlp_exporter)
|
||||
trace.get_tracer_provider().add_span_processor(span_processor)
|
||||
metric_reader = PeriodicExportingMetricReader(
|
||||
OTLPMetricExporter(
|
||||
endpoint=self.config.otel_endpoint,
|
||||
)
|
||||
)
|
||||
metric_provider = MeterProvider(
|
||||
resource=resource, metric_readers=[metric_reader]
|
||||
)
|
||||
metrics.set_meter_provider(metric_provider)
|
||||
self.meter = metrics.get_meter(__name__)
|
||||
if TelemetrySink.SQLITE in self.config.sinks:
|
||||
trace.get_tracer_provider().add_span_processor(
|
||||
SQLiteSpanProcessor(self.config.sqlite_db_path)
|
||||
)
|
||||
self.trace_store = SQLiteTraceStore(self.config.sqlite_db_path)
|
||||
if TelemetrySink.CONSOLE in self.config.sinks:
|
||||
trace.get_tracer_provider().add_span_processor(ConsoleSpanProcessor())
|
||||
self._lock = _global_lock
|
||||
|
||||
async def initialize(self) -> None:
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue