From fd8adc1e50ccc99c331d124a2b1914ff469a8038 Mon Sep 17 00:00:00 2001 From: Hardik Shah Date: Wed, 31 Jul 2024 22:07:45 -0700 Subject: [PATCH] addressing comments --- llama_toolchain/inference/api_instance.py | 2 +- requirements.txt | 1 + tests/test_ollama_inference.py | 8 +++++++- 3 files changed, 9 insertions(+), 2 deletions(-) diff --git a/llama_toolchain/inference/api_instance.py b/llama_toolchain/inference/api_instance.py index 25b5ecf4b..975de3446 100644 --- a/llama_toolchain/inference/api_instance.py +++ b/llama_toolchain/inference/api_instance.py @@ -13,7 +13,7 @@ async def get_inference_api_instance(config: InferenceConfig): return InferenceImpl(config.impl_config) elif config.impl_config.impl_type == ImplType.ollama.value: - from .inference import OllamaInference + from .ollama import OllamaInference return OllamaInference(config.impl_config) diff --git a/requirements.txt b/requirements.txt index a51bc74d9..05d642f81 100644 --- a/requirements.txt +++ b/requirements.txt @@ -13,6 +13,7 @@ hydra-zen json-strong-typing llama-models matplotlib +ollama omegaconf pandas Pillow diff --git a/tests/test_ollama_inference.py b/tests/test_ollama_inference.py index 6f6b5d1a8..d9dbab6e9 100644 --- a/tests/test_ollama_inference.py +++ b/tests/test_ollama_inference.py @@ -9,6 +9,9 @@ from llama_models.llama3_1.api.datatypes import ( StopReason, SystemMessage, ) +from llama_toolchain.inference.api_instance import ( + get_inference_api_instance, +) from llama_toolchain.inference.api.datatypes import ( ChatCompletionResponseEventType, ) @@ -16,6 +19,7 @@ from llama_toolchain.inference.api.endpoints import ( ChatCompletionRequest ) from llama_toolchain.inference.api.config import ( + InferenceConfig, OllamaImplConfig ) from llama_toolchain.inference.ollama import ( @@ -32,7 +36,9 @@ class OllamaInferenceTests(unittest.IsolatedAsyncioTestCase): ) # setup ollama - self.api = OllamaInference(ollama_config) + self.api = await get_inference_api_instance( + InferenceConfig(impl_config=ollama_config) + ) await self.api.initialize() current_date = datetime.now()