feat(providers): support non-llama models for inference providers (#1200)

This PR begins the process of supporting non-llama models within Llama
Stack. We start simple by adding support for this functionality within a
few existing providers: fireworks, together and ollama.

## Test Plan

```bash
LLAMA_STACK_CONFIG=fireworks pytest -s -v tests/client-sdk/inference/test_text_inference.py \
  --inference-model accounts/fireworks/models/phi-3-vision-128k-instruct
```

^ this passes most of the tests but as expected fails the tool calling
related tests since they are very specific to Llama models

```
inference/test_text_inference.py::test_text_completion_streaming[accounts/fireworks/models/phi-3-vision-128k-instruct] PASSED
inference/test_text_inference.py::test_completion_log_probs_non_streaming[accounts/fireworks/models/phi-3-vision-128k-instruct] PASSED
inference/test_text_inference.py::test_completion_log_probs_streaming[accounts/fireworks/models/phi-3-vision-128k-instruct] PASSED
inference/test_text_inference.py::test_text_completion_structured_output[accounts/fireworks/models/phi-3-vision-128k-instruct-completion-01] PASSED
inference/test_text_inference.py::test_text_chat_completion_non_streaming[accounts/fireworks/models/phi-3-vision-128k-instruct-Which planet do humans live on?-Earth] PASSED
inference/test_text_inference.py::test_text_chat_completion_non_streaming[accounts/fireworks/models/phi-3-vision-128k-instruct-Which planet has rings around it with a name starting w
ith letter S?-Saturn] PASSED
inference/test_text_inference.py::test_text_chat_completion_streaming[accounts/fireworks/models/phi-3-vision-128k-instruct-What's the name of the Sun in latin?-Sol] PASSED
inference/test_text_inference.py::test_text_chat_completion_streaming[accounts/fireworks/models/phi-3-vision-128k-instruct-What is the name of the US captial?-Washington] PASSED
inference/test_text_inference.py::test_text_chat_completion_with_tool_calling_and_non_streaming[accounts/fireworks/models/phi-3-vision-128k-instruct] FAILED
inference/test_text_inference.py::test_text_chat_completion_with_tool_calling_and_streaming[accounts/fireworks/models/phi-3-vision-128k-instruct] FAILED
inference/test_text_inference.py::test_text_chat_completion_with_tool_choice_required[accounts/fireworks/models/phi-3-vision-128k-instruct] FAILED
inference/test_text_inference.py::test_text_chat_completion_with_tool_choice_none[accounts/fireworks/models/phi-3-vision-128k-instruct] PASSED
inference/test_text_inference.py::test_text_chat_completion_structured_output[accounts/fireworks/models/phi-3-vision-128k-instruct] ERROR
inference/test_text_inference.py::test_text_chat_completion_tool_calling_tools_not_in_request[accounts/fireworks/models/phi-3-vision-128k-instruct-True] PASSED
inference/test_text_inference.py::test_text_chat_completion_tool_calling_tools_not_in_request[accounts/fireworks/models/phi-3-vision-128k-instruct-False] PASSED
```
This commit is contained in:
Ashwin Bharambe 2025-02-21 13:21:28 -08:00 committed by GitHub
parent 9bbe34694d
commit ab54b8cd58
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
7 changed files with 103 additions and 74 deletions

View file

@ -209,15 +209,14 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProv
input_dict = {}
media_present = request_has_media(request)
llama_model = self.get_llama_model(request.model)
if isinstance(request, ChatCompletionRequest):
if media_present:
if media_present or not llama_model:
input_dict["messages"] = [
await convert_message_to_openai_dict(m, download=True) for m in request.messages
]
else:
input_dict["prompt"] = await chat_completion_request_to_prompt(
request, self.get_llama_model(request.model)
)
input_dict["prompt"] = await chat_completion_request_to_prompt(request, llama_model)
else:
assert not media_present, "Fireworks does not support media for Completion requests"
input_dict["prompt"] = await completion_request_to_prompt(request)

View file

@ -178,8 +178,9 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
input_dict = {}
media_present = request_has_media(request)
llama_model = self.register_helper.get_llama_model(request.model)
if isinstance(request, ChatCompletionRequest):
if media_present:
if media_present or not llama_model:
contents = [await convert_message_to_openai_dict_for_ollama(m) for m in request.messages]
# flatten the list of lists
input_dict["messages"] = [item for sublist in contents for item in sublist]
@ -187,7 +188,7 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
input_dict["raw"] = True
input_dict["prompt"] = await chat_completion_request_to_prompt(
request,
self.register_helper.get_llama_model(request.model),
llama_model,
)
else:
assert not media_present, "Ollama does not support media for Completion requests"

View file

@ -203,13 +203,12 @@ class TogetherInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProvi
async def _get_params(self, request: Union[ChatCompletionRequest, CompletionRequest]) -> dict:
input_dict = {}
media_present = request_has_media(request)
llama_model = self.get_llama_model(request.model)
if isinstance(request, ChatCompletionRequest):
if media_present:
if media_present or not llama_model:
input_dict["messages"] = [await convert_message_to_openai_dict(m) for m in request.messages]
else:
input_dict["prompt"] = await chat_completion_request_to_prompt(
request, self.get_llama_model(request.model)
)
input_dict["prompt"] = await chat_completion_request_to_prompt(request, llama_model)
else:
assert not media_present, "Together does not support media for Completion requests"
input_dict["prompt"] = await completion_request_to_prompt(request)

View file

@ -79,28 +79,28 @@ class ModelRegistryHelper(ModelsProtocolPrivate):
provider_resource_id = model.provider_resource_id
else:
provider_resource_id = self.get_provider_model_id(model.provider_resource_id)
if provider_resource_id:
model.provider_resource_id = provider_resource_id
else:
if model.metadata.get("llama_model") is None:
raise ValueError(
f"Model '{model.provider_resource_id}' is not available and no llama_model was specified in metadata. "
"Please specify a llama_model in metadata or use a supported model identifier"
)
llama_model = model.metadata.get("llama_model")
if llama_model is None:
return model
existing_llama_model = self.get_llama_model(model.provider_resource_id)
if existing_llama_model:
if existing_llama_model != model.metadata["llama_model"]:
if existing_llama_model != llama_model:
raise ValueError(
f"Provider model id '{model.provider_resource_id}' is already registered to a different llama model: '{existing_llama_model}'"
)
else:
if model.metadata["llama_model"] not in ALL_HUGGINGFACE_REPOS_TO_MODEL_DESCRIPTOR:
if llama_model not in ALL_HUGGINGFACE_REPOS_TO_MODEL_DESCRIPTOR:
raise ValueError(
f"Invalid llama_model '{model.metadata['llama_model']}' specified in metadata. "
f"Invalid llama_model '{llama_model}' specified in metadata. "
f"Must be one of: {', '.join(ALL_HUGGINGFACE_REPOS_TO_MODEL_DESCRIPTOR.keys())}"
)
self.provider_id_to_llama_model_map[model.provider_resource_id] = (
ALL_HUGGINGFACE_REPOS_TO_MODEL_DESCRIPTOR[model.metadata["llama_model"]]
ALL_HUGGINGFACE_REPOS_TO_MODEL_DESCRIPTOR[llama_model]
)
return model

View file

@ -42,28 +42,30 @@ def pytest_addoption(parser):
)
parser.addoption(
"--inference-model",
action="store",
default=TEXT_MODEL,
help="Specify the inference model to use for testing",
)
parser.addoption(
"--vision-inference-model",
action="store",
default=VISION_MODEL,
help="Specify the vision inference model to use for testing",
)
parser.addoption(
"--safety-shield",
action="store",
default="meta-llama/Llama-Guard-3-1B",
help="Specify the safety shield model to use for testing",
)
parser.addoption(
"--embedding-model",
action="store",
default=TEXT_MODEL,
default=None,
help="Specify the embedding model to use for testing",
)
parser.addoption(
"--embedding-dimension",
type=int,
default=384,
help="Output dimensionality of the embedding model to use for testing",
)
@pytest.fixture(scope="session")
@ -78,7 +80,7 @@ def provider_data():
@pytest.fixture(scope="session")
def llama_stack_client(provider_data):
def llama_stack_client(provider_data, text_model_id):
if os.environ.get("LLAMA_STACK_CONFIG"):
client = LlamaStackAsLibraryClient(
get_env_or_fail("LLAMA_STACK_CONFIG"),
@ -95,6 +97,45 @@ def llama_stack_client(provider_data):
)
else:
raise ValueError("LLAMA_STACK_CONFIG or LLAMA_STACK_BASE_URL must be set")
return client
@pytest.fixture(scope="session")
def inference_provider_type(llama_stack_client):
providers = llama_stack_client.providers.list()
inference_providers = [p for p in providers if p.api == "inference"]
assert len(inference_providers) > 0, "No inference providers found"
return inference_providers[0].provider_type
@pytest.fixture(scope="session")
def client_with_models(llama_stack_client, text_model_id, vision_model_id, embedding_model_id, embedding_dimension):
client = llama_stack_client
providers = [p for p in client.providers.list() if p.api == "inference"]
assert len(providers) > 0, "No inference providers found"
inference_providers = [p.provider_id for p in providers if p.provider_type != "inline::sentence-transformers"]
if text_model_id:
client.models.register(model_id=text_model_id, provider_id=inference_providers[0])
if vision_model_id:
client.models.register(model_id=vision_model_id, provider_id=inference_providers[0])
if embedding_model_id and embedding_dimension:
# try to find a provider that supports embeddings, if sentence-transformers is not available
selected_provider = None
for p in providers:
if p.provider_type == "inline::sentence-transformers":
selected_provider = p
break
selected_provider = selected_provider or providers[0]
client.models.register(
model_id=embedding_model_id,
provider_id=selected_provider.provider_id,
model_type="embedding",
metadata={"embedding_dimension": embedding_dimension},
)
return client
@ -117,3 +158,9 @@ def pytest_generate_tests(metafunc):
[metafunc.config.getoption("--embedding-model")],
scope="session",
)
if "embedding_dimension" in metafunc.fixturenames:
metafunc.parametrize(
"embedding_dimension",
[metafunc.config.getoption("--embedding-dimension")],
scope="session",
)

View file

@ -28,14 +28,6 @@ def provider_tool_format(inference_provider_type):
)
@pytest.fixture(scope="session")
def inference_provider_type(llama_stack_client):
providers = llama_stack_client.providers.list()
inference_providers = [p for p in providers if p.api == "inference"]
assert len(inference_providers) > 0, "No inference providers found"
return inference_providers[0].provider_type
@pytest.fixture
def get_weather_tool_definition():
return {
@ -50,8 +42,8 @@ def get_weather_tool_definition():
}
def test_text_completion_non_streaming(llama_stack_client, text_model_id):
response = llama_stack_client.inference.completion(
def test_text_completion_non_streaming(client_with_models, text_model_id):
response = client_with_models.inference.completion(
content="Complete the sentence using one word: Roses are red, violets are ",
stream=False,
model_id=text_model_id,
@ -63,8 +55,8 @@ def test_text_completion_non_streaming(llama_stack_client, text_model_id):
# assert "blue" in response.content.lower().strip()
def test_text_completion_streaming(llama_stack_client, text_model_id):
response = llama_stack_client.inference.completion(
def test_text_completion_streaming(client_with_models, text_model_id):
response = client_with_models.inference.completion(
content="Complete the sentence using one word: Roses are red, violets are ",
stream=True,
model_id=text_model_id,
@ -78,11 +70,11 @@ def test_text_completion_streaming(llama_stack_client, text_model_id):
assert len(content_str) > 10
def test_completion_log_probs_non_streaming(llama_stack_client, text_model_id, inference_provider_type):
def test_completion_log_probs_non_streaming(client_with_models, text_model_id, inference_provider_type):
if inference_provider_type not in PROVIDER_LOGPROBS_TOP_K:
pytest.xfail(f"{inference_provider_type} doesn't support log probs yet")
response = llama_stack_client.inference.completion(
response = client_with_models.inference.completion(
content="Complete the sentence: Micheael Jordan is born in ",
stream=False,
model_id=text_model_id,
@ -98,11 +90,11 @@ def test_completion_log_probs_non_streaming(llama_stack_client, text_model_id, i
assert all(len(logprob.logprobs_by_token) == 1 for logprob in response.logprobs)
def test_completion_log_probs_streaming(llama_stack_client, text_model_id, inference_provider_type):
def test_completion_log_probs_streaming(client_with_models, text_model_id, inference_provider_type):
if inference_provider_type not in PROVIDER_LOGPROBS_TOP_K:
pytest.xfail(f"{inference_provider_type} doesn't support log probs yet")
response = llama_stack_client.inference.completion(
response = client_with_models.inference.completion(
content="Complete the sentence: Micheael Jordan is born in ",
stream=True,
model_id=text_model_id,
@ -123,7 +115,7 @@ def test_completion_log_probs_streaming(llama_stack_client, text_model_id, infer
@pytest.mark.parametrize("test_case", ["completion-01"])
def test_text_completion_structured_output(llama_stack_client, text_model_id, inference_provider_type, test_case):
def test_text_completion_structured_output(client_with_models, text_model_id, test_case):
class AnswerFormat(BaseModel):
name: str
year_born: str
@ -132,7 +124,7 @@ def test_text_completion_structured_output(llama_stack_client, text_model_id, in
tc = TestCase(test_case)
user_input = tc["user_input"]
response = llama_stack_client.inference.completion(
response = client_with_models.inference.completion(
model_id=text_model_id,
content=user_input,
stream=False,
@ -161,8 +153,8 @@ def test_text_completion_structured_output(llama_stack_client, text_model_id, in
),
],
)
def test_text_chat_completion_non_streaming(llama_stack_client, text_model_id, question, expected):
response = llama_stack_client.inference.chat_completion(
def test_text_chat_completion_non_streaming(client_with_models, text_model_id, question, expected):
response = client_with_models.inference.chat_completion(
model_id=text_model_id,
messages=[
{
@ -184,8 +176,8 @@ def test_text_chat_completion_non_streaming(llama_stack_client, text_model_id, q
("What is the name of the US captial?", "Washington"),
],
)
def test_text_chat_completion_streaming(llama_stack_client, text_model_id, question, expected):
response = llama_stack_client.inference.chat_completion(
def test_text_chat_completion_streaming(client_with_models, text_model_id, question, expected):
response = client_with_models.inference.chat_completion(
model_id=text_model_id,
messages=[{"role": "user", "content": question}],
stream=True,
@ -196,9 +188,9 @@ def test_text_chat_completion_streaming(llama_stack_client, text_model_id, quest
def test_text_chat_completion_with_tool_calling_and_non_streaming(
llama_stack_client, text_model_id, get_weather_tool_definition, provider_tool_format
client_with_models, text_model_id, get_weather_tool_definition, provider_tool_format
):
response = llama_stack_client.inference.chat_completion(
response = client_with_models.inference.chat_completion(
model_id=text_model_id,
messages=[
{"role": "system", "content": "You are a helpful assistant."},
@ -233,9 +225,9 @@ def extract_tool_invocation_content(response):
def test_text_chat_completion_with_tool_calling_and_streaming(
llama_stack_client, text_model_id, get_weather_tool_definition, provider_tool_format
client_with_models, text_model_id, get_weather_tool_definition, provider_tool_format
):
response = llama_stack_client.inference.chat_completion(
response = client_with_models.inference.chat_completion(
model_id=text_model_id,
messages=[
{"role": "system", "content": "You are a helpful assistant."},
@ -251,13 +243,12 @@ def test_text_chat_completion_with_tool_calling_and_streaming(
def test_text_chat_completion_with_tool_choice_required(
llama_stack_client,
client_with_models,
text_model_id,
get_weather_tool_definition,
provider_tool_format,
inference_provider_type,
):
response = llama_stack_client.inference.chat_completion(
response = client_with_models.inference.chat_completion(
model_id=text_model_id,
messages=[
{"role": "system", "content": "You are a helpful assistant."},
@ -275,9 +266,9 @@ def test_text_chat_completion_with_tool_choice_required(
def test_text_chat_completion_with_tool_choice_none(
llama_stack_client, text_model_id, get_weather_tool_definition, provider_tool_format
client_with_models, text_model_id, get_weather_tool_definition, provider_tool_format
):
response = llama_stack_client.inference.chat_completion(
response = client_with_models.inference.chat_completion(
model_id=text_model_id,
messages=[
{"role": "system", "content": "You are a helpful assistant."},
@ -292,7 +283,7 @@ def test_text_chat_completion_with_tool_choice_none(
@pytest.mark.parametrize("test_case", ["chat_completion-01"])
def test_text_chat_completion_structured_output(llama_stack_client, text_model_id, inference_provider_type, test_case):
def test_text_chat_completion_structured_output(client_with_models, text_model_id, test_case):
class AnswerFormat(BaseModel):
first_name: str
last_name: str
@ -301,7 +292,7 @@ def test_text_chat_completion_structured_output(llama_stack_client, text_model_i
tc = TestCase(test_case)
response = llama_stack_client.inference.chat_completion(
response = client_with_models.inference.chat_completion(
model_id=text_model_id,
messages=tc["messages"],
response_format={
@ -325,7 +316,7 @@ def test_text_chat_completion_structured_output(llama_stack_client, text_model_i
False,
],
)
def test_text_chat_completion_tool_calling_tools_not_in_request(llama_stack_client, text_model_id, streaming):
def test_text_chat_completion_tool_calling_tools_not_in_request(client_with_models, text_model_id, streaming):
# TODO: more dynamic lookup on tool_prompt_format for model family
tool_prompt_format = "json" if "3.1" in text_model_id else "python_list"
request = {
@ -381,7 +372,7 @@ def test_text_chat_completion_tool_calling_tools_not_in_request(llama_stack_clie
"stream": streaming,
}
response = llama_stack_client.inference.chat_completion(**request)
response = client_with_models.inference.chat_completion(**request)
if streaming:
for chunk in response:

View file

@ -10,14 +10,6 @@ import pathlib
import pytest
@pytest.fixture(scope="session")
def inference_provider_type(llama_stack_client):
providers = llama_stack_client.providers.list()
inference_providers = [p for p in providers if p.api == "inference"]
assert len(inference_providers) > 0, "No inference providers found"
return inference_providers[0].provider_type
@pytest.fixture
def image_path():
return pathlib.Path(__file__).parent / "dog.png"
@ -35,7 +27,7 @@ def base64_image_url(base64_image_data, image_path):
return f"data:image/{image_path.suffix[1:]};base64,{base64_image_data}"
def test_image_chat_completion_non_streaming(llama_stack_client, vision_model_id):
def test_image_chat_completion_non_streaming(client_with_models, vision_model_id):
message = {
"role": "user",
"content": [
@ -53,7 +45,7 @@ def test_image_chat_completion_non_streaming(llama_stack_client, vision_model_id
},
],
}
response = llama_stack_client.inference.chat_completion(
response = client_with_models.inference.chat_completion(
model_id=vision_model_id,
messages=[message],
stream=False,
@ -63,7 +55,7 @@ def test_image_chat_completion_non_streaming(llama_stack_client, vision_model_id
assert any(expected in message_content for expected in {"dog", "puppy", "pup"})
def test_image_chat_completion_streaming(llama_stack_client, vision_model_id):
def test_image_chat_completion_streaming(client_with_models, vision_model_id):
message = {
"role": "user",
"content": [
@ -81,7 +73,7 @@ def test_image_chat_completion_streaming(llama_stack_client, vision_model_id):
},
],
}
response = llama_stack_client.inference.chat_completion(
response = client_with_models.inference.chat_completion(
model_id=vision_model_id,
messages=[message],
stream=True,
@ -94,7 +86,7 @@ def test_image_chat_completion_streaming(llama_stack_client, vision_model_id):
@pytest.mark.parametrize("type_", ["url", "data"])
def test_image_chat_completion_base64(llama_stack_client, vision_model_id, base64_image_data, base64_image_url, type_):
def test_image_chat_completion_base64(client_with_models, vision_model_id, base64_image_data, base64_image_url, type_):
image_spec = {
"url": {
"type": "image",
@ -122,7 +114,7 @@ def test_image_chat_completion_base64(llama_stack_client, vision_model_id, base6
},
],
}
response = llama_stack_client.inference.chat_completion(
response = client_with_models.inference.chat_completion(
model_id=vision_model_id,
messages=[message],
stream=False,