Merge branch 'main' into remove-deprecated-embeddings

This commit is contained in:
Matthew Farrellee 2025-09-27 15:01:32 -04:00
commit 5c44dcdf0e
770 changed files with 176834 additions and 27431 deletions

View file

@ -8,6 +8,7 @@
from collections.abc import AsyncGenerator
from huggingface_hub import AsyncInferenceClient, HfApi
from pydantic import SecretStr
from llama_stack.apis.common.content_types import (
InterleavedContent,
@ -29,6 +30,7 @@ from llama_stack.apis.inference import (
ToolPromptFormat,
)
from llama_stack.apis.models import Model
from llama_stack.apis.models.models import ModelType
from llama_stack.log import get_logger
from llama_stack.models.llama.sku_list import all_registered_models
from llama_stack.providers.datatypes import ModelsProtocolPrivate
@ -37,16 +39,15 @@ from llama_stack.providers.utils.inference.model_registry import (
build_hf_repo_model_entry,
)
from llama_stack.providers.utils.inference.openai_compat import (
OpenAIChatCompletionToLlamaStackMixin,
OpenAICompatCompletionChoice,
OpenAICompatCompletionResponse,
OpenAICompletionToLlamaStackMixin,
get_sampling_options,
process_chat_completion_response,
process_chat_completion_stream_response,
process_completion_response,
process_completion_stream_response,
)
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
from llama_stack.providers.utils.inference.prompt_adapter import (
chat_completion_request_to_model_input_info,
completion_request_to_prompt_model_input_info,
@ -69,26 +70,49 @@ def build_hf_repo_model_entries():
class _HfAdapter(
OpenAIMixin,
Inference,
OpenAIChatCompletionToLlamaStackMixin,
OpenAICompletionToLlamaStackMixin,
ModelsProtocolPrivate,
):
client: AsyncInferenceClient
url: str
api_key: SecretStr
hf_client: AsyncInferenceClient
max_tokens: int
model_id: str
overwrite_completion_id = True # TGI always returns id=""
def __init__(self) -> None:
self.register_helper = ModelRegistryHelper(build_hf_repo_model_entries())
self.huggingface_repo_to_llama_model_id = {
model.huggingface_repo: model.descriptor() for model in all_registered_models() if model.huggingface_repo
}
def get_api_key(self):
return self.api_key.get_secret_value()
def get_base_url(self):
return self.url
async def shutdown(self) -> None:
pass
async def list_models(self) -> list[Model] | None:
models = []
async for model in self.client.models.list():
models.append(
Model(
identifier=model.id,
provider_resource_id=model.id,
provider_id=self.__provider_id__,
metadata={},
model_type=ModelType.llm,
)
)
return models
async def register_model(self, model: Model) -> Model:
model = await self.register_helper.register_model(model)
if model.provider_resource_id != self.model_id:
raise ValueError(
f"Model {model.provider_resource_id} does not match the model {self.model_id} served by TGI."
@ -172,7 +196,7 @@ class _HfAdapter(
params = await self._get_params_for_completion(request)
async def _generate_and_convert_to_openai_compat():
s = await self.client.text_generation(**params)
s = await self.hf_client.text_generation(**params)
async for chunk in s:
token_result = chunk.token
finish_reason = None
@ -190,7 +214,7 @@ class _HfAdapter(
async def _nonstream_completion(self, request: CompletionRequest) -> AsyncGenerator:
params = await self._get_params_for_completion(request)
r = await self.client.text_generation(**params)
r = await self.hf_client.text_generation(**params)
choice = OpenAICompatCompletionChoice(
finish_reason=r.details.finish_reason,
@ -237,7 +261,7 @@ class _HfAdapter(
async def _nonstream_chat_completion(self, request: ChatCompletionRequest) -> ChatCompletionResponse:
params = await self._get_params(request)
r = await self.client.text_generation(**params)
r = await self.hf_client.text_generation(**params)
choice = OpenAICompatCompletionChoice(
finish_reason=r.details.finish_reason,
@ -252,7 +276,7 @@ class _HfAdapter(
params = await self._get_params(request)
async def _generate_and_convert_to_openai_compat():
s = await self.client.text_generation(**params)
s = await self.hf_client.text_generation(**params)
async for chunk in s:
token_result = chunk.token
@ -294,18 +318,21 @@ class TGIAdapter(_HfAdapter):
if not config.url:
raise ValueError("You must provide a URL in run.yaml (or via the TGI_URL environment variable) to use TGI.")
log.info(f"Initializing TGI client with url={config.url}")
self.client = AsyncInferenceClient(model=config.url, provider="hf-inference")
endpoint_info = await self.client.get_endpoint_info()
self.hf_client = AsyncInferenceClient(model=config.url, provider="hf-inference")
endpoint_info = await self.hf_client.get_endpoint_info()
self.max_tokens = endpoint_info["max_total_tokens"]
self.model_id = endpoint_info["model_id"]
self.url = f"{config.url.rstrip('/')}/v1"
self.api_key = SecretStr("NO_KEY")
class InferenceAPIAdapter(_HfAdapter):
async def initialize(self, config: InferenceAPIImplConfig) -> None:
self.client = AsyncInferenceClient(model=config.huggingface_repo, token=config.api_token.get_secret_value())
endpoint_info = await self.client.get_endpoint_info()
self.hf_client = AsyncInferenceClient(model=config.huggingface_repo, token=config.api_token.get_secret_value())
endpoint_info = await self.hf_client.get_endpoint_info()
self.max_tokens = endpoint_info["max_total_tokens"]
self.model_id = endpoint_info["model_id"]
# TODO: how do we set url for this?
class InferenceEndpointAdapter(_HfAdapter):
@ -317,6 +344,7 @@ class InferenceEndpointAdapter(_HfAdapter):
endpoint.wait(timeout=60)
# Initialize the adapter
self.client = endpoint.async_client
self.hf_client = endpoint.async_client
self.model_id = endpoint.repository
self.max_tokens = int(endpoint.raw["model"]["image"]["custom"]["env"]["MAX_TOTAL_TOKENS"])
# TODO: how do we set url for this?