Support non-llama models for inference providers

This commit is contained in:
Ashwin Bharambe 2025-02-20 22:41:41 -08:00
parent 0fe071764f
commit 1b64573284
5 changed files with 27 additions and 20 deletions

View file

@ -209,15 +209,14 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProv
input_dict = {} input_dict = {}
media_present = request_has_media(request) media_present = request_has_media(request)
llama_model = self.get_llama_model(request.model)
if isinstance(request, ChatCompletionRequest): if isinstance(request, ChatCompletionRequest):
if media_present: if media_present or not llama_model:
input_dict["messages"] = [ input_dict["messages"] = [
await convert_message_to_openai_dict(m, download=True) for m in request.messages await convert_message_to_openai_dict(m, download=True) for m in request.messages
] ]
else: else:
input_dict["prompt"] = await chat_completion_request_to_prompt( input_dict["prompt"] = await chat_completion_request_to_prompt(request, llama_model)
request, self.get_llama_model(request.model)
)
else: else:
assert not media_present, "Fireworks does not support media for Completion requests" assert not media_present, "Fireworks does not support media for Completion requests"
input_dict["prompt"] = await completion_request_to_prompt(request) input_dict["prompt"] = await completion_request_to_prompt(request)

View file

@ -178,8 +178,9 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
input_dict = {} input_dict = {}
media_present = request_has_media(request) media_present = request_has_media(request)
llama_model = self.register_helper.get_llama_model(request.model)
if isinstance(request, ChatCompletionRequest): if isinstance(request, ChatCompletionRequest):
if media_present: if media_present or not llama_model:
contents = [await convert_message_to_openai_dict_for_ollama(m) for m in request.messages] contents = [await convert_message_to_openai_dict_for_ollama(m) for m in request.messages]
# flatten the list of lists # flatten the list of lists
input_dict["messages"] = [item for sublist in contents for item in sublist] input_dict["messages"] = [item for sublist in contents for item in sublist]
@ -187,7 +188,7 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
input_dict["raw"] = True input_dict["raw"] = True
input_dict["prompt"] = await chat_completion_request_to_prompt( input_dict["prompt"] = await chat_completion_request_to_prompt(
request, request,
self.register_helper.get_llama_model(request.model), llama_model,
) )
else: else:
assert not media_present, "Ollama does not support media for Completion requests" assert not media_present, "Ollama does not support media for Completion requests"

View file

@ -203,13 +203,12 @@ class TogetherInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProvi
async def _get_params(self, request: Union[ChatCompletionRequest, CompletionRequest]) -> dict: async def _get_params(self, request: Union[ChatCompletionRequest, CompletionRequest]) -> dict:
input_dict = {} input_dict = {}
media_present = request_has_media(request) media_present = request_has_media(request)
llama_model = self.get_llama_model(request.model)
if isinstance(request, ChatCompletionRequest): if isinstance(request, ChatCompletionRequest):
if media_present: if media_present or not llama_model:
input_dict["messages"] = [await convert_message_to_openai_dict(m) for m in request.messages] input_dict["messages"] = [await convert_message_to_openai_dict(m) for m in request.messages]
else: else:
input_dict["prompt"] = await chat_completion_request_to_prompt( input_dict["prompt"] = await chat_completion_request_to_prompt(request, llama_model)
request, self.get_llama_model(request.model)
)
else: else:
assert not media_present, "Together does not support media for Completion requests" assert not media_present, "Together does not support media for Completion requests"
input_dict["prompt"] = await completion_request_to_prompt(request) input_dict["prompt"] = await completion_request_to_prompt(request)

View file

@ -79,28 +79,28 @@ class ModelRegistryHelper(ModelsProtocolPrivate):
provider_resource_id = model.provider_resource_id provider_resource_id = model.provider_resource_id
else: else:
provider_resource_id = self.get_provider_model_id(model.provider_resource_id) provider_resource_id = self.get_provider_model_id(model.provider_resource_id)
if provider_resource_id: if provider_resource_id:
model.provider_resource_id = provider_resource_id model.provider_resource_id = provider_resource_id
else: else:
if model.metadata.get("llama_model") is None: llama_model = model.metadata.get("llama_model")
raise ValueError( if llama_model is None:
f"Model '{model.provider_resource_id}' is not available and no llama_model was specified in metadata. " return model
"Please specify a llama_model in metadata or use a supported model identifier"
)
existing_llama_model = self.get_llama_model(model.provider_resource_id) existing_llama_model = self.get_llama_model(model.provider_resource_id)
if existing_llama_model: if existing_llama_model:
if existing_llama_model != model.metadata["llama_model"]: if existing_llama_model != llama_model:
raise ValueError( raise ValueError(
f"Provider model id '{model.provider_resource_id}' is already registered to a different llama model: '{existing_llama_model}'" f"Provider model id '{model.provider_resource_id}' is already registered to a different llama model: '{existing_llama_model}'"
) )
else: else:
if model.metadata["llama_model"] not in ALL_HUGGINGFACE_REPOS_TO_MODEL_DESCRIPTOR: if llama_model not in ALL_HUGGINGFACE_REPOS_TO_MODEL_DESCRIPTOR:
raise ValueError( raise ValueError(
f"Invalid llama_model '{model.metadata['llama_model']}' specified in metadata. " f"Invalid llama_model '{llama_model}' specified in metadata. "
f"Must be one of: {', '.join(ALL_HUGGINGFACE_REPOS_TO_MODEL_DESCRIPTOR.keys())}" f"Must be one of: {', '.join(ALL_HUGGINGFACE_REPOS_TO_MODEL_DESCRIPTOR.keys())}"
) )
self.provider_id_to_llama_model_map[model.provider_resource_id] = ( self.provider_id_to_llama_model_map[model.provider_resource_id] = (
ALL_HUGGINGFACE_REPOS_TO_MODEL_DESCRIPTOR[model.metadata["llama_model"]] ALL_HUGGINGFACE_REPOS_TO_MODEL_DESCRIPTOR[llama_model]
) )
return model return model

View file

@ -78,7 +78,7 @@ def provider_data():
@pytest.fixture(scope="session") @pytest.fixture(scope="session")
def llama_stack_client(provider_data): def llama_stack_client(provider_data, text_model_id):
if os.environ.get("LLAMA_STACK_CONFIG"): if os.environ.get("LLAMA_STACK_CONFIG"):
client = LlamaStackAsLibraryClient( client = LlamaStackAsLibraryClient(
get_env_or_fail("LLAMA_STACK_CONFIG"), get_env_or_fail("LLAMA_STACK_CONFIG"),
@ -95,6 +95,14 @@ def llama_stack_client(provider_data):
) )
else: else:
raise ValueError("LLAMA_STACK_CONFIG or LLAMA_STACK_BASE_URL must be set") raise ValueError("LLAMA_STACK_CONFIG or LLAMA_STACK_BASE_URL must be set")
inference_providers = [
p.provider_id
for p in client.providers.list()
if p.api == "inference" and p.provider_id != "sentence-transformers"
]
assert len(inference_providers) > 0, "No inference providers found"
client.models.register(model_id=text_model_id, provider_id=inference_providers[0])
return client return client