diff --git a/llama_stack/providers/registry/inference.py b/llama_stack/providers/registry/inference.py index e72140ccf..346a2bd73 100644 --- a/llama_stack/providers/registry/inference.py +++ b/llama_stack/providers/registry/inference.py @@ -215,4 +215,14 @@ def available_providers() -> List[ProviderSpec]: config_class="llama_stack.providers.remote.inference.sambanova.SambaNovaImplConfig", ), ), + remote_provider_spec( + api=Api.inference, + adapter=AdapterSpec( + adapter_type="passthrough", + pip_packages=[], + module="llama_stack.providers.remote.inference.passthrough", + config_class="llama_stack.providers.remote.inference.passthrough.PassthroughImplConfig", + provider_data_validator="llama_stack.providers.remote.inference.passthrough.PassthroughProviderDataValidator", + ), + ), ] diff --git a/llama_stack/providers/remote/inference/passthrough/__init__.py b/llama_stack/providers/remote/inference/passthrough/__init__.py new file mode 100644 index 000000000..69dd4c461 --- /dev/null +++ b/llama_stack/providers/remote/inference/passthrough/__init__.py @@ -0,0 +1,23 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from pydantic import BaseModel + +from .config import PassthroughImplConfig + + +class PassthroughProviderDataValidator(BaseModel): + url: str + api_key: str + + +async def get_adapter_impl(config: PassthroughImplConfig, _deps): + from .passthrough import PassthroughInferenceAdapter + + assert isinstance(config, PassthroughImplConfig), f"Unexpected config type: {type(config)}" + impl = PassthroughInferenceAdapter(config) + await impl.initialize() + return impl diff --git a/llama_stack/providers/remote/inference/passthrough/config.py b/llama_stack/providers/remote/inference/passthrough/config.py new file mode 100644 index 000000000..46325e428 --- /dev/null +++ b/llama_stack/providers/remote/inference/passthrough/config.py @@ -0,0 +1,31 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from typing import Any, Dict, Optional + +from pydantic import BaseModel, Field, SecretStr + +from llama_stack.schema_utils import json_schema_type + + +@json_schema_type +class PassthroughImplConfig(BaseModel): + url: str = Field( + default=None, + description="The URL for the passthrough endpoint", + ) + + api_key: Optional[SecretStr] = Field( + default=None, + description="API Key for the passthrouth endpoint", + ) + + @classmethod + def sample_run_config(cls, **kwargs) -> Dict[str, Any]: + return { + "url": "${env.PASSTHROUGH_URL}", + "api_key": "${env.PASSTHROUGH_API_KEY}", + } diff --git a/llama_stack/providers/remote/inference/passthrough/passthrough.py b/llama_stack/providers/remote/inference/passthrough/passthrough.py new file mode 100644 index 000000000..a34c34f69 --- /dev/null +++ b/llama_stack/providers/remote/inference/passthrough/passthrough.py @@ -0,0 +1,148 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from typing import AsyncGenerator, List, Optional + +from llama_stack_client import LlamaStackClient + +from llama_stack.apis.common.content_types import InterleavedContent +from llama_stack.apis.inference import ( + EmbeddingsResponse, + Inference, + LogProbConfig, + Message, + ResponseFormat, + SamplingParams, + ToolChoice, + ToolConfig, + ToolDefinition, + ToolPromptFormat, +) +from llama_stack.apis.models import Model +from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper + +from .config import PassthroughImplConfig + + +class PassthroughInferenceAdapter(Inference): + def __init__(self, config: PassthroughImplConfig) -> None: + ModelRegistryHelper.__init__(self, []) + self.config = config + + async def initialize(self) -> None: + pass + + async def shutdown(self) -> None: + pass + + async def unregister_model(self, model_id: str) -> None: + pass + + async def register_model(self, model: Model) -> Model: + return model + + def _get_client(self) -> LlamaStackClient: + passthrough_url = None + passthrough_api_key = None + provider_data = None + + if self.config.url is not None: + passthrough_url = self.config.url + else: + provider_data = self.get_request_provider_data() + if provider_data is None or not provider_data.passthrough_url: + raise ValueError( + 'Pass url of the passthrough endpoint in the header X-LlamaStack-Provider-Data as { "passthrough_url": }' + ) + passthrough_url = provider_data.passthrough_url + + if self.config.api_key is not None: + passthrough_api_key = self.config.api_key.get_secret_value() + else: + provider_data = self.get_request_provider_data() + if provider_data is None or not provider_data.passthrough_api_key: + raise ValueError( + 'Pass API Key for the passthrough endpoint in the header X-LlamaStack-Provider-Data as { "passthrough_api_key": }' + ) + passthrough_api_key = provider_data.passthrough_api_key + + return LlamaStackClient( + base_url=passthrough_url, + api_key=passthrough_api_key, + provider_data=provider_data, + ) + + async def completion( + self, + model_id: str, + content: InterleavedContent, + sampling_params: Optional[SamplingParams] = SamplingParams(), + response_format: Optional[ResponseFormat] = None, + stream: Optional[bool] = False, + logprobs: Optional[LogProbConfig] = None, + ) -> AsyncGenerator: + client = self._get_client() + model = await self.model_store.get_model(model_id) + + params = { + "model_id": model.provider_resource_id, + "content": content, + "sampling_params": sampling_params, + "response_format": response_format, + "stream": stream, + "logprobs": logprobs, + } + + params = {key: value for key, value in params.items() if value is not None} + + # only pass through the not None params + return client.inference.completion(**params) + + async def chat_completion( + self, + model_id: str, + messages: List[Message], + sampling_params: Optional[SamplingParams] = SamplingParams(), + tools: Optional[List[ToolDefinition]] = None, + tool_choice: Optional[ToolChoice] = ToolChoice.auto, + tool_prompt_format: Optional[ToolPromptFormat] = None, + response_format: Optional[ResponseFormat] = None, + stream: Optional[bool] = False, + logprobs: Optional[LogProbConfig] = None, + tool_config: Optional[ToolConfig] = None, + ) -> AsyncGenerator: + client = self._get_client() + model = await self.model_store.get_model(model_id) + + params = { + "model_id": model.provider_resource_id, + "messages": messages, + "sampling_params": sampling_params, + "tools": tools, + "tool_choice": tool_choice, + "tool_prompt_format": tool_prompt_format, + "response_format": response_format, + "stream": stream, + "logprobs": logprobs, + } + + params = {key: value for key, value in params.items() if value is not None} + + # only pass through the not None params + return client.inference.chat_completion(**params) + + async def embeddings( + self, + model_id: str, + contents: List[InterleavedContent], + ) -> EmbeddingsResponse: + client = self._get_client() + model = await self.model_store.get_model(model_id) + + return client.inference.embeddings( + model_id=model.provider_resource_id, + contents=contents, + ) diff --git a/llama_stack/templates/passthrough/build.yaml b/llama_stack/templates/passthrough/build.yaml new file mode 100644 index 000000000..5fed5286e --- /dev/null +++ b/llama_stack/templates/passthrough/build.yaml @@ -0,0 +1,32 @@ +version: '2' +distribution_spec: + description: Use for running LLM inference with the endpoint that compatible with Llama Stack API + providers: + inference: + - remote::passthrough + vector_io: + - inline::faiss + - remote::chromadb + - remote::pgvector + safety: + - inline::llama-guard + agents: + - inline::meta-reference + telemetry: + - inline::meta-reference + eval: + - inline::meta-reference + datasetio: + - remote::huggingface + - inline::localfs + scoring: + - inline::basic + - inline::llm-as-judge + - inline::braintrust + tool_runtime: + - remote::brave-search + - remote::tavily-search + - inline::code-interpreter + - inline::rag-runtime + - remote::model-context-protocol +image_type: conda diff --git a/llama_stack/templates/passthrough/run.yaml b/llama_stack/templates/passthrough/run.yaml new file mode 100644 index 000000000..2548faa5d --- /dev/null +++ b/llama_stack/templates/passthrough/run.yaml @@ -0,0 +1,120 @@ +version: '2' +image_name: passthrough +apis: +- agents +- datasetio +- eval +- inference +- safety +- scoring +- telemetry +- tool_runtime +- vector_io +providers: + inference: + - provider_id: passthrough + provider_type: remote::passthrough + config: + url: ${env.PASSTHROUGH_URL} + api_key: ${env.PASSTHROUGH_API_KEY} + - provider_id: sentence-transformers + provider_type: inline::sentence-transformers + config: {} + vector_io: + - provider_id: faiss + provider_type: inline::faiss + config: + kvstore: + type: sqlite + namespace: null + db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/passthrough}/faiss_store.db + safety: + - provider_id: llama-guard + provider_type: inline::llama-guard + config: {} + agents: + - provider_id: meta-reference + provider_type: inline::meta-reference + config: + persistence_store: + type: sqlite + namespace: null + db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/passthrough}/agents_store.db + telemetry: + - provider_id: meta-reference + provider_type: inline::meta-reference + config: + service_name: ${env.OTEL_SERVICE_NAME:llama-stack} + sinks: ${env.TELEMETRY_SINKS:console,sqlite} + sqlite_db_path: ${env.SQLITE_DB_PATH:~/.llama/distributions/passthrough/trace_store.db} + eval: + - provider_id: meta-reference + provider_type: inline::meta-reference + config: {} + datasetio: + - provider_id: huggingface + provider_type: remote::huggingface + config: {} + - provider_id: localfs + provider_type: inline::localfs + config: {} + scoring: + - provider_id: basic + provider_type: inline::basic + config: {} + - provider_id: llm-as-judge + provider_type: inline::llm-as-judge + config: {} + - provider_id: braintrust + provider_type: inline::braintrust + config: + openai_api_key: ${env.OPENAI_API_KEY:} + tool_runtime: + - provider_id: brave-search + provider_type: remote::brave-search + config: + api_key: ${env.BRAVE_SEARCH_API_KEY:} + max_results: 3 + - provider_id: tavily-search + provider_type: remote::tavily-search + config: + api_key: ${env.TAVILY_SEARCH_API_KEY:} + max_results: 3 + - provider_id: code-interpreter + provider_type: inline::code-interpreter + config: {} + - provider_id: rag-runtime + provider_type: inline::rag-runtime + config: {} + - provider_id: model-context-protocol + provider_type: remote::model-context-protocol + config: {} +metadata_store: + type: sqlite + db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/meta-llama}/registry.db +models: +- metadata: {} + model_id: meta-llama/Llama-3.1-8B-Instruct + provider_id: passthrough + provider_model_id: llama3.1-8b-instruct + model_type: llm +- metadata: {} + model_id: meta-llama/Llama-3.2-11B-Vision-Instruct + provider_id: passthrough + provider_model_id: llama3.2-11b-vision-instruct + model_type: llm +shields: +- shield_id: meta-llama/Llama-Guard-3-8B +vector_dbs: [] +datasets: [] +scoring_fns: [] +eval_tasks: [] +tool_groups: +- toolgroup_id: builtin::websearch + provider_id: tavily-search +- toolgroup_id: builtin::rag + provider_id: rag-runtime +- toolgroup_id: builtin::code_interpreter + provider_id: code-interpreter +server: + port: 8321