# Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. from typing import AsyncGenerator, List, Optional from llama_models.llama3.api.chat_format import ChatFormat from llama_models.llama3.api.datatypes import ( GreedySamplingStrategy, SamplingParams, StopReason, TopKSamplingStrategy, TopPSamplingStrategy, ) from pydantic import BaseModel from llama_stack.apis.common.content_types import ( ImageContentItem, TextContentItem, TextDelta, ToolCallDelta, ToolCallParseStatus, ) from llama_stack.apis.inference import ( ChatCompletionResponse, ChatCompletionResponseEvent, ChatCompletionResponseEventType, ChatCompletionResponseStreamChunk, CompletionMessage, CompletionResponse, CompletionResponseStreamChunk, Message, ) from llama_stack.providers.utils.inference.prompt_adapter import ( convert_image_content_to_url, ) class OpenAICompatCompletionChoiceDelta(BaseModel): content: str class OpenAICompatCompletionChoice(BaseModel): finish_reason: Optional[str] = None text: Optional[str] = None delta: Optional[OpenAICompatCompletionChoiceDelta] = None class OpenAICompatCompletionResponse(BaseModel): choices: List[OpenAICompatCompletionChoice] def get_sampling_strategy_options(params: SamplingParams) -> dict: options = {} if isinstance(params.strategy, GreedySamplingStrategy): options["temperature"] = 0.0 elif isinstance(params.strategy, TopPSamplingStrategy): options["temperature"] = params.strategy.temperature options["top_p"] = params.strategy.top_p elif isinstance(params.strategy, TopKSamplingStrategy): options["top_k"] = params.strategy.top_k else: raise ValueError(f"Unsupported sampling strategy: {params.strategy}") return options def get_sampling_options(params: SamplingParams) -> dict: options = {} if params: options.update(get_sampling_strategy_options(params)) if params.max_tokens: options["max_tokens"] = params.max_tokens if params.repetition_penalty is not None and params.repetition_penalty != 1.0: options["repeat_penalty"] = params.repetition_penalty return options def text_from_choice(choice) -> str: if hasattr(choice, "delta") and choice.delta: return choice.delta.content if hasattr(choice, "message"): return choice.message.content return choice.text def get_stop_reason(finish_reason: str) -> StopReason: if finish_reason in ["stop", "eos"]: return StopReason.end_of_turn elif finish_reason == "eom": return StopReason.end_of_message elif finish_reason == "length": return StopReason.out_of_tokens return StopReason.out_of_tokens def process_completion_response( response: OpenAICompatCompletionResponse, formatter: ChatFormat ) -> CompletionResponse: choice = response.choices[0] # drop suffix if present and return stop reason as end of turn if choice.text.endswith("<|eot_id|>"): return CompletionResponse( stop_reason=StopReason.end_of_turn, content=choice.text[: -len("<|eot_id|>")], ) # drop suffix if present and return stop reason as end of message if choice.text.endswith("<|eom_id|>"): return CompletionResponse( stop_reason=StopReason.end_of_message, content=choice.text[: -len("<|eom_id|>")], ) return CompletionResponse( stop_reason=get_stop_reason(choice.finish_reason), content=choice.text, ) def process_chat_completion_response( response: OpenAICompatCompletionResponse, formatter: ChatFormat ) -> ChatCompletionResponse: choice = response.choices[0] raw_message = formatter.decode_assistant_message_from_content( text_from_choice(choice), get_stop_reason(choice.finish_reason) ) return ChatCompletionResponse( completion_message=CompletionMessage( content=raw_message.content, stop_reason=raw_message.stop_reason, tool_calls=raw_message.tool_calls, ), logprobs=None, ) async def process_completion_stream_response( stream: AsyncGenerator[OpenAICompatCompletionResponse, None], formatter: ChatFormat ) -> AsyncGenerator: stop_reason = None async for chunk in stream: choice = chunk.choices[0] finish_reason = choice.finish_reason text = text_from_choice(choice) if text == "<|eot_id|>": stop_reason = StopReason.end_of_turn text = "" continue elif text == "<|eom_id|>": stop_reason = StopReason.end_of_message text = "" continue yield CompletionResponseStreamChunk( delta=text, stop_reason=stop_reason, ) if finish_reason: if finish_reason in ["stop", "eos", "eos_token"]: stop_reason = StopReason.end_of_turn elif finish_reason == "length": stop_reason = StopReason.out_of_tokens break yield CompletionResponseStreamChunk( delta="", stop_reason=stop_reason, ) async def process_chat_completion_stream_response( stream: AsyncGenerator[OpenAICompatCompletionResponse, None], formatter: ChatFormat ) -> AsyncGenerator: yield ChatCompletionResponseStreamChunk( event=ChatCompletionResponseEvent( event_type=ChatCompletionResponseEventType.start, delta=TextDelta(text=""), ) ) buffer = "" ipython = False stop_reason = None async for chunk in stream: choice = chunk.choices[0] finish_reason = choice.finish_reason if finish_reason: if stop_reason is None and finish_reason in ["stop", "eos", "eos_token"]: stop_reason = StopReason.end_of_turn elif stop_reason is None and finish_reason == "length": stop_reason = StopReason.out_of_tokens break text = text_from_choice(choice) if not text: # Sometimes you get empty chunks from providers continue # check if its a tool call ( aka starts with <|python_tag|> ) if not ipython and text.startswith("<|python_tag|>"): ipython = True yield ChatCompletionResponseStreamChunk( event=ChatCompletionResponseEvent( event_type=ChatCompletionResponseEventType.progress, delta=ToolCallDelta( content="", parse_status=ToolCallParseStatus.started, ), ) ) buffer += text continue if text == "<|eot_id|>": stop_reason = StopReason.end_of_turn text = "" continue elif text == "<|eom_id|>": stop_reason = StopReason.end_of_message text = "" continue if ipython: buffer += text delta = ToolCallDelta( content=text, parse_status=ToolCallParseStatus.in_progress, ) yield ChatCompletionResponseStreamChunk( event=ChatCompletionResponseEvent( event_type=ChatCompletionResponseEventType.progress, delta=delta, stop_reason=stop_reason, ) ) else: buffer += text yield ChatCompletionResponseStreamChunk( event=ChatCompletionResponseEvent( event_type=ChatCompletionResponseEventType.progress, delta=TextDelta(text=text), stop_reason=stop_reason, ) ) # parse tool calls and report errors message = formatter.decode_assistant_message_from_content(buffer, stop_reason) parsed_tool_calls = len(message.tool_calls) > 0 if ipython and not parsed_tool_calls: yield ChatCompletionResponseStreamChunk( event=ChatCompletionResponseEvent( event_type=ChatCompletionResponseEventType.progress, delta=ToolCallDelta( content="", parse_status=ToolCallParseStatus.failed, ), stop_reason=stop_reason, ) ) for tool_call in message.tool_calls: yield ChatCompletionResponseStreamChunk( event=ChatCompletionResponseEvent( event_type=ChatCompletionResponseEventType.progress, delta=ToolCallDelta( content=tool_call, parse_status=ToolCallParseStatus.succeeded, ), stop_reason=stop_reason, ) ) yield ChatCompletionResponseStreamChunk( event=ChatCompletionResponseEvent( event_type=ChatCompletionResponseEventType.complete, delta=TextDelta(text=""), stop_reason=stop_reason, ) ) async def convert_message_to_openai_dict( message: Message, download: bool = False ) -> dict: async def _convert_content(content) -> dict: if isinstance(content, ImageContentItem): return { "type": "image_url", "image_url": { "url": await convert_image_content_to_url( content, download=download ), }, } else: text = content.text if isinstance(content, TextContentItem) else content assert isinstance(text, str) return {"type": "text", "text": text} if isinstance(message.content, list): content = [await _convert_content(c) for c in message.content] else: content = [await _convert_content(message.content)] return { "role": message.role, "content": content, }