mirror of
				https://github.com/meta-llama/llama-stack.git
				synced 2025-10-26 09:15:40 +00:00 
			
		
		
		
	Adds flexible CORS (Cross-Origin Resource Sharing) configuration support to the FastAPI server with both local development and explicit configuration modes: - **Local development mode**: `cors: true` enables localhost-only access with regex pattern `https?://localhost:\d+` - **Explicit configuration mode**: Specific origins configuration with credential support and validation - Prevents insecure combinations (wildcards with credentials) - FastAPI CORSMiddleware integration via `model_dump()` Addresses the need for configurable CORS policies to support web frontends and cross-origin API access while maintaining security. Closes #2119 ## Test Plan 1. Ran Unit Tests. 2. Manual tests: FastAPI middleware integration with actual HTTP requests - Local development mode localhost access validation - Explicit configuration mode origins validation - Preflight OPTIONS request handling Some screenshots of manual tests. <img width="1920" height="927" alt="image" src="https://github.com/user-attachments/assets/79322338-40c7-45c9-a9ea-e3e8d8e2f849" /> <img width="1911" height="1037" alt="image" src="https://github.com/user-attachments/assets/1683524e-b0c9-48c9-a0a5-782e949cde01" /> cc: @leseb @rhuss @franciscojavierarceo
		
			
				
	
	
		
			105 lines
		
	
	
	
		
			3.3 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			105 lines
		
	
	
	
		
			3.3 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| # Copyright (c) Meta Platforms, Inc. and affiliates.
 | |
| # All rights reserved.
 | |
| #
 | |
| # This source code is licensed under the terms described in the LICENSE file in
 | |
| # the root directory of this source tree.
 | |
| 
 | |
| import pytest
 | |
| 
 | |
| from llama_stack.core.datatypes import CORSConfig, process_cors_config
 | |
| 
 | |
| 
 | |
| def test_cors_config_defaults():
 | |
|     config = CORSConfig()
 | |
| 
 | |
|     assert config.allow_origins == []
 | |
|     assert config.allow_origin_regex is None
 | |
|     assert config.allow_methods == ["OPTIONS"]
 | |
|     assert config.allow_headers == []
 | |
|     assert config.allow_credentials is False
 | |
|     assert config.expose_headers == []
 | |
|     assert config.max_age == 600
 | |
| 
 | |
| 
 | |
| def test_cors_config_explicit_config():
 | |
|     config = CORSConfig(
 | |
|         allow_origins=["https://example.com"], allow_credentials=True, max_age=3600, allow_methods=["GET", "POST"]
 | |
|     )
 | |
| 
 | |
|     assert config.allow_origins == ["https://example.com"]
 | |
|     assert config.allow_credentials is True
 | |
|     assert config.max_age == 3600
 | |
|     assert config.allow_methods == ["GET", "POST"]
 | |
| 
 | |
| 
 | |
| def test_cors_config_regex():
 | |
|     config = CORSConfig(allow_origins=[], allow_origin_regex=r"https?://localhost:\d+")
 | |
| 
 | |
|     assert config.allow_origins == []
 | |
|     assert config.allow_origin_regex == r"https?://localhost:\d+"
 | |
| 
 | |
| 
 | |
| def test_cors_config_wildcard_credentials_error():
 | |
|     with pytest.raises(ValueError, match="Cannot use wildcard origins with credentials enabled"):
 | |
|         CORSConfig(allow_origins=["*"], allow_credentials=True)
 | |
| 
 | |
|     with pytest.raises(ValueError, match="Cannot use wildcard origins with credentials enabled"):
 | |
|         CORSConfig(allow_origins=["https://example.com", "*"], allow_credentials=True)
 | |
| 
 | |
| 
 | |
| def test_process_cors_config_false():
 | |
|     result = process_cors_config(False)
 | |
|     assert result is None
 | |
| 
 | |
| 
 | |
| def test_process_cors_config_true():
 | |
|     result = process_cors_config(True)
 | |
| 
 | |
|     assert isinstance(result, CORSConfig)
 | |
|     assert result.allow_origins == []
 | |
|     assert result.allow_origin_regex == r"https?://localhost:\d+"
 | |
|     assert result.allow_credentials is False
 | |
|     expected_methods = ["GET", "POST", "PUT", "DELETE", "OPTIONS"]
 | |
|     for method in expected_methods:
 | |
|         assert method in result.allow_methods
 | |
| 
 | |
| 
 | |
| def test_process_cors_config_passthrough():
 | |
|     original = CORSConfig(allow_origins=["https://example.com"], allow_methods=["GET"])
 | |
|     result = process_cors_config(original)
 | |
| 
 | |
|     assert result is original
 | |
| 
 | |
| 
 | |
| def test_process_cors_config_invalid_type():
 | |
|     with pytest.raises(ValueError, match="Expected bool or CORSConfig, got str"):
 | |
|         process_cors_config("invalid")
 | |
| 
 | |
| 
 | |
| def test_cors_config_model_dump():
 | |
|     cors_config = CORSConfig(
 | |
|         allow_origins=["https://example.com"],
 | |
|         allow_methods=["GET", "POST"],
 | |
|         allow_headers=["Content-Type"],
 | |
|         allow_credentials=True,
 | |
|         max_age=3600,
 | |
|     )
 | |
| 
 | |
|     config_dict = cors_config.model_dump()
 | |
| 
 | |
|     assert config_dict["allow_origins"] == ["https://example.com"]
 | |
|     assert config_dict["allow_methods"] == ["GET", "POST"]
 | |
|     assert config_dict["allow_headers"] == ["Content-Type"]
 | |
|     assert config_dict["allow_credentials"] is True
 | |
|     assert config_dict["max_age"] == 3600
 | |
| 
 | |
|     expected_keys = {
 | |
|         "allow_origins",
 | |
|         "allow_origin_regex",
 | |
|         "allow_methods",
 | |
|         "allow_headers",
 | |
|         "allow_credentials",
 | |
|         "expose_headers",
 | |
|         "max_age",
 | |
|     }
 | |
|     assert set(config_dict.keys()) == expected_keys
 |