mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-10 04:08:31 +00:00
[RFC] feat: log model input
Summary: This PR logs the final prompt sent to providers, post running through prompt_adapter. Will be useful for debugging. Will add to other inference providers if this looks good. Test Plan:
This commit is contained in:
parent
a66b4c4c81
commit
0a192b548a
5 changed files with 35 additions and 1 deletions
|
@ -33,6 +33,7 @@ from llama_stack.apis.inference import (
|
||||||
ToolResponse,
|
ToolResponse,
|
||||||
ToolResponseMessage,
|
ToolResponseMessage,
|
||||||
UserMessage,
|
UserMessage,
|
||||||
|
Message,
|
||||||
)
|
)
|
||||||
from llama_stack.apis.safety import SafetyViolation
|
from llama_stack.apis.safety import SafetyViolation
|
||||||
from llama_stack.apis.tools import ToolDef
|
from llama_stack.apis.tools import ToolDef
|
||||||
|
@ -70,6 +71,7 @@ class InferenceStep(StepCommon):
|
||||||
|
|
||||||
step_type: Literal[StepType.inference.value] = StepType.inference.value
|
step_type: Literal[StepType.inference.value] = StepType.inference.value
|
||||||
model_response: CompletionMessage
|
model_response: CompletionMessage
|
||||||
|
model_input: Union[str, List[Dict[str, str]]]
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
|
|
|
@ -206,11 +206,13 @@ class ChatCompletionResponseEventType(Enum):
|
||||||
:cvar start: Inference has started
|
:cvar start: Inference has started
|
||||||
:cvar complete: Inference is complete and a full response is available
|
:cvar complete: Inference is complete and a full response is available
|
||||||
:cvar progress: Inference is in progress and a partial response is available
|
:cvar progress: Inference is in progress and a partial response is available
|
||||||
|
:cvar prepare: Inference is preparing to start
|
||||||
"""
|
"""
|
||||||
|
|
||||||
start = "start"
|
start = "start"
|
||||||
complete = "complete"
|
complete = "complete"
|
||||||
progress = "progress"
|
progress = "progress"
|
||||||
|
prepare = "prepare"
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
|
@ -227,6 +229,8 @@ class ChatCompletionResponseEvent(BaseModel):
|
||||||
delta: ContentDelta
|
delta: ContentDelta
|
||||||
logprobs: Optional[List[TokenLogProbs]] = None
|
logprobs: Optional[List[TokenLogProbs]] = None
|
||||||
stop_reason: Optional[StopReason] = None
|
stop_reason: Optional[StopReason] = None
|
||||||
|
input_prompt: Optional[str] = None
|
||||||
|
input_messages: Optional[List[Dict[str, str]]] = None
|
||||||
|
|
||||||
|
|
||||||
class ResponseFormatType(Enum):
|
class ResponseFormatType(Enum):
|
||||||
|
|
|
@ -514,6 +514,9 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
elif event.event_type == ChatCompletionResponseEventType.complete:
|
elif event.event_type == ChatCompletionResponseEventType.complete:
|
||||||
stop_reason = StopReason.end_of_turn
|
stop_reason = StopReason.end_of_turn
|
||||||
continue
|
continue
|
||||||
|
elif event.event_type == ChatCompletionResponseEventType.prepare:
|
||||||
|
model_input = event.input_prompt or event.input_messages
|
||||||
|
continue
|
||||||
|
|
||||||
delta = event.delta
|
delta = event.delta
|
||||||
if delta.type == "tool_call":
|
if delta.type == "tool_call":
|
||||||
|
@ -582,6 +585,7 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
model_response=copy.deepcopy(message),
|
model_response=copy.deepcopy(message),
|
||||||
started_at=inference_start_time,
|
started_at=inference_start_time,
|
||||||
completed_at=datetime.now(),
|
completed_at=datetime.now(),
|
||||||
|
model_input=model_input,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
|
@ -10,7 +10,7 @@ from fireworks.client import Fireworks
|
||||||
from llama_models.llama3.api.chat_format import ChatFormat
|
from llama_models.llama3.api.chat_format import ChatFormat
|
||||||
from llama_models.llama3.api.tokenizer import Tokenizer
|
from llama_models.llama3.api.tokenizer import Tokenizer
|
||||||
|
|
||||||
from llama_stack.apis.common.content_types import InterleavedContent
|
from llama_stack.apis.common.content_types import InterleavedContent, TextDelta
|
||||||
from llama_stack.apis.inference import (
|
from llama_stack.apis.inference import (
|
||||||
ChatCompletionRequest,
|
ChatCompletionRequest,
|
||||||
ChatCompletionResponse,
|
ChatCompletionResponse,
|
||||||
|
@ -27,6 +27,9 @@ from llama_stack.apis.inference import (
|
||||||
ToolConfig,
|
ToolConfig,
|
||||||
ToolDefinition,
|
ToolDefinition,
|
||||||
ToolPromptFormat,
|
ToolPromptFormat,
|
||||||
|
ChatCompletionResponseEvent,
|
||||||
|
ChatCompletionResponseEventType,
|
||||||
|
ChatCompletionResponseStreamChunk,
|
||||||
)
|
)
|
||||||
from llama_stack.distribution.request_headers import NeedsRequestProviderData
|
from llama_stack.distribution.request_headers import NeedsRequestProviderData
|
||||||
from llama_stack.models.llama.datatypes import CoreModelId
|
from llama_stack.models.llama.datatypes import CoreModelId
|
||||||
|
@ -234,6 +237,14 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProv
|
||||||
|
|
||||||
async def _stream_chat_completion(self, request: ChatCompletionRequest) -> AsyncGenerator:
|
async def _stream_chat_completion(self, request: ChatCompletionRequest) -> AsyncGenerator:
|
||||||
params = await self._get_params(request)
|
params = await self._get_params(request)
|
||||||
|
yield ChatCompletionResponseStreamChunk(
|
||||||
|
event=ChatCompletionResponseEvent(
|
||||||
|
delta=TextDelta(text=""),
|
||||||
|
event_type=ChatCompletionResponseEventType.prepare,
|
||||||
|
input_prompt=params.get("prompt", None),
|
||||||
|
input_messages=params.get("messages", None),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
async def _to_async_generator():
|
async def _to_async_generator():
|
||||||
if "messages" in params:
|
if "messages" in params:
|
||||||
|
@ -269,6 +280,7 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProv
|
||||||
if input_dict["prompt"].startswith("<|begin_of_text|>"):
|
if input_dict["prompt"].startswith("<|begin_of_text|>"):
|
||||||
input_dict["prompt"] = input_dict["prompt"][len("<|begin_of_text|>") :]
|
input_dict["prompt"] = input_dict["prompt"][len("<|begin_of_text|>") :]
|
||||||
|
|
||||||
|
print(f"prompt: {input_dict['prompt']}")
|
||||||
return {
|
return {
|
||||||
"model": request.model,
|
"model": request.model,
|
||||||
**input_dict,
|
**input_dict,
|
||||||
|
|
|
@ -16,6 +16,7 @@ from llama_stack.apis.common.content_types import (
|
||||||
ImageContentItem,
|
ImageContentItem,
|
||||||
InterleavedContent,
|
InterleavedContent,
|
||||||
TextContentItem,
|
TextContentItem,
|
||||||
|
TextDelta,
|
||||||
)
|
)
|
||||||
from llama_stack.apis.inference import (
|
from llama_stack.apis.inference import (
|
||||||
ChatCompletionRequest,
|
ChatCompletionRequest,
|
||||||
|
@ -31,6 +32,9 @@ from llama_stack.apis.inference import (
|
||||||
ToolConfig,
|
ToolConfig,
|
||||||
ToolDefinition,
|
ToolDefinition,
|
||||||
ToolPromptFormat,
|
ToolPromptFormat,
|
||||||
|
ChatCompletionResponseEvent,
|
||||||
|
ChatCompletionResponseEventType,
|
||||||
|
ChatCompletionResponseStreamChunk,
|
||||||
)
|
)
|
||||||
from llama_stack.apis.models import Model, ModelType
|
from llama_stack.apis.models import Model, ModelType
|
||||||
from llama_stack.models.llama.datatypes import CoreModelId
|
from llama_stack.models.llama.datatypes import CoreModelId
|
||||||
|
@ -308,6 +312,14 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
|
||||||
|
|
||||||
async def _stream_chat_completion(self, request: ChatCompletionRequest) -> AsyncGenerator:
|
async def _stream_chat_completion(self, request: ChatCompletionRequest) -> AsyncGenerator:
|
||||||
params = await self._get_params(request)
|
params = await self._get_params(request)
|
||||||
|
yield ChatCompletionResponseStreamChunk(
|
||||||
|
event=ChatCompletionResponseEvent(
|
||||||
|
delta=TextDelta(text=""),
|
||||||
|
event_type=ChatCompletionResponseEventType.prepare,
|
||||||
|
input_prompt=params.get("prompt", None),
|
||||||
|
input_messages=params.get("messages", None),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
async def _generate_and_convert_to_openai_compat():
|
async def _generate_and_convert_to_openai_compat():
|
||||||
if "messages" in params:
|
if "messages" in params:
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue