addressing comments

This commit is contained in:
Hardik Shah 2024-07-31 22:07:45 -07:00
parent 0e985648f5
commit fd8adc1e50
3 changed files with 9 additions and 2 deletions

View file

@ -13,7 +13,7 @@ async def get_inference_api_instance(config: InferenceConfig):
return InferenceImpl(config.impl_config) return InferenceImpl(config.impl_config)
elif config.impl_config.impl_type == ImplType.ollama.value: elif config.impl_config.impl_type == ImplType.ollama.value:
from .inference import OllamaInference from .ollama import OllamaInference
return OllamaInference(config.impl_config) return OllamaInference(config.impl_config)

View file

@ -13,6 +13,7 @@ hydra-zen
json-strong-typing json-strong-typing
llama-models llama-models
matplotlib matplotlib
ollama
omegaconf omegaconf
pandas pandas
Pillow Pillow

View file

@ -9,6 +9,9 @@ from llama_models.llama3_1.api.datatypes import (
StopReason, StopReason,
SystemMessage, SystemMessage,
) )
from llama_toolchain.inference.api_instance import (
get_inference_api_instance,
)
from llama_toolchain.inference.api.datatypes import ( from llama_toolchain.inference.api.datatypes import (
ChatCompletionResponseEventType, ChatCompletionResponseEventType,
) )
@ -16,6 +19,7 @@ from llama_toolchain.inference.api.endpoints import (
ChatCompletionRequest ChatCompletionRequest
) )
from llama_toolchain.inference.api.config import ( from llama_toolchain.inference.api.config import (
InferenceConfig,
OllamaImplConfig OllamaImplConfig
) )
from llama_toolchain.inference.ollama import ( from llama_toolchain.inference.ollama import (
@ -32,7 +36,9 @@ class OllamaInferenceTests(unittest.IsolatedAsyncioTestCase):
) )
# setup ollama # setup ollama
self.api = OllamaInference(ollama_config) self.api = await get_inference_api_instance(
InferenceConfig(impl_config=ollama_config)
)
await self.api.initialize() await self.api.initialize()
current_date = datetime.now() current_date = datetime.now()