Merge remote-tracking branch 'origin/main' into support_more_data_format

This commit is contained in:
Botao Chen 2025-01-14 11:55:13 -08:00
commit 8d7bb1140f
20 changed files with 381 additions and 414 deletions

View file

@ -3843,8 +3843,8 @@
"properties": { "properties": {
"role": { "role": {
"type": "string", "type": "string",
"const": "ipython", "const": "tool",
"default": "ipython" "default": "tool"
}, },
"call_id": { "call_id": {
"type": "string" "type": "string"
@ -4185,14 +4185,7 @@
"$ref": "#/components/schemas/ChatCompletionResponseEventType" "$ref": "#/components/schemas/ChatCompletionResponseEventType"
}, },
"delta": { "delta": {
"oneOf": [ "$ref": "#/components/schemas/ContentDelta"
{
"type": "string"
},
{
"$ref": "#/components/schemas/ToolCallDelta"
}
]
}, },
"logprobs": { "logprobs": {
"type": "array", "type": "array",
@ -4232,6 +4225,50 @@
], ],
"title": "SSE-stream of these events." "title": "SSE-stream of these events."
}, },
"ContentDelta": {
"oneOf": [
{
"type": "object",
"properties": {
"type": {
"type": "string",
"const": "text",
"default": "text"
},
"text": {
"type": "string"
}
},
"additionalProperties": false,
"required": [
"type",
"text"
]
},
{
"type": "object",
"properties": {
"type": {
"type": "string",
"const": "image",
"default": "image"
},
"data": {
"type": "string",
"contentEncoding": "base64"
}
},
"additionalProperties": false,
"required": [
"type",
"data"
]
},
{
"$ref": "#/components/schemas/ToolCallDelta"
}
]
},
"TokenLogProbs": { "TokenLogProbs": {
"type": "object", "type": "object",
"properties": { "properties": {
@ -4250,6 +4287,11 @@
"ToolCallDelta": { "ToolCallDelta": {
"type": "object", "type": "object",
"properties": { "properties": {
"type": {
"type": "string",
"const": "tool_call",
"default": "tool_call"
},
"content": { "content": {
"oneOf": [ "oneOf": [
{ {
@ -4266,6 +4308,7 @@
}, },
"additionalProperties": false, "additionalProperties": false,
"required": [ "required": [
"type",
"content", "content",
"parse_status" "parse_status"
] ]
@ -4275,8 +4318,8 @@
"enum": [ "enum": [
"started", "started",
"in_progress", "in_progress",
"failure", "failed",
"success" "succeeded"
] ]
}, },
"CompletionRequest": { "CompletionRequest": {
@ -4777,18 +4820,16 @@
"step_id": { "step_id": {
"type": "string" "type": "string"
}, },
"text_delta": { "delta": {
"type": "string" "$ref": "#/components/schemas/ContentDelta"
},
"tool_call_delta": {
"$ref": "#/components/schemas/ToolCallDelta"
} }
}, },
"additionalProperties": false, "additionalProperties": false,
"required": [ "required": [
"event_type", "event_type",
"step_type", "step_type",
"step_id" "step_id",
"delta"
] ]
}, },
"AgentTurnResponseStepStartPayload": { "AgentTurnResponseStepStartPayload": {
@ -8758,6 +8799,10 @@
"name": "CompletionResponseStreamChunk", "name": "CompletionResponseStreamChunk",
"description": "streamed completion response.\n\n<SchemaDefinition schemaRef=\"#/components/schemas/CompletionResponseStreamChunk\" />" "description": "streamed completion response.\n\n<SchemaDefinition schemaRef=\"#/components/schemas/CompletionResponseStreamChunk\" />"
}, },
{
"name": "ContentDelta",
"description": "<SchemaDefinition schemaRef=\"#/components/schemas/ContentDelta\" />"
},
{ {
"name": "CreateAgentRequest", "name": "CreateAgentRequest",
"description": "<SchemaDefinition schemaRef=\"#/components/schemas/CreateAgentRequest\" />" "description": "<SchemaDefinition schemaRef=\"#/components/schemas/CreateAgentRequest\" />"
@ -9392,6 +9437,7 @@
"CompletionRequest", "CompletionRequest",
"CompletionResponse", "CompletionResponse",
"CompletionResponseStreamChunk", "CompletionResponseStreamChunk",
"ContentDelta",
"CreateAgentRequest", "CreateAgentRequest",
"CreateAgentSessionRequest", "CreateAgentSessionRequest",
"CreateAgentTurnRequest", "CreateAgentTurnRequest",

View file

@ -150,6 +150,8 @@ components:
AgentTurnResponseStepProgressPayload: AgentTurnResponseStepProgressPayload:
additionalProperties: false additionalProperties: false
properties: properties:
delta:
$ref: '#/components/schemas/ContentDelta'
event_type: event_type:
const: step_progress const: step_progress
default: step_progress default: step_progress
@ -163,14 +165,11 @@ components:
- shield_call - shield_call
- memory_retrieval - memory_retrieval
type: string type: string
text_delta:
type: string
tool_call_delta:
$ref: '#/components/schemas/ToolCallDelta'
required: required:
- event_type - event_type
- step_type - step_type
- step_id - step_id
- delta
type: object type: object
AgentTurnResponseStepStartPayload: AgentTurnResponseStepStartPayload:
additionalProperties: false additionalProperties: false
@ -462,9 +461,7 @@ components:
additionalProperties: false additionalProperties: false
properties: properties:
delta: delta:
oneOf: $ref: '#/components/schemas/ContentDelta'
- type: string
- $ref: '#/components/schemas/ToolCallDelta'
event_type: event_type:
$ref: '#/components/schemas/ChatCompletionResponseEventType' $ref: '#/components/schemas/ChatCompletionResponseEventType'
logprobs: logprobs:
@ -571,6 +568,34 @@ components:
- delta - delta
title: streamed completion response. title: streamed completion response.
type: object type: object
ContentDelta:
oneOf:
- additionalProperties: false
properties:
text:
type: string
type:
const: text
default: text
type: string
required:
- type
- text
type: object
- additionalProperties: false
properties:
data:
contentEncoding: base64
type: string
type:
const: image
default: image
type: string
required:
- type
- data
type: object
- $ref: '#/components/schemas/ToolCallDelta'
CreateAgentRequest: CreateAgentRequest:
additionalProperties: false additionalProperties: false
properties: properties:
@ -2664,7 +2689,12 @@ components:
- $ref: '#/components/schemas/ToolCall' - $ref: '#/components/schemas/ToolCall'
parse_status: parse_status:
$ref: '#/components/schemas/ToolCallParseStatus' $ref: '#/components/schemas/ToolCallParseStatus'
type:
const: tool_call
default: tool_call
type: string
required: required:
- type
- content - content
- parse_status - parse_status
type: object type: object
@ -2672,8 +2702,8 @@ components:
enum: enum:
- started - started
- in_progress - in_progress
- failure - failed
- success - succeeded
type: string type: string
ToolChoice: ToolChoice:
enum: enum:
@ -2888,8 +2918,8 @@ components:
content: content:
$ref: '#/components/schemas/InterleavedContent' $ref: '#/components/schemas/InterleavedContent'
role: role:
const: ipython const: tool
default: ipython default: tool
type: string type: string
tool_name: tool_name:
oneOf: oneOf:
@ -5500,6 +5530,8 @@ tags:
<SchemaDefinition schemaRef="#/components/schemas/CompletionResponseStreamChunk" <SchemaDefinition schemaRef="#/components/schemas/CompletionResponseStreamChunk"
/>' />'
name: CompletionResponseStreamChunk name: CompletionResponseStreamChunk
- description: <SchemaDefinition schemaRef="#/components/schemas/ContentDelta" />
name: ContentDelta
- description: <SchemaDefinition schemaRef="#/components/schemas/CreateAgentRequest" - description: <SchemaDefinition schemaRef="#/components/schemas/CreateAgentRequest"
/> />
name: CreateAgentRequest name: CreateAgentRequest
@ -5939,6 +5971,7 @@ x-tagGroups:
- CompletionRequest - CompletionRequest
- CompletionResponse - CompletionResponse
- CompletionResponseStreamChunk - CompletionResponseStreamChunk
- ContentDelta
- CreateAgentRequest - CreateAgentRequest
- CreateAgentSessionRequest - CreateAgentSessionRequest
- CreateAgentTurnRequest - CreateAgentTurnRequest

View file

@ -22,12 +22,11 @@ from llama_models.schema_utils import json_schema_type, register_schema, webmeth
from pydantic import BaseModel, ConfigDict, Field from pydantic import BaseModel, ConfigDict, Field
from typing_extensions import Annotated from typing_extensions import Annotated
from llama_stack.apis.common.content_types import InterleavedContent, URL from llama_stack.apis.common.content_types import ContentDelta, InterleavedContent, URL
from llama_stack.apis.inference import ( from llama_stack.apis.inference import (
CompletionMessage, CompletionMessage,
SamplingParams, SamplingParams,
ToolCall, ToolCall,
ToolCallDelta,
ToolChoice, ToolChoice,
ToolPromptFormat, ToolPromptFormat,
ToolResponse, ToolResponse,
@ -216,8 +215,7 @@ class AgentTurnResponseStepProgressPayload(BaseModel):
step_type: StepType step_type: StepType
step_id: str step_id: str
text_delta: Optional[str] = None delta: ContentDelta
tool_call_delta: Optional[ToolCallDelta] = None
@json_schema_type @json_schema_type

View file

@ -11,9 +11,13 @@ from llama_models.llama3.api.tool_utils import ToolUtils
from termcolor import cprint from termcolor import cprint
from llama_stack.apis.agents import AgentTurnResponseEventType, StepType from llama_stack.apis.agents import AgentTurnResponseEventType, StepType
from llama_stack.apis.common.content_types import ToolCallParseStatus
from llama_stack.apis.inference import ToolResponseMessage from llama_stack.apis.inference import ToolResponseMessage
from llama_stack.providers.utils.inference.prompt_adapter import (
interleaved_content_as_str,
)
class LogEvent: class LogEvent:
def __init__( def __init__(
@ -57,8 +61,11 @@ class EventLogger:
# since it does not produce event but instead # since it does not produce event but instead
# a Message # a Message
if isinstance(chunk, ToolResponseMessage): if isinstance(chunk, ToolResponseMessage):
yield chunk, LogEvent( yield (
role="CustomTool", content=chunk.content, color="grey" chunk,
LogEvent(
role="CustomTool", content=chunk.content, color="grey"
),
) )
continue continue
@ -80,14 +87,20 @@ class EventLogger:
): ):
violation = event.payload.step_details.violation violation = event.payload.step_details.violation
if not violation: if not violation:
yield event, LogEvent( yield (
role=step_type, content="No Violation", color="magenta" event,
LogEvent(
role=step_type, content="No Violation", color="magenta"
),
) )
else: else:
yield event, LogEvent( yield (
role=step_type, event,
content=f"{violation.metadata} {violation.user_message}", LogEvent(
color="red", role=step_type,
content=f"{violation.metadata} {violation.user_message}",
color="red",
),
) )
# handle inference # handle inference
@ -95,8 +108,11 @@ class EventLogger:
if stream: if stream:
if event_type == EventType.step_start.value: if event_type == EventType.step_start.value:
# TODO: Currently this event is never received # TODO: Currently this event is never received
yield event, LogEvent( yield (
role=step_type, content="", end="", color="yellow" event,
LogEvent(
role=step_type, content="", end="", color="yellow"
),
) )
elif event_type == EventType.step_progress.value: elif event_type == EventType.step_progress.value:
# HACK: if previous was not step/event was not inference's step_progress # HACK: if previous was not step/event was not inference's step_progress
@ -107,24 +123,34 @@ class EventLogger:
previous_event_type != EventType.step_progress.value previous_event_type != EventType.step_progress.value
and previous_step_type != StepType.inference and previous_step_type != StepType.inference
): ):
yield event, LogEvent( yield (
role=step_type, content="", end="", color="yellow" event,
LogEvent(
role=step_type, content="", end="", color="yellow"
),
) )
if event.payload.tool_call_delta: delta = event.payload.delta
if isinstance(event.payload.tool_call_delta.content, str): if delta.type == "tool_call":
yield event, LogEvent( if delta.parse_status == ToolCallParseStatus.succeeded:
role=None, yield (
content=event.payload.tool_call_delta.content, event,
end="", LogEvent(
color="cyan", role=None,
content=delta.content,
end="",
color="cyan",
),
) )
else: else:
yield event, LogEvent( yield (
role=None, event,
content=event.payload.text_delta, LogEvent(
end="", role=None,
color="yellow", content=delta.text,
end="",
color="yellow",
),
) )
else: else:
# step_complete # step_complete
@ -140,10 +166,13 @@ class EventLogger:
) )
else: else:
content = response.content content = response.content
yield event, LogEvent( yield (
role=step_type, event,
content=content, LogEvent(
color="yellow", role=step_type,
content=content,
color="yellow",
),
) )
# handle tool_execution # handle tool_execution
@ -155,16 +184,22 @@ class EventLogger:
): ):
details = event.payload.step_details details = event.payload.step_details
for t in details.tool_calls: for t in details.tool_calls:
yield event, LogEvent( yield (
role=step_type, event,
content=f"Tool:{t.tool_name} Args:{t.arguments}", LogEvent(
color="green", role=step_type,
content=f"Tool:{t.tool_name} Args:{t.arguments}",
color="green",
),
) )
for r in details.tool_responses: for r in details.tool_responses:
yield event, LogEvent( yield (
role=step_type, event,
content=f"Tool:{r.tool_name} Response:{r.content}", LogEvent(
color="green", role=step_type,
content=f"Tool:{r.tool_name} Response:{r.content}",
color="green",
),
) )
if ( if (
@ -172,15 +207,16 @@ class EventLogger:
and event_type == EventType.step_complete.value and event_type == EventType.step_complete.value
): ):
details = event.payload.step_details details = event.payload.step_details
inserted_context = interleaved_text_media_as_str( inserted_context = interleaved_content_as_str(details.inserted_context)
details.inserted_context
)
content = f"fetched {len(inserted_context)} bytes from {details.memory_bank_ids}" content = f"fetched {len(inserted_context)} bytes from {details.memory_bank_ids}"
yield event, LogEvent( yield (
role=step_type, event,
content=content, LogEvent(
color="cyan", role=step_type,
content=content,
color="cyan",
),
) )
previous_event_type = event_type previous_event_type = event_type

View file

@ -5,10 +5,12 @@
# the root directory of this source tree. # the root directory of this source tree.
import base64 import base64
from enum import Enum
from typing import Annotated, List, Literal, Optional, Union from typing import Annotated, List, Literal, Optional, Union
from llama_models.schema_utils import json_schema_type, register_schema from llama_models.llama3.api.datatypes import ToolCall
from llama_models.schema_utils import json_schema_type, register_schema
from pydantic import BaseModel, Field, field_serializer, model_validator from pydantic import BaseModel, Field, field_serializer, model_validator
@ -60,3 +62,42 @@ InterleavedContent = register_schema(
Union[str, InterleavedContentItem, List[InterleavedContentItem]], Union[str, InterleavedContentItem, List[InterleavedContentItem]],
name="InterleavedContent", name="InterleavedContent",
) )
class TextDelta(BaseModel):
type: Literal["text"] = "text"
text: str
class ImageDelta(BaseModel):
type: Literal["image"] = "image"
data: bytes
@json_schema_type
class ToolCallParseStatus(Enum):
started = "started"
in_progress = "in_progress"
failed = "failed"
succeeded = "succeeded"
@json_schema_type
class ToolCallDelta(BaseModel):
type: Literal["tool_call"] = "tool_call"
# you either send an in-progress tool call so the client can stream a long
# code generation or you send the final parsed tool call at the end of the
# stream
content: Union[str, ToolCall]
parse_status: ToolCallParseStatus
# streaming completions send a stream of ContentDeltas
ContentDelta = register_schema(
Annotated[
Union[TextDelta, ImageDelta, ToolCallDelta],
Field(discriminator="type"),
],
name="ContentDelta",
)

View file

@ -29,7 +29,7 @@ from llama_models.schema_utils import json_schema_type, register_schema, webmeth
from pydantic import BaseModel, Field, field_validator from pydantic import BaseModel, Field, field_validator
from typing_extensions import Annotated from typing_extensions import Annotated
from llama_stack.apis.common.content_types import InterleavedContent from llama_stack.apis.common.content_types import ContentDelta, InterleavedContent
from llama_stack.apis.models import Model from llama_stack.apis.models import Model
from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol
@ -147,26 +147,12 @@ class ChatCompletionResponseEventType(Enum):
progress = "progress" progress = "progress"
@json_schema_type
class ToolCallParseStatus(Enum):
started = "started"
in_progress = "in_progress"
failure = "failure"
success = "success"
@json_schema_type
class ToolCallDelta(BaseModel):
content: Union[str, ToolCall]
parse_status: ToolCallParseStatus
@json_schema_type @json_schema_type
class ChatCompletionResponseEvent(BaseModel): class ChatCompletionResponseEvent(BaseModel):
"""Chat completion response event.""" """Chat completion response event."""
event_type: ChatCompletionResponseEventType event_type: ChatCompletionResponseEventType
delta: Union[str, ToolCallDelta] delta: ContentDelta
logprobs: Optional[List[TokenLogProbs]] = None logprobs: Optional[List[TokenLogProbs]] = None
stop_reason: Optional[StopReason] = None stop_reason: Optional[StopReason] = None

View file

@ -4,9 +4,7 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
import argparse import argparse
import importlib.resources import importlib.resources
import os import os
import shutil import shutil
from functools import lru_cache from functools import lru_cache
@ -14,14 +12,12 @@ from pathlib import Path
from typing import List, Optional from typing import List, Optional
from llama_stack.cli.subcommand import Subcommand from llama_stack.cli.subcommand import Subcommand
from llama_stack.distribution.datatypes import ( from llama_stack.distribution.datatypes import (
BuildConfig, BuildConfig,
DistributionSpec, DistributionSpec,
Provider, Provider,
StackRunConfig, StackRunConfig,
) )
from llama_stack.distribution.distribution import get_provider_registry from llama_stack.distribution.distribution import get_provider_registry
from llama_stack.distribution.resolver import InvalidProviderError from llama_stack.distribution.resolver import InvalidProviderError
from llama_stack.distribution.utils.dynamic import instantiate_class_type from llama_stack.distribution.utils.dynamic import instantiate_class_type
@ -296,6 +292,7 @@ class StackBuild(Subcommand):
/ f"templates/{template_name}/run.yaml" / f"templates/{template_name}/run.yaml"
) )
with importlib.resources.as_file(template_path) as path: with importlib.resources.as_file(template_path) as path:
run_config_file = build_dir / f"{build_config.name}-run.yaml"
shutil.copy(path, run_config_file) shutil.copy(path, run_config_file)
# Find all ${env.VARIABLE} patterns # Find all ${env.VARIABLE} patterns
cprint("Build Successful!", color="green") cprint("Build Successful!", color="green")

View file

@ -9,12 +9,10 @@ import inspect
import json import json
import logging import logging
import os import os
import queue
import threading
from concurrent.futures import ThreadPoolExecutor from concurrent.futures import ThreadPoolExecutor
from enum import Enum from enum import Enum
from pathlib import Path from pathlib import Path
from typing import Any, Generator, get_args, get_origin, Optional, TypeVar from typing import Any, get_args, get_origin, Optional, TypeVar
import httpx import httpx
import yaml import yaml
@ -64,71 +62,6 @@ def in_notebook():
return True return True
def stream_across_asyncio_run_boundary(
async_gen_maker,
pool_executor: ThreadPoolExecutor,
path: Optional[str] = None,
provider_data: Optional[dict[str, Any]] = None,
) -> Generator[T, None, None]:
result_queue = queue.Queue()
stop_event = threading.Event()
async def consumer():
# make sure we make the generator in the event loop context
gen = await async_gen_maker()
await start_trace(path, {"__location__": "library_client"})
if provider_data:
set_request_provider_data(
{"X-LlamaStack-Provider-Data": json.dumps(provider_data)}
)
try:
async for item in await gen:
result_queue.put(item)
except Exception as e:
print(f"Error in generator {e}")
result_queue.put(e)
except asyncio.CancelledError:
return
finally:
result_queue.put(StopIteration)
stop_event.set()
await end_trace()
def run_async():
# Run our own loop to avoid double async generator cleanup which is done
# by asyncio.run()
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
task = loop.create_task(consumer())
loop.run_until_complete(task)
finally:
# Handle pending tasks like a generator's athrow()
pending = asyncio.all_tasks(loop)
if pending:
loop.run_until_complete(
asyncio.gather(*pending, return_exceptions=True)
)
loop.close()
future = pool_executor.submit(run_async)
try:
# yield results as they come in
while not stop_event.is_set() or not result_queue.empty():
try:
item = result_queue.get(timeout=0.1)
if item is StopIteration:
break
if isinstance(item, Exception):
raise item
yield item
except queue.Empty:
continue
finally:
future.result()
def convert_pydantic_to_json_value(value: Any) -> Any: def convert_pydantic_to_json_value(value: Any) -> Any:
if isinstance(value, Enum): if isinstance(value, Enum):
return value.value return value.value
@ -184,7 +117,7 @@ class LlamaStackAsLibraryClient(LlamaStackClient):
): ):
super().__init__() super().__init__()
self.async_client = AsyncLlamaStackAsLibraryClient( self.async_client = AsyncLlamaStackAsLibraryClient(
config_path_or_template_name, custom_provider_registry config_path_or_template_name, custom_provider_registry, provider_data
) )
self.pool_executor = ThreadPoolExecutor(max_workers=4) self.pool_executor = ThreadPoolExecutor(max_workers=4)
self.skip_logger_removal = skip_logger_removal self.skip_logger_removal = skip_logger_removal
@ -210,39 +143,30 @@ class LlamaStackAsLibraryClient(LlamaStackClient):
root_logger.removeHandler(handler) root_logger.removeHandler(handler)
print(f"Removed handler {handler.__class__.__name__} from root logger") print(f"Removed handler {handler.__class__.__name__} from root logger")
def _get_path(
self,
cast_to: Any,
options: Any,
*,
stream=False,
stream_cls=None,
):
return options.url
def request(self, *args, **kwargs): def request(self, *args, **kwargs):
path = self._get_path(*args, **kwargs)
if kwargs.get("stream"): if kwargs.get("stream"):
return stream_across_asyncio_run_boundary( # NOTE: We are using AsyncLlamaStackClient under the hood
lambda: self.async_client.request(*args, **kwargs), # A new event loop is needed to convert the AsyncStream
self.pool_executor, # from async client into SyncStream return type for streaming
path=path, loop = asyncio.new_event_loop()
provider_data=self.provider_data, asyncio.set_event_loop(loop)
)
else:
async def _traced_request(): def sync_generator():
if self.provider_data:
set_request_provider_data(
{"X-LlamaStack-Provider-Data": json.dumps(self.provider_data)}
)
await start_trace(path, {"__location__": "library_client"})
try: try:
return await self.async_client.request(*args, **kwargs) async_stream = loop.run_until_complete(
self.async_client.request(*args, **kwargs)
)
while True:
chunk = loop.run_until_complete(async_stream.__anext__())
yield chunk
except StopAsyncIteration:
pass
finally: finally:
await end_trace() loop.close()
return asyncio.run(_traced_request()) return sync_generator()
else:
return asyncio.run(self.async_client.request(*args, **kwargs))
class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient): class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
@ -250,9 +174,9 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
self, self,
config_path_or_template_name: str, config_path_or_template_name: str,
custom_provider_registry: Optional[ProviderRegistry] = None, custom_provider_registry: Optional[ProviderRegistry] = None,
provider_data: Optional[dict[str, Any]] = None,
): ):
super().__init__() super().__init__()
# when using the library client, we should not log to console since many # when using the library client, we should not log to console since many
# of our logs are intended for server-side usage # of our logs are intended for server-side usage
current_sinks = os.environ.get("TELEMETRY_SINKS", "sqlite").split(",") current_sinks = os.environ.get("TELEMETRY_SINKS", "sqlite").split(",")
@ -273,6 +197,7 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
self.config_path_or_template_name = config_path_or_template_name self.config_path_or_template_name = config_path_or_template_name
self.config = config self.config = config
self.custom_provider_registry = custom_provider_registry self.custom_provider_registry = custom_provider_registry
self.provider_data = provider_data
async def initialize(self): async def initialize(self):
try: try:
@ -329,17 +254,24 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
if not self.endpoint_impls: if not self.endpoint_impls:
raise ValueError("Client not initialized") raise ValueError("Client not initialized")
if self.provider_data:
set_request_provider_data(
{"X-LlamaStack-Provider-Data": json.dumps(self.provider_data)}
)
await start_trace(options.url, {"__location__": "library_client"})
if stream: if stream:
return self._call_streaming( response = await self._call_streaming(
cast_to=cast_to, cast_to=cast_to,
options=options, options=options,
stream_cls=stream_cls, stream_cls=stream_cls,
) )
else: else:
return await self._call_non_streaming( response = await self._call_non_streaming(
cast_to=cast_to, cast_to=cast_to,
options=options, options=options,
) )
await end_trace()
return response
async def _call_non_streaming( async def _call_non_streaming(
self, self,

View file

@ -35,7 +35,7 @@ class DistributionRegistry(Protocol):
REGISTER_PREFIX = "distributions:registry" REGISTER_PREFIX = "distributions:registry"
KEY_VERSION = "v4" KEY_VERSION = "v5"
KEY_FORMAT = f"{REGISTER_PREFIX}:{KEY_VERSION}::" + "{type}:{identifier}" KEY_FORMAT = f"{REGISTER_PREFIX}:{KEY_VERSION}::" + "{type}:{identifier}"

View file

@ -40,7 +40,12 @@ from llama_stack.apis.agents import (
ToolExecutionStep, ToolExecutionStep,
Turn, Turn,
) )
from llama_stack.apis.common.content_types import TextContentItem, URL from llama_stack.apis.common.content_types import (
TextContentItem,
ToolCallDelta,
ToolCallParseStatus,
URL,
)
from llama_stack.apis.inference import ( from llama_stack.apis.inference import (
ChatCompletionResponseEventType, ChatCompletionResponseEventType,
CompletionMessage, CompletionMessage,
@ -49,8 +54,6 @@ from llama_stack.apis.inference import (
SamplingParams, SamplingParams,
StopReason, StopReason,
SystemMessage, SystemMessage,
ToolCallDelta,
ToolCallParseStatus,
ToolDefinition, ToolDefinition,
ToolResponse, ToolResponse,
ToolResponseMessage, ToolResponseMessage,
@ -411,8 +414,8 @@ class ChatAgent(ShieldRunnerMixin):
payload=AgentTurnResponseStepProgressPayload( payload=AgentTurnResponseStepProgressPayload(
step_type=StepType.tool_execution.value, step_type=StepType.tool_execution.value,
step_id=step_id, step_id=step_id,
tool_call_delta=ToolCallDelta( delta=ToolCallDelta(
parse_status=ToolCallParseStatus.success, parse_status=ToolCallParseStatus.succeeded,
content=ToolCall( content=ToolCall(
call_id="", call_id="",
tool_name=MEMORY_QUERY_TOOL, tool_name=MEMORY_QUERY_TOOL,
@ -507,8 +510,8 @@ class ChatAgent(ShieldRunnerMixin):
continue continue
delta = event.delta delta = event.delta
if isinstance(delta, ToolCallDelta): if delta.type == "tool_call":
if delta.parse_status == ToolCallParseStatus.success: if delta.parse_status == ToolCallParseStatus.succeeded:
tool_calls.append(delta.content) tool_calls.append(delta.content)
if stream: if stream:
yield AgentTurnResponseStreamChunk( yield AgentTurnResponseStreamChunk(
@ -516,21 +519,20 @@ class ChatAgent(ShieldRunnerMixin):
payload=AgentTurnResponseStepProgressPayload( payload=AgentTurnResponseStepProgressPayload(
step_type=StepType.inference.value, step_type=StepType.inference.value,
step_id=step_id, step_id=step_id,
text_delta="", delta=delta,
tool_call_delta=delta,
) )
) )
) )
elif isinstance(delta, str): elif delta.type == "text":
content += delta content += delta.text
if stream and event.stop_reason is None: if stream and event.stop_reason is None:
yield AgentTurnResponseStreamChunk( yield AgentTurnResponseStreamChunk(
event=AgentTurnResponseEvent( event=AgentTurnResponseEvent(
payload=AgentTurnResponseStepProgressPayload( payload=AgentTurnResponseStepProgressPayload(
step_type=StepType.inference.value, step_type=StepType.inference.value,
step_id=step_id, step_id=step_id,
text_delta=event.delta, delta=delta,
) )
) )
) )

View file

@ -16,6 +16,11 @@ from llama_models.llama3.api.datatypes import (
) )
from llama_models.sku_list import resolve_model from llama_models.sku_list import resolve_model
from llama_stack.apis.common.content_types import (
TextDelta,
ToolCallDelta,
ToolCallParseStatus,
)
from llama_stack.apis.inference import ( from llama_stack.apis.inference import (
ChatCompletionRequest, ChatCompletionRequest,
ChatCompletionResponse, ChatCompletionResponse,
@ -32,8 +37,6 @@ from llama_stack.apis.inference import (
Message, Message,
ResponseFormat, ResponseFormat,
TokenLogProbs, TokenLogProbs,
ToolCallDelta,
ToolCallParseStatus,
ToolChoice, ToolChoice,
) )
from llama_stack.apis.models import Model, ModelType from llama_stack.apis.models import Model, ModelType
@ -190,14 +193,14 @@ class MetaReferenceInferenceImpl(
] ]
yield CompletionResponseStreamChunk( yield CompletionResponseStreamChunk(
delta=text, delta=TextDelta(text=text),
stop_reason=stop_reason, stop_reason=stop_reason,
logprobs=logprobs if request.logprobs else None, logprobs=logprobs if request.logprobs else None,
) )
if stop_reason is None: if stop_reason is None:
yield CompletionResponseStreamChunk( yield CompletionResponseStreamChunk(
delta="", delta=TextDelta(text=""),
stop_reason=StopReason.out_of_tokens, stop_reason=StopReason.out_of_tokens,
) )
@ -352,7 +355,7 @@ class MetaReferenceInferenceImpl(
yield ChatCompletionResponseStreamChunk( yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent( event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.start, event_type=ChatCompletionResponseEventType.start,
delta="", delta=TextDelta(text=""),
) )
) )
@ -392,7 +395,7 @@ class MetaReferenceInferenceImpl(
parse_status=ToolCallParseStatus.in_progress, parse_status=ToolCallParseStatus.in_progress,
) )
else: else:
delta = text delta = TextDelta(text=text)
if stop_reason is None: if stop_reason is None:
if request.logprobs: if request.logprobs:
@ -428,7 +431,7 @@ class MetaReferenceInferenceImpl(
event_type=ChatCompletionResponseEventType.progress, event_type=ChatCompletionResponseEventType.progress,
delta=ToolCallDelta( delta=ToolCallDelta(
content="", content="",
parse_status=ToolCallParseStatus.failure, parse_status=ToolCallParseStatus.failed,
), ),
stop_reason=stop_reason, stop_reason=stop_reason,
) )
@ -440,7 +443,7 @@ class MetaReferenceInferenceImpl(
event_type=ChatCompletionResponseEventType.progress, event_type=ChatCompletionResponseEventType.progress,
delta=ToolCallDelta( delta=ToolCallDelta(
content=tool_call, content=tool_call,
parse_status=ToolCallParseStatus.success, parse_status=ToolCallParseStatus.succeeded,
), ),
stop_reason=stop_reason, stop_reason=stop_reason,
) )
@ -449,7 +452,7 @@ class MetaReferenceInferenceImpl(
yield ChatCompletionResponseStreamChunk( yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent( event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.complete, event_type=ChatCompletionResponseEventType.complete,
delta="", delta=TextDelta(text=""),
stop_reason=stop_reason, stop_reason=stop_reason,
) )
) )

View file

@ -30,13 +30,10 @@ from llama_stack.apis.telemetry import (
Trace, Trace,
UnstructuredLogEvent, UnstructuredLogEvent,
) )
from llama_stack.distribution.datatypes import Api from llama_stack.distribution.datatypes import Api
from llama_stack.providers.inline.telemetry.meta_reference.console_span_processor import ( from llama_stack.providers.inline.telemetry.meta_reference.console_span_processor import (
ConsoleSpanProcessor, ConsoleSpanProcessor,
) )
from llama_stack.providers.inline.telemetry.meta_reference.sqlite_span_processor import ( from llama_stack.providers.inline.telemetry.meta_reference.sqlite_span_processor import (
SQLiteSpanProcessor, SQLiteSpanProcessor,
) )
@ -52,6 +49,7 @@ _GLOBAL_STORAGE = {
"up_down_counters": {}, "up_down_counters": {},
} }
_global_lock = threading.Lock() _global_lock = threading.Lock()
_TRACER_PROVIDER = None
def string_to_trace_id(s: str) -> int: def string_to_trace_id(s: str) -> int:
@ -80,31 +78,34 @@ class TelemetryAdapter(TelemetryDatasetMixin, Telemetry):
} }
) )
provider = TracerProvider(resource=resource) global _TRACER_PROVIDER
trace.set_tracer_provider(provider) if _TRACER_PROVIDER is None:
if TelemetrySink.OTEL in self.config.sinks: provider = TracerProvider(resource=resource)
otlp_exporter = OTLPSpanExporter( trace.set_tracer_provider(provider)
endpoint=self.config.otel_endpoint, _TRACER_PROVIDER = provider
) if TelemetrySink.OTEL in self.config.sinks:
span_processor = BatchSpanProcessor(otlp_exporter) otlp_exporter = OTLPSpanExporter(
trace.get_tracer_provider().add_span_processor(span_processor)
metric_reader = PeriodicExportingMetricReader(
OTLPMetricExporter(
endpoint=self.config.otel_endpoint, endpoint=self.config.otel_endpoint,
) )
) span_processor = BatchSpanProcessor(otlp_exporter)
metric_provider = MeterProvider( trace.get_tracer_provider().add_span_processor(span_processor)
resource=resource, metric_readers=[metric_reader] metric_reader = PeriodicExportingMetricReader(
) OTLPMetricExporter(
metrics.set_meter_provider(metric_provider) endpoint=self.config.otel_endpoint,
self.meter = metrics.get_meter(__name__) )
if TelemetrySink.SQLITE in self.config.sinks: )
trace.get_tracer_provider().add_span_processor( metric_provider = MeterProvider(
SQLiteSpanProcessor(self.config.sqlite_db_path) resource=resource, metric_readers=[metric_reader]
) )
self.trace_store = SQLiteTraceStore(self.config.sqlite_db_path) metrics.set_meter_provider(metric_provider)
if TelemetrySink.CONSOLE in self.config.sinks: self.meter = metrics.get_meter(__name__)
trace.get_tracer_provider().add_span_processor(ConsoleSpanProcessor()) if TelemetrySink.SQLITE in self.config.sinks:
trace.get_tracer_provider().add_span_processor(
SQLiteSpanProcessor(self.config.sqlite_db_path)
)
self.trace_store = SQLiteTraceStore(self.config.sqlite_db_path)
if TelemetrySink.CONSOLE in self.config.sinks:
trace.get_tracer_provider().add_span_processor(ConsoleSpanProcessor())
self._lock = _global_lock self._lock = _global_lock
async def initialize(self) -> None: async def initialize(self) -> None:

View file

@ -30,6 +30,11 @@ from groq.types.shared.function_definition import FunctionDefinition
from llama_models.llama3.api.datatypes import ToolParamDefinition from llama_models.llama3.api.datatypes import ToolParamDefinition
from llama_stack.apis.common.content_types import (
TextDelta,
ToolCallDelta,
ToolCallParseStatus,
)
from llama_stack.apis.inference import ( from llama_stack.apis.inference import (
ChatCompletionRequest, ChatCompletionRequest,
ChatCompletionResponse, ChatCompletionResponse,
@ -40,8 +45,6 @@ from llama_stack.apis.inference import (
Message, Message,
StopReason, StopReason,
ToolCall, ToolCall,
ToolCallDelta,
ToolCallParseStatus,
ToolDefinition, ToolDefinition,
ToolPromptFormat, ToolPromptFormat,
) )
@ -162,7 +165,7 @@ def convert_chat_completion_response(
def _map_finish_reason_to_stop_reason( def _map_finish_reason_to_stop_reason(
finish_reason: Literal["stop", "length", "tool_calls"] finish_reason: Literal["stop", "length", "tool_calls"],
) -> StopReason: ) -> StopReason:
""" """
Convert a Groq chat completion finish_reason to a StopReason. Convert a Groq chat completion finish_reason to a StopReason.
@ -185,7 +188,6 @@ def _map_finish_reason_to_stop_reason(
async def convert_chat_completion_response_stream( async def convert_chat_completion_response_stream(
stream: Stream[ChatCompletionChunk], stream: Stream[ChatCompletionChunk],
) -> AsyncGenerator[ChatCompletionResponseStreamChunk, None]: ) -> AsyncGenerator[ChatCompletionResponseStreamChunk, None]:
event_type = ChatCompletionResponseEventType.start event_type = ChatCompletionResponseEventType.start
for chunk in stream: for chunk in stream:
choice = chunk.choices[0] choice = chunk.choices[0]
@ -194,7 +196,7 @@ async def convert_chat_completion_response_stream(
yield ChatCompletionResponseStreamChunk( yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent( event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.complete, event_type=ChatCompletionResponseEventType.complete,
delta=choice.delta.content or "", delta=TextDelta(text=choice.delta.content or ""),
logprobs=None, logprobs=None,
stop_reason=_map_finish_reason_to_stop_reason(choice.finish_reason), stop_reason=_map_finish_reason_to_stop_reason(choice.finish_reason),
) )
@ -213,7 +215,7 @@ async def convert_chat_completion_response_stream(
event_type=event_type, event_type=event_type,
delta=ToolCallDelta( delta=ToolCallDelta(
content=tool_call, content=tool_call,
parse_status=ToolCallParseStatus.success, parse_status=ToolCallParseStatus.succeeded,
), ),
) )
) )
@ -221,7 +223,7 @@ async def convert_chat_completion_response_stream(
yield ChatCompletionResponseStreamChunk( yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent( event=ChatCompletionResponseEvent(
event_type=event_type, event_type=event_type,
delta=choice.delta.content or "", delta=TextDelta(text=choice.delta.content or ""),
logprobs=None, logprobs=None,
) )
) )

View file

@ -34,6 +34,11 @@ from openai.types.chat.chat_completion_message_tool_call_param import (
from openai.types.completion import Completion as OpenAICompletion from openai.types.completion import Completion as OpenAICompletion
from openai.types.completion_choice import Logprobs as OpenAICompletionLogprobs from openai.types.completion_choice import Logprobs as OpenAICompletionLogprobs
from llama_stack.apis.common.content_types import (
TextDelta,
ToolCallDelta,
ToolCallParseStatus,
)
from llama_stack.apis.inference import ( from llama_stack.apis.inference import (
ChatCompletionRequest, ChatCompletionRequest,
ChatCompletionResponse, ChatCompletionResponse,
@ -48,8 +53,6 @@ from llama_stack.apis.inference import (
Message, Message,
SystemMessage, SystemMessage,
TokenLogProbs, TokenLogProbs,
ToolCallDelta,
ToolCallParseStatus,
ToolResponseMessage, ToolResponseMessage,
UserMessage, UserMessage,
) )
@ -432,69 +435,6 @@ async def convert_openai_chat_completion_stream(
""" """
Convert a stream of OpenAI chat completion chunks into a stream Convert a stream of OpenAI chat completion chunks into a stream
of ChatCompletionResponseStreamChunk. of ChatCompletionResponseStreamChunk.
OpenAI ChatCompletionChunk:
choices: List[Choice]
OpenAI Choice: # different from the non-streamed Choice
delta: ChoiceDelta
finish_reason: Optional[Literal["stop", "length", "tool_calls", "content_filter", "function_call"]]
logprobs: Optional[ChoiceLogprobs]
OpenAI ChoiceDelta:
content: Optional[str]
role: Optional[Literal["system", "user", "assistant", "tool"]]
tool_calls: Optional[List[ChoiceDeltaToolCall]]
OpenAI ChoiceDeltaToolCall:
index: int
id: Optional[str]
function: Optional[ChoiceDeltaToolCallFunction]
type: Optional[Literal["function"]]
OpenAI ChoiceDeltaToolCallFunction:
name: Optional[str]
arguments: Optional[str]
->
ChatCompletionResponseStreamChunk:
event: ChatCompletionResponseEvent
ChatCompletionResponseEvent:
event_type: ChatCompletionResponseEventType
delta: Union[str, ToolCallDelta]
logprobs: Optional[List[TokenLogProbs]]
stop_reason: Optional[StopReason]
ChatCompletionResponseEventType:
start = "start"
progress = "progress"
complete = "complete"
ToolCallDelta:
content: Union[str, ToolCall]
parse_status: ToolCallParseStatus
ToolCall:
call_id: str
tool_name: str
arguments: str
ToolCallParseStatus:
started = "started"
in_progress = "in_progress"
failure = "failure"
success = "success"
TokenLogProbs:
logprobs_by_token: Dict[str, float]
- token, logprob
StopReason:
end_of_turn = "end_of_turn"
end_of_message = "end_of_message"
out_of_tokens = "out_of_tokens"
""" """
# generate a stream of ChatCompletionResponseEventType: start -> progress -> progress -> ... # generate a stream of ChatCompletionResponseEventType: start -> progress -> progress -> ...
@ -543,7 +483,7 @@ async def convert_openai_chat_completion_stream(
yield ChatCompletionResponseStreamChunk( yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent( event=ChatCompletionResponseEvent(
event_type=next(event_type), event_type=next(event_type),
delta=choice.delta.content, delta=TextDelta(text=choice.delta.content),
logprobs=_convert_openai_logprobs(choice.logprobs), logprobs=_convert_openai_logprobs(choice.logprobs),
) )
) )
@ -561,7 +501,7 @@ async def convert_openai_chat_completion_stream(
event_type=next(event_type), event_type=next(event_type),
delta=ToolCallDelta( delta=ToolCallDelta(
content=_convert_openai_tool_calls(choice.delta.tool_calls)[0], content=_convert_openai_tool_calls(choice.delta.tool_calls)[0],
parse_status=ToolCallParseStatus.success, parse_status=ToolCallParseStatus.succeeded,
), ),
logprobs=_convert_openai_logprobs(choice.logprobs), logprobs=_convert_openai_logprobs(choice.logprobs),
) )
@ -570,7 +510,7 @@ async def convert_openai_chat_completion_stream(
yield ChatCompletionResponseStreamChunk( yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent( event=ChatCompletionResponseEvent(
event_type=next(event_type), event_type=next(event_type),
delta=choice.delta.content or "", # content is not optional delta=TextDelta(text=choice.delta.content or ""),
logprobs=_convert_openai_logprobs(choice.logprobs), logprobs=_convert_openai_logprobs(choice.logprobs),
) )
) )
@ -578,7 +518,7 @@ async def convert_openai_chat_completion_stream(
yield ChatCompletionResponseStreamChunk( yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent( event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.complete, event_type=ChatCompletionResponseEventType.complete,
delta="", delta=TextDelta(text=""),
stop_reason=stop_reason, stop_reason=stop_reason,
) )
) )
@ -653,18 +593,6 @@ def _convert_openai_completion_logprobs(
) -> Optional[List[TokenLogProbs]]: ) -> Optional[List[TokenLogProbs]]:
""" """
Convert an OpenAI CompletionLogprobs into a list of TokenLogProbs. Convert an OpenAI CompletionLogprobs into a list of TokenLogProbs.
OpenAI CompletionLogprobs:
text_offset: Optional[List[int]]
token_logprobs: Optional[List[float]]
tokens: Optional[List[str]]
top_logprobs: Optional[List[Dict[str, float]]]
->
TokenLogProbs:
logprobs_by_token: Dict[str, float]
- token, logprob
""" """
if not logprobs: if not logprobs:
return None return None
@ -679,28 +607,6 @@ def convert_openai_completion_choice(
) -> CompletionResponse: ) -> CompletionResponse:
""" """
Convert an OpenAI Completion Choice into a CompletionResponse. Convert an OpenAI Completion Choice into a CompletionResponse.
OpenAI Completion Choice:
text: str
finish_reason: str
logprobs: Optional[ChoiceLogprobs]
->
CompletionResponse:
completion_message: CompletionMessage
logprobs: Optional[List[TokenLogProbs]]
CompletionMessage:
role: Literal["assistant"]
content: str | ImageMedia | List[str | ImageMedia]
stop_reason: StopReason
tool_calls: List[ToolCall]
class StopReason(Enum):
end_of_turn = "end_of_turn"
end_of_message = "end_of_message"
out_of_tokens = "out_of_tokens"
""" """
return CompletionResponse( return CompletionResponse(
content=choice.text, content=choice.text,
@ -715,32 +621,11 @@ async def convert_openai_completion_stream(
""" """
Convert a stream of OpenAI Completions into a stream Convert a stream of OpenAI Completions into a stream
of ChatCompletionResponseStreamChunks. of ChatCompletionResponseStreamChunks.
OpenAI Completion:
id: str
choices: List[OpenAICompletionChoice]
created: int
model: str
system_fingerprint: Optional[str]
usage: Optional[OpenAICompletionUsage]
OpenAI CompletionChoice:
finish_reason: str
index: int
logprobs: Optional[OpenAILogprobs]
text: str
->
CompletionResponseStreamChunk:
delta: str
stop_reason: Optional[StopReason]
logprobs: Optional[List[TokenLogProbs]]
""" """
async for chunk in stream: async for chunk in stream:
choice = chunk.choices[0] choice = chunk.choices[0]
yield CompletionResponseStreamChunk( yield CompletionResponseStreamChunk(
delta=choice.text, delta=TextDelta(text=choice.text),
stop_reason=_convert_openai_finish_reason(choice.finish_reason), stop_reason=_convert_openai_finish_reason(choice.finish_reason),
logprobs=_convert_openai_completion_logprobs(choice.logprobs), logprobs=_convert_openai_completion_logprobs(choice.logprobs),
) )

View file

@ -18,6 +18,7 @@ from llama_models.llama3.api.datatypes import (
from pydantic import BaseModel, ValidationError from pydantic import BaseModel, ValidationError
from llama_stack.apis.common.content_types import ToolCallParseStatus
from llama_stack.apis.inference import ( from llama_stack.apis.inference import (
ChatCompletionResponse, ChatCompletionResponse,
ChatCompletionResponseEventType, ChatCompletionResponseEventType,
@ -27,8 +28,6 @@ from llama_stack.apis.inference import (
JsonSchemaResponseFormat, JsonSchemaResponseFormat,
LogProbConfig, LogProbConfig,
SystemMessage, SystemMessage,
ToolCallDelta,
ToolCallParseStatus,
ToolChoice, ToolChoice,
UserMessage, UserMessage,
) )
@ -196,7 +195,9 @@ class TestInference:
1 <= len(chunks) <= 6 1 <= len(chunks) <= 6
) # why 6 and not 5? the response may have an extra closing chunk, e.g. for usage or stop_reason ) # why 6 and not 5? the response may have an extra closing chunk, e.g. for usage or stop_reason
for chunk in chunks: for chunk in chunks:
if chunk.delta: # if there's a token, we expect logprobs if (
chunk.delta.type == "text" and chunk.delta.text
): # if there's a token, we expect logprobs
assert chunk.logprobs, "Logprobs should not be empty" assert chunk.logprobs, "Logprobs should not be empty"
assert all( assert all(
len(logprob.logprobs_by_token) == 3 for logprob in chunk.logprobs len(logprob.logprobs_by_token) == 3 for logprob in chunk.logprobs
@ -463,7 +464,7 @@ class TestInference:
if "Llama3.1" in inference_model: if "Llama3.1" in inference_model:
assert all( assert all(
isinstance(chunk.event.delta, ToolCallDelta) chunk.event.delta.type == "tool_call"
for chunk in grouped[ChatCompletionResponseEventType.progress] for chunk in grouped[ChatCompletionResponseEventType.progress]
) )
first = grouped[ChatCompletionResponseEventType.progress][0] first = grouped[ChatCompletionResponseEventType.progress][0]
@ -474,8 +475,8 @@ class TestInference:
last = grouped[ChatCompletionResponseEventType.progress][-1] last = grouped[ChatCompletionResponseEventType.progress][-1]
# assert last.event.stop_reason == expected_stop_reason # assert last.event.stop_reason == expected_stop_reason
assert last.event.delta.parse_status == ToolCallParseStatus.success assert last.event.delta.parse_status == ToolCallParseStatus.succeeded
assert isinstance(last.event.delta.content, ToolCall) assert last.event.delta.content.type == "tool_call"
call = last.event.delta.content call = last.event.delta.content
assert call.tool_name == "get_weather" assert call.tool_name == "get_weather"

View file

@ -11,7 +11,13 @@ from llama_models.llama3.api.chat_format import ChatFormat
from llama_models.llama3.api.datatypes import SamplingParams, StopReason from llama_models.llama3.api.datatypes import SamplingParams, StopReason
from pydantic import BaseModel from pydantic import BaseModel
from llama_stack.apis.common.content_types import ImageContentItem, TextContentItem from llama_stack.apis.common.content_types import (
ImageContentItem,
TextContentItem,
TextDelta,
ToolCallDelta,
ToolCallParseStatus,
)
from llama_stack.apis.inference import ( from llama_stack.apis.inference import (
ChatCompletionResponse, ChatCompletionResponse,
@ -22,8 +28,6 @@ from llama_stack.apis.inference import (
CompletionResponse, CompletionResponse,
CompletionResponseStreamChunk, CompletionResponseStreamChunk,
Message, Message,
ToolCallDelta,
ToolCallParseStatus,
) )
from llama_stack.providers.utils.inference.prompt_adapter import ( from llama_stack.providers.utils.inference.prompt_adapter import (
@ -160,7 +164,7 @@ async def process_chat_completion_stream_response(
yield ChatCompletionResponseStreamChunk( yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent( event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.start, event_type=ChatCompletionResponseEventType.start,
delta="", delta=TextDelta(text=""),
) )
) )
@ -227,7 +231,7 @@ async def process_chat_completion_stream_response(
yield ChatCompletionResponseStreamChunk( yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent( event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.progress, event_type=ChatCompletionResponseEventType.progress,
delta=text, delta=TextDelta(text=text),
stop_reason=stop_reason, stop_reason=stop_reason,
) )
) )
@ -241,7 +245,7 @@ async def process_chat_completion_stream_response(
event_type=ChatCompletionResponseEventType.progress, event_type=ChatCompletionResponseEventType.progress,
delta=ToolCallDelta( delta=ToolCallDelta(
content="", content="",
parse_status=ToolCallParseStatus.failure, parse_status=ToolCallParseStatus.failed,
), ),
stop_reason=stop_reason, stop_reason=stop_reason,
) )
@ -253,7 +257,7 @@ async def process_chat_completion_stream_response(
event_type=ChatCompletionResponseEventType.progress, event_type=ChatCompletionResponseEventType.progress,
delta=ToolCallDelta( delta=ToolCallDelta(
content=tool_call, content=tool_call,
parse_status=ToolCallParseStatus.success, parse_status=ToolCallParseStatus.succeeded,
), ),
stop_reason=stop_reason, stop_reason=stop_reason,
) )
@ -262,7 +266,7 @@ async def process_chat_completion_stream_response(
yield ChatCompletionResponseStreamChunk( yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent( event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.complete, event_type=ChatCompletionResponseEventType.complete,
delta="", delta=TextDelta(text=""),
stop_reason=stop_reason, stop_reason=stop_reason,
) )
) )

View file

@ -265,6 +265,7 @@ def chat_completion_request_to_messages(
For eg. for llama_3_1, add system message with the appropriate tools or For eg. for llama_3_1, add system message with the appropriate tools or
add user messsage for custom tools, etc. add user messsage for custom tools, etc.
""" """
assert llama_model is not None, "llama_model is required"
model = resolve_model(llama_model) model = resolve_model(llama_model)
if model is None: if model is None:
log.error(f"Could not resolve model {llama_model}") log.error(f"Could not resolve model {llama_model}")

View file

@ -127,7 +127,8 @@ class TraceContext:
def setup_logger(api: Telemetry, level: int = logging.INFO): def setup_logger(api: Telemetry, level: int = logging.INFO):
global BACKGROUND_LOGGER global BACKGROUND_LOGGER
BACKGROUND_LOGGER = BackgroundLogger(api) if BACKGROUND_LOGGER is None:
BACKGROUND_LOGGER = BackgroundLogger(api)
logger = logging.getLogger() logger = logging.getLogger()
logger.setLevel(level) logger.setLevel(level)
logger.addHandler(TelemetryHandler()) logger.addHandler(TelemetryHandler())

View file

@ -12,6 +12,11 @@ from llama_stack.providers.tests.env import get_env_or_fail
from llama_stack_client import LlamaStackClient from llama_stack_client import LlamaStackClient
def pytest_configure(config):
config.option.tbstyle = "short"
config.option.disable_warnings = True
@pytest.fixture(scope="session") @pytest.fixture(scope="session")
def provider_data(): def provider_data():
# check env for tavily secret, brave secret and inject all into provider data # check env for tavily secret, brave secret and inject all into provider data
@ -29,6 +34,7 @@ def llama_stack_client(provider_data):
client = LlamaStackAsLibraryClient( client = LlamaStackAsLibraryClient(
get_env_or_fail("LLAMA_STACK_CONFIG"), get_env_or_fail("LLAMA_STACK_CONFIG"),
provider_data=provider_data, provider_data=provider_data,
skip_logger_removal=True,
) )
client.initialize() client.initialize()
elif os.environ.get("LLAMA_STACK_BASE_URL"): elif os.environ.get("LLAMA_STACK_BASE_URL"):

View file

@ -6,9 +6,9 @@
import pytest import pytest
from llama_stack_client.lib.inference.event_logger import EventLogger
from pydantic import BaseModel from pydantic import BaseModel
PROVIDER_TOOL_PROMPT_FORMAT = { PROVIDER_TOOL_PROMPT_FORMAT = {
"remote::ollama": "python_list", "remote::ollama": "python_list",
"remote::together": "json", "remote::together": "json",
@ -39,7 +39,7 @@ def text_model_id(llama_stack_client):
available_models = [ available_models = [
model.identifier model.identifier
for model in llama_stack_client.models.list() for model in llama_stack_client.models.list()
if model.identifier.startswith("meta-llama") if model.identifier.startswith("meta-llama") and "405" not in model.identifier
] ]
assert len(available_models) > 0 assert len(available_models) > 0
return available_models[0] return available_models[0]
@ -208,12 +208,9 @@ def test_text_chat_completion_streaming(
stream=True, stream=True,
) )
streamed_content = [ streamed_content = [
str(log.content.lower().strip()) str(chunk.event.delta.text.lower().strip()) for chunk in response
for log in EventLogger().log(response)
if log is not None
] ]
assert len(streamed_content) > 0 assert len(streamed_content) > 0
assert "assistant>" in streamed_content[0]
assert expected.lower() in "".join(streamed_content) assert expected.lower() in "".join(streamed_content)
@ -250,17 +247,16 @@ def test_text_chat_completion_with_tool_calling_and_non_streaming(
def extract_tool_invocation_content(response): def extract_tool_invocation_content(response):
text_content: str = "" text_content: str = ""
tool_invocation_content: str = "" tool_invocation_content: str = ""
for log in EventLogger().log(response): for chunk in response:
if log is None: delta = chunk.event.delta
continue if delta.type == "text":
if isinstance(log.content, str): text_content += delta.text
text_content += log.content elif delta.type == "tool_call":
elif isinstance(log.content, object): if isinstance(delta.content, str):
if isinstance(log.content.content, str): tool_invocation_content += delta.content
continue else:
elif isinstance(log.content.content, object): call = delta.content
tool_invocation_content += f"[{log.content.content.tool_name}, {log.content.content.arguments}]" tool_invocation_content += f"[{call.tool_name}, {call.arguments}]"
return text_content, tool_invocation_content return text_content, tool_invocation_content
@ -280,7 +276,6 @@ def test_text_chat_completion_with_tool_calling_and_streaming(
) )
text_content, tool_invocation_content = extract_tool_invocation_content(response) text_content, tool_invocation_content = extract_tool_invocation_content(response)
assert "Assistant>" in text_content
assert tool_invocation_content == "[get_weather, {'location': 'San Francisco, CA'}]" assert tool_invocation_content == "[get_weather, {'location': 'San Francisco, CA'}]"
@ -368,10 +363,7 @@ def test_image_chat_completion_streaming(llama_stack_client, vision_model_id):
stream=True, stream=True,
) )
streamed_content = [ streamed_content = [
str(log.content.lower().strip()) str(chunk.event.delta.text.lower().strip()) for chunk in response
for log in EventLogger().log(response)
if log is not None
] ]
assert len(streamed_content) > 0 assert len(streamed_content) > 0
assert "assistant>" in streamed_content[0]
assert any(expected in streamed_content for expected in {"dog", "puppy", "pup"}) assert any(expected in streamed_content for expected in {"dog", "puppy", "pup"})