mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-10-16 23:03:49 +00:00
Works with ollama 0.4.0 pre-release with the vision model
This commit is contained in:
parent
03013dafc1
commit
d543eb442b
5 changed files with 137 additions and 57 deletions
|
@ -19,12 +19,11 @@ def pytest_addoption(parser):
|
|||
|
||||
|
||||
def pytest_configure(config):
|
||||
config.addinivalue_line(
|
||||
"markers", "llama_8b: mark test to run only with the given model"
|
||||
)
|
||||
config.addinivalue_line(
|
||||
"markers", "llama_3b: mark test to run only with the given model"
|
||||
)
|
||||
for model in ["llama_8b", "llama_3b", "llama_vision"]:
|
||||
config.addinivalue_line(
|
||||
"markers", f"{model}: mark test to run only with the given model"
|
||||
)
|
||||
|
||||
for fixture_name in INFERENCE_FIXTURES:
|
||||
config.addinivalue_line(
|
||||
"markers",
|
||||
|
@ -37,6 +36,14 @@ MODEL_PARAMS = [
|
|||
pytest.param("Llama3.2-3B-Instruct", marks=pytest.mark.llama_3b, id="llama_3b"),
|
||||
]
|
||||
|
||||
VISION_MODEL_PARAMS = [
|
||||
pytest.param(
|
||||
"Llama3.2-11B-Vision-Instruct",
|
||||
marks=pytest.mark.llama_vision,
|
||||
id="llama_vision",
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
def pytest_generate_tests(metafunc):
|
||||
if "inference_model" in metafunc.fixturenames:
|
||||
|
@ -44,7 +51,11 @@ def pytest_generate_tests(metafunc):
|
|||
if model:
|
||||
params = [pytest.param(model, id="")]
|
||||
else:
|
||||
params = MODEL_PARAMS
|
||||
cls_name = metafunc.cls.__name__
|
||||
if "Vision" in cls_name:
|
||||
params = VISION_MODEL_PARAMS
|
||||
else:
|
||||
params = MODEL_PARAMS
|
||||
|
||||
metafunc.parametrize(
|
||||
"inference_model",
|
||||
|
|
|
@ -29,11 +29,6 @@ def inference_model(request):
|
|||
return request.config.getoption("--inference-model", None)
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def vision_inference_model():
|
||||
return "Llama3.2-11B-Vision-Instruct"
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def inference_remote() -> ProviderFixture:
|
||||
return remote_stack_fixture()
|
||||
|
|
|
@ -21,19 +21,20 @@ THIS_DIR = Path(__file__).parent
|
|||
class TestVisionModelInference:
|
||||
@pytest.mark.asyncio
|
||||
async def test_vision_chat_completion_non_streaming(
|
||||
self, vision_inference_model, inference_stack
|
||||
self, inference_model, inference_stack
|
||||
):
|
||||
inference_impl, _ = inference_stack
|
||||
|
||||
provider = inference_impl.routing_table.get_provider_impl(
|
||||
vision_inference_model
|
||||
)
|
||||
provider = inference_impl.routing_table.get_provider_impl(inference_model)
|
||||
if provider.__provider_spec__.provider_type not in (
|
||||
"meta-reference",
|
||||
"remote::together",
|
||||
"remote::fireworks",
|
||||
"remote::ollama",
|
||||
):
|
||||
pytest.skip("Other inference providers don't support completion() yet")
|
||||
pytest.skip(
|
||||
"Other inference providers don't support vision chat completion() yet"
|
||||
)
|
||||
|
||||
images = [
|
||||
ImageMedia(image=PIL_Image.open(THIS_DIR / "pasta.jpeg")),
|
||||
|
@ -51,7 +52,7 @@ class TestVisionModelInference:
|
|||
]
|
||||
for image, expected_strings in zip(images, expected_strings_to_check):
|
||||
response = await inference_impl.chat_completion(
|
||||
model=vision_inference_model,
|
||||
model=inference_model,
|
||||
messages=[
|
||||
SystemMessage(content="You are a helpful assistant."),
|
||||
UserMessage(
|
||||
|
@ -69,19 +70,20 @@ class TestVisionModelInference:
|
|||
|
||||
@pytest.mark.asyncio
|
||||
async def test_vision_chat_completion_streaming(
|
||||
self, vision_inference_model, inference_stack
|
||||
self, inference_model, inference_stack
|
||||
):
|
||||
inference_impl, _ = inference_stack
|
||||
|
||||
provider = inference_impl.routing_table.get_provider_impl(
|
||||
vision_inference_model
|
||||
)
|
||||
provider = inference_impl.routing_table.get_provider_impl(inference_model)
|
||||
if provider.__provider_spec__.provider_type not in (
|
||||
"meta-reference",
|
||||
"remote::together",
|
||||
"remote::fireworks",
|
||||
"remote::ollama",
|
||||
):
|
||||
pytest.skip("Other inference providers don't support completion() yet")
|
||||
pytest.skip(
|
||||
"Other inference providers don't support vision chat completion() yet"
|
||||
)
|
||||
|
||||
images = [
|
||||
ImageMedia(
|
||||
|
@ -97,7 +99,7 @@ class TestVisionModelInference:
|
|||
response = [
|
||||
r
|
||||
async for r in await inference_impl.chat_completion(
|
||||
model=vision_inference_model,
|
||||
model=inference_model,
|
||||
messages=[
|
||||
SystemMessage(content="You are a helpful assistant."),
|
||||
UserMessage(
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue