Merge remote-tracking branch 'origin/main' into support_more_data_format

This commit is contained in:
Botao Chen 2025-01-14 11:55:13 -08:00
commit 8d7bb1140f
20 changed files with 381 additions and 414 deletions

View file

@ -11,7 +11,13 @@ from llama_models.llama3.api.chat_format import ChatFormat
from llama_models.llama3.api.datatypes import SamplingParams, StopReason
from pydantic import BaseModel
from llama_stack.apis.common.content_types import ImageContentItem, TextContentItem
from llama_stack.apis.common.content_types import (
ImageContentItem,
TextContentItem,
TextDelta,
ToolCallDelta,
ToolCallParseStatus,
)
from llama_stack.apis.inference import (
ChatCompletionResponse,
@ -22,8 +28,6 @@ from llama_stack.apis.inference import (
CompletionResponse,
CompletionResponseStreamChunk,
Message,
ToolCallDelta,
ToolCallParseStatus,
)
from llama_stack.providers.utils.inference.prompt_adapter import (
@ -160,7 +164,7 @@ async def process_chat_completion_stream_response(
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.start,
delta="",
delta=TextDelta(text=""),
)
)
@ -227,7 +231,7 @@ async def process_chat_completion_stream_response(
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.progress,
delta=text,
delta=TextDelta(text=text),
stop_reason=stop_reason,
)
)
@ -241,7 +245,7 @@ async def process_chat_completion_stream_response(
event_type=ChatCompletionResponseEventType.progress,
delta=ToolCallDelta(
content="",
parse_status=ToolCallParseStatus.failure,
parse_status=ToolCallParseStatus.failed,
),
stop_reason=stop_reason,
)
@ -253,7 +257,7 @@ async def process_chat_completion_stream_response(
event_type=ChatCompletionResponseEventType.progress,
delta=ToolCallDelta(
content=tool_call,
parse_status=ToolCallParseStatus.success,
parse_status=ToolCallParseStatus.succeeded,
),
stop_reason=stop_reason,
)
@ -262,7 +266,7 @@ async def process_chat_completion_stream_response(
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.complete,
delta="",
delta=TextDelta(text=""),
stop_reason=stop_reason,
)
)

View file

@ -265,6 +265,7 @@ def chat_completion_request_to_messages(
For eg. for llama_3_1, add system message with the appropriate tools or
add user messsage for custom tools, etc.
"""
assert llama_model is not None, "llama_model is required"
model = resolve_model(llama_model)
if model is None:
log.error(f"Could not resolve model {llama_model}")

View file

@ -127,7 +127,8 @@ class TraceContext:
def setup_logger(api: Telemetry, level: int = logging.INFO):
global BACKGROUND_LOGGER
BACKGROUND_LOGGER = BackgroundLogger(api)
if BACKGROUND_LOGGER is None:
BACKGROUND_LOGGER = BackgroundLogger(api)
logger = logging.getLogger()
logger.setLevel(level)
logger.addHandler(TelemetryHandler())