rebase on top of registry

This commit is contained in:
Xi Yan 2024-10-08 23:41:03 -07:00
commit 6abef716dd
107 changed files with 4813 additions and 3587 deletions

View file

@ -1,8 +1,9 @@
built_at: '2024-09-23T00:54:40.551416'
version: '2'
built_at: '2024-10-08T17:40:45.325529'
image_name: local
docker_image: null
conda_env: local
apis_to_serve:
apis:
- shields
- agents
- models
@ -11,56 +12,23 @@ apis_to_serve:
- inference
- safety
- evals
api_providers:
providers:
evals:
provider_type: eleuther
config: {}
# evals:
# provider_type: meta-reference
# config: {}
inference:
providers:
- meta-reference
safety:
providers:
- meta-reference
agents:
provider_type: meta-reference
config:
persistence_store:
namespace: null
type: sqlite
db_path: /home/xiyan/.llama/runtime/kvstore.db
memory:
providers:
- meta-reference
telemetry:
- provider_id: meta-reference
provider_type: meta-reference
config: {}
routing_table:
inference:
- provider_type: meta-reference
- provider_id: meta-reference
provider_type: meta-reference
config:
model: Llama3.2-1B-Instruct
quantization: null
torch_seed: null
max_seq_len: 4096
max_batch_size: 1
routing_key: Llama3.2-1B-Instruct
# - provider_type: meta-reference
# config:
# model: Llama-Guard-3-1B
# quantization: null
# torch_seed: null
# max_seq_len: 4096
# max_batch_size: 1
# routing_key: Llama-Guard-3-1B
# - provider_type: remote::tgi
# config:
# url: http://127.0.0.1:5009
# routing_key: Llama3.1-8B-Instruct
safety:
- provider_type: meta-reference
- provider_id: meta-reference
provider_type: meta-reference
config:
llama_guard_shield:
model: Llama-Guard-3-1B
@ -69,8 +37,35 @@ routing_table:
disable_output_check: false
prompt_guard_shield:
model: Prompt-Guard-86M
routing_key: ["llama_guard", "code_scanner_guard", "injection_shield", "jailbreak_shield"]
memory:
- provider_type: meta-reference
- provider_id: meta-reference
provider_type: meta-reference
config: {}
routing_key: vector
agents:
- provider_id: meta-reference
provider_type: meta-reference
config:
persistence_store:
namespace: null
type: sqlite
db_path: /home/xiyan/.llama/runtime/kvstore.db
telemetry:
- provider_id: meta-reference
provider_type: meta-reference
config: {}
models:
- identifier: Llama3.2-1B-Instruct
llama_model: Llama3.2-1B-Instruct
provider_id: meta-reference
shields:
- identifier: llama_guard
type: llama_guard
provider_id: meta-reference
params: {}
memory_banks:
- identifier: vector
provider_id: meta-reference
type: vector
embedding_model: all-MiniLM-L6-v2
chunk_size_in_tokens: 512
overlap_size_in_tokens: null

View file

@ -1,126 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import unittest
from llama_models.llama3.api import * # noqa: F403
from llama_stack.inference.api import * # noqa: F403
from llama_stack.inference.augment_messages import augment_messages_for_tools
MODEL = "Llama3.1-8B-Instruct"
class PrepareMessagesTests(unittest.IsolatedAsyncioTestCase):
async def test_system_default(self):
content = "Hello !"
request = ChatCompletionRequest(
model=MODEL,
messages=[
UserMessage(content=content),
],
)
messages = augment_messages_for_tools(request)
self.assertEqual(len(messages), 2)
self.assertEqual(messages[-1].content, content)
self.assertTrue("Cutting Knowledge Date: December 2023" in messages[0].content)
async def test_system_builtin_only(self):
content = "Hello !"
request = ChatCompletionRequest(
model=MODEL,
messages=[
UserMessage(content=content),
],
tools=[
ToolDefinition(tool_name=BuiltinTool.code_interpreter),
ToolDefinition(tool_name=BuiltinTool.brave_search),
],
)
messages = augment_messages_for_tools(request)
self.assertEqual(len(messages), 2)
self.assertEqual(messages[-1].content, content)
self.assertTrue("Cutting Knowledge Date: December 2023" in messages[0].content)
self.assertTrue("Tools: brave_search" in messages[0].content)
async def test_system_custom_only(self):
content = "Hello !"
request = ChatCompletionRequest(
model=MODEL,
messages=[
UserMessage(content=content),
],
tools=[
ToolDefinition(
tool_name="custom1",
description="custom1 tool",
parameters={
"param1": ToolParamDefinition(
param_type="str",
description="param1 description",
required=True,
),
},
)
],
tool_prompt_format=ToolPromptFormat.json,
)
messages = augment_messages_for_tools(request)
self.assertEqual(len(messages), 3)
self.assertTrue("Environment: ipython" in messages[0].content)
self.assertTrue("Return function calls in JSON format" in messages[1].content)
self.assertEqual(messages[-1].content, content)
async def test_system_custom_and_builtin(self):
content = "Hello !"
request = ChatCompletionRequest(
model=MODEL,
messages=[
UserMessage(content=content),
],
tools=[
ToolDefinition(tool_name=BuiltinTool.code_interpreter),
ToolDefinition(tool_name=BuiltinTool.brave_search),
ToolDefinition(
tool_name="custom1",
description="custom1 tool",
parameters={
"param1": ToolParamDefinition(
param_type="str",
description="param1 description",
required=True,
),
},
),
],
)
messages = augment_messages_for_tools(request)
self.assertEqual(len(messages), 3)
self.assertTrue("Environment: ipython" in messages[0].content)
self.assertTrue("Tools: brave_search" in messages[0].content)
self.assertTrue("Return function calls in JSON format" in messages[1].content)
self.assertEqual(messages[-1].content, content)
async def test_user_provided_system_message(self):
content = "Hello !"
system_prompt = "You are a pirate"
request = ChatCompletionRequest(
model=MODEL,
messages=[
SystemMessage(content=system_prompt),
UserMessage(content=content),
],
tools=[
ToolDefinition(tool_name=BuiltinTool.code_interpreter),
],
)
messages = augment_messages_for_tools(request)
self.assertEqual(len(messages), 2, messages)
self.assertTrue(messages[0].content.endswith(system_prompt))
self.assertEqual(messages[-1].content, content)