forked from phoenix-oss/llama-stack-mirror
move all implementations to use updated type
This commit is contained in:
parent
aced2ce07e
commit
9a5803a429
8 changed files with 139 additions and 208 deletions
|
@ -22,12 +22,11 @@ from llama_models.schema_utils import json_schema_type, register_schema, webmeth
|
||||||
from pydantic import BaseModel, ConfigDict, Field
|
from pydantic import BaseModel, ConfigDict, Field
|
||||||
from typing_extensions import Annotated
|
from typing_extensions import Annotated
|
||||||
|
|
||||||
from llama_stack.apis.common.content_types import InterleavedContent, URL
|
from llama_stack.apis.common.content_types import ContentDelta, InterleavedContent, URL
|
||||||
from llama_stack.apis.inference import (
|
from llama_stack.apis.inference import (
|
||||||
CompletionMessage,
|
CompletionMessage,
|
||||||
SamplingParams,
|
SamplingParams,
|
||||||
ToolCall,
|
ToolCall,
|
||||||
ToolCallDelta,
|
|
||||||
ToolChoice,
|
ToolChoice,
|
||||||
ToolPromptFormat,
|
ToolPromptFormat,
|
||||||
ToolResponse,
|
ToolResponse,
|
||||||
|
@ -216,8 +215,7 @@ class AgentTurnResponseStepProgressPayload(BaseModel):
|
||||||
step_type: StepType
|
step_type: StepType
|
||||||
step_id: str
|
step_id: str
|
||||||
|
|
||||||
text_delta: Optional[str] = None
|
delta: ContentDelta
|
||||||
tool_call_delta: Optional[ToolCallDelta] = None
|
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
|
|
|
@ -11,9 +11,13 @@ from llama_models.llama3.api.tool_utils import ToolUtils
|
||||||
from termcolor import cprint
|
from termcolor import cprint
|
||||||
|
|
||||||
from llama_stack.apis.agents import AgentTurnResponseEventType, StepType
|
from llama_stack.apis.agents import AgentTurnResponseEventType, StepType
|
||||||
|
from llama_stack.apis.common.content_types import ToolCallParseStatus
|
||||||
from llama_stack.apis.inference import ToolResponseMessage
|
from llama_stack.apis.inference import ToolResponseMessage
|
||||||
|
|
||||||
|
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||||
|
interleaved_content_as_str,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class LogEvent:
|
class LogEvent:
|
||||||
def __init__(
|
def __init__(
|
||||||
|
@ -57,8 +61,11 @@ class EventLogger:
|
||||||
# since it does not produce event but instead
|
# since it does not produce event but instead
|
||||||
# a Message
|
# a Message
|
||||||
if isinstance(chunk, ToolResponseMessage):
|
if isinstance(chunk, ToolResponseMessage):
|
||||||
yield chunk, LogEvent(
|
yield (
|
||||||
|
chunk,
|
||||||
|
LogEvent(
|
||||||
role="CustomTool", content=chunk.content, color="grey"
|
role="CustomTool", content=chunk.content, color="grey"
|
||||||
|
),
|
||||||
)
|
)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
@ -80,14 +87,20 @@ class EventLogger:
|
||||||
):
|
):
|
||||||
violation = event.payload.step_details.violation
|
violation = event.payload.step_details.violation
|
||||||
if not violation:
|
if not violation:
|
||||||
yield event, LogEvent(
|
yield (
|
||||||
|
event,
|
||||||
|
LogEvent(
|
||||||
role=step_type, content="No Violation", color="magenta"
|
role=step_type, content="No Violation", color="magenta"
|
||||||
|
),
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
yield event, LogEvent(
|
yield (
|
||||||
|
event,
|
||||||
|
LogEvent(
|
||||||
role=step_type,
|
role=step_type,
|
||||||
content=f"{violation.metadata} {violation.user_message}",
|
content=f"{violation.metadata} {violation.user_message}",
|
||||||
color="red",
|
color="red",
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
# handle inference
|
# handle inference
|
||||||
|
@ -95,8 +108,11 @@ class EventLogger:
|
||||||
if stream:
|
if stream:
|
||||||
if event_type == EventType.step_start.value:
|
if event_type == EventType.step_start.value:
|
||||||
# TODO: Currently this event is never received
|
# TODO: Currently this event is never received
|
||||||
yield event, LogEvent(
|
yield (
|
||||||
|
event,
|
||||||
|
LogEvent(
|
||||||
role=step_type, content="", end="", color="yellow"
|
role=step_type, content="", end="", color="yellow"
|
||||||
|
),
|
||||||
)
|
)
|
||||||
elif event_type == EventType.step_progress.value:
|
elif event_type == EventType.step_progress.value:
|
||||||
# HACK: if previous was not step/event was not inference's step_progress
|
# HACK: if previous was not step/event was not inference's step_progress
|
||||||
|
@ -107,24 +123,34 @@ class EventLogger:
|
||||||
previous_event_type != EventType.step_progress.value
|
previous_event_type != EventType.step_progress.value
|
||||||
and previous_step_type != StepType.inference
|
and previous_step_type != StepType.inference
|
||||||
):
|
):
|
||||||
yield event, LogEvent(
|
yield (
|
||||||
|
event,
|
||||||
|
LogEvent(
|
||||||
role=step_type, content="", end="", color="yellow"
|
role=step_type, content="", end="", color="yellow"
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
if event.payload.tool_call_delta:
|
delta = event.payload.delta
|
||||||
if isinstance(event.payload.tool_call_delta.content, str):
|
if delta.type == "tool_call":
|
||||||
yield event, LogEvent(
|
if delta.parse_status == ToolCallParseStatus.success:
|
||||||
|
yield (
|
||||||
|
event,
|
||||||
|
LogEvent(
|
||||||
role=None,
|
role=None,
|
||||||
content=event.payload.tool_call_delta.content,
|
content=delta.content,
|
||||||
end="",
|
end="",
|
||||||
color="cyan",
|
color="cyan",
|
||||||
|
),
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
yield event, LogEvent(
|
yield (
|
||||||
|
event,
|
||||||
|
LogEvent(
|
||||||
role=None,
|
role=None,
|
||||||
content=event.payload.text_delta,
|
content=delta.text,
|
||||||
end="",
|
end="",
|
||||||
color="yellow",
|
color="yellow",
|
||||||
|
),
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
# step_complete
|
# step_complete
|
||||||
|
@ -140,10 +166,13 @@ class EventLogger:
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
content = response.content
|
content = response.content
|
||||||
yield event, LogEvent(
|
yield (
|
||||||
|
event,
|
||||||
|
LogEvent(
|
||||||
role=step_type,
|
role=step_type,
|
||||||
content=content,
|
content=content,
|
||||||
color="yellow",
|
color="yellow",
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
# handle tool_execution
|
# handle tool_execution
|
||||||
|
@ -155,16 +184,22 @@ class EventLogger:
|
||||||
):
|
):
|
||||||
details = event.payload.step_details
|
details = event.payload.step_details
|
||||||
for t in details.tool_calls:
|
for t in details.tool_calls:
|
||||||
yield event, LogEvent(
|
yield (
|
||||||
|
event,
|
||||||
|
LogEvent(
|
||||||
role=step_type,
|
role=step_type,
|
||||||
content=f"Tool:{t.tool_name} Args:{t.arguments}",
|
content=f"Tool:{t.tool_name} Args:{t.arguments}",
|
||||||
color="green",
|
color="green",
|
||||||
|
),
|
||||||
)
|
)
|
||||||
for r in details.tool_responses:
|
for r in details.tool_responses:
|
||||||
yield event, LogEvent(
|
yield (
|
||||||
|
event,
|
||||||
|
LogEvent(
|
||||||
role=step_type,
|
role=step_type,
|
||||||
content=f"Tool:{r.tool_name} Response:{r.content}",
|
content=f"Tool:{r.tool_name} Response:{r.content}",
|
||||||
color="green",
|
color="green",
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
if (
|
if (
|
||||||
|
@ -172,15 +207,16 @@ class EventLogger:
|
||||||
and event_type == EventType.step_complete.value
|
and event_type == EventType.step_complete.value
|
||||||
):
|
):
|
||||||
details = event.payload.step_details
|
details = event.payload.step_details
|
||||||
inserted_context = interleaved_text_media_as_str(
|
inserted_context = interleaved_content_as_str(details.inserted_context)
|
||||||
details.inserted_context
|
|
||||||
)
|
|
||||||
content = f"fetched {len(inserted_context)} bytes from {details.memory_bank_ids}"
|
content = f"fetched {len(inserted_context)} bytes from {details.memory_bank_ids}"
|
||||||
|
|
||||||
yield event, LogEvent(
|
yield (
|
||||||
|
event,
|
||||||
|
LogEvent(
|
||||||
role=step_type,
|
role=step_type,
|
||||||
content=content,
|
content=content,
|
||||||
color="cyan",
|
color="cyan",
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
previous_event_type = event_type
|
previous_event_type = event_type
|
||||||
|
|
|
@ -40,7 +40,12 @@ from llama_stack.apis.agents import (
|
||||||
ToolExecutionStep,
|
ToolExecutionStep,
|
||||||
Turn,
|
Turn,
|
||||||
)
|
)
|
||||||
from llama_stack.apis.common.content_types import TextContentItem, URL
|
from llama_stack.apis.common.content_types import (
|
||||||
|
TextContentItem,
|
||||||
|
ToolCallDelta,
|
||||||
|
ToolCallParseStatus,
|
||||||
|
URL,
|
||||||
|
)
|
||||||
from llama_stack.apis.inference import (
|
from llama_stack.apis.inference import (
|
||||||
ChatCompletionResponseEventType,
|
ChatCompletionResponseEventType,
|
||||||
CompletionMessage,
|
CompletionMessage,
|
||||||
|
@ -49,8 +54,6 @@ from llama_stack.apis.inference import (
|
||||||
SamplingParams,
|
SamplingParams,
|
||||||
StopReason,
|
StopReason,
|
||||||
SystemMessage,
|
SystemMessage,
|
||||||
ToolCallDelta,
|
|
||||||
ToolCallParseStatus,
|
|
||||||
ToolDefinition,
|
ToolDefinition,
|
||||||
ToolResponse,
|
ToolResponse,
|
||||||
ToolResponseMessage,
|
ToolResponseMessage,
|
||||||
|
@ -411,7 +414,7 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
payload=AgentTurnResponseStepProgressPayload(
|
payload=AgentTurnResponseStepProgressPayload(
|
||||||
step_type=StepType.tool_execution.value,
|
step_type=StepType.tool_execution.value,
|
||||||
step_id=step_id,
|
step_id=step_id,
|
||||||
tool_call_delta=ToolCallDelta(
|
delta=ToolCallDelta(
|
||||||
parse_status=ToolCallParseStatus.success,
|
parse_status=ToolCallParseStatus.success,
|
||||||
content=ToolCall(
|
content=ToolCall(
|
||||||
call_id="",
|
call_id="",
|
||||||
|
@ -507,7 +510,7 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
continue
|
continue
|
||||||
|
|
||||||
delta = event.delta
|
delta = event.delta
|
||||||
if isinstance(delta, ToolCallDelta):
|
if delta.type == "tool_call":
|
||||||
if delta.parse_status == ToolCallParseStatus.success:
|
if delta.parse_status == ToolCallParseStatus.success:
|
||||||
tool_calls.append(delta.content)
|
tool_calls.append(delta.content)
|
||||||
if stream:
|
if stream:
|
||||||
|
@ -516,21 +519,20 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
payload=AgentTurnResponseStepProgressPayload(
|
payload=AgentTurnResponseStepProgressPayload(
|
||||||
step_type=StepType.inference.value,
|
step_type=StepType.inference.value,
|
||||||
step_id=step_id,
|
step_id=step_id,
|
||||||
text_delta="",
|
delta=delta,
|
||||||
tool_call_delta=delta,
|
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
elif isinstance(delta, str):
|
elif delta.type == "text":
|
||||||
content += delta
|
content += delta.text
|
||||||
if stream and event.stop_reason is None:
|
if stream and event.stop_reason is None:
|
||||||
yield AgentTurnResponseStreamChunk(
|
yield AgentTurnResponseStreamChunk(
|
||||||
event=AgentTurnResponseEvent(
|
event=AgentTurnResponseEvent(
|
||||||
payload=AgentTurnResponseStepProgressPayload(
|
payload=AgentTurnResponseStepProgressPayload(
|
||||||
step_type=StepType.inference.value,
|
step_type=StepType.inference.value,
|
||||||
step_id=step_id,
|
step_id=step_id,
|
||||||
text_delta=event.delta,
|
delta=delta,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
|
@ -16,6 +16,11 @@ from llama_models.llama3.api.datatypes import (
|
||||||
)
|
)
|
||||||
from llama_models.sku_list import resolve_model
|
from llama_models.sku_list import resolve_model
|
||||||
|
|
||||||
|
from llama_stack.apis.common.content_types import (
|
||||||
|
TextDelta,
|
||||||
|
ToolCallDelta,
|
||||||
|
ToolCallParseStatus,
|
||||||
|
)
|
||||||
from llama_stack.apis.inference import (
|
from llama_stack.apis.inference import (
|
||||||
ChatCompletionRequest,
|
ChatCompletionRequest,
|
||||||
ChatCompletionResponse,
|
ChatCompletionResponse,
|
||||||
|
@ -32,8 +37,6 @@ from llama_stack.apis.inference import (
|
||||||
Message,
|
Message,
|
||||||
ResponseFormat,
|
ResponseFormat,
|
||||||
TokenLogProbs,
|
TokenLogProbs,
|
||||||
ToolCallDelta,
|
|
||||||
ToolCallParseStatus,
|
|
||||||
ToolChoice,
|
ToolChoice,
|
||||||
)
|
)
|
||||||
from llama_stack.apis.models import Model, ModelType
|
from llama_stack.apis.models import Model, ModelType
|
||||||
|
@ -190,14 +193,14 @@ class MetaReferenceInferenceImpl(
|
||||||
]
|
]
|
||||||
|
|
||||||
yield CompletionResponseStreamChunk(
|
yield CompletionResponseStreamChunk(
|
||||||
delta=text,
|
delta=TextDelta(text=text),
|
||||||
stop_reason=stop_reason,
|
stop_reason=stop_reason,
|
||||||
logprobs=logprobs if request.logprobs else None,
|
logprobs=logprobs if request.logprobs else None,
|
||||||
)
|
)
|
||||||
|
|
||||||
if stop_reason is None:
|
if stop_reason is None:
|
||||||
yield CompletionResponseStreamChunk(
|
yield CompletionResponseStreamChunk(
|
||||||
delta="",
|
delta=TextDelta(text=""),
|
||||||
stop_reason=StopReason.out_of_tokens,
|
stop_reason=StopReason.out_of_tokens,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -352,7 +355,7 @@ class MetaReferenceInferenceImpl(
|
||||||
yield ChatCompletionResponseStreamChunk(
|
yield ChatCompletionResponseStreamChunk(
|
||||||
event=ChatCompletionResponseEvent(
|
event=ChatCompletionResponseEvent(
|
||||||
event_type=ChatCompletionResponseEventType.start,
|
event_type=ChatCompletionResponseEventType.start,
|
||||||
delta="",
|
delta=TextDelta(text=""),
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -392,7 +395,7 @@ class MetaReferenceInferenceImpl(
|
||||||
parse_status=ToolCallParseStatus.in_progress,
|
parse_status=ToolCallParseStatus.in_progress,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
delta = text
|
delta = TextDelta(text=text)
|
||||||
|
|
||||||
if stop_reason is None:
|
if stop_reason is None:
|
||||||
if request.logprobs:
|
if request.logprobs:
|
||||||
|
@ -449,7 +452,7 @@ class MetaReferenceInferenceImpl(
|
||||||
yield ChatCompletionResponseStreamChunk(
|
yield ChatCompletionResponseStreamChunk(
|
||||||
event=ChatCompletionResponseEvent(
|
event=ChatCompletionResponseEvent(
|
||||||
event_type=ChatCompletionResponseEventType.complete,
|
event_type=ChatCompletionResponseEventType.complete,
|
||||||
delta="",
|
delta=TextDelta(text=""),
|
||||||
stop_reason=stop_reason,
|
stop_reason=stop_reason,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
|
@ -30,6 +30,11 @@ from groq.types.shared.function_definition import FunctionDefinition
|
||||||
|
|
||||||
from llama_models.llama3.api.datatypes import ToolParamDefinition
|
from llama_models.llama3.api.datatypes import ToolParamDefinition
|
||||||
|
|
||||||
|
from llama_stack.apis.common.content_types import (
|
||||||
|
TextDelta,
|
||||||
|
ToolCallDelta,
|
||||||
|
ToolCallParseStatus,
|
||||||
|
)
|
||||||
from llama_stack.apis.inference import (
|
from llama_stack.apis.inference import (
|
||||||
ChatCompletionRequest,
|
ChatCompletionRequest,
|
||||||
ChatCompletionResponse,
|
ChatCompletionResponse,
|
||||||
|
@ -40,8 +45,6 @@ from llama_stack.apis.inference import (
|
||||||
Message,
|
Message,
|
||||||
StopReason,
|
StopReason,
|
||||||
ToolCall,
|
ToolCall,
|
||||||
ToolCallDelta,
|
|
||||||
ToolCallParseStatus,
|
|
||||||
ToolDefinition,
|
ToolDefinition,
|
||||||
ToolPromptFormat,
|
ToolPromptFormat,
|
||||||
)
|
)
|
||||||
|
@ -162,7 +165,7 @@ def convert_chat_completion_response(
|
||||||
|
|
||||||
|
|
||||||
def _map_finish_reason_to_stop_reason(
|
def _map_finish_reason_to_stop_reason(
|
||||||
finish_reason: Literal["stop", "length", "tool_calls"]
|
finish_reason: Literal["stop", "length", "tool_calls"],
|
||||||
) -> StopReason:
|
) -> StopReason:
|
||||||
"""
|
"""
|
||||||
Convert a Groq chat completion finish_reason to a StopReason.
|
Convert a Groq chat completion finish_reason to a StopReason.
|
||||||
|
@ -185,7 +188,6 @@ def _map_finish_reason_to_stop_reason(
|
||||||
async def convert_chat_completion_response_stream(
|
async def convert_chat_completion_response_stream(
|
||||||
stream: Stream[ChatCompletionChunk],
|
stream: Stream[ChatCompletionChunk],
|
||||||
) -> AsyncGenerator[ChatCompletionResponseStreamChunk, None]:
|
) -> AsyncGenerator[ChatCompletionResponseStreamChunk, None]:
|
||||||
|
|
||||||
event_type = ChatCompletionResponseEventType.start
|
event_type = ChatCompletionResponseEventType.start
|
||||||
for chunk in stream:
|
for chunk in stream:
|
||||||
choice = chunk.choices[0]
|
choice = chunk.choices[0]
|
||||||
|
@ -194,7 +196,7 @@ async def convert_chat_completion_response_stream(
|
||||||
yield ChatCompletionResponseStreamChunk(
|
yield ChatCompletionResponseStreamChunk(
|
||||||
event=ChatCompletionResponseEvent(
|
event=ChatCompletionResponseEvent(
|
||||||
event_type=ChatCompletionResponseEventType.complete,
|
event_type=ChatCompletionResponseEventType.complete,
|
||||||
delta=choice.delta.content or "",
|
delta=TextDelta(text=choice.delta.content or ""),
|
||||||
logprobs=None,
|
logprobs=None,
|
||||||
stop_reason=_map_finish_reason_to_stop_reason(choice.finish_reason),
|
stop_reason=_map_finish_reason_to_stop_reason(choice.finish_reason),
|
||||||
)
|
)
|
||||||
|
@ -221,7 +223,7 @@ async def convert_chat_completion_response_stream(
|
||||||
yield ChatCompletionResponseStreamChunk(
|
yield ChatCompletionResponseStreamChunk(
|
||||||
event=ChatCompletionResponseEvent(
|
event=ChatCompletionResponseEvent(
|
||||||
event_type=event_type,
|
event_type=event_type,
|
||||||
delta=choice.delta.content or "",
|
delta=TextDelta(text=choice.delta.content or ""),
|
||||||
logprobs=None,
|
logprobs=None,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
|
@ -34,6 +34,11 @@ from openai.types.chat.chat_completion_message_tool_call_param import (
|
||||||
from openai.types.completion import Completion as OpenAICompletion
|
from openai.types.completion import Completion as OpenAICompletion
|
||||||
from openai.types.completion_choice import Logprobs as OpenAICompletionLogprobs
|
from openai.types.completion_choice import Logprobs as OpenAICompletionLogprobs
|
||||||
|
|
||||||
|
from llama_stack.apis.common.content_types import (
|
||||||
|
TextDelta,
|
||||||
|
ToolCallDelta,
|
||||||
|
ToolCallParseStatus,
|
||||||
|
)
|
||||||
from llama_stack.apis.inference import (
|
from llama_stack.apis.inference import (
|
||||||
ChatCompletionRequest,
|
ChatCompletionRequest,
|
||||||
ChatCompletionResponse,
|
ChatCompletionResponse,
|
||||||
|
@ -48,8 +53,6 @@ from llama_stack.apis.inference import (
|
||||||
Message,
|
Message,
|
||||||
SystemMessage,
|
SystemMessage,
|
||||||
TokenLogProbs,
|
TokenLogProbs,
|
||||||
ToolCallDelta,
|
|
||||||
ToolCallParseStatus,
|
|
||||||
ToolResponseMessage,
|
ToolResponseMessage,
|
||||||
UserMessage,
|
UserMessage,
|
||||||
)
|
)
|
||||||
|
@ -432,69 +435,6 @@ async def convert_openai_chat_completion_stream(
|
||||||
"""
|
"""
|
||||||
Convert a stream of OpenAI chat completion chunks into a stream
|
Convert a stream of OpenAI chat completion chunks into a stream
|
||||||
of ChatCompletionResponseStreamChunk.
|
of ChatCompletionResponseStreamChunk.
|
||||||
|
|
||||||
OpenAI ChatCompletionChunk:
|
|
||||||
choices: List[Choice]
|
|
||||||
|
|
||||||
OpenAI Choice: # different from the non-streamed Choice
|
|
||||||
delta: ChoiceDelta
|
|
||||||
finish_reason: Optional[Literal["stop", "length", "tool_calls", "content_filter", "function_call"]]
|
|
||||||
logprobs: Optional[ChoiceLogprobs]
|
|
||||||
|
|
||||||
OpenAI ChoiceDelta:
|
|
||||||
content: Optional[str]
|
|
||||||
role: Optional[Literal["system", "user", "assistant", "tool"]]
|
|
||||||
tool_calls: Optional[List[ChoiceDeltaToolCall]]
|
|
||||||
|
|
||||||
OpenAI ChoiceDeltaToolCall:
|
|
||||||
index: int
|
|
||||||
id: Optional[str]
|
|
||||||
function: Optional[ChoiceDeltaToolCallFunction]
|
|
||||||
type: Optional[Literal["function"]]
|
|
||||||
|
|
||||||
OpenAI ChoiceDeltaToolCallFunction:
|
|
||||||
name: Optional[str]
|
|
||||||
arguments: Optional[str]
|
|
||||||
|
|
||||||
->
|
|
||||||
|
|
||||||
ChatCompletionResponseStreamChunk:
|
|
||||||
event: ChatCompletionResponseEvent
|
|
||||||
|
|
||||||
ChatCompletionResponseEvent:
|
|
||||||
event_type: ChatCompletionResponseEventType
|
|
||||||
delta: Union[str, ToolCallDelta]
|
|
||||||
logprobs: Optional[List[TokenLogProbs]]
|
|
||||||
stop_reason: Optional[StopReason]
|
|
||||||
|
|
||||||
ChatCompletionResponseEventType:
|
|
||||||
start = "start"
|
|
||||||
progress = "progress"
|
|
||||||
complete = "complete"
|
|
||||||
|
|
||||||
ToolCallDelta:
|
|
||||||
content: Union[str, ToolCall]
|
|
||||||
parse_status: ToolCallParseStatus
|
|
||||||
|
|
||||||
ToolCall:
|
|
||||||
call_id: str
|
|
||||||
tool_name: str
|
|
||||||
arguments: str
|
|
||||||
|
|
||||||
ToolCallParseStatus:
|
|
||||||
started = "started"
|
|
||||||
in_progress = "in_progress"
|
|
||||||
failure = "failure"
|
|
||||||
success = "success"
|
|
||||||
|
|
||||||
TokenLogProbs:
|
|
||||||
logprobs_by_token: Dict[str, float]
|
|
||||||
- token, logprob
|
|
||||||
|
|
||||||
StopReason:
|
|
||||||
end_of_turn = "end_of_turn"
|
|
||||||
end_of_message = "end_of_message"
|
|
||||||
out_of_tokens = "out_of_tokens"
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# generate a stream of ChatCompletionResponseEventType: start -> progress -> progress -> ...
|
# generate a stream of ChatCompletionResponseEventType: start -> progress -> progress -> ...
|
||||||
|
@ -543,7 +483,7 @@ async def convert_openai_chat_completion_stream(
|
||||||
yield ChatCompletionResponseStreamChunk(
|
yield ChatCompletionResponseStreamChunk(
|
||||||
event=ChatCompletionResponseEvent(
|
event=ChatCompletionResponseEvent(
|
||||||
event_type=next(event_type),
|
event_type=next(event_type),
|
||||||
delta=choice.delta.content,
|
delta=TextDelta(text=choice.delta.content),
|
||||||
logprobs=_convert_openai_logprobs(choice.logprobs),
|
logprobs=_convert_openai_logprobs(choice.logprobs),
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
@ -570,7 +510,7 @@ async def convert_openai_chat_completion_stream(
|
||||||
yield ChatCompletionResponseStreamChunk(
|
yield ChatCompletionResponseStreamChunk(
|
||||||
event=ChatCompletionResponseEvent(
|
event=ChatCompletionResponseEvent(
|
||||||
event_type=next(event_type),
|
event_type=next(event_type),
|
||||||
delta=choice.delta.content or "", # content is not optional
|
delta=TextDelta(text=choice.delta.content or ""),
|
||||||
logprobs=_convert_openai_logprobs(choice.logprobs),
|
logprobs=_convert_openai_logprobs(choice.logprobs),
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
@ -578,7 +518,7 @@ async def convert_openai_chat_completion_stream(
|
||||||
yield ChatCompletionResponseStreamChunk(
|
yield ChatCompletionResponseStreamChunk(
|
||||||
event=ChatCompletionResponseEvent(
|
event=ChatCompletionResponseEvent(
|
||||||
event_type=ChatCompletionResponseEventType.complete,
|
event_type=ChatCompletionResponseEventType.complete,
|
||||||
delta="",
|
delta=TextDelta(text=""),
|
||||||
stop_reason=stop_reason,
|
stop_reason=stop_reason,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
@ -653,18 +593,6 @@ def _convert_openai_completion_logprobs(
|
||||||
) -> Optional[List[TokenLogProbs]]:
|
) -> Optional[List[TokenLogProbs]]:
|
||||||
"""
|
"""
|
||||||
Convert an OpenAI CompletionLogprobs into a list of TokenLogProbs.
|
Convert an OpenAI CompletionLogprobs into a list of TokenLogProbs.
|
||||||
|
|
||||||
OpenAI CompletionLogprobs:
|
|
||||||
text_offset: Optional[List[int]]
|
|
||||||
token_logprobs: Optional[List[float]]
|
|
||||||
tokens: Optional[List[str]]
|
|
||||||
top_logprobs: Optional[List[Dict[str, float]]]
|
|
||||||
|
|
||||||
->
|
|
||||||
|
|
||||||
TokenLogProbs:
|
|
||||||
logprobs_by_token: Dict[str, float]
|
|
||||||
- token, logprob
|
|
||||||
"""
|
"""
|
||||||
if not logprobs:
|
if not logprobs:
|
||||||
return None
|
return None
|
||||||
|
@ -679,28 +607,6 @@ def convert_openai_completion_choice(
|
||||||
) -> CompletionResponse:
|
) -> CompletionResponse:
|
||||||
"""
|
"""
|
||||||
Convert an OpenAI Completion Choice into a CompletionResponse.
|
Convert an OpenAI Completion Choice into a CompletionResponse.
|
||||||
|
|
||||||
OpenAI Completion Choice:
|
|
||||||
text: str
|
|
||||||
finish_reason: str
|
|
||||||
logprobs: Optional[ChoiceLogprobs]
|
|
||||||
|
|
||||||
->
|
|
||||||
|
|
||||||
CompletionResponse:
|
|
||||||
completion_message: CompletionMessage
|
|
||||||
logprobs: Optional[List[TokenLogProbs]]
|
|
||||||
|
|
||||||
CompletionMessage:
|
|
||||||
role: Literal["assistant"]
|
|
||||||
content: str | ImageMedia | List[str | ImageMedia]
|
|
||||||
stop_reason: StopReason
|
|
||||||
tool_calls: List[ToolCall]
|
|
||||||
|
|
||||||
class StopReason(Enum):
|
|
||||||
end_of_turn = "end_of_turn"
|
|
||||||
end_of_message = "end_of_message"
|
|
||||||
out_of_tokens = "out_of_tokens"
|
|
||||||
"""
|
"""
|
||||||
return CompletionResponse(
|
return CompletionResponse(
|
||||||
content=choice.text,
|
content=choice.text,
|
||||||
|
@ -715,32 +621,11 @@ async def convert_openai_completion_stream(
|
||||||
"""
|
"""
|
||||||
Convert a stream of OpenAI Completions into a stream
|
Convert a stream of OpenAI Completions into a stream
|
||||||
of ChatCompletionResponseStreamChunks.
|
of ChatCompletionResponseStreamChunks.
|
||||||
|
|
||||||
OpenAI Completion:
|
|
||||||
id: str
|
|
||||||
choices: List[OpenAICompletionChoice]
|
|
||||||
created: int
|
|
||||||
model: str
|
|
||||||
system_fingerprint: Optional[str]
|
|
||||||
usage: Optional[OpenAICompletionUsage]
|
|
||||||
|
|
||||||
OpenAI CompletionChoice:
|
|
||||||
finish_reason: str
|
|
||||||
index: int
|
|
||||||
logprobs: Optional[OpenAILogprobs]
|
|
||||||
text: str
|
|
||||||
|
|
||||||
->
|
|
||||||
|
|
||||||
CompletionResponseStreamChunk:
|
|
||||||
delta: str
|
|
||||||
stop_reason: Optional[StopReason]
|
|
||||||
logprobs: Optional[List[TokenLogProbs]]
|
|
||||||
"""
|
"""
|
||||||
async for chunk in stream:
|
async for chunk in stream:
|
||||||
choice = chunk.choices[0]
|
choice = chunk.choices[0]
|
||||||
yield CompletionResponseStreamChunk(
|
yield CompletionResponseStreamChunk(
|
||||||
delta=choice.text,
|
delta=TextDelta(text=choice.text),
|
||||||
stop_reason=_convert_openai_finish_reason(choice.finish_reason),
|
stop_reason=_convert_openai_finish_reason(choice.finish_reason),
|
||||||
logprobs=_convert_openai_completion_logprobs(choice.logprobs),
|
logprobs=_convert_openai_completion_logprobs(choice.logprobs),
|
||||||
)
|
)
|
||||||
|
|
|
@ -18,6 +18,7 @@ from llama_models.llama3.api.datatypes import (
|
||||||
|
|
||||||
from pydantic import BaseModel, ValidationError
|
from pydantic import BaseModel, ValidationError
|
||||||
|
|
||||||
|
from llama_stack.apis.common.content_types import ToolCallParseStatus
|
||||||
from llama_stack.apis.inference import (
|
from llama_stack.apis.inference import (
|
||||||
ChatCompletionResponse,
|
ChatCompletionResponse,
|
||||||
ChatCompletionResponseEventType,
|
ChatCompletionResponseEventType,
|
||||||
|
@ -27,8 +28,6 @@ from llama_stack.apis.inference import (
|
||||||
JsonSchemaResponseFormat,
|
JsonSchemaResponseFormat,
|
||||||
LogProbConfig,
|
LogProbConfig,
|
||||||
SystemMessage,
|
SystemMessage,
|
||||||
ToolCallDelta,
|
|
||||||
ToolCallParseStatus,
|
|
||||||
ToolChoice,
|
ToolChoice,
|
||||||
UserMessage,
|
UserMessage,
|
||||||
)
|
)
|
||||||
|
@ -196,7 +195,9 @@ class TestInference:
|
||||||
1 <= len(chunks) <= 6
|
1 <= len(chunks) <= 6
|
||||||
) # why 6 and not 5? the response may have an extra closing chunk, e.g. for usage or stop_reason
|
) # why 6 and not 5? the response may have an extra closing chunk, e.g. for usage or stop_reason
|
||||||
for chunk in chunks:
|
for chunk in chunks:
|
||||||
if chunk.delta: # if there's a token, we expect logprobs
|
if (
|
||||||
|
chunk.delta.type == "text" and chunk.delta.text
|
||||||
|
): # if there's a token, we expect logprobs
|
||||||
assert chunk.logprobs, "Logprobs should not be empty"
|
assert chunk.logprobs, "Logprobs should not be empty"
|
||||||
assert all(
|
assert all(
|
||||||
len(logprob.logprobs_by_token) == 3 for logprob in chunk.logprobs
|
len(logprob.logprobs_by_token) == 3 for logprob in chunk.logprobs
|
||||||
|
@ -463,7 +464,7 @@ class TestInference:
|
||||||
|
|
||||||
if "Llama3.1" in inference_model:
|
if "Llama3.1" in inference_model:
|
||||||
assert all(
|
assert all(
|
||||||
isinstance(chunk.event.delta, ToolCallDelta)
|
chunk.event.delta.type == "tool_call"
|
||||||
for chunk in grouped[ChatCompletionResponseEventType.progress]
|
for chunk in grouped[ChatCompletionResponseEventType.progress]
|
||||||
)
|
)
|
||||||
first = grouped[ChatCompletionResponseEventType.progress][0]
|
first = grouped[ChatCompletionResponseEventType.progress][0]
|
||||||
|
@ -475,7 +476,7 @@ class TestInference:
|
||||||
last = grouped[ChatCompletionResponseEventType.progress][-1]
|
last = grouped[ChatCompletionResponseEventType.progress][-1]
|
||||||
# assert last.event.stop_reason == expected_stop_reason
|
# assert last.event.stop_reason == expected_stop_reason
|
||||||
assert last.event.delta.parse_status == ToolCallParseStatus.success
|
assert last.event.delta.parse_status == ToolCallParseStatus.success
|
||||||
assert isinstance(last.event.delta.content, ToolCall)
|
assert last.event.delta.content.type == "tool_call"
|
||||||
|
|
||||||
call = last.event.delta.content
|
call = last.event.delta.content
|
||||||
assert call.tool_name == "get_weather"
|
assert call.tool_name == "get_weather"
|
||||||
|
|
|
@ -11,7 +11,13 @@ from llama_models.llama3.api.chat_format import ChatFormat
|
||||||
from llama_models.llama3.api.datatypes import SamplingParams, StopReason
|
from llama_models.llama3.api.datatypes import SamplingParams, StopReason
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
from llama_stack.apis.common.content_types import ImageContentItem, TextContentItem
|
from llama_stack.apis.common.content_types import (
|
||||||
|
ImageContentItem,
|
||||||
|
TextContentItem,
|
||||||
|
TextDelta,
|
||||||
|
ToolCallDelta,
|
||||||
|
ToolCallParseStatus,
|
||||||
|
)
|
||||||
|
|
||||||
from llama_stack.apis.inference import (
|
from llama_stack.apis.inference import (
|
||||||
ChatCompletionResponse,
|
ChatCompletionResponse,
|
||||||
|
@ -22,8 +28,6 @@ from llama_stack.apis.inference import (
|
||||||
CompletionResponse,
|
CompletionResponse,
|
||||||
CompletionResponseStreamChunk,
|
CompletionResponseStreamChunk,
|
||||||
Message,
|
Message,
|
||||||
ToolCallDelta,
|
|
||||||
ToolCallParseStatus,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
from llama_stack.providers.utils.inference.prompt_adapter import (
|
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||||
|
@ -138,7 +142,7 @@ async def process_completion_stream_response(
|
||||||
text = ""
|
text = ""
|
||||||
continue
|
continue
|
||||||
yield CompletionResponseStreamChunk(
|
yield CompletionResponseStreamChunk(
|
||||||
delta=text,
|
delta=TextDelta(text=text),
|
||||||
stop_reason=stop_reason,
|
stop_reason=stop_reason,
|
||||||
)
|
)
|
||||||
if finish_reason:
|
if finish_reason:
|
||||||
|
@ -149,7 +153,7 @@ async def process_completion_stream_response(
|
||||||
break
|
break
|
||||||
|
|
||||||
yield CompletionResponseStreamChunk(
|
yield CompletionResponseStreamChunk(
|
||||||
delta="",
|
delta=TextDelta(text=""),
|
||||||
stop_reason=stop_reason,
|
stop_reason=stop_reason,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -160,7 +164,7 @@ async def process_chat_completion_stream_response(
|
||||||
yield ChatCompletionResponseStreamChunk(
|
yield ChatCompletionResponseStreamChunk(
|
||||||
event=ChatCompletionResponseEvent(
|
event=ChatCompletionResponseEvent(
|
||||||
event_type=ChatCompletionResponseEventType.start,
|
event_type=ChatCompletionResponseEventType.start,
|
||||||
delta="",
|
delta=TextDelta(text=""),
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -227,7 +231,7 @@ async def process_chat_completion_stream_response(
|
||||||
yield ChatCompletionResponseStreamChunk(
|
yield ChatCompletionResponseStreamChunk(
|
||||||
event=ChatCompletionResponseEvent(
|
event=ChatCompletionResponseEvent(
|
||||||
event_type=ChatCompletionResponseEventType.progress,
|
event_type=ChatCompletionResponseEventType.progress,
|
||||||
delta=text,
|
delta=TextDelta(text=text),
|
||||||
stop_reason=stop_reason,
|
stop_reason=stop_reason,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
@ -262,7 +266,7 @@ async def process_chat_completion_stream_response(
|
||||||
yield ChatCompletionResponseStreamChunk(
|
yield ChatCompletionResponseStreamChunk(
|
||||||
event=ChatCompletionResponseEvent(
|
event=ChatCompletionResponseEvent(
|
||||||
event_type=ChatCompletionResponseEventType.complete,
|
event_type=ChatCompletionResponseEventType.complete,
|
||||||
delta="",
|
delta=TextDelta(text=""),
|
||||||
stop_reason=stop_reason,
|
stop_reason=stop_reason,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue