diff --git a/llama_stack/providers/adapters/inference/tgi/__init__.py b/llama_stack/providers/adapters/inference/tgi/__init__.py index 95803c200..451650323 100644 --- a/llama_stack/providers/adapters/inference/tgi/__init__.py +++ b/llama_stack/providers/adapters/inference/tgi/__init__.py @@ -10,7 +10,10 @@ from .config import InferenceAPIImplConfig, InferenceEndpointImplConfig, TGIImpl from .tgi import InferenceAPIAdapter, InferenceEndpointAdapter, TGIAdapter -async def get_adapter_impl(config: Union[InferenceAPIImplConfig, InferenceEndpointImplConfig, TGIImplConfig], _deps): +async def get_adapter_impl( + config: Union[InferenceAPIImplConfig, InferenceEndpointImplConfig, TGIImplConfig], + _deps, +): if isinstance(config, TGIImplConfig): impl = TGIAdapter() elif isinstance(config, InferenceAPIImplConfig): diff --git a/llama_stack/providers/adapters/inference/tgi/config.py b/llama_stack/providers/adapters/inference/tgi/config.py index c91846fc9..233205066 100644 --- a/llama_stack/providers/adapters/inference/tgi/config.py +++ b/llama_stack/providers/adapters/inference/tgi/config.py @@ -20,6 +20,7 @@ class TGIImplConfig(BaseModel): description="A bearer token if your TGI endpoint is protected.", ) + @json_schema_type class InferenceEndpointImplConfig(BaseModel): endpoint_name: str = Field( @@ -40,5 +41,3 @@ class InferenceAPIImplConfig(BaseModel): default=None, description="Your Hugging Face user access token (will default to locally saved token if not provided)", ) - - diff --git a/llama_stack/providers/adapters/inference/tgi/tgi.py b/llama_stack/providers/adapters/inference/tgi/tgi.py index 7f95df313..60c1d895e 100644 --- a/llama_stack/providers/adapters/inference/tgi/tgi.py +++ b/llama_stack/providers/adapters/inference/tgi/tgi.py @@ -20,6 +20,7 @@ from .config import InferenceAPIImplConfig, InferenceEndpointImplConfig, TGIImpl logger = logging.getLogger(__name__) + class _HfAdapter(Inference): client: AsyncInferenceClient max_tokens: int @@ -214,6 +215,7 @@ class _HfAdapter(Inference): ) ) + class TGIAdapter(_HfAdapter): async def initialize(self, config: TGIImplConfig) -> None: self.client = AsyncInferenceClient(model=config.url, token=config.api_token) @@ -221,13 +223,17 @@ class TGIAdapter(_HfAdapter): self.max_tokens = endpoint_info["max_total_tokens"] self.model_id = endpoint_info["model_id"] + class InferenceAPIAdapter(_HfAdapter): async def initialize(self, config: InferenceAPIImplConfig) -> None: - self.client = AsyncInferenceClient(model=config.model_id, token=config.api_token) + self.client = AsyncInferenceClient( + model=config.model_id, token=config.api_token + ) endpoint_info = await self.client.get_endpoint_info() self.max_tokens = endpoint_info["max_total_tokens"] self.model_id = endpoint_info["model_id"] + class InferenceEndpointAdapter(_HfAdapter): async def initialize(self, config: InferenceEndpointImplConfig) -> None: # Get the inference endpoint details @@ -240,4 +246,6 @@ class InferenceEndpointAdapter(_HfAdapter): # Initialize the adapter self.client = endpoint.async_client self.model_id = endpoint.repository - self.max_tokens = int(endpoint.raw["model"]["image"]["custom"]["env"]["MAX_TOTAL_TOKENS"]) + self.max_tokens = int( + endpoint.raw["model"]["image"]["custom"]["env"]["MAX_TOTAL_TOKENS"] + )