From aced2ce07e02295d13db11be39156adb24ce07f3 Mon Sep 17 00:00:00 2001 From: Ashwin Bharambe Date: Mon, 13 Jan 2025 19:38:44 -0800 Subject: [PATCH 1/6] introduce and use a generic ContentDelta --- llama_stack/apis/common/content_types.py | 43 +++++++++++++++++++++++- llama_stack/apis/inference/inference.py | 18 ++-------- 2 files changed, 44 insertions(+), 17 deletions(-) diff --git a/llama_stack/apis/common/content_types.py b/llama_stack/apis/common/content_types.py index 629e0e94d..3b61fa243 100644 --- a/llama_stack/apis/common/content_types.py +++ b/llama_stack/apis/common/content_types.py @@ -5,10 +5,12 @@ # the root directory of this source tree. import base64 +from enum import Enum from typing import Annotated, List, Literal, Optional, Union -from llama_models.schema_utils import json_schema_type, register_schema +from llama_models.llama3.api.datatypes import ToolCall +from llama_models.schema_utils import json_schema_type, register_schema from pydantic import BaseModel, Field, field_serializer, model_validator @@ -60,3 +62,42 @@ InterleavedContent = register_schema( Union[str, InterleavedContentItem, List[InterleavedContentItem]], name="InterleavedContent", ) + + +class TextDelta(BaseModel): + type: Literal["text"] = "text" + text: str + + +class ImageDelta(BaseModel): + type: Literal["image"] = "image" + data: bytes + + +@json_schema_type +class ToolCallParseStatus(Enum): + started = "started" + in_progress = "in_progress" + failed = "failed" + succeeded = "succeeded" + + +@json_schema_type +class ToolCallDelta(BaseModel): + type: Literal["tool_call"] = "tool_call" + + # you either send an in-progress tool call so the client can stream a long + # code generation or you send the final parsed tool call at the end of the + # stream + content: Union[str, ToolCall] + parse_status: ToolCallParseStatus + + +# streaming completions send a stream of ContentDeltas +ContentDelta = register_schema( + Annotated[ + Union[TextDelta, ImageDelta, ToolCallDelta], + Field(discriminator="type"), + ], + name="ContentDelta", +) diff --git a/llama_stack/apis/inference/inference.py b/llama_stack/apis/inference/inference.py index 4a453700c..b525aa331 100644 --- a/llama_stack/apis/inference/inference.py +++ b/llama_stack/apis/inference/inference.py @@ -29,7 +29,7 @@ from llama_models.schema_utils import json_schema_type, register_schema, webmeth from pydantic import BaseModel, Field, field_validator from typing_extensions import Annotated -from llama_stack.apis.common.content_types import InterleavedContent +from llama_stack.apis.common.content_types import ContentDelta, InterleavedContent from llama_stack.apis.models import Model from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol @@ -147,26 +147,12 @@ class ChatCompletionResponseEventType(Enum): progress = "progress" -@json_schema_type -class ToolCallParseStatus(Enum): - started = "started" - in_progress = "in_progress" - failure = "failure" - success = "success" - - -@json_schema_type -class ToolCallDelta(BaseModel): - content: Union[str, ToolCall] - parse_status: ToolCallParseStatus - - @json_schema_type class ChatCompletionResponseEvent(BaseModel): """Chat completion response event.""" event_type: ChatCompletionResponseEventType - delta: Union[str, ToolCallDelta] + delta: ContentDelta logprobs: Optional[List[TokenLogProbs]] = None stop_reason: Optional[StopReason] = None From 9a5803a429770fd7f23aec0482001b6bf8c3d0f5 Mon Sep 17 00:00:00 2001 From: Ashwin Bharambe Date: Mon, 13 Jan 2025 20:04:19 -0800 Subject: [PATCH 2/6] move all implementations to use updated type --- llama_stack/apis/agents/agents.py | 6 +- llama_stack/apis/agents/event_logger.py | 124 ++++++++++------ .../agents/meta_reference/agent_instance.py | 22 +-- .../inference/meta_reference/inference.py | 17 ++- .../remote/inference/groq/groq_utils.py | 14 +- .../remote/inference/nvidia/openai_utils.py | 133 ++---------------- .../tests/inference/test_text_inference.py | 11 +- .../utils/inference/openai_compat.py | 20 +-- 8 files changed, 139 insertions(+), 208 deletions(-) diff --git a/llama_stack/apis/agents/agents.py b/llama_stack/apis/agents/agents.py index fb9df21e6..c3f3d21f0 100644 --- a/llama_stack/apis/agents/agents.py +++ b/llama_stack/apis/agents/agents.py @@ -22,12 +22,11 @@ from llama_models.schema_utils import json_schema_type, register_schema, webmeth from pydantic import BaseModel, ConfigDict, Field from typing_extensions import Annotated -from llama_stack.apis.common.content_types import InterleavedContent, URL +from llama_stack.apis.common.content_types import ContentDelta, InterleavedContent, URL from llama_stack.apis.inference import ( CompletionMessage, SamplingParams, ToolCall, - ToolCallDelta, ToolChoice, ToolPromptFormat, ToolResponse, @@ -216,8 +215,7 @@ class AgentTurnResponseStepProgressPayload(BaseModel): step_type: StepType step_id: str - text_delta: Optional[str] = None - tool_call_delta: Optional[ToolCallDelta] = None + delta: ContentDelta @json_schema_type diff --git a/llama_stack/apis/agents/event_logger.py b/llama_stack/apis/agents/event_logger.py index 40a69d19c..41004ccb0 100644 --- a/llama_stack/apis/agents/event_logger.py +++ b/llama_stack/apis/agents/event_logger.py @@ -11,9 +11,13 @@ from llama_models.llama3.api.tool_utils import ToolUtils from termcolor import cprint from llama_stack.apis.agents import AgentTurnResponseEventType, StepType - +from llama_stack.apis.common.content_types import ToolCallParseStatus from llama_stack.apis.inference import ToolResponseMessage +from llama_stack.providers.utils.inference.prompt_adapter import ( + interleaved_content_as_str, +) + class LogEvent: def __init__( @@ -57,8 +61,11 @@ class EventLogger: # since it does not produce event but instead # a Message if isinstance(chunk, ToolResponseMessage): - yield chunk, LogEvent( - role="CustomTool", content=chunk.content, color="grey" + yield ( + chunk, + LogEvent( + role="CustomTool", content=chunk.content, color="grey" + ), ) continue @@ -80,14 +87,20 @@ class EventLogger: ): violation = event.payload.step_details.violation if not violation: - yield event, LogEvent( - role=step_type, content="No Violation", color="magenta" + yield ( + event, + LogEvent( + role=step_type, content="No Violation", color="magenta" + ), ) else: - yield event, LogEvent( - role=step_type, - content=f"{violation.metadata} {violation.user_message}", - color="red", + yield ( + event, + LogEvent( + role=step_type, + content=f"{violation.metadata} {violation.user_message}", + color="red", + ), ) # handle inference @@ -95,8 +108,11 @@ class EventLogger: if stream: if event_type == EventType.step_start.value: # TODO: Currently this event is never received - yield event, LogEvent( - role=step_type, content="", end="", color="yellow" + yield ( + event, + LogEvent( + role=step_type, content="", end="", color="yellow" + ), ) elif event_type == EventType.step_progress.value: # HACK: if previous was not step/event was not inference's step_progress @@ -107,24 +123,34 @@ class EventLogger: previous_event_type != EventType.step_progress.value and previous_step_type != StepType.inference ): - yield event, LogEvent( - role=step_type, content="", end="", color="yellow" + yield ( + event, + LogEvent( + role=step_type, content="", end="", color="yellow" + ), ) - if event.payload.tool_call_delta: - if isinstance(event.payload.tool_call_delta.content, str): - yield event, LogEvent( - role=None, - content=event.payload.tool_call_delta.content, - end="", - color="cyan", + delta = event.payload.delta + if delta.type == "tool_call": + if delta.parse_status == ToolCallParseStatus.success: + yield ( + event, + LogEvent( + role=None, + content=delta.content, + end="", + color="cyan", + ), ) else: - yield event, LogEvent( - role=None, - content=event.payload.text_delta, - end="", - color="yellow", + yield ( + event, + LogEvent( + role=None, + content=delta.text, + end="", + color="yellow", + ), ) else: # step_complete @@ -140,10 +166,13 @@ class EventLogger: ) else: content = response.content - yield event, LogEvent( - role=step_type, - content=content, - color="yellow", + yield ( + event, + LogEvent( + role=step_type, + content=content, + color="yellow", + ), ) # handle tool_execution @@ -155,16 +184,22 @@ class EventLogger: ): details = event.payload.step_details for t in details.tool_calls: - yield event, LogEvent( - role=step_type, - content=f"Tool:{t.tool_name} Args:{t.arguments}", - color="green", + yield ( + event, + LogEvent( + role=step_type, + content=f"Tool:{t.tool_name} Args:{t.arguments}", + color="green", + ), ) for r in details.tool_responses: - yield event, LogEvent( - role=step_type, - content=f"Tool:{r.tool_name} Response:{r.content}", - color="green", + yield ( + event, + LogEvent( + role=step_type, + content=f"Tool:{r.tool_name} Response:{r.content}", + color="green", + ), ) if ( @@ -172,15 +207,16 @@ class EventLogger: and event_type == EventType.step_complete.value ): details = event.payload.step_details - inserted_context = interleaved_text_media_as_str( - details.inserted_context - ) + inserted_context = interleaved_content_as_str(details.inserted_context) content = f"fetched {len(inserted_context)} bytes from {details.memory_bank_ids}" - yield event, LogEvent( - role=step_type, - content=content, - color="cyan", + yield ( + event, + LogEvent( + role=step_type, + content=content, + color="cyan", + ), ) previous_event_type = event_type diff --git a/llama_stack/providers/inline/agents/meta_reference/agent_instance.py b/llama_stack/providers/inline/agents/meta_reference/agent_instance.py index 24448a28f..be33d75c3 100644 --- a/llama_stack/providers/inline/agents/meta_reference/agent_instance.py +++ b/llama_stack/providers/inline/agents/meta_reference/agent_instance.py @@ -40,7 +40,12 @@ from llama_stack.apis.agents import ( ToolExecutionStep, Turn, ) -from llama_stack.apis.common.content_types import TextContentItem, URL +from llama_stack.apis.common.content_types import ( + TextContentItem, + ToolCallDelta, + ToolCallParseStatus, + URL, +) from llama_stack.apis.inference import ( ChatCompletionResponseEventType, CompletionMessage, @@ -49,8 +54,6 @@ from llama_stack.apis.inference import ( SamplingParams, StopReason, SystemMessage, - ToolCallDelta, - ToolCallParseStatus, ToolDefinition, ToolResponse, ToolResponseMessage, @@ -411,7 +414,7 @@ class ChatAgent(ShieldRunnerMixin): payload=AgentTurnResponseStepProgressPayload( step_type=StepType.tool_execution.value, step_id=step_id, - tool_call_delta=ToolCallDelta( + delta=ToolCallDelta( parse_status=ToolCallParseStatus.success, content=ToolCall( call_id="", @@ -507,7 +510,7 @@ class ChatAgent(ShieldRunnerMixin): continue delta = event.delta - if isinstance(delta, ToolCallDelta): + if delta.type == "tool_call": if delta.parse_status == ToolCallParseStatus.success: tool_calls.append(delta.content) if stream: @@ -516,21 +519,20 @@ class ChatAgent(ShieldRunnerMixin): payload=AgentTurnResponseStepProgressPayload( step_type=StepType.inference.value, step_id=step_id, - text_delta="", - tool_call_delta=delta, + delta=delta, ) ) ) - elif isinstance(delta, str): - content += delta + elif delta.type == "text": + content += delta.text if stream and event.stop_reason is None: yield AgentTurnResponseStreamChunk( event=AgentTurnResponseEvent( payload=AgentTurnResponseStepProgressPayload( step_type=StepType.inference.value, step_id=step_id, - text_delta=event.delta, + delta=delta, ) ) ) diff --git a/llama_stack/providers/inline/inference/meta_reference/inference.py b/llama_stack/providers/inline/inference/meta_reference/inference.py index 5b502a581..e099580af 100644 --- a/llama_stack/providers/inline/inference/meta_reference/inference.py +++ b/llama_stack/providers/inline/inference/meta_reference/inference.py @@ -16,6 +16,11 @@ from llama_models.llama3.api.datatypes import ( ) from llama_models.sku_list import resolve_model +from llama_stack.apis.common.content_types import ( + TextDelta, + ToolCallDelta, + ToolCallParseStatus, +) from llama_stack.apis.inference import ( ChatCompletionRequest, ChatCompletionResponse, @@ -32,8 +37,6 @@ from llama_stack.apis.inference import ( Message, ResponseFormat, TokenLogProbs, - ToolCallDelta, - ToolCallParseStatus, ToolChoice, ) from llama_stack.apis.models import Model, ModelType @@ -190,14 +193,14 @@ class MetaReferenceInferenceImpl( ] yield CompletionResponseStreamChunk( - delta=text, + delta=TextDelta(text=text), stop_reason=stop_reason, logprobs=logprobs if request.logprobs else None, ) if stop_reason is None: yield CompletionResponseStreamChunk( - delta="", + delta=TextDelta(text=""), stop_reason=StopReason.out_of_tokens, ) @@ -352,7 +355,7 @@ class MetaReferenceInferenceImpl( yield ChatCompletionResponseStreamChunk( event=ChatCompletionResponseEvent( event_type=ChatCompletionResponseEventType.start, - delta="", + delta=TextDelta(text=""), ) ) @@ -392,7 +395,7 @@ class MetaReferenceInferenceImpl( parse_status=ToolCallParseStatus.in_progress, ) else: - delta = text + delta = TextDelta(text=text) if stop_reason is None: if request.logprobs: @@ -449,7 +452,7 @@ class MetaReferenceInferenceImpl( yield ChatCompletionResponseStreamChunk( event=ChatCompletionResponseEvent( event_type=ChatCompletionResponseEventType.complete, - delta="", + delta=TextDelta(text=""), stop_reason=stop_reason, ) ) diff --git a/llama_stack/providers/remote/inference/groq/groq_utils.py b/llama_stack/providers/remote/inference/groq/groq_utils.py index 032f4c8d4..b87c0c94c 100644 --- a/llama_stack/providers/remote/inference/groq/groq_utils.py +++ b/llama_stack/providers/remote/inference/groq/groq_utils.py @@ -30,6 +30,11 @@ from groq.types.shared.function_definition import FunctionDefinition from llama_models.llama3.api.datatypes import ToolParamDefinition +from llama_stack.apis.common.content_types import ( + TextDelta, + ToolCallDelta, + ToolCallParseStatus, +) from llama_stack.apis.inference import ( ChatCompletionRequest, ChatCompletionResponse, @@ -40,8 +45,6 @@ from llama_stack.apis.inference import ( Message, StopReason, ToolCall, - ToolCallDelta, - ToolCallParseStatus, ToolDefinition, ToolPromptFormat, ) @@ -162,7 +165,7 @@ def convert_chat_completion_response( def _map_finish_reason_to_stop_reason( - finish_reason: Literal["stop", "length", "tool_calls"] + finish_reason: Literal["stop", "length", "tool_calls"], ) -> StopReason: """ Convert a Groq chat completion finish_reason to a StopReason. @@ -185,7 +188,6 @@ def _map_finish_reason_to_stop_reason( async def convert_chat_completion_response_stream( stream: Stream[ChatCompletionChunk], ) -> AsyncGenerator[ChatCompletionResponseStreamChunk, None]: - event_type = ChatCompletionResponseEventType.start for chunk in stream: choice = chunk.choices[0] @@ -194,7 +196,7 @@ async def convert_chat_completion_response_stream( yield ChatCompletionResponseStreamChunk( event=ChatCompletionResponseEvent( event_type=ChatCompletionResponseEventType.complete, - delta=choice.delta.content or "", + delta=TextDelta(text=choice.delta.content or ""), logprobs=None, stop_reason=_map_finish_reason_to_stop_reason(choice.finish_reason), ) @@ -221,7 +223,7 @@ async def convert_chat_completion_response_stream( yield ChatCompletionResponseStreamChunk( event=ChatCompletionResponseEvent( event_type=event_type, - delta=choice.delta.content or "", + delta=TextDelta(text=choice.delta.content or ""), logprobs=None, ) ) diff --git a/llama_stack/providers/remote/inference/nvidia/openai_utils.py b/llama_stack/providers/remote/inference/nvidia/openai_utils.py index dcc7c5fca..955b65aa5 100644 --- a/llama_stack/providers/remote/inference/nvidia/openai_utils.py +++ b/llama_stack/providers/remote/inference/nvidia/openai_utils.py @@ -34,6 +34,11 @@ from openai.types.chat.chat_completion_message_tool_call_param import ( from openai.types.completion import Completion as OpenAICompletion from openai.types.completion_choice import Logprobs as OpenAICompletionLogprobs +from llama_stack.apis.common.content_types import ( + TextDelta, + ToolCallDelta, + ToolCallParseStatus, +) from llama_stack.apis.inference import ( ChatCompletionRequest, ChatCompletionResponse, @@ -48,8 +53,6 @@ from llama_stack.apis.inference import ( Message, SystemMessage, TokenLogProbs, - ToolCallDelta, - ToolCallParseStatus, ToolResponseMessage, UserMessage, ) @@ -432,69 +435,6 @@ async def convert_openai_chat_completion_stream( """ Convert a stream of OpenAI chat completion chunks into a stream of ChatCompletionResponseStreamChunk. - - OpenAI ChatCompletionChunk: - choices: List[Choice] - - OpenAI Choice: # different from the non-streamed Choice - delta: ChoiceDelta - finish_reason: Optional[Literal["stop", "length", "tool_calls", "content_filter", "function_call"]] - logprobs: Optional[ChoiceLogprobs] - - OpenAI ChoiceDelta: - content: Optional[str] - role: Optional[Literal["system", "user", "assistant", "tool"]] - tool_calls: Optional[List[ChoiceDeltaToolCall]] - - OpenAI ChoiceDeltaToolCall: - index: int - id: Optional[str] - function: Optional[ChoiceDeltaToolCallFunction] - type: Optional[Literal["function"]] - - OpenAI ChoiceDeltaToolCallFunction: - name: Optional[str] - arguments: Optional[str] - - -> - - ChatCompletionResponseStreamChunk: - event: ChatCompletionResponseEvent - - ChatCompletionResponseEvent: - event_type: ChatCompletionResponseEventType - delta: Union[str, ToolCallDelta] - logprobs: Optional[List[TokenLogProbs]] - stop_reason: Optional[StopReason] - - ChatCompletionResponseEventType: - start = "start" - progress = "progress" - complete = "complete" - - ToolCallDelta: - content: Union[str, ToolCall] - parse_status: ToolCallParseStatus - - ToolCall: - call_id: str - tool_name: str - arguments: str - - ToolCallParseStatus: - started = "started" - in_progress = "in_progress" - failure = "failure" - success = "success" - - TokenLogProbs: - logprobs_by_token: Dict[str, float] - - token, logprob - - StopReason: - end_of_turn = "end_of_turn" - end_of_message = "end_of_message" - out_of_tokens = "out_of_tokens" """ # generate a stream of ChatCompletionResponseEventType: start -> progress -> progress -> ... @@ -543,7 +483,7 @@ async def convert_openai_chat_completion_stream( yield ChatCompletionResponseStreamChunk( event=ChatCompletionResponseEvent( event_type=next(event_type), - delta=choice.delta.content, + delta=TextDelta(text=choice.delta.content), logprobs=_convert_openai_logprobs(choice.logprobs), ) ) @@ -570,7 +510,7 @@ async def convert_openai_chat_completion_stream( yield ChatCompletionResponseStreamChunk( event=ChatCompletionResponseEvent( event_type=next(event_type), - delta=choice.delta.content or "", # content is not optional + delta=TextDelta(text=choice.delta.content or ""), logprobs=_convert_openai_logprobs(choice.logprobs), ) ) @@ -578,7 +518,7 @@ async def convert_openai_chat_completion_stream( yield ChatCompletionResponseStreamChunk( event=ChatCompletionResponseEvent( event_type=ChatCompletionResponseEventType.complete, - delta="", + delta=TextDelta(text=""), stop_reason=stop_reason, ) ) @@ -653,18 +593,6 @@ def _convert_openai_completion_logprobs( ) -> Optional[List[TokenLogProbs]]: """ Convert an OpenAI CompletionLogprobs into a list of TokenLogProbs. - - OpenAI CompletionLogprobs: - text_offset: Optional[List[int]] - token_logprobs: Optional[List[float]] - tokens: Optional[List[str]] - top_logprobs: Optional[List[Dict[str, float]]] - - -> - - TokenLogProbs: - logprobs_by_token: Dict[str, float] - - token, logprob """ if not logprobs: return None @@ -679,28 +607,6 @@ def convert_openai_completion_choice( ) -> CompletionResponse: """ Convert an OpenAI Completion Choice into a CompletionResponse. - - OpenAI Completion Choice: - text: str - finish_reason: str - logprobs: Optional[ChoiceLogprobs] - - -> - - CompletionResponse: - completion_message: CompletionMessage - logprobs: Optional[List[TokenLogProbs]] - - CompletionMessage: - role: Literal["assistant"] - content: str | ImageMedia | List[str | ImageMedia] - stop_reason: StopReason - tool_calls: List[ToolCall] - - class StopReason(Enum): - end_of_turn = "end_of_turn" - end_of_message = "end_of_message" - out_of_tokens = "out_of_tokens" """ return CompletionResponse( content=choice.text, @@ -715,32 +621,11 @@ async def convert_openai_completion_stream( """ Convert a stream of OpenAI Completions into a stream of ChatCompletionResponseStreamChunks. - - OpenAI Completion: - id: str - choices: List[OpenAICompletionChoice] - created: int - model: str - system_fingerprint: Optional[str] - usage: Optional[OpenAICompletionUsage] - - OpenAI CompletionChoice: - finish_reason: str - index: int - logprobs: Optional[OpenAILogprobs] - text: str - - -> - - CompletionResponseStreamChunk: - delta: str - stop_reason: Optional[StopReason] - logprobs: Optional[List[TokenLogProbs]] """ async for chunk in stream: choice = chunk.choices[0] yield CompletionResponseStreamChunk( - delta=choice.text, + delta=TextDelta(text=choice.text), stop_reason=_convert_openai_finish_reason(choice.finish_reason), logprobs=_convert_openai_completion_logprobs(choice.logprobs), ) diff --git a/llama_stack/providers/tests/inference/test_text_inference.py b/llama_stack/providers/tests/inference/test_text_inference.py index 19cc8393c..24093cb59 100644 --- a/llama_stack/providers/tests/inference/test_text_inference.py +++ b/llama_stack/providers/tests/inference/test_text_inference.py @@ -18,6 +18,7 @@ from llama_models.llama3.api.datatypes import ( from pydantic import BaseModel, ValidationError +from llama_stack.apis.common.content_types import ToolCallParseStatus from llama_stack.apis.inference import ( ChatCompletionResponse, ChatCompletionResponseEventType, @@ -27,8 +28,6 @@ from llama_stack.apis.inference import ( JsonSchemaResponseFormat, LogProbConfig, SystemMessage, - ToolCallDelta, - ToolCallParseStatus, ToolChoice, UserMessage, ) @@ -196,7 +195,9 @@ class TestInference: 1 <= len(chunks) <= 6 ) # why 6 and not 5? the response may have an extra closing chunk, e.g. for usage or stop_reason for chunk in chunks: - if chunk.delta: # if there's a token, we expect logprobs + if ( + chunk.delta.type == "text" and chunk.delta.text + ): # if there's a token, we expect logprobs assert chunk.logprobs, "Logprobs should not be empty" assert all( len(logprob.logprobs_by_token) == 3 for logprob in chunk.logprobs @@ -463,7 +464,7 @@ class TestInference: if "Llama3.1" in inference_model: assert all( - isinstance(chunk.event.delta, ToolCallDelta) + chunk.event.delta.type == "tool_call" for chunk in grouped[ChatCompletionResponseEventType.progress] ) first = grouped[ChatCompletionResponseEventType.progress][0] @@ -475,7 +476,7 @@ class TestInference: last = grouped[ChatCompletionResponseEventType.progress][-1] # assert last.event.stop_reason == expected_stop_reason assert last.event.delta.parse_status == ToolCallParseStatus.success - assert isinstance(last.event.delta.content, ToolCall) + assert last.event.delta.content.type == "tool_call" call = last.event.delta.content assert call.tool_name == "get_weather" diff --git a/llama_stack/providers/utils/inference/openai_compat.py b/llama_stack/providers/utils/inference/openai_compat.py index ba63be2b6..e70ad4033 100644 --- a/llama_stack/providers/utils/inference/openai_compat.py +++ b/llama_stack/providers/utils/inference/openai_compat.py @@ -11,7 +11,13 @@ from llama_models.llama3.api.chat_format import ChatFormat from llama_models.llama3.api.datatypes import SamplingParams, StopReason from pydantic import BaseModel -from llama_stack.apis.common.content_types import ImageContentItem, TextContentItem +from llama_stack.apis.common.content_types import ( + ImageContentItem, + TextContentItem, + TextDelta, + ToolCallDelta, + ToolCallParseStatus, +) from llama_stack.apis.inference import ( ChatCompletionResponse, @@ -22,8 +28,6 @@ from llama_stack.apis.inference import ( CompletionResponse, CompletionResponseStreamChunk, Message, - ToolCallDelta, - ToolCallParseStatus, ) from llama_stack.providers.utils.inference.prompt_adapter import ( @@ -138,7 +142,7 @@ async def process_completion_stream_response( text = "" continue yield CompletionResponseStreamChunk( - delta=text, + delta=TextDelta(text=text), stop_reason=stop_reason, ) if finish_reason: @@ -149,7 +153,7 @@ async def process_completion_stream_response( break yield CompletionResponseStreamChunk( - delta="", + delta=TextDelta(text=""), stop_reason=stop_reason, ) @@ -160,7 +164,7 @@ async def process_chat_completion_stream_response( yield ChatCompletionResponseStreamChunk( event=ChatCompletionResponseEvent( event_type=ChatCompletionResponseEventType.start, - delta="", + delta=TextDelta(text=""), ) ) @@ -227,7 +231,7 @@ async def process_chat_completion_stream_response( yield ChatCompletionResponseStreamChunk( event=ChatCompletionResponseEvent( event_type=ChatCompletionResponseEventType.progress, - delta=text, + delta=TextDelta(text=text), stop_reason=stop_reason, ) ) @@ -262,7 +266,7 @@ async def process_chat_completion_stream_response( yield ChatCompletionResponseStreamChunk( event=ChatCompletionResponseEvent( event_type=ChatCompletionResponseEventType.complete, - delta="", + delta=TextDelta(text=""), stop_reason=stop_reason, ) ) From d9d34433fc8814f445f83db559b824f2a2104ba2 Mon Sep 17 00:00:00 2001 From: Ashwin Bharambe Date: Mon, 13 Jan 2025 20:06:49 -0800 Subject: [PATCH 3/6] Update spec --- docs/resources/llama-stack-spec.html | 82 +++++++++++++++---- docs/resources/llama-stack-spec.yaml | 55 ++++++++++--- llama_stack/apis/agents/event_logger.py | 2 +- .../agents/meta_reference/agent_instance.py | 4 +- .../inference/meta_reference/inference.py | 4 +- .../remote/inference/groq/groq_utils.py | 2 +- .../remote/inference/nvidia/openai_utils.py | 2 +- .../tests/inference/test_text_inference.py | 2 +- .../utils/inference/openai_compat.py | 4 +- 9 files changed, 118 insertions(+), 39 deletions(-) diff --git a/docs/resources/llama-stack-spec.html b/docs/resources/llama-stack-spec.html index 0ce216479..5ed8701a4 100644 --- a/docs/resources/llama-stack-spec.html +++ b/docs/resources/llama-stack-spec.html @@ -3843,8 +3843,8 @@ "properties": { "role": { "type": "string", - "const": "ipython", - "default": "ipython" + "const": "tool", + "default": "tool" }, "call_id": { "type": "string" @@ -4185,14 +4185,7 @@ "$ref": "#/components/schemas/ChatCompletionResponseEventType" }, "delta": { - "oneOf": [ - { - "type": "string" - }, - { - "$ref": "#/components/schemas/ToolCallDelta" - } - ] + "$ref": "#/components/schemas/ContentDelta" }, "logprobs": { "type": "array", @@ -4232,6 +4225,50 @@ ], "title": "SSE-stream of these events." }, + "ContentDelta": { + "oneOf": [ + { + "type": "object", + "properties": { + "type": { + "type": "string", + "const": "text", + "default": "text" + }, + "text": { + "type": "string" + } + }, + "additionalProperties": false, + "required": [ + "type", + "text" + ] + }, + { + "type": "object", + "properties": { + "type": { + "type": "string", + "const": "image", + "default": "image" + }, + "data": { + "type": "string", + "contentEncoding": "base64" + } + }, + "additionalProperties": false, + "required": [ + "type", + "data" + ] + }, + { + "$ref": "#/components/schemas/ToolCallDelta" + } + ] + }, "TokenLogProbs": { "type": "object", "properties": { @@ -4250,6 +4287,11 @@ "ToolCallDelta": { "type": "object", "properties": { + "type": { + "type": "string", + "const": "tool_call", + "default": "tool_call" + }, "content": { "oneOf": [ { @@ -4266,6 +4308,7 @@ }, "additionalProperties": false, "required": [ + "type", "content", "parse_status" ] @@ -4275,8 +4318,8 @@ "enum": [ "started", "in_progress", - "failure", - "success" + "failed", + "succeeded" ] }, "CompletionRequest": { @@ -4777,18 +4820,16 @@ "step_id": { "type": "string" }, - "text_delta": { - "type": "string" - }, - "tool_call_delta": { - "$ref": "#/components/schemas/ToolCallDelta" + "delta": { + "$ref": "#/components/schemas/ContentDelta" } }, "additionalProperties": false, "required": [ "event_type", "step_type", - "step_id" + "step_id", + "delta" ] }, "AgentTurnResponseStepStartPayload": { @@ -8758,6 +8799,10 @@ "name": "CompletionResponseStreamChunk", "description": "streamed completion response.\n\n" }, + { + "name": "ContentDelta", + "description": "" + }, { "name": "CreateAgentRequest", "description": "" @@ -9392,6 +9437,7 @@ "CompletionRequest", "CompletionResponse", "CompletionResponseStreamChunk", + "ContentDelta", "CreateAgentRequest", "CreateAgentSessionRequest", "CreateAgentTurnRequest", diff --git a/docs/resources/llama-stack-spec.yaml b/docs/resources/llama-stack-spec.yaml index 031178ce9..2a573959f 100644 --- a/docs/resources/llama-stack-spec.yaml +++ b/docs/resources/llama-stack-spec.yaml @@ -150,6 +150,8 @@ components: AgentTurnResponseStepProgressPayload: additionalProperties: false properties: + delta: + $ref: '#/components/schemas/ContentDelta' event_type: const: step_progress default: step_progress @@ -163,14 +165,11 @@ components: - shield_call - memory_retrieval type: string - text_delta: - type: string - tool_call_delta: - $ref: '#/components/schemas/ToolCallDelta' required: - event_type - step_type - step_id + - delta type: object AgentTurnResponseStepStartPayload: additionalProperties: false @@ -462,9 +461,7 @@ components: additionalProperties: false properties: delta: - oneOf: - - type: string - - $ref: '#/components/schemas/ToolCallDelta' + $ref: '#/components/schemas/ContentDelta' event_type: $ref: '#/components/schemas/ChatCompletionResponseEventType' logprobs: @@ -571,6 +568,34 @@ components: - delta title: streamed completion response. type: object + ContentDelta: + oneOf: + - additionalProperties: false + properties: + text: + type: string + type: + const: text + default: text + type: string + required: + - type + - text + type: object + - additionalProperties: false + properties: + data: + contentEncoding: base64 + type: string + type: + const: image + default: image + type: string + required: + - type + - data + type: object + - $ref: '#/components/schemas/ToolCallDelta' CreateAgentRequest: additionalProperties: false properties: @@ -2664,7 +2689,12 @@ components: - $ref: '#/components/schemas/ToolCall' parse_status: $ref: '#/components/schemas/ToolCallParseStatus' + type: + const: tool_call + default: tool_call + type: string required: + - type - content - parse_status type: object @@ -2672,8 +2702,8 @@ components: enum: - started - in_progress - - failure - - success + - failed + - succeeded type: string ToolChoice: enum: @@ -2888,8 +2918,8 @@ components: content: $ref: '#/components/schemas/InterleavedContent' role: - const: ipython - default: ipython + const: tool + default: tool type: string tool_name: oneOf: @@ -5500,6 +5530,8 @@ tags: ' name: CompletionResponseStreamChunk +- description: + name: ContentDelta - description: name: CreateAgentRequest @@ -5939,6 +5971,7 @@ x-tagGroups: - CompletionRequest - CompletionResponse - CompletionResponseStreamChunk + - ContentDelta - CreateAgentRequest - CreateAgentSessionRequest - CreateAgentTurnRequest diff --git a/llama_stack/apis/agents/event_logger.py b/llama_stack/apis/agents/event_logger.py index 41004ccb0..9e2f14805 100644 --- a/llama_stack/apis/agents/event_logger.py +++ b/llama_stack/apis/agents/event_logger.py @@ -132,7 +132,7 @@ class EventLogger: delta = event.payload.delta if delta.type == "tool_call": - if delta.parse_status == ToolCallParseStatus.success: + if delta.parse_status == ToolCallParseStatus.succeeded: yield ( event, LogEvent( diff --git a/llama_stack/providers/inline/agents/meta_reference/agent_instance.py b/llama_stack/providers/inline/agents/meta_reference/agent_instance.py index be33d75c3..2299e80d1 100644 --- a/llama_stack/providers/inline/agents/meta_reference/agent_instance.py +++ b/llama_stack/providers/inline/agents/meta_reference/agent_instance.py @@ -415,7 +415,7 @@ class ChatAgent(ShieldRunnerMixin): step_type=StepType.tool_execution.value, step_id=step_id, delta=ToolCallDelta( - parse_status=ToolCallParseStatus.success, + parse_status=ToolCallParseStatus.succeeded, content=ToolCall( call_id="", tool_name=MEMORY_QUERY_TOOL, @@ -511,7 +511,7 @@ class ChatAgent(ShieldRunnerMixin): delta = event.delta if delta.type == "tool_call": - if delta.parse_status == ToolCallParseStatus.success: + if delta.parse_status == ToolCallParseStatus.succeeded: tool_calls.append(delta.content) if stream: yield AgentTurnResponseStreamChunk( diff --git a/llama_stack/providers/inline/inference/meta_reference/inference.py b/llama_stack/providers/inline/inference/meta_reference/inference.py index e099580af..d64d32f03 100644 --- a/llama_stack/providers/inline/inference/meta_reference/inference.py +++ b/llama_stack/providers/inline/inference/meta_reference/inference.py @@ -431,7 +431,7 @@ class MetaReferenceInferenceImpl( event_type=ChatCompletionResponseEventType.progress, delta=ToolCallDelta( content="", - parse_status=ToolCallParseStatus.failure, + parse_status=ToolCallParseStatus.failed, ), stop_reason=stop_reason, ) @@ -443,7 +443,7 @@ class MetaReferenceInferenceImpl( event_type=ChatCompletionResponseEventType.progress, delta=ToolCallDelta( content=tool_call, - parse_status=ToolCallParseStatus.success, + parse_status=ToolCallParseStatus.succeeded, ), stop_reason=stop_reason, ) diff --git a/llama_stack/providers/remote/inference/groq/groq_utils.py b/llama_stack/providers/remote/inference/groq/groq_utils.py index b87c0c94c..11f684847 100644 --- a/llama_stack/providers/remote/inference/groq/groq_utils.py +++ b/llama_stack/providers/remote/inference/groq/groq_utils.py @@ -215,7 +215,7 @@ async def convert_chat_completion_response_stream( event_type=event_type, delta=ToolCallDelta( content=tool_call, - parse_status=ToolCallParseStatus.success, + parse_status=ToolCallParseStatus.succeeded, ), ) ) diff --git a/llama_stack/providers/remote/inference/nvidia/openai_utils.py b/llama_stack/providers/remote/inference/nvidia/openai_utils.py index 955b65aa5..975812844 100644 --- a/llama_stack/providers/remote/inference/nvidia/openai_utils.py +++ b/llama_stack/providers/remote/inference/nvidia/openai_utils.py @@ -501,7 +501,7 @@ async def convert_openai_chat_completion_stream( event_type=next(event_type), delta=ToolCallDelta( content=_convert_openai_tool_calls(choice.delta.tool_calls)[0], - parse_status=ToolCallParseStatus.success, + parse_status=ToolCallParseStatus.succeeded, ), logprobs=_convert_openai_logprobs(choice.logprobs), ) diff --git a/llama_stack/providers/tests/inference/test_text_inference.py b/llama_stack/providers/tests/inference/test_text_inference.py index 24093cb59..932ae36e6 100644 --- a/llama_stack/providers/tests/inference/test_text_inference.py +++ b/llama_stack/providers/tests/inference/test_text_inference.py @@ -475,7 +475,7 @@ class TestInference: last = grouped[ChatCompletionResponseEventType.progress][-1] # assert last.event.stop_reason == expected_stop_reason - assert last.event.delta.parse_status == ToolCallParseStatus.success + assert last.event.delta.parse_status == ToolCallParseStatus.succeeded assert last.event.delta.content.type == "tool_call" call = last.event.delta.content diff --git a/llama_stack/providers/utils/inference/openai_compat.py b/llama_stack/providers/utils/inference/openai_compat.py index e70ad4033..82e01c364 100644 --- a/llama_stack/providers/utils/inference/openai_compat.py +++ b/llama_stack/providers/utils/inference/openai_compat.py @@ -245,7 +245,7 @@ async def process_chat_completion_stream_response( event_type=ChatCompletionResponseEventType.progress, delta=ToolCallDelta( content="", - parse_status=ToolCallParseStatus.failure, + parse_status=ToolCallParseStatus.failed, ), stop_reason=stop_reason, ) @@ -257,7 +257,7 @@ async def process_chat_completion_stream_response( event_type=ChatCompletionResponseEventType.progress, delta=ToolCallDelta( content=tool_call, - parse_status=ToolCallParseStatus.success, + parse_status=ToolCallParseStatus.succeeded, ), stop_reason=stop_reason, ) From 2c2969f3312bb70ea4e745ebe3c9a1c4c7c45308 Mon Sep 17 00:00:00 2001 From: Ashwin Bharambe Date: Mon, 13 Jan 2025 23:16:16 -0800 Subject: [PATCH 4/6] Fixes; make inference tests pass with newer tool call types --- llama_stack/distribution/store/registry.py | 2 +- .../utils/inference/openai_compat.py | 4 +-- .../utils/inference/prompt_adapter.py | 1 + tests/client-sdk/conftest.py | 6 ++++ tests/client-sdk/inference/test_inference.py | 36 ++++++++----------- 5 files changed, 24 insertions(+), 25 deletions(-) diff --git a/llama_stack/distribution/store/registry.py b/llama_stack/distribution/store/registry.py index d26b4447c..010d137ec 100644 --- a/llama_stack/distribution/store/registry.py +++ b/llama_stack/distribution/store/registry.py @@ -35,7 +35,7 @@ class DistributionRegistry(Protocol): REGISTER_PREFIX = "distributions:registry" -KEY_VERSION = "v4" +KEY_VERSION = "v5" KEY_FORMAT = f"{REGISTER_PREFIX}:{KEY_VERSION}::" + "{type}:{identifier}" diff --git a/llama_stack/providers/utils/inference/openai_compat.py b/llama_stack/providers/utils/inference/openai_compat.py index 82e01c364..4c46954cf 100644 --- a/llama_stack/providers/utils/inference/openai_compat.py +++ b/llama_stack/providers/utils/inference/openai_compat.py @@ -142,7 +142,7 @@ async def process_completion_stream_response( text = "" continue yield CompletionResponseStreamChunk( - delta=TextDelta(text=text), + delta=text, stop_reason=stop_reason, ) if finish_reason: @@ -153,7 +153,7 @@ async def process_completion_stream_response( break yield CompletionResponseStreamChunk( - delta=TextDelta(text=""), + delta="", stop_reason=stop_reason, ) diff --git a/llama_stack/providers/utils/inference/prompt_adapter.py b/llama_stack/providers/utils/inference/prompt_adapter.py index 2d66dc60b..de4918f5c 100644 --- a/llama_stack/providers/utils/inference/prompt_adapter.py +++ b/llama_stack/providers/utils/inference/prompt_adapter.py @@ -265,6 +265,7 @@ def chat_completion_request_to_messages( For eg. for llama_3_1, add system message with the appropriate tools or add user messsage for custom tools, etc. """ + assert llama_model is not None, "llama_model is required" model = resolve_model(llama_model) if model is None: log.error(f"Could not resolve model {llama_model}") diff --git a/tests/client-sdk/conftest.py b/tests/client-sdk/conftest.py index 16e6d1bbd..b40d54ee5 100644 --- a/tests/client-sdk/conftest.py +++ b/tests/client-sdk/conftest.py @@ -12,6 +12,11 @@ from llama_stack.providers.tests.env import get_env_or_fail from llama_stack_client import LlamaStackClient +def pytest_configure(config): + config.option.tbstyle = "short" + config.option.disable_warnings = True + + @pytest.fixture(scope="session") def provider_data(): # check env for tavily secret, brave secret and inject all into provider data @@ -29,6 +34,7 @@ def llama_stack_client(provider_data): client = LlamaStackAsLibraryClient( get_env_or_fail("LLAMA_STACK_CONFIG"), provider_data=provider_data, + skip_logger_removal=True, ) client.initialize() elif os.environ.get("LLAMA_STACK_BASE_URL"): diff --git a/tests/client-sdk/inference/test_inference.py b/tests/client-sdk/inference/test_inference.py index ef6219389..a50dba3a0 100644 --- a/tests/client-sdk/inference/test_inference.py +++ b/tests/client-sdk/inference/test_inference.py @@ -6,9 +6,9 @@ import pytest -from llama_stack_client.lib.inference.event_logger import EventLogger from pydantic import BaseModel + PROVIDER_TOOL_PROMPT_FORMAT = { "remote::ollama": "python_list", "remote::together": "json", @@ -39,7 +39,7 @@ def text_model_id(llama_stack_client): available_models = [ model.identifier for model in llama_stack_client.models.list() - if model.identifier.startswith("meta-llama") + if model.identifier.startswith("meta-llama") and "405" not in model.identifier ] assert len(available_models) > 0 return available_models[0] @@ -208,12 +208,9 @@ def test_text_chat_completion_streaming( stream=True, ) streamed_content = [ - str(log.content.lower().strip()) - for log in EventLogger().log(response) - if log is not None + str(chunk.event.delta.text.lower().strip()) for chunk in response ] assert len(streamed_content) > 0 - assert "assistant>" in streamed_content[0] assert expected.lower() in "".join(streamed_content) @@ -250,17 +247,16 @@ def test_text_chat_completion_with_tool_calling_and_non_streaming( def extract_tool_invocation_content(response): text_content: str = "" tool_invocation_content: str = "" - for log in EventLogger().log(response): - if log is None: - continue - if isinstance(log.content, str): - text_content += log.content - elif isinstance(log.content, object): - if isinstance(log.content.content, str): - continue - elif isinstance(log.content.content, object): - tool_invocation_content += f"[{log.content.content.tool_name}, {log.content.content.arguments}]" - + for chunk in response: + delta = chunk.event.delta + if delta.type == "text": + text_content += delta.text + elif delta.type == "tool_call": + if isinstance(delta.content, str): + tool_invocation_content += delta.content + else: + call = delta.content + tool_invocation_content += f"[{call.tool_name}, {call.arguments}]" return text_content, tool_invocation_content @@ -280,7 +276,6 @@ def test_text_chat_completion_with_tool_calling_and_streaming( ) text_content, tool_invocation_content = extract_tool_invocation_content(response) - assert "Assistant>" in text_content assert tool_invocation_content == "[get_weather, {'location': 'San Francisco, CA'}]" @@ -368,10 +363,7 @@ def test_image_chat_completion_streaming(llama_stack_client, vision_model_id): stream=True, ) streamed_content = [ - str(log.content.lower().strip()) - for log in EventLogger().log(response) - if log is not None + str(chunk.event.delta.text.lower().strip()) for chunk in response ] assert len(streamed_content) > 0 - assert "assistant>" in streamed_content[0] assert any(expected in streamed_content for expected in {"dog", "puppy", "pup"}) From 194d12b304cd0ec68f79235beea0a7fb2cbb16b9 Mon Sep 17 00:00:00 2001 From: Xi Yan Date: Tue, 14 Jan 2025 10:58:46 -0800 Subject: [PATCH 5/6] [bugfix] fix streaming GeneratorExit exception with LlamaStackAsLibraryClient (#760) # What does this PR do? #### Issue - Using Jupyter notebook with LlamaStackAsLibraryClient + streaming gives exception ``` Exception ignored in: Traceback (most recent call last): File "/opt/anaconda3/envs/fresh/lib/python3.11/site-packages/httpcore/_async/connection_pool.py", line 404, in _aiter_ yield part RuntimeError: async generator ignored GeneratorExit ``` - Reproduce w/ https://github.com/meta-llama/llama-stack/blob/notebook-streaming-debug/inline.ipynb #### Fix - Issue likely comes from stream_across_asyncio_run_boundary closing connection too soon when interacting in jupyter environment - This uses an alternative way to convert AsyncStream to SyncStream return type by sync version of LlamaStackAsLibraryClient, which calls AsyncLlamaStackAsLibraryClient calling async impls under the hood #### Additional changes - Moved tracing logic into AsyncLlamaStackAsLibraryClient.request s.t. streaming / non-streaming request for LlamaStackAsLibraryClient shares same code ## Test Plan - Test w/ together & fireworks & ollama with streaming and non-streaming using notebook in: https://github.com/meta-llama/llama-stack/blob/notebook-streaming-debug/inline.ipynb - Note: need to restart kernel and run pip install -e . in jupyter interpreter for local code change to take effect image ## Sources Please link relevant resources if necessary. ## Before submitting - [ ] This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case). - [ ] Ran pre-commit to handle lint / formatting issues. - [ ] Read the [contributor guideline](https://github.com/meta-llama/llama-stack/blob/main/CONTRIBUTING.md), Pull Request section? - [ ] Updated relevant documentation. - [ ] Wrote necessary unit or integration tests. --- llama_stack/distribution/library_client.py | 130 +++++---------------- 1 file changed, 31 insertions(+), 99 deletions(-) diff --git a/llama_stack/distribution/library_client.py b/llama_stack/distribution/library_client.py index 50af2cdea..0c124e64b 100644 --- a/llama_stack/distribution/library_client.py +++ b/llama_stack/distribution/library_client.py @@ -9,12 +9,10 @@ import inspect import json import logging import os -import queue -import threading from concurrent.futures import ThreadPoolExecutor from enum import Enum from pathlib import Path -from typing import Any, Generator, get_args, get_origin, Optional, TypeVar +from typing import Any, get_args, get_origin, Optional, TypeVar import httpx import yaml @@ -64,71 +62,6 @@ def in_notebook(): return True -def stream_across_asyncio_run_boundary( - async_gen_maker, - pool_executor: ThreadPoolExecutor, - path: Optional[str] = None, - provider_data: Optional[dict[str, Any]] = None, -) -> Generator[T, None, None]: - result_queue = queue.Queue() - stop_event = threading.Event() - - async def consumer(): - # make sure we make the generator in the event loop context - gen = await async_gen_maker() - await start_trace(path, {"__location__": "library_client"}) - if provider_data: - set_request_provider_data( - {"X-LlamaStack-Provider-Data": json.dumps(provider_data)} - ) - try: - async for item in await gen: - result_queue.put(item) - except Exception as e: - print(f"Error in generator {e}") - result_queue.put(e) - except asyncio.CancelledError: - return - finally: - result_queue.put(StopIteration) - stop_event.set() - await end_trace() - - def run_async(): - # Run our own loop to avoid double async generator cleanup which is done - # by asyncio.run() - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - try: - task = loop.create_task(consumer()) - loop.run_until_complete(task) - finally: - # Handle pending tasks like a generator's athrow() - pending = asyncio.all_tasks(loop) - if pending: - loop.run_until_complete( - asyncio.gather(*pending, return_exceptions=True) - ) - loop.close() - - future = pool_executor.submit(run_async) - - try: - # yield results as they come in - while not stop_event.is_set() or not result_queue.empty(): - try: - item = result_queue.get(timeout=0.1) - if item is StopIteration: - break - if isinstance(item, Exception): - raise item - yield item - except queue.Empty: - continue - finally: - future.result() - - def convert_pydantic_to_json_value(value: Any) -> Any: if isinstance(value, Enum): return value.value @@ -184,7 +117,7 @@ class LlamaStackAsLibraryClient(LlamaStackClient): ): super().__init__() self.async_client = AsyncLlamaStackAsLibraryClient( - config_path_or_template_name, custom_provider_registry + config_path_or_template_name, custom_provider_registry, provider_data ) self.pool_executor = ThreadPoolExecutor(max_workers=4) self.skip_logger_removal = skip_logger_removal @@ -210,39 +143,30 @@ class LlamaStackAsLibraryClient(LlamaStackClient): root_logger.removeHandler(handler) print(f"Removed handler {handler.__class__.__name__} from root logger") - def _get_path( - self, - cast_to: Any, - options: Any, - *, - stream=False, - stream_cls=None, - ): - return options.url - def request(self, *args, **kwargs): - path = self._get_path(*args, **kwargs) if kwargs.get("stream"): - return stream_across_asyncio_run_boundary( - lambda: self.async_client.request(*args, **kwargs), - self.pool_executor, - path=path, - provider_data=self.provider_data, - ) - else: + # NOTE: We are using AsyncLlamaStackClient under the hood + # A new event loop is needed to convert the AsyncStream + # from async client into SyncStream return type for streaming + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) - async def _traced_request(): - if self.provider_data: - set_request_provider_data( - {"X-LlamaStack-Provider-Data": json.dumps(self.provider_data)} - ) - await start_trace(path, {"__location__": "library_client"}) + def sync_generator(): try: - return await self.async_client.request(*args, **kwargs) + async_stream = loop.run_until_complete( + self.async_client.request(*args, **kwargs) + ) + while True: + chunk = loop.run_until_complete(async_stream.__anext__()) + yield chunk + except StopAsyncIteration: + pass finally: - await end_trace() + loop.close() - return asyncio.run(_traced_request()) + return sync_generator() + else: + return asyncio.run(self.async_client.request(*args, **kwargs)) class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient): @@ -250,9 +174,9 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient): self, config_path_or_template_name: str, custom_provider_registry: Optional[ProviderRegistry] = None, + provider_data: Optional[dict[str, Any]] = None, ): super().__init__() - # when using the library client, we should not log to console since many # of our logs are intended for server-side usage current_sinks = os.environ.get("TELEMETRY_SINKS", "sqlite").split(",") @@ -273,6 +197,7 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient): self.config_path_or_template_name = config_path_or_template_name self.config = config self.custom_provider_registry = custom_provider_registry + self.provider_data = provider_data async def initialize(self): try: @@ -329,17 +254,24 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient): if not self.endpoint_impls: raise ValueError("Client not initialized") + if self.provider_data: + set_request_provider_data( + {"X-LlamaStack-Provider-Data": json.dumps(self.provider_data)} + ) + await start_trace(options.url, {"__location__": "library_client"}) if stream: - return self._call_streaming( + response = await self._call_streaming( cast_to=cast_to, options=options, stream_cls=stream_cls, ) else: - return await self._call_non_streaming( + response = await self._call_non_streaming( cast_to=cast_to, options=options, ) + await end_trace() + return response async def _call_non_streaming( self, From a174938fbd8768f920df98909e341c9b4f1a6a65 Mon Sep 17 00:00:00 2001 From: Dinesh Yeduguru Date: Tue, 14 Jan 2025 11:31:50 -0800 Subject: [PATCH 6/6] Fix telemetry to work on reinstantiating new lib cli (#761) # What does this PR do? Since we maintain global state in our telemetry pipeline, reinstantiating lib cli will cause us to add duplicate span processors causing sqlite to lock out because of constraint violations since we now have two span processor writing to sqlite. This PR changes the telemetry adapter for otel to only instantiate the provider once and add the span processsors only once. Also fixes an issue llama stack build ## Test Plan tested with notebook at https://colab.research.google.com/drive/1ck7hXQxRl6UvT-ijNRZ-gMZxH1G3cN2d#scrollTo=9496f75c --- llama_stack/cli/stack/build.py | 5 +- .../telemetry/meta_reference/telemetry.py | 53 ++++++++++--------- .../providers/utils/telemetry/tracing.py | 3 +- 3 files changed, 30 insertions(+), 31 deletions(-) diff --git a/llama_stack/cli/stack/build.py b/llama_stack/cli/stack/build.py index 084374c8a..85e6cb962 100644 --- a/llama_stack/cli/stack/build.py +++ b/llama_stack/cli/stack/build.py @@ -4,9 +4,7 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. import argparse - import importlib.resources - import os import shutil from functools import lru_cache @@ -14,14 +12,12 @@ from pathlib import Path from typing import List, Optional from llama_stack.cli.subcommand import Subcommand - from llama_stack.distribution.datatypes import ( BuildConfig, DistributionSpec, Provider, StackRunConfig, ) - from llama_stack.distribution.distribution import get_provider_registry from llama_stack.distribution.resolver import InvalidProviderError from llama_stack.distribution.utils.dynamic import instantiate_class_type @@ -296,6 +292,7 @@ class StackBuild(Subcommand): / f"templates/{template_name}/run.yaml" ) with importlib.resources.as_file(template_path) as path: + run_config_file = build_dir / f"{build_config.name}-run.yaml" shutil.copy(path, run_config_file) # Find all ${env.VARIABLE} patterns cprint("Build Successful!", color="green") diff --git a/llama_stack/providers/inline/telemetry/meta_reference/telemetry.py b/llama_stack/providers/inline/telemetry/meta_reference/telemetry.py index efc37b553..332a150cf 100644 --- a/llama_stack/providers/inline/telemetry/meta_reference/telemetry.py +++ b/llama_stack/providers/inline/telemetry/meta_reference/telemetry.py @@ -30,13 +30,10 @@ from llama_stack.apis.telemetry import ( Trace, UnstructuredLogEvent, ) - from llama_stack.distribution.datatypes import Api - from llama_stack.providers.inline.telemetry.meta_reference.console_span_processor import ( ConsoleSpanProcessor, ) - from llama_stack.providers.inline.telemetry.meta_reference.sqlite_span_processor import ( SQLiteSpanProcessor, ) @@ -52,6 +49,7 @@ _GLOBAL_STORAGE = { "up_down_counters": {}, } _global_lock = threading.Lock() +_TRACER_PROVIDER = None def string_to_trace_id(s: str) -> int: @@ -80,31 +78,34 @@ class TelemetryAdapter(TelemetryDatasetMixin, Telemetry): } ) - provider = TracerProvider(resource=resource) - trace.set_tracer_provider(provider) - if TelemetrySink.OTEL in self.config.sinks: - otlp_exporter = OTLPSpanExporter( - endpoint=self.config.otel_endpoint, - ) - span_processor = BatchSpanProcessor(otlp_exporter) - trace.get_tracer_provider().add_span_processor(span_processor) - metric_reader = PeriodicExportingMetricReader( - OTLPMetricExporter( + global _TRACER_PROVIDER + if _TRACER_PROVIDER is None: + provider = TracerProvider(resource=resource) + trace.set_tracer_provider(provider) + _TRACER_PROVIDER = provider + if TelemetrySink.OTEL in self.config.sinks: + otlp_exporter = OTLPSpanExporter( endpoint=self.config.otel_endpoint, ) - ) - metric_provider = MeterProvider( - resource=resource, metric_readers=[metric_reader] - ) - metrics.set_meter_provider(metric_provider) - self.meter = metrics.get_meter(__name__) - if TelemetrySink.SQLITE in self.config.sinks: - trace.get_tracer_provider().add_span_processor( - SQLiteSpanProcessor(self.config.sqlite_db_path) - ) - self.trace_store = SQLiteTraceStore(self.config.sqlite_db_path) - if TelemetrySink.CONSOLE in self.config.sinks: - trace.get_tracer_provider().add_span_processor(ConsoleSpanProcessor()) + span_processor = BatchSpanProcessor(otlp_exporter) + trace.get_tracer_provider().add_span_processor(span_processor) + metric_reader = PeriodicExportingMetricReader( + OTLPMetricExporter( + endpoint=self.config.otel_endpoint, + ) + ) + metric_provider = MeterProvider( + resource=resource, metric_readers=[metric_reader] + ) + metrics.set_meter_provider(metric_provider) + self.meter = metrics.get_meter(__name__) + if TelemetrySink.SQLITE in self.config.sinks: + trace.get_tracer_provider().add_span_processor( + SQLiteSpanProcessor(self.config.sqlite_db_path) + ) + self.trace_store = SQLiteTraceStore(self.config.sqlite_db_path) + if TelemetrySink.CONSOLE in self.config.sinks: + trace.get_tracer_provider().add_span_processor(ConsoleSpanProcessor()) self._lock = _global_lock async def initialize(self) -> None: diff --git a/llama_stack/providers/utils/telemetry/tracing.py b/llama_stack/providers/utils/telemetry/tracing.py index f304d58f6..d84024941 100644 --- a/llama_stack/providers/utils/telemetry/tracing.py +++ b/llama_stack/providers/utils/telemetry/tracing.py @@ -127,7 +127,8 @@ class TraceContext: def setup_logger(api: Telemetry, level: int = logging.INFO): global BACKGROUND_LOGGER - BACKGROUND_LOGGER = BackgroundLogger(api) + if BACKGROUND_LOGGER is None: + BACKGROUND_LOGGER = BackgroundLogger(api) logger = logging.getLogger() logger.setLevel(level) logger.addHandler(TelemetryHandler())