diff --git a/llama_stack/providers/remote/inference/fireworks/fireworks.py b/llama_stack/providers/remote/inference/fireworks/fireworks.py index b9b23584b..90fe70cbf 100644 --- a/llama_stack/providers/remote/inference/fireworks/fireworks.py +++ b/llama_stack/providers/remote/inference/fireworks/fireworks.py @@ -209,15 +209,14 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProv input_dict = {} media_present = request_has_media(request) + llama_model = self.get_llama_model(request.model) if isinstance(request, ChatCompletionRequest): - if media_present: + if media_present or not llama_model: input_dict["messages"] = [ await convert_message_to_openai_dict(m, download=True) for m in request.messages ] else: - input_dict["prompt"] = await chat_completion_request_to_prompt( - request, self.get_llama_model(request.model) - ) + input_dict["prompt"] = await chat_completion_request_to_prompt(request, llama_model) else: assert not media_present, "Fireworks does not support media for Completion requests" input_dict["prompt"] = await completion_request_to_prompt(request) diff --git a/llama_stack/providers/remote/inference/ollama/ollama.py b/llama_stack/providers/remote/inference/ollama/ollama.py index 058bbeeee..6fcfd2e99 100644 --- a/llama_stack/providers/remote/inference/ollama/ollama.py +++ b/llama_stack/providers/remote/inference/ollama/ollama.py @@ -178,8 +178,9 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate): input_dict = {} media_present = request_has_media(request) + llama_model = self.register_helper.get_llama_model(request.model) if isinstance(request, ChatCompletionRequest): - if media_present: + if media_present or not llama_model: contents = [await convert_message_to_openai_dict_for_ollama(m) for m in request.messages] # flatten the list of lists input_dict["messages"] = [item for sublist in contents for item in sublist] @@ -187,7 +188,7 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate): input_dict["raw"] = True input_dict["prompt"] = await chat_completion_request_to_prompt( request, - self.register_helper.get_llama_model(request.model), + llama_model, ) else: assert not media_present, "Ollama does not support media for Completion requests" diff --git a/llama_stack/providers/remote/inference/together/together.py b/llama_stack/providers/remote/inference/together/together.py index 1fca54bb3..040f04e77 100644 --- a/llama_stack/providers/remote/inference/together/together.py +++ b/llama_stack/providers/remote/inference/together/together.py @@ -203,13 +203,12 @@ class TogetherInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProvi async def _get_params(self, request: Union[ChatCompletionRequest, CompletionRequest]) -> dict: input_dict = {} media_present = request_has_media(request) + llama_model = self.get_llama_model(request.model) if isinstance(request, ChatCompletionRequest): - if media_present: + if media_present or not llama_model: input_dict["messages"] = [await convert_message_to_openai_dict(m) for m in request.messages] else: - input_dict["prompt"] = await chat_completion_request_to_prompt( - request, self.get_llama_model(request.model) - ) + input_dict["prompt"] = await chat_completion_request_to_prompt(request, llama_model) else: assert not media_present, "Together does not support media for Completion requests" input_dict["prompt"] = await completion_request_to_prompt(request) diff --git a/llama_stack/providers/utils/inference/model_registry.py b/llama_stack/providers/utils/inference/model_registry.py index 0882019e3..d9e24662a 100644 --- a/llama_stack/providers/utils/inference/model_registry.py +++ b/llama_stack/providers/utils/inference/model_registry.py @@ -79,28 +79,28 @@ class ModelRegistryHelper(ModelsProtocolPrivate): provider_resource_id = model.provider_resource_id else: provider_resource_id = self.get_provider_model_id(model.provider_resource_id) + if provider_resource_id: model.provider_resource_id = provider_resource_id else: - if model.metadata.get("llama_model") is None: - raise ValueError( - f"Model '{model.provider_resource_id}' is not available and no llama_model was specified in metadata. " - "Please specify a llama_model in metadata or use a supported model identifier" - ) + llama_model = model.metadata.get("llama_model") + if llama_model is None: + return model + existing_llama_model = self.get_llama_model(model.provider_resource_id) if existing_llama_model: - if existing_llama_model != model.metadata["llama_model"]: + if existing_llama_model != llama_model: raise ValueError( f"Provider model id '{model.provider_resource_id}' is already registered to a different llama model: '{existing_llama_model}'" ) else: - if model.metadata["llama_model"] not in ALL_HUGGINGFACE_REPOS_TO_MODEL_DESCRIPTOR: + if llama_model not in ALL_HUGGINGFACE_REPOS_TO_MODEL_DESCRIPTOR: raise ValueError( - f"Invalid llama_model '{model.metadata['llama_model']}' specified in metadata. " + f"Invalid llama_model '{llama_model}' specified in metadata. " f"Must be one of: {', '.join(ALL_HUGGINGFACE_REPOS_TO_MODEL_DESCRIPTOR.keys())}" ) self.provider_id_to_llama_model_map[model.provider_resource_id] = ( - ALL_HUGGINGFACE_REPOS_TO_MODEL_DESCRIPTOR[model.metadata["llama_model"]] + ALL_HUGGINGFACE_REPOS_TO_MODEL_DESCRIPTOR[llama_model] ) return model diff --git a/tests/client-sdk/conftest.py b/tests/client-sdk/conftest.py index efdec6b01..662505590 100644 --- a/tests/client-sdk/conftest.py +++ b/tests/client-sdk/conftest.py @@ -42,28 +42,30 @@ def pytest_addoption(parser): ) parser.addoption( "--inference-model", - action="store", default=TEXT_MODEL, help="Specify the inference model to use for testing", ) parser.addoption( "--vision-inference-model", - action="store", default=VISION_MODEL, help="Specify the vision inference model to use for testing", ) parser.addoption( "--safety-shield", - action="store", default="meta-llama/Llama-Guard-3-1B", help="Specify the safety shield model to use for testing", ) parser.addoption( "--embedding-model", - action="store", - default=TEXT_MODEL, + default=None, help="Specify the embedding model to use for testing", ) + parser.addoption( + "--embedding-dimension", + type=int, + default=384, + help="Output dimensionality of the embedding model to use for testing", + ) @pytest.fixture(scope="session") @@ -78,7 +80,7 @@ def provider_data(): @pytest.fixture(scope="session") -def llama_stack_client(provider_data): +def llama_stack_client(provider_data, text_model_id): if os.environ.get("LLAMA_STACK_CONFIG"): client = LlamaStackAsLibraryClient( get_env_or_fail("LLAMA_STACK_CONFIG"), @@ -95,6 +97,45 @@ def llama_stack_client(provider_data): ) else: raise ValueError("LLAMA_STACK_CONFIG or LLAMA_STACK_BASE_URL must be set") + + return client + + +@pytest.fixture(scope="session") +def inference_provider_type(llama_stack_client): + providers = llama_stack_client.providers.list() + inference_providers = [p for p in providers if p.api == "inference"] + assert len(inference_providers) > 0, "No inference providers found" + return inference_providers[0].provider_type + + +@pytest.fixture(scope="session") +def client_with_models(llama_stack_client, text_model_id, vision_model_id, embedding_model_id, embedding_dimension): + client = llama_stack_client + + providers = [p for p in client.providers.list() if p.api == "inference"] + assert len(providers) > 0, "No inference providers found" + inference_providers = [p.provider_id for p in providers if p.provider_type != "inline::sentence-transformers"] + if text_model_id: + client.models.register(model_id=text_model_id, provider_id=inference_providers[0]) + if vision_model_id: + client.models.register(model_id=vision_model_id, provider_id=inference_providers[0]) + + if embedding_model_id and embedding_dimension: + # try to find a provider that supports embeddings, if sentence-transformers is not available + selected_provider = None + for p in providers: + if p.provider_type == "inline::sentence-transformers": + selected_provider = p + break + + selected_provider = selected_provider or providers[0] + client.models.register( + model_id=embedding_model_id, + provider_id=selected_provider.provider_id, + model_type="embedding", + metadata={"embedding_dimension": embedding_dimension}, + ) return client @@ -117,3 +158,9 @@ def pytest_generate_tests(metafunc): [metafunc.config.getoption("--embedding-model")], scope="session", ) + if "embedding_dimension" in metafunc.fixturenames: + metafunc.parametrize( + "embedding_dimension", + [metafunc.config.getoption("--embedding-dimension")], + scope="session", + ) diff --git a/tests/client-sdk/inference/test_text_inference.py b/tests/client-sdk/inference/test_text_inference.py index 545325bbe..75d932380 100644 --- a/tests/client-sdk/inference/test_text_inference.py +++ b/tests/client-sdk/inference/test_text_inference.py @@ -28,14 +28,6 @@ def provider_tool_format(inference_provider_type): ) -@pytest.fixture(scope="session") -def inference_provider_type(llama_stack_client): - providers = llama_stack_client.providers.list() - inference_providers = [p for p in providers if p.api == "inference"] - assert len(inference_providers) > 0, "No inference providers found" - return inference_providers[0].provider_type - - @pytest.fixture def get_weather_tool_definition(): return { @@ -50,8 +42,8 @@ def get_weather_tool_definition(): } -def test_text_completion_non_streaming(llama_stack_client, text_model_id): - response = llama_stack_client.inference.completion( +def test_text_completion_non_streaming(client_with_models, text_model_id): + response = client_with_models.inference.completion( content="Complete the sentence using one word: Roses are red, violets are ", stream=False, model_id=text_model_id, @@ -63,8 +55,8 @@ def test_text_completion_non_streaming(llama_stack_client, text_model_id): # assert "blue" in response.content.lower().strip() -def test_text_completion_streaming(llama_stack_client, text_model_id): - response = llama_stack_client.inference.completion( +def test_text_completion_streaming(client_with_models, text_model_id): + response = client_with_models.inference.completion( content="Complete the sentence using one word: Roses are red, violets are ", stream=True, model_id=text_model_id, @@ -78,11 +70,11 @@ def test_text_completion_streaming(llama_stack_client, text_model_id): assert len(content_str) > 10 -def test_completion_log_probs_non_streaming(llama_stack_client, text_model_id, inference_provider_type): +def test_completion_log_probs_non_streaming(client_with_models, text_model_id, inference_provider_type): if inference_provider_type not in PROVIDER_LOGPROBS_TOP_K: pytest.xfail(f"{inference_provider_type} doesn't support log probs yet") - response = llama_stack_client.inference.completion( + response = client_with_models.inference.completion( content="Complete the sentence: Micheael Jordan is born in ", stream=False, model_id=text_model_id, @@ -98,11 +90,11 @@ def test_completion_log_probs_non_streaming(llama_stack_client, text_model_id, i assert all(len(logprob.logprobs_by_token) == 1 for logprob in response.logprobs) -def test_completion_log_probs_streaming(llama_stack_client, text_model_id, inference_provider_type): +def test_completion_log_probs_streaming(client_with_models, text_model_id, inference_provider_type): if inference_provider_type not in PROVIDER_LOGPROBS_TOP_K: pytest.xfail(f"{inference_provider_type} doesn't support log probs yet") - response = llama_stack_client.inference.completion( + response = client_with_models.inference.completion( content="Complete the sentence: Micheael Jordan is born in ", stream=True, model_id=text_model_id, @@ -123,7 +115,7 @@ def test_completion_log_probs_streaming(llama_stack_client, text_model_id, infer @pytest.mark.parametrize("test_case", ["completion-01"]) -def test_text_completion_structured_output(llama_stack_client, text_model_id, inference_provider_type, test_case): +def test_text_completion_structured_output(client_with_models, text_model_id, test_case): class AnswerFormat(BaseModel): name: str year_born: str @@ -132,7 +124,7 @@ def test_text_completion_structured_output(llama_stack_client, text_model_id, in tc = TestCase(test_case) user_input = tc["user_input"] - response = llama_stack_client.inference.completion( + response = client_with_models.inference.completion( model_id=text_model_id, content=user_input, stream=False, @@ -161,8 +153,8 @@ def test_text_completion_structured_output(llama_stack_client, text_model_id, in ), ], ) -def test_text_chat_completion_non_streaming(llama_stack_client, text_model_id, question, expected): - response = llama_stack_client.inference.chat_completion( +def test_text_chat_completion_non_streaming(client_with_models, text_model_id, question, expected): + response = client_with_models.inference.chat_completion( model_id=text_model_id, messages=[ { @@ -184,8 +176,8 @@ def test_text_chat_completion_non_streaming(llama_stack_client, text_model_id, q ("What is the name of the US captial?", "Washington"), ], ) -def test_text_chat_completion_streaming(llama_stack_client, text_model_id, question, expected): - response = llama_stack_client.inference.chat_completion( +def test_text_chat_completion_streaming(client_with_models, text_model_id, question, expected): + response = client_with_models.inference.chat_completion( model_id=text_model_id, messages=[{"role": "user", "content": question}], stream=True, @@ -196,9 +188,9 @@ def test_text_chat_completion_streaming(llama_stack_client, text_model_id, quest def test_text_chat_completion_with_tool_calling_and_non_streaming( - llama_stack_client, text_model_id, get_weather_tool_definition, provider_tool_format + client_with_models, text_model_id, get_weather_tool_definition, provider_tool_format ): - response = llama_stack_client.inference.chat_completion( + response = client_with_models.inference.chat_completion( model_id=text_model_id, messages=[ {"role": "system", "content": "You are a helpful assistant."}, @@ -233,9 +225,9 @@ def extract_tool_invocation_content(response): def test_text_chat_completion_with_tool_calling_and_streaming( - llama_stack_client, text_model_id, get_weather_tool_definition, provider_tool_format + client_with_models, text_model_id, get_weather_tool_definition, provider_tool_format ): - response = llama_stack_client.inference.chat_completion( + response = client_with_models.inference.chat_completion( model_id=text_model_id, messages=[ {"role": "system", "content": "You are a helpful assistant."}, @@ -251,13 +243,12 @@ def test_text_chat_completion_with_tool_calling_and_streaming( def test_text_chat_completion_with_tool_choice_required( - llama_stack_client, + client_with_models, text_model_id, get_weather_tool_definition, provider_tool_format, - inference_provider_type, ): - response = llama_stack_client.inference.chat_completion( + response = client_with_models.inference.chat_completion( model_id=text_model_id, messages=[ {"role": "system", "content": "You are a helpful assistant."}, @@ -275,9 +266,9 @@ def test_text_chat_completion_with_tool_choice_required( def test_text_chat_completion_with_tool_choice_none( - llama_stack_client, text_model_id, get_weather_tool_definition, provider_tool_format + client_with_models, text_model_id, get_weather_tool_definition, provider_tool_format ): - response = llama_stack_client.inference.chat_completion( + response = client_with_models.inference.chat_completion( model_id=text_model_id, messages=[ {"role": "system", "content": "You are a helpful assistant."}, @@ -292,7 +283,7 @@ def test_text_chat_completion_with_tool_choice_none( @pytest.mark.parametrize("test_case", ["chat_completion-01"]) -def test_text_chat_completion_structured_output(llama_stack_client, text_model_id, inference_provider_type, test_case): +def test_text_chat_completion_structured_output(client_with_models, text_model_id, test_case): class AnswerFormat(BaseModel): first_name: str last_name: str @@ -301,7 +292,7 @@ def test_text_chat_completion_structured_output(llama_stack_client, text_model_i tc = TestCase(test_case) - response = llama_stack_client.inference.chat_completion( + response = client_with_models.inference.chat_completion( model_id=text_model_id, messages=tc["messages"], response_format={ @@ -325,7 +316,7 @@ def test_text_chat_completion_structured_output(llama_stack_client, text_model_i False, ], ) -def test_text_chat_completion_tool_calling_tools_not_in_request(llama_stack_client, text_model_id, streaming): +def test_text_chat_completion_tool_calling_tools_not_in_request(client_with_models, text_model_id, streaming): # TODO: more dynamic lookup on tool_prompt_format for model family tool_prompt_format = "json" if "3.1" in text_model_id else "python_list" request = { @@ -381,7 +372,7 @@ def test_text_chat_completion_tool_calling_tools_not_in_request(llama_stack_clie "stream": streaming, } - response = llama_stack_client.inference.chat_completion(**request) + response = client_with_models.inference.chat_completion(**request) if streaming: for chunk in response: diff --git a/tests/client-sdk/inference/test_vision_inference.py b/tests/client-sdk/inference/test_vision_inference.py index b23089747..8fa0d8023 100644 --- a/tests/client-sdk/inference/test_vision_inference.py +++ b/tests/client-sdk/inference/test_vision_inference.py @@ -10,14 +10,6 @@ import pathlib import pytest -@pytest.fixture(scope="session") -def inference_provider_type(llama_stack_client): - providers = llama_stack_client.providers.list() - inference_providers = [p for p in providers if p.api == "inference"] - assert len(inference_providers) > 0, "No inference providers found" - return inference_providers[0].provider_type - - @pytest.fixture def image_path(): return pathlib.Path(__file__).parent / "dog.png" @@ -35,7 +27,7 @@ def base64_image_url(base64_image_data, image_path): return f"data:image/{image_path.suffix[1:]};base64,{base64_image_data}" -def test_image_chat_completion_non_streaming(llama_stack_client, vision_model_id): +def test_image_chat_completion_non_streaming(client_with_models, vision_model_id): message = { "role": "user", "content": [ @@ -53,7 +45,7 @@ def test_image_chat_completion_non_streaming(llama_stack_client, vision_model_id }, ], } - response = llama_stack_client.inference.chat_completion( + response = client_with_models.inference.chat_completion( model_id=vision_model_id, messages=[message], stream=False, @@ -63,7 +55,7 @@ def test_image_chat_completion_non_streaming(llama_stack_client, vision_model_id assert any(expected in message_content for expected in {"dog", "puppy", "pup"}) -def test_image_chat_completion_streaming(llama_stack_client, vision_model_id): +def test_image_chat_completion_streaming(client_with_models, vision_model_id): message = { "role": "user", "content": [ @@ -81,7 +73,7 @@ def test_image_chat_completion_streaming(llama_stack_client, vision_model_id): }, ], } - response = llama_stack_client.inference.chat_completion( + response = client_with_models.inference.chat_completion( model_id=vision_model_id, messages=[message], stream=True, @@ -94,7 +86,7 @@ def test_image_chat_completion_streaming(llama_stack_client, vision_model_id): @pytest.mark.parametrize("type_", ["url", "data"]) -def test_image_chat_completion_base64(llama_stack_client, vision_model_id, base64_image_data, base64_image_url, type_): +def test_image_chat_completion_base64(client_with_models, vision_model_id, base64_image_data, base64_image_url, type_): image_spec = { "url": { "type": "image", @@ -122,7 +114,7 @@ def test_image_chat_completion_base64(llama_stack_client, vision_model_id, base6 }, ], } - response = llama_stack_client.inference.chat_completion( + response = client_with_models.inference.chat_completion( model_id=vision_model_id, messages=[message], stream=False,