mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-10-04 04:04:14 +00:00
feat: add dynamic model registration support to TGI inference (#3417)
Some checks failed
Vector IO Integration Tests / test-matrix (push) Failing after 4s
Update ReadTheDocs / update-readthedocs (push) Failing after 3s
UI Tests / ui-tests (22) (push) Successful in 43s
SqlStore Integration Tests / test-postgres (3.13) (push) Failing after 3s
Integration Auth Tests / test-matrix (oauth2_token) (push) Failing after 1s
API Conformance Tests / check-schema-compatibility (push) Successful in 7s
Unit Tests / unit-tests (3.13) (push) Failing after 4s
Pre-commit / pre-commit (push) Successful in 1m21s
Test External Providers Installed via Module / test-external-providers-from-module (venv) (push) Has been skipped
Integration Tests (Replay) / Integration Tests (, , , client=, ) (push) Failing after 3s
Python Package Build Test / build (3.12) (push) Failing after 2s
Python Package Build Test / build (3.13) (push) Failing after 2s
SqlStore Integration Tests / test-postgres (3.12) (push) Failing after 5s
Unit Tests / unit-tests (3.12) (push) Failing after 3s
Test External API and Providers / test-external (venv) (push) Failing after 5s
Some checks failed
Vector IO Integration Tests / test-matrix (push) Failing after 4s
Update ReadTheDocs / update-readthedocs (push) Failing after 3s
UI Tests / ui-tests (22) (push) Successful in 43s
SqlStore Integration Tests / test-postgres (3.13) (push) Failing after 3s
Integration Auth Tests / test-matrix (oauth2_token) (push) Failing after 1s
API Conformance Tests / check-schema-compatibility (push) Successful in 7s
Unit Tests / unit-tests (3.13) (push) Failing after 4s
Pre-commit / pre-commit (push) Successful in 1m21s
Test External Providers Installed via Module / test-external-providers-from-module (venv) (push) Has been skipped
Integration Tests (Replay) / Integration Tests (, , , client=, ) (push) Failing after 3s
Python Package Build Test / build (3.12) (push) Failing after 2s
Python Package Build Test / build (3.13) (push) Failing after 2s
SqlStore Integration Tests / test-postgres (3.12) (push) Failing after 5s
Unit Tests / unit-tests (3.12) (push) Failing after 3s
Test External API and Providers / test-external (venv) (push) Failing after 5s
# What does this PR do? adds dynamic model support to TGI add new overwrite_completion_id feature to OpenAIMixin to deal with TGI always returning id="" ## Test Plan tgi: `docker run --gpus all --shm-size 1g -p 8080:80 -v /data:/data ghcr.io/huggingface/text-generation-inference --model-id Qwen/Qwen3-0.6B` stack: `TGI_URL=http://localhost:8080 uv run llama stack build --image-type venv --distro ci-tests --run` test: `./scripts/integration-tests.sh --stack-config http://localhost:8321 --setup tgi --subdirs inference --pattern openai`
This commit is contained in:
parent
ab321739f2
commit
f4ab154ade
14 changed files with 12218 additions and 20 deletions
|
@ -8,6 +8,7 @@
|
|||
from collections.abc import AsyncGenerator
|
||||
|
||||
from huggingface_hub import AsyncInferenceClient, HfApi
|
||||
from pydantic import SecretStr
|
||||
|
||||
from llama_stack.apis.common.content_types import (
|
||||
InterleavedContent,
|
||||
|
@ -33,6 +34,7 @@ from llama_stack.apis.inference import (
|
|||
ToolPromptFormat,
|
||||
)
|
||||
from llama_stack.apis.models import Model
|
||||
from llama_stack.apis.models.models import ModelType
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.models.llama.sku_list import all_registered_models
|
||||
from llama_stack.providers.datatypes import ModelsProtocolPrivate
|
||||
|
@ -41,16 +43,15 @@ from llama_stack.providers.utils.inference.model_registry import (
|
|||
build_hf_repo_model_entry,
|
||||
)
|
||||
from llama_stack.providers.utils.inference.openai_compat import (
|
||||
OpenAIChatCompletionToLlamaStackMixin,
|
||||
OpenAICompatCompletionChoice,
|
||||
OpenAICompatCompletionResponse,
|
||||
OpenAICompletionToLlamaStackMixin,
|
||||
get_sampling_options,
|
||||
process_chat_completion_response,
|
||||
process_chat_completion_stream_response,
|
||||
process_completion_response,
|
||||
process_completion_stream_response,
|
||||
)
|
||||
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
|
||||
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||
chat_completion_request_to_model_input_info,
|
||||
completion_request_to_prompt_model_input_info,
|
||||
|
@ -73,26 +74,49 @@ def build_hf_repo_model_entries():
|
|||
|
||||
|
||||
class _HfAdapter(
|
||||
OpenAIMixin,
|
||||
Inference,
|
||||
OpenAIChatCompletionToLlamaStackMixin,
|
||||
OpenAICompletionToLlamaStackMixin,
|
||||
ModelsProtocolPrivate,
|
||||
):
|
||||
client: AsyncInferenceClient
|
||||
url: str
|
||||
api_key: SecretStr
|
||||
|
||||
hf_client: AsyncInferenceClient
|
||||
max_tokens: int
|
||||
model_id: str
|
||||
|
||||
overwrite_completion_id = True # TGI always returns id=""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.register_helper = ModelRegistryHelper(build_hf_repo_model_entries())
|
||||
self.huggingface_repo_to_llama_model_id = {
|
||||
model.huggingface_repo: model.descriptor() for model in all_registered_models() if model.huggingface_repo
|
||||
}
|
||||
|
||||
def get_api_key(self):
|
||||
return self.api_key.get_secret_value()
|
||||
|
||||
def get_base_url(self):
|
||||
return self.url
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
pass
|
||||
|
||||
async def list_models(self) -> list[Model] | None:
|
||||
models = []
|
||||
async for model in self.client.models.list():
|
||||
models.append(
|
||||
Model(
|
||||
identifier=model.id,
|
||||
provider_resource_id=model.id,
|
||||
provider_id=self.__provider_id__,
|
||||
metadata={},
|
||||
model_type=ModelType.llm,
|
||||
)
|
||||
)
|
||||
return models
|
||||
|
||||
async def register_model(self, model: Model) -> Model:
|
||||
model = await self.register_helper.register_model(model)
|
||||
if model.provider_resource_id != self.model_id:
|
||||
raise ValueError(
|
||||
f"Model {model.provider_resource_id} does not match the model {self.model_id} served by TGI."
|
||||
|
@ -176,7 +200,7 @@ class _HfAdapter(
|
|||
params = await self._get_params_for_completion(request)
|
||||
|
||||
async def _generate_and_convert_to_openai_compat():
|
||||
s = await self.client.text_generation(**params)
|
||||
s = await self.hf_client.text_generation(**params)
|
||||
async for chunk in s:
|
||||
token_result = chunk.token
|
||||
finish_reason = None
|
||||
|
@ -194,7 +218,7 @@ class _HfAdapter(
|
|||
|
||||
async def _nonstream_completion(self, request: CompletionRequest) -> AsyncGenerator:
|
||||
params = await self._get_params_for_completion(request)
|
||||
r = await self.client.text_generation(**params)
|
||||
r = await self.hf_client.text_generation(**params)
|
||||
|
||||
choice = OpenAICompatCompletionChoice(
|
||||
finish_reason=r.details.finish_reason,
|
||||
|
@ -241,7 +265,7 @@ class _HfAdapter(
|
|||
|
||||
async def _nonstream_chat_completion(self, request: ChatCompletionRequest) -> ChatCompletionResponse:
|
||||
params = await self._get_params(request)
|
||||
r = await self.client.text_generation(**params)
|
||||
r = await self.hf_client.text_generation(**params)
|
||||
|
||||
choice = OpenAICompatCompletionChoice(
|
||||
finish_reason=r.details.finish_reason,
|
||||
|
@ -256,7 +280,7 @@ class _HfAdapter(
|
|||
params = await self._get_params(request)
|
||||
|
||||
async def _generate_and_convert_to_openai_compat():
|
||||
s = await self.client.text_generation(**params)
|
||||
s = await self.hf_client.text_generation(**params)
|
||||
async for chunk in s:
|
||||
token_result = chunk.token
|
||||
|
||||
|
@ -308,18 +332,21 @@ class TGIAdapter(_HfAdapter):
|
|||
if not config.url:
|
||||
raise ValueError("You must provide a URL in run.yaml (or via the TGI_URL environment variable) to use TGI.")
|
||||
log.info(f"Initializing TGI client with url={config.url}")
|
||||
self.client = AsyncInferenceClient(model=config.url, provider="hf-inference")
|
||||
endpoint_info = await self.client.get_endpoint_info()
|
||||
self.hf_client = AsyncInferenceClient(model=config.url, provider="hf-inference")
|
||||
endpoint_info = await self.hf_client.get_endpoint_info()
|
||||
self.max_tokens = endpoint_info["max_total_tokens"]
|
||||
self.model_id = endpoint_info["model_id"]
|
||||
self.url = f"{config.url.rstrip('/')}/v1"
|
||||
self.api_key = SecretStr("NO_KEY")
|
||||
|
||||
|
||||
class InferenceAPIAdapter(_HfAdapter):
|
||||
async def initialize(self, config: InferenceAPIImplConfig) -> None:
|
||||
self.client = AsyncInferenceClient(model=config.huggingface_repo, token=config.api_token.get_secret_value())
|
||||
endpoint_info = await self.client.get_endpoint_info()
|
||||
self.hf_client = AsyncInferenceClient(model=config.huggingface_repo, token=config.api_token.get_secret_value())
|
||||
endpoint_info = await self.hf_client.get_endpoint_info()
|
||||
self.max_tokens = endpoint_info["max_total_tokens"]
|
||||
self.model_id = endpoint_info["model_id"]
|
||||
# TODO: how do we set url for this?
|
||||
|
||||
|
||||
class InferenceEndpointAdapter(_HfAdapter):
|
||||
|
@ -331,6 +358,7 @@ class InferenceEndpointAdapter(_HfAdapter):
|
|||
endpoint.wait(timeout=60)
|
||||
|
||||
# Initialize the adapter
|
||||
self.client = endpoint.async_client
|
||||
self.hf_client = endpoint.async_client
|
||||
self.model_id = endpoint.repository
|
||||
self.max_tokens = int(endpoint.raw["model"]["image"]["custom"]["env"]["MAX_TOTAL_TOKENS"])
|
||||
# TODO: how do we set url for this?
|
||||
|
|
|
@ -4,6 +4,7 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import uuid
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import AsyncIterator
|
||||
from typing import Any
|
||||
|
@ -43,6 +44,12 @@ class OpenAIMixin(ABC):
|
|||
The model_store is set in routing_tables/common.py during provider initialization.
|
||||
"""
|
||||
|
||||
# Allow subclasses to control whether to overwrite the 'id' field in OpenAI responses
|
||||
# is overwritten with a client-side generated id.
|
||||
#
|
||||
# This is useful for providers that do not return a unique id in the response.
|
||||
overwrite_completion_id: bool = False
|
||||
|
||||
@abstractmethod
|
||||
def get_api_key(self) -> str:
|
||||
"""
|
||||
|
@ -110,6 +117,23 @@ class OpenAIMixin(ABC):
|
|||
raise ValueError(f"Model {model} has no provider_resource_id")
|
||||
return model_obj.provider_resource_id
|
||||
|
||||
async def _maybe_overwrite_id(self, resp: Any, stream: bool | None) -> Any:
|
||||
if not self.overwrite_completion_id:
|
||||
return resp
|
||||
|
||||
new_id = f"cltsd-{uuid.uuid4()}"
|
||||
if stream:
|
||||
|
||||
async def _gen():
|
||||
async for chunk in resp:
|
||||
chunk.id = new_id
|
||||
yield chunk
|
||||
|
||||
return _gen()
|
||||
else:
|
||||
resp.id = new_id
|
||||
return resp
|
||||
|
||||
async def openai_completion(
|
||||
self,
|
||||
model: str,
|
||||
|
@ -147,7 +171,7 @@ class OpenAIMixin(ABC):
|
|||
extra_body["guided_choice"] = guided_choice
|
||||
|
||||
# TODO: fix openai_completion to return type compatible with OpenAI's API response
|
||||
return await self.client.completions.create( # type: ignore[no-any-return]
|
||||
resp = await self.client.completions.create(
|
||||
**await prepare_openai_completion_params(
|
||||
model=await self._get_provider_model_id(model),
|
||||
prompt=prompt,
|
||||
|
@ -171,6 +195,8 @@ class OpenAIMixin(ABC):
|
|||
extra_body=extra_body,
|
||||
)
|
||||
|
||||
return await self._maybe_overwrite_id(resp, stream) # type: ignore[no-any-return]
|
||||
|
||||
async def openai_chat_completion(
|
||||
self,
|
||||
model: str,
|
||||
|
@ -200,8 +226,7 @@ class OpenAIMixin(ABC):
|
|||
"""
|
||||
Direct OpenAI chat completion API call.
|
||||
"""
|
||||
# Type ignore because return types are compatible
|
||||
return await self.client.chat.completions.create( # type: ignore[no-any-return]
|
||||
resp = await self.client.chat.completions.create(
|
||||
**await prepare_openai_completion_params(
|
||||
model=await self._get_provider_model_id(model),
|
||||
messages=messages,
|
||||
|
@ -229,6 +254,8 @@ class OpenAIMixin(ABC):
|
|||
)
|
||||
)
|
||||
|
||||
return await self._maybe_overwrite_id(resp, stream) # type: ignore[no-any-return]
|
||||
|
||||
async def openai_embeddings(
|
||||
self,
|
||||
model: str,
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue