diff --git a/llama_stack/providers/adapters/inference/tgi/__init__.py b/llama_stack/providers/adapters/inference/tgi/__init__.py index 451650323..e3b24de2f 100644 --- a/llama_stack/providers/adapters/inference/tgi/__init__.py +++ b/llama_stack/providers/adapters/inference/tgi/__init__.py @@ -6,15 +6,32 @@ from typing import Union -from .config import InferenceAPIImplConfig, InferenceEndpointImplConfig, TGIImplConfig -from .tgi import InferenceAPIAdapter, InferenceEndpointAdapter, TGIAdapter +from .config import ( + DellTGIImplConfig, + InferenceAPIImplConfig, + InferenceEndpointImplConfig, + TGIImplConfig, +) +from .tgi import ( + DellTGIAdapter, + InferenceAPIAdapter, + InferenceEndpointAdapter, + TGIAdapter, +) async def get_adapter_impl( - config: Union[InferenceAPIImplConfig, InferenceEndpointImplConfig, TGIImplConfig], + config: Union[ + InferenceAPIImplConfig, + InferenceEndpointImplConfig, + TGIImplConfig, + DellTGIImplConfig, + ], _deps, ): - if isinstance(config, TGIImplConfig): + if isinstance(config, DellTGIImplConfig): + impl = DellTGIAdapter() + elif isinstance(config, TGIImplConfig): impl = TGIAdapter() elif isinstance(config, InferenceAPIImplConfig): impl = InferenceAPIAdapter() diff --git a/llama_stack/providers/adapters/inference/tgi/config.py b/llama_stack/providers/adapters/inference/tgi/config.py index 6ce2b9dc6..801d5fc8f 100644 --- a/llama_stack/providers/adapters/inference/tgi/config.py +++ b/llama_stack/providers/adapters/inference/tgi/config.py @@ -41,3 +41,17 @@ class InferenceAPIImplConfig(BaseModel): default=None, description="Your Hugging Face user access token (will default to locally saved token if not provided)", ) + + +@json_schema_type +class DellTGIImplConfig(BaseModel): + url: str = Field( + description="The URL for the Dell TGI endpoint (e.g. 'http://localhost:8080')", + ) + hf_model_name: str = Field( + description="The model ID of the model on the Hugging Face Hub (e.g. 'meta-llama/Meta-Llama-3.1-70B-Instruct')", + ) + api_token: Optional[str] = Field( + default=None, + description="A bearer token if your TGI endpoint is protected.", + ) diff --git a/llama_stack/providers/adapters/inference/tgi/tgi.py b/llama_stack/providers/adapters/inference/tgi/tgi.py index 92cb9ba6a..4fe160045 100644 --- a/llama_stack/providers/adapters/inference/tgi/tgi.py +++ b/llama_stack/providers/adapters/inference/tgi/tgi.py @@ -29,7 +29,12 @@ from llama_stack.providers.utils.inference.prompt_adapter import ( chat_completion_request_to_model_input_info, ) -from .config import InferenceAPIImplConfig, InferenceEndpointImplConfig, TGIImplConfig +from .config import ( + DellTGIImplConfig, + InferenceAPIImplConfig, + InferenceEndpointImplConfig, + TGIImplConfig, +) logger = logging.getLogger(__name__) @@ -52,10 +57,6 @@ class _HfAdapter(Inference, ModelsProtocolPrivate): async def list_models(self) -> List[ModelDef]: repo = self.model_id - # tmp hack to support Dell - if repo not in self.huggingface_repo_to_llama_model_id: - repo = "meta-llama/Llama-3.1-8B-Instruct" - identifier = self.huggingface_repo_to_llama_model_id[repo] return [ ModelDef( @@ -177,6 +178,14 @@ class TGIAdapter(_HfAdapter): self.model_id = endpoint_info["model_id"] +class DellTGIAdapter(_HfAdapter): + async def initialize(self, config: DellTGIImplConfig) -> None: + self.client = AsyncInferenceClient(model=config.url, token=config.api_token) + endpoint_info = await self.client.get_endpoint_info() + self.max_tokens = endpoint_info["max_total_tokens"] + self.model_id = config.hf_model_name + + class InferenceAPIAdapter(_HfAdapter): async def initialize(self, config: InferenceAPIImplConfig) -> None: self.client = AsyncInferenceClient( diff --git a/llama_stack/providers/registry/inference.py b/llama_stack/providers/registry/inference.py index 686fc273b..8530109c5 100644 --- a/llama_stack/providers/registry/inference.py +++ b/llama_stack/providers/registry/inference.py @@ -87,6 +87,15 @@ def available_providers() -> List[ProviderSpec]: config_class="llama_stack.providers.adapters.inference.tgi.InferenceEndpointImplConfig", ), ), + remote_provider_spec( + api=Api.inference, + adapter=AdapterSpec( + adapter_type="dell-tgi", + pip_packages=["huggingface_hub", "aiohttp"], + module="llama_stack.providers.adapters.inference.tgi", + config_class="llama_stack.providers.adapters.inference.tgi.DellTGIImplConfig", + ), + ), remote_provider_spec( api=Api.inference, adapter=AdapterSpec( diff --git a/tests/examples/tgi-run.yaml b/tests/examples/dell-tgi-run.yaml similarity index 88% rename from tests/examples/tgi-run.yaml rename to tests/examples/dell-tgi-run.yaml index a398b20a3..0db5bdb3e 100644 --- a/tests/examples/tgi-run.yaml +++ b/tests/examples/dell-tgi-run.yaml @@ -13,10 +13,11 @@ apis: - safety providers: inference: - - provider_id: remote::tgi - provider_type: remote::tgi + - provider_id: remote::dell-tgi + provider_type: remote::dell-tgi config: url: http://127.0.0.1:5009 + hf_model_name: meta-llama/Llama-3.1-8B-Instruct safety: - provider_id: meta-reference provider_type: meta-reference