mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-09 19:58:29 +00:00
[RFC] feat: log model input
Summary: This PR logs the final prompt sent to providers, post running through prompt_adapter. Will be useful for debugging. Will add to other inference providers if this looks good. Test Plan:
This commit is contained in:
parent
a66b4c4c81
commit
0a192b548a
5 changed files with 35 additions and 1 deletions
|
@ -33,6 +33,7 @@ from llama_stack.apis.inference import (
|
|||
ToolResponse,
|
||||
ToolResponseMessage,
|
||||
UserMessage,
|
||||
Message,
|
||||
)
|
||||
from llama_stack.apis.safety import SafetyViolation
|
||||
from llama_stack.apis.tools import ToolDef
|
||||
|
@ -70,6 +71,7 @@ class InferenceStep(StepCommon):
|
|||
|
||||
step_type: Literal[StepType.inference.value] = StepType.inference.value
|
||||
model_response: CompletionMessage
|
||||
model_input: Union[str, List[Dict[str, str]]]
|
||||
|
||||
|
||||
@json_schema_type
|
||||
|
|
|
@ -206,11 +206,13 @@ class ChatCompletionResponseEventType(Enum):
|
|||
:cvar start: Inference has started
|
||||
:cvar complete: Inference is complete and a full response is available
|
||||
:cvar progress: Inference is in progress and a partial response is available
|
||||
:cvar prepare: Inference is preparing to start
|
||||
"""
|
||||
|
||||
start = "start"
|
||||
complete = "complete"
|
||||
progress = "progress"
|
||||
prepare = "prepare"
|
||||
|
||||
|
||||
@json_schema_type
|
||||
|
@ -227,6 +229,8 @@ class ChatCompletionResponseEvent(BaseModel):
|
|||
delta: ContentDelta
|
||||
logprobs: Optional[List[TokenLogProbs]] = None
|
||||
stop_reason: Optional[StopReason] = None
|
||||
input_prompt: Optional[str] = None
|
||||
input_messages: Optional[List[Dict[str, str]]] = None
|
||||
|
||||
|
||||
class ResponseFormatType(Enum):
|
||||
|
|
|
@ -514,6 +514,9 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
elif event.event_type == ChatCompletionResponseEventType.complete:
|
||||
stop_reason = StopReason.end_of_turn
|
||||
continue
|
||||
elif event.event_type == ChatCompletionResponseEventType.prepare:
|
||||
model_input = event.input_prompt or event.input_messages
|
||||
continue
|
||||
|
||||
delta = event.delta
|
||||
if delta.type == "tool_call":
|
||||
|
@ -582,6 +585,7 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
model_response=copy.deepcopy(message),
|
||||
started_at=inference_start_time,
|
||||
completed_at=datetime.now(),
|
||||
model_input=model_input,
|
||||
),
|
||||
)
|
||||
)
|
||||
|
|
|
@ -10,7 +10,7 @@ from fireworks.client import Fireworks
|
|||
from llama_models.llama3.api.chat_format import ChatFormat
|
||||
from llama_models.llama3.api.tokenizer import Tokenizer
|
||||
|
||||
from llama_stack.apis.common.content_types import InterleavedContent
|
||||
from llama_stack.apis.common.content_types import InterleavedContent, TextDelta
|
||||
from llama_stack.apis.inference import (
|
||||
ChatCompletionRequest,
|
||||
ChatCompletionResponse,
|
||||
|
@ -27,6 +27,9 @@ from llama_stack.apis.inference import (
|
|||
ToolConfig,
|
||||
ToolDefinition,
|
||||
ToolPromptFormat,
|
||||
ChatCompletionResponseEvent,
|
||||
ChatCompletionResponseEventType,
|
||||
ChatCompletionResponseStreamChunk,
|
||||
)
|
||||
from llama_stack.distribution.request_headers import NeedsRequestProviderData
|
||||
from llama_stack.models.llama.datatypes import CoreModelId
|
||||
|
@ -234,6 +237,14 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProv
|
|||
|
||||
async def _stream_chat_completion(self, request: ChatCompletionRequest) -> AsyncGenerator:
|
||||
params = await self._get_params(request)
|
||||
yield ChatCompletionResponseStreamChunk(
|
||||
event=ChatCompletionResponseEvent(
|
||||
delta=TextDelta(text=""),
|
||||
event_type=ChatCompletionResponseEventType.prepare,
|
||||
input_prompt=params.get("prompt", None),
|
||||
input_messages=params.get("messages", None),
|
||||
)
|
||||
)
|
||||
|
||||
async def _to_async_generator():
|
||||
if "messages" in params:
|
||||
|
@ -269,6 +280,7 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProv
|
|||
if input_dict["prompt"].startswith("<|begin_of_text|>"):
|
||||
input_dict["prompt"] = input_dict["prompt"][len("<|begin_of_text|>") :]
|
||||
|
||||
print(f"prompt: {input_dict['prompt']}")
|
||||
return {
|
||||
"model": request.model,
|
||||
**input_dict,
|
||||
|
|
|
@ -16,6 +16,7 @@ from llama_stack.apis.common.content_types import (
|
|||
ImageContentItem,
|
||||
InterleavedContent,
|
||||
TextContentItem,
|
||||
TextDelta,
|
||||
)
|
||||
from llama_stack.apis.inference import (
|
||||
ChatCompletionRequest,
|
||||
|
@ -31,6 +32,9 @@ from llama_stack.apis.inference import (
|
|||
ToolConfig,
|
||||
ToolDefinition,
|
||||
ToolPromptFormat,
|
||||
ChatCompletionResponseEvent,
|
||||
ChatCompletionResponseEventType,
|
||||
ChatCompletionResponseStreamChunk,
|
||||
)
|
||||
from llama_stack.apis.models import Model, ModelType
|
||||
from llama_stack.models.llama.datatypes import CoreModelId
|
||||
|
@ -308,6 +312,14 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
|
|||
|
||||
async def _stream_chat_completion(self, request: ChatCompletionRequest) -> AsyncGenerator:
|
||||
params = await self._get_params(request)
|
||||
yield ChatCompletionResponseStreamChunk(
|
||||
event=ChatCompletionResponseEvent(
|
||||
delta=TextDelta(text=""),
|
||||
event_type=ChatCompletionResponseEventType.prepare,
|
||||
input_prompt=params.get("prompt", None),
|
||||
input_messages=params.get("messages", None),
|
||||
)
|
||||
)
|
||||
|
||||
async def _generate_and_convert_to_openai_compat():
|
||||
if "messages" in params:
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue