diff --git a/llama_stack/providers/registry/inference.py b/llama_stack/providers/registry/inference.py index e72140ccf..3fd084e1b 100644 --- a/llama_stack/providers/registry/inference.py +++ b/llama_stack/providers/registry/inference.py @@ -164,6 +164,15 @@ def available_providers() -> List[ProviderSpec]: provider_data_validator="llama_stack.providers.remote.inference.groq.GroqProviderDataValidator", ), ), + remote_provider_spec( + api=Api.inference, + adapter=AdapterSpec( + adapter_type="litellm", + pip_packages=["litellm"], + module="llama_stack.providers.remote.inference.litellm", + config_class="llama_stack.providers.remote.inference.litellm.LitellmConfig", + ), + ), remote_provider_spec( api=Api.inference, adapter=AdapterSpec( diff --git a/llama_stack/providers/remote/inference/litellm/__init__.py b/llama_stack/providers/remote/inference/litellm/__init__.py new file mode 100644 index 000000000..969943ba2 --- /dev/null +++ b/llama_stack/providers/remote/inference/litellm/__init__.py @@ -0,0 +1,19 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from pydantic import BaseModel + +from llama_stack.apis.inference import Inference + +from .config import LitellmConfig + + +async def get_adapter_impl(config: LitellmConfig, _deps) -> Inference: + # import dynamically so the import is used only when it is needed + from .litellm import LitellmInferenceAdapter + assert isinstance(config, LitellmConfig), f"Unexpected config type: {type(config)}" + adapter = LitellmInferenceAdapter(config) + return adapter diff --git a/llama_stack/providers/remote/inference/litellm/config.py b/llama_stack/providers/remote/inference/litellm/config.py new file mode 100644 index 000000000..a6661ab81 --- /dev/null +++ b/llama_stack/providers/remote/inference/litellm/config.py @@ -0,0 +1,19 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. +from typing import Optional + +from pydantic import BaseModel, Field + + +class LitellmConfig(BaseModel): + openai_api_key: Optional[str] = Field( + default=None, + description="The access key to use for openai. Default use environment variable: OPENAI_API_KEY", + ) + llm_provider: Optional[str] = Field( + default="openai", + description="The provider to use. Default use environment variable: LLM_PROVIDER", + ) diff --git a/llama_stack/providers/remote/inference/litellm/litellm.py b/llama_stack/providers/remote/inference/litellm/litellm.py new file mode 100644 index 000000000..b62d25546 --- /dev/null +++ b/llama_stack/providers/remote/inference/litellm/litellm.py @@ -0,0 +1,116 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import os +from typing import AsyncIterator, List, Optional, Union, Any + +from litellm import completion as litellm_completion +from litellm.types.utils import ModelResponse + +from llama_models.datatypes import SamplingParams +from llama_models.llama3.api.datatypes import ToolDefinition, ToolPromptFormat, StopReason +from llama_stack.apis.inference import ( + ChatCompletionResponse, + ChatCompletionResponseStreamChunk, + CompletionResponse, + CompletionMessage, + CompletionResponseStreamChunk, + EmbeddingsResponse, + Inference, + InterleavedContent, + LogProbConfig, + Message, + ResponseFormat, + ToolChoice, +) +# from llama_stack.distribution.request_headers import NeedsRequestProviderData +from llama_stack.providers.remote.inference.litellm.config import LitellmConfig +from llama_stack.providers.utils.inference.model_registry import ( + build_model_alias, + ModelRegistryHelper, +) + +_MODEL_ALIASES = [ + build_model_alias( + "gpt-4o", # provider_model_id + "gpt-4o", # model_descriptor + ), +] + +class LitellmInferenceAdapter(Inference, ModelRegistryHelper): + _config: LitellmConfig + + def __init__(self, config: LitellmConfig): + ModelRegistryHelper.__init__(self, model_aliases=_MODEL_ALIASES) + self._config = config + + def completion( + self, + model_id: str, + content: InterleavedContent, + sampling_params: Optional[SamplingParams] = SamplingParams(), + response_format: Optional[ResponseFormat] = None, + stream: Optional[bool] = False, + logprobs: Optional[LogProbConfig] = None, + ) -> Union[CompletionResponse, AsyncIterator[CompletionResponseStreamChunk]]: + # litellm doesn't support non-chat completion as of time of writing + raise NotImplementedError() + + def _messages_to_litellm_messages( + self, + messages: List[Message], + ) -> list[dict[str, Any]]: + litellm_messages = [] + for message in messages: + lm_message = { + "role": message.role, + "content": message.content, + } + litellm_messages.append(lm_message) + return litellm_messages + + def _convert_to_llama_stack_response( + self, + litellm_response: ModelResponse, + ) -> ChatCompletionResponse: + assert litellm_response.choices is not None + assert len(litellm_response.choices) == 1 + message = litellm_response.choices[0].message + completion_message = CompletionMessage(content=message["content"], role=message["role"], stop_reason=StopReason.end_of_message, tool_calls=[]) + return ChatCompletionResponse(completion_message=completion_message) + + async def chat_completion( + self, + model_id: str, + messages: List[Message], + sampling_params: Optional[SamplingParams] = SamplingParams(), + response_format: Optional[ResponseFormat] = None, + tools: Optional[List[ToolDefinition]] = None, + tool_choice: Optional[ToolChoice] = ToolChoice.auto, + tool_prompt_format: Optional[ToolPromptFormat] = None, + stream: Optional[bool] = False, + logprobs: Optional[LogProbConfig] = None, + ) -> Union[ + ChatCompletionResponse, AsyncIterator[ChatCompletionResponseStreamChunk] + ]: + assert stream is False, "streaming not supported" + model_id = self.get_provider_model_id(model_id) + response = litellm_completion( + model=model_id, + custom_llm_provider=self._config.llm_provider, + messages=self._messages_to_litellm_messages(messages), + api_key=self._config.openai_api_key, + ) + + return self._convert_to_llama_stack_response(response) + + async def embeddings( + self, + model_id: str, + contents: List[InterleavedContent], + ) -> EmbeddingsResponse: + raise NotImplementedError() + \ No newline at end of file diff --git a/llama_stack_server-run.yaml b/llama_stack_server-run.yaml new file mode 100644 index 000000000..d04823d2d --- /dev/null +++ b/llama_stack_server-run.yaml @@ -0,0 +1,79 @@ +version: '2' +image_name: llama_stack_server +container_image: null +apis: +- inference +- safety +- agents +- vector_io +- datasetio +- scoring +- eval +- post_training +- tool_runtime +- telemetry +providers: + inference: + - provider_id: litellm + provider_type: remote::litellm + config: + openai_api_key: ??? + safety: + - provider_id: llama-guard + provider_type: inline::llama-guard + config: {} + agents: + - provider_id: meta-reference + provider_type: inline::meta-reference + config: + persistence_store: + type: sqlite + namespace: null + db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/llama_stack_server}/agents_store.db + vector_io: + - provider_id: faiss + provider_type: inline::faiss + config: + kvstore: + type: sqlite + namespace: null + db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/llama_stack_server}/faiss_store.db + datasetio: + - provider_id: localfs + provider_type: inline::localfs + config: {} + scoring: + - provider_id: basic + provider_type: inline::basic + config: {} + eval: + - provider_id: meta-reference + provider_type: inline::meta-reference + config: {} + post_training: + - provider_id: torchtune + provider_type: inline::torchtune + config: {} + tool_runtime: + - provider_id: rag-runtime + provider_type: inline::rag-runtime + config: {} + telemetry: + - provider_id: meta-reference + provider_type: inline::meta-reference + config: + service_name: ${env.OTEL_SERVICE_NAME:llama-stack} + sinks: ${env.TELEMETRY_SINKS:console,sqlite} + sqlite_db_path: ${env.SQLITE_DB_PATH:~/.llama/distributions/llama_stack_server/trace_store.db} +metadata_store: null +models: +- metadata: {} + model_id: gpt-4o + provider_id: litellm + model_type: llm +shields: [] +vector_dbs: [] +datasets: [] +scoring_fns: [] +eval_tasks: [] +tool_groups: [] diff --git a/meta-reference-gpu-run.yaml b/meta-reference-gpu-run.yaml new file mode 100644 index 000000000..4f8c191ca --- /dev/null +++ b/meta-reference-gpu-run.yaml @@ -0,0 +1,117 @@ +version: '2' +image_name: meta-reference-gpu +apis: +- agents +- datasetio +- eval +- inference +- safety +- scoring +- telemetry +- tool_runtime +- vector_io +providers: + inference: + - provider_id: meta-reference-inference + provider_type: inline::meta-reference + config: + model: Llama3.3-70B-Instruct + max_seq_len: 64000 + checkpoint_dir: ${env.INFERENCE_CHECKPOINT_DIR:null} + - provider_id: sentence-transformers + provider_type: inline::sentence-transformers + config: {} + vector_io: + - provider_id: faiss + provider_type: inline::faiss + config: + kvstore: + type: sqlite + namespace: null + db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/meta-reference-gpu}/faiss_store.db + safety: + - provider_id: llama-guard + provider_type: inline::llama-guard + config: {} + agents: + - provider_id: meta-reference + provider_type: inline::meta-reference + config: + persistence_store: + type: sqlite + namespace: null + db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/meta-reference-gpu}/agents_store.db + telemetry: + - provider_id: meta-reference + provider_type: inline::meta-reference + config: + service_name: ${env.OTEL_SERVICE_NAME:llama-stack} + sinks: ${env.TELEMETRY_SINKS:console,sqlite} + sqlite_db_path: ${env.SQLITE_DB_PATH:~/.llama/distributions/meta-reference-gpu/trace_store.db} + eval: + - provider_id: meta-reference + provider_type: inline::meta-reference + config: {} + datasetio: + - provider_id: huggingface + provider_type: remote::huggingface + config: {} + - provider_id: localfs + provider_type: inline::localfs + config: {} + scoring: + - provider_id: basic + provider_type: inline::basic + config: {} + - provider_id: llm-as-judge + provider_type: inline::llm-as-judge + config: {} + - provider_id: braintrust + provider_type: inline::braintrust + config: + openai_api_key: ${env.OPENAI_API_KEY:} + tool_runtime: + - provider_id: brave-search + provider_type: remote::brave-search + config: + api_key: ${env.BRAVE_SEARCH_API_KEY:} + max_results: 3 + - provider_id: tavily-search + provider_type: remote::tavily-search + config: + api_key: ${env.TAVILY_SEARCH_API_KEY:} + max_results: 3 + - provider_id: code-interpreter + provider_type: inline::code-interpreter + config: {} + - provider_id: rag-runtime + provider_type: inline::rag-runtime + config: {} + - provider_id: model-context-protocol + provider_type: remote::model-context-protocol + config: {} +metadata_store: + type: sqlite + db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/meta-reference-gpu}/registry.db +models: +- metadata: {} + model_id: Llama3.3-70B-Instruct + provider_id: meta-reference-inference + model_type: llm +- metadata: + embedding_dimension: 384 + model_id: all-MiniLM-L6-v2 + provider_id: sentence-transformers + model_type: embedding +shields: [] +vector_dbs: [] +datasets: [] +scoring_fns: [] +eval_tasks: [] +tool_groups: +- toolgroup_id: builtin::websearch + provider_id: tavily-search +- toolgroup_id: builtin::rag + provider_id: rag-runtime +- toolgroup_id: builtin::code_interpreter + provider_id: code-interpreter