mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-10-04 12:07:34 +00:00
lint
This commit is contained in:
parent
854ff85784
commit
fcb3438031
3 changed files with 15 additions and 5 deletions
|
@ -10,7 +10,10 @@ from .config import InferenceAPIImplConfig, InferenceEndpointImplConfig, TGIImpl
|
||||||
from .tgi import InferenceAPIAdapter, InferenceEndpointAdapter, TGIAdapter
|
from .tgi import InferenceAPIAdapter, InferenceEndpointAdapter, TGIAdapter
|
||||||
|
|
||||||
|
|
||||||
async def get_adapter_impl(config: Union[InferenceAPIImplConfig, InferenceEndpointImplConfig, TGIImplConfig], _deps):
|
async def get_adapter_impl(
|
||||||
|
config: Union[InferenceAPIImplConfig, InferenceEndpointImplConfig, TGIImplConfig],
|
||||||
|
_deps,
|
||||||
|
):
|
||||||
if isinstance(config, TGIImplConfig):
|
if isinstance(config, TGIImplConfig):
|
||||||
impl = TGIAdapter()
|
impl = TGIAdapter()
|
||||||
elif isinstance(config, InferenceAPIImplConfig):
|
elif isinstance(config, InferenceAPIImplConfig):
|
||||||
|
|
|
@ -20,6 +20,7 @@ class TGIImplConfig(BaseModel):
|
||||||
description="A bearer token if your TGI endpoint is protected.",
|
description="A bearer token if your TGI endpoint is protected.",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class InferenceEndpointImplConfig(BaseModel):
|
class InferenceEndpointImplConfig(BaseModel):
|
||||||
endpoint_name: str = Field(
|
endpoint_name: str = Field(
|
||||||
|
@ -40,5 +41,3 @@ class InferenceAPIImplConfig(BaseModel):
|
||||||
default=None,
|
default=None,
|
||||||
description="Your Hugging Face user access token (will default to locally saved token if not provided)",
|
description="Your Hugging Face user access token (will default to locally saved token if not provided)",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -20,6 +20,7 @@ from .config import InferenceAPIImplConfig, InferenceEndpointImplConfig, TGIImpl
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class _HfAdapter(Inference):
|
class _HfAdapter(Inference):
|
||||||
client: AsyncInferenceClient
|
client: AsyncInferenceClient
|
||||||
max_tokens: int
|
max_tokens: int
|
||||||
|
@ -214,6 +215,7 @@ class _HfAdapter(Inference):
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class TGIAdapter(_HfAdapter):
|
class TGIAdapter(_HfAdapter):
|
||||||
async def initialize(self, config: TGIImplConfig) -> None:
|
async def initialize(self, config: TGIImplConfig) -> None:
|
||||||
self.client = AsyncInferenceClient(model=config.url, token=config.api_token)
|
self.client = AsyncInferenceClient(model=config.url, token=config.api_token)
|
||||||
|
@ -221,13 +223,17 @@ class TGIAdapter(_HfAdapter):
|
||||||
self.max_tokens = endpoint_info["max_total_tokens"]
|
self.max_tokens = endpoint_info["max_total_tokens"]
|
||||||
self.model_id = endpoint_info["model_id"]
|
self.model_id = endpoint_info["model_id"]
|
||||||
|
|
||||||
|
|
||||||
class InferenceAPIAdapter(_HfAdapter):
|
class InferenceAPIAdapter(_HfAdapter):
|
||||||
async def initialize(self, config: InferenceAPIImplConfig) -> None:
|
async def initialize(self, config: InferenceAPIImplConfig) -> None:
|
||||||
self.client = AsyncInferenceClient(model=config.model_id, token=config.api_token)
|
self.client = AsyncInferenceClient(
|
||||||
|
model=config.model_id, token=config.api_token
|
||||||
|
)
|
||||||
endpoint_info = await self.client.get_endpoint_info()
|
endpoint_info = await self.client.get_endpoint_info()
|
||||||
self.max_tokens = endpoint_info["max_total_tokens"]
|
self.max_tokens = endpoint_info["max_total_tokens"]
|
||||||
self.model_id = endpoint_info["model_id"]
|
self.model_id = endpoint_info["model_id"]
|
||||||
|
|
||||||
|
|
||||||
class InferenceEndpointAdapter(_HfAdapter):
|
class InferenceEndpointAdapter(_HfAdapter):
|
||||||
async def initialize(self, config: InferenceEndpointImplConfig) -> None:
|
async def initialize(self, config: InferenceEndpointImplConfig) -> None:
|
||||||
# Get the inference endpoint details
|
# Get the inference endpoint details
|
||||||
|
@ -240,4 +246,6 @@ class InferenceEndpointAdapter(_HfAdapter):
|
||||||
# Initialize the adapter
|
# Initialize the adapter
|
||||||
self.client = endpoint.async_client
|
self.client = endpoint.async_client
|
||||||
self.model_id = endpoint.repository
|
self.model_id = endpoint.repository
|
||||||
self.max_tokens = int(endpoint.raw["model"]["image"]["custom"]["env"]["MAX_TOTAL_TOKENS"])
|
self.max_tokens = int(
|
||||||
|
endpoint.raw["model"]["image"]["custom"]["env"]["MAX_TOTAL_TOKENS"]
|
||||||
|
)
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue