mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-29 07:14:20 +00:00
dell tgi adapter
This commit is contained in:
parent
cfc97df6d5
commit
cd1f1a86bf
5 changed files with 61 additions and 11 deletions
|
@ -6,15 +6,32 @@
|
||||||
|
|
||||||
from typing import Union
|
from typing import Union
|
||||||
|
|
||||||
from .config import InferenceAPIImplConfig, InferenceEndpointImplConfig, TGIImplConfig
|
from .config import (
|
||||||
from .tgi import InferenceAPIAdapter, InferenceEndpointAdapter, TGIAdapter
|
DellTGIImplConfig,
|
||||||
|
InferenceAPIImplConfig,
|
||||||
|
InferenceEndpointImplConfig,
|
||||||
|
TGIImplConfig,
|
||||||
|
)
|
||||||
|
from .tgi import (
|
||||||
|
DellTGIAdapter,
|
||||||
|
InferenceAPIAdapter,
|
||||||
|
InferenceEndpointAdapter,
|
||||||
|
TGIAdapter,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
async def get_adapter_impl(
|
async def get_adapter_impl(
|
||||||
config: Union[InferenceAPIImplConfig, InferenceEndpointImplConfig, TGIImplConfig],
|
config: Union[
|
||||||
|
InferenceAPIImplConfig,
|
||||||
|
InferenceEndpointImplConfig,
|
||||||
|
TGIImplConfig,
|
||||||
|
DellTGIImplConfig,
|
||||||
|
],
|
||||||
_deps,
|
_deps,
|
||||||
):
|
):
|
||||||
if isinstance(config, TGIImplConfig):
|
if isinstance(config, DellTGIImplConfig):
|
||||||
|
impl = DellTGIAdapter()
|
||||||
|
elif isinstance(config, TGIImplConfig):
|
||||||
impl = TGIAdapter()
|
impl = TGIAdapter()
|
||||||
elif isinstance(config, InferenceAPIImplConfig):
|
elif isinstance(config, InferenceAPIImplConfig):
|
||||||
impl = InferenceAPIAdapter()
|
impl = InferenceAPIAdapter()
|
||||||
|
|
|
@ -41,3 +41,17 @@ class InferenceAPIImplConfig(BaseModel):
|
||||||
default=None,
|
default=None,
|
||||||
description="Your Hugging Face user access token (will default to locally saved token if not provided)",
|
description="Your Hugging Face user access token (will default to locally saved token if not provided)",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class DellTGIImplConfig(BaseModel):
|
||||||
|
url: str = Field(
|
||||||
|
description="The URL for the Dell TGI endpoint (e.g. 'http://localhost:8080')",
|
||||||
|
)
|
||||||
|
hf_model_name: str = Field(
|
||||||
|
description="The model ID of the model on the Hugging Face Hub (e.g. 'meta-llama/Meta-Llama-3.1-70B-Instruct')",
|
||||||
|
)
|
||||||
|
api_token: Optional[str] = Field(
|
||||||
|
default=None,
|
||||||
|
description="A bearer token if your TGI endpoint is protected.",
|
||||||
|
)
|
||||||
|
|
|
@ -29,7 +29,12 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||||
chat_completion_request_to_model_input_info,
|
chat_completion_request_to_model_input_info,
|
||||||
)
|
)
|
||||||
|
|
||||||
from .config import InferenceAPIImplConfig, InferenceEndpointImplConfig, TGIImplConfig
|
from .config import (
|
||||||
|
DellTGIImplConfig,
|
||||||
|
InferenceAPIImplConfig,
|
||||||
|
InferenceEndpointImplConfig,
|
||||||
|
TGIImplConfig,
|
||||||
|
)
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
@ -52,10 +57,6 @@ class _HfAdapter(Inference, ModelsProtocolPrivate):
|
||||||
|
|
||||||
async def list_models(self) -> List[ModelDef]:
|
async def list_models(self) -> List[ModelDef]:
|
||||||
repo = self.model_id
|
repo = self.model_id
|
||||||
# tmp hack to support Dell
|
|
||||||
if repo not in self.huggingface_repo_to_llama_model_id:
|
|
||||||
repo = "meta-llama/Llama-3.1-8B-Instruct"
|
|
||||||
|
|
||||||
identifier = self.huggingface_repo_to_llama_model_id[repo]
|
identifier = self.huggingface_repo_to_llama_model_id[repo]
|
||||||
return [
|
return [
|
||||||
ModelDef(
|
ModelDef(
|
||||||
|
@ -177,6 +178,14 @@ class TGIAdapter(_HfAdapter):
|
||||||
self.model_id = endpoint_info["model_id"]
|
self.model_id = endpoint_info["model_id"]
|
||||||
|
|
||||||
|
|
||||||
|
class DellTGIAdapter(_HfAdapter):
|
||||||
|
async def initialize(self, config: DellTGIImplConfig) -> None:
|
||||||
|
self.client = AsyncInferenceClient(model=config.url, token=config.api_token)
|
||||||
|
endpoint_info = await self.client.get_endpoint_info()
|
||||||
|
self.max_tokens = endpoint_info["max_total_tokens"]
|
||||||
|
self.model_id = config.hf_model_name
|
||||||
|
|
||||||
|
|
||||||
class InferenceAPIAdapter(_HfAdapter):
|
class InferenceAPIAdapter(_HfAdapter):
|
||||||
async def initialize(self, config: InferenceAPIImplConfig) -> None:
|
async def initialize(self, config: InferenceAPIImplConfig) -> None:
|
||||||
self.client = AsyncInferenceClient(
|
self.client = AsyncInferenceClient(
|
||||||
|
|
|
@ -87,6 +87,15 @@ def available_providers() -> List[ProviderSpec]:
|
||||||
config_class="llama_stack.providers.adapters.inference.tgi.InferenceEndpointImplConfig",
|
config_class="llama_stack.providers.adapters.inference.tgi.InferenceEndpointImplConfig",
|
||||||
),
|
),
|
||||||
),
|
),
|
||||||
|
remote_provider_spec(
|
||||||
|
api=Api.inference,
|
||||||
|
adapter=AdapterSpec(
|
||||||
|
adapter_type="dell-tgi",
|
||||||
|
pip_packages=["huggingface_hub", "aiohttp"],
|
||||||
|
module="llama_stack.providers.adapters.inference.tgi",
|
||||||
|
config_class="llama_stack.providers.adapters.inference.tgi.DellTGIImplConfig",
|
||||||
|
),
|
||||||
|
),
|
||||||
remote_provider_spec(
|
remote_provider_spec(
|
||||||
api=Api.inference,
|
api=Api.inference,
|
||||||
adapter=AdapterSpec(
|
adapter=AdapterSpec(
|
||||||
|
|
|
@ -13,10 +13,11 @@ apis:
|
||||||
- safety
|
- safety
|
||||||
providers:
|
providers:
|
||||||
inference:
|
inference:
|
||||||
- provider_id: remote::tgi
|
- provider_id: remote::dell-tgi
|
||||||
provider_type: remote::tgi
|
provider_type: remote::dell-tgi
|
||||||
config:
|
config:
|
||||||
url: http://127.0.0.1:5009
|
url: http://127.0.0.1:5009
|
||||||
|
hf_model_name: meta-llama/Llama-3.1-8B-Instruct
|
||||||
safety:
|
safety:
|
||||||
- provider_id: meta-reference
|
- provider_id: meta-reference
|
||||||
provider_type: meta-reference
|
provider_type: meta-reference
|
Loading…
Add table
Add a link
Reference in a new issue