From 9a5803a429770fd7f23aec0482001b6bf8c3d0f5 Mon Sep 17 00:00:00 2001 From: Ashwin Bharambe Date: Mon, 13 Jan 2025 20:04:19 -0800 Subject: [PATCH] move all implementations to use updated type --- llama_stack/apis/agents/agents.py | 6 +- llama_stack/apis/agents/event_logger.py | 124 ++++++++++------ .../agents/meta_reference/agent_instance.py | 22 +-- .../inference/meta_reference/inference.py | 17 ++- .../remote/inference/groq/groq_utils.py | 14 +- .../remote/inference/nvidia/openai_utils.py | 133 ++---------------- .../tests/inference/test_text_inference.py | 11 +- .../utils/inference/openai_compat.py | 20 +-- 8 files changed, 139 insertions(+), 208 deletions(-) diff --git a/llama_stack/apis/agents/agents.py b/llama_stack/apis/agents/agents.py index fb9df21e6..c3f3d21f0 100644 --- a/llama_stack/apis/agents/agents.py +++ b/llama_stack/apis/agents/agents.py @@ -22,12 +22,11 @@ from llama_models.schema_utils import json_schema_type, register_schema, webmeth from pydantic import BaseModel, ConfigDict, Field from typing_extensions import Annotated -from llama_stack.apis.common.content_types import InterleavedContent, URL +from llama_stack.apis.common.content_types import ContentDelta, InterleavedContent, URL from llama_stack.apis.inference import ( CompletionMessage, SamplingParams, ToolCall, - ToolCallDelta, ToolChoice, ToolPromptFormat, ToolResponse, @@ -216,8 +215,7 @@ class AgentTurnResponseStepProgressPayload(BaseModel): step_type: StepType step_id: str - text_delta: Optional[str] = None - tool_call_delta: Optional[ToolCallDelta] = None + delta: ContentDelta @json_schema_type diff --git a/llama_stack/apis/agents/event_logger.py b/llama_stack/apis/agents/event_logger.py index 40a69d19c..41004ccb0 100644 --- a/llama_stack/apis/agents/event_logger.py +++ b/llama_stack/apis/agents/event_logger.py @@ -11,9 +11,13 @@ from llama_models.llama3.api.tool_utils import ToolUtils from termcolor import cprint from llama_stack.apis.agents import AgentTurnResponseEventType, StepType - +from llama_stack.apis.common.content_types import ToolCallParseStatus from llama_stack.apis.inference import ToolResponseMessage +from llama_stack.providers.utils.inference.prompt_adapter import ( + interleaved_content_as_str, +) + class LogEvent: def __init__( @@ -57,8 +61,11 @@ class EventLogger: # since it does not produce event but instead # a Message if isinstance(chunk, ToolResponseMessage): - yield chunk, LogEvent( - role="CustomTool", content=chunk.content, color="grey" + yield ( + chunk, + LogEvent( + role="CustomTool", content=chunk.content, color="grey" + ), ) continue @@ -80,14 +87,20 @@ class EventLogger: ): violation = event.payload.step_details.violation if not violation: - yield event, LogEvent( - role=step_type, content="No Violation", color="magenta" + yield ( + event, + LogEvent( + role=step_type, content="No Violation", color="magenta" + ), ) else: - yield event, LogEvent( - role=step_type, - content=f"{violation.metadata} {violation.user_message}", - color="red", + yield ( + event, + LogEvent( + role=step_type, + content=f"{violation.metadata} {violation.user_message}", + color="red", + ), ) # handle inference @@ -95,8 +108,11 @@ class EventLogger: if stream: if event_type == EventType.step_start.value: # TODO: Currently this event is never received - yield event, LogEvent( - role=step_type, content="", end="", color="yellow" + yield ( + event, + LogEvent( + role=step_type, content="", end="", color="yellow" + ), ) elif event_type == EventType.step_progress.value: # HACK: if previous was not step/event was not inference's step_progress @@ -107,24 +123,34 @@ class EventLogger: previous_event_type != EventType.step_progress.value and previous_step_type != StepType.inference ): - yield event, LogEvent( - role=step_type, content="", end="", color="yellow" + yield ( + event, + LogEvent( + role=step_type, content="", end="", color="yellow" + ), ) - if event.payload.tool_call_delta: - if isinstance(event.payload.tool_call_delta.content, str): - yield event, LogEvent( - role=None, - content=event.payload.tool_call_delta.content, - end="", - color="cyan", + delta = event.payload.delta + if delta.type == "tool_call": + if delta.parse_status == ToolCallParseStatus.success: + yield ( + event, + LogEvent( + role=None, + content=delta.content, + end="", + color="cyan", + ), ) else: - yield event, LogEvent( - role=None, - content=event.payload.text_delta, - end="", - color="yellow", + yield ( + event, + LogEvent( + role=None, + content=delta.text, + end="", + color="yellow", + ), ) else: # step_complete @@ -140,10 +166,13 @@ class EventLogger: ) else: content = response.content - yield event, LogEvent( - role=step_type, - content=content, - color="yellow", + yield ( + event, + LogEvent( + role=step_type, + content=content, + color="yellow", + ), ) # handle tool_execution @@ -155,16 +184,22 @@ class EventLogger: ): details = event.payload.step_details for t in details.tool_calls: - yield event, LogEvent( - role=step_type, - content=f"Tool:{t.tool_name} Args:{t.arguments}", - color="green", + yield ( + event, + LogEvent( + role=step_type, + content=f"Tool:{t.tool_name} Args:{t.arguments}", + color="green", + ), ) for r in details.tool_responses: - yield event, LogEvent( - role=step_type, - content=f"Tool:{r.tool_name} Response:{r.content}", - color="green", + yield ( + event, + LogEvent( + role=step_type, + content=f"Tool:{r.tool_name} Response:{r.content}", + color="green", + ), ) if ( @@ -172,15 +207,16 @@ class EventLogger: and event_type == EventType.step_complete.value ): details = event.payload.step_details - inserted_context = interleaved_text_media_as_str( - details.inserted_context - ) + inserted_context = interleaved_content_as_str(details.inserted_context) content = f"fetched {len(inserted_context)} bytes from {details.memory_bank_ids}" - yield event, LogEvent( - role=step_type, - content=content, - color="cyan", + yield ( + event, + LogEvent( + role=step_type, + content=content, + color="cyan", + ), ) previous_event_type = event_type diff --git a/llama_stack/providers/inline/agents/meta_reference/agent_instance.py b/llama_stack/providers/inline/agents/meta_reference/agent_instance.py index 24448a28f..be33d75c3 100644 --- a/llama_stack/providers/inline/agents/meta_reference/agent_instance.py +++ b/llama_stack/providers/inline/agents/meta_reference/agent_instance.py @@ -40,7 +40,12 @@ from llama_stack.apis.agents import ( ToolExecutionStep, Turn, ) -from llama_stack.apis.common.content_types import TextContentItem, URL +from llama_stack.apis.common.content_types import ( + TextContentItem, + ToolCallDelta, + ToolCallParseStatus, + URL, +) from llama_stack.apis.inference import ( ChatCompletionResponseEventType, CompletionMessage, @@ -49,8 +54,6 @@ from llama_stack.apis.inference import ( SamplingParams, StopReason, SystemMessage, - ToolCallDelta, - ToolCallParseStatus, ToolDefinition, ToolResponse, ToolResponseMessage, @@ -411,7 +414,7 @@ class ChatAgent(ShieldRunnerMixin): payload=AgentTurnResponseStepProgressPayload( step_type=StepType.tool_execution.value, step_id=step_id, - tool_call_delta=ToolCallDelta( + delta=ToolCallDelta( parse_status=ToolCallParseStatus.success, content=ToolCall( call_id="", @@ -507,7 +510,7 @@ class ChatAgent(ShieldRunnerMixin): continue delta = event.delta - if isinstance(delta, ToolCallDelta): + if delta.type == "tool_call": if delta.parse_status == ToolCallParseStatus.success: tool_calls.append(delta.content) if stream: @@ -516,21 +519,20 @@ class ChatAgent(ShieldRunnerMixin): payload=AgentTurnResponseStepProgressPayload( step_type=StepType.inference.value, step_id=step_id, - text_delta="", - tool_call_delta=delta, + delta=delta, ) ) ) - elif isinstance(delta, str): - content += delta + elif delta.type == "text": + content += delta.text if stream and event.stop_reason is None: yield AgentTurnResponseStreamChunk( event=AgentTurnResponseEvent( payload=AgentTurnResponseStepProgressPayload( step_type=StepType.inference.value, step_id=step_id, - text_delta=event.delta, + delta=delta, ) ) ) diff --git a/llama_stack/providers/inline/inference/meta_reference/inference.py b/llama_stack/providers/inline/inference/meta_reference/inference.py index 5b502a581..e099580af 100644 --- a/llama_stack/providers/inline/inference/meta_reference/inference.py +++ b/llama_stack/providers/inline/inference/meta_reference/inference.py @@ -16,6 +16,11 @@ from llama_models.llama3.api.datatypes import ( ) from llama_models.sku_list import resolve_model +from llama_stack.apis.common.content_types import ( + TextDelta, + ToolCallDelta, + ToolCallParseStatus, +) from llama_stack.apis.inference import ( ChatCompletionRequest, ChatCompletionResponse, @@ -32,8 +37,6 @@ from llama_stack.apis.inference import ( Message, ResponseFormat, TokenLogProbs, - ToolCallDelta, - ToolCallParseStatus, ToolChoice, ) from llama_stack.apis.models import Model, ModelType @@ -190,14 +193,14 @@ class MetaReferenceInferenceImpl( ] yield CompletionResponseStreamChunk( - delta=text, + delta=TextDelta(text=text), stop_reason=stop_reason, logprobs=logprobs if request.logprobs else None, ) if stop_reason is None: yield CompletionResponseStreamChunk( - delta="", + delta=TextDelta(text=""), stop_reason=StopReason.out_of_tokens, ) @@ -352,7 +355,7 @@ class MetaReferenceInferenceImpl( yield ChatCompletionResponseStreamChunk( event=ChatCompletionResponseEvent( event_type=ChatCompletionResponseEventType.start, - delta="", + delta=TextDelta(text=""), ) ) @@ -392,7 +395,7 @@ class MetaReferenceInferenceImpl( parse_status=ToolCallParseStatus.in_progress, ) else: - delta = text + delta = TextDelta(text=text) if stop_reason is None: if request.logprobs: @@ -449,7 +452,7 @@ class MetaReferenceInferenceImpl( yield ChatCompletionResponseStreamChunk( event=ChatCompletionResponseEvent( event_type=ChatCompletionResponseEventType.complete, - delta="", + delta=TextDelta(text=""), stop_reason=stop_reason, ) ) diff --git a/llama_stack/providers/remote/inference/groq/groq_utils.py b/llama_stack/providers/remote/inference/groq/groq_utils.py index 032f4c8d4..b87c0c94c 100644 --- a/llama_stack/providers/remote/inference/groq/groq_utils.py +++ b/llama_stack/providers/remote/inference/groq/groq_utils.py @@ -30,6 +30,11 @@ from groq.types.shared.function_definition import FunctionDefinition from llama_models.llama3.api.datatypes import ToolParamDefinition +from llama_stack.apis.common.content_types import ( + TextDelta, + ToolCallDelta, + ToolCallParseStatus, +) from llama_stack.apis.inference import ( ChatCompletionRequest, ChatCompletionResponse, @@ -40,8 +45,6 @@ from llama_stack.apis.inference import ( Message, StopReason, ToolCall, - ToolCallDelta, - ToolCallParseStatus, ToolDefinition, ToolPromptFormat, ) @@ -162,7 +165,7 @@ def convert_chat_completion_response( def _map_finish_reason_to_stop_reason( - finish_reason: Literal["stop", "length", "tool_calls"] + finish_reason: Literal["stop", "length", "tool_calls"], ) -> StopReason: """ Convert a Groq chat completion finish_reason to a StopReason. @@ -185,7 +188,6 @@ def _map_finish_reason_to_stop_reason( async def convert_chat_completion_response_stream( stream: Stream[ChatCompletionChunk], ) -> AsyncGenerator[ChatCompletionResponseStreamChunk, None]: - event_type = ChatCompletionResponseEventType.start for chunk in stream: choice = chunk.choices[0] @@ -194,7 +196,7 @@ async def convert_chat_completion_response_stream( yield ChatCompletionResponseStreamChunk( event=ChatCompletionResponseEvent( event_type=ChatCompletionResponseEventType.complete, - delta=choice.delta.content or "", + delta=TextDelta(text=choice.delta.content or ""), logprobs=None, stop_reason=_map_finish_reason_to_stop_reason(choice.finish_reason), ) @@ -221,7 +223,7 @@ async def convert_chat_completion_response_stream( yield ChatCompletionResponseStreamChunk( event=ChatCompletionResponseEvent( event_type=event_type, - delta=choice.delta.content or "", + delta=TextDelta(text=choice.delta.content or ""), logprobs=None, ) ) diff --git a/llama_stack/providers/remote/inference/nvidia/openai_utils.py b/llama_stack/providers/remote/inference/nvidia/openai_utils.py index dcc7c5fca..955b65aa5 100644 --- a/llama_stack/providers/remote/inference/nvidia/openai_utils.py +++ b/llama_stack/providers/remote/inference/nvidia/openai_utils.py @@ -34,6 +34,11 @@ from openai.types.chat.chat_completion_message_tool_call_param import ( from openai.types.completion import Completion as OpenAICompletion from openai.types.completion_choice import Logprobs as OpenAICompletionLogprobs +from llama_stack.apis.common.content_types import ( + TextDelta, + ToolCallDelta, + ToolCallParseStatus, +) from llama_stack.apis.inference import ( ChatCompletionRequest, ChatCompletionResponse, @@ -48,8 +53,6 @@ from llama_stack.apis.inference import ( Message, SystemMessage, TokenLogProbs, - ToolCallDelta, - ToolCallParseStatus, ToolResponseMessage, UserMessage, ) @@ -432,69 +435,6 @@ async def convert_openai_chat_completion_stream( """ Convert a stream of OpenAI chat completion chunks into a stream of ChatCompletionResponseStreamChunk. - - OpenAI ChatCompletionChunk: - choices: List[Choice] - - OpenAI Choice: # different from the non-streamed Choice - delta: ChoiceDelta - finish_reason: Optional[Literal["stop", "length", "tool_calls", "content_filter", "function_call"]] - logprobs: Optional[ChoiceLogprobs] - - OpenAI ChoiceDelta: - content: Optional[str] - role: Optional[Literal["system", "user", "assistant", "tool"]] - tool_calls: Optional[List[ChoiceDeltaToolCall]] - - OpenAI ChoiceDeltaToolCall: - index: int - id: Optional[str] - function: Optional[ChoiceDeltaToolCallFunction] - type: Optional[Literal["function"]] - - OpenAI ChoiceDeltaToolCallFunction: - name: Optional[str] - arguments: Optional[str] - - -> - - ChatCompletionResponseStreamChunk: - event: ChatCompletionResponseEvent - - ChatCompletionResponseEvent: - event_type: ChatCompletionResponseEventType - delta: Union[str, ToolCallDelta] - logprobs: Optional[List[TokenLogProbs]] - stop_reason: Optional[StopReason] - - ChatCompletionResponseEventType: - start = "start" - progress = "progress" - complete = "complete" - - ToolCallDelta: - content: Union[str, ToolCall] - parse_status: ToolCallParseStatus - - ToolCall: - call_id: str - tool_name: str - arguments: str - - ToolCallParseStatus: - started = "started" - in_progress = "in_progress" - failure = "failure" - success = "success" - - TokenLogProbs: - logprobs_by_token: Dict[str, float] - - token, logprob - - StopReason: - end_of_turn = "end_of_turn" - end_of_message = "end_of_message" - out_of_tokens = "out_of_tokens" """ # generate a stream of ChatCompletionResponseEventType: start -> progress -> progress -> ... @@ -543,7 +483,7 @@ async def convert_openai_chat_completion_stream( yield ChatCompletionResponseStreamChunk( event=ChatCompletionResponseEvent( event_type=next(event_type), - delta=choice.delta.content, + delta=TextDelta(text=choice.delta.content), logprobs=_convert_openai_logprobs(choice.logprobs), ) ) @@ -570,7 +510,7 @@ async def convert_openai_chat_completion_stream( yield ChatCompletionResponseStreamChunk( event=ChatCompletionResponseEvent( event_type=next(event_type), - delta=choice.delta.content or "", # content is not optional + delta=TextDelta(text=choice.delta.content or ""), logprobs=_convert_openai_logprobs(choice.logprobs), ) ) @@ -578,7 +518,7 @@ async def convert_openai_chat_completion_stream( yield ChatCompletionResponseStreamChunk( event=ChatCompletionResponseEvent( event_type=ChatCompletionResponseEventType.complete, - delta="", + delta=TextDelta(text=""), stop_reason=stop_reason, ) ) @@ -653,18 +593,6 @@ def _convert_openai_completion_logprobs( ) -> Optional[List[TokenLogProbs]]: """ Convert an OpenAI CompletionLogprobs into a list of TokenLogProbs. - - OpenAI CompletionLogprobs: - text_offset: Optional[List[int]] - token_logprobs: Optional[List[float]] - tokens: Optional[List[str]] - top_logprobs: Optional[List[Dict[str, float]]] - - -> - - TokenLogProbs: - logprobs_by_token: Dict[str, float] - - token, logprob """ if not logprobs: return None @@ -679,28 +607,6 @@ def convert_openai_completion_choice( ) -> CompletionResponse: """ Convert an OpenAI Completion Choice into a CompletionResponse. - - OpenAI Completion Choice: - text: str - finish_reason: str - logprobs: Optional[ChoiceLogprobs] - - -> - - CompletionResponse: - completion_message: CompletionMessage - logprobs: Optional[List[TokenLogProbs]] - - CompletionMessage: - role: Literal["assistant"] - content: str | ImageMedia | List[str | ImageMedia] - stop_reason: StopReason - tool_calls: List[ToolCall] - - class StopReason(Enum): - end_of_turn = "end_of_turn" - end_of_message = "end_of_message" - out_of_tokens = "out_of_tokens" """ return CompletionResponse( content=choice.text, @@ -715,32 +621,11 @@ async def convert_openai_completion_stream( """ Convert a stream of OpenAI Completions into a stream of ChatCompletionResponseStreamChunks. - - OpenAI Completion: - id: str - choices: List[OpenAICompletionChoice] - created: int - model: str - system_fingerprint: Optional[str] - usage: Optional[OpenAICompletionUsage] - - OpenAI CompletionChoice: - finish_reason: str - index: int - logprobs: Optional[OpenAILogprobs] - text: str - - -> - - CompletionResponseStreamChunk: - delta: str - stop_reason: Optional[StopReason] - logprobs: Optional[List[TokenLogProbs]] """ async for chunk in stream: choice = chunk.choices[0] yield CompletionResponseStreamChunk( - delta=choice.text, + delta=TextDelta(text=choice.text), stop_reason=_convert_openai_finish_reason(choice.finish_reason), logprobs=_convert_openai_completion_logprobs(choice.logprobs), ) diff --git a/llama_stack/providers/tests/inference/test_text_inference.py b/llama_stack/providers/tests/inference/test_text_inference.py index 19cc8393c..24093cb59 100644 --- a/llama_stack/providers/tests/inference/test_text_inference.py +++ b/llama_stack/providers/tests/inference/test_text_inference.py @@ -18,6 +18,7 @@ from llama_models.llama3.api.datatypes import ( from pydantic import BaseModel, ValidationError +from llama_stack.apis.common.content_types import ToolCallParseStatus from llama_stack.apis.inference import ( ChatCompletionResponse, ChatCompletionResponseEventType, @@ -27,8 +28,6 @@ from llama_stack.apis.inference import ( JsonSchemaResponseFormat, LogProbConfig, SystemMessage, - ToolCallDelta, - ToolCallParseStatus, ToolChoice, UserMessage, ) @@ -196,7 +195,9 @@ class TestInference: 1 <= len(chunks) <= 6 ) # why 6 and not 5? the response may have an extra closing chunk, e.g. for usage or stop_reason for chunk in chunks: - if chunk.delta: # if there's a token, we expect logprobs + if ( + chunk.delta.type == "text" and chunk.delta.text + ): # if there's a token, we expect logprobs assert chunk.logprobs, "Logprobs should not be empty" assert all( len(logprob.logprobs_by_token) == 3 for logprob in chunk.logprobs @@ -463,7 +464,7 @@ class TestInference: if "Llama3.1" in inference_model: assert all( - isinstance(chunk.event.delta, ToolCallDelta) + chunk.event.delta.type == "tool_call" for chunk in grouped[ChatCompletionResponseEventType.progress] ) first = grouped[ChatCompletionResponseEventType.progress][0] @@ -475,7 +476,7 @@ class TestInference: last = grouped[ChatCompletionResponseEventType.progress][-1] # assert last.event.stop_reason == expected_stop_reason assert last.event.delta.parse_status == ToolCallParseStatus.success - assert isinstance(last.event.delta.content, ToolCall) + assert last.event.delta.content.type == "tool_call" call = last.event.delta.content assert call.tool_name == "get_weather" diff --git a/llama_stack/providers/utils/inference/openai_compat.py b/llama_stack/providers/utils/inference/openai_compat.py index ba63be2b6..e70ad4033 100644 --- a/llama_stack/providers/utils/inference/openai_compat.py +++ b/llama_stack/providers/utils/inference/openai_compat.py @@ -11,7 +11,13 @@ from llama_models.llama3.api.chat_format import ChatFormat from llama_models.llama3.api.datatypes import SamplingParams, StopReason from pydantic import BaseModel -from llama_stack.apis.common.content_types import ImageContentItem, TextContentItem +from llama_stack.apis.common.content_types import ( + ImageContentItem, + TextContentItem, + TextDelta, + ToolCallDelta, + ToolCallParseStatus, +) from llama_stack.apis.inference import ( ChatCompletionResponse, @@ -22,8 +28,6 @@ from llama_stack.apis.inference import ( CompletionResponse, CompletionResponseStreamChunk, Message, - ToolCallDelta, - ToolCallParseStatus, ) from llama_stack.providers.utils.inference.prompt_adapter import ( @@ -138,7 +142,7 @@ async def process_completion_stream_response( text = "" continue yield CompletionResponseStreamChunk( - delta=text, + delta=TextDelta(text=text), stop_reason=stop_reason, ) if finish_reason: @@ -149,7 +153,7 @@ async def process_completion_stream_response( break yield CompletionResponseStreamChunk( - delta="", + delta=TextDelta(text=""), stop_reason=stop_reason, ) @@ -160,7 +164,7 @@ async def process_chat_completion_stream_response( yield ChatCompletionResponseStreamChunk( event=ChatCompletionResponseEvent( event_type=ChatCompletionResponseEventType.start, - delta="", + delta=TextDelta(text=""), ) ) @@ -227,7 +231,7 @@ async def process_chat_completion_stream_response( yield ChatCompletionResponseStreamChunk( event=ChatCompletionResponseEvent( event_type=ChatCompletionResponseEventType.progress, - delta=text, + delta=TextDelta(text=text), stop_reason=stop_reason, ) ) @@ -262,7 +266,7 @@ async def process_chat_completion_stream_response( yield ChatCompletionResponseStreamChunk( event=ChatCompletionResponseEvent( event_type=ChatCompletionResponseEventType.complete, - delta="", + delta=TextDelta(text=""), stop_reason=stop_reason, ) )