Support for Llama3.2 models and Swift SDK (#98)

This commit is contained in:
Ashwin Bharambe 2024-09-25 10:29:58 -07:00 committed by GitHub
parent 95abbf576b
commit 56aed59eb4
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
56 changed files with 3745 additions and 630 deletions

View file

@ -8,9 +8,9 @@ import unittest
from llama_models.llama3.api import * # noqa: F403
from llama_stack.inference.api import * # noqa: F403
from llama_stack.inference.prepare_messages import prepare_messages
from llama_stack.inference.augment_messages import augment_messages_for_tools
MODEL = "Meta-Llama3.1-8B-Instruct"
MODEL = "Llama3.1-8B-Instruct"
class PrepareMessagesTests(unittest.IsolatedAsyncioTestCase):
@ -22,7 +22,7 @@ class PrepareMessagesTests(unittest.IsolatedAsyncioTestCase):
UserMessage(content=content),
],
)
messages = prepare_messages(request)
messages = augment_messages_for_tools(request)
self.assertEqual(len(messages), 2)
self.assertEqual(messages[-1].content, content)
self.assertTrue("Cutting Knowledge Date: December 2023" in messages[0].content)
@ -39,7 +39,7 @@ class PrepareMessagesTests(unittest.IsolatedAsyncioTestCase):
ToolDefinition(tool_name=BuiltinTool.brave_search),
],
)
messages = prepare_messages(request)
messages = augment_messages_for_tools(request)
self.assertEqual(len(messages), 2)
self.assertEqual(messages[-1].content, content)
self.assertTrue("Cutting Knowledge Date: December 2023" in messages[0].content)
@ -67,7 +67,7 @@ class PrepareMessagesTests(unittest.IsolatedAsyncioTestCase):
],
tool_prompt_format=ToolPromptFormat.json,
)
messages = prepare_messages(request)
messages = augment_messages_for_tools(request)
self.assertEqual(len(messages), 3)
self.assertTrue("Environment: ipython" in messages[0].content)
@ -97,7 +97,7 @@ class PrepareMessagesTests(unittest.IsolatedAsyncioTestCase):
),
],
)
messages = prepare_messages(request)
messages = augment_messages_for_tools(request)
self.assertEqual(len(messages), 3)
self.assertTrue("Environment: ipython" in messages[0].content)
@ -119,7 +119,7 @@ class PrepareMessagesTests(unittest.IsolatedAsyncioTestCase):
ToolDefinition(tool_name=BuiltinTool.code_interpreter),
],
)
messages = prepare_messages(request)
messages = augment_messages_for_tools(request)
self.assertEqual(len(messages), 2, messages)
self.assertTrue(messages[0].content.endswith(system_prompt))

View file

@ -59,7 +59,7 @@ class TestE2E(unittest.IsolatedAsyncioTestCase):
host=TestE2E.HOST,
port=TestE2E.PORT,
custom_tools=custom_tools,
# model="Meta-Llama3.1-70B-Instruct", # Defaults to 8B
# model="Llama3.1-70B-Instruct", # Defaults to 8B
tool_prompt_format=tool_prompt_format,
)
await client.create_session(__file__)

View file

@ -9,31 +9,15 @@
import asyncio
import os
import textwrap
import unittest
from datetime import datetime
from llama_models.llama3.api.datatypes import (
BuiltinTool,
StopReason,
SystemMessage,
ToolDefinition,
ToolParamDefinition,
ToolPromptFormat,
ToolResponseMessage,
UserMessage,
)
from llama_stack.inference.api import (
ChatCompletionRequest,
ChatCompletionResponseEventType,
)
from llama_models.llama3.api.datatypes import * # noqa: F403
from llama_stack.inference.api import * # noqa: F403
from llama_stack.inference.meta_reference.config import MetaReferenceImplConfig
from llama_stack.inference.meta_reference.inference import get_provider_impl
MODEL = "Meta-Llama3.1-8B-Instruct"
MODEL = "Llama3.1-8B-Instruct"
HELPER_MSG = """
This test needs llama-3.1-8b-instruct models.
Please donwload using the llama cli
@ -45,11 +29,10 @@ llama download --source huggingface --model-id llama3_1_8b_instruct --hf-token <
class InferenceTests(unittest.IsolatedAsyncioTestCase):
@classmethod
def setUpClass(cls):
# This runs the async setup function
asyncio.run(cls.asyncSetUpClass())
@classmethod
async def asyncSetUpClass(cls):
async def asyncSetUpClass(cls): # noqa
# assert model exists on local
model_dir = os.path.expanduser(f"~/.llama/checkpoints/{MODEL}/original/")
assert os.path.isdir(model_dir), HELPER_MSG
@ -67,11 +50,10 @@ class InferenceTests(unittest.IsolatedAsyncioTestCase):
@classmethod
def tearDownClass(cls):
# This runs the async teardown function
asyncio.run(cls.asyncTearDownClass())
@classmethod
async def asyncTearDownClass(cls):
async def asyncTearDownClass(cls): # noqa
await cls.api.shutdown()
async def asyncSetUp(self):

View file

@ -4,26 +4,10 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import textwrap
import unittest
from datetime import datetime
from llama_models.llama3.api.datatypes import (
BuiltinTool,
SamplingParams,
SamplingStrategy,
StopReason,
SystemMessage,
ToolDefinition,
ToolParamDefinition,
ToolPromptFormat,
ToolResponseMessage,
UserMessage,
)
from llama_stack.inference.api import (
ChatCompletionRequest,
ChatCompletionResponseEventType,
)
from llama_models.llama3.api.datatypes import * # noqa: F403
from llama_stack.inference.api import * # noqa: F403
from llama_stack.inference.ollama.config import OllamaImplConfig
from llama_stack.inference.ollama.ollama import get_provider_impl
@ -52,7 +36,7 @@ class OllamaInferenceTests(unittest.IsolatedAsyncioTestCase):
),
},
)
self.valid_supported_model = "Meta-Llama3.1-8B-Instruct"
self.valid_supported_model = "Llama3.1-8B-Instruct"
async def asyncTearDown(self):
await self.api.shutdown()
@ -272,7 +256,7 @@ class OllamaInferenceTests(unittest.IsolatedAsyncioTestCase):
ollama_model = self.api.resolve_ollama_model(self.valid_supported_model)
self.assertEqual(ollama_model, "llama3.1:8b-instruct-fp16")
invalid_model = "Meta-Llama3.1-8B"
invalid_model = "Llama3.1-8B"
with self.assertRaisesRegex(
AssertionError, f"Unsupported model: {invalid_model}"
):