feat(providers): support non-llama models for inference providers (#1200)

This PR begins the process of supporting non-llama models within Llama
Stack. We start simple by adding support for this functionality within a
few existing providers: fireworks, together and ollama.

## Test Plan

```bash
LLAMA_STACK_CONFIG=fireworks pytest -s -v tests/client-sdk/inference/test_text_inference.py \
  --inference-model accounts/fireworks/models/phi-3-vision-128k-instruct
```

^ this passes most of the tests but as expected fails the tool calling
related tests since they are very specific to Llama models

```
inference/test_text_inference.py::test_text_completion_streaming[accounts/fireworks/models/phi-3-vision-128k-instruct] PASSED
inference/test_text_inference.py::test_completion_log_probs_non_streaming[accounts/fireworks/models/phi-3-vision-128k-instruct] PASSED
inference/test_text_inference.py::test_completion_log_probs_streaming[accounts/fireworks/models/phi-3-vision-128k-instruct] PASSED
inference/test_text_inference.py::test_text_completion_structured_output[accounts/fireworks/models/phi-3-vision-128k-instruct-completion-01] PASSED
inference/test_text_inference.py::test_text_chat_completion_non_streaming[accounts/fireworks/models/phi-3-vision-128k-instruct-Which planet do humans live on?-Earth] PASSED
inference/test_text_inference.py::test_text_chat_completion_non_streaming[accounts/fireworks/models/phi-3-vision-128k-instruct-Which planet has rings around it with a name starting w
ith letter S?-Saturn] PASSED
inference/test_text_inference.py::test_text_chat_completion_streaming[accounts/fireworks/models/phi-3-vision-128k-instruct-What's the name of the Sun in latin?-Sol] PASSED
inference/test_text_inference.py::test_text_chat_completion_streaming[accounts/fireworks/models/phi-3-vision-128k-instruct-What is the name of the US captial?-Washington] PASSED
inference/test_text_inference.py::test_text_chat_completion_with_tool_calling_and_non_streaming[accounts/fireworks/models/phi-3-vision-128k-instruct] FAILED
inference/test_text_inference.py::test_text_chat_completion_with_tool_calling_and_streaming[accounts/fireworks/models/phi-3-vision-128k-instruct] FAILED
inference/test_text_inference.py::test_text_chat_completion_with_tool_choice_required[accounts/fireworks/models/phi-3-vision-128k-instruct] FAILED
inference/test_text_inference.py::test_text_chat_completion_with_tool_choice_none[accounts/fireworks/models/phi-3-vision-128k-instruct] PASSED
inference/test_text_inference.py::test_text_chat_completion_structured_output[accounts/fireworks/models/phi-3-vision-128k-instruct] ERROR
inference/test_text_inference.py::test_text_chat_completion_tool_calling_tools_not_in_request[accounts/fireworks/models/phi-3-vision-128k-instruct-True] PASSED
inference/test_text_inference.py::test_text_chat_completion_tool_calling_tools_not_in_request[accounts/fireworks/models/phi-3-vision-128k-instruct-False] PASSED
```
This commit is contained in:
Ashwin Bharambe 2025-02-21 13:21:28 -08:00 committed by GitHub
parent 9bbe34694d
commit ab54b8cd58
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
7 changed files with 103 additions and 74 deletions

View file

@ -209,15 +209,14 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProv
input_dict = {} input_dict = {}
media_present = request_has_media(request) media_present = request_has_media(request)
llama_model = self.get_llama_model(request.model)
if isinstance(request, ChatCompletionRequest): if isinstance(request, ChatCompletionRequest):
if media_present: if media_present or not llama_model:
input_dict["messages"] = [ input_dict["messages"] = [
await convert_message_to_openai_dict(m, download=True) for m in request.messages await convert_message_to_openai_dict(m, download=True) for m in request.messages
] ]
else: else:
input_dict["prompt"] = await chat_completion_request_to_prompt( input_dict["prompt"] = await chat_completion_request_to_prompt(request, llama_model)
request, self.get_llama_model(request.model)
)
else: else:
assert not media_present, "Fireworks does not support media for Completion requests" assert not media_present, "Fireworks does not support media for Completion requests"
input_dict["prompt"] = await completion_request_to_prompt(request) input_dict["prompt"] = await completion_request_to_prompt(request)

View file

@ -178,8 +178,9 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
input_dict = {} input_dict = {}
media_present = request_has_media(request) media_present = request_has_media(request)
llama_model = self.register_helper.get_llama_model(request.model)
if isinstance(request, ChatCompletionRequest): if isinstance(request, ChatCompletionRequest):
if media_present: if media_present or not llama_model:
contents = [await convert_message_to_openai_dict_for_ollama(m) for m in request.messages] contents = [await convert_message_to_openai_dict_for_ollama(m) for m in request.messages]
# flatten the list of lists # flatten the list of lists
input_dict["messages"] = [item for sublist in contents for item in sublist] input_dict["messages"] = [item for sublist in contents for item in sublist]
@ -187,7 +188,7 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
input_dict["raw"] = True input_dict["raw"] = True
input_dict["prompt"] = await chat_completion_request_to_prompt( input_dict["prompt"] = await chat_completion_request_to_prompt(
request, request,
self.register_helper.get_llama_model(request.model), llama_model,
) )
else: else:
assert not media_present, "Ollama does not support media for Completion requests" assert not media_present, "Ollama does not support media for Completion requests"

View file

@ -203,13 +203,12 @@ class TogetherInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProvi
async def _get_params(self, request: Union[ChatCompletionRequest, CompletionRequest]) -> dict: async def _get_params(self, request: Union[ChatCompletionRequest, CompletionRequest]) -> dict:
input_dict = {} input_dict = {}
media_present = request_has_media(request) media_present = request_has_media(request)
llama_model = self.get_llama_model(request.model)
if isinstance(request, ChatCompletionRequest): if isinstance(request, ChatCompletionRequest):
if media_present: if media_present or not llama_model:
input_dict["messages"] = [await convert_message_to_openai_dict(m) for m in request.messages] input_dict["messages"] = [await convert_message_to_openai_dict(m) for m in request.messages]
else: else:
input_dict["prompt"] = await chat_completion_request_to_prompt( input_dict["prompt"] = await chat_completion_request_to_prompt(request, llama_model)
request, self.get_llama_model(request.model)
)
else: else:
assert not media_present, "Together does not support media for Completion requests" assert not media_present, "Together does not support media for Completion requests"
input_dict["prompt"] = await completion_request_to_prompt(request) input_dict["prompt"] = await completion_request_to_prompt(request)

View file

@ -79,28 +79,28 @@ class ModelRegistryHelper(ModelsProtocolPrivate):
provider_resource_id = model.provider_resource_id provider_resource_id = model.provider_resource_id
else: else:
provider_resource_id = self.get_provider_model_id(model.provider_resource_id) provider_resource_id = self.get_provider_model_id(model.provider_resource_id)
if provider_resource_id: if provider_resource_id:
model.provider_resource_id = provider_resource_id model.provider_resource_id = provider_resource_id
else: else:
if model.metadata.get("llama_model") is None: llama_model = model.metadata.get("llama_model")
raise ValueError( if llama_model is None:
f"Model '{model.provider_resource_id}' is not available and no llama_model was specified in metadata. " return model
"Please specify a llama_model in metadata or use a supported model identifier"
)
existing_llama_model = self.get_llama_model(model.provider_resource_id) existing_llama_model = self.get_llama_model(model.provider_resource_id)
if existing_llama_model: if existing_llama_model:
if existing_llama_model != model.metadata["llama_model"]: if existing_llama_model != llama_model:
raise ValueError( raise ValueError(
f"Provider model id '{model.provider_resource_id}' is already registered to a different llama model: '{existing_llama_model}'" f"Provider model id '{model.provider_resource_id}' is already registered to a different llama model: '{existing_llama_model}'"
) )
else: else:
if model.metadata["llama_model"] not in ALL_HUGGINGFACE_REPOS_TO_MODEL_DESCRIPTOR: if llama_model not in ALL_HUGGINGFACE_REPOS_TO_MODEL_DESCRIPTOR:
raise ValueError( raise ValueError(
f"Invalid llama_model '{model.metadata['llama_model']}' specified in metadata. " f"Invalid llama_model '{llama_model}' specified in metadata. "
f"Must be one of: {', '.join(ALL_HUGGINGFACE_REPOS_TO_MODEL_DESCRIPTOR.keys())}" f"Must be one of: {', '.join(ALL_HUGGINGFACE_REPOS_TO_MODEL_DESCRIPTOR.keys())}"
) )
self.provider_id_to_llama_model_map[model.provider_resource_id] = ( self.provider_id_to_llama_model_map[model.provider_resource_id] = (
ALL_HUGGINGFACE_REPOS_TO_MODEL_DESCRIPTOR[model.metadata["llama_model"]] ALL_HUGGINGFACE_REPOS_TO_MODEL_DESCRIPTOR[llama_model]
) )
return model return model

View file

@ -42,28 +42,30 @@ def pytest_addoption(parser):
) )
parser.addoption( parser.addoption(
"--inference-model", "--inference-model",
action="store",
default=TEXT_MODEL, default=TEXT_MODEL,
help="Specify the inference model to use for testing", help="Specify the inference model to use for testing",
) )
parser.addoption( parser.addoption(
"--vision-inference-model", "--vision-inference-model",
action="store",
default=VISION_MODEL, default=VISION_MODEL,
help="Specify the vision inference model to use for testing", help="Specify the vision inference model to use for testing",
) )
parser.addoption( parser.addoption(
"--safety-shield", "--safety-shield",
action="store",
default="meta-llama/Llama-Guard-3-1B", default="meta-llama/Llama-Guard-3-1B",
help="Specify the safety shield model to use for testing", help="Specify the safety shield model to use for testing",
) )
parser.addoption( parser.addoption(
"--embedding-model", "--embedding-model",
action="store", default=None,
default=TEXT_MODEL,
help="Specify the embedding model to use for testing", help="Specify the embedding model to use for testing",
) )
parser.addoption(
"--embedding-dimension",
type=int,
default=384,
help="Output dimensionality of the embedding model to use for testing",
)
@pytest.fixture(scope="session") @pytest.fixture(scope="session")
@ -78,7 +80,7 @@ def provider_data():
@pytest.fixture(scope="session") @pytest.fixture(scope="session")
def llama_stack_client(provider_data): def llama_stack_client(provider_data, text_model_id):
if os.environ.get("LLAMA_STACK_CONFIG"): if os.environ.get("LLAMA_STACK_CONFIG"):
client = LlamaStackAsLibraryClient( client = LlamaStackAsLibraryClient(
get_env_or_fail("LLAMA_STACK_CONFIG"), get_env_or_fail("LLAMA_STACK_CONFIG"),
@ -95,6 +97,45 @@ def llama_stack_client(provider_data):
) )
else: else:
raise ValueError("LLAMA_STACK_CONFIG or LLAMA_STACK_BASE_URL must be set") raise ValueError("LLAMA_STACK_CONFIG or LLAMA_STACK_BASE_URL must be set")
return client
@pytest.fixture(scope="session")
def inference_provider_type(llama_stack_client):
providers = llama_stack_client.providers.list()
inference_providers = [p for p in providers if p.api == "inference"]
assert len(inference_providers) > 0, "No inference providers found"
return inference_providers[0].provider_type
@pytest.fixture(scope="session")
def client_with_models(llama_stack_client, text_model_id, vision_model_id, embedding_model_id, embedding_dimension):
client = llama_stack_client
providers = [p for p in client.providers.list() if p.api == "inference"]
assert len(providers) > 0, "No inference providers found"
inference_providers = [p.provider_id for p in providers if p.provider_type != "inline::sentence-transformers"]
if text_model_id:
client.models.register(model_id=text_model_id, provider_id=inference_providers[0])
if vision_model_id:
client.models.register(model_id=vision_model_id, provider_id=inference_providers[0])
if embedding_model_id and embedding_dimension:
# try to find a provider that supports embeddings, if sentence-transformers is not available
selected_provider = None
for p in providers:
if p.provider_type == "inline::sentence-transformers":
selected_provider = p
break
selected_provider = selected_provider or providers[0]
client.models.register(
model_id=embedding_model_id,
provider_id=selected_provider.provider_id,
model_type="embedding",
metadata={"embedding_dimension": embedding_dimension},
)
return client return client
@ -117,3 +158,9 @@ def pytest_generate_tests(metafunc):
[metafunc.config.getoption("--embedding-model")], [metafunc.config.getoption("--embedding-model")],
scope="session", scope="session",
) )
if "embedding_dimension" in metafunc.fixturenames:
metafunc.parametrize(
"embedding_dimension",
[metafunc.config.getoption("--embedding-dimension")],
scope="session",
)

View file

@ -28,14 +28,6 @@ def provider_tool_format(inference_provider_type):
) )
@pytest.fixture(scope="session")
def inference_provider_type(llama_stack_client):
providers = llama_stack_client.providers.list()
inference_providers = [p for p in providers if p.api == "inference"]
assert len(inference_providers) > 0, "No inference providers found"
return inference_providers[0].provider_type
@pytest.fixture @pytest.fixture
def get_weather_tool_definition(): def get_weather_tool_definition():
return { return {
@ -50,8 +42,8 @@ def get_weather_tool_definition():
} }
def test_text_completion_non_streaming(llama_stack_client, text_model_id): def test_text_completion_non_streaming(client_with_models, text_model_id):
response = llama_stack_client.inference.completion( response = client_with_models.inference.completion(
content="Complete the sentence using one word: Roses are red, violets are ", content="Complete the sentence using one word: Roses are red, violets are ",
stream=False, stream=False,
model_id=text_model_id, model_id=text_model_id,
@ -63,8 +55,8 @@ def test_text_completion_non_streaming(llama_stack_client, text_model_id):
# assert "blue" in response.content.lower().strip() # assert "blue" in response.content.lower().strip()
def test_text_completion_streaming(llama_stack_client, text_model_id): def test_text_completion_streaming(client_with_models, text_model_id):
response = llama_stack_client.inference.completion( response = client_with_models.inference.completion(
content="Complete the sentence using one word: Roses are red, violets are ", content="Complete the sentence using one word: Roses are red, violets are ",
stream=True, stream=True,
model_id=text_model_id, model_id=text_model_id,
@ -78,11 +70,11 @@ def test_text_completion_streaming(llama_stack_client, text_model_id):
assert len(content_str) > 10 assert len(content_str) > 10
def test_completion_log_probs_non_streaming(llama_stack_client, text_model_id, inference_provider_type): def test_completion_log_probs_non_streaming(client_with_models, text_model_id, inference_provider_type):
if inference_provider_type not in PROVIDER_LOGPROBS_TOP_K: if inference_provider_type not in PROVIDER_LOGPROBS_TOP_K:
pytest.xfail(f"{inference_provider_type} doesn't support log probs yet") pytest.xfail(f"{inference_provider_type} doesn't support log probs yet")
response = llama_stack_client.inference.completion( response = client_with_models.inference.completion(
content="Complete the sentence: Micheael Jordan is born in ", content="Complete the sentence: Micheael Jordan is born in ",
stream=False, stream=False,
model_id=text_model_id, model_id=text_model_id,
@ -98,11 +90,11 @@ def test_completion_log_probs_non_streaming(llama_stack_client, text_model_id, i
assert all(len(logprob.logprobs_by_token) == 1 for logprob in response.logprobs) assert all(len(logprob.logprobs_by_token) == 1 for logprob in response.logprobs)
def test_completion_log_probs_streaming(llama_stack_client, text_model_id, inference_provider_type): def test_completion_log_probs_streaming(client_with_models, text_model_id, inference_provider_type):
if inference_provider_type not in PROVIDER_LOGPROBS_TOP_K: if inference_provider_type not in PROVIDER_LOGPROBS_TOP_K:
pytest.xfail(f"{inference_provider_type} doesn't support log probs yet") pytest.xfail(f"{inference_provider_type} doesn't support log probs yet")
response = llama_stack_client.inference.completion( response = client_with_models.inference.completion(
content="Complete the sentence: Micheael Jordan is born in ", content="Complete the sentence: Micheael Jordan is born in ",
stream=True, stream=True,
model_id=text_model_id, model_id=text_model_id,
@ -123,7 +115,7 @@ def test_completion_log_probs_streaming(llama_stack_client, text_model_id, infer
@pytest.mark.parametrize("test_case", ["completion-01"]) @pytest.mark.parametrize("test_case", ["completion-01"])
def test_text_completion_structured_output(llama_stack_client, text_model_id, inference_provider_type, test_case): def test_text_completion_structured_output(client_with_models, text_model_id, test_case):
class AnswerFormat(BaseModel): class AnswerFormat(BaseModel):
name: str name: str
year_born: str year_born: str
@ -132,7 +124,7 @@ def test_text_completion_structured_output(llama_stack_client, text_model_id, in
tc = TestCase(test_case) tc = TestCase(test_case)
user_input = tc["user_input"] user_input = tc["user_input"]
response = llama_stack_client.inference.completion( response = client_with_models.inference.completion(
model_id=text_model_id, model_id=text_model_id,
content=user_input, content=user_input,
stream=False, stream=False,
@ -161,8 +153,8 @@ def test_text_completion_structured_output(llama_stack_client, text_model_id, in
), ),
], ],
) )
def test_text_chat_completion_non_streaming(llama_stack_client, text_model_id, question, expected): def test_text_chat_completion_non_streaming(client_with_models, text_model_id, question, expected):
response = llama_stack_client.inference.chat_completion( response = client_with_models.inference.chat_completion(
model_id=text_model_id, model_id=text_model_id,
messages=[ messages=[
{ {
@ -184,8 +176,8 @@ def test_text_chat_completion_non_streaming(llama_stack_client, text_model_id, q
("What is the name of the US captial?", "Washington"), ("What is the name of the US captial?", "Washington"),
], ],
) )
def test_text_chat_completion_streaming(llama_stack_client, text_model_id, question, expected): def test_text_chat_completion_streaming(client_with_models, text_model_id, question, expected):
response = llama_stack_client.inference.chat_completion( response = client_with_models.inference.chat_completion(
model_id=text_model_id, model_id=text_model_id,
messages=[{"role": "user", "content": question}], messages=[{"role": "user", "content": question}],
stream=True, stream=True,
@ -196,9 +188,9 @@ def test_text_chat_completion_streaming(llama_stack_client, text_model_id, quest
def test_text_chat_completion_with_tool_calling_and_non_streaming( def test_text_chat_completion_with_tool_calling_and_non_streaming(
llama_stack_client, text_model_id, get_weather_tool_definition, provider_tool_format client_with_models, text_model_id, get_weather_tool_definition, provider_tool_format
): ):
response = llama_stack_client.inference.chat_completion( response = client_with_models.inference.chat_completion(
model_id=text_model_id, model_id=text_model_id,
messages=[ messages=[
{"role": "system", "content": "You are a helpful assistant."}, {"role": "system", "content": "You are a helpful assistant."},
@ -233,9 +225,9 @@ def extract_tool_invocation_content(response):
def test_text_chat_completion_with_tool_calling_and_streaming( def test_text_chat_completion_with_tool_calling_and_streaming(
llama_stack_client, text_model_id, get_weather_tool_definition, provider_tool_format client_with_models, text_model_id, get_weather_tool_definition, provider_tool_format
): ):
response = llama_stack_client.inference.chat_completion( response = client_with_models.inference.chat_completion(
model_id=text_model_id, model_id=text_model_id,
messages=[ messages=[
{"role": "system", "content": "You are a helpful assistant."}, {"role": "system", "content": "You are a helpful assistant."},
@ -251,13 +243,12 @@ def test_text_chat_completion_with_tool_calling_and_streaming(
def test_text_chat_completion_with_tool_choice_required( def test_text_chat_completion_with_tool_choice_required(
llama_stack_client, client_with_models,
text_model_id, text_model_id,
get_weather_tool_definition, get_weather_tool_definition,
provider_tool_format, provider_tool_format,
inference_provider_type,
): ):
response = llama_stack_client.inference.chat_completion( response = client_with_models.inference.chat_completion(
model_id=text_model_id, model_id=text_model_id,
messages=[ messages=[
{"role": "system", "content": "You are a helpful assistant."}, {"role": "system", "content": "You are a helpful assistant."},
@ -275,9 +266,9 @@ def test_text_chat_completion_with_tool_choice_required(
def test_text_chat_completion_with_tool_choice_none( def test_text_chat_completion_with_tool_choice_none(
llama_stack_client, text_model_id, get_weather_tool_definition, provider_tool_format client_with_models, text_model_id, get_weather_tool_definition, provider_tool_format
): ):
response = llama_stack_client.inference.chat_completion( response = client_with_models.inference.chat_completion(
model_id=text_model_id, model_id=text_model_id,
messages=[ messages=[
{"role": "system", "content": "You are a helpful assistant."}, {"role": "system", "content": "You are a helpful assistant."},
@ -292,7 +283,7 @@ def test_text_chat_completion_with_tool_choice_none(
@pytest.mark.parametrize("test_case", ["chat_completion-01"]) @pytest.mark.parametrize("test_case", ["chat_completion-01"])
def test_text_chat_completion_structured_output(llama_stack_client, text_model_id, inference_provider_type, test_case): def test_text_chat_completion_structured_output(client_with_models, text_model_id, test_case):
class AnswerFormat(BaseModel): class AnswerFormat(BaseModel):
first_name: str first_name: str
last_name: str last_name: str
@ -301,7 +292,7 @@ def test_text_chat_completion_structured_output(llama_stack_client, text_model_i
tc = TestCase(test_case) tc = TestCase(test_case)
response = llama_stack_client.inference.chat_completion( response = client_with_models.inference.chat_completion(
model_id=text_model_id, model_id=text_model_id,
messages=tc["messages"], messages=tc["messages"],
response_format={ response_format={
@ -325,7 +316,7 @@ def test_text_chat_completion_structured_output(llama_stack_client, text_model_i
False, False,
], ],
) )
def test_text_chat_completion_tool_calling_tools_not_in_request(llama_stack_client, text_model_id, streaming): def test_text_chat_completion_tool_calling_tools_not_in_request(client_with_models, text_model_id, streaming):
# TODO: more dynamic lookup on tool_prompt_format for model family # TODO: more dynamic lookup on tool_prompt_format for model family
tool_prompt_format = "json" if "3.1" in text_model_id else "python_list" tool_prompt_format = "json" if "3.1" in text_model_id else "python_list"
request = { request = {
@ -381,7 +372,7 @@ def test_text_chat_completion_tool_calling_tools_not_in_request(llama_stack_clie
"stream": streaming, "stream": streaming,
} }
response = llama_stack_client.inference.chat_completion(**request) response = client_with_models.inference.chat_completion(**request)
if streaming: if streaming:
for chunk in response: for chunk in response:

View file

@ -10,14 +10,6 @@ import pathlib
import pytest import pytest
@pytest.fixture(scope="session")
def inference_provider_type(llama_stack_client):
providers = llama_stack_client.providers.list()
inference_providers = [p for p in providers if p.api == "inference"]
assert len(inference_providers) > 0, "No inference providers found"
return inference_providers[0].provider_type
@pytest.fixture @pytest.fixture
def image_path(): def image_path():
return pathlib.Path(__file__).parent / "dog.png" return pathlib.Path(__file__).parent / "dog.png"
@ -35,7 +27,7 @@ def base64_image_url(base64_image_data, image_path):
return f"data:image/{image_path.suffix[1:]};base64,{base64_image_data}" return f"data:image/{image_path.suffix[1:]};base64,{base64_image_data}"
def test_image_chat_completion_non_streaming(llama_stack_client, vision_model_id): def test_image_chat_completion_non_streaming(client_with_models, vision_model_id):
message = { message = {
"role": "user", "role": "user",
"content": [ "content": [
@ -53,7 +45,7 @@ def test_image_chat_completion_non_streaming(llama_stack_client, vision_model_id
}, },
], ],
} }
response = llama_stack_client.inference.chat_completion( response = client_with_models.inference.chat_completion(
model_id=vision_model_id, model_id=vision_model_id,
messages=[message], messages=[message],
stream=False, stream=False,
@ -63,7 +55,7 @@ def test_image_chat_completion_non_streaming(llama_stack_client, vision_model_id
assert any(expected in message_content for expected in {"dog", "puppy", "pup"}) assert any(expected in message_content for expected in {"dog", "puppy", "pup"})
def test_image_chat_completion_streaming(llama_stack_client, vision_model_id): def test_image_chat_completion_streaming(client_with_models, vision_model_id):
message = { message = {
"role": "user", "role": "user",
"content": [ "content": [
@ -81,7 +73,7 @@ def test_image_chat_completion_streaming(llama_stack_client, vision_model_id):
}, },
], ],
} }
response = llama_stack_client.inference.chat_completion( response = client_with_models.inference.chat_completion(
model_id=vision_model_id, model_id=vision_model_id,
messages=[message], messages=[message],
stream=True, stream=True,
@ -94,7 +86,7 @@ def test_image_chat_completion_streaming(llama_stack_client, vision_model_id):
@pytest.mark.parametrize("type_", ["url", "data"]) @pytest.mark.parametrize("type_", ["url", "data"])
def test_image_chat_completion_base64(llama_stack_client, vision_model_id, base64_image_data, base64_image_url, type_): def test_image_chat_completion_base64(client_with_models, vision_model_id, base64_image_data, base64_image_url, type_):
image_spec = { image_spec = {
"url": { "url": {
"type": "image", "type": "image",
@ -122,7 +114,7 @@ def test_image_chat_completion_base64(llama_stack_client, vision_model_id, base6
}, },
], ],
} }
response = llama_stack_client.inference.chat_completion( response = client_with_models.inference.chat_completion(
model_id=vision_model_id, model_id=vision_model_id,
messages=[message], messages=[message],
stream=False, stream=False,