Works with ollama 0.4.0 pre-release with the vision model

This commit is contained in:
Ashwin Bharambe 2024-11-05 14:59:18 -08:00
parent 03013dafc1
commit d543eb442b
5 changed files with 137 additions and 57 deletions

View file

@ -19,12 +19,11 @@ def pytest_addoption(parser):
def pytest_configure(config):
config.addinivalue_line(
"markers", "llama_8b: mark test to run only with the given model"
)
config.addinivalue_line(
"markers", "llama_3b: mark test to run only with the given model"
)
for model in ["llama_8b", "llama_3b", "llama_vision"]:
config.addinivalue_line(
"markers", f"{model}: mark test to run only with the given model"
)
for fixture_name in INFERENCE_FIXTURES:
config.addinivalue_line(
"markers",
@ -37,6 +36,14 @@ MODEL_PARAMS = [
pytest.param("Llama3.2-3B-Instruct", marks=pytest.mark.llama_3b, id="llama_3b"),
]
VISION_MODEL_PARAMS = [
pytest.param(
"Llama3.2-11B-Vision-Instruct",
marks=pytest.mark.llama_vision,
id="llama_vision",
),
]
def pytest_generate_tests(metafunc):
if "inference_model" in metafunc.fixturenames:
@ -44,7 +51,11 @@ def pytest_generate_tests(metafunc):
if model:
params = [pytest.param(model, id="")]
else:
params = MODEL_PARAMS
cls_name = metafunc.cls.__name__
if "Vision" in cls_name:
params = VISION_MODEL_PARAMS
else:
params = MODEL_PARAMS
metafunc.parametrize(
"inference_model",

View file

@ -29,11 +29,6 @@ def inference_model(request):
return request.config.getoption("--inference-model", None)
@pytest.fixture(scope="session")
def vision_inference_model():
return "Llama3.2-11B-Vision-Instruct"
@pytest.fixture(scope="session")
def inference_remote() -> ProviderFixture:
return remote_stack_fixture()

View file

@ -21,19 +21,20 @@ THIS_DIR = Path(__file__).parent
class TestVisionModelInference:
@pytest.mark.asyncio
async def test_vision_chat_completion_non_streaming(
self, vision_inference_model, inference_stack
self, inference_model, inference_stack
):
inference_impl, _ = inference_stack
provider = inference_impl.routing_table.get_provider_impl(
vision_inference_model
)
provider = inference_impl.routing_table.get_provider_impl(inference_model)
if provider.__provider_spec__.provider_type not in (
"meta-reference",
"remote::together",
"remote::fireworks",
"remote::ollama",
):
pytest.skip("Other inference providers don't support completion() yet")
pytest.skip(
"Other inference providers don't support vision chat completion() yet"
)
images = [
ImageMedia(image=PIL_Image.open(THIS_DIR / "pasta.jpeg")),
@ -51,7 +52,7 @@ class TestVisionModelInference:
]
for image, expected_strings in zip(images, expected_strings_to_check):
response = await inference_impl.chat_completion(
model=vision_inference_model,
model=inference_model,
messages=[
SystemMessage(content="You are a helpful assistant."),
UserMessage(
@ -69,19 +70,20 @@ class TestVisionModelInference:
@pytest.mark.asyncio
async def test_vision_chat_completion_streaming(
self, vision_inference_model, inference_stack
self, inference_model, inference_stack
):
inference_impl, _ = inference_stack
provider = inference_impl.routing_table.get_provider_impl(
vision_inference_model
)
provider = inference_impl.routing_table.get_provider_impl(inference_model)
if provider.__provider_spec__.provider_type not in (
"meta-reference",
"remote::together",
"remote::fireworks",
"remote::ollama",
):
pytest.skip("Other inference providers don't support completion() yet")
pytest.skip(
"Other inference providers don't support vision chat completion() yet"
)
images = [
ImageMedia(
@ -97,7 +99,7 @@ class TestVisionModelInference:
response = [
r
async for r in await inference_impl.chat_completion(
model=vision_inference_model,
model=inference_model,
messages=[
SystemMessage(content="You are a helpful assistant."),
UserMessage(