diff --git a/docs/resources/llama-stack-spec.html b/docs/resources/llama-stack-spec.html
index 0ce216479..5ed8701a4 100644
--- a/docs/resources/llama-stack-spec.html
+++ b/docs/resources/llama-stack-spec.html
@@ -3843,8 +3843,8 @@
"properties": {
"role": {
"type": "string",
- "const": "ipython",
- "default": "ipython"
+ "const": "tool",
+ "default": "tool"
},
"call_id": {
"type": "string"
@@ -4185,14 +4185,7 @@
"$ref": "#/components/schemas/ChatCompletionResponseEventType"
},
"delta": {
- "oneOf": [
- {
- "type": "string"
- },
- {
- "$ref": "#/components/schemas/ToolCallDelta"
- }
- ]
+ "$ref": "#/components/schemas/ContentDelta"
},
"logprobs": {
"type": "array",
@@ -4232,6 +4225,50 @@
],
"title": "SSE-stream of these events."
},
+ "ContentDelta": {
+ "oneOf": [
+ {
+ "type": "object",
+ "properties": {
+ "type": {
+ "type": "string",
+ "const": "text",
+ "default": "text"
+ },
+ "text": {
+ "type": "string"
+ }
+ },
+ "additionalProperties": false,
+ "required": [
+ "type",
+ "text"
+ ]
+ },
+ {
+ "type": "object",
+ "properties": {
+ "type": {
+ "type": "string",
+ "const": "image",
+ "default": "image"
+ },
+ "data": {
+ "type": "string",
+ "contentEncoding": "base64"
+ }
+ },
+ "additionalProperties": false,
+ "required": [
+ "type",
+ "data"
+ ]
+ },
+ {
+ "$ref": "#/components/schemas/ToolCallDelta"
+ }
+ ]
+ },
"TokenLogProbs": {
"type": "object",
"properties": {
@@ -4250,6 +4287,11 @@
"ToolCallDelta": {
"type": "object",
"properties": {
+ "type": {
+ "type": "string",
+ "const": "tool_call",
+ "default": "tool_call"
+ },
"content": {
"oneOf": [
{
@@ -4266,6 +4308,7 @@
},
"additionalProperties": false,
"required": [
+ "type",
"content",
"parse_status"
]
@@ -4275,8 +4318,8 @@
"enum": [
"started",
"in_progress",
- "failure",
- "success"
+ "failed",
+ "succeeded"
]
},
"CompletionRequest": {
@@ -4777,18 +4820,16 @@
"step_id": {
"type": "string"
},
- "text_delta": {
- "type": "string"
- },
- "tool_call_delta": {
- "$ref": "#/components/schemas/ToolCallDelta"
+ "delta": {
+ "$ref": "#/components/schemas/ContentDelta"
}
},
"additionalProperties": false,
"required": [
"event_type",
"step_type",
- "step_id"
+ "step_id",
+ "delta"
]
},
"AgentTurnResponseStepStartPayload": {
@@ -8758,6 +8799,10 @@
"name": "CompletionResponseStreamChunk",
"description": "streamed completion response.\n\n"
},
+ {
+ "name": "ContentDelta",
+ "description": ""
+ },
{
"name": "CreateAgentRequest",
"description": ""
@@ -9392,6 +9437,7 @@
"CompletionRequest",
"CompletionResponse",
"CompletionResponseStreamChunk",
+ "ContentDelta",
"CreateAgentRequest",
"CreateAgentSessionRequest",
"CreateAgentTurnRequest",
diff --git a/docs/resources/llama-stack-spec.yaml b/docs/resources/llama-stack-spec.yaml
index 031178ce9..2a573959f 100644
--- a/docs/resources/llama-stack-spec.yaml
+++ b/docs/resources/llama-stack-spec.yaml
@@ -150,6 +150,8 @@ components:
AgentTurnResponseStepProgressPayload:
additionalProperties: false
properties:
+ delta:
+ $ref: '#/components/schemas/ContentDelta'
event_type:
const: step_progress
default: step_progress
@@ -163,14 +165,11 @@ components:
- shield_call
- memory_retrieval
type: string
- text_delta:
- type: string
- tool_call_delta:
- $ref: '#/components/schemas/ToolCallDelta'
required:
- event_type
- step_type
- step_id
+ - delta
type: object
AgentTurnResponseStepStartPayload:
additionalProperties: false
@@ -462,9 +461,7 @@ components:
additionalProperties: false
properties:
delta:
- oneOf:
- - type: string
- - $ref: '#/components/schemas/ToolCallDelta'
+ $ref: '#/components/schemas/ContentDelta'
event_type:
$ref: '#/components/schemas/ChatCompletionResponseEventType'
logprobs:
@@ -571,6 +568,34 @@ components:
- delta
title: streamed completion response.
type: object
+ ContentDelta:
+ oneOf:
+ - additionalProperties: false
+ properties:
+ text:
+ type: string
+ type:
+ const: text
+ default: text
+ type: string
+ required:
+ - type
+ - text
+ type: object
+ - additionalProperties: false
+ properties:
+ data:
+ contentEncoding: base64
+ type: string
+ type:
+ const: image
+ default: image
+ type: string
+ required:
+ - type
+ - data
+ type: object
+ - $ref: '#/components/schemas/ToolCallDelta'
CreateAgentRequest:
additionalProperties: false
properties:
@@ -2664,7 +2689,12 @@ components:
- $ref: '#/components/schemas/ToolCall'
parse_status:
$ref: '#/components/schemas/ToolCallParseStatus'
+ type:
+ const: tool_call
+ default: tool_call
+ type: string
required:
+ - type
- content
- parse_status
type: object
@@ -2672,8 +2702,8 @@ components:
enum:
- started
- in_progress
- - failure
- - success
+ - failed
+ - succeeded
type: string
ToolChoice:
enum:
@@ -2888,8 +2918,8 @@ components:
content:
$ref: '#/components/schemas/InterleavedContent'
role:
- const: ipython
- default: ipython
+ const: tool
+ default: tool
type: string
tool_name:
oneOf:
@@ -5500,6 +5530,8 @@ tags:
'
name: CompletionResponseStreamChunk
+- description:
+ name: ContentDelta
- description:
name: CreateAgentRequest
@@ -5939,6 +5971,7 @@ x-tagGroups:
- CompletionRequest
- CompletionResponse
- CompletionResponseStreamChunk
+ - ContentDelta
- CreateAgentRequest
- CreateAgentSessionRequest
- CreateAgentTurnRequest
diff --git a/llama_stack/apis/agents/agents.py b/llama_stack/apis/agents/agents.py
index fb9df21e6..c3f3d21f0 100644
--- a/llama_stack/apis/agents/agents.py
+++ b/llama_stack/apis/agents/agents.py
@@ -22,12 +22,11 @@ from llama_models.schema_utils import json_schema_type, register_schema, webmeth
from pydantic import BaseModel, ConfigDict, Field
from typing_extensions import Annotated
-from llama_stack.apis.common.content_types import InterleavedContent, URL
+from llama_stack.apis.common.content_types import ContentDelta, InterleavedContent, URL
from llama_stack.apis.inference import (
CompletionMessage,
SamplingParams,
ToolCall,
- ToolCallDelta,
ToolChoice,
ToolPromptFormat,
ToolResponse,
@@ -216,8 +215,7 @@ class AgentTurnResponseStepProgressPayload(BaseModel):
step_type: StepType
step_id: str
- text_delta: Optional[str] = None
- tool_call_delta: Optional[ToolCallDelta] = None
+ delta: ContentDelta
@json_schema_type
diff --git a/llama_stack/apis/agents/event_logger.py b/llama_stack/apis/agents/event_logger.py
index 40a69d19c..9e2f14805 100644
--- a/llama_stack/apis/agents/event_logger.py
+++ b/llama_stack/apis/agents/event_logger.py
@@ -11,9 +11,13 @@ from llama_models.llama3.api.tool_utils import ToolUtils
from termcolor import cprint
from llama_stack.apis.agents import AgentTurnResponseEventType, StepType
-
+from llama_stack.apis.common.content_types import ToolCallParseStatus
from llama_stack.apis.inference import ToolResponseMessage
+from llama_stack.providers.utils.inference.prompt_adapter import (
+ interleaved_content_as_str,
+)
+
class LogEvent:
def __init__(
@@ -57,8 +61,11 @@ class EventLogger:
# since it does not produce event but instead
# a Message
if isinstance(chunk, ToolResponseMessage):
- yield chunk, LogEvent(
- role="CustomTool", content=chunk.content, color="grey"
+ yield (
+ chunk,
+ LogEvent(
+ role="CustomTool", content=chunk.content, color="grey"
+ ),
)
continue
@@ -80,14 +87,20 @@ class EventLogger:
):
violation = event.payload.step_details.violation
if not violation:
- yield event, LogEvent(
- role=step_type, content="No Violation", color="magenta"
+ yield (
+ event,
+ LogEvent(
+ role=step_type, content="No Violation", color="magenta"
+ ),
)
else:
- yield event, LogEvent(
- role=step_type,
- content=f"{violation.metadata} {violation.user_message}",
- color="red",
+ yield (
+ event,
+ LogEvent(
+ role=step_type,
+ content=f"{violation.metadata} {violation.user_message}",
+ color="red",
+ ),
)
# handle inference
@@ -95,8 +108,11 @@ class EventLogger:
if stream:
if event_type == EventType.step_start.value:
# TODO: Currently this event is never received
- yield event, LogEvent(
- role=step_type, content="", end="", color="yellow"
+ yield (
+ event,
+ LogEvent(
+ role=step_type, content="", end="", color="yellow"
+ ),
)
elif event_type == EventType.step_progress.value:
# HACK: if previous was not step/event was not inference's step_progress
@@ -107,24 +123,34 @@ class EventLogger:
previous_event_type != EventType.step_progress.value
and previous_step_type != StepType.inference
):
- yield event, LogEvent(
- role=step_type, content="", end="", color="yellow"
+ yield (
+ event,
+ LogEvent(
+ role=step_type, content="", end="", color="yellow"
+ ),
)
- if event.payload.tool_call_delta:
- if isinstance(event.payload.tool_call_delta.content, str):
- yield event, LogEvent(
- role=None,
- content=event.payload.tool_call_delta.content,
- end="",
- color="cyan",
+ delta = event.payload.delta
+ if delta.type == "tool_call":
+ if delta.parse_status == ToolCallParseStatus.succeeded:
+ yield (
+ event,
+ LogEvent(
+ role=None,
+ content=delta.content,
+ end="",
+ color="cyan",
+ ),
)
else:
- yield event, LogEvent(
- role=None,
- content=event.payload.text_delta,
- end="",
- color="yellow",
+ yield (
+ event,
+ LogEvent(
+ role=None,
+ content=delta.text,
+ end="",
+ color="yellow",
+ ),
)
else:
# step_complete
@@ -140,10 +166,13 @@ class EventLogger:
)
else:
content = response.content
- yield event, LogEvent(
- role=step_type,
- content=content,
- color="yellow",
+ yield (
+ event,
+ LogEvent(
+ role=step_type,
+ content=content,
+ color="yellow",
+ ),
)
# handle tool_execution
@@ -155,16 +184,22 @@ class EventLogger:
):
details = event.payload.step_details
for t in details.tool_calls:
- yield event, LogEvent(
- role=step_type,
- content=f"Tool:{t.tool_name} Args:{t.arguments}",
- color="green",
+ yield (
+ event,
+ LogEvent(
+ role=step_type,
+ content=f"Tool:{t.tool_name} Args:{t.arguments}",
+ color="green",
+ ),
)
for r in details.tool_responses:
- yield event, LogEvent(
- role=step_type,
- content=f"Tool:{r.tool_name} Response:{r.content}",
- color="green",
+ yield (
+ event,
+ LogEvent(
+ role=step_type,
+ content=f"Tool:{r.tool_name} Response:{r.content}",
+ color="green",
+ ),
)
if (
@@ -172,15 +207,16 @@ class EventLogger:
and event_type == EventType.step_complete.value
):
details = event.payload.step_details
- inserted_context = interleaved_text_media_as_str(
- details.inserted_context
- )
+ inserted_context = interleaved_content_as_str(details.inserted_context)
content = f"fetched {len(inserted_context)} bytes from {details.memory_bank_ids}"
- yield event, LogEvent(
- role=step_type,
- content=content,
- color="cyan",
+ yield (
+ event,
+ LogEvent(
+ role=step_type,
+ content=content,
+ color="cyan",
+ ),
)
previous_event_type = event_type
diff --git a/llama_stack/apis/common/content_types.py b/llama_stack/apis/common/content_types.py
index 629e0e94d..3b61fa243 100644
--- a/llama_stack/apis/common/content_types.py
+++ b/llama_stack/apis/common/content_types.py
@@ -5,10 +5,12 @@
# the root directory of this source tree.
import base64
+from enum import Enum
from typing import Annotated, List, Literal, Optional, Union
-from llama_models.schema_utils import json_schema_type, register_schema
+from llama_models.llama3.api.datatypes import ToolCall
+from llama_models.schema_utils import json_schema_type, register_schema
from pydantic import BaseModel, Field, field_serializer, model_validator
@@ -60,3 +62,42 @@ InterleavedContent = register_schema(
Union[str, InterleavedContentItem, List[InterleavedContentItem]],
name="InterleavedContent",
)
+
+
+class TextDelta(BaseModel):
+ type: Literal["text"] = "text"
+ text: str
+
+
+class ImageDelta(BaseModel):
+ type: Literal["image"] = "image"
+ data: bytes
+
+
+@json_schema_type
+class ToolCallParseStatus(Enum):
+ started = "started"
+ in_progress = "in_progress"
+ failed = "failed"
+ succeeded = "succeeded"
+
+
+@json_schema_type
+class ToolCallDelta(BaseModel):
+ type: Literal["tool_call"] = "tool_call"
+
+ # you either send an in-progress tool call so the client can stream a long
+ # code generation or you send the final parsed tool call at the end of the
+ # stream
+ content: Union[str, ToolCall]
+ parse_status: ToolCallParseStatus
+
+
+# streaming completions send a stream of ContentDeltas
+ContentDelta = register_schema(
+ Annotated[
+ Union[TextDelta, ImageDelta, ToolCallDelta],
+ Field(discriminator="type"),
+ ],
+ name="ContentDelta",
+)
diff --git a/llama_stack/apis/inference/inference.py b/llama_stack/apis/inference/inference.py
index 4a453700c..b525aa331 100644
--- a/llama_stack/apis/inference/inference.py
+++ b/llama_stack/apis/inference/inference.py
@@ -29,7 +29,7 @@ from llama_models.schema_utils import json_schema_type, register_schema, webmeth
from pydantic import BaseModel, Field, field_validator
from typing_extensions import Annotated
-from llama_stack.apis.common.content_types import InterleavedContent
+from llama_stack.apis.common.content_types import ContentDelta, InterleavedContent
from llama_stack.apis.models import Model
from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol
@@ -147,26 +147,12 @@ class ChatCompletionResponseEventType(Enum):
progress = "progress"
-@json_schema_type
-class ToolCallParseStatus(Enum):
- started = "started"
- in_progress = "in_progress"
- failure = "failure"
- success = "success"
-
-
-@json_schema_type
-class ToolCallDelta(BaseModel):
- content: Union[str, ToolCall]
- parse_status: ToolCallParseStatus
-
-
@json_schema_type
class ChatCompletionResponseEvent(BaseModel):
"""Chat completion response event."""
event_type: ChatCompletionResponseEventType
- delta: Union[str, ToolCallDelta]
+ delta: ContentDelta
logprobs: Optional[List[TokenLogProbs]] = None
stop_reason: Optional[StopReason] = None
diff --git a/llama_stack/cli/stack/build.py b/llama_stack/cli/stack/build.py
index 084374c8a..85e6cb962 100644
--- a/llama_stack/cli/stack/build.py
+++ b/llama_stack/cli/stack/build.py
@@ -4,9 +4,7 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import argparse
-
import importlib.resources
-
import os
import shutil
from functools import lru_cache
@@ -14,14 +12,12 @@ from pathlib import Path
from typing import List, Optional
from llama_stack.cli.subcommand import Subcommand
-
from llama_stack.distribution.datatypes import (
BuildConfig,
DistributionSpec,
Provider,
StackRunConfig,
)
-
from llama_stack.distribution.distribution import get_provider_registry
from llama_stack.distribution.resolver import InvalidProviderError
from llama_stack.distribution.utils.dynamic import instantiate_class_type
@@ -296,6 +292,7 @@ class StackBuild(Subcommand):
/ f"templates/{template_name}/run.yaml"
)
with importlib.resources.as_file(template_path) as path:
+ run_config_file = build_dir / f"{build_config.name}-run.yaml"
shutil.copy(path, run_config_file)
# Find all ${env.VARIABLE} patterns
cprint("Build Successful!", color="green")
diff --git a/llama_stack/distribution/library_client.py b/llama_stack/distribution/library_client.py
index 50af2cdea..0c124e64b 100644
--- a/llama_stack/distribution/library_client.py
+++ b/llama_stack/distribution/library_client.py
@@ -9,12 +9,10 @@ import inspect
import json
import logging
import os
-import queue
-import threading
from concurrent.futures import ThreadPoolExecutor
from enum import Enum
from pathlib import Path
-from typing import Any, Generator, get_args, get_origin, Optional, TypeVar
+from typing import Any, get_args, get_origin, Optional, TypeVar
import httpx
import yaml
@@ -64,71 +62,6 @@ def in_notebook():
return True
-def stream_across_asyncio_run_boundary(
- async_gen_maker,
- pool_executor: ThreadPoolExecutor,
- path: Optional[str] = None,
- provider_data: Optional[dict[str, Any]] = None,
-) -> Generator[T, None, None]:
- result_queue = queue.Queue()
- stop_event = threading.Event()
-
- async def consumer():
- # make sure we make the generator in the event loop context
- gen = await async_gen_maker()
- await start_trace(path, {"__location__": "library_client"})
- if provider_data:
- set_request_provider_data(
- {"X-LlamaStack-Provider-Data": json.dumps(provider_data)}
- )
- try:
- async for item in await gen:
- result_queue.put(item)
- except Exception as e:
- print(f"Error in generator {e}")
- result_queue.put(e)
- except asyncio.CancelledError:
- return
- finally:
- result_queue.put(StopIteration)
- stop_event.set()
- await end_trace()
-
- def run_async():
- # Run our own loop to avoid double async generator cleanup which is done
- # by asyncio.run()
- loop = asyncio.new_event_loop()
- asyncio.set_event_loop(loop)
- try:
- task = loop.create_task(consumer())
- loop.run_until_complete(task)
- finally:
- # Handle pending tasks like a generator's athrow()
- pending = asyncio.all_tasks(loop)
- if pending:
- loop.run_until_complete(
- asyncio.gather(*pending, return_exceptions=True)
- )
- loop.close()
-
- future = pool_executor.submit(run_async)
-
- try:
- # yield results as they come in
- while not stop_event.is_set() or not result_queue.empty():
- try:
- item = result_queue.get(timeout=0.1)
- if item is StopIteration:
- break
- if isinstance(item, Exception):
- raise item
- yield item
- except queue.Empty:
- continue
- finally:
- future.result()
-
-
def convert_pydantic_to_json_value(value: Any) -> Any:
if isinstance(value, Enum):
return value.value
@@ -184,7 +117,7 @@ class LlamaStackAsLibraryClient(LlamaStackClient):
):
super().__init__()
self.async_client = AsyncLlamaStackAsLibraryClient(
- config_path_or_template_name, custom_provider_registry
+ config_path_or_template_name, custom_provider_registry, provider_data
)
self.pool_executor = ThreadPoolExecutor(max_workers=4)
self.skip_logger_removal = skip_logger_removal
@@ -210,39 +143,30 @@ class LlamaStackAsLibraryClient(LlamaStackClient):
root_logger.removeHandler(handler)
print(f"Removed handler {handler.__class__.__name__} from root logger")
- def _get_path(
- self,
- cast_to: Any,
- options: Any,
- *,
- stream=False,
- stream_cls=None,
- ):
- return options.url
-
def request(self, *args, **kwargs):
- path = self._get_path(*args, **kwargs)
if kwargs.get("stream"):
- return stream_across_asyncio_run_boundary(
- lambda: self.async_client.request(*args, **kwargs),
- self.pool_executor,
- path=path,
- provider_data=self.provider_data,
- )
- else:
+ # NOTE: We are using AsyncLlamaStackClient under the hood
+ # A new event loop is needed to convert the AsyncStream
+ # from async client into SyncStream return type for streaming
+ loop = asyncio.new_event_loop()
+ asyncio.set_event_loop(loop)
- async def _traced_request():
- if self.provider_data:
- set_request_provider_data(
- {"X-LlamaStack-Provider-Data": json.dumps(self.provider_data)}
- )
- await start_trace(path, {"__location__": "library_client"})
+ def sync_generator():
try:
- return await self.async_client.request(*args, **kwargs)
+ async_stream = loop.run_until_complete(
+ self.async_client.request(*args, **kwargs)
+ )
+ while True:
+ chunk = loop.run_until_complete(async_stream.__anext__())
+ yield chunk
+ except StopAsyncIteration:
+ pass
finally:
- await end_trace()
+ loop.close()
- return asyncio.run(_traced_request())
+ return sync_generator()
+ else:
+ return asyncio.run(self.async_client.request(*args, **kwargs))
class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
@@ -250,9 +174,9 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
self,
config_path_or_template_name: str,
custom_provider_registry: Optional[ProviderRegistry] = None,
+ provider_data: Optional[dict[str, Any]] = None,
):
super().__init__()
-
# when using the library client, we should not log to console since many
# of our logs are intended for server-side usage
current_sinks = os.environ.get("TELEMETRY_SINKS", "sqlite").split(",")
@@ -273,6 +197,7 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
self.config_path_or_template_name = config_path_or_template_name
self.config = config
self.custom_provider_registry = custom_provider_registry
+ self.provider_data = provider_data
async def initialize(self):
try:
@@ -329,17 +254,24 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
if not self.endpoint_impls:
raise ValueError("Client not initialized")
+ if self.provider_data:
+ set_request_provider_data(
+ {"X-LlamaStack-Provider-Data": json.dumps(self.provider_data)}
+ )
+ await start_trace(options.url, {"__location__": "library_client"})
if stream:
- return self._call_streaming(
+ response = await self._call_streaming(
cast_to=cast_to,
options=options,
stream_cls=stream_cls,
)
else:
- return await self._call_non_streaming(
+ response = await self._call_non_streaming(
cast_to=cast_to,
options=options,
)
+ await end_trace()
+ return response
async def _call_non_streaming(
self,
diff --git a/llama_stack/distribution/store/registry.py b/llama_stack/distribution/store/registry.py
index d26b4447c..010d137ec 100644
--- a/llama_stack/distribution/store/registry.py
+++ b/llama_stack/distribution/store/registry.py
@@ -35,7 +35,7 @@ class DistributionRegistry(Protocol):
REGISTER_PREFIX = "distributions:registry"
-KEY_VERSION = "v4"
+KEY_VERSION = "v5"
KEY_FORMAT = f"{REGISTER_PREFIX}:{KEY_VERSION}::" + "{type}:{identifier}"
diff --git a/llama_stack/providers/inline/agents/meta_reference/agent_instance.py b/llama_stack/providers/inline/agents/meta_reference/agent_instance.py
index 24448a28f..2299e80d1 100644
--- a/llama_stack/providers/inline/agents/meta_reference/agent_instance.py
+++ b/llama_stack/providers/inline/agents/meta_reference/agent_instance.py
@@ -40,7 +40,12 @@ from llama_stack.apis.agents import (
ToolExecutionStep,
Turn,
)
-from llama_stack.apis.common.content_types import TextContentItem, URL
+from llama_stack.apis.common.content_types import (
+ TextContentItem,
+ ToolCallDelta,
+ ToolCallParseStatus,
+ URL,
+)
from llama_stack.apis.inference import (
ChatCompletionResponseEventType,
CompletionMessage,
@@ -49,8 +54,6 @@ from llama_stack.apis.inference import (
SamplingParams,
StopReason,
SystemMessage,
- ToolCallDelta,
- ToolCallParseStatus,
ToolDefinition,
ToolResponse,
ToolResponseMessage,
@@ -411,8 +414,8 @@ class ChatAgent(ShieldRunnerMixin):
payload=AgentTurnResponseStepProgressPayload(
step_type=StepType.tool_execution.value,
step_id=step_id,
- tool_call_delta=ToolCallDelta(
- parse_status=ToolCallParseStatus.success,
+ delta=ToolCallDelta(
+ parse_status=ToolCallParseStatus.succeeded,
content=ToolCall(
call_id="",
tool_name=MEMORY_QUERY_TOOL,
@@ -507,8 +510,8 @@ class ChatAgent(ShieldRunnerMixin):
continue
delta = event.delta
- if isinstance(delta, ToolCallDelta):
- if delta.parse_status == ToolCallParseStatus.success:
+ if delta.type == "tool_call":
+ if delta.parse_status == ToolCallParseStatus.succeeded:
tool_calls.append(delta.content)
if stream:
yield AgentTurnResponseStreamChunk(
@@ -516,21 +519,20 @@ class ChatAgent(ShieldRunnerMixin):
payload=AgentTurnResponseStepProgressPayload(
step_type=StepType.inference.value,
step_id=step_id,
- text_delta="",
- tool_call_delta=delta,
+ delta=delta,
)
)
)
- elif isinstance(delta, str):
- content += delta
+ elif delta.type == "text":
+ content += delta.text
if stream and event.stop_reason is None:
yield AgentTurnResponseStreamChunk(
event=AgentTurnResponseEvent(
payload=AgentTurnResponseStepProgressPayload(
step_type=StepType.inference.value,
step_id=step_id,
- text_delta=event.delta,
+ delta=delta,
)
)
)
diff --git a/llama_stack/providers/inline/inference/meta_reference/inference.py b/llama_stack/providers/inline/inference/meta_reference/inference.py
index 5b502a581..d64d32f03 100644
--- a/llama_stack/providers/inline/inference/meta_reference/inference.py
+++ b/llama_stack/providers/inline/inference/meta_reference/inference.py
@@ -16,6 +16,11 @@ from llama_models.llama3.api.datatypes import (
)
from llama_models.sku_list import resolve_model
+from llama_stack.apis.common.content_types import (
+ TextDelta,
+ ToolCallDelta,
+ ToolCallParseStatus,
+)
from llama_stack.apis.inference import (
ChatCompletionRequest,
ChatCompletionResponse,
@@ -32,8 +37,6 @@ from llama_stack.apis.inference import (
Message,
ResponseFormat,
TokenLogProbs,
- ToolCallDelta,
- ToolCallParseStatus,
ToolChoice,
)
from llama_stack.apis.models import Model, ModelType
@@ -190,14 +193,14 @@ class MetaReferenceInferenceImpl(
]
yield CompletionResponseStreamChunk(
- delta=text,
+ delta=TextDelta(text=text),
stop_reason=stop_reason,
logprobs=logprobs if request.logprobs else None,
)
if stop_reason is None:
yield CompletionResponseStreamChunk(
- delta="",
+ delta=TextDelta(text=""),
stop_reason=StopReason.out_of_tokens,
)
@@ -352,7 +355,7 @@ class MetaReferenceInferenceImpl(
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.start,
- delta="",
+ delta=TextDelta(text=""),
)
)
@@ -392,7 +395,7 @@ class MetaReferenceInferenceImpl(
parse_status=ToolCallParseStatus.in_progress,
)
else:
- delta = text
+ delta = TextDelta(text=text)
if stop_reason is None:
if request.logprobs:
@@ -428,7 +431,7 @@ class MetaReferenceInferenceImpl(
event_type=ChatCompletionResponseEventType.progress,
delta=ToolCallDelta(
content="",
- parse_status=ToolCallParseStatus.failure,
+ parse_status=ToolCallParseStatus.failed,
),
stop_reason=stop_reason,
)
@@ -440,7 +443,7 @@ class MetaReferenceInferenceImpl(
event_type=ChatCompletionResponseEventType.progress,
delta=ToolCallDelta(
content=tool_call,
- parse_status=ToolCallParseStatus.success,
+ parse_status=ToolCallParseStatus.succeeded,
),
stop_reason=stop_reason,
)
@@ -449,7 +452,7 @@ class MetaReferenceInferenceImpl(
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.complete,
- delta="",
+ delta=TextDelta(text=""),
stop_reason=stop_reason,
)
)
diff --git a/llama_stack/providers/inline/telemetry/meta_reference/telemetry.py b/llama_stack/providers/inline/telemetry/meta_reference/telemetry.py
index efc37b553..332a150cf 100644
--- a/llama_stack/providers/inline/telemetry/meta_reference/telemetry.py
+++ b/llama_stack/providers/inline/telemetry/meta_reference/telemetry.py
@@ -30,13 +30,10 @@ from llama_stack.apis.telemetry import (
Trace,
UnstructuredLogEvent,
)
-
from llama_stack.distribution.datatypes import Api
-
from llama_stack.providers.inline.telemetry.meta_reference.console_span_processor import (
ConsoleSpanProcessor,
)
-
from llama_stack.providers.inline.telemetry.meta_reference.sqlite_span_processor import (
SQLiteSpanProcessor,
)
@@ -52,6 +49,7 @@ _GLOBAL_STORAGE = {
"up_down_counters": {},
}
_global_lock = threading.Lock()
+_TRACER_PROVIDER = None
def string_to_trace_id(s: str) -> int:
@@ -80,31 +78,34 @@ class TelemetryAdapter(TelemetryDatasetMixin, Telemetry):
}
)
- provider = TracerProvider(resource=resource)
- trace.set_tracer_provider(provider)
- if TelemetrySink.OTEL in self.config.sinks:
- otlp_exporter = OTLPSpanExporter(
- endpoint=self.config.otel_endpoint,
- )
- span_processor = BatchSpanProcessor(otlp_exporter)
- trace.get_tracer_provider().add_span_processor(span_processor)
- metric_reader = PeriodicExportingMetricReader(
- OTLPMetricExporter(
+ global _TRACER_PROVIDER
+ if _TRACER_PROVIDER is None:
+ provider = TracerProvider(resource=resource)
+ trace.set_tracer_provider(provider)
+ _TRACER_PROVIDER = provider
+ if TelemetrySink.OTEL in self.config.sinks:
+ otlp_exporter = OTLPSpanExporter(
endpoint=self.config.otel_endpoint,
)
- )
- metric_provider = MeterProvider(
- resource=resource, metric_readers=[metric_reader]
- )
- metrics.set_meter_provider(metric_provider)
- self.meter = metrics.get_meter(__name__)
- if TelemetrySink.SQLITE in self.config.sinks:
- trace.get_tracer_provider().add_span_processor(
- SQLiteSpanProcessor(self.config.sqlite_db_path)
- )
- self.trace_store = SQLiteTraceStore(self.config.sqlite_db_path)
- if TelemetrySink.CONSOLE in self.config.sinks:
- trace.get_tracer_provider().add_span_processor(ConsoleSpanProcessor())
+ span_processor = BatchSpanProcessor(otlp_exporter)
+ trace.get_tracer_provider().add_span_processor(span_processor)
+ metric_reader = PeriodicExportingMetricReader(
+ OTLPMetricExporter(
+ endpoint=self.config.otel_endpoint,
+ )
+ )
+ metric_provider = MeterProvider(
+ resource=resource, metric_readers=[metric_reader]
+ )
+ metrics.set_meter_provider(metric_provider)
+ self.meter = metrics.get_meter(__name__)
+ if TelemetrySink.SQLITE in self.config.sinks:
+ trace.get_tracer_provider().add_span_processor(
+ SQLiteSpanProcessor(self.config.sqlite_db_path)
+ )
+ self.trace_store = SQLiteTraceStore(self.config.sqlite_db_path)
+ if TelemetrySink.CONSOLE in self.config.sinks:
+ trace.get_tracer_provider().add_span_processor(ConsoleSpanProcessor())
self._lock = _global_lock
async def initialize(self) -> None:
diff --git a/llama_stack/providers/remote/inference/groq/groq_utils.py b/llama_stack/providers/remote/inference/groq/groq_utils.py
index 032f4c8d4..11f684847 100644
--- a/llama_stack/providers/remote/inference/groq/groq_utils.py
+++ b/llama_stack/providers/remote/inference/groq/groq_utils.py
@@ -30,6 +30,11 @@ from groq.types.shared.function_definition import FunctionDefinition
from llama_models.llama3.api.datatypes import ToolParamDefinition
+from llama_stack.apis.common.content_types import (
+ TextDelta,
+ ToolCallDelta,
+ ToolCallParseStatus,
+)
from llama_stack.apis.inference import (
ChatCompletionRequest,
ChatCompletionResponse,
@@ -40,8 +45,6 @@ from llama_stack.apis.inference import (
Message,
StopReason,
ToolCall,
- ToolCallDelta,
- ToolCallParseStatus,
ToolDefinition,
ToolPromptFormat,
)
@@ -162,7 +165,7 @@ def convert_chat_completion_response(
def _map_finish_reason_to_stop_reason(
- finish_reason: Literal["stop", "length", "tool_calls"]
+ finish_reason: Literal["stop", "length", "tool_calls"],
) -> StopReason:
"""
Convert a Groq chat completion finish_reason to a StopReason.
@@ -185,7 +188,6 @@ def _map_finish_reason_to_stop_reason(
async def convert_chat_completion_response_stream(
stream: Stream[ChatCompletionChunk],
) -> AsyncGenerator[ChatCompletionResponseStreamChunk, None]:
-
event_type = ChatCompletionResponseEventType.start
for chunk in stream:
choice = chunk.choices[0]
@@ -194,7 +196,7 @@ async def convert_chat_completion_response_stream(
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.complete,
- delta=choice.delta.content or "",
+ delta=TextDelta(text=choice.delta.content or ""),
logprobs=None,
stop_reason=_map_finish_reason_to_stop_reason(choice.finish_reason),
)
@@ -213,7 +215,7 @@ async def convert_chat_completion_response_stream(
event_type=event_type,
delta=ToolCallDelta(
content=tool_call,
- parse_status=ToolCallParseStatus.success,
+ parse_status=ToolCallParseStatus.succeeded,
),
)
)
@@ -221,7 +223,7 @@ async def convert_chat_completion_response_stream(
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=event_type,
- delta=choice.delta.content or "",
+ delta=TextDelta(text=choice.delta.content or ""),
logprobs=None,
)
)
diff --git a/llama_stack/providers/remote/inference/nvidia/openai_utils.py b/llama_stack/providers/remote/inference/nvidia/openai_utils.py
index dcc7c5fca..975812844 100644
--- a/llama_stack/providers/remote/inference/nvidia/openai_utils.py
+++ b/llama_stack/providers/remote/inference/nvidia/openai_utils.py
@@ -34,6 +34,11 @@ from openai.types.chat.chat_completion_message_tool_call_param import (
from openai.types.completion import Completion as OpenAICompletion
from openai.types.completion_choice import Logprobs as OpenAICompletionLogprobs
+from llama_stack.apis.common.content_types import (
+ TextDelta,
+ ToolCallDelta,
+ ToolCallParseStatus,
+)
from llama_stack.apis.inference import (
ChatCompletionRequest,
ChatCompletionResponse,
@@ -48,8 +53,6 @@ from llama_stack.apis.inference import (
Message,
SystemMessage,
TokenLogProbs,
- ToolCallDelta,
- ToolCallParseStatus,
ToolResponseMessage,
UserMessage,
)
@@ -432,69 +435,6 @@ async def convert_openai_chat_completion_stream(
"""
Convert a stream of OpenAI chat completion chunks into a stream
of ChatCompletionResponseStreamChunk.
-
- OpenAI ChatCompletionChunk:
- choices: List[Choice]
-
- OpenAI Choice: # different from the non-streamed Choice
- delta: ChoiceDelta
- finish_reason: Optional[Literal["stop", "length", "tool_calls", "content_filter", "function_call"]]
- logprobs: Optional[ChoiceLogprobs]
-
- OpenAI ChoiceDelta:
- content: Optional[str]
- role: Optional[Literal["system", "user", "assistant", "tool"]]
- tool_calls: Optional[List[ChoiceDeltaToolCall]]
-
- OpenAI ChoiceDeltaToolCall:
- index: int
- id: Optional[str]
- function: Optional[ChoiceDeltaToolCallFunction]
- type: Optional[Literal["function"]]
-
- OpenAI ChoiceDeltaToolCallFunction:
- name: Optional[str]
- arguments: Optional[str]
-
- ->
-
- ChatCompletionResponseStreamChunk:
- event: ChatCompletionResponseEvent
-
- ChatCompletionResponseEvent:
- event_type: ChatCompletionResponseEventType
- delta: Union[str, ToolCallDelta]
- logprobs: Optional[List[TokenLogProbs]]
- stop_reason: Optional[StopReason]
-
- ChatCompletionResponseEventType:
- start = "start"
- progress = "progress"
- complete = "complete"
-
- ToolCallDelta:
- content: Union[str, ToolCall]
- parse_status: ToolCallParseStatus
-
- ToolCall:
- call_id: str
- tool_name: str
- arguments: str
-
- ToolCallParseStatus:
- started = "started"
- in_progress = "in_progress"
- failure = "failure"
- success = "success"
-
- TokenLogProbs:
- logprobs_by_token: Dict[str, float]
- - token, logprob
-
- StopReason:
- end_of_turn = "end_of_turn"
- end_of_message = "end_of_message"
- out_of_tokens = "out_of_tokens"
"""
# generate a stream of ChatCompletionResponseEventType: start -> progress -> progress -> ...
@@ -543,7 +483,7 @@ async def convert_openai_chat_completion_stream(
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=next(event_type),
- delta=choice.delta.content,
+ delta=TextDelta(text=choice.delta.content),
logprobs=_convert_openai_logprobs(choice.logprobs),
)
)
@@ -561,7 +501,7 @@ async def convert_openai_chat_completion_stream(
event_type=next(event_type),
delta=ToolCallDelta(
content=_convert_openai_tool_calls(choice.delta.tool_calls)[0],
- parse_status=ToolCallParseStatus.success,
+ parse_status=ToolCallParseStatus.succeeded,
),
logprobs=_convert_openai_logprobs(choice.logprobs),
)
@@ -570,7 +510,7 @@ async def convert_openai_chat_completion_stream(
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=next(event_type),
- delta=choice.delta.content or "", # content is not optional
+ delta=TextDelta(text=choice.delta.content or ""),
logprobs=_convert_openai_logprobs(choice.logprobs),
)
)
@@ -578,7 +518,7 @@ async def convert_openai_chat_completion_stream(
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.complete,
- delta="",
+ delta=TextDelta(text=""),
stop_reason=stop_reason,
)
)
@@ -653,18 +593,6 @@ def _convert_openai_completion_logprobs(
) -> Optional[List[TokenLogProbs]]:
"""
Convert an OpenAI CompletionLogprobs into a list of TokenLogProbs.
-
- OpenAI CompletionLogprobs:
- text_offset: Optional[List[int]]
- token_logprobs: Optional[List[float]]
- tokens: Optional[List[str]]
- top_logprobs: Optional[List[Dict[str, float]]]
-
- ->
-
- TokenLogProbs:
- logprobs_by_token: Dict[str, float]
- - token, logprob
"""
if not logprobs:
return None
@@ -679,28 +607,6 @@ def convert_openai_completion_choice(
) -> CompletionResponse:
"""
Convert an OpenAI Completion Choice into a CompletionResponse.
-
- OpenAI Completion Choice:
- text: str
- finish_reason: str
- logprobs: Optional[ChoiceLogprobs]
-
- ->
-
- CompletionResponse:
- completion_message: CompletionMessage
- logprobs: Optional[List[TokenLogProbs]]
-
- CompletionMessage:
- role: Literal["assistant"]
- content: str | ImageMedia | List[str | ImageMedia]
- stop_reason: StopReason
- tool_calls: List[ToolCall]
-
- class StopReason(Enum):
- end_of_turn = "end_of_turn"
- end_of_message = "end_of_message"
- out_of_tokens = "out_of_tokens"
"""
return CompletionResponse(
content=choice.text,
@@ -715,32 +621,11 @@ async def convert_openai_completion_stream(
"""
Convert a stream of OpenAI Completions into a stream
of ChatCompletionResponseStreamChunks.
-
- OpenAI Completion:
- id: str
- choices: List[OpenAICompletionChoice]
- created: int
- model: str
- system_fingerprint: Optional[str]
- usage: Optional[OpenAICompletionUsage]
-
- OpenAI CompletionChoice:
- finish_reason: str
- index: int
- logprobs: Optional[OpenAILogprobs]
- text: str
-
- ->
-
- CompletionResponseStreamChunk:
- delta: str
- stop_reason: Optional[StopReason]
- logprobs: Optional[List[TokenLogProbs]]
"""
async for chunk in stream:
choice = chunk.choices[0]
yield CompletionResponseStreamChunk(
- delta=choice.text,
+ delta=TextDelta(text=choice.text),
stop_reason=_convert_openai_finish_reason(choice.finish_reason),
logprobs=_convert_openai_completion_logprobs(choice.logprobs),
)
diff --git a/llama_stack/providers/tests/inference/test_text_inference.py b/llama_stack/providers/tests/inference/test_text_inference.py
index 19cc8393c..932ae36e6 100644
--- a/llama_stack/providers/tests/inference/test_text_inference.py
+++ b/llama_stack/providers/tests/inference/test_text_inference.py
@@ -18,6 +18,7 @@ from llama_models.llama3.api.datatypes import (
from pydantic import BaseModel, ValidationError
+from llama_stack.apis.common.content_types import ToolCallParseStatus
from llama_stack.apis.inference import (
ChatCompletionResponse,
ChatCompletionResponseEventType,
@@ -27,8 +28,6 @@ from llama_stack.apis.inference import (
JsonSchemaResponseFormat,
LogProbConfig,
SystemMessage,
- ToolCallDelta,
- ToolCallParseStatus,
ToolChoice,
UserMessage,
)
@@ -196,7 +195,9 @@ class TestInference:
1 <= len(chunks) <= 6
) # why 6 and not 5? the response may have an extra closing chunk, e.g. for usage or stop_reason
for chunk in chunks:
- if chunk.delta: # if there's a token, we expect logprobs
+ if (
+ chunk.delta.type == "text" and chunk.delta.text
+ ): # if there's a token, we expect logprobs
assert chunk.logprobs, "Logprobs should not be empty"
assert all(
len(logprob.logprobs_by_token) == 3 for logprob in chunk.logprobs
@@ -463,7 +464,7 @@ class TestInference:
if "Llama3.1" in inference_model:
assert all(
- isinstance(chunk.event.delta, ToolCallDelta)
+ chunk.event.delta.type == "tool_call"
for chunk in grouped[ChatCompletionResponseEventType.progress]
)
first = grouped[ChatCompletionResponseEventType.progress][0]
@@ -474,8 +475,8 @@ class TestInference:
last = grouped[ChatCompletionResponseEventType.progress][-1]
# assert last.event.stop_reason == expected_stop_reason
- assert last.event.delta.parse_status == ToolCallParseStatus.success
- assert isinstance(last.event.delta.content, ToolCall)
+ assert last.event.delta.parse_status == ToolCallParseStatus.succeeded
+ assert last.event.delta.content.type == "tool_call"
call = last.event.delta.content
assert call.tool_name == "get_weather"
diff --git a/llama_stack/providers/utils/inference/openai_compat.py b/llama_stack/providers/utils/inference/openai_compat.py
index ba63be2b6..4c46954cf 100644
--- a/llama_stack/providers/utils/inference/openai_compat.py
+++ b/llama_stack/providers/utils/inference/openai_compat.py
@@ -11,7 +11,13 @@ from llama_models.llama3.api.chat_format import ChatFormat
from llama_models.llama3.api.datatypes import SamplingParams, StopReason
from pydantic import BaseModel
-from llama_stack.apis.common.content_types import ImageContentItem, TextContentItem
+from llama_stack.apis.common.content_types import (
+ ImageContentItem,
+ TextContentItem,
+ TextDelta,
+ ToolCallDelta,
+ ToolCallParseStatus,
+)
from llama_stack.apis.inference import (
ChatCompletionResponse,
@@ -22,8 +28,6 @@ from llama_stack.apis.inference import (
CompletionResponse,
CompletionResponseStreamChunk,
Message,
- ToolCallDelta,
- ToolCallParseStatus,
)
from llama_stack.providers.utils.inference.prompt_adapter import (
@@ -160,7 +164,7 @@ async def process_chat_completion_stream_response(
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.start,
- delta="",
+ delta=TextDelta(text=""),
)
)
@@ -227,7 +231,7 @@ async def process_chat_completion_stream_response(
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.progress,
- delta=text,
+ delta=TextDelta(text=text),
stop_reason=stop_reason,
)
)
@@ -241,7 +245,7 @@ async def process_chat_completion_stream_response(
event_type=ChatCompletionResponseEventType.progress,
delta=ToolCallDelta(
content="",
- parse_status=ToolCallParseStatus.failure,
+ parse_status=ToolCallParseStatus.failed,
),
stop_reason=stop_reason,
)
@@ -253,7 +257,7 @@ async def process_chat_completion_stream_response(
event_type=ChatCompletionResponseEventType.progress,
delta=ToolCallDelta(
content=tool_call,
- parse_status=ToolCallParseStatus.success,
+ parse_status=ToolCallParseStatus.succeeded,
),
stop_reason=stop_reason,
)
@@ -262,7 +266,7 @@ async def process_chat_completion_stream_response(
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.complete,
- delta="",
+ delta=TextDelta(text=""),
stop_reason=stop_reason,
)
)
diff --git a/llama_stack/providers/utils/inference/prompt_adapter.py b/llama_stack/providers/utils/inference/prompt_adapter.py
index 2d66dc60b..de4918f5c 100644
--- a/llama_stack/providers/utils/inference/prompt_adapter.py
+++ b/llama_stack/providers/utils/inference/prompt_adapter.py
@@ -265,6 +265,7 @@ def chat_completion_request_to_messages(
For eg. for llama_3_1, add system message with the appropriate tools or
add user messsage for custom tools, etc.
"""
+ assert llama_model is not None, "llama_model is required"
model = resolve_model(llama_model)
if model is None:
log.error(f"Could not resolve model {llama_model}")
diff --git a/llama_stack/providers/utils/telemetry/tracing.py b/llama_stack/providers/utils/telemetry/tracing.py
index f304d58f6..d84024941 100644
--- a/llama_stack/providers/utils/telemetry/tracing.py
+++ b/llama_stack/providers/utils/telemetry/tracing.py
@@ -127,7 +127,8 @@ class TraceContext:
def setup_logger(api: Telemetry, level: int = logging.INFO):
global BACKGROUND_LOGGER
- BACKGROUND_LOGGER = BackgroundLogger(api)
+ if BACKGROUND_LOGGER is None:
+ BACKGROUND_LOGGER = BackgroundLogger(api)
logger = logging.getLogger()
logger.setLevel(level)
logger.addHandler(TelemetryHandler())
diff --git a/tests/client-sdk/conftest.py b/tests/client-sdk/conftest.py
index 16e6d1bbd..b40d54ee5 100644
--- a/tests/client-sdk/conftest.py
+++ b/tests/client-sdk/conftest.py
@@ -12,6 +12,11 @@ from llama_stack.providers.tests.env import get_env_or_fail
from llama_stack_client import LlamaStackClient
+def pytest_configure(config):
+ config.option.tbstyle = "short"
+ config.option.disable_warnings = True
+
+
@pytest.fixture(scope="session")
def provider_data():
# check env for tavily secret, brave secret and inject all into provider data
@@ -29,6 +34,7 @@ def llama_stack_client(provider_data):
client = LlamaStackAsLibraryClient(
get_env_or_fail("LLAMA_STACK_CONFIG"),
provider_data=provider_data,
+ skip_logger_removal=True,
)
client.initialize()
elif os.environ.get("LLAMA_STACK_BASE_URL"):
diff --git a/tests/client-sdk/inference/test_inference.py b/tests/client-sdk/inference/test_inference.py
index ef6219389..a50dba3a0 100644
--- a/tests/client-sdk/inference/test_inference.py
+++ b/tests/client-sdk/inference/test_inference.py
@@ -6,9 +6,9 @@
import pytest
-from llama_stack_client.lib.inference.event_logger import EventLogger
from pydantic import BaseModel
+
PROVIDER_TOOL_PROMPT_FORMAT = {
"remote::ollama": "python_list",
"remote::together": "json",
@@ -39,7 +39,7 @@ def text_model_id(llama_stack_client):
available_models = [
model.identifier
for model in llama_stack_client.models.list()
- if model.identifier.startswith("meta-llama")
+ if model.identifier.startswith("meta-llama") and "405" not in model.identifier
]
assert len(available_models) > 0
return available_models[0]
@@ -208,12 +208,9 @@ def test_text_chat_completion_streaming(
stream=True,
)
streamed_content = [
- str(log.content.lower().strip())
- for log in EventLogger().log(response)
- if log is not None
+ str(chunk.event.delta.text.lower().strip()) for chunk in response
]
assert len(streamed_content) > 0
- assert "assistant>" in streamed_content[0]
assert expected.lower() in "".join(streamed_content)
@@ -250,17 +247,16 @@ def test_text_chat_completion_with_tool_calling_and_non_streaming(
def extract_tool_invocation_content(response):
text_content: str = ""
tool_invocation_content: str = ""
- for log in EventLogger().log(response):
- if log is None:
- continue
- if isinstance(log.content, str):
- text_content += log.content
- elif isinstance(log.content, object):
- if isinstance(log.content.content, str):
- continue
- elif isinstance(log.content.content, object):
- tool_invocation_content += f"[{log.content.content.tool_name}, {log.content.content.arguments}]"
-
+ for chunk in response:
+ delta = chunk.event.delta
+ if delta.type == "text":
+ text_content += delta.text
+ elif delta.type == "tool_call":
+ if isinstance(delta.content, str):
+ tool_invocation_content += delta.content
+ else:
+ call = delta.content
+ tool_invocation_content += f"[{call.tool_name}, {call.arguments}]"
return text_content, tool_invocation_content
@@ -280,7 +276,6 @@ def test_text_chat_completion_with_tool_calling_and_streaming(
)
text_content, tool_invocation_content = extract_tool_invocation_content(response)
- assert "Assistant>" in text_content
assert tool_invocation_content == "[get_weather, {'location': 'San Francisco, CA'}]"
@@ -368,10 +363,7 @@ def test_image_chat_completion_streaming(llama_stack_client, vision_model_id):
stream=True,
)
streamed_content = [
- str(log.content.lower().strip())
- for log in EventLogger().log(response)
- if log is not None
+ str(chunk.event.delta.text.lower().strip()) for chunk in response
]
assert len(streamed_content) > 0
- assert "assistant>" in streamed_content[0]
assert any(expected in streamed_content for expected in {"dog", "puppy", "pup"})