diff --git a/llama_stack/apis/agents/agents.py b/llama_stack/apis/agents/agents.py index 367648ded..3c2b7efa5 100644 --- a/llama_stack/apis/agents/agents.py +++ b/llama_stack/apis/agents/agents.py @@ -33,6 +33,7 @@ from llama_stack.apis.inference import ( ToolResponse, ToolResponseMessage, UserMessage, + Message, ) from llama_stack.apis.safety import SafetyViolation from llama_stack.apis.tools import ToolDef @@ -70,6 +71,7 @@ class InferenceStep(StepCommon): step_type: Literal[StepType.inference.value] = StepType.inference.value model_response: CompletionMessage + model_input: Union[str, List[Dict[str, str]]] @json_schema_type diff --git a/llama_stack/apis/inference/inference.py b/llama_stack/apis/inference/inference.py index a3fb69477..2129763ce 100644 --- a/llama_stack/apis/inference/inference.py +++ b/llama_stack/apis/inference/inference.py @@ -206,11 +206,13 @@ class ChatCompletionResponseEventType(Enum): :cvar start: Inference has started :cvar complete: Inference is complete and a full response is available :cvar progress: Inference is in progress and a partial response is available + :cvar prepare: Inference is preparing to start """ start = "start" complete = "complete" progress = "progress" + prepare = "prepare" @json_schema_type @@ -227,6 +229,8 @@ class ChatCompletionResponseEvent(BaseModel): delta: ContentDelta logprobs: Optional[List[TokenLogProbs]] = None stop_reason: Optional[StopReason] = None + input_prompt: Optional[str] = None + input_messages: Optional[List[Dict[str, str]]] = None class ResponseFormatType(Enum): diff --git a/llama_stack/providers/inline/agents/meta_reference/agent_instance.py b/llama_stack/providers/inline/agents/meta_reference/agent_instance.py index 1c21df57f..f453919d4 100644 --- a/llama_stack/providers/inline/agents/meta_reference/agent_instance.py +++ b/llama_stack/providers/inline/agents/meta_reference/agent_instance.py @@ -514,6 +514,9 @@ class ChatAgent(ShieldRunnerMixin): elif event.event_type == ChatCompletionResponseEventType.complete: stop_reason = StopReason.end_of_turn continue + elif event.event_type == ChatCompletionResponseEventType.prepare: + model_input = event.input_prompt or event.input_messages + continue delta = event.delta if delta.type == "tool_call": @@ -582,6 +585,7 @@ class ChatAgent(ShieldRunnerMixin): model_response=copy.deepcopy(message), started_at=inference_start_time, completed_at=datetime.now(), + model_input=model_input, ), ) ) diff --git a/llama_stack/providers/remote/inference/fireworks/fireworks.py b/llama_stack/providers/remote/inference/fireworks/fireworks.py index acf37b248..b810afaf2 100644 --- a/llama_stack/providers/remote/inference/fireworks/fireworks.py +++ b/llama_stack/providers/remote/inference/fireworks/fireworks.py @@ -10,7 +10,7 @@ from fireworks.client import Fireworks from llama_models.llama3.api.chat_format import ChatFormat from llama_models.llama3.api.tokenizer import Tokenizer -from llama_stack.apis.common.content_types import InterleavedContent +from llama_stack.apis.common.content_types import InterleavedContent, TextDelta from llama_stack.apis.inference import ( ChatCompletionRequest, ChatCompletionResponse, @@ -27,6 +27,9 @@ from llama_stack.apis.inference import ( ToolConfig, ToolDefinition, ToolPromptFormat, + ChatCompletionResponseEvent, + ChatCompletionResponseEventType, + ChatCompletionResponseStreamChunk, ) from llama_stack.distribution.request_headers import NeedsRequestProviderData from llama_stack.models.llama.datatypes import CoreModelId @@ -234,6 +237,14 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProv async def _stream_chat_completion(self, request: ChatCompletionRequest) -> AsyncGenerator: params = await self._get_params(request) + yield ChatCompletionResponseStreamChunk( + event=ChatCompletionResponseEvent( + delta=TextDelta(text=""), + event_type=ChatCompletionResponseEventType.prepare, + input_prompt=params.get("prompt", None), + input_messages=params.get("messages", None), + ) + ) async def _to_async_generator(): if "messages" in params: @@ -269,6 +280,7 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProv if input_dict["prompt"].startswith("<|begin_of_text|>"): input_dict["prompt"] = input_dict["prompt"][len("<|begin_of_text|>") :] + print(f"prompt: {input_dict['prompt']}") return { "model": request.model, **input_dict, diff --git a/llama_stack/providers/remote/inference/ollama/ollama.py b/llama_stack/providers/remote/inference/ollama/ollama.py index f524c0734..bd836129a 100644 --- a/llama_stack/providers/remote/inference/ollama/ollama.py +++ b/llama_stack/providers/remote/inference/ollama/ollama.py @@ -16,6 +16,7 @@ from llama_stack.apis.common.content_types import ( ImageContentItem, InterleavedContent, TextContentItem, + TextDelta, ) from llama_stack.apis.inference import ( ChatCompletionRequest, @@ -31,6 +32,9 @@ from llama_stack.apis.inference import ( ToolConfig, ToolDefinition, ToolPromptFormat, + ChatCompletionResponseEvent, + ChatCompletionResponseEventType, + ChatCompletionResponseStreamChunk, ) from llama_stack.apis.models import Model, ModelType from llama_stack.models.llama.datatypes import CoreModelId @@ -308,6 +312,14 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate): async def _stream_chat_completion(self, request: ChatCompletionRequest) -> AsyncGenerator: params = await self._get_params(request) + yield ChatCompletionResponseStreamChunk( + event=ChatCompletionResponseEvent( + delta=TextDelta(text=""), + event_type=ChatCompletionResponseEventType.prepare, + input_prompt=params.get("prompt", None), + input_messages=params.get("messages", None), + ) + ) async def _generate_and_convert_to_openai_compat(): if "messages" in params: