mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-12 13:00:39 +00:00
Support non-llama models for inference providers
This commit is contained in:
parent
0fe071764f
commit
1b64573284
5 changed files with 27 additions and 20 deletions
|
@ -209,15 +209,14 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProv
|
|||
input_dict = {}
|
||||
media_present = request_has_media(request)
|
||||
|
||||
llama_model = self.get_llama_model(request.model)
|
||||
if isinstance(request, ChatCompletionRequest):
|
||||
if media_present:
|
||||
if media_present or not llama_model:
|
||||
input_dict["messages"] = [
|
||||
await convert_message_to_openai_dict(m, download=True) for m in request.messages
|
||||
]
|
||||
else:
|
||||
input_dict["prompt"] = await chat_completion_request_to_prompt(
|
||||
request, self.get_llama_model(request.model)
|
||||
)
|
||||
input_dict["prompt"] = await chat_completion_request_to_prompt(request, llama_model)
|
||||
else:
|
||||
assert not media_present, "Fireworks does not support media for Completion requests"
|
||||
input_dict["prompt"] = await completion_request_to_prompt(request)
|
||||
|
|
|
@ -178,8 +178,9 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
|
|||
|
||||
input_dict = {}
|
||||
media_present = request_has_media(request)
|
||||
llama_model = self.register_helper.get_llama_model(request.model)
|
||||
if isinstance(request, ChatCompletionRequest):
|
||||
if media_present:
|
||||
if media_present or not llama_model:
|
||||
contents = [await convert_message_to_openai_dict_for_ollama(m) for m in request.messages]
|
||||
# flatten the list of lists
|
||||
input_dict["messages"] = [item for sublist in contents for item in sublist]
|
||||
|
@ -187,7 +188,7 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
|
|||
input_dict["raw"] = True
|
||||
input_dict["prompt"] = await chat_completion_request_to_prompt(
|
||||
request,
|
||||
self.register_helper.get_llama_model(request.model),
|
||||
llama_model,
|
||||
)
|
||||
else:
|
||||
assert not media_present, "Ollama does not support media for Completion requests"
|
||||
|
|
|
@ -203,13 +203,12 @@ class TogetherInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProvi
|
|||
async def _get_params(self, request: Union[ChatCompletionRequest, CompletionRequest]) -> dict:
|
||||
input_dict = {}
|
||||
media_present = request_has_media(request)
|
||||
llama_model = self.get_llama_model(request.model)
|
||||
if isinstance(request, ChatCompletionRequest):
|
||||
if media_present:
|
||||
if media_present or not llama_model:
|
||||
input_dict["messages"] = [await convert_message_to_openai_dict(m) for m in request.messages]
|
||||
else:
|
||||
input_dict["prompt"] = await chat_completion_request_to_prompt(
|
||||
request, self.get_llama_model(request.model)
|
||||
)
|
||||
input_dict["prompt"] = await chat_completion_request_to_prompt(request, llama_model)
|
||||
else:
|
||||
assert not media_present, "Together does not support media for Completion requests"
|
||||
input_dict["prompt"] = await completion_request_to_prompt(request)
|
||||
|
|
|
@ -79,28 +79,28 @@ class ModelRegistryHelper(ModelsProtocolPrivate):
|
|||
provider_resource_id = model.provider_resource_id
|
||||
else:
|
||||
provider_resource_id = self.get_provider_model_id(model.provider_resource_id)
|
||||
|
||||
if provider_resource_id:
|
||||
model.provider_resource_id = provider_resource_id
|
||||
else:
|
||||
if model.metadata.get("llama_model") is None:
|
||||
raise ValueError(
|
||||
f"Model '{model.provider_resource_id}' is not available and no llama_model was specified in metadata. "
|
||||
"Please specify a llama_model in metadata or use a supported model identifier"
|
||||
)
|
||||
llama_model = model.metadata.get("llama_model")
|
||||
if llama_model is None:
|
||||
return model
|
||||
|
||||
existing_llama_model = self.get_llama_model(model.provider_resource_id)
|
||||
if existing_llama_model:
|
||||
if existing_llama_model != model.metadata["llama_model"]:
|
||||
if existing_llama_model != llama_model:
|
||||
raise ValueError(
|
||||
f"Provider model id '{model.provider_resource_id}' is already registered to a different llama model: '{existing_llama_model}'"
|
||||
)
|
||||
else:
|
||||
if model.metadata["llama_model"] not in ALL_HUGGINGFACE_REPOS_TO_MODEL_DESCRIPTOR:
|
||||
if llama_model not in ALL_HUGGINGFACE_REPOS_TO_MODEL_DESCRIPTOR:
|
||||
raise ValueError(
|
||||
f"Invalid llama_model '{model.metadata['llama_model']}' specified in metadata. "
|
||||
f"Invalid llama_model '{llama_model}' specified in metadata. "
|
||||
f"Must be one of: {', '.join(ALL_HUGGINGFACE_REPOS_TO_MODEL_DESCRIPTOR.keys())}"
|
||||
)
|
||||
self.provider_id_to_llama_model_map[model.provider_resource_id] = (
|
||||
ALL_HUGGINGFACE_REPOS_TO_MODEL_DESCRIPTOR[model.metadata["llama_model"]]
|
||||
ALL_HUGGINGFACE_REPOS_TO_MODEL_DESCRIPTOR[llama_model]
|
||||
)
|
||||
|
||||
return model
|
||||
|
|
|
@ -78,7 +78,7 @@ def provider_data():
|
|||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def llama_stack_client(provider_data):
|
||||
def llama_stack_client(provider_data, text_model_id):
|
||||
if os.environ.get("LLAMA_STACK_CONFIG"):
|
||||
client = LlamaStackAsLibraryClient(
|
||||
get_env_or_fail("LLAMA_STACK_CONFIG"),
|
||||
|
@ -95,6 +95,14 @@ def llama_stack_client(provider_data):
|
|||
)
|
||||
else:
|
||||
raise ValueError("LLAMA_STACK_CONFIG or LLAMA_STACK_BASE_URL must be set")
|
||||
|
||||
inference_providers = [
|
||||
p.provider_id
|
||||
for p in client.providers.list()
|
||||
if p.api == "inference" and p.provider_id != "sentence-transformers"
|
||||
]
|
||||
assert len(inference_providers) > 0, "No inference providers found"
|
||||
client.models.register(model_id=text_model_id, provider_id=inference_providers[0])
|
||||
return client
|
||||
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue