mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-02 20:40:36 +00:00
move all implementations to use updated type
This commit is contained in:
parent
aced2ce07e
commit
9a5803a429
8 changed files with 139 additions and 208 deletions
|
@ -40,7 +40,12 @@ from llama_stack.apis.agents import (
|
|||
ToolExecutionStep,
|
||||
Turn,
|
||||
)
|
||||
from llama_stack.apis.common.content_types import TextContentItem, URL
|
||||
from llama_stack.apis.common.content_types import (
|
||||
TextContentItem,
|
||||
ToolCallDelta,
|
||||
ToolCallParseStatus,
|
||||
URL,
|
||||
)
|
||||
from llama_stack.apis.inference import (
|
||||
ChatCompletionResponseEventType,
|
||||
CompletionMessage,
|
||||
|
@ -49,8 +54,6 @@ from llama_stack.apis.inference import (
|
|||
SamplingParams,
|
||||
StopReason,
|
||||
SystemMessage,
|
||||
ToolCallDelta,
|
||||
ToolCallParseStatus,
|
||||
ToolDefinition,
|
||||
ToolResponse,
|
||||
ToolResponseMessage,
|
||||
|
@ -411,7 +414,7 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
payload=AgentTurnResponseStepProgressPayload(
|
||||
step_type=StepType.tool_execution.value,
|
||||
step_id=step_id,
|
||||
tool_call_delta=ToolCallDelta(
|
||||
delta=ToolCallDelta(
|
||||
parse_status=ToolCallParseStatus.success,
|
||||
content=ToolCall(
|
||||
call_id="",
|
||||
|
@ -507,7 +510,7 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
continue
|
||||
|
||||
delta = event.delta
|
||||
if isinstance(delta, ToolCallDelta):
|
||||
if delta.type == "tool_call":
|
||||
if delta.parse_status == ToolCallParseStatus.success:
|
||||
tool_calls.append(delta.content)
|
||||
if stream:
|
||||
|
@ -516,21 +519,20 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
payload=AgentTurnResponseStepProgressPayload(
|
||||
step_type=StepType.inference.value,
|
||||
step_id=step_id,
|
||||
text_delta="",
|
||||
tool_call_delta=delta,
|
||||
delta=delta,
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
elif isinstance(delta, str):
|
||||
content += delta
|
||||
elif delta.type == "text":
|
||||
content += delta.text
|
||||
if stream and event.stop_reason is None:
|
||||
yield AgentTurnResponseStreamChunk(
|
||||
event=AgentTurnResponseEvent(
|
||||
payload=AgentTurnResponseStepProgressPayload(
|
||||
step_type=StepType.inference.value,
|
||||
step_id=step_id,
|
||||
text_delta=event.delta,
|
||||
delta=delta,
|
||||
)
|
||||
)
|
||||
)
|
||||
|
|
|
@ -16,6 +16,11 @@ from llama_models.llama3.api.datatypes import (
|
|||
)
|
||||
from llama_models.sku_list import resolve_model
|
||||
|
||||
from llama_stack.apis.common.content_types import (
|
||||
TextDelta,
|
||||
ToolCallDelta,
|
||||
ToolCallParseStatus,
|
||||
)
|
||||
from llama_stack.apis.inference import (
|
||||
ChatCompletionRequest,
|
||||
ChatCompletionResponse,
|
||||
|
@ -32,8 +37,6 @@ from llama_stack.apis.inference import (
|
|||
Message,
|
||||
ResponseFormat,
|
||||
TokenLogProbs,
|
||||
ToolCallDelta,
|
||||
ToolCallParseStatus,
|
||||
ToolChoice,
|
||||
)
|
||||
from llama_stack.apis.models import Model, ModelType
|
||||
|
@ -190,14 +193,14 @@ class MetaReferenceInferenceImpl(
|
|||
]
|
||||
|
||||
yield CompletionResponseStreamChunk(
|
||||
delta=text,
|
||||
delta=TextDelta(text=text),
|
||||
stop_reason=stop_reason,
|
||||
logprobs=logprobs if request.logprobs else None,
|
||||
)
|
||||
|
||||
if stop_reason is None:
|
||||
yield CompletionResponseStreamChunk(
|
||||
delta="",
|
||||
delta=TextDelta(text=""),
|
||||
stop_reason=StopReason.out_of_tokens,
|
||||
)
|
||||
|
||||
|
@ -352,7 +355,7 @@ class MetaReferenceInferenceImpl(
|
|||
yield ChatCompletionResponseStreamChunk(
|
||||
event=ChatCompletionResponseEvent(
|
||||
event_type=ChatCompletionResponseEventType.start,
|
||||
delta="",
|
||||
delta=TextDelta(text=""),
|
||||
)
|
||||
)
|
||||
|
||||
|
@ -392,7 +395,7 @@ class MetaReferenceInferenceImpl(
|
|||
parse_status=ToolCallParseStatus.in_progress,
|
||||
)
|
||||
else:
|
||||
delta = text
|
||||
delta = TextDelta(text=text)
|
||||
|
||||
if stop_reason is None:
|
||||
if request.logprobs:
|
||||
|
@ -449,7 +452,7 @@ class MetaReferenceInferenceImpl(
|
|||
yield ChatCompletionResponseStreamChunk(
|
||||
event=ChatCompletionResponseEvent(
|
||||
event_type=ChatCompletionResponseEventType.complete,
|
||||
delta="",
|
||||
delta=TextDelta(text=""),
|
||||
stop_reason=stop_reason,
|
||||
)
|
||||
)
|
||||
|
|
|
@ -30,6 +30,11 @@ from groq.types.shared.function_definition import FunctionDefinition
|
|||
|
||||
from llama_models.llama3.api.datatypes import ToolParamDefinition
|
||||
|
||||
from llama_stack.apis.common.content_types import (
|
||||
TextDelta,
|
||||
ToolCallDelta,
|
||||
ToolCallParseStatus,
|
||||
)
|
||||
from llama_stack.apis.inference import (
|
||||
ChatCompletionRequest,
|
||||
ChatCompletionResponse,
|
||||
|
@ -40,8 +45,6 @@ from llama_stack.apis.inference import (
|
|||
Message,
|
||||
StopReason,
|
||||
ToolCall,
|
||||
ToolCallDelta,
|
||||
ToolCallParseStatus,
|
||||
ToolDefinition,
|
||||
ToolPromptFormat,
|
||||
)
|
||||
|
@ -162,7 +165,7 @@ def convert_chat_completion_response(
|
|||
|
||||
|
||||
def _map_finish_reason_to_stop_reason(
|
||||
finish_reason: Literal["stop", "length", "tool_calls"]
|
||||
finish_reason: Literal["stop", "length", "tool_calls"],
|
||||
) -> StopReason:
|
||||
"""
|
||||
Convert a Groq chat completion finish_reason to a StopReason.
|
||||
|
@ -185,7 +188,6 @@ def _map_finish_reason_to_stop_reason(
|
|||
async def convert_chat_completion_response_stream(
|
||||
stream: Stream[ChatCompletionChunk],
|
||||
) -> AsyncGenerator[ChatCompletionResponseStreamChunk, None]:
|
||||
|
||||
event_type = ChatCompletionResponseEventType.start
|
||||
for chunk in stream:
|
||||
choice = chunk.choices[0]
|
||||
|
@ -194,7 +196,7 @@ async def convert_chat_completion_response_stream(
|
|||
yield ChatCompletionResponseStreamChunk(
|
||||
event=ChatCompletionResponseEvent(
|
||||
event_type=ChatCompletionResponseEventType.complete,
|
||||
delta=choice.delta.content or "",
|
||||
delta=TextDelta(text=choice.delta.content or ""),
|
||||
logprobs=None,
|
||||
stop_reason=_map_finish_reason_to_stop_reason(choice.finish_reason),
|
||||
)
|
||||
|
@ -221,7 +223,7 @@ async def convert_chat_completion_response_stream(
|
|||
yield ChatCompletionResponseStreamChunk(
|
||||
event=ChatCompletionResponseEvent(
|
||||
event_type=event_type,
|
||||
delta=choice.delta.content or "",
|
||||
delta=TextDelta(text=choice.delta.content or ""),
|
||||
logprobs=None,
|
||||
)
|
||||
)
|
||||
|
|
|
@ -34,6 +34,11 @@ from openai.types.chat.chat_completion_message_tool_call_param import (
|
|||
from openai.types.completion import Completion as OpenAICompletion
|
||||
from openai.types.completion_choice import Logprobs as OpenAICompletionLogprobs
|
||||
|
||||
from llama_stack.apis.common.content_types import (
|
||||
TextDelta,
|
||||
ToolCallDelta,
|
||||
ToolCallParseStatus,
|
||||
)
|
||||
from llama_stack.apis.inference import (
|
||||
ChatCompletionRequest,
|
||||
ChatCompletionResponse,
|
||||
|
@ -48,8 +53,6 @@ from llama_stack.apis.inference import (
|
|||
Message,
|
||||
SystemMessage,
|
||||
TokenLogProbs,
|
||||
ToolCallDelta,
|
||||
ToolCallParseStatus,
|
||||
ToolResponseMessage,
|
||||
UserMessage,
|
||||
)
|
||||
|
@ -432,69 +435,6 @@ async def convert_openai_chat_completion_stream(
|
|||
"""
|
||||
Convert a stream of OpenAI chat completion chunks into a stream
|
||||
of ChatCompletionResponseStreamChunk.
|
||||
|
||||
OpenAI ChatCompletionChunk:
|
||||
choices: List[Choice]
|
||||
|
||||
OpenAI Choice: # different from the non-streamed Choice
|
||||
delta: ChoiceDelta
|
||||
finish_reason: Optional[Literal["stop", "length", "tool_calls", "content_filter", "function_call"]]
|
||||
logprobs: Optional[ChoiceLogprobs]
|
||||
|
||||
OpenAI ChoiceDelta:
|
||||
content: Optional[str]
|
||||
role: Optional[Literal["system", "user", "assistant", "tool"]]
|
||||
tool_calls: Optional[List[ChoiceDeltaToolCall]]
|
||||
|
||||
OpenAI ChoiceDeltaToolCall:
|
||||
index: int
|
||||
id: Optional[str]
|
||||
function: Optional[ChoiceDeltaToolCallFunction]
|
||||
type: Optional[Literal["function"]]
|
||||
|
||||
OpenAI ChoiceDeltaToolCallFunction:
|
||||
name: Optional[str]
|
||||
arguments: Optional[str]
|
||||
|
||||
->
|
||||
|
||||
ChatCompletionResponseStreamChunk:
|
||||
event: ChatCompletionResponseEvent
|
||||
|
||||
ChatCompletionResponseEvent:
|
||||
event_type: ChatCompletionResponseEventType
|
||||
delta: Union[str, ToolCallDelta]
|
||||
logprobs: Optional[List[TokenLogProbs]]
|
||||
stop_reason: Optional[StopReason]
|
||||
|
||||
ChatCompletionResponseEventType:
|
||||
start = "start"
|
||||
progress = "progress"
|
||||
complete = "complete"
|
||||
|
||||
ToolCallDelta:
|
||||
content: Union[str, ToolCall]
|
||||
parse_status: ToolCallParseStatus
|
||||
|
||||
ToolCall:
|
||||
call_id: str
|
||||
tool_name: str
|
||||
arguments: str
|
||||
|
||||
ToolCallParseStatus:
|
||||
started = "started"
|
||||
in_progress = "in_progress"
|
||||
failure = "failure"
|
||||
success = "success"
|
||||
|
||||
TokenLogProbs:
|
||||
logprobs_by_token: Dict[str, float]
|
||||
- token, logprob
|
||||
|
||||
StopReason:
|
||||
end_of_turn = "end_of_turn"
|
||||
end_of_message = "end_of_message"
|
||||
out_of_tokens = "out_of_tokens"
|
||||
"""
|
||||
|
||||
# generate a stream of ChatCompletionResponseEventType: start -> progress -> progress -> ...
|
||||
|
@ -543,7 +483,7 @@ async def convert_openai_chat_completion_stream(
|
|||
yield ChatCompletionResponseStreamChunk(
|
||||
event=ChatCompletionResponseEvent(
|
||||
event_type=next(event_type),
|
||||
delta=choice.delta.content,
|
||||
delta=TextDelta(text=choice.delta.content),
|
||||
logprobs=_convert_openai_logprobs(choice.logprobs),
|
||||
)
|
||||
)
|
||||
|
@ -570,7 +510,7 @@ async def convert_openai_chat_completion_stream(
|
|||
yield ChatCompletionResponseStreamChunk(
|
||||
event=ChatCompletionResponseEvent(
|
||||
event_type=next(event_type),
|
||||
delta=choice.delta.content or "", # content is not optional
|
||||
delta=TextDelta(text=choice.delta.content or ""),
|
||||
logprobs=_convert_openai_logprobs(choice.logprobs),
|
||||
)
|
||||
)
|
||||
|
@ -578,7 +518,7 @@ async def convert_openai_chat_completion_stream(
|
|||
yield ChatCompletionResponseStreamChunk(
|
||||
event=ChatCompletionResponseEvent(
|
||||
event_type=ChatCompletionResponseEventType.complete,
|
||||
delta="",
|
||||
delta=TextDelta(text=""),
|
||||
stop_reason=stop_reason,
|
||||
)
|
||||
)
|
||||
|
@ -653,18 +593,6 @@ def _convert_openai_completion_logprobs(
|
|||
) -> Optional[List[TokenLogProbs]]:
|
||||
"""
|
||||
Convert an OpenAI CompletionLogprobs into a list of TokenLogProbs.
|
||||
|
||||
OpenAI CompletionLogprobs:
|
||||
text_offset: Optional[List[int]]
|
||||
token_logprobs: Optional[List[float]]
|
||||
tokens: Optional[List[str]]
|
||||
top_logprobs: Optional[List[Dict[str, float]]]
|
||||
|
||||
->
|
||||
|
||||
TokenLogProbs:
|
||||
logprobs_by_token: Dict[str, float]
|
||||
- token, logprob
|
||||
"""
|
||||
if not logprobs:
|
||||
return None
|
||||
|
@ -679,28 +607,6 @@ def convert_openai_completion_choice(
|
|||
) -> CompletionResponse:
|
||||
"""
|
||||
Convert an OpenAI Completion Choice into a CompletionResponse.
|
||||
|
||||
OpenAI Completion Choice:
|
||||
text: str
|
||||
finish_reason: str
|
||||
logprobs: Optional[ChoiceLogprobs]
|
||||
|
||||
->
|
||||
|
||||
CompletionResponse:
|
||||
completion_message: CompletionMessage
|
||||
logprobs: Optional[List[TokenLogProbs]]
|
||||
|
||||
CompletionMessage:
|
||||
role: Literal["assistant"]
|
||||
content: str | ImageMedia | List[str | ImageMedia]
|
||||
stop_reason: StopReason
|
||||
tool_calls: List[ToolCall]
|
||||
|
||||
class StopReason(Enum):
|
||||
end_of_turn = "end_of_turn"
|
||||
end_of_message = "end_of_message"
|
||||
out_of_tokens = "out_of_tokens"
|
||||
"""
|
||||
return CompletionResponse(
|
||||
content=choice.text,
|
||||
|
@ -715,32 +621,11 @@ async def convert_openai_completion_stream(
|
|||
"""
|
||||
Convert a stream of OpenAI Completions into a stream
|
||||
of ChatCompletionResponseStreamChunks.
|
||||
|
||||
OpenAI Completion:
|
||||
id: str
|
||||
choices: List[OpenAICompletionChoice]
|
||||
created: int
|
||||
model: str
|
||||
system_fingerprint: Optional[str]
|
||||
usage: Optional[OpenAICompletionUsage]
|
||||
|
||||
OpenAI CompletionChoice:
|
||||
finish_reason: str
|
||||
index: int
|
||||
logprobs: Optional[OpenAILogprobs]
|
||||
text: str
|
||||
|
||||
->
|
||||
|
||||
CompletionResponseStreamChunk:
|
||||
delta: str
|
||||
stop_reason: Optional[StopReason]
|
||||
logprobs: Optional[List[TokenLogProbs]]
|
||||
"""
|
||||
async for chunk in stream:
|
||||
choice = chunk.choices[0]
|
||||
yield CompletionResponseStreamChunk(
|
||||
delta=choice.text,
|
||||
delta=TextDelta(text=choice.text),
|
||||
stop_reason=_convert_openai_finish_reason(choice.finish_reason),
|
||||
logprobs=_convert_openai_completion_logprobs(choice.logprobs),
|
||||
)
|
||||
|
|
|
@ -18,6 +18,7 @@ from llama_models.llama3.api.datatypes import (
|
|||
|
||||
from pydantic import BaseModel, ValidationError
|
||||
|
||||
from llama_stack.apis.common.content_types import ToolCallParseStatus
|
||||
from llama_stack.apis.inference import (
|
||||
ChatCompletionResponse,
|
||||
ChatCompletionResponseEventType,
|
||||
|
@ -27,8 +28,6 @@ from llama_stack.apis.inference import (
|
|||
JsonSchemaResponseFormat,
|
||||
LogProbConfig,
|
||||
SystemMessage,
|
||||
ToolCallDelta,
|
||||
ToolCallParseStatus,
|
||||
ToolChoice,
|
||||
UserMessage,
|
||||
)
|
||||
|
@ -196,7 +195,9 @@ class TestInference:
|
|||
1 <= len(chunks) <= 6
|
||||
) # why 6 and not 5? the response may have an extra closing chunk, e.g. for usage or stop_reason
|
||||
for chunk in chunks:
|
||||
if chunk.delta: # if there's a token, we expect logprobs
|
||||
if (
|
||||
chunk.delta.type == "text" and chunk.delta.text
|
||||
): # if there's a token, we expect logprobs
|
||||
assert chunk.logprobs, "Logprobs should not be empty"
|
||||
assert all(
|
||||
len(logprob.logprobs_by_token) == 3 for logprob in chunk.logprobs
|
||||
|
@ -463,7 +464,7 @@ class TestInference:
|
|||
|
||||
if "Llama3.1" in inference_model:
|
||||
assert all(
|
||||
isinstance(chunk.event.delta, ToolCallDelta)
|
||||
chunk.event.delta.type == "tool_call"
|
||||
for chunk in grouped[ChatCompletionResponseEventType.progress]
|
||||
)
|
||||
first = grouped[ChatCompletionResponseEventType.progress][0]
|
||||
|
@ -475,7 +476,7 @@ class TestInference:
|
|||
last = grouped[ChatCompletionResponseEventType.progress][-1]
|
||||
# assert last.event.stop_reason == expected_stop_reason
|
||||
assert last.event.delta.parse_status == ToolCallParseStatus.success
|
||||
assert isinstance(last.event.delta.content, ToolCall)
|
||||
assert last.event.delta.content.type == "tool_call"
|
||||
|
||||
call = last.event.delta.content
|
||||
assert call.tool_name == "get_weather"
|
||||
|
|
|
@ -11,7 +11,13 @@ from llama_models.llama3.api.chat_format import ChatFormat
|
|||
from llama_models.llama3.api.datatypes import SamplingParams, StopReason
|
||||
from pydantic import BaseModel
|
||||
|
||||
from llama_stack.apis.common.content_types import ImageContentItem, TextContentItem
|
||||
from llama_stack.apis.common.content_types import (
|
||||
ImageContentItem,
|
||||
TextContentItem,
|
||||
TextDelta,
|
||||
ToolCallDelta,
|
||||
ToolCallParseStatus,
|
||||
)
|
||||
|
||||
from llama_stack.apis.inference import (
|
||||
ChatCompletionResponse,
|
||||
|
@ -22,8 +28,6 @@ from llama_stack.apis.inference import (
|
|||
CompletionResponse,
|
||||
CompletionResponseStreamChunk,
|
||||
Message,
|
||||
ToolCallDelta,
|
||||
ToolCallParseStatus,
|
||||
)
|
||||
|
||||
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||
|
@ -138,7 +142,7 @@ async def process_completion_stream_response(
|
|||
text = ""
|
||||
continue
|
||||
yield CompletionResponseStreamChunk(
|
||||
delta=text,
|
||||
delta=TextDelta(text=text),
|
||||
stop_reason=stop_reason,
|
||||
)
|
||||
if finish_reason:
|
||||
|
@ -149,7 +153,7 @@ async def process_completion_stream_response(
|
|||
break
|
||||
|
||||
yield CompletionResponseStreamChunk(
|
||||
delta="",
|
||||
delta=TextDelta(text=""),
|
||||
stop_reason=stop_reason,
|
||||
)
|
||||
|
||||
|
@ -160,7 +164,7 @@ async def process_chat_completion_stream_response(
|
|||
yield ChatCompletionResponseStreamChunk(
|
||||
event=ChatCompletionResponseEvent(
|
||||
event_type=ChatCompletionResponseEventType.start,
|
||||
delta="",
|
||||
delta=TextDelta(text=""),
|
||||
)
|
||||
)
|
||||
|
||||
|
@ -227,7 +231,7 @@ async def process_chat_completion_stream_response(
|
|||
yield ChatCompletionResponseStreamChunk(
|
||||
event=ChatCompletionResponseEvent(
|
||||
event_type=ChatCompletionResponseEventType.progress,
|
||||
delta=text,
|
||||
delta=TextDelta(text=text),
|
||||
stop_reason=stop_reason,
|
||||
)
|
||||
)
|
||||
|
@ -262,7 +266,7 @@ async def process_chat_completion_stream_response(
|
|||
yield ChatCompletionResponseStreamChunk(
|
||||
event=ChatCompletionResponseEvent(
|
||||
event_type=ChatCompletionResponseEventType.complete,
|
||||
delta="",
|
||||
delta=TextDelta(text=""),
|
||||
stop_reason=stop_reason,
|
||||
)
|
||||
)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue