mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-19 11:20:03 +00:00
refactor: move more tests, delete some providers tests (#1382)
Move unittests to tests/unittests. Gradually nuking tests from providers/tests/ and unifying them into tests/api (which are e2e tests using SDK types) ## Test Plan `pytest -s -v tests/unittests/`
This commit is contained in:
parent
e5ec68f66e
commit
86fc514abb
11 changed files with 6 additions and 142 deletions
127
tests/unittests/cli/test_stack_config.py
Normal file
127
tests/unittests/cli/test_stack_config.py
Normal file
|
@ -0,0 +1,127 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from datetime import datetime
|
||||
|
||||
import pytest
|
||||
import yaml
|
||||
|
||||
from llama_stack.distribution.configure import (
|
||||
LLAMA_STACK_RUN_CONFIG_VERSION,
|
||||
parse_and_maybe_upgrade_config,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def up_to_date_config():
|
||||
return yaml.safe_load(
|
||||
"""
|
||||
version: {version}
|
||||
image_name: foo
|
||||
apis_to_serve: []
|
||||
built_at: {built_at}
|
||||
providers:
|
||||
inference:
|
||||
- provider_id: provider1
|
||||
provider_type: inline::meta-reference
|
||||
config: {{}}
|
||||
safety:
|
||||
- provider_id: provider1
|
||||
provider_type: inline::meta-reference
|
||||
config:
|
||||
llama_guard_shield:
|
||||
model: Llama-Guard-3-1B
|
||||
excluded_categories: []
|
||||
disable_input_check: false
|
||||
disable_output_check: false
|
||||
enable_prompt_guard: false
|
||||
memory:
|
||||
- provider_id: provider1
|
||||
provider_type: inline::meta-reference
|
||||
config: {{}}
|
||||
""".format(version=LLAMA_STACK_RUN_CONFIG_VERSION, built_at=datetime.now().isoformat())
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def old_config():
|
||||
return yaml.safe_load(
|
||||
"""
|
||||
image_name: foo
|
||||
built_at: {built_at}
|
||||
apis_to_serve: []
|
||||
routing_table:
|
||||
inference:
|
||||
- provider_type: remote::ollama
|
||||
config:
|
||||
host: localhost
|
||||
port: 11434
|
||||
routing_key: Llama3.2-1B-Instruct
|
||||
- provider_type: inline::meta-reference
|
||||
config:
|
||||
model: Llama3.1-8B-Instruct
|
||||
routing_key: Llama3.1-8B-Instruct
|
||||
safety:
|
||||
- routing_key: ["shield1", "shield2"]
|
||||
provider_type: inline::meta-reference
|
||||
config:
|
||||
llama_guard_shield:
|
||||
model: Llama-Guard-3-1B
|
||||
excluded_categories: []
|
||||
disable_input_check: false
|
||||
disable_output_check: false
|
||||
enable_prompt_guard: false
|
||||
memory:
|
||||
- routing_key: vector
|
||||
provider_type: inline::meta-reference
|
||||
config: {{}}
|
||||
api_providers:
|
||||
telemetry:
|
||||
provider_type: noop
|
||||
config: {{}}
|
||||
""".format(built_at=datetime.now().isoformat())
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def invalid_config():
|
||||
return yaml.safe_load(
|
||||
"""
|
||||
routing_table: {}
|
||||
api_providers: {}
|
||||
"""
|
||||
)
|
||||
|
||||
|
||||
def test_parse_and_maybe_upgrade_config_up_to_date(up_to_date_config):
|
||||
result = parse_and_maybe_upgrade_config(up_to_date_config)
|
||||
assert result.version == LLAMA_STACK_RUN_CONFIG_VERSION
|
||||
assert "inference" in result.providers
|
||||
|
||||
|
||||
def test_parse_and_maybe_upgrade_config_old_format(old_config):
|
||||
result = parse_and_maybe_upgrade_config(old_config)
|
||||
assert result.version == LLAMA_STACK_RUN_CONFIG_VERSION
|
||||
assert all(api in result.providers for api in ["inference", "safety", "memory", "telemetry"])
|
||||
safety_provider = result.providers["safety"][0]
|
||||
assert safety_provider.provider_type == "inline::meta-reference"
|
||||
assert "llama_guard_shield" in safety_provider.config
|
||||
|
||||
inference_providers = result.providers["inference"]
|
||||
assert len(inference_providers) == 2
|
||||
assert {x.provider_id for x in inference_providers} == {
|
||||
"remote::ollama-00",
|
||||
"inline::meta-reference-01",
|
||||
}
|
||||
|
||||
ollama = inference_providers[0]
|
||||
assert ollama.provider_type == "remote::ollama"
|
||||
assert ollama.config["port"] == 11434
|
||||
|
||||
|
||||
def test_parse_and_maybe_upgrade_config_invalid(invalid_config):
|
||||
with pytest.raises(KeyError):
|
||||
parse_and_maybe_upgrade_config(invalid_config)
|
281
tests/unittests/models/test_prompt_adapter.py
Normal file
281
tests/unittests/models/test_prompt_adapter.py
Normal file
|
@ -0,0 +1,281 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import unittest
|
||||
|
||||
from llama_stack.apis.inference import (
|
||||
ChatCompletionRequest,
|
||||
CompletionMessage,
|
||||
StopReason,
|
||||
SystemMessage,
|
||||
ToolCall,
|
||||
ToolConfig,
|
||||
UserMessage,
|
||||
)
|
||||
from llama_stack.models.llama.datatypes import (
|
||||
BuiltinTool,
|
||||
ToolDefinition,
|
||||
ToolParamDefinition,
|
||||
ToolPromptFormat,
|
||||
)
|
||||
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||
chat_completion_request_to_messages,
|
||||
chat_completion_request_to_prompt,
|
||||
)
|
||||
|
||||
MODEL = "Llama3.1-8B-Instruct"
|
||||
MODEL3_2 = "Llama3.2-3B-Instruct"
|
||||
|
||||
|
||||
class PrepareMessagesTests(unittest.IsolatedAsyncioTestCase):
|
||||
async def test_system_default(self):
|
||||
content = "Hello !"
|
||||
request = ChatCompletionRequest(
|
||||
model=MODEL,
|
||||
messages=[
|
||||
UserMessage(content=content),
|
||||
],
|
||||
)
|
||||
messages = chat_completion_request_to_messages(request, MODEL)
|
||||
self.assertEqual(len(messages), 2)
|
||||
self.assertEqual(messages[-1].content, content)
|
||||
self.assertTrue("Cutting Knowledge Date: December 2023" in messages[0].content)
|
||||
|
||||
async def test_system_builtin_only(self):
|
||||
content = "Hello !"
|
||||
request = ChatCompletionRequest(
|
||||
model=MODEL,
|
||||
messages=[
|
||||
UserMessage(content=content),
|
||||
],
|
||||
tools=[
|
||||
ToolDefinition(tool_name=BuiltinTool.code_interpreter),
|
||||
ToolDefinition(tool_name=BuiltinTool.brave_search),
|
||||
],
|
||||
)
|
||||
messages = chat_completion_request_to_messages(request, MODEL)
|
||||
self.assertEqual(len(messages), 2)
|
||||
self.assertEqual(messages[-1].content, content)
|
||||
self.assertTrue("Cutting Knowledge Date: December 2023" in messages[0].content)
|
||||
self.assertTrue("Tools: brave_search" in messages[0].content)
|
||||
|
||||
async def test_system_custom_only(self):
|
||||
content = "Hello !"
|
||||
request = ChatCompletionRequest(
|
||||
model=MODEL,
|
||||
messages=[
|
||||
UserMessage(content=content),
|
||||
],
|
||||
tools=[
|
||||
ToolDefinition(
|
||||
tool_name="custom1",
|
||||
description="custom1 tool",
|
||||
parameters={
|
||||
"param1": ToolParamDefinition(
|
||||
param_type="str",
|
||||
description="param1 description",
|
||||
required=True,
|
||||
),
|
||||
},
|
||||
)
|
||||
],
|
||||
tool_config=ToolConfig(tool_prompt_format=ToolPromptFormat.json),
|
||||
)
|
||||
messages = chat_completion_request_to_messages(request, MODEL)
|
||||
self.assertEqual(len(messages), 3)
|
||||
self.assertTrue("Environment: ipython" in messages[0].content)
|
||||
|
||||
self.assertTrue("Return function calls in JSON format" in messages[1].content)
|
||||
self.assertEqual(messages[-1].content, content)
|
||||
|
||||
async def test_system_custom_and_builtin(self):
|
||||
content = "Hello !"
|
||||
request = ChatCompletionRequest(
|
||||
model=MODEL,
|
||||
messages=[
|
||||
UserMessage(content=content),
|
||||
],
|
||||
tools=[
|
||||
ToolDefinition(tool_name=BuiltinTool.code_interpreter),
|
||||
ToolDefinition(tool_name=BuiltinTool.brave_search),
|
||||
ToolDefinition(
|
||||
tool_name="custom1",
|
||||
description="custom1 tool",
|
||||
parameters={
|
||||
"param1": ToolParamDefinition(
|
||||
param_type="str",
|
||||
description="param1 description",
|
||||
required=True,
|
||||
),
|
||||
},
|
||||
),
|
||||
],
|
||||
)
|
||||
messages = chat_completion_request_to_messages(request, MODEL)
|
||||
self.assertEqual(len(messages), 3)
|
||||
|
||||
self.assertTrue("Environment: ipython" in messages[0].content)
|
||||
self.assertTrue("Tools: brave_search" in messages[0].content)
|
||||
|
||||
self.assertTrue("Return function calls in JSON format" in messages[1].content)
|
||||
self.assertEqual(messages[-1].content, content)
|
||||
|
||||
async def test_completion_message_encoding(self):
|
||||
request = ChatCompletionRequest(
|
||||
model=MODEL3_2,
|
||||
messages=[
|
||||
UserMessage(content="hello"),
|
||||
CompletionMessage(
|
||||
content="",
|
||||
stop_reason=StopReason.end_of_turn,
|
||||
tool_calls=[
|
||||
ToolCall(
|
||||
tool_name="custom1",
|
||||
arguments={"param1": "value1"},
|
||||
call_id="123",
|
||||
)
|
||||
],
|
||||
),
|
||||
],
|
||||
tools=[
|
||||
ToolDefinition(
|
||||
tool_name="custom1",
|
||||
description="custom1 tool",
|
||||
parameters={
|
||||
"param1": ToolParamDefinition(
|
||||
param_type="str",
|
||||
description="param1 description",
|
||||
required=True,
|
||||
),
|
||||
},
|
||||
),
|
||||
],
|
||||
tool_config=ToolConfig(tool_prompt_format=ToolPromptFormat.python_list),
|
||||
)
|
||||
prompt = await chat_completion_request_to_prompt(request, request.model)
|
||||
self.assertIn('[custom1(param1="value1")]', prompt)
|
||||
|
||||
request.model = MODEL
|
||||
request.tool_config.tool_prompt_format = ToolPromptFormat.json
|
||||
prompt = await chat_completion_request_to_prompt(request, request.model)
|
||||
self.assertIn('{"type": "function", "name": "custom1", "parameters": {"param1": "value1"}}', prompt)
|
||||
|
||||
async def test_user_provided_system_message(self):
|
||||
content = "Hello !"
|
||||
system_prompt = "You are a pirate"
|
||||
request = ChatCompletionRequest(
|
||||
model=MODEL,
|
||||
messages=[
|
||||
SystemMessage(content=system_prompt),
|
||||
UserMessage(content=content),
|
||||
],
|
||||
tools=[
|
||||
ToolDefinition(tool_name=BuiltinTool.code_interpreter),
|
||||
],
|
||||
)
|
||||
messages = chat_completion_request_to_messages(request, MODEL)
|
||||
self.assertEqual(len(messages), 2, messages)
|
||||
self.assertTrue(messages[0].content.endswith(system_prompt))
|
||||
|
||||
self.assertEqual(messages[-1].content, content)
|
||||
|
||||
async def test_repalce_system_message_behavior_builtin_tools(self):
|
||||
content = "Hello !"
|
||||
system_prompt = "You are a pirate"
|
||||
request = ChatCompletionRequest(
|
||||
model=MODEL,
|
||||
messages=[
|
||||
SystemMessage(content=system_prompt),
|
||||
UserMessage(content=content),
|
||||
],
|
||||
tools=[
|
||||
ToolDefinition(tool_name=BuiltinTool.code_interpreter),
|
||||
],
|
||||
tool_config=ToolConfig(
|
||||
tool_choice="auto",
|
||||
tool_prompt_format="python_list",
|
||||
system_message_behavior="replace",
|
||||
),
|
||||
)
|
||||
messages = chat_completion_request_to_messages(request, MODEL3_2)
|
||||
self.assertEqual(len(messages), 2, messages)
|
||||
self.assertTrue(messages[0].content.endswith(system_prompt))
|
||||
self.assertIn("Environment: ipython", messages[0].content)
|
||||
self.assertEqual(messages[-1].content, content)
|
||||
|
||||
async def test_repalce_system_message_behavior_custom_tools(self):
|
||||
content = "Hello !"
|
||||
system_prompt = "You are a pirate"
|
||||
request = ChatCompletionRequest(
|
||||
model=MODEL,
|
||||
messages=[
|
||||
SystemMessage(content=system_prompt),
|
||||
UserMessage(content=content),
|
||||
],
|
||||
tools=[
|
||||
ToolDefinition(tool_name=BuiltinTool.code_interpreter),
|
||||
ToolDefinition(
|
||||
tool_name="custom1",
|
||||
description="custom1 tool",
|
||||
parameters={
|
||||
"param1": ToolParamDefinition(
|
||||
param_type="str",
|
||||
description="param1 description",
|
||||
required=True,
|
||||
),
|
||||
},
|
||||
),
|
||||
],
|
||||
tool_config=ToolConfig(
|
||||
tool_choice="auto",
|
||||
tool_prompt_format="python_list",
|
||||
system_message_behavior="replace",
|
||||
),
|
||||
)
|
||||
messages = chat_completion_request_to_messages(request, MODEL3_2)
|
||||
|
||||
self.assertEqual(len(messages), 2, messages)
|
||||
self.assertTrue(messages[0].content.endswith(system_prompt))
|
||||
self.assertIn("Environment: ipython", messages[0].content)
|
||||
self.assertEqual(messages[-1].content, content)
|
||||
|
||||
async def test_replace_system_message_behavior_custom_tools_with_template(self):
|
||||
content = "Hello !"
|
||||
system_prompt = "You are a pirate {{ function_description }}"
|
||||
request = ChatCompletionRequest(
|
||||
model=MODEL,
|
||||
messages=[
|
||||
SystemMessage(content=system_prompt),
|
||||
UserMessage(content=content),
|
||||
],
|
||||
tools=[
|
||||
ToolDefinition(tool_name=BuiltinTool.code_interpreter),
|
||||
ToolDefinition(
|
||||
tool_name="custom1",
|
||||
description="custom1 tool",
|
||||
parameters={
|
||||
"param1": ToolParamDefinition(
|
||||
param_type="str",
|
||||
description="param1 description",
|
||||
required=True,
|
||||
),
|
||||
},
|
||||
),
|
||||
],
|
||||
tool_config=ToolConfig(
|
||||
tool_choice="auto",
|
||||
tool_prompt_format="python_list",
|
||||
system_message_behavior="replace",
|
||||
),
|
||||
)
|
||||
messages = chat_completion_request_to_messages(request, MODEL3_2)
|
||||
|
||||
self.assertEqual(len(messages), 2, messages)
|
||||
self.assertIn("Environment: ipython", messages[0].content)
|
||||
self.assertIn("You are a pirate", messages[0].content)
|
||||
# function description is present in the system prompt
|
||||
self.assertIn('"name": "custom1"', messages[0].content)
|
||||
self.assertEqual(messages[-1].content, content)
|
198
tests/unittests/models/test_system_prompts.py
Normal file
198
tests/unittests/models/test_system_prompts.py
Normal file
|
@ -0,0 +1,198 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# top-level folder for each specific model found within the models/ directory at
|
||||
# the top-level of this source tree.
|
||||
|
||||
import textwrap
|
||||
import unittest
|
||||
from datetime import datetime
|
||||
|
||||
from llama_stack.models.llama.llama3.prompt_templates import (
|
||||
BuiltinToolGenerator,
|
||||
FunctionTagCustomToolGenerator,
|
||||
JsonCustomToolGenerator,
|
||||
PythonListCustomToolGenerator,
|
||||
SystemDefaultGenerator,
|
||||
)
|
||||
|
||||
|
||||
class PromptTemplateTests(unittest.TestCase):
|
||||
def check_generator_output(self, generator, expected_text):
|
||||
example = generator.data_examples()[0]
|
||||
|
||||
pt = generator.gen(example)
|
||||
text = pt.render()
|
||||
# print(text) # debugging
|
||||
assert text == expected_text, f"Expected:\n{expected_text}\nActual:\n{text}"
|
||||
|
||||
def test_system_default(self):
|
||||
generator = SystemDefaultGenerator()
|
||||
today = datetime.now().strftime("%d %B %Y")
|
||||
expected_text = f"Cutting Knowledge Date: December 2023\nToday Date: {today}"
|
||||
self.check_generator_output(generator, expected_text)
|
||||
|
||||
def test_system_builtin_only(self):
|
||||
generator = BuiltinToolGenerator()
|
||||
expected_text = textwrap.dedent(
|
||||
"""
|
||||
Environment: ipython
|
||||
Tools: brave_search, wolfram_alpha
|
||||
"""
|
||||
)
|
||||
self.check_generator_output(generator, expected_text.strip("\n"))
|
||||
|
||||
def test_system_custom_only(self):
|
||||
self.maxDiff = None
|
||||
generator = JsonCustomToolGenerator()
|
||||
expected_text = textwrap.dedent(
|
||||
"""
|
||||
Answer the user's question by making use of the following functions if needed.
|
||||
If none of the function can be used, please say so.
|
||||
Here is a list of functions in JSON format:
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "trending_songs",
|
||||
"description": "Returns the trending songs on a Music site",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": [
|
||||
{
|
||||
"n": {
|
||||
"type": "object",
|
||||
"description": "The number of songs to return"
|
||||
}
|
||||
},
|
||||
{
|
||||
"genre": {
|
||||
"type": "object",
|
||||
"description": "The genre of the songs to return"
|
||||
}
|
||||
}
|
||||
],
|
||||
"required": ["n"]
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Return function calls in JSON format.
|
||||
"""
|
||||
)
|
||||
self.check_generator_output(generator, expected_text.strip("\n"))
|
||||
|
||||
def test_system_custom_function_tag(self):
|
||||
self.maxDiff = None
|
||||
generator = FunctionTagCustomToolGenerator()
|
||||
expected_text = textwrap.dedent(
|
||||
"""
|
||||
You have access to the following functions:
|
||||
|
||||
Use the function 'trending_songs' to 'Returns the trending songs on a Music site':
|
||||
{"name": "trending_songs", "description": "Returns the trending songs on a Music site", "parameters": {"genre": {"description": "The genre of the songs to return", "param_type": "str", "required": false}, "n": {"description": "The number of songs to return", "param_type": "int", "required": true}}}
|
||||
|
||||
Think very carefully before calling functions.
|
||||
If you choose to call a function ONLY reply in the following format with no prefix or suffix:
|
||||
|
||||
<function=example_function_name>{"example_name": "example_value"}</function>
|
||||
|
||||
Reminder:
|
||||
- If looking for real time information use relevant functions before falling back to brave_search
|
||||
- Function calls MUST follow the specified format, start with <function= and end with </function>
|
||||
- Required parameters MUST be specified
|
||||
- Only call one function at a time
|
||||
- Put the entire function call reply on one line
|
||||
"""
|
||||
)
|
||||
self.check_generator_output(generator, expected_text.strip("\n"))
|
||||
|
||||
def test_llama_3_2_system_zero_shot(self):
|
||||
generator = PythonListCustomToolGenerator()
|
||||
expected_text = textwrap.dedent(
|
||||
"""
|
||||
You are a helpful assistant. You have access to functions, but you should only use them if they are required.
|
||||
You are an expert in composing functions. You are given a question and a set of possible functions.
|
||||
Based on the question, you may or may not need to make one function/tool call to achieve the purpose.
|
||||
|
||||
If you decide to invoke any of the function(s), you MUST put it in the format of [func_name1(params_name1=params_value1, params_name2=params_value2...), func_name2(params)]
|
||||
You SHOULD NOT include any other text in the response.
|
||||
|
||||
Here is a list of functions in JSON format that you can invoke.
|
||||
|
||||
[
|
||||
{
|
||||
"name": "get_weather",
|
||||
"description": "Get weather info for places",
|
||||
"parameters": {
|
||||
"type": "dict",
|
||||
"required": ["city"],
|
||||
"properties": {
|
||||
"city": {
|
||||
"type": "string",
|
||||
"description": "The name of the city to get the weather for"
|
||||
},
|
||||
"metric": {
|
||||
"type": "string",
|
||||
"description": "The metric for weather. Options are: celsius, fahrenheit",
|
||||
"default": "celsius"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
]
|
||||
"""
|
||||
)
|
||||
self.check_generator_output(generator, expected_text.strip("\n"))
|
||||
|
||||
def test_llama_3_2_provided_system_prompt(self):
|
||||
generator = PythonListCustomToolGenerator()
|
||||
expected_text = textwrap.dedent(
|
||||
"""
|
||||
Overriding message.
|
||||
|
||||
If you decide to invoke any of the function(s), you MUST put it in the format of [func_name1(params_name1=params_value1, params_name2=params_value2...), func_name2(params)]
|
||||
You SHOULD NOT include any other text in the response.
|
||||
|
||||
Here is a list of functions in JSON format that you can invoke.
|
||||
|
||||
[
|
||||
{
|
||||
"name": "get_weather",
|
||||
"description": "Get weather info for places",
|
||||
"parameters": {
|
||||
"type": "dict",
|
||||
"required": ["city"],
|
||||
"properties": {
|
||||
"city": {
|
||||
"type": "string",
|
||||
"description": "The name of the city to get the weather for"
|
||||
},
|
||||
"metric": {
|
||||
"type": "string",
|
||||
"description": "The metric for weather. Options are: celsius, fahrenheit",
|
||||
"default": "celsius"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
]"""
|
||||
)
|
||||
user_system_prompt = textwrap.dedent(
|
||||
"""
|
||||
Overriding message.
|
||||
|
||||
{{ function_description }}
|
||||
"""
|
||||
)
|
||||
example = generator.data_examples()[0]
|
||||
|
||||
pt = generator.gen(example, user_system_prompt)
|
||||
text = pt.render()
|
||||
assert text == expected_text, f"Expected:\n{expected_text}\nActual:\n{text}"
|
BIN
tests/unittests/rag/fixtures/dummy.pdf
Normal file
BIN
tests/unittests/rag/fixtures/dummy.pdf
Normal file
Binary file not shown.
76
tests/unittests/rag/test_vector_store.py
Normal file
76
tests/unittests/rag/test_vector_store.py
Normal file
|
@ -0,0 +1,76 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import base64
|
||||
import mimetypes
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from llama_stack.apis.tools import RAGDocument
|
||||
from llama_stack.providers.utils.memory.vector_store import URL, content_from_doc
|
||||
|
||||
DUMMY_PDF_PATH = Path(os.path.abspath(__file__)).parent / "fixtures" / "dummy.pdf"
|
||||
|
||||
|
||||
def read_file(file_path: str) -> bytes:
|
||||
with open(file_path, "rb") as file:
|
||||
return file.read()
|
||||
|
||||
|
||||
def data_url_from_file(file_path: str) -> str:
|
||||
with open(file_path, "rb") as file:
|
||||
file_content = file.read()
|
||||
|
||||
base64_content = base64.b64encode(file_content).decode("utf-8")
|
||||
mime_type, _ = mimetypes.guess_type(file_path)
|
||||
|
||||
data_url = f"data:{mime_type};base64,{base64_content}"
|
||||
|
||||
return data_url
|
||||
|
||||
|
||||
class TestVectorStore:
|
||||
@pytest.mark.asyncio
|
||||
async def test_returns_content_from_pdf_data_uri(self):
|
||||
data_uri = data_url_from_file(DUMMY_PDF_PATH)
|
||||
doc = RAGDocument(
|
||||
document_id="dummy",
|
||||
content=data_uri,
|
||||
mime_type="application/pdf",
|
||||
metadata={},
|
||||
)
|
||||
content = await content_from_doc(doc)
|
||||
assert content == "Dumm y PDF file"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_downloads_pdf_and_returns_content(self):
|
||||
# Using GitHub to host the PDF file
|
||||
url = "https://raw.githubusercontent.com/meta-llama/llama-stack/da035d69cfca915318eaf485770a467ca3c2a238/llama_stack/providers/tests/memory/fixtures/dummy.pdf"
|
||||
doc = RAGDocument(
|
||||
document_id="dummy",
|
||||
content=url,
|
||||
mime_type="application/pdf",
|
||||
metadata={},
|
||||
)
|
||||
content = await content_from_doc(doc)
|
||||
assert content == "Dumm y PDF file"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_downloads_pdf_and_returns_content_with_url_object(self):
|
||||
# Using GitHub to host the PDF file
|
||||
url = "https://raw.githubusercontent.com/meta-llama/llama-stack/da035d69cfca915318eaf485770a467ca3c2a238/llama_stack/providers/tests/memory/fixtures/dummy.pdf"
|
||||
doc = RAGDocument(
|
||||
document_id="dummy",
|
||||
content=URL(
|
||||
uri=url,
|
||||
),
|
||||
mime_type="application/pdf",
|
||||
metadata={},
|
||||
)
|
||||
content = await content_from_doc(doc)
|
||||
assert content == "Dumm y PDF file"
|
199
tests/unittests/registry/test_registry.py
Normal file
199
tests/unittests/registry/test_registry.py
Normal file
|
@ -0,0 +1,199 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import os
|
||||
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
|
||||
from llama_stack.apis.inference import Model
|
||||
from llama_stack.apis.vector_dbs import VectorDB
|
||||
from llama_stack.distribution.store.registry import (
|
||||
CachedDiskDistributionRegistry,
|
||||
DiskDistributionRegistry,
|
||||
)
|
||||
from llama_stack.providers.utils.kvstore import kvstore_impl
|
||||
from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def config():
|
||||
config = SqliteKVStoreConfig(db_path="/tmp/test_registry.db")
|
||||
if os.path.exists(config.db_path):
|
||||
os.remove(config.db_path)
|
||||
return config
|
||||
|
||||
|
||||
@pytest_asyncio.fixture(scope="function")
|
||||
async def registry(config):
|
||||
registry = DiskDistributionRegistry(await kvstore_impl(config))
|
||||
await registry.initialize()
|
||||
return registry
|
||||
|
||||
|
||||
@pytest_asyncio.fixture(scope="function")
|
||||
async def cached_registry(config):
|
||||
registry = CachedDiskDistributionRegistry(await kvstore_impl(config))
|
||||
await registry.initialize()
|
||||
return registry
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_vector_db():
|
||||
return VectorDB(
|
||||
identifier="test_vector_db",
|
||||
embedding_model="all-MiniLM-L6-v2",
|
||||
embedding_dimension=384,
|
||||
provider_resource_id="test_vector_db",
|
||||
provider_id="test-provider",
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_model():
|
||||
return Model(
|
||||
identifier="test_model",
|
||||
provider_resource_id="test_model",
|
||||
provider_id="test-provider",
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_registry_initialization(registry):
|
||||
# Test empty registry
|
||||
result = await registry.get("nonexistent", "nonexistent")
|
||||
assert result is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_basic_registration(registry, sample_vector_db, sample_model):
|
||||
print(f"Registering {sample_vector_db}")
|
||||
await registry.register(sample_vector_db)
|
||||
print(f"Registering {sample_model}")
|
||||
await registry.register(sample_model)
|
||||
print("Getting vector_db")
|
||||
result_vector_db = await registry.get("vector_db", "test_vector_db")
|
||||
assert result_vector_db is not None
|
||||
assert result_vector_db.identifier == sample_vector_db.identifier
|
||||
assert result_vector_db.embedding_model == sample_vector_db.embedding_model
|
||||
assert result_vector_db.provider_id == sample_vector_db.provider_id
|
||||
|
||||
result_model = await registry.get("model", "test_model")
|
||||
assert result_model is not None
|
||||
assert result_model.identifier == sample_model.identifier
|
||||
assert result_model.provider_id == sample_model.provider_id
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cached_registry_initialization(config, sample_vector_db, sample_model):
|
||||
# First populate the disk registry
|
||||
disk_registry = DiskDistributionRegistry(await kvstore_impl(config))
|
||||
await disk_registry.initialize()
|
||||
await disk_registry.register(sample_vector_db)
|
||||
await disk_registry.register(sample_model)
|
||||
|
||||
# Test cached version loads from disk
|
||||
cached_registry = CachedDiskDistributionRegistry(await kvstore_impl(config))
|
||||
await cached_registry.initialize()
|
||||
|
||||
result_vector_db = await cached_registry.get("vector_db", "test_vector_db")
|
||||
assert result_vector_db is not None
|
||||
assert result_vector_db.identifier == sample_vector_db.identifier
|
||||
assert result_vector_db.embedding_model == sample_vector_db.embedding_model
|
||||
assert result_vector_db.embedding_dimension == sample_vector_db.embedding_dimension
|
||||
assert result_vector_db.provider_id == sample_vector_db.provider_id
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cached_registry_updates(config):
|
||||
cached_registry = CachedDiskDistributionRegistry(await kvstore_impl(config))
|
||||
await cached_registry.initialize()
|
||||
|
||||
new_vector_db = VectorDB(
|
||||
identifier="test_vector_db_2",
|
||||
embedding_model="all-MiniLM-L6-v2",
|
||||
embedding_dimension=384,
|
||||
provider_resource_id="test_vector_db_2",
|
||||
provider_id="baz",
|
||||
)
|
||||
await cached_registry.register(new_vector_db)
|
||||
|
||||
# Verify in cache
|
||||
result_vector_db = await cached_registry.get("vector_db", "test_vector_db_2")
|
||||
assert result_vector_db is not None
|
||||
assert result_vector_db.identifier == new_vector_db.identifier
|
||||
assert result_vector_db.provider_id == new_vector_db.provider_id
|
||||
|
||||
# Verify persisted to disk
|
||||
new_registry = DiskDistributionRegistry(await kvstore_impl(config))
|
||||
await new_registry.initialize()
|
||||
result_vector_db = await new_registry.get("vector_db", "test_vector_db_2")
|
||||
assert result_vector_db is not None
|
||||
assert result_vector_db.identifier == new_vector_db.identifier
|
||||
assert result_vector_db.provider_id == new_vector_db.provider_id
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_duplicate_provider_registration(config):
|
||||
cached_registry = CachedDiskDistributionRegistry(await kvstore_impl(config))
|
||||
await cached_registry.initialize()
|
||||
|
||||
original_vector_db = VectorDB(
|
||||
identifier="test_vector_db_2",
|
||||
embedding_model="all-MiniLM-L6-v2",
|
||||
embedding_dimension=384,
|
||||
provider_resource_id="test_vector_db_2",
|
||||
provider_id="baz",
|
||||
)
|
||||
await cached_registry.register(original_vector_db)
|
||||
|
||||
duplicate_vector_db = VectorDB(
|
||||
identifier="test_vector_db_2",
|
||||
embedding_model="different-model",
|
||||
embedding_dimension=384,
|
||||
provider_resource_id="test_vector_db_2",
|
||||
provider_id="baz", # Same provider_id
|
||||
)
|
||||
await cached_registry.register(duplicate_vector_db)
|
||||
|
||||
result = await cached_registry.get("vector_db", "test_vector_db_2")
|
||||
assert result is not None
|
||||
assert result.embedding_model == original_vector_db.embedding_model # Original values preserved
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_all_objects(config):
|
||||
cached_registry = CachedDiskDistributionRegistry(await kvstore_impl(config))
|
||||
await cached_registry.initialize()
|
||||
|
||||
# Create multiple test banks
|
||||
test_vector_dbs = [
|
||||
VectorDB(
|
||||
identifier=f"test_vector_db_{i}",
|
||||
embedding_model="all-MiniLM-L6-v2",
|
||||
embedding_dimension=384,
|
||||
provider_resource_id=f"test_vector_db_{i}",
|
||||
provider_id=f"provider_{i}",
|
||||
)
|
||||
for i in range(3)
|
||||
]
|
||||
|
||||
# Register all vector_dbs
|
||||
for vector_db in test_vector_dbs:
|
||||
await cached_registry.register(vector_db)
|
||||
|
||||
# Test get_all retrieval
|
||||
all_results = await cached_registry.get_all()
|
||||
assert len(all_results) == 3
|
||||
|
||||
# Verify each vector_db was stored correctly
|
||||
for original_vector_db in test_vector_dbs:
|
||||
matching_vector_dbs = [v for v in all_results if v.identifier == original_vector_db.identifier]
|
||||
assert len(matching_vector_dbs) == 1
|
||||
stored_vector_db = matching_vector_dbs[0]
|
||||
assert stored_vector_db.embedding_model == original_vector_db.embedding_model
|
||||
assert stored_vector_db.provider_id == original_vector_db.provider_id
|
||||
assert stored_vector_db.embedding_dimension == original_vector_db.embedding_dimension
|
88
tests/unittests/server/test_logcat.py
Normal file
88
tests/unittests/server/test_logcat.py
Normal file
|
@ -0,0 +1,88 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import io
|
||||
import logging
|
||||
import os
|
||||
import unittest
|
||||
|
||||
from llama_stack import logcat
|
||||
|
||||
|
||||
class TestLogcat(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.original_env = os.environ.get("LLAMA_STACK_LOGGING")
|
||||
|
||||
self.log_output = io.StringIO()
|
||||
self._init_logcat()
|
||||
|
||||
def tearDown(self):
|
||||
if self.original_env is not None:
|
||||
os.environ["LLAMA_STACK_LOGGING"] = self.original_env
|
||||
else:
|
||||
os.environ.pop("LLAMA_STACK_LOGGING", None)
|
||||
|
||||
def _init_logcat(self):
|
||||
logcat.init(default_level=logging.DEBUG)
|
||||
self.handler = logging.StreamHandler(self.log_output)
|
||||
self.handler.setFormatter(logging.Formatter("[%(category)s] %(message)s"))
|
||||
logcat._logger.handlers.clear()
|
||||
logcat._logger.addHandler(self.handler)
|
||||
|
||||
def test_basic_logging(self):
|
||||
logcat.info("server", "Info message")
|
||||
logcat.warning("server", "Warning message")
|
||||
logcat.error("server", "Error message")
|
||||
|
||||
output = self.log_output.getvalue()
|
||||
self.assertIn("[server] Info message", output)
|
||||
self.assertIn("[server] Warning message", output)
|
||||
self.assertIn("[server] Error message", output)
|
||||
|
||||
def test_different_categories(self):
|
||||
# Log messages with different categories
|
||||
logcat.info("server", "Server message")
|
||||
logcat.info("inference", "Inference message")
|
||||
logcat.info("router", "Router message")
|
||||
|
||||
output = self.log_output.getvalue()
|
||||
self.assertIn("[server] Server message", output)
|
||||
self.assertIn("[inference] Inference message", output)
|
||||
self.assertIn("[router] Router message", output)
|
||||
|
||||
def test_env_var_control(self):
|
||||
os.environ["LLAMA_STACK_LOGGING"] = "server=debug;inference=warning"
|
||||
self._init_logcat()
|
||||
|
||||
# These should be visible based on the environment settings
|
||||
logcat.debug("server", "Server debug message")
|
||||
logcat.info("server", "Server info message")
|
||||
logcat.warning("inference", "Inference warning message")
|
||||
logcat.error("inference", "Inference error message")
|
||||
|
||||
# These should be filtered out based on the environment settings
|
||||
logcat.debug("inference", "Inference debug message")
|
||||
logcat.info("inference", "Inference info message")
|
||||
|
||||
output = self.log_output.getvalue()
|
||||
self.assertIn("[server] Server debug message", output)
|
||||
self.assertIn("[server] Server info message", output)
|
||||
self.assertIn("[inference] Inference warning message", output)
|
||||
self.assertIn("[inference] Inference error message", output)
|
||||
|
||||
self.assertNotIn("[inference] Inference debug message", output)
|
||||
self.assertNotIn("[inference] Inference info message", output)
|
||||
|
||||
def test_invalid_category(self):
|
||||
logcat.info("nonexistent", "This message should not be logged")
|
||||
|
||||
# Check that the message was not logged
|
||||
output = self.log_output.getvalue()
|
||||
self.assertNotIn("[nonexistent] This message should not be logged", output)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
66
tests/unittests/server/test_replace_env_vars.py
Normal file
66
tests/unittests/server/test_replace_env_vars.py
Normal file
|
@ -0,0 +1,66 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import os
|
||||
import unittest
|
||||
|
||||
from llama_stack.distribution.stack import replace_env_vars
|
||||
|
||||
|
||||
class TestReplaceEnvVars(unittest.TestCase):
|
||||
def setUp(self):
|
||||
# Clear any existing environment variables we'll use in tests
|
||||
for var in ["TEST_VAR", "EMPTY_VAR", "ZERO_VAR"]:
|
||||
if var in os.environ:
|
||||
del os.environ[var]
|
||||
|
||||
# Set up test environment variables
|
||||
os.environ["TEST_VAR"] = "test_value"
|
||||
os.environ["EMPTY_VAR"] = ""
|
||||
os.environ["ZERO_VAR"] = "0"
|
||||
|
||||
def test_simple_replacement(self):
|
||||
self.assertEqual(replace_env_vars("${env.TEST_VAR}"), "test_value")
|
||||
|
||||
def test_default_value_when_not_set(self):
|
||||
self.assertEqual(replace_env_vars("${env.NOT_SET:default}"), "default")
|
||||
|
||||
def test_default_value_when_set(self):
|
||||
self.assertEqual(replace_env_vars("${env.TEST_VAR:default}"), "test_value")
|
||||
|
||||
def test_default_value_when_empty(self):
|
||||
self.assertEqual(replace_env_vars("${env.EMPTY_VAR:default}"), "default")
|
||||
|
||||
def test_conditional_value_when_set(self):
|
||||
self.assertEqual(replace_env_vars("${env.TEST_VAR+conditional}"), "conditional")
|
||||
|
||||
def test_conditional_value_when_not_set(self):
|
||||
self.assertEqual(replace_env_vars("${env.NOT_SET+conditional}"), "")
|
||||
|
||||
def test_conditional_value_when_empty(self):
|
||||
self.assertEqual(replace_env_vars("${env.EMPTY_VAR+conditional}"), "")
|
||||
|
||||
def test_conditional_value_with_zero(self):
|
||||
self.assertEqual(replace_env_vars("${env.ZERO_VAR+conditional}"), "conditional")
|
||||
|
||||
def test_mixed_syntax(self):
|
||||
self.assertEqual(replace_env_vars("${env.TEST_VAR:default} and ${env.NOT_SET+conditional}"), "test_value and ")
|
||||
self.assertEqual(
|
||||
replace_env_vars("${env.NOT_SET:default} and ${env.TEST_VAR+conditional}"), "default and conditional"
|
||||
)
|
||||
|
||||
def test_nested_structures(self):
|
||||
data = {
|
||||
"key1": "${env.TEST_VAR:default}",
|
||||
"key2": ["${env.NOT_SET:default}", "${env.TEST_VAR+conditional}"],
|
||||
"key3": {"nested": "${env.NOT_SET+conditional}"},
|
||||
}
|
||||
expected = {"key1": "test_value", "key2": ["default", "conditional"], "key3": {"nested": ""}}
|
||||
self.assertEqual(replace_env_vars(data), expected)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
Loading…
Add table
Add a link
Reference in a new issue