diff --git a/docs/resources/llama-stack-spec.html b/docs/resources/llama-stack-spec.html index 0ce216479..5ed8701a4 100644 --- a/docs/resources/llama-stack-spec.html +++ b/docs/resources/llama-stack-spec.html @@ -3843,8 +3843,8 @@ "properties": { "role": { "type": "string", - "const": "ipython", - "default": "ipython" + "const": "tool", + "default": "tool" }, "call_id": { "type": "string" @@ -4185,14 +4185,7 @@ "$ref": "#/components/schemas/ChatCompletionResponseEventType" }, "delta": { - "oneOf": [ - { - "type": "string" - }, - { - "$ref": "#/components/schemas/ToolCallDelta" - } - ] + "$ref": "#/components/schemas/ContentDelta" }, "logprobs": { "type": "array", @@ -4232,6 +4225,50 @@ ], "title": "SSE-stream of these events." }, + "ContentDelta": { + "oneOf": [ + { + "type": "object", + "properties": { + "type": { + "type": "string", + "const": "text", + "default": "text" + }, + "text": { + "type": "string" + } + }, + "additionalProperties": false, + "required": [ + "type", + "text" + ] + }, + { + "type": "object", + "properties": { + "type": { + "type": "string", + "const": "image", + "default": "image" + }, + "data": { + "type": "string", + "contentEncoding": "base64" + } + }, + "additionalProperties": false, + "required": [ + "type", + "data" + ] + }, + { + "$ref": "#/components/schemas/ToolCallDelta" + } + ] + }, "TokenLogProbs": { "type": "object", "properties": { @@ -4250,6 +4287,11 @@ "ToolCallDelta": { "type": "object", "properties": { + "type": { + "type": "string", + "const": "tool_call", + "default": "tool_call" + }, "content": { "oneOf": [ { @@ -4266,6 +4308,7 @@ }, "additionalProperties": false, "required": [ + "type", "content", "parse_status" ] @@ -4275,8 +4318,8 @@ "enum": [ "started", "in_progress", - "failure", - "success" + "failed", + "succeeded" ] }, "CompletionRequest": { @@ -4777,18 +4820,16 @@ "step_id": { "type": "string" }, - "text_delta": { - "type": "string" - }, - "tool_call_delta": { - "$ref": "#/components/schemas/ToolCallDelta" + "delta": { + "$ref": "#/components/schemas/ContentDelta" } }, "additionalProperties": false, "required": [ "event_type", "step_type", - "step_id" + "step_id", + "delta" ] }, "AgentTurnResponseStepStartPayload": { @@ -8758,6 +8799,10 @@ "name": "CompletionResponseStreamChunk", "description": "streamed completion response.\n\n" }, + { + "name": "ContentDelta", + "description": "" + }, { "name": "CreateAgentRequest", "description": "" @@ -9392,6 +9437,7 @@ "CompletionRequest", "CompletionResponse", "CompletionResponseStreamChunk", + "ContentDelta", "CreateAgentRequest", "CreateAgentSessionRequest", "CreateAgentTurnRequest", diff --git a/docs/resources/llama-stack-spec.yaml b/docs/resources/llama-stack-spec.yaml index 031178ce9..2a573959f 100644 --- a/docs/resources/llama-stack-spec.yaml +++ b/docs/resources/llama-stack-spec.yaml @@ -150,6 +150,8 @@ components: AgentTurnResponseStepProgressPayload: additionalProperties: false properties: + delta: + $ref: '#/components/schemas/ContentDelta' event_type: const: step_progress default: step_progress @@ -163,14 +165,11 @@ components: - shield_call - memory_retrieval type: string - text_delta: - type: string - tool_call_delta: - $ref: '#/components/schemas/ToolCallDelta' required: - event_type - step_type - step_id + - delta type: object AgentTurnResponseStepStartPayload: additionalProperties: false @@ -462,9 +461,7 @@ components: additionalProperties: false properties: delta: - oneOf: - - type: string - - $ref: '#/components/schemas/ToolCallDelta' + $ref: '#/components/schemas/ContentDelta' event_type: $ref: '#/components/schemas/ChatCompletionResponseEventType' logprobs: @@ -571,6 +568,34 @@ components: - delta title: streamed completion response. type: object + ContentDelta: + oneOf: + - additionalProperties: false + properties: + text: + type: string + type: + const: text + default: text + type: string + required: + - type + - text + type: object + - additionalProperties: false + properties: + data: + contentEncoding: base64 + type: string + type: + const: image + default: image + type: string + required: + - type + - data + type: object + - $ref: '#/components/schemas/ToolCallDelta' CreateAgentRequest: additionalProperties: false properties: @@ -2664,7 +2689,12 @@ components: - $ref: '#/components/schemas/ToolCall' parse_status: $ref: '#/components/schemas/ToolCallParseStatus' + type: + const: tool_call + default: tool_call + type: string required: + - type - content - parse_status type: object @@ -2672,8 +2702,8 @@ components: enum: - started - in_progress - - failure - - success + - failed + - succeeded type: string ToolChoice: enum: @@ -2888,8 +2918,8 @@ components: content: $ref: '#/components/schemas/InterleavedContent' role: - const: ipython - default: ipython + const: tool + default: tool type: string tool_name: oneOf: @@ -5500,6 +5530,8 @@ tags: ' name: CompletionResponseStreamChunk +- description: + name: ContentDelta - description: name: CreateAgentRequest @@ -5939,6 +5971,7 @@ x-tagGroups: - CompletionRequest - CompletionResponse - CompletionResponseStreamChunk + - ContentDelta - CreateAgentRequest - CreateAgentSessionRequest - CreateAgentTurnRequest diff --git a/llama_stack/apis/agents/agents.py b/llama_stack/apis/agents/agents.py index fb9df21e6..c3f3d21f0 100644 --- a/llama_stack/apis/agents/agents.py +++ b/llama_stack/apis/agents/agents.py @@ -22,12 +22,11 @@ from llama_models.schema_utils import json_schema_type, register_schema, webmeth from pydantic import BaseModel, ConfigDict, Field from typing_extensions import Annotated -from llama_stack.apis.common.content_types import InterleavedContent, URL +from llama_stack.apis.common.content_types import ContentDelta, InterleavedContent, URL from llama_stack.apis.inference import ( CompletionMessage, SamplingParams, ToolCall, - ToolCallDelta, ToolChoice, ToolPromptFormat, ToolResponse, @@ -216,8 +215,7 @@ class AgentTurnResponseStepProgressPayload(BaseModel): step_type: StepType step_id: str - text_delta: Optional[str] = None - tool_call_delta: Optional[ToolCallDelta] = None + delta: ContentDelta @json_schema_type diff --git a/llama_stack/apis/agents/event_logger.py b/llama_stack/apis/agents/event_logger.py index 40a69d19c..9e2f14805 100644 --- a/llama_stack/apis/agents/event_logger.py +++ b/llama_stack/apis/agents/event_logger.py @@ -11,9 +11,13 @@ from llama_models.llama3.api.tool_utils import ToolUtils from termcolor import cprint from llama_stack.apis.agents import AgentTurnResponseEventType, StepType - +from llama_stack.apis.common.content_types import ToolCallParseStatus from llama_stack.apis.inference import ToolResponseMessage +from llama_stack.providers.utils.inference.prompt_adapter import ( + interleaved_content_as_str, +) + class LogEvent: def __init__( @@ -57,8 +61,11 @@ class EventLogger: # since it does not produce event but instead # a Message if isinstance(chunk, ToolResponseMessage): - yield chunk, LogEvent( - role="CustomTool", content=chunk.content, color="grey" + yield ( + chunk, + LogEvent( + role="CustomTool", content=chunk.content, color="grey" + ), ) continue @@ -80,14 +87,20 @@ class EventLogger: ): violation = event.payload.step_details.violation if not violation: - yield event, LogEvent( - role=step_type, content="No Violation", color="magenta" + yield ( + event, + LogEvent( + role=step_type, content="No Violation", color="magenta" + ), ) else: - yield event, LogEvent( - role=step_type, - content=f"{violation.metadata} {violation.user_message}", - color="red", + yield ( + event, + LogEvent( + role=step_type, + content=f"{violation.metadata} {violation.user_message}", + color="red", + ), ) # handle inference @@ -95,8 +108,11 @@ class EventLogger: if stream: if event_type == EventType.step_start.value: # TODO: Currently this event is never received - yield event, LogEvent( - role=step_type, content="", end="", color="yellow" + yield ( + event, + LogEvent( + role=step_type, content="", end="", color="yellow" + ), ) elif event_type == EventType.step_progress.value: # HACK: if previous was not step/event was not inference's step_progress @@ -107,24 +123,34 @@ class EventLogger: previous_event_type != EventType.step_progress.value and previous_step_type != StepType.inference ): - yield event, LogEvent( - role=step_type, content="", end="", color="yellow" + yield ( + event, + LogEvent( + role=step_type, content="", end="", color="yellow" + ), ) - if event.payload.tool_call_delta: - if isinstance(event.payload.tool_call_delta.content, str): - yield event, LogEvent( - role=None, - content=event.payload.tool_call_delta.content, - end="", - color="cyan", + delta = event.payload.delta + if delta.type == "tool_call": + if delta.parse_status == ToolCallParseStatus.succeeded: + yield ( + event, + LogEvent( + role=None, + content=delta.content, + end="", + color="cyan", + ), ) else: - yield event, LogEvent( - role=None, - content=event.payload.text_delta, - end="", - color="yellow", + yield ( + event, + LogEvent( + role=None, + content=delta.text, + end="", + color="yellow", + ), ) else: # step_complete @@ -140,10 +166,13 @@ class EventLogger: ) else: content = response.content - yield event, LogEvent( - role=step_type, - content=content, - color="yellow", + yield ( + event, + LogEvent( + role=step_type, + content=content, + color="yellow", + ), ) # handle tool_execution @@ -155,16 +184,22 @@ class EventLogger: ): details = event.payload.step_details for t in details.tool_calls: - yield event, LogEvent( - role=step_type, - content=f"Tool:{t.tool_name} Args:{t.arguments}", - color="green", + yield ( + event, + LogEvent( + role=step_type, + content=f"Tool:{t.tool_name} Args:{t.arguments}", + color="green", + ), ) for r in details.tool_responses: - yield event, LogEvent( - role=step_type, - content=f"Tool:{r.tool_name} Response:{r.content}", - color="green", + yield ( + event, + LogEvent( + role=step_type, + content=f"Tool:{r.tool_name} Response:{r.content}", + color="green", + ), ) if ( @@ -172,15 +207,16 @@ class EventLogger: and event_type == EventType.step_complete.value ): details = event.payload.step_details - inserted_context = interleaved_text_media_as_str( - details.inserted_context - ) + inserted_context = interleaved_content_as_str(details.inserted_context) content = f"fetched {len(inserted_context)} bytes from {details.memory_bank_ids}" - yield event, LogEvent( - role=step_type, - content=content, - color="cyan", + yield ( + event, + LogEvent( + role=step_type, + content=content, + color="cyan", + ), ) previous_event_type = event_type diff --git a/llama_stack/apis/common/content_types.py b/llama_stack/apis/common/content_types.py index 629e0e94d..3b61fa243 100644 --- a/llama_stack/apis/common/content_types.py +++ b/llama_stack/apis/common/content_types.py @@ -5,10 +5,12 @@ # the root directory of this source tree. import base64 +from enum import Enum from typing import Annotated, List, Literal, Optional, Union -from llama_models.schema_utils import json_schema_type, register_schema +from llama_models.llama3.api.datatypes import ToolCall +from llama_models.schema_utils import json_schema_type, register_schema from pydantic import BaseModel, Field, field_serializer, model_validator @@ -60,3 +62,42 @@ InterleavedContent = register_schema( Union[str, InterleavedContentItem, List[InterleavedContentItem]], name="InterleavedContent", ) + + +class TextDelta(BaseModel): + type: Literal["text"] = "text" + text: str + + +class ImageDelta(BaseModel): + type: Literal["image"] = "image" + data: bytes + + +@json_schema_type +class ToolCallParseStatus(Enum): + started = "started" + in_progress = "in_progress" + failed = "failed" + succeeded = "succeeded" + + +@json_schema_type +class ToolCallDelta(BaseModel): + type: Literal["tool_call"] = "tool_call" + + # you either send an in-progress tool call so the client can stream a long + # code generation or you send the final parsed tool call at the end of the + # stream + content: Union[str, ToolCall] + parse_status: ToolCallParseStatus + + +# streaming completions send a stream of ContentDeltas +ContentDelta = register_schema( + Annotated[ + Union[TextDelta, ImageDelta, ToolCallDelta], + Field(discriminator="type"), + ], + name="ContentDelta", +) diff --git a/llama_stack/apis/inference/inference.py b/llama_stack/apis/inference/inference.py index 4a453700c..b525aa331 100644 --- a/llama_stack/apis/inference/inference.py +++ b/llama_stack/apis/inference/inference.py @@ -29,7 +29,7 @@ from llama_models.schema_utils import json_schema_type, register_schema, webmeth from pydantic import BaseModel, Field, field_validator from typing_extensions import Annotated -from llama_stack.apis.common.content_types import InterleavedContent +from llama_stack.apis.common.content_types import ContentDelta, InterleavedContent from llama_stack.apis.models import Model from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol @@ -147,26 +147,12 @@ class ChatCompletionResponseEventType(Enum): progress = "progress" -@json_schema_type -class ToolCallParseStatus(Enum): - started = "started" - in_progress = "in_progress" - failure = "failure" - success = "success" - - -@json_schema_type -class ToolCallDelta(BaseModel): - content: Union[str, ToolCall] - parse_status: ToolCallParseStatus - - @json_schema_type class ChatCompletionResponseEvent(BaseModel): """Chat completion response event.""" event_type: ChatCompletionResponseEventType - delta: Union[str, ToolCallDelta] + delta: ContentDelta logprobs: Optional[List[TokenLogProbs]] = None stop_reason: Optional[StopReason] = None diff --git a/llama_stack/cli/stack/build.py b/llama_stack/cli/stack/build.py index 084374c8a..85e6cb962 100644 --- a/llama_stack/cli/stack/build.py +++ b/llama_stack/cli/stack/build.py @@ -4,9 +4,7 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. import argparse - import importlib.resources - import os import shutil from functools import lru_cache @@ -14,14 +12,12 @@ from pathlib import Path from typing import List, Optional from llama_stack.cli.subcommand import Subcommand - from llama_stack.distribution.datatypes import ( BuildConfig, DistributionSpec, Provider, StackRunConfig, ) - from llama_stack.distribution.distribution import get_provider_registry from llama_stack.distribution.resolver import InvalidProviderError from llama_stack.distribution.utils.dynamic import instantiate_class_type @@ -296,6 +292,7 @@ class StackBuild(Subcommand): / f"templates/{template_name}/run.yaml" ) with importlib.resources.as_file(template_path) as path: + run_config_file = build_dir / f"{build_config.name}-run.yaml" shutil.copy(path, run_config_file) # Find all ${env.VARIABLE} patterns cprint("Build Successful!", color="green") diff --git a/llama_stack/distribution/library_client.py b/llama_stack/distribution/library_client.py index 50af2cdea..0c124e64b 100644 --- a/llama_stack/distribution/library_client.py +++ b/llama_stack/distribution/library_client.py @@ -9,12 +9,10 @@ import inspect import json import logging import os -import queue -import threading from concurrent.futures import ThreadPoolExecutor from enum import Enum from pathlib import Path -from typing import Any, Generator, get_args, get_origin, Optional, TypeVar +from typing import Any, get_args, get_origin, Optional, TypeVar import httpx import yaml @@ -64,71 +62,6 @@ def in_notebook(): return True -def stream_across_asyncio_run_boundary( - async_gen_maker, - pool_executor: ThreadPoolExecutor, - path: Optional[str] = None, - provider_data: Optional[dict[str, Any]] = None, -) -> Generator[T, None, None]: - result_queue = queue.Queue() - stop_event = threading.Event() - - async def consumer(): - # make sure we make the generator in the event loop context - gen = await async_gen_maker() - await start_trace(path, {"__location__": "library_client"}) - if provider_data: - set_request_provider_data( - {"X-LlamaStack-Provider-Data": json.dumps(provider_data)} - ) - try: - async for item in await gen: - result_queue.put(item) - except Exception as e: - print(f"Error in generator {e}") - result_queue.put(e) - except asyncio.CancelledError: - return - finally: - result_queue.put(StopIteration) - stop_event.set() - await end_trace() - - def run_async(): - # Run our own loop to avoid double async generator cleanup which is done - # by asyncio.run() - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - try: - task = loop.create_task(consumer()) - loop.run_until_complete(task) - finally: - # Handle pending tasks like a generator's athrow() - pending = asyncio.all_tasks(loop) - if pending: - loop.run_until_complete( - asyncio.gather(*pending, return_exceptions=True) - ) - loop.close() - - future = pool_executor.submit(run_async) - - try: - # yield results as they come in - while not stop_event.is_set() or not result_queue.empty(): - try: - item = result_queue.get(timeout=0.1) - if item is StopIteration: - break - if isinstance(item, Exception): - raise item - yield item - except queue.Empty: - continue - finally: - future.result() - - def convert_pydantic_to_json_value(value: Any) -> Any: if isinstance(value, Enum): return value.value @@ -184,7 +117,7 @@ class LlamaStackAsLibraryClient(LlamaStackClient): ): super().__init__() self.async_client = AsyncLlamaStackAsLibraryClient( - config_path_or_template_name, custom_provider_registry + config_path_or_template_name, custom_provider_registry, provider_data ) self.pool_executor = ThreadPoolExecutor(max_workers=4) self.skip_logger_removal = skip_logger_removal @@ -210,39 +143,30 @@ class LlamaStackAsLibraryClient(LlamaStackClient): root_logger.removeHandler(handler) print(f"Removed handler {handler.__class__.__name__} from root logger") - def _get_path( - self, - cast_to: Any, - options: Any, - *, - stream=False, - stream_cls=None, - ): - return options.url - def request(self, *args, **kwargs): - path = self._get_path(*args, **kwargs) if kwargs.get("stream"): - return stream_across_asyncio_run_boundary( - lambda: self.async_client.request(*args, **kwargs), - self.pool_executor, - path=path, - provider_data=self.provider_data, - ) - else: + # NOTE: We are using AsyncLlamaStackClient under the hood + # A new event loop is needed to convert the AsyncStream + # from async client into SyncStream return type for streaming + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) - async def _traced_request(): - if self.provider_data: - set_request_provider_data( - {"X-LlamaStack-Provider-Data": json.dumps(self.provider_data)} - ) - await start_trace(path, {"__location__": "library_client"}) + def sync_generator(): try: - return await self.async_client.request(*args, **kwargs) + async_stream = loop.run_until_complete( + self.async_client.request(*args, **kwargs) + ) + while True: + chunk = loop.run_until_complete(async_stream.__anext__()) + yield chunk + except StopAsyncIteration: + pass finally: - await end_trace() + loop.close() - return asyncio.run(_traced_request()) + return sync_generator() + else: + return asyncio.run(self.async_client.request(*args, **kwargs)) class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient): @@ -250,9 +174,9 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient): self, config_path_or_template_name: str, custom_provider_registry: Optional[ProviderRegistry] = None, + provider_data: Optional[dict[str, Any]] = None, ): super().__init__() - # when using the library client, we should not log to console since many # of our logs are intended for server-side usage current_sinks = os.environ.get("TELEMETRY_SINKS", "sqlite").split(",") @@ -273,6 +197,7 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient): self.config_path_or_template_name = config_path_or_template_name self.config = config self.custom_provider_registry = custom_provider_registry + self.provider_data = provider_data async def initialize(self): try: @@ -329,17 +254,24 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient): if not self.endpoint_impls: raise ValueError("Client not initialized") + if self.provider_data: + set_request_provider_data( + {"X-LlamaStack-Provider-Data": json.dumps(self.provider_data)} + ) + await start_trace(options.url, {"__location__": "library_client"}) if stream: - return self._call_streaming( + response = await self._call_streaming( cast_to=cast_to, options=options, stream_cls=stream_cls, ) else: - return await self._call_non_streaming( + response = await self._call_non_streaming( cast_to=cast_to, options=options, ) + await end_trace() + return response async def _call_non_streaming( self, diff --git a/llama_stack/distribution/store/registry.py b/llama_stack/distribution/store/registry.py index d26b4447c..010d137ec 100644 --- a/llama_stack/distribution/store/registry.py +++ b/llama_stack/distribution/store/registry.py @@ -35,7 +35,7 @@ class DistributionRegistry(Protocol): REGISTER_PREFIX = "distributions:registry" -KEY_VERSION = "v4" +KEY_VERSION = "v5" KEY_FORMAT = f"{REGISTER_PREFIX}:{KEY_VERSION}::" + "{type}:{identifier}" diff --git a/llama_stack/providers/inline/agents/meta_reference/agent_instance.py b/llama_stack/providers/inline/agents/meta_reference/agent_instance.py index 24448a28f..2299e80d1 100644 --- a/llama_stack/providers/inline/agents/meta_reference/agent_instance.py +++ b/llama_stack/providers/inline/agents/meta_reference/agent_instance.py @@ -40,7 +40,12 @@ from llama_stack.apis.agents import ( ToolExecutionStep, Turn, ) -from llama_stack.apis.common.content_types import TextContentItem, URL +from llama_stack.apis.common.content_types import ( + TextContentItem, + ToolCallDelta, + ToolCallParseStatus, + URL, +) from llama_stack.apis.inference import ( ChatCompletionResponseEventType, CompletionMessage, @@ -49,8 +54,6 @@ from llama_stack.apis.inference import ( SamplingParams, StopReason, SystemMessage, - ToolCallDelta, - ToolCallParseStatus, ToolDefinition, ToolResponse, ToolResponseMessage, @@ -411,8 +414,8 @@ class ChatAgent(ShieldRunnerMixin): payload=AgentTurnResponseStepProgressPayload( step_type=StepType.tool_execution.value, step_id=step_id, - tool_call_delta=ToolCallDelta( - parse_status=ToolCallParseStatus.success, + delta=ToolCallDelta( + parse_status=ToolCallParseStatus.succeeded, content=ToolCall( call_id="", tool_name=MEMORY_QUERY_TOOL, @@ -507,8 +510,8 @@ class ChatAgent(ShieldRunnerMixin): continue delta = event.delta - if isinstance(delta, ToolCallDelta): - if delta.parse_status == ToolCallParseStatus.success: + if delta.type == "tool_call": + if delta.parse_status == ToolCallParseStatus.succeeded: tool_calls.append(delta.content) if stream: yield AgentTurnResponseStreamChunk( @@ -516,21 +519,20 @@ class ChatAgent(ShieldRunnerMixin): payload=AgentTurnResponseStepProgressPayload( step_type=StepType.inference.value, step_id=step_id, - text_delta="", - tool_call_delta=delta, + delta=delta, ) ) ) - elif isinstance(delta, str): - content += delta + elif delta.type == "text": + content += delta.text if stream and event.stop_reason is None: yield AgentTurnResponseStreamChunk( event=AgentTurnResponseEvent( payload=AgentTurnResponseStepProgressPayload( step_type=StepType.inference.value, step_id=step_id, - text_delta=event.delta, + delta=delta, ) ) ) diff --git a/llama_stack/providers/inline/inference/meta_reference/inference.py b/llama_stack/providers/inline/inference/meta_reference/inference.py index 5b502a581..d64d32f03 100644 --- a/llama_stack/providers/inline/inference/meta_reference/inference.py +++ b/llama_stack/providers/inline/inference/meta_reference/inference.py @@ -16,6 +16,11 @@ from llama_models.llama3.api.datatypes import ( ) from llama_models.sku_list import resolve_model +from llama_stack.apis.common.content_types import ( + TextDelta, + ToolCallDelta, + ToolCallParseStatus, +) from llama_stack.apis.inference import ( ChatCompletionRequest, ChatCompletionResponse, @@ -32,8 +37,6 @@ from llama_stack.apis.inference import ( Message, ResponseFormat, TokenLogProbs, - ToolCallDelta, - ToolCallParseStatus, ToolChoice, ) from llama_stack.apis.models import Model, ModelType @@ -190,14 +193,14 @@ class MetaReferenceInferenceImpl( ] yield CompletionResponseStreamChunk( - delta=text, + delta=TextDelta(text=text), stop_reason=stop_reason, logprobs=logprobs if request.logprobs else None, ) if stop_reason is None: yield CompletionResponseStreamChunk( - delta="", + delta=TextDelta(text=""), stop_reason=StopReason.out_of_tokens, ) @@ -352,7 +355,7 @@ class MetaReferenceInferenceImpl( yield ChatCompletionResponseStreamChunk( event=ChatCompletionResponseEvent( event_type=ChatCompletionResponseEventType.start, - delta="", + delta=TextDelta(text=""), ) ) @@ -392,7 +395,7 @@ class MetaReferenceInferenceImpl( parse_status=ToolCallParseStatus.in_progress, ) else: - delta = text + delta = TextDelta(text=text) if stop_reason is None: if request.logprobs: @@ -428,7 +431,7 @@ class MetaReferenceInferenceImpl( event_type=ChatCompletionResponseEventType.progress, delta=ToolCallDelta( content="", - parse_status=ToolCallParseStatus.failure, + parse_status=ToolCallParseStatus.failed, ), stop_reason=stop_reason, ) @@ -440,7 +443,7 @@ class MetaReferenceInferenceImpl( event_type=ChatCompletionResponseEventType.progress, delta=ToolCallDelta( content=tool_call, - parse_status=ToolCallParseStatus.success, + parse_status=ToolCallParseStatus.succeeded, ), stop_reason=stop_reason, ) @@ -449,7 +452,7 @@ class MetaReferenceInferenceImpl( yield ChatCompletionResponseStreamChunk( event=ChatCompletionResponseEvent( event_type=ChatCompletionResponseEventType.complete, - delta="", + delta=TextDelta(text=""), stop_reason=stop_reason, ) ) diff --git a/llama_stack/providers/inline/telemetry/meta_reference/telemetry.py b/llama_stack/providers/inline/telemetry/meta_reference/telemetry.py index efc37b553..332a150cf 100644 --- a/llama_stack/providers/inline/telemetry/meta_reference/telemetry.py +++ b/llama_stack/providers/inline/telemetry/meta_reference/telemetry.py @@ -30,13 +30,10 @@ from llama_stack.apis.telemetry import ( Trace, UnstructuredLogEvent, ) - from llama_stack.distribution.datatypes import Api - from llama_stack.providers.inline.telemetry.meta_reference.console_span_processor import ( ConsoleSpanProcessor, ) - from llama_stack.providers.inline.telemetry.meta_reference.sqlite_span_processor import ( SQLiteSpanProcessor, ) @@ -52,6 +49,7 @@ _GLOBAL_STORAGE = { "up_down_counters": {}, } _global_lock = threading.Lock() +_TRACER_PROVIDER = None def string_to_trace_id(s: str) -> int: @@ -80,31 +78,34 @@ class TelemetryAdapter(TelemetryDatasetMixin, Telemetry): } ) - provider = TracerProvider(resource=resource) - trace.set_tracer_provider(provider) - if TelemetrySink.OTEL in self.config.sinks: - otlp_exporter = OTLPSpanExporter( - endpoint=self.config.otel_endpoint, - ) - span_processor = BatchSpanProcessor(otlp_exporter) - trace.get_tracer_provider().add_span_processor(span_processor) - metric_reader = PeriodicExportingMetricReader( - OTLPMetricExporter( + global _TRACER_PROVIDER + if _TRACER_PROVIDER is None: + provider = TracerProvider(resource=resource) + trace.set_tracer_provider(provider) + _TRACER_PROVIDER = provider + if TelemetrySink.OTEL in self.config.sinks: + otlp_exporter = OTLPSpanExporter( endpoint=self.config.otel_endpoint, ) - ) - metric_provider = MeterProvider( - resource=resource, metric_readers=[metric_reader] - ) - metrics.set_meter_provider(metric_provider) - self.meter = metrics.get_meter(__name__) - if TelemetrySink.SQLITE in self.config.sinks: - trace.get_tracer_provider().add_span_processor( - SQLiteSpanProcessor(self.config.sqlite_db_path) - ) - self.trace_store = SQLiteTraceStore(self.config.sqlite_db_path) - if TelemetrySink.CONSOLE in self.config.sinks: - trace.get_tracer_provider().add_span_processor(ConsoleSpanProcessor()) + span_processor = BatchSpanProcessor(otlp_exporter) + trace.get_tracer_provider().add_span_processor(span_processor) + metric_reader = PeriodicExportingMetricReader( + OTLPMetricExporter( + endpoint=self.config.otel_endpoint, + ) + ) + metric_provider = MeterProvider( + resource=resource, metric_readers=[metric_reader] + ) + metrics.set_meter_provider(metric_provider) + self.meter = metrics.get_meter(__name__) + if TelemetrySink.SQLITE in self.config.sinks: + trace.get_tracer_provider().add_span_processor( + SQLiteSpanProcessor(self.config.sqlite_db_path) + ) + self.trace_store = SQLiteTraceStore(self.config.sqlite_db_path) + if TelemetrySink.CONSOLE in self.config.sinks: + trace.get_tracer_provider().add_span_processor(ConsoleSpanProcessor()) self._lock = _global_lock async def initialize(self) -> None: diff --git a/llama_stack/providers/remote/inference/groq/groq_utils.py b/llama_stack/providers/remote/inference/groq/groq_utils.py index 032f4c8d4..11f684847 100644 --- a/llama_stack/providers/remote/inference/groq/groq_utils.py +++ b/llama_stack/providers/remote/inference/groq/groq_utils.py @@ -30,6 +30,11 @@ from groq.types.shared.function_definition import FunctionDefinition from llama_models.llama3.api.datatypes import ToolParamDefinition +from llama_stack.apis.common.content_types import ( + TextDelta, + ToolCallDelta, + ToolCallParseStatus, +) from llama_stack.apis.inference import ( ChatCompletionRequest, ChatCompletionResponse, @@ -40,8 +45,6 @@ from llama_stack.apis.inference import ( Message, StopReason, ToolCall, - ToolCallDelta, - ToolCallParseStatus, ToolDefinition, ToolPromptFormat, ) @@ -162,7 +165,7 @@ def convert_chat_completion_response( def _map_finish_reason_to_stop_reason( - finish_reason: Literal["stop", "length", "tool_calls"] + finish_reason: Literal["stop", "length", "tool_calls"], ) -> StopReason: """ Convert a Groq chat completion finish_reason to a StopReason. @@ -185,7 +188,6 @@ def _map_finish_reason_to_stop_reason( async def convert_chat_completion_response_stream( stream: Stream[ChatCompletionChunk], ) -> AsyncGenerator[ChatCompletionResponseStreamChunk, None]: - event_type = ChatCompletionResponseEventType.start for chunk in stream: choice = chunk.choices[0] @@ -194,7 +196,7 @@ async def convert_chat_completion_response_stream( yield ChatCompletionResponseStreamChunk( event=ChatCompletionResponseEvent( event_type=ChatCompletionResponseEventType.complete, - delta=choice.delta.content or "", + delta=TextDelta(text=choice.delta.content or ""), logprobs=None, stop_reason=_map_finish_reason_to_stop_reason(choice.finish_reason), ) @@ -213,7 +215,7 @@ async def convert_chat_completion_response_stream( event_type=event_type, delta=ToolCallDelta( content=tool_call, - parse_status=ToolCallParseStatus.success, + parse_status=ToolCallParseStatus.succeeded, ), ) ) @@ -221,7 +223,7 @@ async def convert_chat_completion_response_stream( yield ChatCompletionResponseStreamChunk( event=ChatCompletionResponseEvent( event_type=event_type, - delta=choice.delta.content or "", + delta=TextDelta(text=choice.delta.content or ""), logprobs=None, ) ) diff --git a/llama_stack/providers/remote/inference/nvidia/openai_utils.py b/llama_stack/providers/remote/inference/nvidia/openai_utils.py index dcc7c5fca..975812844 100644 --- a/llama_stack/providers/remote/inference/nvidia/openai_utils.py +++ b/llama_stack/providers/remote/inference/nvidia/openai_utils.py @@ -34,6 +34,11 @@ from openai.types.chat.chat_completion_message_tool_call_param import ( from openai.types.completion import Completion as OpenAICompletion from openai.types.completion_choice import Logprobs as OpenAICompletionLogprobs +from llama_stack.apis.common.content_types import ( + TextDelta, + ToolCallDelta, + ToolCallParseStatus, +) from llama_stack.apis.inference import ( ChatCompletionRequest, ChatCompletionResponse, @@ -48,8 +53,6 @@ from llama_stack.apis.inference import ( Message, SystemMessage, TokenLogProbs, - ToolCallDelta, - ToolCallParseStatus, ToolResponseMessage, UserMessage, ) @@ -432,69 +435,6 @@ async def convert_openai_chat_completion_stream( """ Convert a stream of OpenAI chat completion chunks into a stream of ChatCompletionResponseStreamChunk. - - OpenAI ChatCompletionChunk: - choices: List[Choice] - - OpenAI Choice: # different from the non-streamed Choice - delta: ChoiceDelta - finish_reason: Optional[Literal["stop", "length", "tool_calls", "content_filter", "function_call"]] - logprobs: Optional[ChoiceLogprobs] - - OpenAI ChoiceDelta: - content: Optional[str] - role: Optional[Literal["system", "user", "assistant", "tool"]] - tool_calls: Optional[List[ChoiceDeltaToolCall]] - - OpenAI ChoiceDeltaToolCall: - index: int - id: Optional[str] - function: Optional[ChoiceDeltaToolCallFunction] - type: Optional[Literal["function"]] - - OpenAI ChoiceDeltaToolCallFunction: - name: Optional[str] - arguments: Optional[str] - - -> - - ChatCompletionResponseStreamChunk: - event: ChatCompletionResponseEvent - - ChatCompletionResponseEvent: - event_type: ChatCompletionResponseEventType - delta: Union[str, ToolCallDelta] - logprobs: Optional[List[TokenLogProbs]] - stop_reason: Optional[StopReason] - - ChatCompletionResponseEventType: - start = "start" - progress = "progress" - complete = "complete" - - ToolCallDelta: - content: Union[str, ToolCall] - parse_status: ToolCallParseStatus - - ToolCall: - call_id: str - tool_name: str - arguments: str - - ToolCallParseStatus: - started = "started" - in_progress = "in_progress" - failure = "failure" - success = "success" - - TokenLogProbs: - logprobs_by_token: Dict[str, float] - - token, logprob - - StopReason: - end_of_turn = "end_of_turn" - end_of_message = "end_of_message" - out_of_tokens = "out_of_tokens" """ # generate a stream of ChatCompletionResponseEventType: start -> progress -> progress -> ... @@ -543,7 +483,7 @@ async def convert_openai_chat_completion_stream( yield ChatCompletionResponseStreamChunk( event=ChatCompletionResponseEvent( event_type=next(event_type), - delta=choice.delta.content, + delta=TextDelta(text=choice.delta.content), logprobs=_convert_openai_logprobs(choice.logprobs), ) ) @@ -561,7 +501,7 @@ async def convert_openai_chat_completion_stream( event_type=next(event_type), delta=ToolCallDelta( content=_convert_openai_tool_calls(choice.delta.tool_calls)[0], - parse_status=ToolCallParseStatus.success, + parse_status=ToolCallParseStatus.succeeded, ), logprobs=_convert_openai_logprobs(choice.logprobs), ) @@ -570,7 +510,7 @@ async def convert_openai_chat_completion_stream( yield ChatCompletionResponseStreamChunk( event=ChatCompletionResponseEvent( event_type=next(event_type), - delta=choice.delta.content or "", # content is not optional + delta=TextDelta(text=choice.delta.content or ""), logprobs=_convert_openai_logprobs(choice.logprobs), ) ) @@ -578,7 +518,7 @@ async def convert_openai_chat_completion_stream( yield ChatCompletionResponseStreamChunk( event=ChatCompletionResponseEvent( event_type=ChatCompletionResponseEventType.complete, - delta="", + delta=TextDelta(text=""), stop_reason=stop_reason, ) ) @@ -653,18 +593,6 @@ def _convert_openai_completion_logprobs( ) -> Optional[List[TokenLogProbs]]: """ Convert an OpenAI CompletionLogprobs into a list of TokenLogProbs. - - OpenAI CompletionLogprobs: - text_offset: Optional[List[int]] - token_logprobs: Optional[List[float]] - tokens: Optional[List[str]] - top_logprobs: Optional[List[Dict[str, float]]] - - -> - - TokenLogProbs: - logprobs_by_token: Dict[str, float] - - token, logprob """ if not logprobs: return None @@ -679,28 +607,6 @@ def convert_openai_completion_choice( ) -> CompletionResponse: """ Convert an OpenAI Completion Choice into a CompletionResponse. - - OpenAI Completion Choice: - text: str - finish_reason: str - logprobs: Optional[ChoiceLogprobs] - - -> - - CompletionResponse: - completion_message: CompletionMessage - logprobs: Optional[List[TokenLogProbs]] - - CompletionMessage: - role: Literal["assistant"] - content: str | ImageMedia | List[str | ImageMedia] - stop_reason: StopReason - tool_calls: List[ToolCall] - - class StopReason(Enum): - end_of_turn = "end_of_turn" - end_of_message = "end_of_message" - out_of_tokens = "out_of_tokens" """ return CompletionResponse( content=choice.text, @@ -715,32 +621,11 @@ async def convert_openai_completion_stream( """ Convert a stream of OpenAI Completions into a stream of ChatCompletionResponseStreamChunks. - - OpenAI Completion: - id: str - choices: List[OpenAICompletionChoice] - created: int - model: str - system_fingerprint: Optional[str] - usage: Optional[OpenAICompletionUsage] - - OpenAI CompletionChoice: - finish_reason: str - index: int - logprobs: Optional[OpenAILogprobs] - text: str - - -> - - CompletionResponseStreamChunk: - delta: str - stop_reason: Optional[StopReason] - logprobs: Optional[List[TokenLogProbs]] """ async for chunk in stream: choice = chunk.choices[0] yield CompletionResponseStreamChunk( - delta=choice.text, + delta=TextDelta(text=choice.text), stop_reason=_convert_openai_finish_reason(choice.finish_reason), logprobs=_convert_openai_completion_logprobs(choice.logprobs), ) diff --git a/llama_stack/providers/tests/inference/test_text_inference.py b/llama_stack/providers/tests/inference/test_text_inference.py index 19cc8393c..932ae36e6 100644 --- a/llama_stack/providers/tests/inference/test_text_inference.py +++ b/llama_stack/providers/tests/inference/test_text_inference.py @@ -18,6 +18,7 @@ from llama_models.llama3.api.datatypes import ( from pydantic import BaseModel, ValidationError +from llama_stack.apis.common.content_types import ToolCallParseStatus from llama_stack.apis.inference import ( ChatCompletionResponse, ChatCompletionResponseEventType, @@ -27,8 +28,6 @@ from llama_stack.apis.inference import ( JsonSchemaResponseFormat, LogProbConfig, SystemMessage, - ToolCallDelta, - ToolCallParseStatus, ToolChoice, UserMessage, ) @@ -196,7 +195,9 @@ class TestInference: 1 <= len(chunks) <= 6 ) # why 6 and not 5? the response may have an extra closing chunk, e.g. for usage or stop_reason for chunk in chunks: - if chunk.delta: # if there's a token, we expect logprobs + if ( + chunk.delta.type == "text" and chunk.delta.text + ): # if there's a token, we expect logprobs assert chunk.logprobs, "Logprobs should not be empty" assert all( len(logprob.logprobs_by_token) == 3 for logprob in chunk.logprobs @@ -463,7 +464,7 @@ class TestInference: if "Llama3.1" in inference_model: assert all( - isinstance(chunk.event.delta, ToolCallDelta) + chunk.event.delta.type == "tool_call" for chunk in grouped[ChatCompletionResponseEventType.progress] ) first = grouped[ChatCompletionResponseEventType.progress][0] @@ -474,8 +475,8 @@ class TestInference: last = grouped[ChatCompletionResponseEventType.progress][-1] # assert last.event.stop_reason == expected_stop_reason - assert last.event.delta.parse_status == ToolCallParseStatus.success - assert isinstance(last.event.delta.content, ToolCall) + assert last.event.delta.parse_status == ToolCallParseStatus.succeeded + assert last.event.delta.content.type == "tool_call" call = last.event.delta.content assert call.tool_name == "get_weather" diff --git a/llama_stack/providers/utils/inference/openai_compat.py b/llama_stack/providers/utils/inference/openai_compat.py index ba63be2b6..4c46954cf 100644 --- a/llama_stack/providers/utils/inference/openai_compat.py +++ b/llama_stack/providers/utils/inference/openai_compat.py @@ -11,7 +11,13 @@ from llama_models.llama3.api.chat_format import ChatFormat from llama_models.llama3.api.datatypes import SamplingParams, StopReason from pydantic import BaseModel -from llama_stack.apis.common.content_types import ImageContentItem, TextContentItem +from llama_stack.apis.common.content_types import ( + ImageContentItem, + TextContentItem, + TextDelta, + ToolCallDelta, + ToolCallParseStatus, +) from llama_stack.apis.inference import ( ChatCompletionResponse, @@ -22,8 +28,6 @@ from llama_stack.apis.inference import ( CompletionResponse, CompletionResponseStreamChunk, Message, - ToolCallDelta, - ToolCallParseStatus, ) from llama_stack.providers.utils.inference.prompt_adapter import ( @@ -160,7 +164,7 @@ async def process_chat_completion_stream_response( yield ChatCompletionResponseStreamChunk( event=ChatCompletionResponseEvent( event_type=ChatCompletionResponseEventType.start, - delta="", + delta=TextDelta(text=""), ) ) @@ -227,7 +231,7 @@ async def process_chat_completion_stream_response( yield ChatCompletionResponseStreamChunk( event=ChatCompletionResponseEvent( event_type=ChatCompletionResponseEventType.progress, - delta=text, + delta=TextDelta(text=text), stop_reason=stop_reason, ) ) @@ -241,7 +245,7 @@ async def process_chat_completion_stream_response( event_type=ChatCompletionResponseEventType.progress, delta=ToolCallDelta( content="", - parse_status=ToolCallParseStatus.failure, + parse_status=ToolCallParseStatus.failed, ), stop_reason=stop_reason, ) @@ -253,7 +257,7 @@ async def process_chat_completion_stream_response( event_type=ChatCompletionResponseEventType.progress, delta=ToolCallDelta( content=tool_call, - parse_status=ToolCallParseStatus.success, + parse_status=ToolCallParseStatus.succeeded, ), stop_reason=stop_reason, ) @@ -262,7 +266,7 @@ async def process_chat_completion_stream_response( yield ChatCompletionResponseStreamChunk( event=ChatCompletionResponseEvent( event_type=ChatCompletionResponseEventType.complete, - delta="", + delta=TextDelta(text=""), stop_reason=stop_reason, ) ) diff --git a/llama_stack/providers/utils/inference/prompt_adapter.py b/llama_stack/providers/utils/inference/prompt_adapter.py index 2d66dc60b..de4918f5c 100644 --- a/llama_stack/providers/utils/inference/prompt_adapter.py +++ b/llama_stack/providers/utils/inference/prompt_adapter.py @@ -265,6 +265,7 @@ def chat_completion_request_to_messages( For eg. for llama_3_1, add system message with the appropriate tools or add user messsage for custom tools, etc. """ + assert llama_model is not None, "llama_model is required" model = resolve_model(llama_model) if model is None: log.error(f"Could not resolve model {llama_model}") diff --git a/llama_stack/providers/utils/telemetry/tracing.py b/llama_stack/providers/utils/telemetry/tracing.py index f304d58f6..d84024941 100644 --- a/llama_stack/providers/utils/telemetry/tracing.py +++ b/llama_stack/providers/utils/telemetry/tracing.py @@ -127,7 +127,8 @@ class TraceContext: def setup_logger(api: Telemetry, level: int = logging.INFO): global BACKGROUND_LOGGER - BACKGROUND_LOGGER = BackgroundLogger(api) + if BACKGROUND_LOGGER is None: + BACKGROUND_LOGGER = BackgroundLogger(api) logger = logging.getLogger() logger.setLevel(level) logger.addHandler(TelemetryHandler()) diff --git a/tests/client-sdk/conftest.py b/tests/client-sdk/conftest.py index 16e6d1bbd..b40d54ee5 100644 --- a/tests/client-sdk/conftest.py +++ b/tests/client-sdk/conftest.py @@ -12,6 +12,11 @@ from llama_stack.providers.tests.env import get_env_or_fail from llama_stack_client import LlamaStackClient +def pytest_configure(config): + config.option.tbstyle = "short" + config.option.disable_warnings = True + + @pytest.fixture(scope="session") def provider_data(): # check env for tavily secret, brave secret and inject all into provider data @@ -29,6 +34,7 @@ def llama_stack_client(provider_data): client = LlamaStackAsLibraryClient( get_env_or_fail("LLAMA_STACK_CONFIG"), provider_data=provider_data, + skip_logger_removal=True, ) client.initialize() elif os.environ.get("LLAMA_STACK_BASE_URL"): diff --git a/tests/client-sdk/inference/test_inference.py b/tests/client-sdk/inference/test_inference.py index ef6219389..a50dba3a0 100644 --- a/tests/client-sdk/inference/test_inference.py +++ b/tests/client-sdk/inference/test_inference.py @@ -6,9 +6,9 @@ import pytest -from llama_stack_client.lib.inference.event_logger import EventLogger from pydantic import BaseModel + PROVIDER_TOOL_PROMPT_FORMAT = { "remote::ollama": "python_list", "remote::together": "json", @@ -39,7 +39,7 @@ def text_model_id(llama_stack_client): available_models = [ model.identifier for model in llama_stack_client.models.list() - if model.identifier.startswith("meta-llama") + if model.identifier.startswith("meta-llama") and "405" not in model.identifier ] assert len(available_models) > 0 return available_models[0] @@ -208,12 +208,9 @@ def test_text_chat_completion_streaming( stream=True, ) streamed_content = [ - str(log.content.lower().strip()) - for log in EventLogger().log(response) - if log is not None + str(chunk.event.delta.text.lower().strip()) for chunk in response ] assert len(streamed_content) > 0 - assert "assistant>" in streamed_content[0] assert expected.lower() in "".join(streamed_content) @@ -250,17 +247,16 @@ def test_text_chat_completion_with_tool_calling_and_non_streaming( def extract_tool_invocation_content(response): text_content: str = "" tool_invocation_content: str = "" - for log in EventLogger().log(response): - if log is None: - continue - if isinstance(log.content, str): - text_content += log.content - elif isinstance(log.content, object): - if isinstance(log.content.content, str): - continue - elif isinstance(log.content.content, object): - tool_invocation_content += f"[{log.content.content.tool_name}, {log.content.content.arguments}]" - + for chunk in response: + delta = chunk.event.delta + if delta.type == "text": + text_content += delta.text + elif delta.type == "tool_call": + if isinstance(delta.content, str): + tool_invocation_content += delta.content + else: + call = delta.content + tool_invocation_content += f"[{call.tool_name}, {call.arguments}]" return text_content, tool_invocation_content @@ -280,7 +276,6 @@ def test_text_chat_completion_with_tool_calling_and_streaming( ) text_content, tool_invocation_content = extract_tool_invocation_content(response) - assert "Assistant>" in text_content assert tool_invocation_content == "[get_weather, {'location': 'San Francisco, CA'}]" @@ -368,10 +363,7 @@ def test_image_chat_completion_streaming(llama_stack_client, vision_model_id): stream=True, ) streamed_content = [ - str(log.content.lower().strip()) - for log in EventLogger().log(response) - if log is not None + str(chunk.event.delta.text.lower().strip()) for chunk in response ] assert len(streamed_content) > 0 - assert "assistant>" in streamed_content[0] assert any(expected in streamed_content for expected in {"dog", "puppy", "pup"})