mirror of
				https://github.com/meta-llama/llama-stack.git
				synced 2025-10-25 09:05:37 +00:00 
			
		
		
		
	# What does this PR do? The goal of this PR is code base modernization. Schema reflection code needed a minor adjustment to handle UnionTypes and collections.abc.AsyncIterator. (Both are preferred for latest Python releases.) Note to reviewers: almost all changes here are automatically generated by pyupgrade. Some additional unused imports were cleaned up. The only change worth of note can be found under `docs/openapi_generator` and `llama_stack/strong_typing/schema.py` where reflection code was updated to deal with "newer" types. Signed-off-by: Ihar Hrachyshka <ihar.hrachyshka@gmail.com>
		
			
				
	
	
		
			223 lines
		
	
	
	
		
			7.8 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			223 lines
		
	
	
	
		
			7.8 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| # Copyright (c) Meta Platforms, Inc. and affiliates.
 | |
| # All rights reserved.
 | |
| #
 | |
| # This source code is licensed under the terms described in the LICENSE file in
 | |
| # the root directory of this source tree.
 | |
| 
 | |
| from typing import Any
 | |
| from unittest.mock import patch
 | |
| 
 | |
| import pytest
 | |
| import yaml
 | |
| from pydantic import BaseModel, Field, ValidationError
 | |
| 
 | |
| from llama_stack.distribution.datatypes import Api, Provider, StackRunConfig
 | |
| from llama_stack.distribution.distribution import get_provider_registry
 | |
| from llama_stack.providers.datatypes import ProviderSpec
 | |
| 
 | |
| 
 | |
| class SampleConfig(BaseModel):
 | |
|     foo: str = Field(
 | |
|         default="bar",
 | |
|         description="foo",
 | |
|     )
 | |
| 
 | |
|     @classmethod
 | |
|     def sample_run_config(cls, **kwargs: Any) -> dict[str, Any]:
 | |
|         return {
 | |
|             "foo": "baz",
 | |
|         }
 | |
| 
 | |
| 
 | |
| @pytest.fixture
 | |
| def mock_providers():
 | |
|     """Mock the available_providers function to return test providers."""
 | |
|     with patch("llama_stack.providers.registry.inference.available_providers") as mock:
 | |
|         mock.return_value = [
 | |
|             ProviderSpec(
 | |
|                 provider_type="test_provider",
 | |
|                 api=Api.inference,
 | |
|                 adapter_type="test_adapter",
 | |
|                 config_class="test_provider.config.TestProviderConfig",
 | |
|             )
 | |
|         ]
 | |
|         yield mock
 | |
| 
 | |
| 
 | |
| @pytest.fixture
 | |
| def base_config(tmp_path):
 | |
|     """Create a base StackRunConfig with common settings."""
 | |
|     return StackRunConfig(
 | |
|         image_name="test_image",
 | |
|         providers={
 | |
|             "inference": [
 | |
|                 Provider(
 | |
|                     provider_id="sample_provider",
 | |
|                     provider_type="sample",
 | |
|                     config=SampleConfig.sample_run_config(),
 | |
|                 )
 | |
|             ]
 | |
|         },
 | |
|         external_providers_dir=str(tmp_path),
 | |
|     )
 | |
| 
 | |
| 
 | |
| @pytest.fixture
 | |
| def provider_spec_yaml():
 | |
|     """Common provider spec YAML for testing."""
 | |
|     return """
 | |
| adapter:
 | |
|   adapter_type: test_provider
 | |
|   config_class: test_provider.config.TestProviderConfig
 | |
|   module: test_provider
 | |
| api_dependencies:
 | |
|   - safety
 | |
| """
 | |
| 
 | |
| 
 | |
| @pytest.fixture
 | |
| def inline_provider_spec_yaml():
 | |
|     """Common inline provider spec YAML for testing."""
 | |
|     return """
 | |
| module: test_provider
 | |
| config_class: test_provider.config.TestProviderConfig
 | |
| pip_packages:
 | |
|   - test-package
 | |
| api_dependencies:
 | |
|   - safety
 | |
| optional_api_dependencies:
 | |
|   - vector_io
 | |
| provider_data_validator: test_provider.validator.TestValidator
 | |
| container_image: test-image:latest
 | |
| """
 | |
| 
 | |
| 
 | |
| @pytest.fixture
 | |
| def api_directories(tmp_path):
 | |
|     """Create the API directory structure for testing."""
 | |
|     # Create remote provider directory
 | |
|     remote_inference_dir = tmp_path / "remote" / "inference"
 | |
|     remote_inference_dir.mkdir(parents=True, exist_ok=True)
 | |
| 
 | |
|     # Create inline provider directory
 | |
|     inline_inference_dir = tmp_path / "inline" / "inference"
 | |
|     inline_inference_dir.mkdir(parents=True, exist_ok=True)
 | |
| 
 | |
|     return remote_inference_dir, inline_inference_dir
 | |
| 
 | |
| 
 | |
| class TestProviderRegistry:
 | |
|     """Test suite for provider registry functionality."""
 | |
| 
 | |
|     def test_builtin_providers(self, mock_providers):
 | |
|         """Test loading built-in providers."""
 | |
|         registry = get_provider_registry(None)
 | |
| 
 | |
|         assert Api.inference in registry
 | |
|         assert "test_provider" in registry[Api.inference]
 | |
|         assert registry[Api.inference]["test_provider"].provider_type == "test_provider"
 | |
|         assert registry[Api.inference]["test_provider"].api == Api.inference
 | |
| 
 | |
|     def test_external_remote_providers(self, api_directories, mock_providers, base_config, provider_spec_yaml):
 | |
|         """Test loading external remote providers from YAML files."""
 | |
|         remote_dir, _ = api_directories
 | |
|         with open(remote_dir / "test_provider.yaml", "w") as f:
 | |
|             f.write(provider_spec_yaml)
 | |
| 
 | |
|         registry = get_provider_registry(base_config)
 | |
|         assert len(registry[Api.inference]) == 2
 | |
| 
 | |
|         assert Api.inference in registry
 | |
|         assert "remote::test_provider" in registry[Api.inference]
 | |
|         provider = registry[Api.inference]["remote::test_provider"]
 | |
|         assert provider.adapter.adapter_type == "test_provider"
 | |
|         assert provider.adapter.module == "test_provider"
 | |
|         assert provider.adapter.config_class == "test_provider.config.TestProviderConfig"
 | |
|         assert Api.safety in provider.api_dependencies
 | |
| 
 | |
|     def test_external_inline_providers(self, api_directories, mock_providers, base_config, inline_provider_spec_yaml):
 | |
|         """Test loading external inline providers from YAML files."""
 | |
|         _, inline_dir = api_directories
 | |
|         with open(inline_dir / "test_provider.yaml", "w") as f:
 | |
|             f.write(inline_provider_spec_yaml)
 | |
| 
 | |
|         registry = get_provider_registry(base_config)
 | |
|         assert len(registry[Api.inference]) == 2
 | |
| 
 | |
|         assert Api.inference in registry
 | |
|         assert "inline::test_provider" in registry[Api.inference]
 | |
|         provider = registry[Api.inference]["inline::test_provider"]
 | |
|         assert provider.provider_type == "inline::test_provider"
 | |
|         assert provider.module == "test_provider"
 | |
|         assert provider.config_class == "test_provider.config.TestProviderConfig"
 | |
|         assert provider.pip_packages == ["test-package"]
 | |
|         assert Api.safety in provider.api_dependencies
 | |
|         assert Api.vector_io in provider.optional_api_dependencies
 | |
|         assert provider.provider_data_validator == "test_provider.validator.TestValidator"
 | |
|         assert provider.container_image == "test-image:latest"
 | |
| 
 | |
|     def test_invalid_yaml(self, api_directories, mock_providers, base_config):
 | |
|         """Test handling of invalid YAML files."""
 | |
|         remote_dir, inline_dir = api_directories
 | |
|         with open(remote_dir / "invalid.yaml", "w") as f:
 | |
|             f.write("invalid: yaml: content: -")
 | |
|         with open(inline_dir / "invalid.yaml", "w") as f:
 | |
|             f.write("invalid: yaml: content: -")
 | |
| 
 | |
|         with pytest.raises(yaml.YAMLError):
 | |
|             get_provider_registry(base_config)
 | |
| 
 | |
|     def test_missing_directory(self, mock_providers):
 | |
|         """Test handling of missing external providers directory."""
 | |
|         config = StackRunConfig(
 | |
|             image_name="test_image",
 | |
|             providers={
 | |
|                 "inference": [
 | |
|                     Provider(
 | |
|                         provider_id="sample_provider",
 | |
|                         provider_type="sample",
 | |
|                         config=SampleConfig.sample_run_config(),
 | |
|                     )
 | |
|                 ]
 | |
|             },
 | |
|             external_providers_dir="/nonexistent/dir",
 | |
|         )
 | |
|         with pytest.raises(FileNotFoundError):
 | |
|             get_provider_registry(config)
 | |
| 
 | |
|     def test_empty_api_directory(self, api_directories, mock_providers, base_config):
 | |
|         """Test handling of empty API directory."""
 | |
|         registry = get_provider_registry(base_config)
 | |
|         assert len(registry[Api.inference]) == 1  # Only built-in provider
 | |
| 
 | |
|     def test_malformed_remote_provider_spec(self, api_directories, mock_providers, base_config):
 | |
|         """Test handling of malformed remote provider spec (missing required fields)."""
 | |
|         remote_dir, _ = api_directories
 | |
|         malformed_spec = """
 | |
| adapter:
 | |
|   adapter_type: test_provider
 | |
|   # Missing required fields
 | |
| api_dependencies:
 | |
|   - safety
 | |
| """
 | |
|         with open(remote_dir / "malformed.yaml", "w") as f:
 | |
|             f.write(malformed_spec)
 | |
| 
 | |
|         with pytest.raises(ValidationError):
 | |
|             get_provider_registry(base_config)
 | |
| 
 | |
|     def test_malformed_inline_provider_spec(self, api_directories, mock_providers, base_config):
 | |
|         """Test handling of malformed inline provider spec (missing required fields)."""
 | |
|         _, inline_dir = api_directories
 | |
|         malformed_spec = """
 | |
| module: test_provider
 | |
| # Missing required config_class
 | |
| pip_packages:
 | |
|   - test-package
 | |
| """
 | |
|         with open(inline_dir / "malformed.yaml", "w") as f:
 | |
|             f.write(malformed_spec)
 | |
| 
 | |
|         with pytest.raises(KeyError) as exc_info:
 | |
|             get_provider_registry(base_config)
 | |
|         assert "config_class" in str(exc_info.value)
 |