From 4de45560bf60072978fbefab64a95b98c604948f Mon Sep 17 00:00:00 2001 From: Charlie Doern Date: Tue, 11 Mar 2025 18:15:45 -0400 Subject: [PATCH] feat: remote ramalama provider implementation Implement remote ramalama provider using AsyncOpenAI as the client since ramalama doesn't have its own Async library. Ramalama is similar to ollama, as it is a lightweight local inference server. However, it runs by default in a containerized mode. RAMALAMA_URL is http://localhost:8080 by default Signed-off-by: Charlie Doern --- llama_stack/distribution/resolver.py | 1 + llama_stack/providers/registry/inference.py | 9 + .../remote/inference/ramalama/__init__.py | 15 + .../remote/inference/ramalama/config.py | 19 + .../remote/inference/ramalama/models.py | 103 ++++++ .../remote/inference/ramalama/openai_utils.py | 344 ++++++++++++++++++ .../remote/inference/ramalama/ramalama.py | 188 ++++++++++ pyproject.toml | 1 + 8 files changed, 680 insertions(+) create mode 100644 llama_stack/providers/remote/inference/ramalama/__init__.py create mode 100644 llama_stack/providers/remote/inference/ramalama/config.py create mode 100644 llama_stack/providers/remote/inference/ramalama/models.py create mode 100644 llama_stack/providers/remote/inference/ramalama/openai_utils.py create mode 100644 llama_stack/providers/remote/inference/ramalama/ramalama.py diff --git a/llama_stack/distribution/resolver.py b/llama_stack/distribution/resolver.py index e9a594eba..4c419042b 100644 --- a/llama_stack/distribution/resolver.py +++ b/llama_stack/distribution/resolver.py @@ -306,6 +306,7 @@ async def instantiate_provider( additional_protocols = additional_protocols_map() provider_spec = provider.spec + if not hasattr(provider_spec, "module"): raise AttributeError(f"ProviderSpec of type {type(provider_spec)} does not have a 'module' attribute") diff --git a/llama_stack/providers/registry/inference.py b/llama_stack/providers/registry/inference.py index 3c54cabcf..788e66505 100644 --- a/llama_stack/providers/registry/inference.py +++ b/llama_stack/providers/registry/inference.py @@ -77,6 +77,15 @@ def available_providers() -> List[ProviderSpec]: module="llama_stack.providers.remote.inference.ollama", ), ), + remote_provider_spec( + api=Api.inference, + adapter=AdapterSpec( + adapter_type="ramalama", + pip_packages=["ramalama", "aiohttp"], + config_class="llama_stack.providers.remote.inference.ramalama.RamalamaImplConfig", + module="llama_stack.providers.remote.inference.ramalama", + ), + ), remote_provider_spec( api=Api.inference, adapter=AdapterSpec( diff --git a/llama_stack/providers/remote/inference/ramalama/__init__.py b/llama_stack/providers/remote/inference/ramalama/__init__.py new file mode 100644 index 000000000..77d1a32b9 --- /dev/null +++ b/llama_stack/providers/remote/inference/ramalama/__init__.py @@ -0,0 +1,15 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from .config import RamalamaImplConfig + + +async def get_adapter_impl(config: RamalamaImplConfig, _deps): + from .ramalama import RamalamaInferenceAdapter + + impl = RamalamaInferenceAdapter(config.url) + await impl.initialize() + return impl diff --git a/llama_stack/providers/remote/inference/ramalama/config.py b/llama_stack/providers/remote/inference/ramalama/config.py new file mode 100644 index 000000000..b6b59cee4 --- /dev/null +++ b/llama_stack/providers/remote/inference/ramalama/config.py @@ -0,0 +1,19 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from typing import Any, Dict + +from pydantic import BaseModel + +DEFAULT_RAMALAMA_URL = "http://localhost:8080" + + +class RamalamaImplConfig(BaseModel): + url: str = DEFAULT_RAMALAMA_URL + + @classmethod + def sample_run_config(cls, url: str = "${env.RAMALAMA_URL:http://localhost:8080}", **kwargs) -> Dict[str, Any]: + return {"url": url} diff --git a/llama_stack/providers/remote/inference/ramalama/models.py b/llama_stack/providers/remote/inference/ramalama/models.py new file mode 100644 index 000000000..be556762c --- /dev/null +++ b/llama_stack/providers/remote/inference/ramalama/models.py @@ -0,0 +1,103 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from llama_stack.apis.models.models import ModelType +from llama_stack.models.llama.datatypes import CoreModelId +from llama_stack.providers.utils.inference.model_registry import ( + ProviderModelEntry, + build_hf_repo_model_entry, + build_model_entry, +) + +model_entries = [ + build_hf_repo_model_entry( + "llama3.1:8b-instruct-fp16", + CoreModelId.llama3_1_8b_instruct.value, + ), + build_model_entry( + "llama3.1:8b", + CoreModelId.llama3_1_8b_instruct.value, + ), + build_hf_repo_model_entry( + "llama3.1:70b-instruct-fp16", + CoreModelId.llama3_1_70b_instruct.value, + ), + build_model_entry( + "llama3.1:70b", + CoreModelId.llama3_1_70b_instruct.value, + ), + build_hf_repo_model_entry( + "llama3.1:405b-instruct-fp16", + CoreModelId.llama3_1_405b_instruct.value, + ), + build_model_entry( + "llama3.1:405b", + CoreModelId.llama3_1_405b_instruct.value, + ), + build_hf_repo_model_entry( + "llama3.2:1b-instruct-fp16", + CoreModelId.llama3_2_1b_instruct.value, + ), + build_model_entry( + "llama3.2:1b", + CoreModelId.llama3_2_1b_instruct.value, + ), + build_hf_repo_model_entry( + "llama3.2:3b-instruct-fp16", + CoreModelId.llama3_2_3b_instruct.value, + ), + build_model_entry( + "llama3.2:3b", + CoreModelId.llama3_2_3b_instruct.value, + ), + build_hf_repo_model_entry( + "llama3.2-vision:11b-instruct-fp16", + CoreModelId.llama3_2_11b_vision_instruct.value, + ), + build_model_entry( + "llama3.2-vision:latest", + CoreModelId.llama3_2_11b_vision_instruct.value, + ), + build_hf_repo_model_entry( + "llama3.2-vision:90b-instruct-fp16", + CoreModelId.llama3_2_90b_vision_instruct.value, + ), + build_model_entry( + "llama3.2-vision:90b", + CoreModelId.llama3_2_90b_vision_instruct.value, + ), + build_hf_repo_model_entry( + "llama3.3:70b", + CoreModelId.llama3_3_70b_instruct.value, + ), + # The Llama Guard models don't have their full fp16 versions + # so we are going to alias their default version to the canonical SKU + build_hf_repo_model_entry( + "llama-guard3:8b", + CoreModelId.llama_guard_3_8b.value, + ), + build_hf_repo_model_entry( + "llama-guard3:1b", + CoreModelId.llama_guard_3_1b.value, + ), + ProviderModelEntry( + provider_model_id="all-minilm:latest", + aliases=["all-minilm"], + model_type=ModelType.embedding, + metadata={ + "embedding_dimension": 384, + "context_length": 512, + }, + ), + ProviderModelEntry( + provider_model_id="nomic-embed-text", + model_type=ModelType.embedding, + metadata={ + "embedding_dimension": 768, + "context_length": 8192, + }, + ), +] diff --git a/llama_stack/providers/remote/inference/ramalama/openai_utils.py b/llama_stack/providers/remote/inference/ramalama/openai_utils.py new file mode 100644 index 000000000..7458d28ea --- /dev/null +++ b/llama_stack/providers/remote/inference/ramalama/openai_utils.py @@ -0,0 +1,344 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from typing import List, Optional + +from llama_stack.apis.inference import ( + ChatCompletionRequest, + GrammarResponseFormat, + JsonSchemaResponseFormat, + Message, + ToolChoice, + UserMessage, +) +from llama_stack.providers.utils.inference.openai_compat import ( + convert_message_to_openai_dict, + get_sampling_options, +) + + +def _merge_context_into_content(message: Message) -> Message: # type: ignore + """ + Merge the ``context`` field of a Llama Stack ``Message`` object into + the content field for compabilitiy with OpenAI-style APIs. + + Generates a content string that emulates the current behavior + of ``llama_models.llama3.api.chat_format.encode_message()``. + + :param message: Message that may include ``context`` field + + :returns: A version of ``message`` with any context merged into the + ``content`` field. + """ + if not isinstance(message, UserMessage): # Separate type check for linter + return message + if message.context is None: + return message + return UserMessage( + role=message.role, + # Emumate llama_models.llama3.api.chat_format.encode_message() + content=message.content + "\n\n" + message.context, + context=None, + ) + + +async def llama_stack_chat_completion_to_openai_chat_completion_dict( + request: ChatCompletionRequest, +) -> dict: + """ + Convert a chat completion request in Llama Stack format into an + equivalent set of arguments to pass to an OpenAI-compatible + chat completions API. + + :param request: Bundled request parameters in Llama Stack format. + + :returns: Dictionary of key-value pairs to use as an initializer + for a dataclass or to be converted directly to JSON and sent + over the wire. + """ + + converted_messages = [ + # This mystery async call makes the parent function also be async + await convert_message_to_openai_dict(_merge_context_into_content(m), download=True) + for m in request.messages + ] + # converted_tools = _llama_stack_tools_to_openai_tools(request.tools) + + # Llama will try to use built-in tools with no tool catalog, so don't enable + # tool choice unless at least one tool is enabled. + converted_tool_choice = "none" + if ( + request.tool_config is not None + and request.tool_config.tool_choice == ToolChoice.auto + and request.tools is not None + and len(request.tools) > 0 + ): + converted_tool_choice = "auto" + + # TODO: Figure out what to do with the tool_prompt_format argument. + # Other connectors appear to drop it quietly. + + # Use Llama Stack shared code to translate sampling parameters. + sampling_options = get_sampling_options(request.sampling_params) + + # get_sampling_options() translates repetition penalties to an option that + # OpenAI's APIs don't know about. + # vLLM's OpenAI-compatible API also handles repetition penalties wrong. + # For now, translate repetition penalties into a format that vLLM's broken + # API will handle correctly. Two wrongs make a right... + if "repeat_penalty" in sampling_options: + del sampling_options["repeat_penalty"] + if request.sampling_params.repetition_penalty is not None and request.sampling_params.repetition_penalty != 1.0: + sampling_options["repetition_penalty"] = request.sampling_params.repetition_penalty + + # Convert a single response format into four different parameters, per + # the OpenAI spec + guided_decoding_options = dict() + if request.response_format is None: + # Use defaults + pass + elif isinstance(request.response_format, JsonSchemaResponseFormat): + guided_decoding_options["guided_json"] = request.response_format.json_schema + elif isinstance(request.response_format, GrammarResponseFormat): + guided_decoding_options["guided_grammar"] = request.response_format.bnf + else: + raise TypeError(f"ResponseFormat object is of unexpected subtype '{type(request.response_format)}'") + + logprob_options = dict() + if request.logprobs is not None: + logprob_options["logprobs"] = request.logprobs.top_k + + # Marshall together all the arguments for a ChatCompletionRequest + request_options = { + "model": request.model, + "messages": converted_messages, + "tool_choice": converted_tool_choice, + "stream": request.stream, + **sampling_options, + **guided_decoding_options, + **logprob_options, + } + + return request_options + + +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import warnings +from typing import Any, AsyncGenerator, Dict + +from openai import AsyncStream +from openai.types.chat.chat_completion import ( + Choice as OpenAIChoice, +) +from openai.types.completion import Completion as OpenAICompletion +from openai.types.completion_choice import Logprobs as OpenAICompletionLogprobs + +from llama_stack.apis.inference import ( + ChatCompletionRequest, + CompletionRequest, + CompletionResponse, + CompletionResponseStreamChunk, + TokenLogProbs, +) +from llama_stack.models.llama.datatypes import ( + GreedySamplingStrategy, + TopKSamplingStrategy, + TopPSamplingStrategy, +) +from llama_stack.providers.utils.inference.openai_compat import ( + _convert_openai_finish_reason, + convert_message_to_openai_dict_new, + convert_tooldef_to_openai_tool, +) + + +async def convert_chat_completion_request( + request: ChatCompletionRequest, + n: int = 1, +) -> dict: + """ + Convert a ChatCompletionRequest to an OpenAI API-compatible dictionary. + """ + # model -> model + # messages -> messages + # sampling_params TODO(mattf): review strategy + # strategy=greedy -> nvext.top_k = -1, temperature = temperature + # strategy=top_p -> nvext.top_k = -1, top_p = top_p + # strategy=top_k -> nvext.top_k = top_k + # temperature -> temperature + # top_p -> top_p + # top_k -> nvext.top_k + # max_tokens -> max_tokens + # repetition_penalty -> nvext.repetition_penalty + # response_format -> GrammarResponseFormat TODO(mf) + # response_format -> JsonSchemaResponseFormat: response_format = "json_object" & nvext["guided_json"] = json_schema + # tools -> tools + # tool_choice ("auto", "required") -> tool_choice + # tool_prompt_format -> TBD + # stream -> stream + # logprobs -> logprobs + + if request.response_format and not isinstance(request.response_format, JsonSchemaResponseFormat): + raise ValueError( + f"Unsupported response format: {request.response_format}. Only JsonSchemaResponseFormat is supported." + ) + + nvext = {} + payload: Dict[str, Any] = dict( + model=request.model, + messages=[await convert_message_to_openai_dict_new(message) for message in request.messages], + stream=request.stream, + n=n, + extra_body=dict(nvext=nvext), + extra_headers={ + b"User-Agent": b"llama-stack: nvidia-inference-adapter", + }, + ) + + if request.response_format: + # server bug - setting guided_json changes the behavior of response_format resulting in an error + # payload.update(response_format="json_object") + nvext.update(guided_json=request.response_format.json_schema) + + if request.tools: + payload.update(tools=[convert_tooldef_to_openai_tool(tool) for tool in request.tools]) + if request.tool_config.tool_choice: + payload.update( + tool_choice=request.tool_config.tool_choice.value + ) # we cannot include tool_choice w/o tools, server will complain + + if request.logprobs: + payload.update(logprobs=True) + payload.update(top_logprobs=request.logprobs.top_k) + + if request.sampling_params: + nvext.update(repetition_penalty=request.sampling_params.repetition_penalty) + + if request.sampling_params.max_tokens: + payload.update(max_tokens=request.sampling_params.max_tokens) + + strategy = request.sampling_params.strategy + if isinstance(strategy, TopPSamplingStrategy): + nvext.update(top_k=-1) + payload.update(top_p=strategy.top_p) + payload.update(temperature=strategy.temperature) + elif isinstance(strategy, TopKSamplingStrategy): + if strategy.top_k != -1 and strategy.top_k < 1: + warnings.warn("top_k must be -1 or >= 1", stacklevel=2) + nvext.update(top_k=strategy.top_k) + elif isinstance(strategy, GreedySamplingStrategy): + nvext.update(top_k=-1) + else: + raise ValueError(f"Unsupported sampling strategy: {strategy}") + + return payload + + +def convert_completion_request( + request: CompletionRequest, + n: int = 1, +) -> dict: + """ + Convert a ChatCompletionRequest to an OpenAI API-compatible dictionary. + """ + # model -> model + # prompt -> prompt + # sampling_params TODO(mattf): review strategy + # strategy=greedy -> nvext.top_k = -1, temperature = temperature + # strategy=top_p -> nvext.top_k = -1, top_p = top_p + # strategy=top_k -> nvext.top_k = top_k + # temperature -> temperature + # top_p -> top_p + # top_k -> nvext.top_k + # max_tokens -> max_tokens + # repetition_penalty -> nvext.repetition_penalty + # response_format -> nvext.guided_json + # stream -> stream + # logprobs.top_k -> logprobs + + nvext = {} + payload: Dict[str, Any] = dict( + model=request.model, + prompt=request.content, + stream=request.stream, + extra_body=dict(nvext=nvext), + extra_headers={ + b"User-Agent": b"llama-stack: nvidia-inference-adapter", + }, + n=n, + ) + + if request.response_format: + # this is not openai compliant, it is a nim extension + nvext.update(guided_json=request.response_format.json_schema) + + if request.logprobs: + payload.update(logprobs=request.logprobs.top_k) + + if request.sampling_params: + nvext.update(repetition_penalty=request.sampling_params.repetition_penalty) + + if request.sampling_params.max_tokens: + payload.update(max_tokens=request.sampling_params.max_tokens) + + if request.sampling_params.strategy == "top_p": + nvext.update(top_k=-1) + payload.update(top_p=request.sampling_params.top_p) + elif request.sampling_params.strategy == "top_k": + if request.sampling_params.top_k != -1 and request.sampling_params.top_k < 1: + warnings.warn("top_k must be -1 or >= 1", stacklevel=2) + nvext.update(top_k=request.sampling_params.top_k) + elif request.sampling_params.strategy == "greedy": + nvext.update(top_k=-1) + payload.update(temperature=request.sampling_params.temperature) + + return payload + + +def _convert_openai_completion_logprobs( + logprobs: Optional[OpenAICompletionLogprobs], +) -> Optional[List[TokenLogProbs]]: + """ + Convert an OpenAI CompletionLogprobs into a list of TokenLogProbs. + """ + if not logprobs: + return None + + return [TokenLogProbs(logprobs_by_token=logprobs) for logprobs in logprobs.top_logprobs] + + +def convert_openai_completion_choice( + choice: OpenAIChoice, +) -> CompletionResponse: + """ + Convert an OpenAI Completion Choice into a CompletionResponse. + """ + return CompletionResponse( + content=choice.text, + stop_reason=_convert_openai_finish_reason(choice.finish_reason), + logprobs=_convert_openai_completion_logprobs(choice.logprobs), + ) + + +async def convert_openai_completion_stream( + stream: AsyncStream[OpenAICompletion], +) -> AsyncGenerator[CompletionResponse, None]: + """ + Convert a stream of OpenAI Completions into a stream + of ChatCompletionResponseStreamChunks. + """ + async for chunk in stream: + choice = chunk.choices[0] + yield CompletionResponseStreamChunk( + delta=choice.text, + stop_reason=_convert_openai_finish_reason(choice.finish_reason), + logprobs=_convert_openai_completion_logprobs(choice.logprobs), + ) diff --git a/llama_stack/providers/remote/inference/ramalama/ramalama.py b/llama_stack/providers/remote/inference/ramalama/ramalama.py new file mode 100644 index 000000000..4478935c0 --- /dev/null +++ b/llama_stack/providers/remote/inference/ramalama/ramalama.py @@ -0,0 +1,188 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + + +from typing import AsyncGenerator, List, Optional + +from openai import AsyncOpenAI, BadRequestError + +from llama_stack.apis.common.content_types import ( + InterleavedContent, + InterleavedContentItem, + TextContentItem, +) +from llama_stack.apis.inference import ( + ChatCompletionRequest, + CompletionRequest, + EmbeddingsResponse, + EmbeddingTaskType, + Inference, + LogProbConfig, + Message, + ResponseFormat, + SamplingParams, + TextTruncation, + ToolChoice, + ToolConfig, + ToolDefinition, + ToolPromptFormat, +) +from llama_stack.apis.models import Model +from llama_stack.log import get_logger +from llama_stack.providers.datatypes import ModelsProtocolPrivate +from llama_stack.providers.utils.inference.model_registry import ( + ModelRegistryHelper, +) +from llama_stack.providers.utils.inference.openai_compat import ( + convert_openai_chat_completion_choice, + convert_openai_chat_completion_stream, +) + +from .models import model_entries +from .openai_utils import ( + convert_chat_completion_request, + convert_completion_request, + convert_openai_completion_choice, + convert_openai_completion_stream, +) + +logger = get_logger(name=__name__, category="inference") + + +class RamalamaInferenceAdapter(Inference, ModelsProtocolPrivate): + def __init__(self, url: str) -> None: + self.register_helper = ModelRegistryHelper(model_entries) + self.url = url + + async def initialize(self) -> None: + logger.info(f"checking connectivity to Ramalama at `{self.url}`...") + self.client = AsyncOpenAI(base_url=self.url, api_key="NO KEY") + + async def shutdown(self) -> None: + pass + + async def unregister_model(self, model_id: str) -> None: + pass + + async def completion( + self, + model_id: str, + content: InterleavedContent, + sampling_params: Optional[SamplingParams] = None, + response_format: Optional[ResponseFormat] = None, + stream: Optional[bool] = False, + logprobs: Optional[LogProbConfig] = None, + ) -> AsyncGenerator: + if sampling_params is None: + sampling_params = SamplingParams() + model = await self.model_store.get_model(model_id) + request = convert_completion_request( + request=CompletionRequest( + model=model.provider_resource_id, + content=content, + sampling_params=sampling_params, + response_format=response_format, + stream=stream, + logprobs=logprobs, + ) + ) + + response = await self.client.completions.create(**request) + if stream: + return convert_openai_completion_stream(response) + else: + # we pass n=1 to get only one completion + return convert_openai_completion_choice(response.choices[0]) + + async def chat_completion( + self, + model_id: str, + messages: List[Message], + sampling_params: Optional[SamplingParams] = None, + response_format: Optional[ResponseFormat] = None, + tools: Optional[List[ToolDefinition]] = None, + tool_choice: Optional[ToolChoice] = ToolChoice.auto, + tool_prompt_format: Optional[ToolPromptFormat] = None, + stream: Optional[bool] = False, + logprobs: Optional[LogProbConfig] = None, + tool_config: Optional[ToolConfig] = None, + ) -> AsyncGenerator: + if sampling_params is None: + sampling_params = SamplingParams() + model = await self.model_store.get_model(model_id) + request = await convert_chat_completion_request( + request=ChatCompletionRequest( + model=model.provider_resource_id, + messages=messages, + sampling_params=sampling_params, + tools=tools or [], + stream=stream, + logprobs=logprobs, + response_format=response_format, + tool_config=tool_config, + ), + n=1, + ) + s = await self.client.chat.completions.create(**request) + if stream: + return convert_openai_chat_completion_stream(s, enable_incremental_tool_calls=False) + else: + # we pass n=1 to get only one completion + return convert_openai_chat_completion_choice(s.choices[0]) + + async def embeddings( + self, + model_id: str, + contents: List[str] | List[InterleavedContentItem], + text_truncation: Optional[TextTruncation] = TextTruncation.none, + output_dimension: Optional[int] = None, + task_type: Optional[EmbeddingTaskType] = None, + ) -> EmbeddingsResponse: + flat_contents = [content.text if isinstance(content, TextContentItem) else content for content in contents] + input = [content.text if isinstance(content, TextContentItem) else content for content in flat_contents] + model = self.get_provider_model_id(model_id) + + extra_body = {} + + if text_truncation is not None: + text_truncation_options = { + TextTruncation.none: "NONE", + TextTruncation.end: "END", + TextTruncation.start: "START", + } + extra_body["truncate"] = text_truncation_options[text_truncation] + + if output_dimension is not None: + extra_body["dimensions"] = output_dimension + + if task_type is not None: + task_type_options = { + EmbeddingTaskType.document: "passage", + EmbeddingTaskType.query: "query", + } + extra_body["input_type"] = task_type_options[task_type] + + try: + response = await self._client.embeddings.create( + model=model, + input=input, + extra_body=extra_body, + ) + except BadRequestError as e: + raise ValueError(f"Failed to get embeddings: {e}") from e + + return EmbeddingsResponse(embeddings=[embedding.embedding for embedding in response.data]) + + async def register_model(self, model: Model) -> Model: + model = await self.register_helper.register_model(model) + res = await self.client.models.list() + available_models = [m.id async for m in res] + if model.provider_resource_id not in available_models: + raise ValueError( + f"Model {model.provider_resource_id} is not being served by vLLM. " + f"Available models: {', '.join(available_models)}" + ) + return model diff --git a/pyproject.toml b/pyproject.toml index 47d845c30..d1735319f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -260,6 +260,7 @@ exclude = [ "^llama_stack/providers/remote/inference/nvidia/", "^llama_stack/providers/remote/inference/openai/", "^llama_stack/providers/remote/inference/passthrough/", + "^llama_stack/providers/remote/inference/ramalama/", "^llama_stack/providers/remote/inference/runpod/", "^llama_stack/providers/remote/inference/sambanova/", "^llama_stack/providers/remote/inference/sample/",