This commit is contained in:
Wauplin 2024-09-25 22:27:06 +02:00
parent 854ff85784
commit fcb3438031
No known key found for this signature in database
GPG key ID: 9838FE02BECE1A02
3 changed files with 15 additions and 5 deletions

View file

@ -10,7 +10,10 @@ from .config import InferenceAPIImplConfig, InferenceEndpointImplConfig, TGIImpl
from .tgi import InferenceAPIAdapter, InferenceEndpointAdapter, TGIAdapter
async def get_adapter_impl(config: Union[InferenceAPIImplConfig, InferenceEndpointImplConfig, TGIImplConfig], _deps):
async def get_adapter_impl(
config: Union[InferenceAPIImplConfig, InferenceEndpointImplConfig, TGIImplConfig],
_deps,
):
if isinstance(config, TGIImplConfig):
impl = TGIAdapter()
elif isinstance(config, InferenceAPIImplConfig):

View file

@ -20,6 +20,7 @@ class TGIImplConfig(BaseModel):
description="A bearer token if your TGI endpoint is protected.",
)
@json_schema_type
class InferenceEndpointImplConfig(BaseModel):
endpoint_name: str = Field(
@ -40,5 +41,3 @@ class InferenceAPIImplConfig(BaseModel):
default=None,
description="Your Hugging Face user access token (will default to locally saved token if not provided)",
)

View file

@ -20,6 +20,7 @@ from .config import InferenceAPIImplConfig, InferenceEndpointImplConfig, TGIImpl
logger = logging.getLogger(__name__)
class _HfAdapter(Inference):
client: AsyncInferenceClient
max_tokens: int
@ -214,6 +215,7 @@ class _HfAdapter(Inference):
)
)
class TGIAdapter(_HfAdapter):
async def initialize(self, config: TGIImplConfig) -> None:
self.client = AsyncInferenceClient(model=config.url, token=config.api_token)
@ -221,13 +223,17 @@ class TGIAdapter(_HfAdapter):
self.max_tokens = endpoint_info["max_total_tokens"]
self.model_id = endpoint_info["model_id"]
class InferenceAPIAdapter(_HfAdapter):
async def initialize(self, config: InferenceAPIImplConfig) -> None:
self.client = AsyncInferenceClient(model=config.model_id, token=config.api_token)
self.client = AsyncInferenceClient(
model=config.model_id, token=config.api_token
)
endpoint_info = await self.client.get_endpoint_info()
self.max_tokens = endpoint_info["max_total_tokens"]
self.model_id = endpoint_info["model_id"]
class InferenceEndpointAdapter(_HfAdapter):
async def initialize(self, config: InferenceEndpointImplConfig) -> None:
# Get the inference endpoint details
@ -240,4 +246,6 @@ class InferenceEndpointAdapter(_HfAdapter):
# Initialize the adapter
self.client = endpoint.async_client
self.model_id = endpoint.repository
self.max_tokens = int(endpoint.raw["model"]["image"]["custom"]["env"]["MAX_TOTAL_TOKENS"])
self.max_tokens = int(
endpoint.raw["model"]["image"]["custom"]["env"]["MAX_TOTAL_TOKENS"]
)