Merge branch 'main' into add-watsonx-inference-adapter

This commit is contained in:
Sajikumar JS 2025-04-25 10:57:45 +05:30
commit 6fe8b292b1
74 changed files with 5033 additions and 1685 deletions

View file

@ -362,6 +362,39 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProv
user: Optional[str] = None,
) -> Union[OpenAIChatCompletion, AsyncIterator[OpenAIChatCompletionChunk]]:
model_obj = await self.model_store.get_model(model)
# Divert Llama Models through Llama Stack inference APIs because
# Fireworks chat completions OpenAI-compatible API does not support
# tool calls properly.
llama_model = self.get_llama_model(model_obj.provider_resource_id)
if llama_model:
return await OpenAIChatCompletionToLlamaStackMixin.openai_chat_completion(
self,
model=model,
messages=messages,
frequency_penalty=frequency_penalty,
function_call=function_call,
functions=functions,
logit_bias=logit_bias,
logprobs=logprobs,
max_completion_tokens=max_completion_tokens,
max_tokens=max_tokens,
n=n,
parallel_tool_calls=parallel_tool_calls,
presence_penalty=presence_penalty,
response_format=response_format,
seed=seed,
stop=stop,
stream=stream,
stream_options=stream_options,
temperature=temperature,
tool_choice=tool_choice,
tools=tools,
top_logprobs=top_logprobs,
top_p=top_p,
user=user,
)
params = await prepare_openai_completion_params(
messages=messages,
frequency_penalty=frequency_penalty,
@ -387,11 +420,4 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProv
user=user,
)
# Divert Llama Models through Llama Stack inference APIs because
# Fireworks chat completions OpenAI-compatible API does not support
# tool calls properly.
llama_model = self.get_llama_model(model_obj.provider_resource_id)
if llama_model:
return await OpenAIChatCompletionToLlamaStackMixin.openai_chat_completion(self, model=model, **params)
return await self._get_openai_client().chat.completions.create(model=model_obj.provider_resource_id, **params)

View file

@ -47,10 +47,15 @@ class NVIDIAConfig(BaseModel):
default=60,
description="Timeout for the HTTP requests",
)
append_api_version: bool = Field(
default_factory=lambda: os.getenv("NVIDIA_APPEND_API_VERSION", "True").lower() != "false",
description="When set to false, the API version will not be appended to the base_url. By default, it is true.",
)
@classmethod
def sample_run_config(cls, **kwargs) -> Dict[str, Any]:
return {
"url": "${env.NVIDIA_BASE_URL:https://integrate.api.nvidia.com}",
"api_key": "${env.NVIDIA_API_KEY:}",
"append_api_version": "${env.NVIDIA_APPEND_API_VERSION:True}",
}

View file

@ -33,7 +33,6 @@ from llama_stack.apis.inference import (
TextTruncation,
ToolChoice,
ToolConfig,
ToolDefinition,
)
from llama_stack.apis.inference.inference import (
OpenAIChatCompletion,
@ -42,7 +41,11 @@ from llama_stack.apis.inference.inference import (
OpenAIMessageParam,
OpenAIResponseFormatParam,
)
from llama_stack.models.llama.datatypes import ToolPromptFormat
from llama_stack.apis.models import Model, ModelType
from llama_stack.models.llama.datatypes import ToolDefinition, ToolPromptFormat
from llama_stack.providers.utils.inference import (
ALL_HUGGINGFACE_REPOS_TO_MODEL_DESCRIPTOR,
)
from llama_stack.providers.utils.inference.model_registry import (
ModelRegistryHelper,
)
@ -120,12 +123,20 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper):
"meta/llama-3.2-90b-vision-instruct": "https://ai.api.nvidia.com/v1/gr/meta/llama-3.2-90b-vision-instruct",
}
base_url = f"{self._config.url}/v1"
base_url = f"{self._config.url}/v1" if self._config.append_api_version else self._config.url
if _is_nvidia_hosted(self._config) and provider_model_id in special_model_urls:
base_url = special_model_urls[provider_model_id]
return _get_client_for_base_url(base_url)
async def _get_provider_model_id(self, model_id: str) -> str:
if not self.model_store:
raise RuntimeError("Model store is not set")
model = await self.model_store.get_model(model_id)
if model is None:
raise ValueError(f"Model {model_id} is unknown")
return model.provider_model_id
async def completion(
self,
model_id: str,
@ -144,7 +155,7 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper):
# removing this health check as NeMo customizer endpoint health check is returning 404
# await check_health(self._config) # this raises errors
provider_model_id = self.get_provider_model_id(model_id)
provider_model_id = await self._get_provider_model_id(model_id)
request = convert_completion_request(
request=CompletionRequest(
model=provider_model_id,
@ -188,7 +199,7 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper):
#
flat_contents = [content.text if isinstance(content, TextContentItem) else content for content in contents]
input = [content.text if isinstance(content, TextContentItem) else content for content in flat_contents]
model = self.get_provider_model_id(model_id)
provider_model_id = await self._get_provider_model_id(model_id)
extra_body = {}
@ -211,8 +222,8 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper):
extra_body["input_type"] = task_type_options[task_type]
try:
response = await self._get_client(model).embeddings.create(
model=model,
response = await self._get_client(provider_model_id).embeddings.create(
model=provider_model_id,
input=input,
extra_body=extra_body,
)
@ -246,10 +257,10 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper):
# await check_health(self._config) # this raises errors
provider_model_id = self.get_provider_model_id(model_id)
provider_model_id = await self._get_provider_model_id(model_id)
request = await convert_chat_completion_request(
request=ChatCompletionRequest(
model=self.get_provider_model_id(model_id),
model=provider_model_id,
messages=messages,
sampling_params=sampling_params,
response_format=response_format,
@ -294,7 +305,7 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper):
guided_choice: Optional[List[str]] = None,
prompt_logprobs: Optional[int] = None,
) -> OpenAICompletion:
provider_model_id = self.get_provider_model_id(model)
provider_model_id = await self._get_provider_model_id(model)
params = await prepare_openai_completion_params(
model=provider_model_id,
@ -347,7 +358,7 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper):
top_p: Optional[float] = None,
user: Optional[str] = None,
) -> Union[OpenAIChatCompletion, AsyncIterator[OpenAIChatCompletionChunk]]:
provider_model_id = self.get_provider_model_id(model)
provider_model_id = await self._get_provider_model_id(model)
params = await prepare_openai_completion_params(
model=provider_model_id,
@ -379,3 +390,44 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper):
return await self._get_client(provider_model_id).chat.completions.create(**params)
except APIConnectionError as e:
raise ConnectionError(f"Failed to connect to NVIDIA NIM at {self._config.url}: {e}") from e
async def register_model(self, model: Model) -> Model:
"""
Allow non-llama model registration.
Non-llama model registration: API Catalogue models, post-training models, etc.
client = LlamaStackAsLibraryClient("nvidia")
client.models.register(
model_id="mistralai/mixtral-8x7b-instruct-v0.1",
model_type=ModelType.llm,
provider_id="nvidia",
provider_model_id="mistralai/mixtral-8x7b-instruct-v0.1"
)
NOTE: Only supports models endpoints compatible with AsyncOpenAI base_url format.
"""
if model.model_type == ModelType.embedding:
# embedding models are always registered by their provider model id and does not need to be mapped to a llama model
provider_resource_id = model.provider_resource_id
else:
provider_resource_id = self.get_provider_model_id(model.provider_resource_id)
if provider_resource_id:
model.provider_resource_id = provider_resource_id
else:
llama_model = model.metadata.get("llama_model")
existing_llama_model = self.get_llama_model(model.provider_resource_id)
if existing_llama_model:
if existing_llama_model != llama_model:
raise ValueError(
f"Provider model id '{model.provider_resource_id}' is already registered to a different llama model: '{existing_llama_model}'"
)
else:
# not llama model
if llama_model in ALL_HUGGINGFACE_REPOS_TO_MODEL_DESCRIPTOR:
self.provider_id_to_llama_model_map[model.provider_resource_id] = (
ALL_HUGGINGFACE_REPOS_TO_MODEL_DESCRIPTOR[llama_model]
)
else:
self.alias_to_provider_id_map[model.provider_model_id] = model.provider_model_id
return model

View file

@ -76,8 +76,11 @@ class TogetherInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProvi
async def shutdown(self) -> None:
if self._client:
await self._client.close()
# Together client has no close method, so just set to None
self._client = None
if self._openai_client:
await self._openai_client.close()
self._openai_client = None
async def completion(
self,
@ -359,7 +362,7 @@ class TogetherInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProvi
top_p=top_p,
user=user,
)
if params.get("stream", True):
if params.get("stream", False):
return self._stream_openai_chat_completion(params)
return await self._get_openai_client().chat.completions.create(**params) # type: ignore

View file

@ -231,12 +231,7 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
self.client = None
async def initialize(self) -> None:
log.info(f"Initializing VLLM client with base_url={self.config.url}")
self.client = AsyncOpenAI(
base_url=self.config.url,
api_key=self.config.api_token,
http_client=None if self.config.tls_verify else httpx.AsyncClient(verify=False),
)
pass
async def shutdown(self) -> None:
pass
@ -249,6 +244,20 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
raise ValueError("Model store not set")
return await self.model_store.get_model(model_id)
def _lazy_initialize_client(self):
if self.client is not None:
return
log.info(f"Initializing vLLM client with base_url={self.config.url}")
self.client = self._create_client()
def _create_client(self):
return AsyncOpenAI(
base_url=self.config.url,
api_key=self.config.api_token,
http_client=None if self.config.tls_verify else httpx.AsyncClient(verify=False),
)
async def completion(
self,
model_id: str,
@ -258,6 +267,7 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
) -> CompletionResponse | AsyncGenerator[CompletionResponseStreamChunk, None]:
self._lazy_initialize_client()
if sampling_params is None:
sampling_params = SamplingParams()
model = await self._get_model(model_id)
@ -287,6 +297,7 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
logprobs: Optional[LogProbConfig] = None,
tool_config: Optional[ToolConfig] = None,
) -> ChatCompletionResponse | AsyncGenerator[ChatCompletionResponseStreamChunk, None]:
self._lazy_initialize_client()
if sampling_params is None:
sampling_params = SamplingParams()
model = await self._get_model(model_id)
@ -357,9 +368,12 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
yield chunk
async def register_model(self, model: Model) -> Model:
assert self.client is not None
# register_model is called during Llama Stack initialization, hence we cannot init self.client if not initialized yet.
# self.client should only be created after the initialization is complete to avoid asyncio cross-context errors.
# Changing this may lead to unpredictable behavior.
client = self._create_client() if self.client is None else self.client
model = await self.register_helper.register_model(model)
res = await self.client.models.list()
res = await client.models.list()
available_models = [m.id async for m in res]
if model.provider_resource_id not in available_models:
raise ValueError(
@ -410,6 +424,7 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
output_dimension: Optional[int] = None,
task_type: Optional[EmbeddingTaskType] = None,
) -> EmbeddingsResponse:
self._lazy_initialize_client()
assert self.client is not None
model = await self._get_model(model_id)
@ -449,6 +464,7 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
guided_choice: Optional[List[str]] = None,
prompt_logprobs: Optional[int] = None,
) -> OpenAICompletion:
self._lazy_initialize_client()
model_obj = await self._get_model(model)
extra_body: Dict[str, Any] = {}
@ -505,6 +521,7 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
top_p: Optional[float] = None,
user: Optional[str] = None,
) -> Union[OpenAIChatCompletion, AsyncIterator[OpenAIChatCompletionChunk]]:
self._lazy_initialize_client()
model_obj = await self._get_model(model)
params = await prepare_openai_completion_params(
model=model_obj.provider_resource_id,