forked from phoenix-oss/llama-stack-mirror
feat(providers): support non-llama models for inference providers (#1200)
This PR begins the process of supporting non-llama models within Llama Stack. We start simple by adding support for this functionality within a few existing providers: fireworks, together and ollama. ## Test Plan ```bash LLAMA_STACK_CONFIG=fireworks pytest -s -v tests/client-sdk/inference/test_text_inference.py \ --inference-model accounts/fireworks/models/phi-3-vision-128k-instruct ``` ^ this passes most of the tests but as expected fails the tool calling related tests since they are very specific to Llama models ``` inference/test_text_inference.py::test_text_completion_streaming[accounts/fireworks/models/phi-3-vision-128k-instruct] PASSED inference/test_text_inference.py::test_completion_log_probs_non_streaming[accounts/fireworks/models/phi-3-vision-128k-instruct] PASSED inference/test_text_inference.py::test_completion_log_probs_streaming[accounts/fireworks/models/phi-3-vision-128k-instruct] PASSED inference/test_text_inference.py::test_text_completion_structured_output[accounts/fireworks/models/phi-3-vision-128k-instruct-completion-01] PASSED inference/test_text_inference.py::test_text_chat_completion_non_streaming[accounts/fireworks/models/phi-3-vision-128k-instruct-Which planet do humans live on?-Earth] PASSED inference/test_text_inference.py::test_text_chat_completion_non_streaming[accounts/fireworks/models/phi-3-vision-128k-instruct-Which planet has rings around it with a name starting w ith letter S?-Saturn] PASSED inference/test_text_inference.py::test_text_chat_completion_streaming[accounts/fireworks/models/phi-3-vision-128k-instruct-What's the name of the Sun in latin?-Sol] PASSED inference/test_text_inference.py::test_text_chat_completion_streaming[accounts/fireworks/models/phi-3-vision-128k-instruct-What is the name of the US captial?-Washington] PASSED inference/test_text_inference.py::test_text_chat_completion_with_tool_calling_and_non_streaming[accounts/fireworks/models/phi-3-vision-128k-instruct] FAILED inference/test_text_inference.py::test_text_chat_completion_with_tool_calling_and_streaming[accounts/fireworks/models/phi-3-vision-128k-instruct] FAILED inference/test_text_inference.py::test_text_chat_completion_with_tool_choice_required[accounts/fireworks/models/phi-3-vision-128k-instruct] FAILED inference/test_text_inference.py::test_text_chat_completion_with_tool_choice_none[accounts/fireworks/models/phi-3-vision-128k-instruct] PASSED inference/test_text_inference.py::test_text_chat_completion_structured_output[accounts/fireworks/models/phi-3-vision-128k-instruct] ERROR inference/test_text_inference.py::test_text_chat_completion_tool_calling_tools_not_in_request[accounts/fireworks/models/phi-3-vision-128k-instruct-True] PASSED inference/test_text_inference.py::test_text_chat_completion_tool_calling_tools_not_in_request[accounts/fireworks/models/phi-3-vision-128k-instruct-False] PASSED ```
This commit is contained in:
parent
9bbe34694d
commit
ab54b8cd58
7 changed files with 103 additions and 74 deletions
|
@ -209,15 +209,14 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProv
|
|||
input_dict = {}
|
||||
media_present = request_has_media(request)
|
||||
|
||||
llama_model = self.get_llama_model(request.model)
|
||||
if isinstance(request, ChatCompletionRequest):
|
||||
if media_present:
|
||||
if media_present or not llama_model:
|
||||
input_dict["messages"] = [
|
||||
await convert_message_to_openai_dict(m, download=True) for m in request.messages
|
||||
]
|
||||
else:
|
||||
input_dict["prompt"] = await chat_completion_request_to_prompt(
|
||||
request, self.get_llama_model(request.model)
|
||||
)
|
||||
input_dict["prompt"] = await chat_completion_request_to_prompt(request, llama_model)
|
||||
else:
|
||||
assert not media_present, "Fireworks does not support media for Completion requests"
|
||||
input_dict["prompt"] = await completion_request_to_prompt(request)
|
||||
|
|
|
@ -178,8 +178,9 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
|
|||
|
||||
input_dict = {}
|
||||
media_present = request_has_media(request)
|
||||
llama_model = self.register_helper.get_llama_model(request.model)
|
||||
if isinstance(request, ChatCompletionRequest):
|
||||
if media_present:
|
||||
if media_present or not llama_model:
|
||||
contents = [await convert_message_to_openai_dict_for_ollama(m) for m in request.messages]
|
||||
# flatten the list of lists
|
||||
input_dict["messages"] = [item for sublist in contents for item in sublist]
|
||||
|
@ -187,7 +188,7 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
|
|||
input_dict["raw"] = True
|
||||
input_dict["prompt"] = await chat_completion_request_to_prompt(
|
||||
request,
|
||||
self.register_helper.get_llama_model(request.model),
|
||||
llama_model,
|
||||
)
|
||||
else:
|
||||
assert not media_present, "Ollama does not support media for Completion requests"
|
||||
|
|
|
@ -203,13 +203,12 @@ class TogetherInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProvi
|
|||
async def _get_params(self, request: Union[ChatCompletionRequest, CompletionRequest]) -> dict:
|
||||
input_dict = {}
|
||||
media_present = request_has_media(request)
|
||||
llama_model = self.get_llama_model(request.model)
|
||||
if isinstance(request, ChatCompletionRequest):
|
||||
if media_present:
|
||||
if media_present or not llama_model:
|
||||
input_dict["messages"] = [await convert_message_to_openai_dict(m) for m in request.messages]
|
||||
else:
|
||||
input_dict["prompt"] = await chat_completion_request_to_prompt(
|
||||
request, self.get_llama_model(request.model)
|
||||
)
|
||||
input_dict["prompt"] = await chat_completion_request_to_prompt(request, llama_model)
|
||||
else:
|
||||
assert not media_present, "Together does not support media for Completion requests"
|
||||
input_dict["prompt"] = await completion_request_to_prompt(request)
|
||||
|
|
|
@ -79,28 +79,28 @@ class ModelRegistryHelper(ModelsProtocolPrivate):
|
|||
provider_resource_id = model.provider_resource_id
|
||||
else:
|
||||
provider_resource_id = self.get_provider_model_id(model.provider_resource_id)
|
||||
|
||||
if provider_resource_id:
|
||||
model.provider_resource_id = provider_resource_id
|
||||
else:
|
||||
if model.metadata.get("llama_model") is None:
|
||||
raise ValueError(
|
||||
f"Model '{model.provider_resource_id}' is not available and no llama_model was specified in metadata. "
|
||||
"Please specify a llama_model in metadata or use a supported model identifier"
|
||||
)
|
||||
llama_model = model.metadata.get("llama_model")
|
||||
if llama_model is None:
|
||||
return model
|
||||
|
||||
existing_llama_model = self.get_llama_model(model.provider_resource_id)
|
||||
if existing_llama_model:
|
||||
if existing_llama_model != model.metadata["llama_model"]:
|
||||
if existing_llama_model != llama_model:
|
||||
raise ValueError(
|
||||
f"Provider model id '{model.provider_resource_id}' is already registered to a different llama model: '{existing_llama_model}'"
|
||||
)
|
||||
else:
|
||||
if model.metadata["llama_model"] not in ALL_HUGGINGFACE_REPOS_TO_MODEL_DESCRIPTOR:
|
||||
if llama_model not in ALL_HUGGINGFACE_REPOS_TO_MODEL_DESCRIPTOR:
|
||||
raise ValueError(
|
||||
f"Invalid llama_model '{model.metadata['llama_model']}' specified in metadata. "
|
||||
f"Invalid llama_model '{llama_model}' specified in metadata. "
|
||||
f"Must be one of: {', '.join(ALL_HUGGINGFACE_REPOS_TO_MODEL_DESCRIPTOR.keys())}"
|
||||
)
|
||||
self.provider_id_to_llama_model_map[model.provider_resource_id] = (
|
||||
ALL_HUGGINGFACE_REPOS_TO_MODEL_DESCRIPTOR[model.metadata["llama_model"]]
|
||||
ALL_HUGGINGFACE_REPOS_TO_MODEL_DESCRIPTOR[llama_model]
|
||||
)
|
||||
|
||||
return model
|
||||
|
|
|
@ -42,28 +42,30 @@ def pytest_addoption(parser):
|
|||
)
|
||||
parser.addoption(
|
||||
"--inference-model",
|
||||
action="store",
|
||||
default=TEXT_MODEL,
|
||||
help="Specify the inference model to use for testing",
|
||||
)
|
||||
parser.addoption(
|
||||
"--vision-inference-model",
|
||||
action="store",
|
||||
default=VISION_MODEL,
|
||||
help="Specify the vision inference model to use for testing",
|
||||
)
|
||||
parser.addoption(
|
||||
"--safety-shield",
|
||||
action="store",
|
||||
default="meta-llama/Llama-Guard-3-1B",
|
||||
help="Specify the safety shield model to use for testing",
|
||||
)
|
||||
parser.addoption(
|
||||
"--embedding-model",
|
||||
action="store",
|
||||
default=TEXT_MODEL,
|
||||
default=None,
|
||||
help="Specify the embedding model to use for testing",
|
||||
)
|
||||
parser.addoption(
|
||||
"--embedding-dimension",
|
||||
type=int,
|
||||
default=384,
|
||||
help="Output dimensionality of the embedding model to use for testing",
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
|
@ -78,7 +80,7 @@ def provider_data():
|
|||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def llama_stack_client(provider_data):
|
||||
def llama_stack_client(provider_data, text_model_id):
|
||||
if os.environ.get("LLAMA_STACK_CONFIG"):
|
||||
client = LlamaStackAsLibraryClient(
|
||||
get_env_or_fail("LLAMA_STACK_CONFIG"),
|
||||
|
@ -95,6 +97,45 @@ def llama_stack_client(provider_data):
|
|||
)
|
||||
else:
|
||||
raise ValueError("LLAMA_STACK_CONFIG or LLAMA_STACK_BASE_URL must be set")
|
||||
|
||||
return client
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def inference_provider_type(llama_stack_client):
|
||||
providers = llama_stack_client.providers.list()
|
||||
inference_providers = [p for p in providers if p.api == "inference"]
|
||||
assert len(inference_providers) > 0, "No inference providers found"
|
||||
return inference_providers[0].provider_type
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def client_with_models(llama_stack_client, text_model_id, vision_model_id, embedding_model_id, embedding_dimension):
|
||||
client = llama_stack_client
|
||||
|
||||
providers = [p for p in client.providers.list() if p.api == "inference"]
|
||||
assert len(providers) > 0, "No inference providers found"
|
||||
inference_providers = [p.provider_id for p in providers if p.provider_type != "inline::sentence-transformers"]
|
||||
if text_model_id:
|
||||
client.models.register(model_id=text_model_id, provider_id=inference_providers[0])
|
||||
if vision_model_id:
|
||||
client.models.register(model_id=vision_model_id, provider_id=inference_providers[0])
|
||||
|
||||
if embedding_model_id and embedding_dimension:
|
||||
# try to find a provider that supports embeddings, if sentence-transformers is not available
|
||||
selected_provider = None
|
||||
for p in providers:
|
||||
if p.provider_type == "inline::sentence-transformers":
|
||||
selected_provider = p
|
||||
break
|
||||
|
||||
selected_provider = selected_provider or providers[0]
|
||||
client.models.register(
|
||||
model_id=embedding_model_id,
|
||||
provider_id=selected_provider.provider_id,
|
||||
model_type="embedding",
|
||||
metadata={"embedding_dimension": embedding_dimension},
|
||||
)
|
||||
return client
|
||||
|
||||
|
||||
|
@ -117,3 +158,9 @@ def pytest_generate_tests(metafunc):
|
|||
[metafunc.config.getoption("--embedding-model")],
|
||||
scope="session",
|
||||
)
|
||||
if "embedding_dimension" in metafunc.fixturenames:
|
||||
metafunc.parametrize(
|
||||
"embedding_dimension",
|
||||
[metafunc.config.getoption("--embedding-dimension")],
|
||||
scope="session",
|
||||
)
|
||||
|
|
|
@ -28,14 +28,6 @@ def provider_tool_format(inference_provider_type):
|
|||
)
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def inference_provider_type(llama_stack_client):
|
||||
providers = llama_stack_client.providers.list()
|
||||
inference_providers = [p for p in providers if p.api == "inference"]
|
||||
assert len(inference_providers) > 0, "No inference providers found"
|
||||
return inference_providers[0].provider_type
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def get_weather_tool_definition():
|
||||
return {
|
||||
|
@ -50,8 +42,8 @@ def get_weather_tool_definition():
|
|||
}
|
||||
|
||||
|
||||
def test_text_completion_non_streaming(llama_stack_client, text_model_id):
|
||||
response = llama_stack_client.inference.completion(
|
||||
def test_text_completion_non_streaming(client_with_models, text_model_id):
|
||||
response = client_with_models.inference.completion(
|
||||
content="Complete the sentence using one word: Roses are red, violets are ",
|
||||
stream=False,
|
||||
model_id=text_model_id,
|
||||
|
@ -63,8 +55,8 @@ def test_text_completion_non_streaming(llama_stack_client, text_model_id):
|
|||
# assert "blue" in response.content.lower().strip()
|
||||
|
||||
|
||||
def test_text_completion_streaming(llama_stack_client, text_model_id):
|
||||
response = llama_stack_client.inference.completion(
|
||||
def test_text_completion_streaming(client_with_models, text_model_id):
|
||||
response = client_with_models.inference.completion(
|
||||
content="Complete the sentence using one word: Roses are red, violets are ",
|
||||
stream=True,
|
||||
model_id=text_model_id,
|
||||
|
@ -78,11 +70,11 @@ def test_text_completion_streaming(llama_stack_client, text_model_id):
|
|||
assert len(content_str) > 10
|
||||
|
||||
|
||||
def test_completion_log_probs_non_streaming(llama_stack_client, text_model_id, inference_provider_type):
|
||||
def test_completion_log_probs_non_streaming(client_with_models, text_model_id, inference_provider_type):
|
||||
if inference_provider_type not in PROVIDER_LOGPROBS_TOP_K:
|
||||
pytest.xfail(f"{inference_provider_type} doesn't support log probs yet")
|
||||
|
||||
response = llama_stack_client.inference.completion(
|
||||
response = client_with_models.inference.completion(
|
||||
content="Complete the sentence: Micheael Jordan is born in ",
|
||||
stream=False,
|
||||
model_id=text_model_id,
|
||||
|
@ -98,11 +90,11 @@ def test_completion_log_probs_non_streaming(llama_stack_client, text_model_id, i
|
|||
assert all(len(logprob.logprobs_by_token) == 1 for logprob in response.logprobs)
|
||||
|
||||
|
||||
def test_completion_log_probs_streaming(llama_stack_client, text_model_id, inference_provider_type):
|
||||
def test_completion_log_probs_streaming(client_with_models, text_model_id, inference_provider_type):
|
||||
if inference_provider_type not in PROVIDER_LOGPROBS_TOP_K:
|
||||
pytest.xfail(f"{inference_provider_type} doesn't support log probs yet")
|
||||
|
||||
response = llama_stack_client.inference.completion(
|
||||
response = client_with_models.inference.completion(
|
||||
content="Complete the sentence: Micheael Jordan is born in ",
|
||||
stream=True,
|
||||
model_id=text_model_id,
|
||||
|
@ -123,7 +115,7 @@ def test_completion_log_probs_streaming(llama_stack_client, text_model_id, infer
|
|||
|
||||
|
||||
@pytest.mark.parametrize("test_case", ["completion-01"])
|
||||
def test_text_completion_structured_output(llama_stack_client, text_model_id, inference_provider_type, test_case):
|
||||
def test_text_completion_structured_output(client_with_models, text_model_id, test_case):
|
||||
class AnswerFormat(BaseModel):
|
||||
name: str
|
||||
year_born: str
|
||||
|
@ -132,7 +124,7 @@ def test_text_completion_structured_output(llama_stack_client, text_model_id, in
|
|||
tc = TestCase(test_case)
|
||||
|
||||
user_input = tc["user_input"]
|
||||
response = llama_stack_client.inference.completion(
|
||||
response = client_with_models.inference.completion(
|
||||
model_id=text_model_id,
|
||||
content=user_input,
|
||||
stream=False,
|
||||
|
@ -161,8 +153,8 @@ def test_text_completion_structured_output(llama_stack_client, text_model_id, in
|
|||
),
|
||||
],
|
||||
)
|
||||
def test_text_chat_completion_non_streaming(llama_stack_client, text_model_id, question, expected):
|
||||
response = llama_stack_client.inference.chat_completion(
|
||||
def test_text_chat_completion_non_streaming(client_with_models, text_model_id, question, expected):
|
||||
response = client_with_models.inference.chat_completion(
|
||||
model_id=text_model_id,
|
||||
messages=[
|
||||
{
|
||||
|
@ -184,8 +176,8 @@ def test_text_chat_completion_non_streaming(llama_stack_client, text_model_id, q
|
|||
("What is the name of the US captial?", "Washington"),
|
||||
],
|
||||
)
|
||||
def test_text_chat_completion_streaming(llama_stack_client, text_model_id, question, expected):
|
||||
response = llama_stack_client.inference.chat_completion(
|
||||
def test_text_chat_completion_streaming(client_with_models, text_model_id, question, expected):
|
||||
response = client_with_models.inference.chat_completion(
|
||||
model_id=text_model_id,
|
||||
messages=[{"role": "user", "content": question}],
|
||||
stream=True,
|
||||
|
@ -196,9 +188,9 @@ def test_text_chat_completion_streaming(llama_stack_client, text_model_id, quest
|
|||
|
||||
|
||||
def test_text_chat_completion_with_tool_calling_and_non_streaming(
|
||||
llama_stack_client, text_model_id, get_weather_tool_definition, provider_tool_format
|
||||
client_with_models, text_model_id, get_weather_tool_definition, provider_tool_format
|
||||
):
|
||||
response = llama_stack_client.inference.chat_completion(
|
||||
response = client_with_models.inference.chat_completion(
|
||||
model_id=text_model_id,
|
||||
messages=[
|
||||
{"role": "system", "content": "You are a helpful assistant."},
|
||||
|
@ -233,9 +225,9 @@ def extract_tool_invocation_content(response):
|
|||
|
||||
|
||||
def test_text_chat_completion_with_tool_calling_and_streaming(
|
||||
llama_stack_client, text_model_id, get_weather_tool_definition, provider_tool_format
|
||||
client_with_models, text_model_id, get_weather_tool_definition, provider_tool_format
|
||||
):
|
||||
response = llama_stack_client.inference.chat_completion(
|
||||
response = client_with_models.inference.chat_completion(
|
||||
model_id=text_model_id,
|
||||
messages=[
|
||||
{"role": "system", "content": "You are a helpful assistant."},
|
||||
|
@ -251,13 +243,12 @@ def test_text_chat_completion_with_tool_calling_and_streaming(
|
|||
|
||||
|
||||
def test_text_chat_completion_with_tool_choice_required(
|
||||
llama_stack_client,
|
||||
client_with_models,
|
||||
text_model_id,
|
||||
get_weather_tool_definition,
|
||||
provider_tool_format,
|
||||
inference_provider_type,
|
||||
):
|
||||
response = llama_stack_client.inference.chat_completion(
|
||||
response = client_with_models.inference.chat_completion(
|
||||
model_id=text_model_id,
|
||||
messages=[
|
||||
{"role": "system", "content": "You are a helpful assistant."},
|
||||
|
@ -275,9 +266,9 @@ def test_text_chat_completion_with_tool_choice_required(
|
|||
|
||||
|
||||
def test_text_chat_completion_with_tool_choice_none(
|
||||
llama_stack_client, text_model_id, get_weather_tool_definition, provider_tool_format
|
||||
client_with_models, text_model_id, get_weather_tool_definition, provider_tool_format
|
||||
):
|
||||
response = llama_stack_client.inference.chat_completion(
|
||||
response = client_with_models.inference.chat_completion(
|
||||
model_id=text_model_id,
|
||||
messages=[
|
||||
{"role": "system", "content": "You are a helpful assistant."},
|
||||
|
@ -292,7 +283,7 @@ def test_text_chat_completion_with_tool_choice_none(
|
|||
|
||||
|
||||
@pytest.mark.parametrize("test_case", ["chat_completion-01"])
|
||||
def test_text_chat_completion_structured_output(llama_stack_client, text_model_id, inference_provider_type, test_case):
|
||||
def test_text_chat_completion_structured_output(client_with_models, text_model_id, test_case):
|
||||
class AnswerFormat(BaseModel):
|
||||
first_name: str
|
||||
last_name: str
|
||||
|
@ -301,7 +292,7 @@ def test_text_chat_completion_structured_output(llama_stack_client, text_model_i
|
|||
|
||||
tc = TestCase(test_case)
|
||||
|
||||
response = llama_stack_client.inference.chat_completion(
|
||||
response = client_with_models.inference.chat_completion(
|
||||
model_id=text_model_id,
|
||||
messages=tc["messages"],
|
||||
response_format={
|
||||
|
@ -325,7 +316,7 @@ def test_text_chat_completion_structured_output(llama_stack_client, text_model_i
|
|||
False,
|
||||
],
|
||||
)
|
||||
def test_text_chat_completion_tool_calling_tools_not_in_request(llama_stack_client, text_model_id, streaming):
|
||||
def test_text_chat_completion_tool_calling_tools_not_in_request(client_with_models, text_model_id, streaming):
|
||||
# TODO: more dynamic lookup on tool_prompt_format for model family
|
||||
tool_prompt_format = "json" if "3.1" in text_model_id else "python_list"
|
||||
request = {
|
||||
|
@ -381,7 +372,7 @@ def test_text_chat_completion_tool_calling_tools_not_in_request(llama_stack_clie
|
|||
"stream": streaming,
|
||||
}
|
||||
|
||||
response = llama_stack_client.inference.chat_completion(**request)
|
||||
response = client_with_models.inference.chat_completion(**request)
|
||||
|
||||
if streaming:
|
||||
for chunk in response:
|
||||
|
|
|
@ -10,14 +10,6 @@ import pathlib
|
|||
import pytest
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def inference_provider_type(llama_stack_client):
|
||||
providers = llama_stack_client.providers.list()
|
||||
inference_providers = [p for p in providers if p.api == "inference"]
|
||||
assert len(inference_providers) > 0, "No inference providers found"
|
||||
return inference_providers[0].provider_type
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def image_path():
|
||||
return pathlib.Path(__file__).parent / "dog.png"
|
||||
|
@ -35,7 +27,7 @@ def base64_image_url(base64_image_data, image_path):
|
|||
return f"data:image/{image_path.suffix[1:]};base64,{base64_image_data}"
|
||||
|
||||
|
||||
def test_image_chat_completion_non_streaming(llama_stack_client, vision_model_id):
|
||||
def test_image_chat_completion_non_streaming(client_with_models, vision_model_id):
|
||||
message = {
|
||||
"role": "user",
|
||||
"content": [
|
||||
|
@ -53,7 +45,7 @@ def test_image_chat_completion_non_streaming(llama_stack_client, vision_model_id
|
|||
},
|
||||
],
|
||||
}
|
||||
response = llama_stack_client.inference.chat_completion(
|
||||
response = client_with_models.inference.chat_completion(
|
||||
model_id=vision_model_id,
|
||||
messages=[message],
|
||||
stream=False,
|
||||
|
@ -63,7 +55,7 @@ def test_image_chat_completion_non_streaming(llama_stack_client, vision_model_id
|
|||
assert any(expected in message_content for expected in {"dog", "puppy", "pup"})
|
||||
|
||||
|
||||
def test_image_chat_completion_streaming(llama_stack_client, vision_model_id):
|
||||
def test_image_chat_completion_streaming(client_with_models, vision_model_id):
|
||||
message = {
|
||||
"role": "user",
|
||||
"content": [
|
||||
|
@ -81,7 +73,7 @@ def test_image_chat_completion_streaming(llama_stack_client, vision_model_id):
|
|||
},
|
||||
],
|
||||
}
|
||||
response = llama_stack_client.inference.chat_completion(
|
||||
response = client_with_models.inference.chat_completion(
|
||||
model_id=vision_model_id,
|
||||
messages=[message],
|
||||
stream=True,
|
||||
|
@ -94,7 +86,7 @@ def test_image_chat_completion_streaming(llama_stack_client, vision_model_id):
|
|||
|
||||
|
||||
@pytest.mark.parametrize("type_", ["url", "data"])
|
||||
def test_image_chat_completion_base64(llama_stack_client, vision_model_id, base64_image_data, base64_image_url, type_):
|
||||
def test_image_chat_completion_base64(client_with_models, vision_model_id, base64_image_data, base64_image_url, type_):
|
||||
image_spec = {
|
||||
"url": {
|
||||
"type": "image",
|
||||
|
@ -122,7 +114,7 @@ def test_image_chat_completion_base64(llama_stack_client, vision_model_id, base6
|
|||
},
|
||||
],
|
||||
}
|
||||
response = llama_stack_client.inference.chat_completion(
|
||||
response = client_with_models.inference.chat_completion(
|
||||
model_id=vision_model_id,
|
||||
messages=[message],
|
||||
stream=False,
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue