minor import fixes

This commit is contained in:
Hardik Shah 2024-08-26 14:21:35 -07:00
parent dc433f6c90
commit c3708859aa
7 changed files with 16 additions and 11 deletions

View file

@ -18,9 +18,11 @@ from llama_models.llama3.api.datatypes import (
ToolResponseMessage,
UserMessage,
)
from llama_toolchain.inference.api.datatypes import ChatCompletionResponseEventType
from llama_toolchain.inference.api.endpoints import ChatCompletionRequest
from llama_toolchain.inference.api import (
ChatCompletionRequest,
ChatCompletionResponseEventType,
)
from llama_toolchain.inference.meta_reference.config import MetaReferenceImplConfig
from llama_toolchain.inference.meta_reference.inference import get_provider_impl
@ -221,12 +223,12 @@ class InferenceTests(unittest.IsolatedAsyncioTestCase):
self.assertEqual(
events[-1].event_type, ChatCompletionResponseEventType.complete
)
self.assertEqual(events[-1].stop_reason, StopReason.end_of_message)
self.assertEqual(events[-1].stop_reason, StopReason.end_of_turn)
# last but one event should be eom with tool call
self.assertEqual(
events[-2].event_type, ChatCompletionResponseEventType.progress
)
self.assertEqual(events[-2].stop_reason, StopReason.end_of_message)
self.assertEqual(events[-2].stop_reason, StopReason.end_of_turn)
self.assertEqual(events[-2].delta.content.tool_name, "get_boiling_point")
async def test_multi_turn(self):