Support for Llama3.2 models and Swift SDK (#98)

This commit is contained in:
Ashwin Bharambe 2024-09-25 10:29:58 -07:00 committed by GitHub
parent 95abbf576b
commit 56aed59eb4
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
56 changed files with 3745 additions and 630 deletions

View file

@ -15,14 +15,16 @@ from llama_models.llama3.api.tokenizer import Tokenizer
from llama_models.sku_list import resolve_model
from llama_stack.apis.inference import * # noqa: F403
from llama_stack.providers.utils.inference.prepare_messages import prepare_messages
from llama_stack.providers.utils.inference.augment_messages import (
augment_messages_for_tools,
)
from .config import FireworksImplConfig
FIREWORKS_SUPPORTED_MODELS = {
"Meta-Llama3.1-8B-Instruct": "fireworks/llama-v3p1-8b-instruct",
"Meta-Llama3.1-70B-Instruct": "fireworks/llama-v3p1-70b-instruct",
"Meta-Llama3.1-405B-Instruct": "fireworks/llama-v3p1-405b-instruct",
"Llama3.1-8B-Instruct": "fireworks/llama-v3p1-8b-instruct",
"Llama3.1-70B-Instruct": "fireworks/llama-v3p1-70b-instruct",
"Llama3.1-405B-Instruct": "fireworks/llama-v3p1-405b-instruct",
}
@ -106,7 +108,7 @@ class FireworksInferenceAdapter(Inference):
logprobs=logprobs,
)
messages = prepare_messages(request)
messages = augment_messages_for_tools(request)
# accumulate sampling params and other options to pass to fireworks
options = self.get_fireworks_chat_options(request)

View file

@ -16,14 +16,16 @@ from llama_models.sku_list import resolve_model
from ollama import AsyncClient
from llama_stack.apis.inference import * # noqa: F403
from llama_stack.providers.utils.inference.prepare_messages import prepare_messages
from llama_stack.providers.utils.inference.augment_messages import (
augment_messages_for_tools,
)
# TODO: Eventually this will move to the llama cli model list command
# mapping of Model SKUs to ollama models
OLLAMA_SUPPORTED_SKUS = {
# "Meta-Llama3.1-8B-Instruct": "llama3.1",
"Meta-Llama3.1-8B-Instruct": "llama3.1:8b-instruct-fp16",
"Meta-Llama3.1-70B-Instruct": "llama3.1:70b-instruct-fp16",
# "Llama3.1-8B-Instruct": "llama3.1",
"Llama3.1-8B-Instruct": "llama3.1:8b-instruct-fp16",
"Llama3.1-70B-Instruct": "llama3.1:70b-instruct-fp16",
}
@ -115,7 +117,7 @@ class OllamaInferenceAdapter(Inference):
logprobs=logprobs,
)
messages = prepare_messages(request)
messages = augment_messages_for_tools(request)
# accumulate sampling params and other options to pass to ollama
options = self.get_ollama_chat_options(request)
ollama_model = self.resolve_ollama_model(request.model)

View file

@ -14,7 +14,9 @@ from llama_models.llama3.api.chat_format import ChatFormat
from llama_models.llama3.api.datatypes import StopReason
from llama_models.llama3.api.tokenizer import Tokenizer
from llama_stack.apis.inference import * # noqa: F403
from llama_stack.providers.utils.inference.prepare_messages import prepare_messages
from llama_stack.providers.utils.inference.augment_messages import (
augment_messages_for_tools,
)
from .config import TGIImplConfig
@ -95,7 +97,7 @@ class TGIAdapter(Inference):
logprobs=logprobs,
)
messages = prepare_messages(request)
messages = augment_messages_for_tools(request)
model_input = self.formatter.encode_dialog_prompt(messages)
prompt = self.tokenizer.decode(model_input.tokens)

View file

@ -15,14 +15,16 @@ from llama_models.sku_list import resolve_model
from together import Together
from llama_stack.apis.inference import * # noqa: F403
from llama_stack.providers.utils.inference.prepare_messages import prepare_messages
from llama_stack.providers.utils.inference.augment_messages import (
augment_messages_for_tools,
)
from .config import TogetherImplConfig
TOGETHER_SUPPORTED_MODELS = {
"Meta-Llama3.1-8B-Instruct": "meta-llama/Meta-Llama-3.1-8B-Instruct-Turbo",
"Meta-Llama3.1-70B-Instruct": "meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo",
"Meta-Llama3.1-405B-Instruct": "meta-llama/Meta-Llama-3.1-405B-Instruct-Turbo",
"Llama3.1-8B-Instruct": "meta-llama/Llama-3.1-8B-Instruct-Turbo",
"Llama3.1-70B-Instruct": "meta-llama/Llama-3.1-70B-Instruct-Turbo",
"Llama3.1-405B-Instruct": "meta-llama/Llama-3.1-405B-Instruct-Turbo",
}
@ -110,7 +112,7 @@ class TogetherInferenceAdapter(Inference):
# accumulate sampling params and other options to pass to together
options = self.get_together_chat_options(request)
together_model = self.resolve_together_model(request.model)
messages = prepare_messages(request)
messages = augment_messages_for_tools(request)
if not request.stream:
# TODO: might need to add back an async here