test: add unit test to ensure all config types are instantiable (#1601)

This commit is contained in:
Ashwin Bharambe 2025-03-12 22:29:58 -07:00 committed by GitHub
parent 0a0d6cb96e
commit d072b5fa0c
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
71 changed files with 662 additions and 465 deletions

View file

@ -4,6 +4,7 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any, Dict
from pydantic import BaseModel, Field
@ -20,3 +21,15 @@ class DatabricksImplConfig(BaseModel):
default=None,
description="The Databricks API token",
)
@classmethod
def sample_run_config(
cls,
url: str = "${env.DATABRICKS_URL}",
api_token: str = "${env.DATABRICKS_API_TOKEN}",
**kwargs: Any,
) -> Dict[str, Any]:
return {
"url": url,
"api_token": api_token,
}

View file

@ -5,10 +5,11 @@
# the root directory of this source tree.
from .config import RunpodImplConfig
from .runpod import RunpodInferenceAdapter
async def get_adapter_impl(config: RunpodImplConfig, _deps):
from .runpod import RunpodInferenceAdapter
assert isinstance(config, RunpodImplConfig), f"Unexpected config type: {type(config)}"
impl = RunpodInferenceAdapter(config)
await impl.initialize()

View file

@ -4,7 +4,7 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Optional
from typing import Any, Dict, Optional
from pydantic import BaseModel, Field
@ -21,3 +21,10 @@ class RunpodImplConfig(BaseModel):
default=None,
description="The API token",
)
@classmethod
def sample_run_config(cls, **kwargs: Any) -> Dict[str, Any]:
return {
"url": "${env.RUNPOD_URL:}",
"api_token": "${env.RUNPOD_API_TOKEN:}",
}

View file

@ -8,7 +8,6 @@ from typing import AsyncGenerator
from openai import OpenAI
from llama_stack.apis.inference import * # noqa: F403
from llama_stack.models.llama.datatypes import Message
# from llama_stack.providers.datatypes import ModelsProtocolPrivate
from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper

View file

@ -1,17 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any
from .config import SampleConfig
async def get_adapter_impl(config: SampleConfig, _deps) -> Any:
from .sample import SampleInferenceImpl
impl = SampleInferenceImpl(config)
await impl.initialize()
return impl

View file

@ -1,12 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from pydantic import BaseModel
class SampleConfig(BaseModel):
host: str = "localhost"
port: int = 9999

View file

@ -1,23 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_stack.apis.inference import Inference
from llama_stack.apis.models import Model
from .config import SampleConfig
class SampleInferenceImpl(Inference):
def __init__(self, config: SampleConfig):
self.config = config
async def register_model(self, model: Model) -> None:
# these are the model names the Llama Stack will use to route requests to this provider
# perform validation here if necessary
pass
async def initialize(self):
pass