mirror of
				https://github.com/meta-llama/llama-stack.git
				synced 2025-10-25 01:01:13 +00:00 
			
		
		
		
	# What does this PR do? on the path to maintainable impls of inference providers. make all configs instances of RemoteInferenceProviderConfig. ## Test Plan ci
		
			
				
	
	
		
			64 lines
		
	
	
	
		
			2.1 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			64 lines
		
	
	
	
		
			2.1 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| # Copyright (c) Meta Platforms, Inc. and affiliates.
 | |
| # All rights reserved.
 | |
| #
 | |
| # This source code is licensed under the terms described in the LICENSE file in
 | |
| # the root directory of this source tree.
 | |
| 
 | |
| import os
 | |
| from typing import Any
 | |
| 
 | |
| from pydantic import BaseModel, Field, HttpUrl, SecretStr
 | |
| 
 | |
| from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
 | |
| from llama_stack.schema_utils import json_schema_type
 | |
| 
 | |
| 
 | |
| class AzureProviderDataValidator(BaseModel):
 | |
|     azure_api_key: SecretStr = Field(
 | |
|         description="Azure API key for Azure",
 | |
|     )
 | |
|     azure_api_base: HttpUrl = Field(
 | |
|         description="Azure API base for Azure (e.g., https://your-resource-name.openai.azure.com)",
 | |
|     )
 | |
|     azure_api_version: str | None = Field(
 | |
|         default=None,
 | |
|         description="Azure API version for Azure (e.g., 2024-06-01)",
 | |
|     )
 | |
|     azure_api_type: str | None = Field(
 | |
|         default="azure",
 | |
|         description="Azure API type for Azure (e.g., azure)",
 | |
|     )
 | |
| 
 | |
| 
 | |
| @json_schema_type
 | |
| class AzureConfig(RemoteInferenceProviderConfig):
 | |
|     api_key: SecretStr = Field(
 | |
|         description="Azure API key for Azure",
 | |
|     )
 | |
|     api_base: HttpUrl = Field(
 | |
|         description="Azure API base for Azure (e.g., https://your-resource-name.openai.azure.com)",
 | |
|     )
 | |
|     api_version: str | None = Field(
 | |
|         default_factory=lambda: os.getenv("AZURE_API_VERSION"),
 | |
|         description="Azure API version for Azure (e.g., 2024-12-01-preview)",
 | |
|     )
 | |
|     api_type: str | None = Field(
 | |
|         default_factory=lambda: os.getenv("AZURE_API_TYPE", "azure"),
 | |
|         description="Azure API type for Azure (e.g., azure)",
 | |
|     )
 | |
| 
 | |
|     @classmethod
 | |
|     def sample_run_config(
 | |
|         cls,
 | |
|         api_key: str = "${env.AZURE_API_KEY:=}",
 | |
|         api_base: str = "${env.AZURE_API_BASE:=}",
 | |
|         api_version: str = "${env.AZURE_API_VERSION:=}",
 | |
|         api_type: str = "${env.AZURE_API_TYPE:=}",
 | |
|         **kwargs,
 | |
|     ) -> dict[str, Any]:
 | |
|         return {
 | |
|             "api_key": api_key,
 | |
|             "api_base": api_base,
 | |
|             "api_version": api_version,
 | |
|             "api_type": api_type,
 | |
|         }
 |