feat(providers): support non-llama models for inference providers (#1200)

This PR begins the process of supporting non-llama models within Llama
Stack. We start simple by adding support for this functionality within a
few existing providers: fireworks, together and ollama.

## Test Plan

```bash
LLAMA_STACK_CONFIG=fireworks pytest -s -v tests/client-sdk/inference/test_text_inference.py \
  --inference-model accounts/fireworks/models/phi-3-vision-128k-instruct
```

^ this passes most of the tests but as expected fails the tool calling
related tests since they are very specific to Llama models

```
inference/test_text_inference.py::test_text_completion_streaming[accounts/fireworks/models/phi-3-vision-128k-instruct] PASSED
inference/test_text_inference.py::test_completion_log_probs_non_streaming[accounts/fireworks/models/phi-3-vision-128k-instruct] PASSED
inference/test_text_inference.py::test_completion_log_probs_streaming[accounts/fireworks/models/phi-3-vision-128k-instruct] PASSED
inference/test_text_inference.py::test_text_completion_structured_output[accounts/fireworks/models/phi-3-vision-128k-instruct-completion-01] PASSED
inference/test_text_inference.py::test_text_chat_completion_non_streaming[accounts/fireworks/models/phi-3-vision-128k-instruct-Which planet do humans live on?-Earth] PASSED
inference/test_text_inference.py::test_text_chat_completion_non_streaming[accounts/fireworks/models/phi-3-vision-128k-instruct-Which planet has rings around it with a name starting w
ith letter S?-Saturn] PASSED
inference/test_text_inference.py::test_text_chat_completion_streaming[accounts/fireworks/models/phi-3-vision-128k-instruct-What's the name of the Sun in latin?-Sol] PASSED
inference/test_text_inference.py::test_text_chat_completion_streaming[accounts/fireworks/models/phi-3-vision-128k-instruct-What is the name of the US captial?-Washington] PASSED
inference/test_text_inference.py::test_text_chat_completion_with_tool_calling_and_non_streaming[accounts/fireworks/models/phi-3-vision-128k-instruct] FAILED
inference/test_text_inference.py::test_text_chat_completion_with_tool_calling_and_streaming[accounts/fireworks/models/phi-3-vision-128k-instruct] FAILED
inference/test_text_inference.py::test_text_chat_completion_with_tool_choice_required[accounts/fireworks/models/phi-3-vision-128k-instruct] FAILED
inference/test_text_inference.py::test_text_chat_completion_with_tool_choice_none[accounts/fireworks/models/phi-3-vision-128k-instruct] PASSED
inference/test_text_inference.py::test_text_chat_completion_structured_output[accounts/fireworks/models/phi-3-vision-128k-instruct] ERROR
inference/test_text_inference.py::test_text_chat_completion_tool_calling_tools_not_in_request[accounts/fireworks/models/phi-3-vision-128k-instruct-True] PASSED
inference/test_text_inference.py::test_text_chat_completion_tool_calling_tools_not_in_request[accounts/fireworks/models/phi-3-vision-128k-instruct-False] PASSED
```
This commit is contained in:
Ashwin Bharambe 2025-02-21 13:21:28 -08:00 committed by GitHub
parent 9bbe34694d
commit ab54b8cd58
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
7 changed files with 103 additions and 74 deletions

View file

@ -28,14 +28,6 @@ def provider_tool_format(inference_provider_type):
)
@pytest.fixture(scope="session")
def inference_provider_type(llama_stack_client):
providers = llama_stack_client.providers.list()
inference_providers = [p for p in providers if p.api == "inference"]
assert len(inference_providers) > 0, "No inference providers found"
return inference_providers[0].provider_type
@pytest.fixture
def get_weather_tool_definition():
return {
@ -50,8 +42,8 @@ def get_weather_tool_definition():
}
def test_text_completion_non_streaming(llama_stack_client, text_model_id):
response = llama_stack_client.inference.completion(
def test_text_completion_non_streaming(client_with_models, text_model_id):
response = client_with_models.inference.completion(
content="Complete the sentence using one word: Roses are red, violets are ",
stream=False,
model_id=text_model_id,
@ -63,8 +55,8 @@ def test_text_completion_non_streaming(llama_stack_client, text_model_id):
# assert "blue" in response.content.lower().strip()
def test_text_completion_streaming(llama_stack_client, text_model_id):
response = llama_stack_client.inference.completion(
def test_text_completion_streaming(client_with_models, text_model_id):
response = client_with_models.inference.completion(
content="Complete the sentence using one word: Roses are red, violets are ",
stream=True,
model_id=text_model_id,
@ -78,11 +70,11 @@ def test_text_completion_streaming(llama_stack_client, text_model_id):
assert len(content_str) > 10
def test_completion_log_probs_non_streaming(llama_stack_client, text_model_id, inference_provider_type):
def test_completion_log_probs_non_streaming(client_with_models, text_model_id, inference_provider_type):
if inference_provider_type not in PROVIDER_LOGPROBS_TOP_K:
pytest.xfail(f"{inference_provider_type} doesn't support log probs yet")
response = llama_stack_client.inference.completion(
response = client_with_models.inference.completion(
content="Complete the sentence: Micheael Jordan is born in ",
stream=False,
model_id=text_model_id,
@ -98,11 +90,11 @@ def test_completion_log_probs_non_streaming(llama_stack_client, text_model_id, i
assert all(len(logprob.logprobs_by_token) == 1 for logprob in response.logprobs)
def test_completion_log_probs_streaming(llama_stack_client, text_model_id, inference_provider_type):
def test_completion_log_probs_streaming(client_with_models, text_model_id, inference_provider_type):
if inference_provider_type not in PROVIDER_LOGPROBS_TOP_K:
pytest.xfail(f"{inference_provider_type} doesn't support log probs yet")
response = llama_stack_client.inference.completion(
response = client_with_models.inference.completion(
content="Complete the sentence: Micheael Jordan is born in ",
stream=True,
model_id=text_model_id,
@ -123,7 +115,7 @@ def test_completion_log_probs_streaming(llama_stack_client, text_model_id, infer
@pytest.mark.parametrize("test_case", ["completion-01"])
def test_text_completion_structured_output(llama_stack_client, text_model_id, inference_provider_type, test_case):
def test_text_completion_structured_output(client_with_models, text_model_id, test_case):
class AnswerFormat(BaseModel):
name: str
year_born: str
@ -132,7 +124,7 @@ def test_text_completion_structured_output(llama_stack_client, text_model_id, in
tc = TestCase(test_case)
user_input = tc["user_input"]
response = llama_stack_client.inference.completion(
response = client_with_models.inference.completion(
model_id=text_model_id,
content=user_input,
stream=False,
@ -161,8 +153,8 @@ def test_text_completion_structured_output(llama_stack_client, text_model_id, in
),
],
)
def test_text_chat_completion_non_streaming(llama_stack_client, text_model_id, question, expected):
response = llama_stack_client.inference.chat_completion(
def test_text_chat_completion_non_streaming(client_with_models, text_model_id, question, expected):
response = client_with_models.inference.chat_completion(
model_id=text_model_id,
messages=[
{
@ -184,8 +176,8 @@ def test_text_chat_completion_non_streaming(llama_stack_client, text_model_id, q
("What is the name of the US captial?", "Washington"),
],
)
def test_text_chat_completion_streaming(llama_stack_client, text_model_id, question, expected):
response = llama_stack_client.inference.chat_completion(
def test_text_chat_completion_streaming(client_with_models, text_model_id, question, expected):
response = client_with_models.inference.chat_completion(
model_id=text_model_id,
messages=[{"role": "user", "content": question}],
stream=True,
@ -196,9 +188,9 @@ def test_text_chat_completion_streaming(llama_stack_client, text_model_id, quest
def test_text_chat_completion_with_tool_calling_and_non_streaming(
llama_stack_client, text_model_id, get_weather_tool_definition, provider_tool_format
client_with_models, text_model_id, get_weather_tool_definition, provider_tool_format
):
response = llama_stack_client.inference.chat_completion(
response = client_with_models.inference.chat_completion(
model_id=text_model_id,
messages=[
{"role": "system", "content": "You are a helpful assistant."},
@ -233,9 +225,9 @@ def extract_tool_invocation_content(response):
def test_text_chat_completion_with_tool_calling_and_streaming(
llama_stack_client, text_model_id, get_weather_tool_definition, provider_tool_format
client_with_models, text_model_id, get_weather_tool_definition, provider_tool_format
):
response = llama_stack_client.inference.chat_completion(
response = client_with_models.inference.chat_completion(
model_id=text_model_id,
messages=[
{"role": "system", "content": "You are a helpful assistant."},
@ -251,13 +243,12 @@ def test_text_chat_completion_with_tool_calling_and_streaming(
def test_text_chat_completion_with_tool_choice_required(
llama_stack_client,
client_with_models,
text_model_id,
get_weather_tool_definition,
provider_tool_format,
inference_provider_type,
):
response = llama_stack_client.inference.chat_completion(
response = client_with_models.inference.chat_completion(
model_id=text_model_id,
messages=[
{"role": "system", "content": "You are a helpful assistant."},
@ -275,9 +266,9 @@ def test_text_chat_completion_with_tool_choice_required(
def test_text_chat_completion_with_tool_choice_none(
llama_stack_client, text_model_id, get_weather_tool_definition, provider_tool_format
client_with_models, text_model_id, get_weather_tool_definition, provider_tool_format
):
response = llama_stack_client.inference.chat_completion(
response = client_with_models.inference.chat_completion(
model_id=text_model_id,
messages=[
{"role": "system", "content": "You are a helpful assistant."},
@ -292,7 +283,7 @@ def test_text_chat_completion_with_tool_choice_none(
@pytest.mark.parametrize("test_case", ["chat_completion-01"])
def test_text_chat_completion_structured_output(llama_stack_client, text_model_id, inference_provider_type, test_case):
def test_text_chat_completion_structured_output(client_with_models, text_model_id, test_case):
class AnswerFormat(BaseModel):
first_name: str
last_name: str
@ -301,7 +292,7 @@ def test_text_chat_completion_structured_output(llama_stack_client, text_model_i
tc = TestCase(test_case)
response = llama_stack_client.inference.chat_completion(
response = client_with_models.inference.chat_completion(
model_id=text_model_id,
messages=tc["messages"],
response_format={
@ -325,7 +316,7 @@ def test_text_chat_completion_structured_output(llama_stack_client, text_model_i
False,
],
)
def test_text_chat_completion_tool_calling_tools_not_in_request(llama_stack_client, text_model_id, streaming):
def test_text_chat_completion_tool_calling_tools_not_in_request(client_with_models, text_model_id, streaming):
# TODO: more dynamic lookup on tool_prompt_format for model family
tool_prompt_format = "json" if "3.1" in text_model_id else "python_list"
request = {
@ -381,7 +372,7 @@ def test_text_chat_completion_tool_calling_tools_not_in_request(llama_stack_clie
"stream": streaming,
}
response = llama_stack_client.inference.chat_completion(**request)
response = client_with_models.inference.chat_completion(**request)
if streaming:
for chunk in response: