[RFC] feat: log model input

Summary:

This PR logs the final prompt sent to providers, post running through prompt_adapter. Will be useful for debugging.

Will add to other inference providers if this looks good.


Test Plan:
This commit is contained in:
Eric Huang 2025-02-18 21:15:32 -08:00
parent a66b4c4c81
commit 0a192b548a
5 changed files with 35 additions and 1 deletions

View file

@ -33,6 +33,7 @@ from llama_stack.apis.inference import (
ToolResponse,
ToolResponseMessage,
UserMessage,
Message,
)
from llama_stack.apis.safety import SafetyViolation
from llama_stack.apis.tools import ToolDef
@ -70,6 +71,7 @@ class InferenceStep(StepCommon):
step_type: Literal[StepType.inference.value] = StepType.inference.value
model_response: CompletionMessage
model_input: Union[str, List[Dict[str, str]]]
@json_schema_type

View file

@ -206,11 +206,13 @@ class ChatCompletionResponseEventType(Enum):
:cvar start: Inference has started
:cvar complete: Inference is complete and a full response is available
:cvar progress: Inference is in progress and a partial response is available
:cvar prepare: Inference is preparing to start
"""
start = "start"
complete = "complete"
progress = "progress"
prepare = "prepare"
@json_schema_type
@ -227,6 +229,8 @@ class ChatCompletionResponseEvent(BaseModel):
delta: ContentDelta
logprobs: Optional[List[TokenLogProbs]] = None
stop_reason: Optional[StopReason] = None
input_prompt: Optional[str] = None
input_messages: Optional[List[Dict[str, str]]] = None
class ResponseFormatType(Enum):

View file

@ -514,6 +514,9 @@ class ChatAgent(ShieldRunnerMixin):
elif event.event_type == ChatCompletionResponseEventType.complete:
stop_reason = StopReason.end_of_turn
continue
elif event.event_type == ChatCompletionResponseEventType.prepare:
model_input = event.input_prompt or event.input_messages
continue
delta = event.delta
if delta.type == "tool_call":
@ -582,6 +585,7 @@ class ChatAgent(ShieldRunnerMixin):
model_response=copy.deepcopy(message),
started_at=inference_start_time,
completed_at=datetime.now(),
model_input=model_input,
),
)
)

View file

@ -10,7 +10,7 @@ from fireworks.client import Fireworks
from llama_models.llama3.api.chat_format import ChatFormat
from llama_models.llama3.api.tokenizer import Tokenizer
from llama_stack.apis.common.content_types import InterleavedContent
from llama_stack.apis.common.content_types import InterleavedContent, TextDelta
from llama_stack.apis.inference import (
ChatCompletionRequest,
ChatCompletionResponse,
@ -27,6 +27,9 @@ from llama_stack.apis.inference import (
ToolConfig,
ToolDefinition,
ToolPromptFormat,
ChatCompletionResponseEvent,
ChatCompletionResponseEventType,
ChatCompletionResponseStreamChunk,
)
from llama_stack.distribution.request_headers import NeedsRequestProviderData
from llama_stack.models.llama.datatypes import CoreModelId
@ -234,6 +237,14 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProv
async def _stream_chat_completion(self, request: ChatCompletionRequest) -> AsyncGenerator:
params = await self._get_params(request)
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
delta=TextDelta(text=""),
event_type=ChatCompletionResponseEventType.prepare,
input_prompt=params.get("prompt", None),
input_messages=params.get("messages", None),
)
)
async def _to_async_generator():
if "messages" in params:
@ -269,6 +280,7 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProv
if input_dict["prompt"].startswith("<|begin_of_text|>"):
input_dict["prompt"] = input_dict["prompt"][len("<|begin_of_text|>") :]
print(f"prompt: {input_dict['prompt']}")
return {
"model": request.model,
**input_dict,

View file

@ -16,6 +16,7 @@ from llama_stack.apis.common.content_types import (
ImageContentItem,
InterleavedContent,
TextContentItem,
TextDelta,
)
from llama_stack.apis.inference import (
ChatCompletionRequest,
@ -31,6 +32,9 @@ from llama_stack.apis.inference import (
ToolConfig,
ToolDefinition,
ToolPromptFormat,
ChatCompletionResponseEvent,
ChatCompletionResponseEventType,
ChatCompletionResponseStreamChunk,
)
from llama_stack.apis.models import Model, ModelType
from llama_stack.models.llama.datatypes import CoreModelId
@ -308,6 +312,14 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
async def _stream_chat_completion(self, request: ChatCompletionRequest) -> AsyncGenerator:
params = await self._get_params(request)
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
delta=TextDelta(text=""),
event_type=ChatCompletionResponseEventType.prepare,
input_prompt=params.get("prompt", None),
input_messages=params.get("messages", None),
)
)
async def _generate_and_convert_to_openai_compat():
if "messages" in params: