move all implementations to use updated type

This commit is contained in:
Ashwin Bharambe 2025-01-13 20:04:19 -08:00
parent aced2ce07e
commit 9a5803a429
8 changed files with 139 additions and 208 deletions

View file

@ -18,6 +18,7 @@ from llama_models.llama3.api.datatypes import (
from pydantic import BaseModel, ValidationError
from llama_stack.apis.common.content_types import ToolCallParseStatus
from llama_stack.apis.inference import (
ChatCompletionResponse,
ChatCompletionResponseEventType,
@ -27,8 +28,6 @@ from llama_stack.apis.inference import (
JsonSchemaResponseFormat,
LogProbConfig,
SystemMessage,
ToolCallDelta,
ToolCallParseStatus,
ToolChoice,
UserMessage,
)
@ -196,7 +195,9 @@ class TestInference:
1 <= len(chunks) <= 6
) # why 6 and not 5? the response may have an extra closing chunk, e.g. for usage or stop_reason
for chunk in chunks:
if chunk.delta: # if there's a token, we expect logprobs
if (
chunk.delta.type == "text" and chunk.delta.text
): # if there's a token, we expect logprobs
assert chunk.logprobs, "Logprobs should not be empty"
assert all(
len(logprob.logprobs_by_token) == 3 for logprob in chunk.logprobs
@ -463,7 +464,7 @@ class TestInference:
if "Llama3.1" in inference_model:
assert all(
isinstance(chunk.event.delta, ToolCallDelta)
chunk.event.delta.type == "tool_call"
for chunk in grouped[ChatCompletionResponseEventType.progress]
)
first = grouped[ChatCompletionResponseEventType.progress][0]
@ -475,7 +476,7 @@ class TestInference:
last = grouped[ChatCompletionResponseEventType.progress][-1]
# assert last.event.stop_reason == expected_stop_reason
assert last.event.delta.parse_status == ToolCallParseStatus.success
assert isinstance(last.event.delta.content, ToolCall)
assert last.event.delta.content.type == "tool_call"
call = last.event.delta.content
assert call.tool_name == "get_weather"