Add unsupported OpenAI mixin to all remaining inference providers

This commit is contained in:
Ben Browning 2025-04-08 12:50:23 -04:00
parent 00c4493bda
commit 15d37fde19
7 changed files with 56 additions and 7 deletions

View file

@ -36,8 +36,10 @@ from llama_stack.providers.utils.inference.model_registry import (
ModelRegistryHelper, ModelRegistryHelper,
) )
from llama_stack.providers.utils.inference.openai_compat import ( from llama_stack.providers.utils.inference.openai_compat import (
OpenAIChatCompletionUnsupportedMixin,
OpenAICompatCompletionChoice, OpenAICompatCompletionChoice,
OpenAICompatCompletionResponse, OpenAICompatCompletionResponse,
OpenAICompletionUnsupportedMixin,
get_sampling_strategy_options, get_sampling_strategy_options,
process_chat_completion_response, process_chat_completion_response,
process_chat_completion_stream_response, process_chat_completion_stream_response,
@ -51,7 +53,12 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
from .models import MODEL_ENTRIES from .models import MODEL_ENTRIES
class BedrockInferenceAdapter(ModelRegistryHelper, Inference): class BedrockInferenceAdapter(
ModelRegistryHelper,
Inference,
OpenAIChatCompletionUnsupportedMixin,
OpenAICompletionUnsupportedMixin,
):
def __init__(self, config: BedrockConfig) -> None: def __init__(self, config: BedrockConfig) -> None:
ModelRegistryHelper.__init__(self, MODEL_ENTRIES) ModelRegistryHelper.__init__(self, MODEL_ENTRIES)
self._config = config self._config = config

View file

@ -34,6 +34,8 @@ from llama_stack.providers.utils.inference.model_registry import (
ModelRegistryHelper, ModelRegistryHelper,
) )
from llama_stack.providers.utils.inference.openai_compat import ( from llama_stack.providers.utils.inference.openai_compat import (
OpenAIChatCompletionUnsupportedMixin,
OpenAICompletionUnsupportedMixin,
get_sampling_options, get_sampling_options,
process_chat_completion_response, process_chat_completion_response,
process_chat_completion_stream_response, process_chat_completion_stream_response,
@ -49,7 +51,12 @@ from .config import CerebrasImplConfig
from .models import MODEL_ENTRIES from .models import MODEL_ENTRIES
class CerebrasInferenceAdapter(ModelRegistryHelper, Inference): class CerebrasInferenceAdapter(
ModelRegistryHelper,
Inference,
OpenAIChatCompletionUnsupportedMixin,
OpenAICompletionUnsupportedMixin,
):
def __init__(self, config: CerebrasImplConfig) -> None: def __init__(self, config: CerebrasImplConfig) -> None:
ModelRegistryHelper.__init__( ModelRegistryHelper.__init__(
self, self,

View file

@ -34,6 +34,8 @@ from llama_stack.providers.utils.inference.model_registry import (
build_hf_repo_model_entry, build_hf_repo_model_entry,
) )
from llama_stack.providers.utils.inference.openai_compat import ( from llama_stack.providers.utils.inference.openai_compat import (
OpenAIChatCompletionUnsupportedMixin,
OpenAICompletionUnsupportedMixin,
get_sampling_options, get_sampling_options,
process_chat_completion_response, process_chat_completion_response,
process_chat_completion_stream_response, process_chat_completion_stream_response,
@ -56,7 +58,12 @@ model_entries = [
] ]
class DatabricksInferenceAdapter(ModelRegistryHelper, Inference): class DatabricksInferenceAdapter(
ModelRegistryHelper,
Inference,
OpenAIChatCompletionUnsupportedMixin,
OpenAICompletionUnsupportedMixin,
):
def __init__(self, config: DatabricksImplConfig) -> None: def __init__(self, config: DatabricksImplConfig) -> None:
ModelRegistryHelper.__init__(self, model_entries=model_entries) ModelRegistryHelper.__init__(self, model_entries=model_entries)
self.config = config self.config = config

View file

@ -40,6 +40,8 @@ from llama_stack.providers.utils.inference.model_registry import (
ModelRegistryHelper, ModelRegistryHelper,
) )
from llama_stack.providers.utils.inference.openai_compat import ( from llama_stack.providers.utils.inference.openai_compat import (
OpenAIChatCompletionUnsupportedMixin,
OpenAICompletionUnsupportedMixin,
convert_openai_chat_completion_choice, convert_openai_chat_completion_choice,
convert_openai_chat_completion_stream, convert_openai_chat_completion_stream,
) )
@ -58,7 +60,12 @@ from .utils import _is_nvidia_hosted
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper): class NVIDIAInferenceAdapter(
Inference,
OpenAIChatCompletionUnsupportedMixin,
OpenAICompletionUnsupportedMixin,
ModelRegistryHelper,
):
def __init__(self, config: NVIDIAConfig) -> None: def __init__(self, config: NVIDIAConfig) -> None:
# TODO(mf): filter by available models # TODO(mf): filter by available models
ModelRegistryHelper.__init__(self, model_entries=MODEL_ENTRIES) ModelRegistryHelper.__init__(self, model_entries=MODEL_ENTRIES)

View file

@ -12,6 +12,8 @@ from llama_stack.apis.inference import * # noqa: F403
# from llama_stack.providers.datatypes import ModelsProtocolPrivate # from llama_stack.providers.datatypes import ModelsProtocolPrivate
from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper
from llama_stack.providers.utils.inference.openai_compat import ( from llama_stack.providers.utils.inference.openai_compat import (
OpenAIChatCompletionUnsupportedMixin,
OpenAICompletionUnsupportedMixin,
get_sampling_options, get_sampling_options,
process_chat_completion_response, process_chat_completion_response,
process_chat_completion_stream_response, process_chat_completion_stream_response,
@ -38,7 +40,12 @@ RUNPOD_SUPPORTED_MODELS = {
} }
class RunpodInferenceAdapter(ModelRegistryHelper, Inference): class RunpodInferenceAdapter(
ModelRegistryHelper,
Inference,
OpenAIChatCompletionUnsupportedMixin,
OpenAICompletionUnsupportedMixin,
):
def __init__(self, config: RunpodImplConfig) -> None: def __init__(self, config: RunpodImplConfig) -> None:
ModelRegistryHelper.__init__(self, stack_to_provider_models_map=RUNPOD_SUPPORTED_MODELS) ModelRegistryHelper.__init__(self, stack_to_provider_models_map=RUNPOD_SUPPORTED_MODELS)
self.config = config self.config = config

View file

@ -42,6 +42,8 @@ from llama_stack.apis.inference import (
) )
from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper
from llama_stack.providers.utils.inference.openai_compat import ( from llama_stack.providers.utils.inference.openai_compat import (
OpenAIChatCompletionUnsupportedMixin,
OpenAICompletionUnsupportedMixin,
process_chat_completion_stream_response, process_chat_completion_stream_response,
) )
from llama_stack.providers.utils.inference.prompt_adapter import ( from llama_stack.providers.utils.inference.prompt_adapter import (
@ -52,7 +54,12 @@ from .config import SambaNovaImplConfig
from .models import MODEL_ENTRIES from .models import MODEL_ENTRIES
class SambaNovaInferenceAdapter(ModelRegistryHelper, Inference): class SambaNovaInferenceAdapter(
ModelRegistryHelper,
Inference,
OpenAIChatCompletionUnsupportedMixin,
OpenAICompletionUnsupportedMixin,
):
def __init__(self, config: SambaNovaImplConfig) -> None: def __init__(self, config: SambaNovaImplConfig) -> None:
ModelRegistryHelper.__init__(self, model_entries=MODEL_ENTRIES) ModelRegistryHelper.__init__(self, model_entries=MODEL_ENTRIES)
self.config = config self.config = config

View file

@ -40,8 +40,10 @@ from llama_stack.providers.utils.inference.model_registry import (
build_hf_repo_model_entry, build_hf_repo_model_entry,
) )
from llama_stack.providers.utils.inference.openai_compat import ( from llama_stack.providers.utils.inference.openai_compat import (
OpenAIChatCompletionUnsupportedMixin,
OpenAICompatCompletionChoice, OpenAICompatCompletionChoice,
OpenAICompatCompletionResponse, OpenAICompatCompletionResponse,
OpenAICompletionUnsupportedMixin,
get_sampling_options, get_sampling_options,
process_chat_completion_response, process_chat_completion_response,
process_chat_completion_stream_response, process_chat_completion_stream_response,
@ -69,7 +71,12 @@ def build_hf_repo_model_entries():
] ]
class _HfAdapter(Inference, ModelsProtocolPrivate): class _HfAdapter(
Inference,
OpenAIChatCompletionUnsupportedMixin,
OpenAICompletionUnsupportedMixin,
ModelsProtocolPrivate,
):
client: AsyncInferenceClient client: AsyncInferenceClient
max_tokens: int max_tokens: int
model_id: str model_id: str