mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-03 09:53:45 +00:00
# What does this PR do? Completes #3732 by removing runtime URL transformations and requiring users to provide full URLs in configuration. All providers now use 'base_url' consistently and respect the exact URL provided without appending paths like /v1 or /openai/v1 at runtime. BREAKING CHANGE: Users must update configs to include full URL paths (e.g., http://localhost:11434/v1 instead of http://localhost:11434). Closes #3732 ## Test Plan Existing tests should pass even with the URL changes, due to default URLs being altered. Add unit test to enforce URL standardization across remote inference providers (verifies all use 'base_url' field with HttpUrl | None type) Signed-off-by: Charlie Doern <cdoern@redhat.com>
62 lines
2 KiB
Python
62 lines
2 KiB
Python
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
# All rights reserved.
|
|
#
|
|
# This source code is licensed under the terms described in the LICENSE file in
|
|
# the root directory of this source tree.
|
|
|
|
import os
|
|
from typing import Any
|
|
|
|
from pydantic import BaseModel, Field, HttpUrl, SecretStr
|
|
|
|
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
|
|
from llama_stack_api import json_schema_type
|
|
|
|
|
|
class AzureProviderDataValidator(BaseModel):
|
|
azure_api_key: SecretStr = Field(
|
|
description="Azure API key for Azure",
|
|
)
|
|
azure_api_base: HttpUrl = Field(
|
|
description="Azure API base for Azure (e.g., https://your-resource-name.openai.azure.com)",
|
|
)
|
|
azure_api_version: str | None = Field(
|
|
default=None,
|
|
description="Azure API version for Azure (e.g., 2024-06-01)",
|
|
)
|
|
azure_api_type: str | None = Field(
|
|
default="azure",
|
|
description="Azure API type for Azure (e.g., azure)",
|
|
)
|
|
|
|
|
|
@json_schema_type
|
|
class AzureConfig(RemoteInferenceProviderConfig):
|
|
base_url: HttpUrl | None = Field(
|
|
default=None,
|
|
description="Azure API base for Azure (e.g., https://your-resource-name.openai.azure.com/openai/v1)",
|
|
)
|
|
api_version: str | None = Field(
|
|
default_factory=lambda: os.getenv("AZURE_API_VERSION"),
|
|
description="Azure API version for Azure (e.g., 2024-12-01-preview)",
|
|
)
|
|
api_type: str | None = Field(
|
|
default_factory=lambda: os.getenv("AZURE_API_TYPE", "azure"),
|
|
description="Azure API type for Azure (e.g., azure)",
|
|
)
|
|
|
|
@classmethod
|
|
def sample_run_config(
|
|
cls,
|
|
api_key: str = "${env.AZURE_API_KEY:=}",
|
|
base_url: str = "${env.AZURE_API_BASE:=}",
|
|
api_version: str = "${env.AZURE_API_VERSION:=}",
|
|
api_type: str = "${env.AZURE_API_TYPE:=}",
|
|
**kwargs,
|
|
) -> dict[str, Any]:
|
|
return {
|
|
"api_key": api_key,
|
|
"base_url": base_url,
|
|
"api_version": api_version,
|
|
"api_type": api_type,
|
|
}
|