mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-06-28 02:53:30 +00:00
move all implementations to use updated type
This commit is contained in:
parent
aced2ce07e
commit
9a5803a429
8 changed files with 139 additions and 208 deletions
|
@ -18,6 +18,7 @@ from llama_models.llama3.api.datatypes import (
|
|||
|
||||
from pydantic import BaseModel, ValidationError
|
||||
|
||||
from llama_stack.apis.common.content_types import ToolCallParseStatus
|
||||
from llama_stack.apis.inference import (
|
||||
ChatCompletionResponse,
|
||||
ChatCompletionResponseEventType,
|
||||
|
@ -27,8 +28,6 @@ from llama_stack.apis.inference import (
|
|||
JsonSchemaResponseFormat,
|
||||
LogProbConfig,
|
||||
SystemMessage,
|
||||
ToolCallDelta,
|
||||
ToolCallParseStatus,
|
||||
ToolChoice,
|
||||
UserMessage,
|
||||
)
|
||||
|
@ -196,7 +195,9 @@ class TestInference:
|
|||
1 <= len(chunks) <= 6
|
||||
) # why 6 and not 5? the response may have an extra closing chunk, e.g. for usage or stop_reason
|
||||
for chunk in chunks:
|
||||
if chunk.delta: # if there's a token, we expect logprobs
|
||||
if (
|
||||
chunk.delta.type == "text" and chunk.delta.text
|
||||
): # if there's a token, we expect logprobs
|
||||
assert chunk.logprobs, "Logprobs should not be empty"
|
||||
assert all(
|
||||
len(logprob.logprobs_by_token) == 3 for logprob in chunk.logprobs
|
||||
|
@ -463,7 +464,7 @@ class TestInference:
|
|||
|
||||
if "Llama3.1" in inference_model:
|
||||
assert all(
|
||||
isinstance(chunk.event.delta, ToolCallDelta)
|
||||
chunk.event.delta.type == "tool_call"
|
||||
for chunk in grouped[ChatCompletionResponseEventType.progress]
|
||||
)
|
||||
first = grouped[ChatCompletionResponseEventType.progress][0]
|
||||
|
@ -475,7 +476,7 @@ class TestInference:
|
|||
last = grouped[ChatCompletionResponseEventType.progress][-1]
|
||||
# assert last.event.stop_reason == expected_stop_reason
|
||||
assert last.event.delta.parse_status == ToolCallParseStatus.success
|
||||
assert isinstance(last.event.delta.content, ToolCall)
|
||||
assert last.event.delta.content.type == "tool_call"
|
||||
|
||||
call = last.event.delta.content
|
||||
assert call.tool_name == "get_weather"
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue