mirror of
				https://github.com/meta-llama/llama-stack.git
				synced 2025-10-26 09:15:40 +00:00 
			
		
		
		
	# What does this PR do? **Fixes** #1959 HuggingFace provides several loading paths that the datasets library can use. My theory on why the test would previously fail intermittently is because when calling `load_dataset(...)`, it may be trying several options such as local cache, Hugging Face Hub, or a dataset script, or other. There's one of these options that seem to work inconsistently in the CI. The HuggingFace datasets library relies on the `transformers` package to load certain datasets such as `llamastack/simpleqa`, and by adding the package, we can see the dataset is loaded consistently via the Hugging Face Hub. Please see PR in my fork demonstrating over 7 consecutive passes: https://github.com/ChristianZaccaria/llama-stack/pull/1 **Some References:** - https://github.com/huggingface/transformers/issues/8690 - https://huggingface.co/docs/datasets/en/loading [//]: # (If resolving an issue, uncomment and update the line below) [//]: # (Closes #[issue-number]) ## Test Plan [Describe the tests you ran to verify your changes with result summaries. *Provide clear instructions so the plan can be easily re-executed.*] [//]: # (## Documentation)
		
			
				
	
	
		
			95 lines
		
	
	
	
		
			2.8 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			95 lines
		
	
	
	
		
			2.8 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| # Copyright (c) Meta Platforms, Inc. and affiliates.
 | |
| # All rights reserved.
 | |
| #
 | |
| # This source code is licensed under the terms described in the LICENSE file in
 | |
| # the root directory of this source tree.
 | |
| 
 | |
| 
 | |
| import base64
 | |
| import mimetypes
 | |
| import os
 | |
| 
 | |
| import pytest
 | |
| 
 | |
| # How to run this test:
 | |
| #
 | |
| # LLAMA_STACK_CONFIG="template-name" pytest -v tests/integration/datasets
 | |
| 
 | |
| 
 | |
| def data_url_from_file(file_path: str) -> str:
 | |
|     if not os.path.exists(file_path):
 | |
|         raise FileNotFoundError(f"File not found: {file_path}")
 | |
| 
 | |
|     with open(file_path, "rb") as file:
 | |
|         file_content = file.read()
 | |
| 
 | |
|     base64_content = base64.b64encode(file_content).decode("utf-8")
 | |
|     mime_type, _ = mimetypes.guess_type(file_path)
 | |
| 
 | |
|     data_url = f"data:{mime_type};base64,{base64_content}"
 | |
| 
 | |
|     return data_url
 | |
| 
 | |
| 
 | |
| @pytest.mark.parametrize(
 | |
|     "purpose, source, provider_id, limit",
 | |
|     [
 | |
|         (
 | |
|             "eval/messages-answer",
 | |
|             {
 | |
|                 "type": "uri",
 | |
|                 "uri": "huggingface://datasets/llamastack/simpleqa?split=train",
 | |
|             },
 | |
|             "huggingface",
 | |
|             10,
 | |
|         ),
 | |
|         (
 | |
|             "eval/messages-answer",
 | |
|             {
 | |
|                 "type": "rows",
 | |
|                 "rows": [
 | |
|                     {
 | |
|                         "messages": [{"role": "user", "content": "Hello, world!"}],
 | |
|                         "answer": "Hello, world!",
 | |
|                     },
 | |
|                     {
 | |
|                         "messages": [
 | |
|                             {
 | |
|                                 "role": "user",
 | |
|                                 "content": "What is the capital of France?",
 | |
|                             }
 | |
|                         ],
 | |
|                         "answer": "Paris",
 | |
|                     },
 | |
|                 ],
 | |
|             },
 | |
|             "localfs",
 | |
|             2,
 | |
|         ),
 | |
|         (
 | |
|             "eval/messages-answer",
 | |
|             {
 | |
|                 "type": "uri",
 | |
|                 "uri": data_url_from_file(os.path.join(os.path.dirname(__file__), "test_dataset.csv")),
 | |
|             },
 | |
|             "localfs",
 | |
|             5,
 | |
|         ),
 | |
|     ],
 | |
| )
 | |
| def test_register_and_iterrows(llama_stack_client, purpose, source, provider_id, limit):
 | |
|     dataset = llama_stack_client.datasets.register(
 | |
|         purpose=purpose,
 | |
|         source=source,
 | |
|     )
 | |
|     assert dataset.identifier is not None
 | |
|     assert dataset.provider_id == provider_id
 | |
|     iterrow_response = llama_stack_client.datasets.iterrows(dataset.identifier, limit=limit)
 | |
|     assert len(iterrow_response.data) == limit
 | |
| 
 | |
|     dataset_list = llama_stack_client.datasets.list()
 | |
|     assert dataset.identifier in [d.identifier for d in dataset_list]
 | |
| 
 | |
|     llama_stack_client.datasets.unregister(dataset.identifier)
 | |
|     dataset_list = llama_stack_client.datasets.list()
 | |
|     assert dataset.identifier not in [d.identifier for d in dataset_list]
 |