From 1b64573284c8e7babe832b79321457a40c8cc6e7 Mon Sep 17 00:00:00 2001 From: Ashwin Bharambe Date: Thu, 20 Feb 2025 22:41:41 -0800 Subject: [PATCH] Support non-llama models for inference providers --- .../remote/inference/fireworks/fireworks.py | 7 +++---- .../remote/inference/ollama/ollama.py | 5 +++-- .../remote/inference/together/together.py | 7 +++---- .../utils/inference/model_registry.py | 18 +++++++++--------- tests/client-sdk/conftest.py | 10 +++++++++- 5 files changed, 27 insertions(+), 20 deletions(-) diff --git a/llama_stack/providers/remote/inference/fireworks/fireworks.py b/llama_stack/providers/remote/inference/fireworks/fireworks.py index b9b23584b..90fe70cbf 100644 --- a/llama_stack/providers/remote/inference/fireworks/fireworks.py +++ b/llama_stack/providers/remote/inference/fireworks/fireworks.py @@ -209,15 +209,14 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProv input_dict = {} media_present = request_has_media(request) + llama_model = self.get_llama_model(request.model) if isinstance(request, ChatCompletionRequest): - if media_present: + if media_present or not llama_model: input_dict["messages"] = [ await convert_message_to_openai_dict(m, download=True) for m in request.messages ] else: - input_dict["prompt"] = await chat_completion_request_to_prompt( - request, self.get_llama_model(request.model) - ) + input_dict["prompt"] = await chat_completion_request_to_prompt(request, llama_model) else: assert not media_present, "Fireworks does not support media for Completion requests" input_dict["prompt"] = await completion_request_to_prompt(request) diff --git a/llama_stack/providers/remote/inference/ollama/ollama.py b/llama_stack/providers/remote/inference/ollama/ollama.py index f61ac9898..5841b13aa 100644 --- a/llama_stack/providers/remote/inference/ollama/ollama.py +++ b/llama_stack/providers/remote/inference/ollama/ollama.py @@ -178,8 +178,9 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate): input_dict = {} media_present = request_has_media(request) + llama_model = self.register_helper.get_llama_model(request.model) if isinstance(request, ChatCompletionRequest): - if media_present: + if media_present or not llama_model: contents = [await convert_message_to_openai_dict_for_ollama(m) for m in request.messages] # flatten the list of lists input_dict["messages"] = [item for sublist in contents for item in sublist] @@ -187,7 +188,7 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate): input_dict["raw"] = True input_dict["prompt"] = await chat_completion_request_to_prompt( request, - self.register_helper.get_llama_model(request.model), + llama_model, ) else: assert not media_present, "Ollama does not support media for Completion requests" diff --git a/llama_stack/providers/remote/inference/together/together.py b/llama_stack/providers/remote/inference/together/together.py index 1fca54bb3..040f04e77 100644 --- a/llama_stack/providers/remote/inference/together/together.py +++ b/llama_stack/providers/remote/inference/together/together.py @@ -203,13 +203,12 @@ class TogetherInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProvi async def _get_params(self, request: Union[ChatCompletionRequest, CompletionRequest]) -> dict: input_dict = {} media_present = request_has_media(request) + llama_model = self.get_llama_model(request.model) if isinstance(request, ChatCompletionRequest): - if media_present: + if media_present or not llama_model: input_dict["messages"] = [await convert_message_to_openai_dict(m) for m in request.messages] else: - input_dict["prompt"] = await chat_completion_request_to_prompt( - request, self.get_llama_model(request.model) - ) + input_dict["prompt"] = await chat_completion_request_to_prompt(request, llama_model) else: assert not media_present, "Together does not support media for Completion requests" input_dict["prompt"] = await completion_request_to_prompt(request) diff --git a/llama_stack/providers/utils/inference/model_registry.py b/llama_stack/providers/utils/inference/model_registry.py index 0882019e3..d9e24662a 100644 --- a/llama_stack/providers/utils/inference/model_registry.py +++ b/llama_stack/providers/utils/inference/model_registry.py @@ -79,28 +79,28 @@ class ModelRegistryHelper(ModelsProtocolPrivate): provider_resource_id = model.provider_resource_id else: provider_resource_id = self.get_provider_model_id(model.provider_resource_id) + if provider_resource_id: model.provider_resource_id = provider_resource_id else: - if model.metadata.get("llama_model") is None: - raise ValueError( - f"Model '{model.provider_resource_id}' is not available and no llama_model was specified in metadata. " - "Please specify a llama_model in metadata or use a supported model identifier" - ) + llama_model = model.metadata.get("llama_model") + if llama_model is None: + return model + existing_llama_model = self.get_llama_model(model.provider_resource_id) if existing_llama_model: - if existing_llama_model != model.metadata["llama_model"]: + if existing_llama_model != llama_model: raise ValueError( f"Provider model id '{model.provider_resource_id}' is already registered to a different llama model: '{existing_llama_model}'" ) else: - if model.metadata["llama_model"] not in ALL_HUGGINGFACE_REPOS_TO_MODEL_DESCRIPTOR: + if llama_model not in ALL_HUGGINGFACE_REPOS_TO_MODEL_DESCRIPTOR: raise ValueError( - f"Invalid llama_model '{model.metadata['llama_model']}' specified in metadata. " + f"Invalid llama_model '{llama_model}' specified in metadata. " f"Must be one of: {', '.join(ALL_HUGGINGFACE_REPOS_TO_MODEL_DESCRIPTOR.keys())}" ) self.provider_id_to_llama_model_map[model.provider_resource_id] = ( - ALL_HUGGINGFACE_REPOS_TO_MODEL_DESCRIPTOR[model.metadata["llama_model"]] + ALL_HUGGINGFACE_REPOS_TO_MODEL_DESCRIPTOR[llama_model] ) return model diff --git a/tests/client-sdk/conftest.py b/tests/client-sdk/conftest.py index efdec6b01..a2e52c8ec 100644 --- a/tests/client-sdk/conftest.py +++ b/tests/client-sdk/conftest.py @@ -78,7 +78,7 @@ def provider_data(): @pytest.fixture(scope="session") -def llama_stack_client(provider_data): +def llama_stack_client(provider_data, text_model_id): if os.environ.get("LLAMA_STACK_CONFIG"): client = LlamaStackAsLibraryClient( get_env_or_fail("LLAMA_STACK_CONFIG"), @@ -95,6 +95,14 @@ def llama_stack_client(provider_data): ) else: raise ValueError("LLAMA_STACK_CONFIG or LLAMA_STACK_BASE_URL must be set") + + inference_providers = [ + p.provider_id + for p in client.providers.list() + if p.api == "inference" and p.provider_id != "sentence-transformers" + ] + assert len(inference_providers) > 0, "No inference providers found" + client.models.register(model_id=text_model_id, provider_id=inference_providers[0]) return client