feat: add dynamic model registration support to TGI inference

add new overwrite_completion_id feature to OpenAIMixin to deal with TGI always returning id=""

test with -

tgi: `docker run --gpus all --shm-size 1g -p 8080:80 -v /data:/data ghcr.io/huggingface/text-generation-inference --model-id Qwen/Qwen3-0.6B`

stack: `TGI_URL=http://localhost:8080 uv run llama stack build --image-type venv --distro ci-tests --run`

test: `./scripts/integration-tests.sh --stack-config http://localhost:8321 --setup tgi --subdirs inference --pattern openai`
This commit is contained in:
Matthew Farrellee 2025-09-11 02:02:02 -04:00
parent d15368a302
commit c3fc859257
14 changed files with 12218 additions and 20 deletions

View file

@ -4,6 +4,7 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import uuid
from abc import ABC, abstractmethod
from collections.abc import AsyncIterator
from typing import Any
@ -43,6 +44,12 @@ class OpenAIMixin(ABC):
The model_store is set in routing_tables/common.py during provider initialization.
"""
# Allow subclasses to control whether to overwrite the 'id' field in OpenAI responses
# is overwritten with a client-side generated id.
#
# This is useful for providers that do not return a unique id in the response,
overwrite_completion_id: bool = False
@abstractmethod
def get_api_key(self) -> str:
"""
@ -98,6 +105,23 @@ class OpenAIMixin(ABC):
raise ValueError(f"Model {model} has no provider_resource_id")
return model_obj.provider_resource_id
async def _maybe_overwrite_id(self, resp: Any, stream: bool | None) -> Any:
if not self.overwrite_completion_id:
return resp
new_id = f"cltsd-{uuid.uuid4()}"
if stream:
async def _gen():
async for chunk in resp:
chunk.id = new_id
yield chunk
return _gen()
else:
resp.id = new_id
return resp
async def openai_completion(
self,
model: str,
@ -130,7 +154,7 @@ class OpenAIMixin(ABC):
logger.warning("prompt_logprobs is not supported by the OpenAI API. Ignoring.")
# TODO: fix openai_completion to return type compatible with OpenAI's API response
return await self.client.completions.create( # type: ignore[no-any-return]
resp = await self.client.completions.create(
**await prepare_openai_completion_params(
model=await self._get_provider_model_id(model),
prompt=prompt,
@ -153,6 +177,8 @@ class OpenAIMixin(ABC):
)
)
return await self._maybe_overwrite_id(resp, stream) # type: ignore[no-any-return]
async def openai_chat_completion(
self,
model: str,
@ -182,8 +208,7 @@ class OpenAIMixin(ABC):
"""
Direct OpenAI chat completion API call.
"""
# Type ignore because return types are compatible
return await self.client.chat.completions.create( # type: ignore[no-any-return]
resp = await self.client.chat.completions.create(
**await prepare_openai_completion_params(
model=await self._get_provider_model_id(model),
messages=messages,
@ -211,6 +236,8 @@ class OpenAIMixin(ABC):
)
)
return await self._maybe_overwrite_id(resp, stream) # type: ignore[no-any-return]
async def openai_embeddings(
self,
model: str,