Compare commits

...
Sign in to create a new pull request.

1 commit

Author SHA1 Message Date
Abhishek Kumawat
9e0c8a82cb Litellm support in llama stack: 2025-02-03 06:15:09 -08:00
6 changed files with 359 additions and 0 deletions

View file

@ -164,6 +164,15 @@ def available_providers() -> List[ProviderSpec]:
provider_data_validator="llama_stack.providers.remote.inference.groq.GroqProviderDataValidator",
),
),
remote_provider_spec(
api=Api.inference,
adapter=AdapterSpec(
adapter_type="litellm",
pip_packages=["litellm"],
module="llama_stack.providers.remote.inference.litellm",
config_class="llama_stack.providers.remote.inference.litellm.LitellmConfig",
),
),
remote_provider_spec(
api=Api.inference,
adapter=AdapterSpec(

View file

@ -0,0 +1,19 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from pydantic import BaseModel
from llama_stack.apis.inference import Inference
from .config import LitellmConfig
async def get_adapter_impl(config: LitellmConfig, _deps) -> Inference:
# import dynamically so the import is used only when it is needed
from .litellm import LitellmInferenceAdapter
assert isinstance(config, LitellmConfig), f"Unexpected config type: {type(config)}"
adapter = LitellmInferenceAdapter(config)
return adapter

View file

@ -0,0 +1,19 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Optional
from pydantic import BaseModel, Field
class LitellmConfig(BaseModel):
openai_api_key: Optional[str] = Field(
default=None,
description="The access key to use for openai. Default use environment variable: OPENAI_API_KEY",
)
llm_provider: Optional[str] = Field(
default="openai",
description="The provider to use. Default use environment variable: LLM_PROVIDER",
)

View file

@ -0,0 +1,116 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import os
from typing import AsyncIterator, List, Optional, Union, Any
from litellm import completion as litellm_completion
from litellm.types.utils import ModelResponse
from llama_models.datatypes import SamplingParams
from llama_models.llama3.api.datatypes import ToolDefinition, ToolPromptFormat, StopReason
from llama_stack.apis.inference import (
ChatCompletionResponse,
ChatCompletionResponseStreamChunk,
CompletionResponse,
CompletionMessage,
CompletionResponseStreamChunk,
EmbeddingsResponse,
Inference,
InterleavedContent,
LogProbConfig,
Message,
ResponseFormat,
ToolChoice,
)
# from llama_stack.distribution.request_headers import NeedsRequestProviderData
from llama_stack.providers.remote.inference.litellm.config import LitellmConfig
from llama_stack.providers.utils.inference.model_registry import (
build_model_alias,
ModelRegistryHelper,
)
_MODEL_ALIASES = [
build_model_alias(
"gpt-4o", # provider_model_id
"gpt-4o", # model_descriptor
),
]
class LitellmInferenceAdapter(Inference, ModelRegistryHelper):
_config: LitellmConfig
def __init__(self, config: LitellmConfig):
ModelRegistryHelper.__init__(self, model_aliases=_MODEL_ALIASES)
self._config = config
def completion(
self,
model_id: str,
content: InterleavedContent,
sampling_params: Optional[SamplingParams] = SamplingParams(),
response_format: Optional[ResponseFormat] = None,
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
) -> Union[CompletionResponse, AsyncIterator[CompletionResponseStreamChunk]]:
# litellm doesn't support non-chat completion as of time of writing
raise NotImplementedError()
def _messages_to_litellm_messages(
self,
messages: List[Message],
) -> list[dict[str, Any]]:
litellm_messages = []
for message in messages:
lm_message = {
"role": message.role,
"content": message.content,
}
litellm_messages.append(lm_message)
return litellm_messages
def _convert_to_llama_stack_response(
self,
litellm_response: ModelResponse,
) -> ChatCompletionResponse:
assert litellm_response.choices is not None
assert len(litellm_response.choices) == 1
message = litellm_response.choices[0].message
completion_message = CompletionMessage(content=message["content"], role=message["role"], stop_reason=StopReason.end_of_message, tool_calls=[])
return ChatCompletionResponse(completion_message=completion_message)
async def chat_completion(
self,
model_id: str,
messages: List[Message],
sampling_params: Optional[SamplingParams] = SamplingParams(),
response_format: Optional[ResponseFormat] = None,
tools: Optional[List[ToolDefinition]] = None,
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
tool_prompt_format: Optional[ToolPromptFormat] = None,
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
) -> Union[
ChatCompletionResponse, AsyncIterator[ChatCompletionResponseStreamChunk]
]:
assert stream is False, "streaming not supported"
model_id = self.get_provider_model_id(model_id)
response = litellm_completion(
model=model_id,
custom_llm_provider=self._config.llm_provider,
messages=self._messages_to_litellm_messages(messages),
api_key=self._config.openai_api_key,
)
return self._convert_to_llama_stack_response(response)
async def embeddings(
self,
model_id: str,
contents: List[InterleavedContent],
) -> EmbeddingsResponse:
raise NotImplementedError()

View file

@ -0,0 +1,79 @@
version: '2'
image_name: llama_stack_server
container_image: null
apis:
- inference
- safety
- agents
- vector_io
- datasetio
- scoring
- eval
- post_training
- tool_runtime
- telemetry
providers:
inference:
- provider_id: litellm
provider_type: remote::litellm
config:
openai_api_key: ???
safety:
- provider_id: llama-guard
provider_type: inline::llama-guard
config: {}
agents:
- provider_id: meta-reference
provider_type: inline::meta-reference
config:
persistence_store:
type: sqlite
namespace: null
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/llama_stack_server}/agents_store.db
vector_io:
- provider_id: faiss
provider_type: inline::faiss
config:
kvstore:
type: sqlite
namespace: null
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/llama_stack_server}/faiss_store.db
datasetio:
- provider_id: localfs
provider_type: inline::localfs
config: {}
scoring:
- provider_id: basic
provider_type: inline::basic
config: {}
eval:
- provider_id: meta-reference
provider_type: inline::meta-reference
config: {}
post_training:
- provider_id: torchtune
provider_type: inline::torchtune
config: {}
tool_runtime:
- provider_id: rag-runtime
provider_type: inline::rag-runtime
config: {}
telemetry:
- provider_id: meta-reference
provider_type: inline::meta-reference
config:
service_name: ${env.OTEL_SERVICE_NAME:llama-stack}
sinks: ${env.TELEMETRY_SINKS:console,sqlite}
sqlite_db_path: ${env.SQLITE_DB_PATH:~/.llama/distributions/llama_stack_server/trace_store.db}
metadata_store: null
models:
- metadata: {}
model_id: gpt-4o
provider_id: litellm
model_type: llm
shields: []
vector_dbs: []
datasets: []
scoring_fns: []
eval_tasks: []
tool_groups: []

117
meta-reference-gpu-run.yaml Normal file
View file

@ -0,0 +1,117 @@
version: '2'
image_name: meta-reference-gpu
apis:
- agents
- datasetio
- eval
- inference
- safety
- scoring
- telemetry
- tool_runtime
- vector_io
providers:
inference:
- provider_id: meta-reference-inference
provider_type: inline::meta-reference
config:
model: Llama3.3-70B-Instruct
max_seq_len: 64000
checkpoint_dir: ${env.INFERENCE_CHECKPOINT_DIR:null}
- provider_id: sentence-transformers
provider_type: inline::sentence-transformers
config: {}
vector_io:
- provider_id: faiss
provider_type: inline::faiss
config:
kvstore:
type: sqlite
namespace: null
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/meta-reference-gpu}/faiss_store.db
safety:
- provider_id: llama-guard
provider_type: inline::llama-guard
config: {}
agents:
- provider_id: meta-reference
provider_type: inline::meta-reference
config:
persistence_store:
type: sqlite
namespace: null
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/meta-reference-gpu}/agents_store.db
telemetry:
- provider_id: meta-reference
provider_type: inline::meta-reference
config:
service_name: ${env.OTEL_SERVICE_NAME:llama-stack}
sinks: ${env.TELEMETRY_SINKS:console,sqlite}
sqlite_db_path: ${env.SQLITE_DB_PATH:~/.llama/distributions/meta-reference-gpu/trace_store.db}
eval:
- provider_id: meta-reference
provider_type: inline::meta-reference
config: {}
datasetio:
- provider_id: huggingface
provider_type: remote::huggingface
config: {}
- provider_id: localfs
provider_type: inline::localfs
config: {}
scoring:
- provider_id: basic
provider_type: inline::basic
config: {}
- provider_id: llm-as-judge
provider_type: inline::llm-as-judge
config: {}
- provider_id: braintrust
provider_type: inline::braintrust
config:
openai_api_key: ${env.OPENAI_API_KEY:}
tool_runtime:
- provider_id: brave-search
provider_type: remote::brave-search
config:
api_key: ${env.BRAVE_SEARCH_API_KEY:}
max_results: 3
- provider_id: tavily-search
provider_type: remote::tavily-search
config:
api_key: ${env.TAVILY_SEARCH_API_KEY:}
max_results: 3
- provider_id: code-interpreter
provider_type: inline::code-interpreter
config: {}
- provider_id: rag-runtime
provider_type: inline::rag-runtime
config: {}
- provider_id: model-context-protocol
provider_type: remote::model-context-protocol
config: {}
metadata_store:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/meta-reference-gpu}/registry.db
models:
- metadata: {}
model_id: Llama3.3-70B-Instruct
provider_id: meta-reference-inference
model_type: llm
- metadata:
embedding_dimension: 384
model_id: all-MiniLM-L6-v2
provider_id: sentence-transformers
model_type: embedding
shields: []
vector_dbs: []
datasets: []
scoring_fns: []
eval_tasks: []
tool_groups:
- toolgroup_id: builtin::websearch
provider_id: tavily-search
- toolgroup_id: builtin::rag
provider_id: rag-runtime
- toolgroup_id: builtin::code_interpreter
provider_id: code-interpreter