From 597869a2aa46d70a96ef60aee266d66c57d0ed8e Mon Sep 17 00:00:00 2001 From: Sixian Yi Date: Wed, 22 Jan 2025 19:20:49 -0800 Subject: [PATCH] add distro report (#847) # What does this PR do? Generate distro reports to cover inference, agents, and vector_io. ## Test Plan Report generated through `/opt/miniconda3/envs/stack/bin/pytest -s -v tests/client-sdk/ --report` ## Sources Please link relevant resources if necessary. ## Before submitting - [ ] This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case). - [ ] Ran pre-commit to handle lint / formatting issues. - [ ] Read the [contributor guideline](https://github.com/meta-llama/llama-stack/blob/main/CONTRIBUTING.md), Pull Request section? - [ ] Updated relevant documentation. - [ ] Wrote necessary unit or integration tests. --- llama_stack/templates/cerebras/report.md | 44 ++++++++ llama_stack/templates/fireworks/report.md | 57 +++++----- llama_stack/templates/ollama/report.md | 44 ++++++++ llama_stack/templates/tgi/report.md | 44 ++++++++ llama_stack/templates/together/report.md | 44 ++++++++ tests/client-sdk/metadata.py | 14 +-- tests/client-sdk/report.py | 103 +++++++++++++++---- tests/client-sdk/vector_io/test_vector_io.py | 38 ++++++- 8 files changed, 328 insertions(+), 60 deletions(-) create mode 100644 llama_stack/templates/cerebras/report.md create mode 100644 llama_stack/templates/ollama/report.md create mode 100644 llama_stack/templates/tgi/report.md create mode 100644 llama_stack/templates/together/report.md diff --git a/llama_stack/templates/cerebras/report.md b/llama_stack/templates/cerebras/report.md new file mode 100644 index 000000000..c65cd4979 --- /dev/null +++ b/llama_stack/templates/cerebras/report.md @@ -0,0 +1,44 @@ +# Report for cerebras distribution + +## Supported Models: +| Model Descriptor | cerebras | +|:---|:---| +| meta-llama/Llama-3-8B-Instruct | ❌ | +| meta-llama/Llama-3-70B-Instruct | ❌ | +| meta-llama/Llama-3.1-8B-Instruct | ✅ | +| meta-llama/Llama-3.1-70B-Instruct | ❌ | +| meta-llama/Llama-3.1-405B-Instruct-FP8 | ❌ | +| meta-llama/Llama-3.2-1B-Instruct | ❌ | +| meta-llama/Llama-3.2-3B-Instruct | ❌ | +| meta-llama/Llama-3.2-11B-Vision-Instruct | ❌ | +| meta-llama/Llama-3.2-90B-Vision-Instruct | ❌ | +| meta-llama/Llama-3.3-70B-Instruct | ✅ | +| meta-llama/Llama-Guard-3-11B-Vision | ❌ | +| meta-llama/Llama-Guard-3-1B | ❌ | +| meta-llama/Llama-Guard-3-8B | ❌ | +| meta-llama/Llama-Guard-2-8B | ❌ | + +## Inference: +| Model | API | Capability | Test | Status | +|:----- |:-----|:-----|:-----|:-----| +| Llama-3.1-8B-Instruct | /chat_completion | streaming | test_text_chat_completion_streaming | ✅ | +| Llama-3.2-11B-Vision-Instruct | /chat_completion | streaming | test_image_chat_completion_streaming | ❌ | +| Llama-3.2-11B-Vision-Instruct | /chat_completion | non_streaming | test_image_chat_completion_non_streaming | ❌ | +| Llama-3.1-8B-Instruct | /chat_completion | non_streaming | test_text_chat_completion_non_streaming | ✅ | +| Llama-3.1-8B-Instruct | /chat_completion | tool_calling | test_text_chat_completion_with_tool_calling_and_streaming | ✅ | +| Llama-3.1-8B-Instruct | /chat_completion | tool_calling | test_text_chat_completion_with_tool_calling_and_non_streaming | ✅ | +| Llama-3.1-8B-Instruct | /completion | streaming | test_text_completion_streaming | ✅ | +| Llama-3.1-8B-Instruct | /completion | non_streaming | test_text_completion_non_streaming | ✅ | +| Llama-3.1-8B-Instruct | /completion | structured_output | test_text_completion_structured_output | ❌ | + +## Vector_io: +| API | Capability | Test | Status | +|:-----|:-----|:-----|:-----| +| /retrieve | | test_vector_db_retrieve | ✅ | + +## Agents: +| API | Capability | Test | Status | +|:-----|:-----|:-----|:-----| +| /create_agent_turn | rag | test_rag_agent | ✅ | +| /create_agent_turn | custom_tool | test_custom_tool | ❌ | +| /create_agent_turn | code_execution | test_code_interpreter_for_attachments | ✅ | diff --git a/llama_stack/templates/fireworks/report.md b/llama_stack/templates/fireworks/report.md index 55efec0f5..1c5550bf4 100644 --- a/llama_stack/templates/fireworks/report.md +++ b/llama_stack/templates/fireworks/report.md @@ -3,43 +3,42 @@ ## Supported Models: | Model Descriptor | fireworks | |:---|:---| -| meta-llama/Llama-3-8B-Instruct | ❌ | -| meta-llama/Llama-3-70B-Instruct | ❌ | -| meta-llama/Llama-3.1-8B-Instruct | ✅ | -| meta-llama/Llama-3.1-70B-Instruct | ✅ | -| meta-llama/Llama-3.1-405B-Instruct-FP8 | ✅ | -| meta-llama/Llama-3.2-1B-Instruct | ✅ | -| meta-llama/Llama-3.2-3B-Instruct | ✅ | -| meta-llama/Llama-3.2-11B-Vision-Instruct | ✅ | -| meta-llama/Llama-3.2-90B-Vision-Instruct | ✅ | -| meta-llama/Llama-3.3-70B-Instruct | ✅ | -| meta-llama/Llama-Guard-3-11B-Vision | ✅ | -| meta-llama/Llama-Guard-3-1B | ❌ | -| meta-llama/Llama-Guard-3-8B | ✅ | -| meta-llama/Llama-Guard-2-8B | ❌ | +| Llama-3-8B-Instruct | ❌ | +| Llama-3-70B-Instruct | ❌ | +| Llama3.1-8B-Instruct | ✅ | +| Llama3.1-70B-Instruct | ✅ | +| Llama3.1-405B-Instruct | ✅ | +| Llama3.2-1B-Instruct | ✅ | +| Llama3.2-3B-Instruct | ✅ | +| Llama3.2-11B-Vision-Instruct | ✅ | +| Llama3.2-90B-Vision-Instruct | ✅ | +| Llama3.3-70B-Instruct | ✅ | +| Llama-Guard-3-11B-Vision | ✅ | +| Llama-Guard-3-1B | ❌ | +| Llama-Guard-3-8B | ✅ | +| Llama-Guard-2-8B | ❌ | ## Inference: | Model | API | Capability | Test | Status | |:----- |:-----|:-----|:-----|:-----| -| Text | /chat_completion | streaming | test_text_chat_completion_streaming | ✅ | -| Vision | /chat_completion | streaming | test_image_chat_completion_streaming | ✅ | -| Vision | /chat_completion | non_streaming | test_image_chat_completion_non_streaming | ✅ | -| Text | /chat_completion | non_streaming | test_text_chat_completion_non_streaming | ✅ | -| Text | /chat_completion | tool_calling | test_text_chat_completion_with_tool_calling_and_streaming | ✅ | -| Text | /chat_completion | tool_calling | test_text_chat_completion_with_tool_calling_and_non_streaming | ✅ | -| Text | /completion | streaming | test_text_completion_streaming | ✅ | -| Text | /completion | non_streaming | test_text_completion_non_streaming | ✅ | -| Text | /completion | structured_output | test_text_completion_structured_output | ✅ | +| Llama-3.1-8B-Instruct | /chat_completion | streaming | test_text_chat_completion_streaming | ✅ | +| Llama-3.2-11B-Vision-Instruct | /chat_completion | streaming | test_image_chat_completion_streaming | ✅ | +| Llama-3.2-11B-Vision-Instruct | /chat_completion | non_streaming | test_image_chat_completion_non_streaming | ✅ | +| Llama-3.1-8B-Instruct | /chat_completion | non_streaming | test_text_chat_completion_non_streaming | ✅ | +| Llama-3.1-8B-Instruct | /chat_completion | tool_calling | test_text_chat_completion_with_tool_calling_and_streaming | ✅ | +| Llama-3.1-8B-Instruct | /chat_completion | tool_calling | test_text_chat_completion_with_tool_calling_and_non_streaming | ✅ | +| Llama-3.1-8B-Instruct | /completion | streaming | test_text_completion_streaming | ✅ | +| Llama-3.1-8B-Instruct | /completion | non_streaming | test_text_completion_non_streaming | ✅ | +| Llama-3.1-8B-Instruct | /completion | structured_output | test_text_completion_structured_output | ✅ | -## Memory: +## Vector_io: | API | Capability | Test | Status | |:-----|:-----|:-----|:-----| -| /insert, /query | inline | test_memory_bank_insert_inline_and_query | ✅ | -| /insert, /query | url | test_memory_bank_insert_from_url_and_query | ✅ | +| /retrieve | | test_vector_db_retrieve | ✅ | ## Agents: | API | Capability | Test | Status | |:-----|:-----|:-----|:-----| -| create_agent_turn | rag | test_rag_agent | ✅ | -| create_agent_turn | custom_tool | test_custom_tool | ✅ | -| create_agent_turn | code_execution | test_code_execution | ❌ | +| /create_agent_turn | rag | test_rag_agent | ✅ | +| /create_agent_turn | custom_tool | test_custom_tool | ✅ | +| /create_agent_turn | code_execution | test_code_interpreter_for_attachments | ✅ | diff --git a/llama_stack/templates/ollama/report.md b/llama_stack/templates/ollama/report.md new file mode 100644 index 000000000..0d370b8ec --- /dev/null +++ b/llama_stack/templates/ollama/report.md @@ -0,0 +1,44 @@ +# Report for ollama distribution + +## Supported Models: +| Model Descriptor | ollama | +|:---|:---| +| Llama-3-8B-Instruct | ❌ | +| Llama-3-70B-Instruct | ❌ | +| Llama3.1-8B-Instruct | ✅ | +| Llama3.1-70B-Instruct | ✅ | +| Llama3.1-405B-Instruct | ✅ | +| Llama3.2-1B-Instruct | ✅ | +| Llama3.2-3B-Instruct | ✅ | +| Llama3.2-11B-Vision-Instruct | ✅ | +| Llama3.2-90B-Vision-Instruct | ✅ | +| Llama3.3-70B-Instruct | ✅ | +| Llama-Guard-3-11B-Vision | ❌ | +| Llama-Guard-3-1B | ✅ | +| Llama-Guard-3-8B | ✅ | +| Llama-Guard-2-8B | ❌ | + +## Inference: +| Model | API | Capability | Test | Status | +|:----- |:-----|:-----|:-----|:-----| +| Llama-3.1-8B-Instruct | /chat_completion | streaming | test_text_chat_completion_streaming | ✅ | +| Llama-3.2-11B-Vision-Instruct | /chat_completion | streaming | test_image_chat_completion_streaming | ❌ | +| Llama-3.2-11B-Vision-Instruct | /chat_completion | non_streaming | test_image_chat_completion_non_streaming | ❌ | +| Llama-3.1-8B-Instruct | /chat_completion | non_streaming | test_text_chat_completion_non_streaming | ✅ | +| Llama-3.1-8B-Instruct | /chat_completion | tool_calling | test_text_chat_completion_with_tool_calling_and_streaming | ✅ | +| Llama-3.1-8B-Instruct | /chat_completion | tool_calling | test_text_chat_completion_with_tool_calling_and_non_streaming | ✅ | +| Llama-3.1-8B-Instruct | /completion | streaming | test_text_completion_streaming | ✅ | +| Llama-3.1-8B-Instruct | /completion | non_streaming | test_text_completion_non_streaming | ✅ | +| Llama-3.1-8B-Instruct | /completion | structured_output | test_text_completion_structured_output | ✅ | + +## Vector_io: +| API | Capability | Test | Status | +|:-----|:-----|:-----|:-----| +| /retrieve | | test_vector_db_retrieve | ✅ | + +## Agents: +| API | Capability | Test | Status | +|:-----|:-----|:-----|:-----| +| /create_agent_turn | rag | test_rag_agent | ✅ | +| /create_agent_turn | custom_tool | test_custom_tool | ✅ | +| /create_agent_turn | code_execution | test_code_interpreter_for_attachments | ✅ | diff --git a/llama_stack/templates/tgi/report.md b/llama_stack/templates/tgi/report.md new file mode 100644 index 000000000..1f76ff692 --- /dev/null +++ b/llama_stack/templates/tgi/report.md @@ -0,0 +1,44 @@ +# Report for tgi distribution + +## Supported Models: +| Model Descriptor | tgi | +|:---|:---| +| Llama-3-8B-Instruct | ✅ | +| Llama-3-70B-Instruct | ✅ | +| Llama3.1-8B-Instruct | ✅ | +| Llama3.1-70B-Instruct | ✅ | +| Llama3.1-405B-Instruct | ✅ | +| Llama3.2-1B-Instruct | ✅ | +| Llama3.2-3B-Instruct | ✅ | +| Llama3.2-11B-Vision-Instruct | ✅ | +| Llama3.2-90B-Vision-Instruct | ✅ | +| Llama3.3-70B-Instruct | ✅ | +| Llama-Guard-3-11B-Vision | ✅ | +| Llama-Guard-3-1B | ✅ | +| Llama-Guard-3-8B | ✅ | +| Llama-Guard-2-8B | ✅ | + +## Inference: +| Model | API | Capability | Test | Status | +|:----- |:-----|:-----|:-----|:-----| +| Llama-3.1-8B-Instruct | /chat_completion | streaming | test_text_chat_completion_streaming | ✅ | +| Llama-3.2-11B-Vision-Instruct | /chat_completion | streaming | test_image_chat_completion_streaming | ❌ | +| Llama-3.2-11B-Vision-Instruct | /chat_completion | non_streaming | test_image_chat_completion_non_streaming | ❌ | +| Llama-3.1-8B-Instruct | /chat_completion | non_streaming | test_text_chat_completion_non_streaming | ✅ | +| Llama-3.1-8B-Instruct | /chat_completion | tool_calling | test_text_chat_completion_with_tool_calling_and_streaming | ✅ | +| Llama-3.1-8B-Instruct | /chat_completion | tool_calling | test_text_chat_completion_with_tool_calling_and_non_streaming | ✅ | +| Llama-3.1-8B-Instruct | /completion | streaming | test_text_completion_streaming | ✅ | +| Llama-3.1-8B-Instruct | /completion | non_streaming | test_text_completion_non_streaming | ✅ | +| Llama-3.1-8B-Instruct | /completion | structured_output | test_text_completion_structured_output | ✅ | + +## Vector_io: +| API | Capability | Test | Status | +|:-----|:-----|:-----|:-----| +| /retrieve | | test_vector_db_retrieve | ✅ | + +## Agents: +| API | Capability | Test | Status | +|:-----|:-----|:-----|:-----| +| /create_agent_turn | rag | test_rag_agent | ✅ | +| /create_agent_turn | custom_tool | test_custom_tool | ✅ | +| /create_agent_turn | code_execution | test_code_interpreter_for_attachments | ✅ | diff --git a/llama_stack/templates/together/report.md b/llama_stack/templates/together/report.md new file mode 100644 index 000000000..10891f4e5 --- /dev/null +++ b/llama_stack/templates/together/report.md @@ -0,0 +1,44 @@ +# Report for together distribution + +## Supported Models: +| Model Descriptor | together | +|:---|:---| +| Llama-3-8B-Instruct | ❌ | +| Llama-3-70B-Instruct | ❌ | +| Llama3.1-8B-Instruct | ✅ | +| Llama3.1-70B-Instruct | ✅ | +| Llama3.1-405B-Instruct | ✅ | +| Llama3.2-1B-Instruct | ❌ | +| Llama3.2-3B-Instruct | ✅ | +| Llama3.2-11B-Vision-Instruct | ✅ | +| Llama3.2-90B-Vision-Instruct | ✅ | +| Llama3.3-70B-Instruct | ✅ | +| Llama-Guard-3-11B-Vision | ✅ | +| Llama-Guard-3-1B | ❌ | +| Llama-Guard-3-8B | ✅ | +| Llama-Guard-2-8B | ❌ | + +## Inference: +| Model | API | Capability | Test | Status | +|:----- |:-----|:-----|:-----|:-----| +| Llama-3.1-8B-Instruct | /chat_completion | streaming | test_text_chat_completion_streaming | ✅ | +| Llama-3.2-11B-Vision-Instruct | /chat_completion | streaming | test_image_chat_completion_streaming | ✅ | +| Llama-3.2-11B-Vision-Instruct | /chat_completion | non_streaming | test_image_chat_completion_non_streaming | ✅ | +| Llama-3.1-8B-Instruct | /chat_completion | non_streaming | test_text_chat_completion_non_streaming | ✅ | +| Llama-3.1-8B-Instruct | /chat_completion | tool_calling | test_text_chat_completion_with_tool_calling_and_streaming | ✅ | +| Llama-3.1-8B-Instruct | /chat_completion | tool_calling | test_text_chat_completion_with_tool_calling_and_non_streaming | ✅ | +| Llama-3.1-8B-Instruct | /completion | streaming | test_text_completion_streaming | ✅ | +| Llama-3.1-8B-Instruct | /completion | non_streaming | test_text_completion_non_streaming | ✅ | +| Llama-3.1-8B-Instruct | /completion | structured_output | test_text_completion_structured_output | ✅ | + +## Vector_io: +| API | Capability | Test | Status | +|:-----|:-----|:-----|:-----| +| /retrieve | | test_vector_db_retrieve | ✅ | + +## Agents: +| API | Capability | Test | Status | +|:-----|:-----|:-----|:-----| +| /create_agent_turn | rag | test_rag_agent | ✅ | +| /create_agent_turn | custom_tool | test_custom_tool | ✅ | +| /create_agent_turn | code_execution | test_code_interpreter_for_attachments | ✅ | diff --git a/tests/client-sdk/metadata.py b/tests/client-sdk/metadata.py index 1a87c6bd0..badd7edff 100644 --- a/tests/client-sdk/metadata.py +++ b/tests/client-sdk/metadata.py @@ -4,6 +4,7 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. +from llama_stack.providers.datatypes import Api INFERENCE_API_CAPA_TEST_MAP = { "chat_completion": { @@ -27,10 +28,9 @@ INFERENCE_API_CAPA_TEST_MAP = { }, } -MEMORY_API_TEST_MAP = { - "/insert, /query": { - "inline": ["test_memory_bank_insert_inline_and_query"], - "url": ["test_memory_bank_insert_from_url_and_query"], +VECTORIO_API_TEST_MAP = { + "retrieve": { + "": ["test_vector_db_retrieve"], } } @@ -44,7 +44,7 @@ AGENTS_API_TEST_MAP = { API_MAPS = { - "inference": INFERENCE_API_CAPA_TEST_MAP, - "memory": MEMORY_API_TEST_MAP, - "agents": AGENTS_API_TEST_MAP, + Api.inference: INFERENCE_API_CAPA_TEST_MAP, + Api.vector_io: VECTORIO_API_TEST_MAP, + Api.agents: AGENTS_API_TEST_MAP, } diff --git a/tests/client-sdk/report.py b/tests/client-sdk/report.py index 5a291f1af..de50efa46 100644 --- a/tests/client-sdk/report.py +++ b/tests/client-sdk/report.py @@ -12,8 +12,9 @@ from pathlib import Path from urllib.parse import urlparse import pytest - +from llama_models.datatypes import CoreModelId from llama_models.sku_list import ( + all_registered_models, llama3_1_instruct_models, llama3_2_instruct_models, llama3_3_instruct_models, @@ -22,6 +23,7 @@ from llama_models.sku_list import ( ) from llama_stack.distribution.library_client import LlamaStackAsLibraryClient +from llama_stack.providers.datatypes import Api from llama_stack.providers.tests.env import get_env_or_fail from llama_stack_client import LlamaStackClient @@ -42,6 +44,45 @@ def featured_models_repo_names(): return [model.huggingface_repo for model in models if not model.variant] +SUPPORTED_MODELS = { + "ollama": set( + [ + CoreModelId.llama3_1_8b_instruct.value, + CoreModelId.llama3_1_8b_instruct.value, + CoreModelId.llama3_1_70b_instruct.value, + CoreModelId.llama3_1_70b_instruct.value, + CoreModelId.llama3_1_405b_instruct.value, + CoreModelId.llama3_1_405b_instruct.value, + CoreModelId.llama3_2_1b_instruct.value, + CoreModelId.llama3_2_1b_instruct.value, + CoreModelId.llama3_2_3b_instruct.value, + CoreModelId.llama3_2_3b_instruct.value, + CoreModelId.llama3_2_11b_vision_instruct.value, + CoreModelId.llama3_2_11b_vision_instruct.value, + CoreModelId.llama3_2_90b_vision_instruct.value, + CoreModelId.llama3_2_90b_vision_instruct.value, + CoreModelId.llama3_3_70b_instruct.value, + CoreModelId.llama_guard_3_8b.value, + CoreModelId.llama_guard_3_1b.value, + ] + ), + "tgi": set( + [ + model.core_model_id.value + for model in all_registered_models() + if model.huggingface_repo + ] + ), + "vllm": set( + [ + model.core_model_id.value + for model in all_registered_models() + if model.huggingface_repo + ] + ), +} + + class Report: def __init__(self): @@ -90,6 +131,8 @@ class Report: # test function -> test nodeid self.test_data = dict() self.test_name_to_nodeid = defaultdict(list) + self.vision_model_id = None + self.text_model_id = None @pytest.hookimpl(tryfirst=True) def pytest_runtest_logreport(self, report): @@ -113,20 +156,28 @@ class Report: report.append(dividor) rows = [] - - try: + if self.image_name in SUPPORTED_MODELS: + for model in all_registered_models(): + if ( + "Instruct" not in model.core_model_id.value + and "Guard" not in model.core_model_id.value + ) or (model.variant): + continue + row = f"| {model.core_model_id.value} |" + if model.core_model_id.value in SUPPORTED_MODELS[self.image_name]: + row += " ✅ |" + else: + row += " ❌ |" + rows.append(row) + else: supported_models = {m.identifier for m in self.client.models.list()} - except Exception as e: - cprint(f"Error getting models: {e}", "red") - supported_models = set() - - for m_name in featured_models_repo_names(): - row = f"| {m_name} |" - if m_name in supported_models: - row += " ✅ |" - else: - row += " ❌ |" - rows.append(row) + for model in featured_models_repo_names(): + row = f"| {model} |" + if model in supported_models: + row += " ✅ |" + else: + row += " ❌ |" + rows.append(row) report.extend(rows) report.append("\n## Inference:") @@ -134,23 +185,28 @@ class Report: "| Model | API | Capability | Test | Status |", "|:----- |:-----|:-----|:-----|:-----|", ] - for api, capa_map in API_MAPS["inference"].items(): + for api, capa_map in API_MAPS[Api.inference].items(): for capa, tests in capa_map.items(): for test_name in tests: - model_type = "Text" if "text" in test_name else "Vision" + model_id = ( + self.text_model_id + if "text" in test_name + else self.vision_model_id + ) test_nodeids = self.test_name_to_nodeid[test_name] assert len(test_nodeids) > 0 + # There might be more than one parametrizations for the same test function. We take # the result of the first one for now. Ideally we should mark the test as failed if # any of the parametrizations failed. test_table.append( - f"| {model_type} | /{api} | {capa} | {test_name} | {self._print_result_icon(self.test_data[test_nodeids[0]])} |" + f"| {model_id} | /{api} | {capa} | {test_name} | {self._print_result_icon(self.test_data[test_nodeids[0]])} |" ) report.extend(test_table) - for api_group in ["memory", "agents"]: - api_capitalized = api_group.capitalize() + for api_group in [Api.vector_io, Api.agents]: + api_capitalized = api_group.name.capitalize() report.append(f"\n## {api_capitalized}:") test_table = [ "| API | Capability | Test | Status |", @@ -162,7 +218,7 @@ class Report: test_nodeids = self.test_name_to_nodeid[test_name] assert len(test_nodeids) > 0 test_table.append( - f"| {api} | {capa} | {test_name} | {self._print_result_icon(self.test_data[test_nodeids[0]])} |" + f"| /{api} | {capa} | {test_name} | {self._print_result_icon(self.test_data[test_nodeids[0]])} |" ) report.extend(test_table) @@ -173,6 +229,13 @@ class Report: def pytest_runtest_makereport(self, item, call): func_name = getattr(item, "originalname", item.name) + if "text_model_id" in item.funcargs: + text_model = item.funcargs["text_model_id"].split("/")[1] + self.text_model_id = self.text_model_id or text_model + elif "vision_model_id" in item.funcargs: + vision_model = item.funcargs["vision_model_id"].split("/")[1] + self.vision_model_id = self.vision_model_id or vision_model + self.test_name_to_nodeid[func_name].append(item.nodeid) def _print_result_icon(self, result): diff --git a/tests/client-sdk/vector_io/test_vector_io.py b/tests/client-sdk/vector_io/test_vector_io.py index 20e49d805..2a110b73a 100644 --- a/tests/client-sdk/vector_io/test_vector_io.py +++ b/tests/client-sdk/vector_io/test_vector_io.py @@ -6,8 +6,36 @@ import random +import pytest -def test_vector_db_retrieve(llama_stack_client, embedding_model): + +@pytest.fixture(scope="function") +def empty_vector_db_registry(llama_stack_client): + vector_dbs = [ + vector_db.identifier for vector_db in llama_stack_client.vector_dbs.list() + ] + for vector_db_id in vector_dbs: + llama_stack_client.vector_dbs.unregister(vector_db_id=vector_db_id) + + +@pytest.fixture(scope="function") +def single_entry_vector_db_registry(llama_stack_client, empty_vector_db_registry): + vector_db_id = f"test_vector_db_{random.randint(1000, 9999)}" + llama_stack_client.vector_dbs.register( + vector_db_id=vector_db_id, + embedding_model="all-MiniLM-L6-v2", + embedding_dimension=384, + provider_id="faiss", + ) + vector_dbs = [ + vector_db.identifier for vector_db in llama_stack_client.vector_dbs.list() + ] + return vector_dbs + + +def test_vector_db_retrieve( + llama_stack_client, embedding_model, empty_vector_db_registry +): # Register a memory bank first vector_db_id = f"test_vector_db_{random.randint(1000, 9999)}" llama_stack_client.vector_dbs.register( @@ -26,14 +54,16 @@ def test_vector_db_retrieve(llama_stack_client, embedding_model): assert response.provider_resource_id == vector_db_id -def test_vector_db_list(llama_stack_client): +def test_vector_db_list(llama_stack_client, empty_vector_db_registry): vector_dbs_after_register = [ vector_db.identifier for vector_db in llama_stack_client.vector_dbs.list() ] assert len(vector_dbs_after_register) == 0 -def test_vector_db_register(llama_stack_client, embedding_model): +def test_vector_db_register( + llama_stack_client, embedding_model, empty_vector_db_registry +): vector_db_id = f"test_vector_db_{random.randint(1000, 9999)}" llama_stack_client.vector_dbs.register( vector_db_id=vector_db_id, @@ -48,7 +78,7 @@ def test_vector_db_register(llama_stack_client, embedding_model): assert vector_dbs_after_register == [vector_db_id] -def test_vector_db_unregister(llama_stack_client): +def test_vector_db_unregister(llama_stack_client, single_entry_vector_db_registry): vector_dbs = [ vector_db.identifier for vector_db in llama_stack_client.vector_dbs.list() ]