diff --git a/tests/test_inference.py b/tests/test_inference.py new file mode 100644 index 000000000..6d11ba415 --- /dev/null +++ b/tests/test_inference.py @@ -0,0 +1,114 @@ +# Run this test using the following command: +# python -m unittest tests/test_inference.py + +import os +import unittest + +from llama_models.llama3_1.api.datatypes import ( + InstructModel, + UserMessage +) + +from llama_toolchain.inference.api.config import ( + ImplType, + InferenceConfig, + InlineImplConfig, + ModelCheckpointConfig, + PytorchCheckpoint, + CheckpointQuantizationFormat, +) +from llama_toolchain.inference.api.datatypes import ( + ChatCompletionResponseEventType, +) +from llama_toolchain.inference.api.endpoints import ( + ChatCompletionRequest +) +from llama_toolchain.inference.inference import InferenceImpl +from llama_toolchain.inference.event_logger import EventLogger + + +HELPER_MSG = """ +This test needs llama-3.1-8b-instruct models. +Please donwload using the llama cli + +llama download --source huggingface --model-id llama3_1_8b_instruct --hf-token +""" + + +class InferenceTests(unittest.IsolatedAsyncioTestCase): + + async def asyncSetUp(self): + # assert model exists on local + model_dir = os.path.expanduser("~/.llama/checkpoints/Meta-Llama-3.1-8B-Instruct/original/") + assert os.path.isdir(model_dir), HELPER_MSG + + tokenizer_path = os.path.join(model_dir, "tokenizer.model") + assert os.path.exists(tokenizer_path), HELPER_MSG + + inference_config = InlineImplConfig( + checkpoint_config=ModelCheckpointConfig( + checkpoint=PytorchCheckpoint( + checkpoint_dir=model_dir, + tokenizer_path=tokenizer_path, + model_parallel_size=1, + quantization_format=CheckpointQuantizationFormat.bf16, + ) + ), + max_seq_len=2048, + ) + + self.inference = InferenceImpl(inference_config) + await self.inference.initialize() + + async def asyncTearDown(self): + await self.inference.shutdown() + + async def test_inline_inference_no_streaming(self): + request = ChatCompletionRequest( + model=InstructModel.llama3_8b_chat, + messages=[ + UserMessage( + content="What is the capital of France?", + ), + ], + stream=False, + ) + iterator = self.inference.chat_completion(request) + + async for chunk in iterator: + response = chunk + + result = response.completion_message.content + self.assertTrue("Paris" in result, result) + + async def test_inline_inference_streaming(self): + request = ChatCompletionRequest( + model=InstructModel.llama3_8b_chat, + messages=[ + UserMessage( + content="What is the capital of France?", + ), + ], + stream=True, + ) + iterator = self.inference.chat_completion(request) + + events = [] + async for chunk in iterator: + events.append(chunk.event) + + + self.assertEqual( + events[0].event_type, + ChatCompletionResponseEventType.start + ) + self.assertEqual( + events[-1].event_type, + ChatCompletionResponseEventType.complete + ) + + response = "" + for e in events[1:-1]: + response += e.delta + + self.assertTrue("Paris" in response, response)