diff --git a/distributions/runpod/build.yaml b/distributions/runpod/build.yaml new file mode 100644 index 000000000..9348573ef --- /dev/null +++ b/distributions/runpod/build.yaml @@ -0,0 +1,9 @@ +name: runpod +distribution_spec: + description: Use Runpod for running LLM inference + providers: + inference: remote::runpod + memory: meta-reference + safety: meta-reference + agents: meta-reference + telemetry: meta-reference diff --git a/llama_stack/providers/adapters/inference/runpod/__init__.py b/llama_stack/providers/adapters/inference/runpod/__init__.py new file mode 100644 index 000000000..67d49bc45 --- /dev/null +++ b/llama_stack/providers/adapters/inference/runpod/__init__.py @@ -0,0 +1,17 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from .config import RunpodImplConfig +from .runpod import RunpodInferenceAdapter + + +async def get_adapter_impl(config: RunpodImplConfig, _deps): + assert isinstance( + config, RunpodImplConfig + ), f"Unexpected config type: {type(config)}" + impl = RunpodInferenceAdapter(config) + await impl.initialize() + return impl diff --git a/llama_stack/providers/adapters/inference/runpod/config.py b/llama_stack/providers/adapters/inference/runpod/config.py new file mode 100644 index 000000000..1a9582052 --- /dev/null +++ b/llama_stack/providers/adapters/inference/runpod/config.py @@ -0,0 +1,22 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from typing import Optional + +from llama_models.schema_utils import json_schema_type +from pydantic import BaseModel, Field + + +@json_schema_type +class RunpodImplConfig(BaseModel): + url: Optional[str] = Field( + default=None, + description="The URL for the Runpod model serving endpoint", + ) + api_token: Optional[str] = Field( + default=None, + description="The API token", + ) diff --git a/llama_stack/providers/adapters/inference/runpod/runpod.py b/llama_stack/providers/adapters/inference/runpod/runpod.py new file mode 100644 index 000000000..cb2e6b237 --- /dev/null +++ b/llama_stack/providers/adapters/inference/runpod/runpod.py @@ -0,0 +1,133 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. +from typing import AsyncGenerator + +from llama_models.llama3.api.chat_format import ChatFormat +from llama_models.llama3.api.datatypes import Message +from llama_models.llama3.api.tokenizer import Tokenizer + +from openai import OpenAI + +from llama_stack.apis.inference import * # noqa: F403 +# from llama_stack.providers.datatypes import ModelsProtocolPrivate +from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper + +from llama_stack.providers.utils.inference.openai_compat import ( + get_sampling_options, + process_chat_completion_response, + process_chat_completion_stream_response, +) +from llama_stack.providers.utils.inference.prompt_adapter import ( + chat_completion_request_to_prompt, +) + +from .config import RunpodImplConfig + +RUNPOD_SUPPORTED_MODELS = { + "Llama3.1-8B": "meta-llama/Llama-3.1-8B", + "Llama3.1-70B": "meta-llama/Llama-3.1-70B", + "Llama3.1-405B:bf16-mp8": "meta-llama/Llama-3.1-405B", + "Llama3.1-405B": "meta-llama/Llama-3.1-405B-FP8", + "Llama3.1-405B:bf16-mp16": "meta-llama/Llama-3.1-405B", + "Llama3.1-8B-Instruct": "meta-llama/Llama-3.1-8B-Instruct", + "Llama3.1-70B-Instruct": "meta-llama/Llama-3.1-70B-Instruct", + "Llama3.1-405B-Instruct:bf16-mp8": "meta-llama/Llama-3.1-405B-Instruct", + "Llama3.1-405B-Instruct": "meta-llama/Llama-3.1-405B-Instruct-FP8", + "Llama3.1-405B-Instruct:bf16-mp16": "meta-llama/Llama-3.1-405B-Instruct", + "Llama3.2-1B": "meta-llama/Llama-3.2-1B", + "Llama3.2-3B": "meta-llama/Llama-3.2-3B", +} +class RunpodInferenceAdapter(ModelRegistryHelper, Inference): + def __init__(self, config: RunpodImplConfig) -> None: + ModelRegistryHelper.__init__( + self, stack_to_provider_models_map=RUNPOD_SUPPORTED_MODELS + ) + self.config = config + self.formatter = ChatFormat(Tokenizer.get_instance()) + + async def initialize(self) -> None: + return + + async def shutdown(self) -> None: + pass + + async def completion( + self, + model: str, + content: InterleavedTextMedia, + sampling_params: Optional[SamplingParams] = SamplingParams(), + response_format: Optional[ResponseFormat] = None, + stream: Optional[bool] = False, + logprobs: Optional[LogProbConfig] = None, + ) -> AsyncGenerator: + raise NotImplementedError() + + async def chat_completion( + self, + model: str, + messages: List[Message], + sampling_params: Optional[SamplingParams] = SamplingParams(), + response_format: Optional[ResponseFormat] = None, + tools: Optional[List[ToolDefinition]] = None, + tool_choice: Optional[ToolChoice] = ToolChoice.auto, + tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json, + stream: Optional[bool] = False, + logprobs: Optional[LogProbConfig] = None, + ) -> AsyncGenerator: + request = ChatCompletionRequest( + model=model, + messages=messages, + sampling_params=sampling_params, + tools=tools or [], + tool_choice=tool_choice, + tool_prompt_format=tool_prompt_format, + stream=stream, + logprobs=logprobs, + ) + + client = OpenAI(base_url=self.config.url, api_key=self.config.api_token) + if stream: + return self._stream_chat_completion(request, client) + else: + return await self._nonstream_chat_completion(request, client) + + async def _nonstream_chat_completion( + self, request: ChatCompletionRequest, client: OpenAI + ) -> ChatCompletionResponse: + params = self._get_params(request) + r = client.completions.create(**params) + return process_chat_completion_response(r, self.formatter) + + async def _stream_chat_completion( + self, request: ChatCompletionRequest, client: OpenAI + ) -> AsyncGenerator: + params = self._get_params(request) + + async def _to_async_generator(): + s = client.completions.create(**params) + for chunk in s: + yield chunk + + stream = _to_async_generator() + async for chunk in process_chat_completion_stream_response( + stream, self.formatter + ): + yield chunk + + def _get_params(self, request: ChatCompletionRequest) -> dict: + return { + "model": self.map_to_provider_model(request.model), + "prompt": chat_completion_request_to_prompt(request, self.formatter), + "stream": request.stream, + **get_sampling_options(request.sampling_params), + } + + async def embeddings( + self, + model: str, + contents: List[InterleavedTextMedia], + ) -> EmbeddingsResponse: + raise NotImplementedError() \ No newline at end of file diff --git a/llama_stack/providers/registry/inference.py b/llama_stack/providers/registry/inference.py index 55924a1e9..320c649d2 100644 --- a/llama_stack/providers/registry/inference.py +++ b/llama_stack/providers/registry/inference.py @@ -195,4 +195,13 @@ def available_providers() -> List[ProviderSpec]: config_class="llama_stack.providers.remote.inference.nvidia.NVIDIAConfig", ), ), + remote_provider_spec( + api=Api.inference, + adapter=AdapterSpec( + adapter_type="runpod", + pip_packages=["openai"], + module="llama_stack.providers.adapters.inference.runpod", + config_class="llama_stack.providers.adapters.inference.runpod.RunpodImplConfig", + ), + ), ]