diff --git a/README.md b/README.md index 8c201e43d..9fe41f85d 100644 --- a/README.md +++ b/README.md @@ -119,6 +119,7 @@ Here is a list of the various API providers and available distributions that can | OpenAI | Hosted | | ✅ | | | | | Anthropic | Hosted | | ✅ | | | | | Gemini | Hosted | | ✅ | | | | +| Ramalama | Single Node | | ✅ | | | | ### Distributions diff --git a/docs/source/distributions/self_hosted_distro/ramalama.md b/docs/source/distributions/self_hosted_distro/ramalama.md new file mode 100644 index 000000000..84eb352bd --- /dev/null +++ b/docs/source/distributions/self_hosted_distro/ramalama.md @@ -0,0 +1,194 @@ +--- +orphan: true +--- + +# RamaLama Distribution + +```{toctree} +:maxdepth: 2 +:hidden: + +self +``` + +The `llamastack/distribution-ramalama` distribution consists of the following provider configurations. + +| API | Provider(s) | +|-----|-------------| +| agents | `inline::meta-reference` | +| datasetio | `remote::huggingface`, `inline::localfs` | +| eval | `inline::meta-reference` | +| inference | `remote::ramalama` | +| safety | `inline::llama-guard` | +| scoring | `inline::basic`, `inline::llm-as-judge`, `inline::braintrust` | +| telemetry | `inline::meta-reference` | +| tool_runtime | `remote::brave-search`, `remote::tavily-search`, `inline::code-interpreter`, `inline::rag-runtime` | +| vector_io | `inline::faiss`, `remote::chromadb`, `remote::pgvector` | + + +You should use this distribution if you have a regular desktop machine without very powerful GPUs. Of course, if you have powerful GPUs, you can still continue using this distribution since RamaLama supports GPU acceleration. + +### Environment Variables + +The following environment variables can be configured: + +- `LLAMA_STACK_PORT`: Port for the Llama Stack distribution server (default: `8321`) +- `RAMALAMA_URL`: URL of the RamaLama server (default: `http://0.0.0.0:8080/v1`) +- `INFERENCE_MODEL`: Inference model loaded into the RamaLama server (default: `meta-llama/Llama-3.2-3B-Instruct`) +- `SAFETY_MODEL`: Safety model loaded into the RamaLama server (default: `meta-llama/Llama-Guard-3-1B`) + + +## Setting up RamaLama server + +Please check the [RamaLama Documentation](https://github.com/containers/ramalama) on how to install and run RamaLama. After installing RamaLama, you need to run `ramalama serve` to start the server. + +In order to load models, you can run: + +```bash +export RAMALAMA_INFERENCE_MODEL="llama3.2:3b-instruct-fp16" + +export INFERENCE_MODEL="~/path_to_model/meta-llama/Llama-3.2-3B-Instruct" + +ramalama serve $RAMALAMA_INFERENCE_MODEL +``` +RamaLama requires the inference model to be the fully qualified path to the model on disk when running on MacOS, on Linux it can just be the model name. + +If you are using Llama Stack Safety / Shield APIs, you will also need to pull and run the safety model. + +```bash +export SAFETY_MODEL="meta-llama/Llama-Guard-3-1B" + +# ramalama names this model differently, and we must use the ramalama name when loading the model +export RAMALAMA_SAFETY_MODEL="llama-guard3:1b" +ramalama run $RAMALAMA_SAFETY_MODEL --keepalive 60m +``` + +## Running Llama Stack + +Now you are ready to run Llama Stack with RamaLama as the inference provider. You can do this via Conda, Venv, or Podman which has a pre-built image. + +### Via Podman + +This method allows you to get started quickly without having to build the distribution code. + +```bash +export LLAMA_STACK_PORT=5001 +podman run \ + -it \ + -p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \ + -v ~/.llama:/root/.llama:z \ + llamastack/distribution-ramalama \ + --port $LLAMA_STACK_PORT \ + --env INFERENCE_MODEL=$INFERENCE_MODEL \ + --env RAMALAMA_URL=http://0.0.0.0:8080/v1 +``` + +If you are using Llama Stack Safety / Shield APIs, use: + +```bash +# You need a local checkout of llama-stack to run this, get it using +# git clone https://github.com/meta-llama/llama-stack.git +cd /path/to/llama-stack + +podman run \ + -it \ + -p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \ + -v ~/.llama:/root/.llama:z \ + -v ./llama_stack/templates/ramalama/run-with-safety.yaml:/root/my-run.yaml:z \ + llamastack/distribution-ramalama \ + --yaml-config /root/my-run.yaml \ + --port $LLAMA_STACK_PORT \ + --env INFERENCE_MODEL=$INFERENCE_MODEL \ + --env SAFETY_MODEL=$SAFETY_MODEL \ + --env RAMALAMA_URL=http://host.containers.internal:8080/v1 +``` + +### Via Docker + +This method allows you to get started quickly without having to build the distribution code. + +```bash +export LLAMA_STACK_PORT=5001 +docker run \ + -it \ + -p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \ + -v ~/.llama:/root/.llama \ + llamastack/distribution-ramalama \ + --port $LLAMA_STACK_PORT \ + --env INFERENCE_MODEL=$INFERENCE_MODEL \ + --env RAMALAMA_URL=http://host.docker.internal:8080/v1 +``` + +If you are using Llama Stack Safety / Shield APIs, use: + +```bash +# You need a local checkout of llama-stack to run this, get it using +# git clone https://github.com/meta-llama/llama-stack.git +cd /path/to/llama-stack + +docker run \ + -it \ + -p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \ + -v ~/.llama:/root/.llama \ + -v ./llama_stack/templates/ramalama/run-with-safety.yaml:/root/my-run.yaml \ + llamastack/distribution-ramalama \ + --yaml-config /root/my-run.yaml \ + --port $LLAMA_STACK_PORT \ + --env INFERENCE_MODEL=$INFERENCE_MODEL \ + --env SAFETY_MODEL=$SAFETY_MODEL \ + --env RAMALAMA_URL=http://host.docker.internal:8080/v1 +``` + +### Via Conda + +Make sure you have done `uv pip install llama-stack` and have the Llama Stack CLI available. + +```bash +export LLAMA_STACK_PORT=5001 + +llama stack build --template ramalama --image-type conda +llama stack run ./run.yaml \ + --port $LLAMA_STACK_PORT \ + --env INFERENCE_MODEL=$INFERENCE_MODEL \ + --env RAMALAMA_URL=http://host.docker.internal:8080/v1 +``` + +If you are using Llama Stack Safety / Shield APIs, use: + +```bash +llama stack run ./run-with-safety.yaml \ + --port $LLAMA_STACK_PORT \ + --env INFERENCE_MODEL=$INFERENCE_MODEL \ + --env SAFETY_MODEL=$SAFETY_MODEL \ + --env RAMALAMA_URL=http://host.docker.internal:8080/v1 +``` + + +### (Optional) Update Model Serving Configuration + +```{note} +Please check the [model_aliases](https://github.com/meta-llama/llama-stack/blob/main/llama_stack/providers/remote/inference/ramalama/ramalama.py#L45) for the supported RamaLama models. +``` + +To serve a new model with `ramalama` +```bash +ramalama run +``` + +To make sure that the model is being served correctly, run `ramalama ps` to get a list of models being served by ramalama. +``` +$ ramalama ps + +NAME ID SIZE PROCESSOR UNTIL +llama3.1:8b-instruct-fp16 4aacac419454 17 GB 100% GPU 4 minutes from now +``` + +To verify that the model served by ramalama is correctly connected to Llama Stack server +```bash +$ llama-stack-client models list ++----------------------+----------------------+---------------+-----------------------------------------------+ +| identifier | llama_model | provider_id | metadata | ++======================+======================+===============+===============================================+ +| Llama3.1-8B-Instruct | Llama3.1-8B-Instruct | ramalama0 | {'ramalama_model': 'llama3.1:8b-instruct-fp16'} | ++----------------------+----------------------+---------------+-----------------------------------------------+ +``` diff --git a/llama_stack/providers/remote/inference/ramalama/models.py b/llama_stack/providers/remote/inference/ramalama/models.py index be556762c..42e364105 100644 --- a/llama_stack/providers/remote/inference/ramalama/models.py +++ b/llama_stack/providers/remote/inference/ramalama/models.py @@ -5,7 +5,7 @@ # the root directory of this source tree. from llama_stack.apis.models.models import ModelType -from llama_stack.models.llama.datatypes import CoreModelId +from llama_stack.models.llama.sku_types import CoreModelId from llama_stack.providers.utils.inference.model_registry import ( ProviderModelEntry, build_hf_repo_model_entry, diff --git a/llama_stack/providers/remote/inference/ramalama/openai_utils.py b/llama_stack/providers/remote/inference/ramalama/openai_utils.py deleted file mode 100644 index 7458d28ea..000000000 --- a/llama_stack/providers/remote/inference/ramalama/openai_utils.py +++ /dev/null @@ -1,344 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# the root directory of this source tree. - -from typing import List, Optional - -from llama_stack.apis.inference import ( - ChatCompletionRequest, - GrammarResponseFormat, - JsonSchemaResponseFormat, - Message, - ToolChoice, - UserMessage, -) -from llama_stack.providers.utils.inference.openai_compat import ( - convert_message_to_openai_dict, - get_sampling_options, -) - - -def _merge_context_into_content(message: Message) -> Message: # type: ignore - """ - Merge the ``context`` field of a Llama Stack ``Message`` object into - the content field for compabilitiy with OpenAI-style APIs. - - Generates a content string that emulates the current behavior - of ``llama_models.llama3.api.chat_format.encode_message()``. - - :param message: Message that may include ``context`` field - - :returns: A version of ``message`` with any context merged into the - ``content`` field. - """ - if not isinstance(message, UserMessage): # Separate type check for linter - return message - if message.context is None: - return message - return UserMessage( - role=message.role, - # Emumate llama_models.llama3.api.chat_format.encode_message() - content=message.content + "\n\n" + message.context, - context=None, - ) - - -async def llama_stack_chat_completion_to_openai_chat_completion_dict( - request: ChatCompletionRequest, -) -> dict: - """ - Convert a chat completion request in Llama Stack format into an - equivalent set of arguments to pass to an OpenAI-compatible - chat completions API. - - :param request: Bundled request parameters in Llama Stack format. - - :returns: Dictionary of key-value pairs to use as an initializer - for a dataclass or to be converted directly to JSON and sent - over the wire. - """ - - converted_messages = [ - # This mystery async call makes the parent function also be async - await convert_message_to_openai_dict(_merge_context_into_content(m), download=True) - for m in request.messages - ] - # converted_tools = _llama_stack_tools_to_openai_tools(request.tools) - - # Llama will try to use built-in tools with no tool catalog, so don't enable - # tool choice unless at least one tool is enabled. - converted_tool_choice = "none" - if ( - request.tool_config is not None - and request.tool_config.tool_choice == ToolChoice.auto - and request.tools is not None - and len(request.tools) > 0 - ): - converted_tool_choice = "auto" - - # TODO: Figure out what to do with the tool_prompt_format argument. - # Other connectors appear to drop it quietly. - - # Use Llama Stack shared code to translate sampling parameters. - sampling_options = get_sampling_options(request.sampling_params) - - # get_sampling_options() translates repetition penalties to an option that - # OpenAI's APIs don't know about. - # vLLM's OpenAI-compatible API also handles repetition penalties wrong. - # For now, translate repetition penalties into a format that vLLM's broken - # API will handle correctly. Two wrongs make a right... - if "repeat_penalty" in sampling_options: - del sampling_options["repeat_penalty"] - if request.sampling_params.repetition_penalty is not None and request.sampling_params.repetition_penalty != 1.0: - sampling_options["repetition_penalty"] = request.sampling_params.repetition_penalty - - # Convert a single response format into four different parameters, per - # the OpenAI spec - guided_decoding_options = dict() - if request.response_format is None: - # Use defaults - pass - elif isinstance(request.response_format, JsonSchemaResponseFormat): - guided_decoding_options["guided_json"] = request.response_format.json_schema - elif isinstance(request.response_format, GrammarResponseFormat): - guided_decoding_options["guided_grammar"] = request.response_format.bnf - else: - raise TypeError(f"ResponseFormat object is of unexpected subtype '{type(request.response_format)}'") - - logprob_options = dict() - if request.logprobs is not None: - logprob_options["logprobs"] = request.logprobs.top_k - - # Marshall together all the arguments for a ChatCompletionRequest - request_options = { - "model": request.model, - "messages": converted_messages, - "tool_choice": converted_tool_choice, - "stream": request.stream, - **sampling_options, - **guided_decoding_options, - **logprob_options, - } - - return request_options - - -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# the root directory of this source tree. - -import warnings -from typing import Any, AsyncGenerator, Dict - -from openai import AsyncStream -from openai.types.chat.chat_completion import ( - Choice as OpenAIChoice, -) -from openai.types.completion import Completion as OpenAICompletion -from openai.types.completion_choice import Logprobs as OpenAICompletionLogprobs - -from llama_stack.apis.inference import ( - ChatCompletionRequest, - CompletionRequest, - CompletionResponse, - CompletionResponseStreamChunk, - TokenLogProbs, -) -from llama_stack.models.llama.datatypes import ( - GreedySamplingStrategy, - TopKSamplingStrategy, - TopPSamplingStrategy, -) -from llama_stack.providers.utils.inference.openai_compat import ( - _convert_openai_finish_reason, - convert_message_to_openai_dict_new, - convert_tooldef_to_openai_tool, -) - - -async def convert_chat_completion_request( - request: ChatCompletionRequest, - n: int = 1, -) -> dict: - """ - Convert a ChatCompletionRequest to an OpenAI API-compatible dictionary. - """ - # model -> model - # messages -> messages - # sampling_params TODO(mattf): review strategy - # strategy=greedy -> nvext.top_k = -1, temperature = temperature - # strategy=top_p -> nvext.top_k = -1, top_p = top_p - # strategy=top_k -> nvext.top_k = top_k - # temperature -> temperature - # top_p -> top_p - # top_k -> nvext.top_k - # max_tokens -> max_tokens - # repetition_penalty -> nvext.repetition_penalty - # response_format -> GrammarResponseFormat TODO(mf) - # response_format -> JsonSchemaResponseFormat: response_format = "json_object" & nvext["guided_json"] = json_schema - # tools -> tools - # tool_choice ("auto", "required") -> tool_choice - # tool_prompt_format -> TBD - # stream -> stream - # logprobs -> logprobs - - if request.response_format and not isinstance(request.response_format, JsonSchemaResponseFormat): - raise ValueError( - f"Unsupported response format: {request.response_format}. Only JsonSchemaResponseFormat is supported." - ) - - nvext = {} - payload: Dict[str, Any] = dict( - model=request.model, - messages=[await convert_message_to_openai_dict_new(message) for message in request.messages], - stream=request.stream, - n=n, - extra_body=dict(nvext=nvext), - extra_headers={ - b"User-Agent": b"llama-stack: nvidia-inference-adapter", - }, - ) - - if request.response_format: - # server bug - setting guided_json changes the behavior of response_format resulting in an error - # payload.update(response_format="json_object") - nvext.update(guided_json=request.response_format.json_schema) - - if request.tools: - payload.update(tools=[convert_tooldef_to_openai_tool(tool) for tool in request.tools]) - if request.tool_config.tool_choice: - payload.update( - tool_choice=request.tool_config.tool_choice.value - ) # we cannot include tool_choice w/o tools, server will complain - - if request.logprobs: - payload.update(logprobs=True) - payload.update(top_logprobs=request.logprobs.top_k) - - if request.sampling_params: - nvext.update(repetition_penalty=request.sampling_params.repetition_penalty) - - if request.sampling_params.max_tokens: - payload.update(max_tokens=request.sampling_params.max_tokens) - - strategy = request.sampling_params.strategy - if isinstance(strategy, TopPSamplingStrategy): - nvext.update(top_k=-1) - payload.update(top_p=strategy.top_p) - payload.update(temperature=strategy.temperature) - elif isinstance(strategy, TopKSamplingStrategy): - if strategy.top_k != -1 and strategy.top_k < 1: - warnings.warn("top_k must be -1 or >= 1", stacklevel=2) - nvext.update(top_k=strategy.top_k) - elif isinstance(strategy, GreedySamplingStrategy): - nvext.update(top_k=-1) - else: - raise ValueError(f"Unsupported sampling strategy: {strategy}") - - return payload - - -def convert_completion_request( - request: CompletionRequest, - n: int = 1, -) -> dict: - """ - Convert a ChatCompletionRequest to an OpenAI API-compatible dictionary. - """ - # model -> model - # prompt -> prompt - # sampling_params TODO(mattf): review strategy - # strategy=greedy -> nvext.top_k = -1, temperature = temperature - # strategy=top_p -> nvext.top_k = -1, top_p = top_p - # strategy=top_k -> nvext.top_k = top_k - # temperature -> temperature - # top_p -> top_p - # top_k -> nvext.top_k - # max_tokens -> max_tokens - # repetition_penalty -> nvext.repetition_penalty - # response_format -> nvext.guided_json - # stream -> stream - # logprobs.top_k -> logprobs - - nvext = {} - payload: Dict[str, Any] = dict( - model=request.model, - prompt=request.content, - stream=request.stream, - extra_body=dict(nvext=nvext), - extra_headers={ - b"User-Agent": b"llama-stack: nvidia-inference-adapter", - }, - n=n, - ) - - if request.response_format: - # this is not openai compliant, it is a nim extension - nvext.update(guided_json=request.response_format.json_schema) - - if request.logprobs: - payload.update(logprobs=request.logprobs.top_k) - - if request.sampling_params: - nvext.update(repetition_penalty=request.sampling_params.repetition_penalty) - - if request.sampling_params.max_tokens: - payload.update(max_tokens=request.sampling_params.max_tokens) - - if request.sampling_params.strategy == "top_p": - nvext.update(top_k=-1) - payload.update(top_p=request.sampling_params.top_p) - elif request.sampling_params.strategy == "top_k": - if request.sampling_params.top_k != -1 and request.sampling_params.top_k < 1: - warnings.warn("top_k must be -1 or >= 1", stacklevel=2) - nvext.update(top_k=request.sampling_params.top_k) - elif request.sampling_params.strategy == "greedy": - nvext.update(top_k=-1) - payload.update(temperature=request.sampling_params.temperature) - - return payload - - -def _convert_openai_completion_logprobs( - logprobs: Optional[OpenAICompletionLogprobs], -) -> Optional[List[TokenLogProbs]]: - """ - Convert an OpenAI CompletionLogprobs into a list of TokenLogProbs. - """ - if not logprobs: - return None - - return [TokenLogProbs(logprobs_by_token=logprobs) for logprobs in logprobs.top_logprobs] - - -def convert_openai_completion_choice( - choice: OpenAIChoice, -) -> CompletionResponse: - """ - Convert an OpenAI Completion Choice into a CompletionResponse. - """ - return CompletionResponse( - content=choice.text, - stop_reason=_convert_openai_finish_reason(choice.finish_reason), - logprobs=_convert_openai_completion_logprobs(choice.logprobs), - ) - - -async def convert_openai_completion_stream( - stream: AsyncStream[OpenAICompletion], -) -> AsyncGenerator[CompletionResponse, None]: - """ - Convert a stream of OpenAI Completions into a stream - of ChatCompletionResponseStreamChunks. - """ - async for chunk in stream: - choice = chunk.choices[0] - yield CompletionResponseStreamChunk( - delta=choice.text, - stop_reason=_convert_openai_finish_reason(choice.finish_reason), - logprobs=_convert_openai_completion_logprobs(choice.logprobs), - ) diff --git a/llama_stack/providers/remote/inference/ramalama/ramalama.py b/llama_stack/providers/remote/inference/ramalama/ramalama.py index 4478935c0..7d0d66e22 100644 --- a/llama_stack/providers/remote/inference/ramalama/ramalama.py +++ b/llama_stack/providers/remote/inference/ramalama/ramalama.py @@ -5,9 +5,12 @@ # the root directory of this source tree. -from typing import AsyncGenerator, List, Optional +from typing import Any, AsyncGenerator, AsyncIterator, Dict, List, Optional, Union from openai import AsyncOpenAI, BadRequestError +from openai.types.chat.chat_completion_chunk import ( + ChatCompletionChunk as OpenAIChatCompletionChunk, +) from llama_stack.apis.common.content_types import ( InterleavedContent, @@ -30,6 +33,12 @@ from llama_stack.apis.inference import ( ToolDefinition, ToolPromptFormat, ) +from llama_stack.apis.inference.inference import ( + OpenAIChatCompletion, + OpenAICompletion, + OpenAIMessageParam, + OpenAIResponseFormatParam, +) from llama_stack.apis.models import Model from llama_stack.log import get_logger from llama_stack.providers.datatypes import ModelsProtocolPrivate @@ -37,17 +46,16 @@ from llama_stack.providers.utils.inference.model_registry import ( ModelRegistryHelper, ) from llama_stack.providers.utils.inference.openai_compat import ( + convert_chat_completion_request, + convert_completion_request, convert_openai_chat_completion_choice, convert_openai_chat_completion_stream, + convert_openai_completion_choice, + convert_openai_completion_stream, + prepare_openai_completion_params, ) from .models import model_entries -from .openai_utils import ( - convert_chat_completion_request, - convert_completion_request, - convert_openai_completion_choice, - convert_openai_completion_stream, -) logger = get_logger(name=__name__, category="inference") @@ -180,9 +188,132 @@ class RamalamaInferenceAdapter(Inference, ModelsProtocolPrivate): model = await self.register_helper.register_model(model) res = await self.client.models.list() available_models = [m.id async for m in res] - if model.provider_resource_id not in available_models: + # Ramalama handles paths on MacOS and Linux differently + if (model.provider_resource_id.split("/")[-1] not in available_models) and ( + model.provider_resource_id not in available_models + ): raise ValueError( - f"Model {model.provider_resource_id} is not being served by vLLM. " + f"Model {model.provider_resource_id} is not being served by Ramalama. " f"Available models: {', '.join(available_models)}" ) return model + + async def openai_completion( + self, + model: str, + prompt: Union[str, List[str], List[int], List[List[int]]], + best_of: Optional[int] = None, + echo: Optional[bool] = None, + frequency_penalty: Optional[float] = None, + logit_bias: Optional[Dict[str, float]] = None, + logprobs: Optional[bool] = None, + max_tokens: Optional[int] = None, + n: Optional[int] = None, + presence_penalty: Optional[float] = None, + seed: Optional[int] = None, + stop: Optional[Union[str, List[str]]] = None, + stream: Optional[bool] = None, + stream_options: Optional[Dict[str, Any]] = None, + temperature: Optional[float] = None, + top_p: Optional[float] = None, + user: Optional[str] = None, + guided_choice: Optional[List[str]] = None, + prompt_logprobs: Optional[int] = None, + ) -> OpenAICompletion: + model_obj = await self.model_store.get_model(model) + params = await prepare_openai_completion_params( + model=model_obj.provider_resource_id, + prompt=prompt, + best_of=best_of, + echo=echo, + frequency_penalty=frequency_penalty, + logit_bias=logit_bias, + logprobs=logprobs, + max_tokens=max_tokens, + n=n, + presence_penalty=presence_penalty, + seed=seed, + stop=stop, + stream=stream, + stream_options=stream_options, + temperature=temperature, + top_p=top_p, + user=user, + ) + return await self.client.completions.create(**params) # type: ignore + + async def openai_chat_completion( + self, + model: str, + messages: List[OpenAIMessageParam], + frequency_penalty: Optional[float] = None, + function_call: Optional[Union[str, Dict[str, Any]]] = None, + functions: Optional[List[Dict[str, Any]]] = None, + logit_bias: Optional[Dict[str, float]] = None, + logprobs: Optional[bool] = None, + max_completion_tokens: Optional[int] = None, + max_tokens: Optional[int] = None, + n: Optional[int] = None, + parallel_tool_calls: Optional[bool] = None, + presence_penalty: Optional[float] = None, + response_format: Optional[OpenAIResponseFormatParam] = None, + seed: Optional[int] = None, + stop: Optional[Union[str, List[str]]] = None, + stream: Optional[bool] = None, + stream_options: Optional[Dict[str, Any]] = None, + temperature: Optional[float] = None, + tool_choice: Optional[Union[str, Dict[str, Any]]] = None, + tools: Optional[List[Dict[str, Any]]] = None, + top_logprobs: Optional[int] = None, + top_p: Optional[float] = None, + user: Optional[str] = None, + ) -> Union[OpenAIChatCompletion, AsyncIterator[OpenAIChatCompletionChunk]]: + model_obj = await self.model_store.get_model(model) + params = await prepare_openai_completion_params( + model=model_obj.provider_resource_id, + messages=messages, + frequency_penalty=frequency_penalty, + function_call=function_call, + functions=functions, + logit_bias=logit_bias, + logprobs=logprobs, + max_completion_tokens=max_completion_tokens, + max_tokens=max_tokens, + n=n, + parallel_tool_calls=parallel_tool_calls, + presence_penalty=presence_penalty, + response_format=response_format, + seed=seed, + stop=stop, + stream=stream, + stream_options=stream_options, + temperature=temperature, + tool_choice=tool_choice, + tools=tools, + top_logprobs=top_logprobs, + top_p=top_p, + user=user, + ) + return await self.client.chat.completions.create(**params) # type: ignore + + async def batch_completion( + self, + model_id: str, + content_batch: List[InterleavedContent], + sampling_params: Optional[SamplingParams] = None, + response_format: Optional[ResponseFormat] = None, + logprobs: Optional[LogProbConfig] = None, + ): + raise NotImplementedError("Batch completion is not supported for Ramalama") + + async def batch_chat_completion( + self, + model_id: str, + messages_batch: List[List[Message]], + sampling_params: Optional[SamplingParams] = None, + tools: Optional[List[ToolDefinition]] = None, + tool_config: Optional[ToolConfig] = None, + response_format: Optional[ResponseFormat] = None, + logprobs: Optional[LogProbConfig] = None, + ): + raise NotImplementedError("Batch chat completion is not supported for Ramalama") diff --git a/llama_stack/providers/utils/inference/openai_compat.py b/llama_stack/providers/utils/inference/openai_compat.py index f91e7d7dc..82a92c751 100644 --- a/llama_stack/providers/utils/inference/openai_compat.py +++ b/llama_stack/providers/utils/inference/openai_compat.py @@ -78,6 +78,7 @@ from openai.types.chat.chat_completion_content_part_image_param import ( from openai.types.chat.chat_completion_message_tool_call_param import ( Function as OpenAIFunction, ) +from openai.types.completion_choice import Logprobs as OpenAICompletionLogprobs from pydantic import BaseModel from llama_stack.apis.common.content_types import ( @@ -97,8 +98,10 @@ from llama_stack.apis.inference import ( ChatCompletionResponseEventType, ChatCompletionResponseStreamChunk, CompletionMessage, + CompletionRequest, CompletionResponse, CompletionResponseStreamChunk, + GrammarResponseFormat, GreedySamplingStrategy, Message, SamplingParams, @@ -1466,3 +1469,292 @@ class OpenAIChatCompletionToLlamaStackMixin: model=model, object="chat.completion", ) + + +async def convert_chat_completion_request( + request: ChatCompletionRequest, + n: int = 1, +) -> dict: + """ + Convert a ChatCompletionRequest to an OpenAI API-compatible dictionary. + """ + # model -> model + # messages -> messages + # sampling_params TODO(mattf): review strategy + # strategy=greedy -> nvext.top_k = -1, temperature = temperature + # strategy=top_p -> nvext.top_k = -1, top_p = top_p + # strategy=top_k -> nvext.top_k = top_k + # temperature -> temperature + # top_p -> top_p + # top_k -> nvext.top_k + # max_tokens -> max_tokens + # repetition_penalty -> nvext.repetition_penalty + # response_format -> GrammarResponseFormat TODO(mf) + # response_format -> JsonSchemaResponseFormat: response_format = "json_object" & nvext["guided_json"] = json_schema + # tools -> tools + # tool_choice ("auto", "required") -> tool_choice + # tool_prompt_format -> TBD + # stream -> stream + # logprobs -> logprobs + + if request.response_format and not isinstance(request.response_format, JsonSchemaResponseFormat): + raise ValueError( + f"Unsupported response format: {request.response_format}. Only JsonSchemaResponseFormat is supported." + ) + + nvext = {} + payload: Dict[str, Any] = dict( + model=request.model, + messages=[await convert_message_to_openai_dict_new(message) for message in request.messages], + stream=request.stream, + n=n, + extra_body=dict(nvext=nvext), + extra_headers={ + b"User-Agent": b"llama-stack: nvidia-inference-adapter", + }, + ) + + if request.response_format: + # server bug - setting guided_json changes the behavior of response_format resulting in an error + # payload.update(response_format="json_object") + nvext.update(guided_json=request.response_format.json_schema) + + if request.tools: + payload.update(tools=[convert_tooldef_to_openai_tool(tool) for tool in request.tools]) + if request.tool_config.tool_choice: + payload.update( + tool_choice=request.tool_config.tool_choice.value + ) # we cannot include tool_choice w/o tools, server will complain + + if request.logprobs: + payload.update(logprobs=True) + payload.update(top_logprobs=request.logprobs.top_k) + + if request.sampling_params: + nvext.update(repetition_penalty=request.sampling_params.repetition_penalty) + + if request.sampling_params.max_tokens: + payload.update(max_tokens=request.sampling_params.max_tokens) + + strategy = request.sampling_params.strategy + if isinstance(strategy, TopPSamplingStrategy): + nvext.update(top_k=-1) + payload.update(top_p=strategy.top_p) + payload.update(temperature=strategy.temperature) + elif isinstance(strategy, TopKSamplingStrategy): + if strategy.top_k != -1 and strategy.top_k < 1: + warnings.warn("top_k must be -1 or >= 1", stacklevel=2) + nvext.update(top_k=strategy.top_k) + elif isinstance(strategy, GreedySamplingStrategy): + nvext.update(top_k=-1) + else: + raise ValueError(f"Unsupported sampling strategy: {strategy}") + + return payload + + +def convert_completion_request( + request: CompletionRequest, + n: int = 1, +) -> dict: + """ + Convert a ChatCompletionRequest to an OpenAI API-compatible dictionary. + """ + # model -> model + # prompt -> prompt + # sampling_params TODO(mattf): review strategy + # strategy=greedy -> nvext.top_k = -1, temperature = temperature + # strategy=top_p -> nvext.top_k = -1, top_p = top_p + # strategy=top_k -> nvext.top_k = top_k + # temperature -> temperature + # top_p -> top_p + # top_k -> nvext.top_k + # max_tokens -> max_tokens + # repetition_penalty -> nvext.repetition_penalty + # response_format -> nvext.guided_json + # stream -> stream + # logprobs.top_k -> logprobs + + nvext = {} + payload: Dict[str, Any] = dict( + model=request.model, + prompt=request.content, + stream=request.stream, + extra_body=dict(nvext=nvext), + extra_headers={ + b"User-Agent": b"llama-stack: nvidia-inference-adapter", + }, + n=n, + ) + + if request.response_format: + # this is not openai compliant, it is a nim extension + nvext.update(guided_json=request.response_format.json_schema) + + if request.logprobs: + payload.update(logprobs=request.logprobs.top_k) + + if request.sampling_params: + nvext.update(repetition_penalty=request.sampling_params.repetition_penalty) + + if request.sampling_params.max_tokens: + payload.update(max_tokens=request.sampling_params.max_tokens) + + if request.sampling_params.strategy == "top_p": + nvext.update(top_k=-1) + payload.update(top_p=request.sampling_params.top_p) + elif request.sampling_params.strategy == "top_k": + if request.sampling_params.top_k != -1 and request.sampling_params.top_k < 1: + warnings.warn("top_k must be -1 or >= 1", stacklevel=2) + nvext.update(top_k=request.sampling_params.top_k) + elif request.sampling_params.strategy == "greedy": + nvext.update(top_k=-1) + payload.update(temperature=request.sampling_params.temperature) + + return payload + + +def _convert_openai_completion_logprobs( + logprobs: Optional[OpenAICompletionLogprobs], +) -> Optional[List[TokenLogProbs]]: + """ + Convert an OpenAI CompletionLogprobs into a list of TokenLogProbs. + """ + if not logprobs: + return None + + return [TokenLogProbs(logprobs_by_token=logprobs) for logprobs in logprobs.top_logprobs] + + +def convert_openai_completion_choice( + choice: OpenAIChoice, +) -> CompletionResponse: + """ + Convert an OpenAI Completion Choice into a CompletionResponse. + """ + return CompletionResponse( + content=choice.text, + stop_reason=_convert_openai_finish_reason(choice.finish_reason), + logprobs=_convert_openai_completion_logprobs(choice.logprobs), + ) + + +async def convert_openai_completion_stream( + stream: AsyncStream[OpenAICompletion], +) -> AsyncGenerator[CompletionResponse, None]: + """ + Convert a stream of OpenAI Completions into a stream + of ChatCompletionResponseStreamChunks. + """ + async for chunk in stream: + choice = chunk.choices[0] + yield CompletionResponseStreamChunk( + delta=choice.text, + stop_reason=_convert_openai_finish_reason(choice.finish_reason), + logprobs=_convert_openai_completion_logprobs(choice.logprobs), + ) + + +def _merge_context_into_content(message: Message) -> Message: # type: ignore + """ + Merge the ``context`` field of a Llama Stack ``Message`` object into + the content field for compabilitiy with OpenAI-style APIs. + + Generates a content string that emulates the current behavior + of ``llama_models.llama3.api.chat_format.encode_message()``. + + :param message: Message that may include ``context`` field + + :returns: A version of ``message`` with any context merged into the + ``content`` field. + """ + if not isinstance(message, UserMessage): # Separate type check for linter + return message + if message.context is None: + return message + return UserMessage( + role=message.role, + # Emumate llama_models.llama3.api.chat_format.encode_message() + content=message.content + "\n\n" + message.context, + context=None, + ) + + +async def llama_stack_chat_completion_to_openai_chat_completion_dict( + request: ChatCompletionRequest, +) -> dict: + """ + Convert a chat completion request in Llama Stack format into an + equivalent set of arguments to pass to an OpenAI-compatible + chat completions API. + + :param request: Bundled request parameters in Llama Stack format. + + :returns: Dictionary of key-value pairs to use as an initializer + for a dataclass or to be converted directly to JSON and sent + over the wire. + """ + + converted_messages = [ + # This mystery async call makes the parent function also be async + await convert_message_to_openai_dict(_merge_context_into_content(m), download=True) + for m in request.messages + ] + # converted_tools = _llama_stack_tools_to_openai_tools(request.tools) + + # Llama will try to use built-in tools with no tool catalog, so don't enable + # tool choice unless at least one tool is enabled. + converted_tool_choice = "none" + if ( + request.tool_config is not None + and request.tool_config.tool_choice == ToolChoice.auto + and request.tools is not None + and len(request.tools) > 0 + ): + converted_tool_choice = "auto" + + # TODO: Figure out what to do with the tool_prompt_format argument. + # Other connectors appear to drop it quietly. + + # Use Llama Stack shared code to translate sampling parameters. + sampling_options = get_sampling_options(request.sampling_params) + + # get_sampling_options() translates repetition penalties to an option that + # OpenAI's APIs don't know about. + # vLLM's OpenAI-compatible API also handles repetition penalties wrong. + # For now, translate repetition penalties into a format that vLLM's broken + # API will handle correctly. Two wrongs make a right... + if "repeat_penalty" in sampling_options: + del sampling_options["repeat_penalty"] + if request.sampling_params.repetition_penalty is not None and request.sampling_params.repetition_penalty != 1.0: + sampling_options["repetition_penalty"] = request.sampling_params.repetition_penalty + + # Convert a single response format into four different parameters, per + # the OpenAI spec + guided_decoding_options = dict() + if request.response_format is None: + # Use defaults + pass + elif isinstance(request.response_format, JsonSchemaResponseFormat): + guided_decoding_options["guided_json"] = request.response_format.json_schema + elif isinstance(request.response_format, GrammarResponseFormat): + guided_decoding_options["guided_grammar"] = request.response_format.bnf + else: + raise TypeError(f"ResponseFormat object is of unexpected subtype '{type(request.response_format)}'") + + logprob_options = dict() + if request.logprobs is not None: + logprob_options["logprobs"] = request.logprobs.top_k + + # Marshall together all the arguments for a ChatCompletionRequest + request_options = { + "model": request.model, + "messages": converted_messages, + "tool_choice": converted_tool_choice, + "stream": request.stream, + **sampling_options, + **guided_decoding_options, + **logprob_options, + } + + return request_options diff --git a/llama_stack/templates/dependencies.json b/llama_stack/templates/dependencies.json index b96191752..49ed2163d 100644 --- a/llama_stack/templates/dependencies.json +++ b/llama_stack/templates/dependencies.json @@ -536,6 +536,43 @@ "sentence-transformers --no-deps", "torch torchvision --index-url https://download.pytorch.org/whl/cpu" ], + "ramalama": [ + "aiohttp", + "aiosqlite", + "autoevals", + "blobfile", + "chardet", + "chromadb-client", + "datasets", + "emoji", + "faiss-cpu", + "fastapi", + "fire", + "httpx", + "langdetect", + "matplotlib", + "nltk", + "numpy", + "openai", + "opentelemetry-exporter-otlp-proto-http", + "opentelemetry-sdk", + "pandas", + "pillow", + "psycopg2-binary", + "pymongo", + "pypdf", + "pythainlp", + "ramalama", + "redis", + "requests", + "scikit-learn", + "scipy", + "sentencepiece", + "tqdm", + "transformers", + "tree_sitter", + "uvicorn" + ], "remote-vllm": [ "aiosqlite", "autoevals", diff --git a/llama_stack/templates/ramalama/__init__.py b/llama_stack/templates/ramalama/__init__.py new file mode 100644 index 000000000..cdb8595fa --- /dev/null +++ b/llama_stack/templates/ramalama/__init__.py @@ -0,0 +1,7 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from .ramalama import get_distribution_template # noqa: F401 diff --git a/llama_stack/templates/ramalama/build.yaml b/llama_stack/templates/ramalama/build.yaml new file mode 100644 index 000000000..f6f7fcf4f --- /dev/null +++ b/llama_stack/templates/ramalama/build.yaml @@ -0,0 +1,31 @@ +version: '2' +distribution_spec: + description: Use (an external) RamaLama server for running LLM inference + providers: + inference: + - remote::ramalama + vector_io: + - inline::faiss + - remote::chromadb + - remote::pgvector + safety: + - inline::llama-guard + agents: + - inline::meta-reference + telemetry: + - inline::meta-reference + eval: + - inline::meta-reference + datasetio: + - remote::huggingface + - inline::localfs + scoring: + - inline::basic + - inline::llm-as-judge + - inline::braintrust + tool_runtime: + - remote::brave-search + - remote::tavily-search + - inline::code-interpreter + - inline::rag-runtime +image_type: conda diff --git a/llama_stack/templates/ramalama/doc_template.md b/llama_stack/templates/ramalama/doc_template.md new file mode 100644 index 000000000..37f3bce39 --- /dev/null +++ b/llama_stack/templates/ramalama/doc_template.md @@ -0,0 +1,183 @@ +--- +orphan: true +--- +# RamaLama Distribution + +```{toctree} +:maxdepth: 2 +:hidden: + +self +``` + +The `llamastack/distribution-{{ name }}` distribution consists of the following provider configurations. + +{{ providers_table }} + +You should use this distribution if you have a regular desktop machine without very powerful GPUs. Of course, if you have powerful GPUs, you can still continue using this distribution since RamaLama supports GPU acceleration. + +{% if run_config_env_vars %} +### Environment Variables + +The following environment variables can be configured: + +{% for var, (default_value, description) in run_config_env_vars.items() %} +- `{{ var }}`: {{ description }} (default: `{{ default_value }}`) +{% endfor %} +{% endif %} + + +## Setting up RamaLama server + +Please check the [RamaLama Documentation](https://github.com/containers/ramalama) on how to install and run RamaLama. After installing RamaLama, you need to run `ramalama serve` to start the server. + +In order to load models, you can run: + +```bash +export RAMALAMA_INFERENCE_MODEL="llama3.2:3b-instruct-fp16" + +export INFERENCE_MODEL="~/path_to_model/meta-llama/Llama-3.2-3B-Instruct" + +ramalama serve $RAMALAMA_INFERENCE_MODEL +``` +RamaLama requires the inference model to be the fully qualified path to the model on disk when running on MacOS, on Linux it can just be the model name. + +If you are using Llama Stack Safety / Shield APIs, you will also need to pull and run the safety model. + +```bash +export SAFETY_MODEL="meta-llama/Llama-Guard-3-1B" + +# ramalama names this model differently, and we must use the ramalama name when loading the model +export RAMALAMA_SAFETY_MODEL="llama-guard3:1b" +ramalama run $RAMALAMA_SAFETY_MODEL --keepalive 60m +``` + +## Running Llama Stack + +Now you are ready to run Llama Stack with RamaLama as the inference provider. You can do this via Conda, Venv, or Podman which has a pre-built image. + +### Via Podman + +This method allows you to get started quickly without having to build the distribution code. + +```bash +export LLAMA_STACK_PORT=5001 +podman run \ + -it \ + -p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \ + -v ~/.llama:/root/.llama:z \ + llamastack/distribution-{{ name }} \ + --port $LLAMA_STACK_PORT \ + --env INFERENCE_MODEL=$INFERENCE_MODEL \ + --env RAMALAMA_URL=http://0.0.0.0:8080/v1 +``` + +If you are using Llama Stack Safety / Shield APIs, use: + +```bash +# You need a local checkout of llama-stack to run this, get it using +# git clone https://github.com/meta-llama/llama-stack.git +cd /path/to/llama-stack + +podman run \ + -it \ + -p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \ + -v ~/.llama:/root/.llama:z \ + -v ./llama_stack/templates/ramalama/run-with-safety.yaml:/root/my-run.yaml:z \ + llamastack/distribution-{{ name }} \ + --yaml-config /root/my-run.yaml \ + --port $LLAMA_STACK_PORT \ + --env INFERENCE_MODEL=$INFERENCE_MODEL \ + --env SAFETY_MODEL=$SAFETY_MODEL \ + --env RAMALAMA_URL=http://host.containers.internal:8080/v1 +``` + +### Via Docker + +This method allows you to get started quickly without having to build the distribution code. + +```bash +export LLAMA_STACK_PORT=5001 +docker run \ + -it \ + -p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \ + -v ~/.llama:/root/.llama \ + llamastack/distribution-{{ name }} \ + --port $LLAMA_STACK_PORT \ + --env INFERENCE_MODEL=$INFERENCE_MODEL \ + --env RAMALAMA_URL=http://host.docker.internal:8080/v1 +``` + +If you are using Llama Stack Safety / Shield APIs, use: + +```bash +# You need a local checkout of llama-stack to run this, get it using +# git clone https://github.com/meta-llama/llama-stack.git +cd /path/to/llama-stack + +docker run \ + -it \ + -p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \ + -v ~/.llama:/root/.llama \ + -v ./llama_stack/templates/ramalama/run-with-safety.yaml:/root/my-run.yaml \ + llamastack/distribution-{{ name }} \ + --yaml-config /root/my-run.yaml \ + --port $LLAMA_STACK_PORT \ + --env INFERENCE_MODEL=$INFERENCE_MODEL \ + --env SAFETY_MODEL=$SAFETY_MODEL \ + --env RAMALAMA_URL=http://host.docker.internal:8080/v1 +``` + +### Via Conda + +Make sure you have done `uv pip install llama-stack` and have the Llama Stack CLI available. + +```bash +export LLAMA_STACK_PORT=5001 + +llama stack build --template {{ name }} --image-type conda +llama stack run ./run.yaml \ + --port $LLAMA_STACK_PORT \ + --env INFERENCE_MODEL=$INFERENCE_MODEL \ + --env RAMALAMA_URL=http://host.docker.internal:8080/v1 +``` + +If you are using Llama Stack Safety / Shield APIs, use: + +```bash +llama stack run ./run-with-safety.yaml \ + --port $LLAMA_STACK_PORT \ + --env INFERENCE_MODEL=$INFERENCE_MODEL \ + --env SAFETY_MODEL=$SAFETY_MODEL \ + --env RAMALAMA_URL=http://host.docker.internal:8080/v1 +``` + + +### (Optional) Update Model Serving Configuration + +```{note} +Please check the [model_aliases](https://github.com/meta-llama/llama-stack/blob/main/llama_stack/providers/remote/inference/ramalama/ramalama.py#L45) for the supported RamaLama models. +``` + +To serve a new model with `ramalama` +```bash +ramalama run +``` + +To make sure that the model is being served correctly, run `ramalama ps` to get a list of models being served by ramalama. +``` +$ ramalama ps + +NAME ID SIZE PROCESSOR UNTIL +llama3.1:8b-instruct-fp16 4aacac419454 17 GB 100% GPU 4 minutes from now +``` + +To verify that the model served by ramalama is correctly connected to Llama Stack server +```bash +$ llama-stack-client models list ++----------------------+----------------------+---------------+-----------------------------------------------+ +| identifier | llama_model | provider_id | metadata | ++======================+======================+===============+===============================================+ +| Llama3.1-8B-Instruct | Llama3.1-8B-Instruct | ramalama0 | {'ramalama_model': 'llama3.1:8b-instruct-fp16'} | ++----------------------+----------------------+---------------+-----------------------------------------------+ +``` diff --git a/llama_stack/templates/ramalama/ramalama.py b/llama_stack/templates/ramalama/ramalama.py new file mode 100644 index 000000000..4968dc604 --- /dev/null +++ b/llama_stack/templates/ramalama/ramalama.py @@ -0,0 +1,147 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from pathlib import Path + +from llama_stack.distribution.datatypes import ( + ModelInput, + Provider, + ShieldInput, + ToolGroupInput, +) +from llama_stack.providers.inline.inference.sentence_transformers import ( + SentenceTransformersInferenceConfig, +) +from llama_stack.providers.inline.vector_io.faiss.config import FaissVectorIOConfig +from llama_stack.providers.remote.inference.ramalama import RamalamaImplConfig +from llama_stack.templates.template import DistributionTemplate, RunConfigSettings + + +def get_distribution_template() -> DistributionTemplate: + providers = { + "inference": ["remote::ramalama"], + "vector_io": ["inline::faiss", "remote::chromadb", "remote::pgvector"], + "safety": ["inline::llama-guard"], + "agents": ["inline::meta-reference"], + "telemetry": ["inline::meta-reference"], + "eval": ["inline::meta-reference"], + "datasetio": ["remote::huggingface", "inline::localfs"], + "scoring": ["inline::basic", "inline::llm-as-judge", "inline::braintrust"], + "tool_runtime": [ + "remote::brave-search", + "remote::tavily-search", + "inline::code-interpreter", + "inline::rag-runtime", + ], + } + name = "ramalama" + inference_provider = Provider( + provider_id="ramalama", + provider_type="remote::ramalama", + config=RamalamaImplConfig.sample_run_config(), + ) + embedding_provider = Provider( + provider_id="sentence-transformers", + provider_type="inline::sentence-transformers", + config=SentenceTransformersInferenceConfig.sample_run_config(), + ) + vector_io_provider = Provider( + provider_id="faiss", + provider_type="inline::faiss", + config=FaissVectorIOConfig.sample_run_config(f"distributions/{name}"), + ) + inference_model = ModelInput( + model_id="${env.INFERENCE_MODEL}", + provider_id="ramalama", + ) + safety_model = ModelInput( + model_id="${env.SAFETY_MODEL}", + provider_id="ramalama", + ) + default_tool_groups = [ + ToolGroupInput( + toolgroup_id="builtin::websearch", + provider_id="tavily-search", + ), + ToolGroupInput( + toolgroup_id="builtin::rag", + provider_id="rag-runtime", + ), + ToolGroupInput( + toolgroup_id="builtin::code_interpreter", + provider_id="code-interpreter", + ), + ] + + return DistributionTemplate( + name=name, + distro_type="self_hosted", + description="Use (an external) RamaLama server for running LLM inference", + container_image=None, + template_path=Path(__file__).parent / "doc_template.md", + providers=providers, + run_configs={ + "run.yaml": RunConfigSettings( + provider_overrides={ + "inference": [inference_provider, embedding_provider], + "vector_io": [vector_io_provider], + }, + default_models=[inference_model], + default_tool_groups=default_tool_groups, + ), + "run-with-safety.yaml": RunConfigSettings( + provider_overrides={ + "inference": [ + inference_provider, + embedding_provider, + ], + "vector_io": [vector_io_provider], + "safety": [ + Provider( + provider_id="llama-guard", + provider_type="inline::llama-guard", + config={}, + ), + Provider( + provider_id="code-scanner", + provider_type="inline::code-scanner", + config={}, + ), + ], + }, + default_models=[inference_model, safety_model], + default_shields=[ + ShieldInput( + shield_id="${env.SAFETY_MODEL}", + provider_id="llama-guard", + ), + ShieldInput( + shield_id="CodeScanner", + provider_id="code-scanner", + ), + ], + default_tool_groups=default_tool_groups, + ), + }, + run_config_env_vars={ + "LLAMA_STACK_PORT": ( + "8321", + "Port for the Llama Stack distribution server", + ), + "RAMALAMA_URL": ( + "http://0.0.0.0:8080/v1", + "URL of the RamaLama server", + ), + "INFERENCE_MODEL": ( + "meta-llama/Llama-3.2-3B-Instruct", + "Inference model loaded into the RamaLama server", + ), + "SAFETY_MODEL": ( + "meta-llama/Llama-Guard-3-1B", + "Safety model loaded into the RamaLama server", + ), + }, + ) diff --git a/llama_stack/templates/ramalama/report.md b/llama_stack/templates/ramalama/report.md new file mode 100644 index 000000000..ac95d42f2 --- /dev/null +++ b/llama_stack/templates/ramalama/report.md @@ -0,0 +1,44 @@ +# Report for ramalama distribution + +## Supported Models +| Model Descriptor | ramalama | +|:---|:---| +| Llama-3-8B-Instruct | ❌ | +| Llama-3-70B-Instruct | ❌ | +| Llama3.1-8B-Instruct | ✅ | +| Llama3.1-70B-Instruct | ✅ | +| Llama3.1-405B-Instruct | ✅ | +| Llama3.2-1B-Instruct | ✅ | +| Llama3.2-3B-Instruct | ✅ | +| Llama3.2-11B-Vision-Instruct | ✅ | +| Llama3.2-90B-Vision-Instruct | ✅ | +| Llama3.3-70B-Instruct | ✅ | +| Llama-Guard-3-11B-Vision | ❌ | +| Llama-Guard-3-1B | ✅ | +| Llama-Guard-3-8B | ✅ | +| Llama-Guard-2-8B | ❌ | + +## Inference +| Model | API | Capability | Test | Status | +|:----- |:-----|:-----|:-----|:-----| +| Llama-3.1-8B-Instruct | /chat_completion | streaming | test_text_chat_completion_streaming | ✅ | +| Llama-3.2-11B-Vision-Instruct | /chat_completion | streaming | test_image_chat_completion_streaming | ❌ | +| Llama-3.2-11B-Vision-Instruct | /chat_completion | non_streaming | test_image_chat_completion_non_streaming | ❌ | +| Llama-3.1-8B-Instruct | /chat_completion | non_streaming | test_text_chat_completion_non_streaming | ✅ | +| Llama-3.1-8B-Instruct | /chat_completion | tool_calling | test_text_chat_completion_with_tool_calling_and_streaming | ✅ | +| Llama-3.1-8B-Instruct | /chat_completion | tool_calling | test_text_chat_completion_with_tool_calling_and_non_streaming | ✅ | +| Llama-3.1-8B-Instruct | /completion | streaming | test_text_completion_streaming | ✅ | +| Llama-3.1-8B-Instruct | /completion | non_streaming | test_text_completion_non_streaming | ✅ | +| Llama-3.1-8B-Instruct | /completion | structured_output | test_text_completion_structured_output | ✅ | + +## Vector IO +| API | Capability | Test | Status | +|:-----|:-----|:-----|:-----| +| /retrieve | | test_vector_db_retrieve | ✅ | + +## Agents +| API | Capability | Test | Status | +|:-----|:-----|:-----|:-----| +| /create_agent_turn | rag | test_rag_agent | ✅ | +| /create_agent_turn | custom_tool | test_custom_tool | ✅ | +| /create_agent_turn | code_execution | test_code_interpreter_for_attachments | ✅ | diff --git a/llama_stack/templates/ramalama/run-with-safety.yaml b/llama_stack/templates/ramalama/run-with-safety.yaml new file mode 100644 index 000000000..2170ed0b5 --- /dev/null +++ b/llama_stack/templates/ramalama/run-with-safety.yaml @@ -0,0 +1,132 @@ +version: '2' +image_name: ramalama +apis: +- agents +- datasetio +- eval +- inference +- safety +- scoring +- telemetry +- tool_runtime +- vector_io +providers: + inference: + - provider_id: ramalama + provider_type: remote::ramalama + config: + url: ${env.RAMALAMA_URL:http://localhost:8080} + - provider_id: sentence-transformers + provider_type: inline::sentence-transformers + config: {} + vector_io: + - provider_id: faiss + provider_type: inline::faiss + config: + kvstore: + type: sqlite + namespace: null + db_path: ${env.SQLITE_STORE_DIR:distributions/ramalama}/faiss_store.db + safety: + - provider_id: llama-guard + provider_type: inline::llama-guard + config: {} + - provider_id: code-scanner + provider_type: inline::code-scanner + config: {} + agents: + - provider_id: meta-reference + provider_type: inline::meta-reference + config: + persistence_store: + type: sqlite + namespace: null + db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/ramalama}/agents_store.db + telemetry: + - provider_id: meta-reference + provider_type: inline::meta-reference + config: + service_name: "${env.OTEL_SERVICE_NAME:\u200B}" + sinks: ${env.TELEMETRY_SINKS:console,sqlite} + sqlite_db_path: ${env.SQLITE_DB_PATH:~/.llama/distributions/ramalama/trace_store.db} + eval: + - provider_id: meta-reference + provider_type: inline::meta-reference + config: + kvstore: + type: sqlite + namespace: null + db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/ramalama}/meta_reference_eval.db + datasetio: + - provider_id: huggingface + provider_type: remote::huggingface + config: + kvstore: + type: sqlite + namespace: null + db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/ramalama}/huggingface_datasetio.db + - provider_id: localfs + provider_type: inline::localfs + config: + kvstore: + type: sqlite + namespace: null + db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/ramalama}/localfs_datasetio.db + scoring: + - provider_id: basic + provider_type: inline::basic + config: {} + - provider_id: llm-as-judge + provider_type: inline::llm-as-judge + config: {} + - provider_id: braintrust + provider_type: inline::braintrust + config: + openai_api_key: ${env.OPENAI_API_KEY:} + tool_runtime: + - provider_id: brave-search + provider_type: remote::brave-search + config: + api_key: ${env.BRAVE_SEARCH_API_KEY:} + max_results: 3 + - provider_id: tavily-search + provider_type: remote::tavily-search + config: + api_key: ${env.TAVILY_SEARCH_API_KEY:} + max_results: 3 + - provider_id: code-interpreter + provider_type: inline::code-interpreter + config: {} + - provider_id: rag-runtime + provider_type: inline::rag-runtime + config: {} +metadata_store: + type: sqlite + db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/ramalama}/registry.db +models: +- metadata: {} + model_id: ${env.INFERENCE_MODEL} + provider_id: ramalama + model_type: llm +- metadata: {} + model_id: ${env.SAFETY_MODEL} + provider_id: ramalama + model_type: llm +shields: +- shield_id: ${env.SAFETY_MODEL} + provider_id: llama-guard +- shield_id: CodeScanner + provider_id: code-scanner +vector_dbs: [] +datasets: [] +scoring_fns: [] +benchmarks: [] +tool_groups: +- toolgroup_id: builtin::websearch + provider_id: tavily-search +- toolgroup_id: builtin::rag + provider_id: rag-runtime +- toolgroup_id: builtin::code_interpreter + provider_id: code-interpreter +server: + port: 8321 diff --git a/llama_stack/templates/ramalama/run.yaml b/llama_stack/templates/ramalama/run.yaml new file mode 100644 index 000000000..b5b5f2ab9 --- /dev/null +++ b/llama_stack/templates/ramalama/run.yaml @@ -0,0 +1,122 @@ +version: '2' +image_name: ramalama +apis: +- agents +- datasetio +- eval +- inference +- safety +- scoring +- telemetry +- tool_runtime +- vector_io +providers: + inference: + - provider_id: ramalama + provider_type: remote::ramalama + config: + url: ${env.RAMALAMA_URL:http://localhost:8080} + - provider_id: sentence-transformers + provider_type: inline::sentence-transformers + config: {} + vector_io: + - provider_id: faiss + provider_type: inline::faiss + config: + kvstore: + type: sqlite + namespace: null + db_path: ${env.SQLITE_STORE_DIR:distributions/ramalama}/faiss_store.db + safety: + - provider_id: llama-guard + provider_type: inline::llama-guard + config: + excluded_categories: [] + agents: + - provider_id: meta-reference + provider_type: inline::meta-reference + config: + persistence_store: + type: sqlite + namespace: null + db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/ramalama}/agents_store.db + telemetry: + - provider_id: meta-reference + provider_type: inline::meta-reference + config: + service_name: "${env.OTEL_SERVICE_NAME:\u200B}" + sinks: ${env.TELEMETRY_SINKS:console,sqlite} + sqlite_db_path: ${env.SQLITE_DB_PATH:~/.llama/distributions/ramalama/trace_store.db} + eval: + - provider_id: meta-reference + provider_type: inline::meta-reference + config: + kvstore: + type: sqlite + namespace: null + db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/ramalama}/meta_reference_eval.db + datasetio: + - provider_id: huggingface + provider_type: remote::huggingface + config: + kvstore: + type: sqlite + namespace: null + db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/ramalama}/huggingface_datasetio.db + - provider_id: localfs + provider_type: inline::localfs + config: + kvstore: + type: sqlite + namespace: null + db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/ramalama}/localfs_datasetio.db + scoring: + - provider_id: basic + provider_type: inline::basic + config: {} + - provider_id: llm-as-judge + provider_type: inline::llm-as-judge + config: {} + - provider_id: braintrust + provider_type: inline::braintrust + config: + openai_api_key: ${env.OPENAI_API_KEY:} + tool_runtime: + - provider_id: brave-search + provider_type: remote::brave-search + config: + api_key: ${env.BRAVE_SEARCH_API_KEY:} + max_results: 3 + - provider_id: tavily-search + provider_type: remote::tavily-search + config: + api_key: ${env.TAVILY_SEARCH_API_KEY:} + max_results: 3 + - provider_id: code-interpreter + provider_type: inline::code-interpreter + config: {} + - provider_id: rag-runtime + provider_type: inline::rag-runtime + config: {} +metadata_store: + type: sqlite + db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/ramalama}/registry.db +models: +- metadata: {} + model_id: ${env.INFERENCE_MODEL} + provider_id: ramalama + model_type: llm +shields: [] +vector_dbs: [] +datasets: [] +scoring_fns: [] +benchmarks: [] +tool_groups: +- toolgroup_id: builtin::websearch + provider_id: tavily-search +- toolgroup_id: builtin::rag + provider_id: rag-runtime +- toolgroup_id: builtin::code_interpreter + provider_id: code-interpreter +server: + port: 8321