[RFC] feat: log model input

Summary:

This PR logs the final prompt sent to providers, post running through prompt_adapter. Will be useful for debugging.

Will add to other inference providers if this looks good.


Test Plan:
This commit is contained in:
Eric Huang 2025-02-18 21:15:32 -08:00
parent a66b4c4c81
commit 0a192b548a
5 changed files with 35 additions and 1 deletions

View file

@ -33,6 +33,7 @@ from llama_stack.apis.inference import (
ToolResponse, ToolResponse,
ToolResponseMessage, ToolResponseMessage,
UserMessage, UserMessage,
Message,
) )
from llama_stack.apis.safety import SafetyViolation from llama_stack.apis.safety import SafetyViolation
from llama_stack.apis.tools import ToolDef from llama_stack.apis.tools import ToolDef
@ -70,6 +71,7 @@ class InferenceStep(StepCommon):
step_type: Literal[StepType.inference.value] = StepType.inference.value step_type: Literal[StepType.inference.value] = StepType.inference.value
model_response: CompletionMessage model_response: CompletionMessage
model_input: Union[str, List[Dict[str, str]]]
@json_schema_type @json_schema_type

View file

@ -206,11 +206,13 @@ class ChatCompletionResponseEventType(Enum):
:cvar start: Inference has started :cvar start: Inference has started
:cvar complete: Inference is complete and a full response is available :cvar complete: Inference is complete and a full response is available
:cvar progress: Inference is in progress and a partial response is available :cvar progress: Inference is in progress and a partial response is available
:cvar prepare: Inference is preparing to start
""" """
start = "start" start = "start"
complete = "complete" complete = "complete"
progress = "progress" progress = "progress"
prepare = "prepare"
@json_schema_type @json_schema_type
@ -227,6 +229,8 @@ class ChatCompletionResponseEvent(BaseModel):
delta: ContentDelta delta: ContentDelta
logprobs: Optional[List[TokenLogProbs]] = None logprobs: Optional[List[TokenLogProbs]] = None
stop_reason: Optional[StopReason] = None stop_reason: Optional[StopReason] = None
input_prompt: Optional[str] = None
input_messages: Optional[List[Dict[str, str]]] = None
class ResponseFormatType(Enum): class ResponseFormatType(Enum):

View file

@ -514,6 +514,9 @@ class ChatAgent(ShieldRunnerMixin):
elif event.event_type == ChatCompletionResponseEventType.complete: elif event.event_type == ChatCompletionResponseEventType.complete:
stop_reason = StopReason.end_of_turn stop_reason = StopReason.end_of_turn
continue continue
elif event.event_type == ChatCompletionResponseEventType.prepare:
model_input = event.input_prompt or event.input_messages
continue
delta = event.delta delta = event.delta
if delta.type == "tool_call": if delta.type == "tool_call":
@ -582,6 +585,7 @@ class ChatAgent(ShieldRunnerMixin):
model_response=copy.deepcopy(message), model_response=copy.deepcopy(message),
started_at=inference_start_time, started_at=inference_start_time,
completed_at=datetime.now(), completed_at=datetime.now(),
model_input=model_input,
), ),
) )
) )

View file

@ -10,7 +10,7 @@ from fireworks.client import Fireworks
from llama_models.llama3.api.chat_format import ChatFormat from llama_models.llama3.api.chat_format import ChatFormat
from llama_models.llama3.api.tokenizer import Tokenizer from llama_models.llama3.api.tokenizer import Tokenizer
from llama_stack.apis.common.content_types import InterleavedContent from llama_stack.apis.common.content_types import InterleavedContent, TextDelta
from llama_stack.apis.inference import ( from llama_stack.apis.inference import (
ChatCompletionRequest, ChatCompletionRequest,
ChatCompletionResponse, ChatCompletionResponse,
@ -27,6 +27,9 @@ from llama_stack.apis.inference import (
ToolConfig, ToolConfig,
ToolDefinition, ToolDefinition,
ToolPromptFormat, ToolPromptFormat,
ChatCompletionResponseEvent,
ChatCompletionResponseEventType,
ChatCompletionResponseStreamChunk,
) )
from llama_stack.distribution.request_headers import NeedsRequestProviderData from llama_stack.distribution.request_headers import NeedsRequestProviderData
from llama_stack.models.llama.datatypes import CoreModelId from llama_stack.models.llama.datatypes import CoreModelId
@ -234,6 +237,14 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProv
async def _stream_chat_completion(self, request: ChatCompletionRequest) -> AsyncGenerator: async def _stream_chat_completion(self, request: ChatCompletionRequest) -> AsyncGenerator:
params = await self._get_params(request) params = await self._get_params(request)
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
delta=TextDelta(text=""),
event_type=ChatCompletionResponseEventType.prepare,
input_prompt=params.get("prompt", None),
input_messages=params.get("messages", None),
)
)
async def _to_async_generator(): async def _to_async_generator():
if "messages" in params: if "messages" in params:
@ -269,6 +280,7 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProv
if input_dict["prompt"].startswith("<|begin_of_text|>"): if input_dict["prompt"].startswith("<|begin_of_text|>"):
input_dict["prompt"] = input_dict["prompt"][len("<|begin_of_text|>") :] input_dict["prompt"] = input_dict["prompt"][len("<|begin_of_text|>") :]
print(f"prompt: {input_dict['prompt']}")
return { return {
"model": request.model, "model": request.model,
**input_dict, **input_dict,

View file

@ -16,6 +16,7 @@ from llama_stack.apis.common.content_types import (
ImageContentItem, ImageContentItem,
InterleavedContent, InterleavedContent,
TextContentItem, TextContentItem,
TextDelta,
) )
from llama_stack.apis.inference import ( from llama_stack.apis.inference import (
ChatCompletionRequest, ChatCompletionRequest,
@ -31,6 +32,9 @@ from llama_stack.apis.inference import (
ToolConfig, ToolConfig,
ToolDefinition, ToolDefinition,
ToolPromptFormat, ToolPromptFormat,
ChatCompletionResponseEvent,
ChatCompletionResponseEventType,
ChatCompletionResponseStreamChunk,
) )
from llama_stack.apis.models import Model, ModelType from llama_stack.apis.models import Model, ModelType
from llama_stack.models.llama.datatypes import CoreModelId from llama_stack.models.llama.datatypes import CoreModelId
@ -308,6 +312,14 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
async def _stream_chat_completion(self, request: ChatCompletionRequest) -> AsyncGenerator: async def _stream_chat_completion(self, request: ChatCompletionRequest) -> AsyncGenerator:
params = await self._get_params(request) params = await self._get_params(request)
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
delta=TextDelta(text=""),
event_type=ChatCompletionResponseEventType.prepare,
input_prompt=params.get("prompt", None),
input_messages=params.get("messages", None),
)
)
async def _generate_and_convert_to_openai_compat(): async def _generate_and_convert_to_openai_compat():
if "messages" in params: if "messages" in params: