mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-13 13:42:35 +00:00
dell tgi adapter
This commit is contained in:
parent
cfc97df6d5
commit
cd1f1a86bf
5 changed files with 61 additions and 11 deletions
|
|
@ -29,7 +29,12 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
|
|||
chat_completion_request_to_model_input_info,
|
||||
)
|
||||
|
||||
from .config import InferenceAPIImplConfig, InferenceEndpointImplConfig, TGIImplConfig
|
||||
from .config import (
|
||||
DellTGIImplConfig,
|
||||
InferenceAPIImplConfig,
|
||||
InferenceEndpointImplConfig,
|
||||
TGIImplConfig,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -52,10 +57,6 @@ class _HfAdapter(Inference, ModelsProtocolPrivate):
|
|||
|
||||
async def list_models(self) -> List[ModelDef]:
|
||||
repo = self.model_id
|
||||
# tmp hack to support Dell
|
||||
if repo not in self.huggingface_repo_to_llama_model_id:
|
||||
repo = "meta-llama/Llama-3.1-8B-Instruct"
|
||||
|
||||
identifier = self.huggingface_repo_to_llama_model_id[repo]
|
||||
return [
|
||||
ModelDef(
|
||||
|
|
@ -177,6 +178,14 @@ class TGIAdapter(_HfAdapter):
|
|||
self.model_id = endpoint_info["model_id"]
|
||||
|
||||
|
||||
class DellTGIAdapter(_HfAdapter):
|
||||
async def initialize(self, config: DellTGIImplConfig) -> None:
|
||||
self.client = AsyncInferenceClient(model=config.url, token=config.api_token)
|
||||
endpoint_info = await self.client.get_endpoint_info()
|
||||
self.max_tokens = endpoint_info["max_total_tokens"]
|
||||
self.model_id = config.hf_model_name
|
||||
|
||||
|
||||
class InferenceAPIAdapter(_HfAdapter):
|
||||
async def initialize(self, config: InferenceAPIImplConfig) -> None:
|
||||
self.client = AsyncInferenceClient(
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue