update inference config to take model and not model_dir

This commit is contained in:
Hardik Shah 2024-08-06 15:02:41 -07:00
parent 08c3802f45
commit 039861f1c7
9 changed files with 400 additions and 101 deletions

View file

@ -14,24 +14,18 @@ from llama_models.llama3_1.api.datatypes import (
StopReason,
SystemMessage,
)
from llama_toolchain.inference.api.config import (
InferenceConfig,
InlineImplConfig,
RemoteImplConfig,
ModelCheckpointConfig,
PytorchCheckpoint,
CheckpointQuantizationFormat,
)
from llama_toolchain.inference.api_instance import (
get_inference_api_instance,
)
from llama_toolchain.inference.api.datatypes import (
ChatCompletionResponseEventType,
)
from llama_toolchain.inference.meta_reference.inference import get_provider_impl
from llama_toolchain.inference.meta_reference.config import (
MetaReferenceImplConfig,
)
from llama_toolchain.inference.api.endpoints import ChatCompletionRequest
MODEL = "Meta-Llama3.1-8B-Instruct"
HELPER_MSG = """
This test needs llama-3.1-8b-instruct models.
Please donwload using the llama cli
@ -50,32 +44,18 @@ class InferenceTests(unittest.IsolatedAsyncioTestCase):
@classmethod
async def asyncSetUpClass(cls):
# assert model exists on local
model_dir = os.path.expanduser(
"~/.llama/checkpoints/Meta-Llama-3.1-8B-Instruct/original/"
)
model_dir = os.path.expanduser(f"~/.llama/checkpoints/{MODEL}/original/")
assert os.path.isdir(model_dir), HELPER_MSG
tokenizer_path = os.path.join(model_dir, "tokenizer.model")
assert os.path.exists(tokenizer_path), HELPER_MSG
inline_config = InlineImplConfig(
checkpoint_config=ModelCheckpointConfig(
checkpoint=PytorchCheckpoint(
checkpoint_dir=model_dir,
tokenizer_path=tokenizer_path,
model_parallel_size=1,
quantization_format=CheckpointQuantizationFormat.bf16,
)
),
config = MetaReferenceImplConfig(
model=MODEL,
max_seq_len=2048,
)
inference_config = InferenceConfig(impl_config=inline_config)
# -- For faster testing iteration --
# remote_config = RemoteImplConfig(url="http://localhost:5000")
# inference_config = InferenceConfig(impl_config=remote_config)
cls.api = await get_inference_api_instance(inference_config)
cls.api = await get_provider_impl(config, {})
await cls.api.initialize()
current_date = datetime.now()
@ -134,7 +114,7 @@ class InferenceTests(unittest.IsolatedAsyncioTestCase):
await cls.api.shutdown()
async def asyncSetUp(self):
self.valid_supported_model = "Meta-Llama3.1-8B-Instruct"
self.valid_supported_model = MODEL
async def test_text(self):
request = ChatCompletionRequest(

View file

@ -10,14 +10,12 @@ from llama_models.llama3_1.api.datatypes import (
SamplingStrategy,
SystemMessage,
)
from llama_toolchain.inference.api_instance import (
get_inference_api_instance,
)
from llama_toolchain.inference.api.datatypes import (
ChatCompletionResponseEventType,
)
from llama_toolchain.inference.api.endpoints import ChatCompletionRequest
from llama_toolchain.inference.api.config import InferenceConfig, OllamaImplConfig
from llama_toolchain.inference.ollama.config import OllamaImplConfig
from llama_toolchain.inference.ollama.ollama import get_provider_impl
class OllamaInferenceTests(unittest.IsolatedAsyncioTestCase):
@ -30,9 +28,7 @@ class OllamaInferenceTests(unittest.IsolatedAsyncioTestCase):
)
# setup ollama
self.api = await get_inference_api_instance(
InferenceConfig(impl_config=ollama_config)
)
self.api = await get_provider_impl(ollama_config)
await self.api.initialize()
current_date = datetime.now()