forked from phoenix-oss/llama-stack-mirror
Compare commits
1 commit
kvant
...
ak/llama-s
Author | SHA1 | Date | |
---|---|---|---|
|
9e0c8a82cb |
6 changed files with 359 additions and 0 deletions
|
@ -164,6 +164,15 @@ def available_providers() -> List[ProviderSpec]:
|
|||
provider_data_validator="llama_stack.providers.remote.inference.groq.GroqProviderDataValidator",
|
||||
),
|
||||
),
|
||||
remote_provider_spec(
|
||||
api=Api.inference,
|
||||
adapter=AdapterSpec(
|
||||
adapter_type="litellm",
|
||||
pip_packages=["litellm"],
|
||||
module="llama_stack.providers.remote.inference.litellm",
|
||||
config_class="llama_stack.providers.remote.inference.litellm.LitellmConfig",
|
||||
),
|
||||
),
|
||||
remote_provider_spec(
|
||||
api=Api.inference,
|
||||
adapter=AdapterSpec(
|
||||
|
|
19
llama_stack/providers/remote/inference/litellm/__init__.py
Normal file
19
llama_stack/providers/remote/inference/litellm/__init__.py
Normal file
|
@ -0,0 +1,19 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from llama_stack.apis.inference import Inference
|
||||
|
||||
from .config import LitellmConfig
|
||||
|
||||
|
||||
async def get_adapter_impl(config: LitellmConfig, _deps) -> Inference:
|
||||
# import dynamically so the import is used only when it is needed
|
||||
from .litellm import LitellmInferenceAdapter
|
||||
assert isinstance(config, LitellmConfig), f"Unexpected config type: {type(config)}"
|
||||
adapter = LitellmInferenceAdapter(config)
|
||||
return adapter
|
19
llama_stack/providers/remote/inference/litellm/config.py
Normal file
19
llama_stack/providers/remote/inference/litellm/config.py
Normal file
|
@ -0,0 +1,19 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class LitellmConfig(BaseModel):
|
||||
openai_api_key: Optional[str] = Field(
|
||||
default=None,
|
||||
description="The access key to use for openai. Default use environment variable: OPENAI_API_KEY",
|
||||
)
|
||||
llm_provider: Optional[str] = Field(
|
||||
default="openai",
|
||||
description="The provider to use. Default use environment variable: LLM_PROVIDER",
|
||||
)
|
116
llama_stack/providers/remote/inference/litellm/litellm.py
Normal file
116
llama_stack/providers/remote/inference/litellm/litellm.py
Normal file
|
@ -0,0 +1,116 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import os
|
||||
from typing import AsyncIterator, List, Optional, Union, Any
|
||||
|
||||
from litellm import completion as litellm_completion
|
||||
from litellm.types.utils import ModelResponse
|
||||
|
||||
from llama_models.datatypes import SamplingParams
|
||||
from llama_models.llama3.api.datatypes import ToolDefinition, ToolPromptFormat, StopReason
|
||||
from llama_stack.apis.inference import (
|
||||
ChatCompletionResponse,
|
||||
ChatCompletionResponseStreamChunk,
|
||||
CompletionResponse,
|
||||
CompletionMessage,
|
||||
CompletionResponseStreamChunk,
|
||||
EmbeddingsResponse,
|
||||
Inference,
|
||||
InterleavedContent,
|
||||
LogProbConfig,
|
||||
Message,
|
||||
ResponseFormat,
|
||||
ToolChoice,
|
||||
)
|
||||
# from llama_stack.distribution.request_headers import NeedsRequestProviderData
|
||||
from llama_stack.providers.remote.inference.litellm.config import LitellmConfig
|
||||
from llama_stack.providers.utils.inference.model_registry import (
|
||||
build_model_alias,
|
||||
ModelRegistryHelper,
|
||||
)
|
||||
|
||||
_MODEL_ALIASES = [
|
||||
build_model_alias(
|
||||
"gpt-4o", # provider_model_id
|
||||
"gpt-4o", # model_descriptor
|
||||
),
|
||||
]
|
||||
|
||||
class LitellmInferenceAdapter(Inference, ModelRegistryHelper):
|
||||
_config: LitellmConfig
|
||||
|
||||
def __init__(self, config: LitellmConfig):
|
||||
ModelRegistryHelper.__init__(self, model_aliases=_MODEL_ALIASES)
|
||||
self._config = config
|
||||
|
||||
def completion(
|
||||
self,
|
||||
model_id: str,
|
||||
content: InterleavedContent,
|
||||
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
||||
response_format: Optional[ResponseFormat] = None,
|
||||
stream: Optional[bool] = False,
|
||||
logprobs: Optional[LogProbConfig] = None,
|
||||
) -> Union[CompletionResponse, AsyncIterator[CompletionResponseStreamChunk]]:
|
||||
# litellm doesn't support non-chat completion as of time of writing
|
||||
raise NotImplementedError()
|
||||
|
||||
def _messages_to_litellm_messages(
|
||||
self,
|
||||
messages: List[Message],
|
||||
) -> list[dict[str, Any]]:
|
||||
litellm_messages = []
|
||||
for message in messages:
|
||||
lm_message = {
|
||||
"role": message.role,
|
||||
"content": message.content,
|
||||
}
|
||||
litellm_messages.append(lm_message)
|
||||
return litellm_messages
|
||||
|
||||
def _convert_to_llama_stack_response(
|
||||
self,
|
||||
litellm_response: ModelResponse,
|
||||
) -> ChatCompletionResponse:
|
||||
assert litellm_response.choices is not None
|
||||
assert len(litellm_response.choices) == 1
|
||||
message = litellm_response.choices[0].message
|
||||
completion_message = CompletionMessage(content=message["content"], role=message["role"], stop_reason=StopReason.end_of_message, tool_calls=[])
|
||||
return ChatCompletionResponse(completion_message=completion_message)
|
||||
|
||||
async def chat_completion(
|
||||
self,
|
||||
model_id: str,
|
||||
messages: List[Message],
|
||||
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
||||
response_format: Optional[ResponseFormat] = None,
|
||||
tools: Optional[List[ToolDefinition]] = None,
|
||||
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
|
||||
tool_prompt_format: Optional[ToolPromptFormat] = None,
|
||||
stream: Optional[bool] = False,
|
||||
logprobs: Optional[LogProbConfig] = None,
|
||||
) -> Union[
|
||||
ChatCompletionResponse, AsyncIterator[ChatCompletionResponseStreamChunk]
|
||||
]:
|
||||
assert stream is False, "streaming not supported"
|
||||
model_id = self.get_provider_model_id(model_id)
|
||||
response = litellm_completion(
|
||||
model=model_id,
|
||||
custom_llm_provider=self._config.llm_provider,
|
||||
messages=self._messages_to_litellm_messages(messages),
|
||||
api_key=self._config.openai_api_key,
|
||||
)
|
||||
|
||||
return self._convert_to_llama_stack_response(response)
|
||||
|
||||
async def embeddings(
|
||||
self,
|
||||
model_id: str,
|
||||
contents: List[InterleavedContent],
|
||||
) -> EmbeddingsResponse:
|
||||
raise NotImplementedError()
|
||||
|
79
llama_stack_server-run.yaml
Normal file
79
llama_stack_server-run.yaml
Normal file
|
@ -0,0 +1,79 @@
|
|||
version: '2'
|
||||
image_name: llama_stack_server
|
||||
container_image: null
|
||||
apis:
|
||||
- inference
|
||||
- safety
|
||||
- agents
|
||||
- vector_io
|
||||
- datasetio
|
||||
- scoring
|
||||
- eval
|
||||
- post_training
|
||||
- tool_runtime
|
||||
- telemetry
|
||||
providers:
|
||||
inference:
|
||||
- provider_id: litellm
|
||||
provider_type: remote::litellm
|
||||
config:
|
||||
openai_api_key: ???
|
||||
safety:
|
||||
- provider_id: llama-guard
|
||||
provider_type: inline::llama-guard
|
||||
config: {}
|
||||
agents:
|
||||
- provider_id: meta-reference
|
||||
provider_type: inline::meta-reference
|
||||
config:
|
||||
persistence_store:
|
||||
type: sqlite
|
||||
namespace: null
|
||||
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/llama_stack_server}/agents_store.db
|
||||
vector_io:
|
||||
- provider_id: faiss
|
||||
provider_type: inline::faiss
|
||||
config:
|
||||
kvstore:
|
||||
type: sqlite
|
||||
namespace: null
|
||||
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/llama_stack_server}/faiss_store.db
|
||||
datasetio:
|
||||
- provider_id: localfs
|
||||
provider_type: inline::localfs
|
||||
config: {}
|
||||
scoring:
|
||||
- provider_id: basic
|
||||
provider_type: inline::basic
|
||||
config: {}
|
||||
eval:
|
||||
- provider_id: meta-reference
|
||||
provider_type: inline::meta-reference
|
||||
config: {}
|
||||
post_training:
|
||||
- provider_id: torchtune
|
||||
provider_type: inline::torchtune
|
||||
config: {}
|
||||
tool_runtime:
|
||||
- provider_id: rag-runtime
|
||||
provider_type: inline::rag-runtime
|
||||
config: {}
|
||||
telemetry:
|
||||
- provider_id: meta-reference
|
||||
provider_type: inline::meta-reference
|
||||
config:
|
||||
service_name: ${env.OTEL_SERVICE_NAME:llama-stack}
|
||||
sinks: ${env.TELEMETRY_SINKS:console,sqlite}
|
||||
sqlite_db_path: ${env.SQLITE_DB_PATH:~/.llama/distributions/llama_stack_server/trace_store.db}
|
||||
metadata_store: null
|
||||
models:
|
||||
- metadata: {}
|
||||
model_id: gpt-4o
|
||||
provider_id: litellm
|
||||
model_type: llm
|
||||
shields: []
|
||||
vector_dbs: []
|
||||
datasets: []
|
||||
scoring_fns: []
|
||||
eval_tasks: []
|
||||
tool_groups: []
|
117
meta-reference-gpu-run.yaml
Normal file
117
meta-reference-gpu-run.yaml
Normal file
|
@ -0,0 +1,117 @@
|
|||
version: '2'
|
||||
image_name: meta-reference-gpu
|
||||
apis:
|
||||
- agents
|
||||
- datasetio
|
||||
- eval
|
||||
- inference
|
||||
- safety
|
||||
- scoring
|
||||
- telemetry
|
||||
- tool_runtime
|
||||
- vector_io
|
||||
providers:
|
||||
inference:
|
||||
- provider_id: meta-reference-inference
|
||||
provider_type: inline::meta-reference
|
||||
config:
|
||||
model: Llama3.3-70B-Instruct
|
||||
max_seq_len: 64000
|
||||
checkpoint_dir: ${env.INFERENCE_CHECKPOINT_DIR:null}
|
||||
- provider_id: sentence-transformers
|
||||
provider_type: inline::sentence-transformers
|
||||
config: {}
|
||||
vector_io:
|
||||
- provider_id: faiss
|
||||
provider_type: inline::faiss
|
||||
config:
|
||||
kvstore:
|
||||
type: sqlite
|
||||
namespace: null
|
||||
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/meta-reference-gpu}/faiss_store.db
|
||||
safety:
|
||||
- provider_id: llama-guard
|
||||
provider_type: inline::llama-guard
|
||||
config: {}
|
||||
agents:
|
||||
- provider_id: meta-reference
|
||||
provider_type: inline::meta-reference
|
||||
config:
|
||||
persistence_store:
|
||||
type: sqlite
|
||||
namespace: null
|
||||
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/meta-reference-gpu}/agents_store.db
|
||||
telemetry:
|
||||
- provider_id: meta-reference
|
||||
provider_type: inline::meta-reference
|
||||
config:
|
||||
service_name: ${env.OTEL_SERVICE_NAME:llama-stack}
|
||||
sinks: ${env.TELEMETRY_SINKS:console,sqlite}
|
||||
sqlite_db_path: ${env.SQLITE_DB_PATH:~/.llama/distributions/meta-reference-gpu/trace_store.db}
|
||||
eval:
|
||||
- provider_id: meta-reference
|
||||
provider_type: inline::meta-reference
|
||||
config: {}
|
||||
datasetio:
|
||||
- provider_id: huggingface
|
||||
provider_type: remote::huggingface
|
||||
config: {}
|
||||
- provider_id: localfs
|
||||
provider_type: inline::localfs
|
||||
config: {}
|
||||
scoring:
|
||||
- provider_id: basic
|
||||
provider_type: inline::basic
|
||||
config: {}
|
||||
- provider_id: llm-as-judge
|
||||
provider_type: inline::llm-as-judge
|
||||
config: {}
|
||||
- provider_id: braintrust
|
||||
provider_type: inline::braintrust
|
||||
config:
|
||||
openai_api_key: ${env.OPENAI_API_KEY:}
|
||||
tool_runtime:
|
||||
- provider_id: brave-search
|
||||
provider_type: remote::brave-search
|
||||
config:
|
||||
api_key: ${env.BRAVE_SEARCH_API_KEY:}
|
||||
max_results: 3
|
||||
- provider_id: tavily-search
|
||||
provider_type: remote::tavily-search
|
||||
config:
|
||||
api_key: ${env.TAVILY_SEARCH_API_KEY:}
|
||||
max_results: 3
|
||||
- provider_id: code-interpreter
|
||||
provider_type: inline::code-interpreter
|
||||
config: {}
|
||||
- provider_id: rag-runtime
|
||||
provider_type: inline::rag-runtime
|
||||
config: {}
|
||||
- provider_id: model-context-protocol
|
||||
provider_type: remote::model-context-protocol
|
||||
config: {}
|
||||
metadata_store:
|
||||
type: sqlite
|
||||
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/meta-reference-gpu}/registry.db
|
||||
models:
|
||||
- metadata: {}
|
||||
model_id: Llama3.3-70B-Instruct
|
||||
provider_id: meta-reference-inference
|
||||
model_type: llm
|
||||
- metadata:
|
||||
embedding_dimension: 384
|
||||
model_id: all-MiniLM-L6-v2
|
||||
provider_id: sentence-transformers
|
||||
model_type: embedding
|
||||
shields: []
|
||||
vector_dbs: []
|
||||
datasets: []
|
||||
scoring_fns: []
|
||||
eval_tasks: []
|
||||
tool_groups:
|
||||
- toolgroup_id: builtin::websearch
|
||||
provider_id: tavily-search
|
||||
- toolgroup_id: builtin::rag
|
||||
provider_id: rag-runtime
|
||||
- toolgroup_id: builtin::code_interpreter
|
||||
provider_id: code-interpreter
|
Loading…
Add table
Add a link
Reference in a new issue