forked from phoenix-oss/llama-stack-mirror
feat(providers): support non-llama models for inference providers (#1200)
This PR begins the process of supporting non-llama models within Llama Stack. We start simple by adding support for this functionality within a few existing providers: fireworks, together and ollama. ## Test Plan ```bash LLAMA_STACK_CONFIG=fireworks pytest -s -v tests/client-sdk/inference/test_text_inference.py \ --inference-model accounts/fireworks/models/phi-3-vision-128k-instruct ``` ^ this passes most of the tests but as expected fails the tool calling related tests since they are very specific to Llama models ``` inference/test_text_inference.py::test_text_completion_streaming[accounts/fireworks/models/phi-3-vision-128k-instruct] PASSED inference/test_text_inference.py::test_completion_log_probs_non_streaming[accounts/fireworks/models/phi-3-vision-128k-instruct] PASSED inference/test_text_inference.py::test_completion_log_probs_streaming[accounts/fireworks/models/phi-3-vision-128k-instruct] PASSED inference/test_text_inference.py::test_text_completion_structured_output[accounts/fireworks/models/phi-3-vision-128k-instruct-completion-01] PASSED inference/test_text_inference.py::test_text_chat_completion_non_streaming[accounts/fireworks/models/phi-3-vision-128k-instruct-Which planet do humans live on?-Earth] PASSED inference/test_text_inference.py::test_text_chat_completion_non_streaming[accounts/fireworks/models/phi-3-vision-128k-instruct-Which planet has rings around it with a name starting w ith letter S?-Saturn] PASSED inference/test_text_inference.py::test_text_chat_completion_streaming[accounts/fireworks/models/phi-3-vision-128k-instruct-What's the name of the Sun in latin?-Sol] PASSED inference/test_text_inference.py::test_text_chat_completion_streaming[accounts/fireworks/models/phi-3-vision-128k-instruct-What is the name of the US captial?-Washington] PASSED inference/test_text_inference.py::test_text_chat_completion_with_tool_calling_and_non_streaming[accounts/fireworks/models/phi-3-vision-128k-instruct] FAILED inference/test_text_inference.py::test_text_chat_completion_with_tool_calling_and_streaming[accounts/fireworks/models/phi-3-vision-128k-instruct] FAILED inference/test_text_inference.py::test_text_chat_completion_with_tool_choice_required[accounts/fireworks/models/phi-3-vision-128k-instruct] FAILED inference/test_text_inference.py::test_text_chat_completion_with_tool_choice_none[accounts/fireworks/models/phi-3-vision-128k-instruct] PASSED inference/test_text_inference.py::test_text_chat_completion_structured_output[accounts/fireworks/models/phi-3-vision-128k-instruct] ERROR inference/test_text_inference.py::test_text_chat_completion_tool_calling_tools_not_in_request[accounts/fireworks/models/phi-3-vision-128k-instruct-True] PASSED inference/test_text_inference.py::test_text_chat_completion_tool_calling_tools_not_in_request[accounts/fireworks/models/phi-3-vision-128k-instruct-False] PASSED ```
This commit is contained in:
parent
9bbe34694d
commit
ab54b8cd58
7 changed files with 103 additions and 74 deletions
|
@ -28,14 +28,6 @@ def provider_tool_format(inference_provider_type):
|
|||
)
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def inference_provider_type(llama_stack_client):
|
||||
providers = llama_stack_client.providers.list()
|
||||
inference_providers = [p for p in providers if p.api == "inference"]
|
||||
assert len(inference_providers) > 0, "No inference providers found"
|
||||
return inference_providers[0].provider_type
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def get_weather_tool_definition():
|
||||
return {
|
||||
|
@ -50,8 +42,8 @@ def get_weather_tool_definition():
|
|||
}
|
||||
|
||||
|
||||
def test_text_completion_non_streaming(llama_stack_client, text_model_id):
|
||||
response = llama_stack_client.inference.completion(
|
||||
def test_text_completion_non_streaming(client_with_models, text_model_id):
|
||||
response = client_with_models.inference.completion(
|
||||
content="Complete the sentence using one word: Roses are red, violets are ",
|
||||
stream=False,
|
||||
model_id=text_model_id,
|
||||
|
@ -63,8 +55,8 @@ def test_text_completion_non_streaming(llama_stack_client, text_model_id):
|
|||
# assert "blue" in response.content.lower().strip()
|
||||
|
||||
|
||||
def test_text_completion_streaming(llama_stack_client, text_model_id):
|
||||
response = llama_stack_client.inference.completion(
|
||||
def test_text_completion_streaming(client_with_models, text_model_id):
|
||||
response = client_with_models.inference.completion(
|
||||
content="Complete the sentence using one word: Roses are red, violets are ",
|
||||
stream=True,
|
||||
model_id=text_model_id,
|
||||
|
@ -78,11 +70,11 @@ def test_text_completion_streaming(llama_stack_client, text_model_id):
|
|||
assert len(content_str) > 10
|
||||
|
||||
|
||||
def test_completion_log_probs_non_streaming(llama_stack_client, text_model_id, inference_provider_type):
|
||||
def test_completion_log_probs_non_streaming(client_with_models, text_model_id, inference_provider_type):
|
||||
if inference_provider_type not in PROVIDER_LOGPROBS_TOP_K:
|
||||
pytest.xfail(f"{inference_provider_type} doesn't support log probs yet")
|
||||
|
||||
response = llama_stack_client.inference.completion(
|
||||
response = client_with_models.inference.completion(
|
||||
content="Complete the sentence: Micheael Jordan is born in ",
|
||||
stream=False,
|
||||
model_id=text_model_id,
|
||||
|
@ -98,11 +90,11 @@ def test_completion_log_probs_non_streaming(llama_stack_client, text_model_id, i
|
|||
assert all(len(logprob.logprobs_by_token) == 1 for logprob in response.logprobs)
|
||||
|
||||
|
||||
def test_completion_log_probs_streaming(llama_stack_client, text_model_id, inference_provider_type):
|
||||
def test_completion_log_probs_streaming(client_with_models, text_model_id, inference_provider_type):
|
||||
if inference_provider_type not in PROVIDER_LOGPROBS_TOP_K:
|
||||
pytest.xfail(f"{inference_provider_type} doesn't support log probs yet")
|
||||
|
||||
response = llama_stack_client.inference.completion(
|
||||
response = client_with_models.inference.completion(
|
||||
content="Complete the sentence: Micheael Jordan is born in ",
|
||||
stream=True,
|
||||
model_id=text_model_id,
|
||||
|
@ -123,7 +115,7 @@ def test_completion_log_probs_streaming(llama_stack_client, text_model_id, infer
|
|||
|
||||
|
||||
@pytest.mark.parametrize("test_case", ["completion-01"])
|
||||
def test_text_completion_structured_output(llama_stack_client, text_model_id, inference_provider_type, test_case):
|
||||
def test_text_completion_structured_output(client_with_models, text_model_id, test_case):
|
||||
class AnswerFormat(BaseModel):
|
||||
name: str
|
||||
year_born: str
|
||||
|
@ -132,7 +124,7 @@ def test_text_completion_structured_output(llama_stack_client, text_model_id, in
|
|||
tc = TestCase(test_case)
|
||||
|
||||
user_input = tc["user_input"]
|
||||
response = llama_stack_client.inference.completion(
|
||||
response = client_with_models.inference.completion(
|
||||
model_id=text_model_id,
|
||||
content=user_input,
|
||||
stream=False,
|
||||
|
@ -161,8 +153,8 @@ def test_text_completion_structured_output(llama_stack_client, text_model_id, in
|
|||
),
|
||||
],
|
||||
)
|
||||
def test_text_chat_completion_non_streaming(llama_stack_client, text_model_id, question, expected):
|
||||
response = llama_stack_client.inference.chat_completion(
|
||||
def test_text_chat_completion_non_streaming(client_with_models, text_model_id, question, expected):
|
||||
response = client_with_models.inference.chat_completion(
|
||||
model_id=text_model_id,
|
||||
messages=[
|
||||
{
|
||||
|
@ -184,8 +176,8 @@ def test_text_chat_completion_non_streaming(llama_stack_client, text_model_id, q
|
|||
("What is the name of the US captial?", "Washington"),
|
||||
],
|
||||
)
|
||||
def test_text_chat_completion_streaming(llama_stack_client, text_model_id, question, expected):
|
||||
response = llama_stack_client.inference.chat_completion(
|
||||
def test_text_chat_completion_streaming(client_with_models, text_model_id, question, expected):
|
||||
response = client_with_models.inference.chat_completion(
|
||||
model_id=text_model_id,
|
||||
messages=[{"role": "user", "content": question}],
|
||||
stream=True,
|
||||
|
@ -196,9 +188,9 @@ def test_text_chat_completion_streaming(llama_stack_client, text_model_id, quest
|
|||
|
||||
|
||||
def test_text_chat_completion_with_tool_calling_and_non_streaming(
|
||||
llama_stack_client, text_model_id, get_weather_tool_definition, provider_tool_format
|
||||
client_with_models, text_model_id, get_weather_tool_definition, provider_tool_format
|
||||
):
|
||||
response = llama_stack_client.inference.chat_completion(
|
||||
response = client_with_models.inference.chat_completion(
|
||||
model_id=text_model_id,
|
||||
messages=[
|
||||
{"role": "system", "content": "You are a helpful assistant."},
|
||||
|
@ -233,9 +225,9 @@ def extract_tool_invocation_content(response):
|
|||
|
||||
|
||||
def test_text_chat_completion_with_tool_calling_and_streaming(
|
||||
llama_stack_client, text_model_id, get_weather_tool_definition, provider_tool_format
|
||||
client_with_models, text_model_id, get_weather_tool_definition, provider_tool_format
|
||||
):
|
||||
response = llama_stack_client.inference.chat_completion(
|
||||
response = client_with_models.inference.chat_completion(
|
||||
model_id=text_model_id,
|
||||
messages=[
|
||||
{"role": "system", "content": "You are a helpful assistant."},
|
||||
|
@ -251,13 +243,12 @@ def test_text_chat_completion_with_tool_calling_and_streaming(
|
|||
|
||||
|
||||
def test_text_chat_completion_with_tool_choice_required(
|
||||
llama_stack_client,
|
||||
client_with_models,
|
||||
text_model_id,
|
||||
get_weather_tool_definition,
|
||||
provider_tool_format,
|
||||
inference_provider_type,
|
||||
):
|
||||
response = llama_stack_client.inference.chat_completion(
|
||||
response = client_with_models.inference.chat_completion(
|
||||
model_id=text_model_id,
|
||||
messages=[
|
||||
{"role": "system", "content": "You are a helpful assistant."},
|
||||
|
@ -275,9 +266,9 @@ def test_text_chat_completion_with_tool_choice_required(
|
|||
|
||||
|
||||
def test_text_chat_completion_with_tool_choice_none(
|
||||
llama_stack_client, text_model_id, get_weather_tool_definition, provider_tool_format
|
||||
client_with_models, text_model_id, get_weather_tool_definition, provider_tool_format
|
||||
):
|
||||
response = llama_stack_client.inference.chat_completion(
|
||||
response = client_with_models.inference.chat_completion(
|
||||
model_id=text_model_id,
|
||||
messages=[
|
||||
{"role": "system", "content": "You are a helpful assistant."},
|
||||
|
@ -292,7 +283,7 @@ def test_text_chat_completion_with_tool_choice_none(
|
|||
|
||||
|
||||
@pytest.mark.parametrize("test_case", ["chat_completion-01"])
|
||||
def test_text_chat_completion_structured_output(llama_stack_client, text_model_id, inference_provider_type, test_case):
|
||||
def test_text_chat_completion_structured_output(client_with_models, text_model_id, test_case):
|
||||
class AnswerFormat(BaseModel):
|
||||
first_name: str
|
||||
last_name: str
|
||||
|
@ -301,7 +292,7 @@ def test_text_chat_completion_structured_output(llama_stack_client, text_model_i
|
|||
|
||||
tc = TestCase(test_case)
|
||||
|
||||
response = llama_stack_client.inference.chat_completion(
|
||||
response = client_with_models.inference.chat_completion(
|
||||
model_id=text_model_id,
|
||||
messages=tc["messages"],
|
||||
response_format={
|
||||
|
@ -325,7 +316,7 @@ def test_text_chat_completion_structured_output(llama_stack_client, text_model_i
|
|||
False,
|
||||
],
|
||||
)
|
||||
def test_text_chat_completion_tool_calling_tools_not_in_request(llama_stack_client, text_model_id, streaming):
|
||||
def test_text_chat_completion_tool_calling_tools_not_in_request(client_with_models, text_model_id, streaming):
|
||||
# TODO: more dynamic lookup on tool_prompt_format for model family
|
||||
tool_prompt_format = "json" if "3.1" in text_model_id else "python_list"
|
||||
request = {
|
||||
|
@ -381,7 +372,7 @@ def test_text_chat_completion_tool_calling_tools_not_in_request(llama_stack_clie
|
|||
"stream": streaming,
|
||||
}
|
||||
|
||||
response = llama_stack_client.inference.chat_completion(**request)
|
||||
response = client_with_models.inference.chat_completion(**request)
|
||||
|
||||
if streaming:
|
||||
for chunk in response:
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue