mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-12 13:00:39 +00:00
Support non-llama models for inference providers
This commit is contained in:
parent
0fe071764f
commit
1b64573284
5 changed files with 27 additions and 20 deletions
|
@ -209,15 +209,14 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProv
|
||||||
input_dict = {}
|
input_dict = {}
|
||||||
media_present = request_has_media(request)
|
media_present = request_has_media(request)
|
||||||
|
|
||||||
|
llama_model = self.get_llama_model(request.model)
|
||||||
if isinstance(request, ChatCompletionRequest):
|
if isinstance(request, ChatCompletionRequest):
|
||||||
if media_present:
|
if media_present or not llama_model:
|
||||||
input_dict["messages"] = [
|
input_dict["messages"] = [
|
||||||
await convert_message_to_openai_dict(m, download=True) for m in request.messages
|
await convert_message_to_openai_dict(m, download=True) for m in request.messages
|
||||||
]
|
]
|
||||||
else:
|
else:
|
||||||
input_dict["prompt"] = await chat_completion_request_to_prompt(
|
input_dict["prompt"] = await chat_completion_request_to_prompt(request, llama_model)
|
||||||
request, self.get_llama_model(request.model)
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
assert not media_present, "Fireworks does not support media for Completion requests"
|
assert not media_present, "Fireworks does not support media for Completion requests"
|
||||||
input_dict["prompt"] = await completion_request_to_prompt(request)
|
input_dict["prompt"] = await completion_request_to_prompt(request)
|
||||||
|
|
|
@ -178,8 +178,9 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
|
||||||
|
|
||||||
input_dict = {}
|
input_dict = {}
|
||||||
media_present = request_has_media(request)
|
media_present = request_has_media(request)
|
||||||
|
llama_model = self.register_helper.get_llama_model(request.model)
|
||||||
if isinstance(request, ChatCompletionRequest):
|
if isinstance(request, ChatCompletionRequest):
|
||||||
if media_present:
|
if media_present or not llama_model:
|
||||||
contents = [await convert_message_to_openai_dict_for_ollama(m) for m in request.messages]
|
contents = [await convert_message_to_openai_dict_for_ollama(m) for m in request.messages]
|
||||||
# flatten the list of lists
|
# flatten the list of lists
|
||||||
input_dict["messages"] = [item for sublist in contents for item in sublist]
|
input_dict["messages"] = [item for sublist in contents for item in sublist]
|
||||||
|
@ -187,7 +188,7 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
|
||||||
input_dict["raw"] = True
|
input_dict["raw"] = True
|
||||||
input_dict["prompt"] = await chat_completion_request_to_prompt(
|
input_dict["prompt"] = await chat_completion_request_to_prompt(
|
||||||
request,
|
request,
|
||||||
self.register_helper.get_llama_model(request.model),
|
llama_model,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
assert not media_present, "Ollama does not support media for Completion requests"
|
assert not media_present, "Ollama does not support media for Completion requests"
|
||||||
|
|
|
@ -203,13 +203,12 @@ class TogetherInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProvi
|
||||||
async def _get_params(self, request: Union[ChatCompletionRequest, CompletionRequest]) -> dict:
|
async def _get_params(self, request: Union[ChatCompletionRequest, CompletionRequest]) -> dict:
|
||||||
input_dict = {}
|
input_dict = {}
|
||||||
media_present = request_has_media(request)
|
media_present = request_has_media(request)
|
||||||
|
llama_model = self.get_llama_model(request.model)
|
||||||
if isinstance(request, ChatCompletionRequest):
|
if isinstance(request, ChatCompletionRequest):
|
||||||
if media_present:
|
if media_present or not llama_model:
|
||||||
input_dict["messages"] = [await convert_message_to_openai_dict(m) for m in request.messages]
|
input_dict["messages"] = [await convert_message_to_openai_dict(m) for m in request.messages]
|
||||||
else:
|
else:
|
||||||
input_dict["prompt"] = await chat_completion_request_to_prompt(
|
input_dict["prompt"] = await chat_completion_request_to_prompt(request, llama_model)
|
||||||
request, self.get_llama_model(request.model)
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
assert not media_present, "Together does not support media for Completion requests"
|
assert not media_present, "Together does not support media for Completion requests"
|
||||||
input_dict["prompt"] = await completion_request_to_prompt(request)
|
input_dict["prompt"] = await completion_request_to_prompt(request)
|
||||||
|
|
|
@ -79,28 +79,28 @@ class ModelRegistryHelper(ModelsProtocolPrivate):
|
||||||
provider_resource_id = model.provider_resource_id
|
provider_resource_id = model.provider_resource_id
|
||||||
else:
|
else:
|
||||||
provider_resource_id = self.get_provider_model_id(model.provider_resource_id)
|
provider_resource_id = self.get_provider_model_id(model.provider_resource_id)
|
||||||
|
|
||||||
if provider_resource_id:
|
if provider_resource_id:
|
||||||
model.provider_resource_id = provider_resource_id
|
model.provider_resource_id = provider_resource_id
|
||||||
else:
|
else:
|
||||||
if model.metadata.get("llama_model") is None:
|
llama_model = model.metadata.get("llama_model")
|
||||||
raise ValueError(
|
if llama_model is None:
|
||||||
f"Model '{model.provider_resource_id}' is not available and no llama_model was specified in metadata. "
|
return model
|
||||||
"Please specify a llama_model in metadata or use a supported model identifier"
|
|
||||||
)
|
|
||||||
existing_llama_model = self.get_llama_model(model.provider_resource_id)
|
existing_llama_model = self.get_llama_model(model.provider_resource_id)
|
||||||
if existing_llama_model:
|
if existing_llama_model:
|
||||||
if existing_llama_model != model.metadata["llama_model"]:
|
if existing_llama_model != llama_model:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Provider model id '{model.provider_resource_id}' is already registered to a different llama model: '{existing_llama_model}'"
|
f"Provider model id '{model.provider_resource_id}' is already registered to a different llama model: '{existing_llama_model}'"
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
if model.metadata["llama_model"] not in ALL_HUGGINGFACE_REPOS_TO_MODEL_DESCRIPTOR:
|
if llama_model not in ALL_HUGGINGFACE_REPOS_TO_MODEL_DESCRIPTOR:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Invalid llama_model '{model.metadata['llama_model']}' specified in metadata. "
|
f"Invalid llama_model '{llama_model}' specified in metadata. "
|
||||||
f"Must be one of: {', '.join(ALL_HUGGINGFACE_REPOS_TO_MODEL_DESCRIPTOR.keys())}"
|
f"Must be one of: {', '.join(ALL_HUGGINGFACE_REPOS_TO_MODEL_DESCRIPTOR.keys())}"
|
||||||
)
|
)
|
||||||
self.provider_id_to_llama_model_map[model.provider_resource_id] = (
|
self.provider_id_to_llama_model_map[model.provider_resource_id] = (
|
||||||
ALL_HUGGINGFACE_REPOS_TO_MODEL_DESCRIPTOR[model.metadata["llama_model"]]
|
ALL_HUGGINGFACE_REPOS_TO_MODEL_DESCRIPTOR[llama_model]
|
||||||
)
|
)
|
||||||
|
|
||||||
return model
|
return model
|
||||||
|
|
|
@ -78,7 +78,7 @@ def provider_data():
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="session")
|
@pytest.fixture(scope="session")
|
||||||
def llama_stack_client(provider_data):
|
def llama_stack_client(provider_data, text_model_id):
|
||||||
if os.environ.get("LLAMA_STACK_CONFIG"):
|
if os.environ.get("LLAMA_STACK_CONFIG"):
|
||||||
client = LlamaStackAsLibraryClient(
|
client = LlamaStackAsLibraryClient(
|
||||||
get_env_or_fail("LLAMA_STACK_CONFIG"),
|
get_env_or_fail("LLAMA_STACK_CONFIG"),
|
||||||
|
@ -95,6 +95,14 @@ def llama_stack_client(provider_data):
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
raise ValueError("LLAMA_STACK_CONFIG or LLAMA_STACK_BASE_URL must be set")
|
raise ValueError("LLAMA_STACK_CONFIG or LLAMA_STACK_BASE_URL must be set")
|
||||||
|
|
||||||
|
inference_providers = [
|
||||||
|
p.provider_id
|
||||||
|
for p in client.providers.list()
|
||||||
|
if p.api == "inference" and p.provider_id != "sentence-transformers"
|
||||||
|
]
|
||||||
|
assert len(inference_providers) > 0, "No inference providers found"
|
||||||
|
client.models.register(model_id=text_model_id, provider_id=inference_providers[0])
|
||||||
return client
|
return client
|
||||||
|
|
||||||
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue