dell tgi adapter

This commit is contained in:
Xi Yan 2024-10-16 16:35:46 -07:00
parent cfc97df6d5
commit cd1f1a86bf
5 changed files with 61 additions and 11 deletions

View file

@ -29,7 +29,12 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
chat_completion_request_to_model_input_info,
)
from .config import InferenceAPIImplConfig, InferenceEndpointImplConfig, TGIImplConfig
from .config import (
DellTGIImplConfig,
InferenceAPIImplConfig,
InferenceEndpointImplConfig,
TGIImplConfig,
)
logger = logging.getLogger(__name__)
@ -52,10 +57,6 @@ class _HfAdapter(Inference, ModelsProtocolPrivate):
async def list_models(self) -> List[ModelDef]:
repo = self.model_id
# tmp hack to support Dell
if repo not in self.huggingface_repo_to_llama_model_id:
repo = "meta-llama/Llama-3.1-8B-Instruct"
identifier = self.huggingface_repo_to_llama_model_id[repo]
return [
ModelDef(
@ -177,6 +178,14 @@ class TGIAdapter(_HfAdapter):
self.model_id = endpoint_info["model_id"]
class DellTGIAdapter(_HfAdapter):
async def initialize(self, config: DellTGIImplConfig) -> None:
self.client = AsyncInferenceClient(model=config.url, token=config.api_token)
endpoint_info = await self.client.get_endpoint_info()
self.max_tokens = endpoint_info["max_total_tokens"]
self.model_id = config.hf_model_name
class InferenceAPIAdapter(_HfAdapter):
async def initialize(self, config: InferenceAPIImplConfig) -> None:
self.client = AsyncInferenceClient(