dell tgi adapter

This commit is contained in:
Xi Yan 2024-10-16 16:35:46 -07:00
parent cfc97df6d5
commit cd1f1a86bf
5 changed files with 61 additions and 11 deletions

View file

@ -6,15 +6,32 @@
from typing import Union from typing import Union
from .config import InferenceAPIImplConfig, InferenceEndpointImplConfig, TGIImplConfig from .config import (
from .tgi import InferenceAPIAdapter, InferenceEndpointAdapter, TGIAdapter DellTGIImplConfig,
InferenceAPIImplConfig,
InferenceEndpointImplConfig,
TGIImplConfig,
)
from .tgi import (
DellTGIAdapter,
InferenceAPIAdapter,
InferenceEndpointAdapter,
TGIAdapter,
)
async def get_adapter_impl( async def get_adapter_impl(
config: Union[InferenceAPIImplConfig, InferenceEndpointImplConfig, TGIImplConfig], config: Union[
InferenceAPIImplConfig,
InferenceEndpointImplConfig,
TGIImplConfig,
DellTGIImplConfig,
],
_deps, _deps,
): ):
if isinstance(config, TGIImplConfig): if isinstance(config, DellTGIImplConfig):
impl = DellTGIAdapter()
elif isinstance(config, TGIImplConfig):
impl = TGIAdapter() impl = TGIAdapter()
elif isinstance(config, InferenceAPIImplConfig): elif isinstance(config, InferenceAPIImplConfig):
impl = InferenceAPIAdapter() impl = InferenceAPIAdapter()

View file

@ -41,3 +41,17 @@ class InferenceAPIImplConfig(BaseModel):
default=None, default=None,
description="Your Hugging Face user access token (will default to locally saved token if not provided)", description="Your Hugging Face user access token (will default to locally saved token if not provided)",
) )
@json_schema_type
class DellTGIImplConfig(BaseModel):
url: str = Field(
description="The URL for the Dell TGI endpoint (e.g. 'http://localhost:8080')",
)
hf_model_name: str = Field(
description="The model ID of the model on the Hugging Face Hub (e.g. 'meta-llama/Meta-Llama-3.1-70B-Instruct')",
)
api_token: Optional[str] = Field(
default=None,
description="A bearer token if your TGI endpoint is protected.",
)

View file

@ -29,7 +29,12 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
chat_completion_request_to_model_input_info, chat_completion_request_to_model_input_info,
) )
from .config import InferenceAPIImplConfig, InferenceEndpointImplConfig, TGIImplConfig from .config import (
DellTGIImplConfig,
InferenceAPIImplConfig,
InferenceEndpointImplConfig,
TGIImplConfig,
)
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -52,10 +57,6 @@ class _HfAdapter(Inference, ModelsProtocolPrivate):
async def list_models(self) -> List[ModelDef]: async def list_models(self) -> List[ModelDef]:
repo = self.model_id repo = self.model_id
# tmp hack to support Dell
if repo not in self.huggingface_repo_to_llama_model_id:
repo = "meta-llama/Llama-3.1-8B-Instruct"
identifier = self.huggingface_repo_to_llama_model_id[repo] identifier = self.huggingface_repo_to_llama_model_id[repo]
return [ return [
ModelDef( ModelDef(
@ -177,6 +178,14 @@ class TGIAdapter(_HfAdapter):
self.model_id = endpoint_info["model_id"] self.model_id = endpoint_info["model_id"]
class DellTGIAdapter(_HfAdapter):
async def initialize(self, config: DellTGIImplConfig) -> None:
self.client = AsyncInferenceClient(model=config.url, token=config.api_token)
endpoint_info = await self.client.get_endpoint_info()
self.max_tokens = endpoint_info["max_total_tokens"]
self.model_id = config.hf_model_name
class InferenceAPIAdapter(_HfAdapter): class InferenceAPIAdapter(_HfAdapter):
async def initialize(self, config: InferenceAPIImplConfig) -> None: async def initialize(self, config: InferenceAPIImplConfig) -> None:
self.client = AsyncInferenceClient( self.client = AsyncInferenceClient(

View file

@ -87,6 +87,15 @@ def available_providers() -> List[ProviderSpec]:
config_class="llama_stack.providers.adapters.inference.tgi.InferenceEndpointImplConfig", config_class="llama_stack.providers.adapters.inference.tgi.InferenceEndpointImplConfig",
), ),
), ),
remote_provider_spec(
api=Api.inference,
adapter=AdapterSpec(
adapter_type="dell-tgi",
pip_packages=["huggingface_hub", "aiohttp"],
module="llama_stack.providers.adapters.inference.tgi",
config_class="llama_stack.providers.adapters.inference.tgi.DellTGIImplConfig",
),
),
remote_provider_spec( remote_provider_spec(
api=Api.inference, api=Api.inference,
adapter=AdapterSpec( adapter=AdapterSpec(

View file

@ -13,10 +13,11 @@ apis:
- safety - safety
providers: providers:
inference: inference:
- provider_id: remote::tgi - provider_id: remote::dell-tgi
provider_type: remote::tgi provider_type: remote::dell-tgi
config: config:
url: http://127.0.0.1:5009 url: http://127.0.0.1:5009
hf_model_name: meta-llama/Llama-3.1-8B-Instruct
safety: safety:
- provider_id: meta-reference - provider_id: meta-reference
provider_type: meta-reference provider_type: meta-reference