forked from phoenix-oss/llama-stack-mirror
feat: add (openai, anthropic, gemini) providers via litellm (#1267)
# What does this PR do? This PR introduces more non-llama model support to llama stack. Providers introduced: openai, anthropic and gemini. All of these providers use essentially the same piece of code -- the implementation works via the `litellm` library. We will expose only specific models for providers we enable making sure they all work well and pass tests. This setup (instead of automatically enabling _all_ providers and models allowed by LiteLLM) ensures we can also perform any needed prompt tuning on a per-model basis as needed (just like we do it for llama models.) ## Test Plan ```bash #!/bin/bash args=("$@") for model in openai/gpt-4o anthropic/claude-3-5-sonnet-latest gemini/gemini-1.5-flash; do LLAMA_STACK_CONFIG=dev pytest -s -v tests/client-sdk/inference/test_text_inference.py \ --embedding-model=all-MiniLM-L6-v2 \ --vision-inference-model="" \ --inference-model=$model "${args[@]}" done ```
This commit is contained in:
parent
b0310af177
commit
63e6acd0c3
25 changed files with 1048 additions and 33 deletions
|
@ -136,6 +136,42 @@
|
||||||
"sentence-transformers --no-deps",
|
"sentence-transformers --no-deps",
|
||||||
"torch torchvision --index-url https://download.pytorch.org/whl/cpu"
|
"torch torchvision --index-url https://download.pytorch.org/whl/cpu"
|
||||||
],
|
],
|
||||||
|
"dev": [
|
||||||
|
"aiosqlite",
|
||||||
|
"autoevals",
|
||||||
|
"blobfile",
|
||||||
|
"chardet",
|
||||||
|
"chromadb-client",
|
||||||
|
"datasets",
|
||||||
|
"fastapi",
|
||||||
|
"fire",
|
||||||
|
"fireworks-ai",
|
||||||
|
"httpx",
|
||||||
|
"litellm",
|
||||||
|
"matplotlib",
|
||||||
|
"mcp",
|
||||||
|
"nltk",
|
||||||
|
"numpy",
|
||||||
|
"openai",
|
||||||
|
"opentelemetry-exporter-otlp-proto-http",
|
||||||
|
"opentelemetry-sdk",
|
||||||
|
"pandas",
|
||||||
|
"pillow",
|
||||||
|
"psycopg2-binary",
|
||||||
|
"pymongo",
|
||||||
|
"pypdf",
|
||||||
|
"redis",
|
||||||
|
"requests",
|
||||||
|
"scikit-learn",
|
||||||
|
"scipy",
|
||||||
|
"sentencepiece",
|
||||||
|
"sqlite-vec",
|
||||||
|
"tqdm",
|
||||||
|
"transformers",
|
||||||
|
"uvicorn",
|
||||||
|
"sentence-transformers --no-deps",
|
||||||
|
"torch torchvision --index-url https://download.pytorch.org/whl/cpu"
|
||||||
|
],
|
||||||
"fireworks": [
|
"fireworks": [
|
||||||
"aiosqlite",
|
"aiosqlite",
|
||||||
"autoevals",
|
"autoevals",
|
||||||
|
|
|
@ -207,6 +207,33 @@ def available_providers() -> List[ProviderSpec]:
|
||||||
config_class="llama_stack.providers.remote.inference.runpod.RunpodImplConfig",
|
config_class="llama_stack.providers.remote.inference.runpod.RunpodImplConfig",
|
||||||
),
|
),
|
||||||
),
|
),
|
||||||
|
remote_provider_spec(
|
||||||
|
api=Api.inference,
|
||||||
|
adapter=AdapterSpec(
|
||||||
|
adapter_type="openai",
|
||||||
|
pip_packages=["litellm"],
|
||||||
|
module="llama_stack.providers.remote.inference.openai",
|
||||||
|
config_class="llama_stack.providers.remote.inference.openai.OpenAIConfig",
|
||||||
|
),
|
||||||
|
),
|
||||||
|
remote_provider_spec(
|
||||||
|
api=Api.inference,
|
||||||
|
adapter=AdapterSpec(
|
||||||
|
adapter_type="anthropic",
|
||||||
|
pip_packages=["litellm"],
|
||||||
|
module="llama_stack.providers.remote.inference.anthropic",
|
||||||
|
config_class="llama_stack.providers.remote.inference.anthropic.AnthropicConfig",
|
||||||
|
),
|
||||||
|
),
|
||||||
|
remote_provider_spec(
|
||||||
|
api=Api.inference,
|
||||||
|
adapter=AdapterSpec(
|
||||||
|
adapter_type="gemini",
|
||||||
|
pip_packages=["litellm"],
|
||||||
|
module="llama_stack.providers.remote.inference.gemini",
|
||||||
|
config_class="llama_stack.providers.remote.inference.gemini.GeminiConfig",
|
||||||
|
),
|
||||||
|
),
|
||||||
remote_provider_spec(
|
remote_provider_spec(
|
||||||
api=Api.inference,
|
api=Api.inference,
|
||||||
adapter=AdapterSpec(
|
adapter=AdapterSpec(
|
||||||
|
|
23
llama_stack/providers/remote/inference/anthropic/__init__.py
Normal file
23
llama_stack/providers/remote/inference/anthropic/__init__.py
Normal file
|
@ -0,0 +1,23 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
from .config import AnthropicConfig
|
||||||
|
|
||||||
|
|
||||||
|
class AnthropicProviderDataValidator(BaseModel):
|
||||||
|
anthropic_api_key: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
|
async def get_adapter_impl(config: AnthropicConfig, _deps):
|
||||||
|
from .anthropic import AnthropicInferenceAdapter
|
||||||
|
|
||||||
|
impl = AnthropicInferenceAdapter(config)
|
||||||
|
await impl.initialize()
|
||||||
|
return impl
|
|
@ -0,0 +1,22 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from llama_stack.providers.utils.inference.litellm_openai_mixin import LiteLLMOpenAIMixin
|
||||||
|
|
||||||
|
from .config import AnthropicConfig
|
||||||
|
from .models import MODEL_ENTRIES
|
||||||
|
|
||||||
|
|
||||||
|
class AnthropicInferenceAdapter(LiteLLMOpenAIMixin):
|
||||||
|
def __init__(self, config: AnthropicConfig) -> None:
|
||||||
|
LiteLLMOpenAIMixin.__init__(self, MODEL_ENTRIES)
|
||||||
|
self.config = config
|
||||||
|
|
||||||
|
async def initialize(self) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def shutdown(self) -> None:
|
||||||
|
pass
|
25
llama_stack/providers/remote/inference/anthropic/config.py
Normal file
25
llama_stack/providers/remote/inference/anthropic/config.py
Normal file
|
@ -0,0 +1,25 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from typing import Any, Dict, Optional
|
||||||
|
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
from llama_stack.schema_utils import json_schema_type
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class AnthropicConfig(BaseModel):
|
||||||
|
api_key: Optional[str] = Field(
|
||||||
|
default=None,
|
||||||
|
description="API key for Anthropic models",
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def sample_run_config(cls, api_key: str = "${env.ANTHROPIC_API_KEY}", **kwargs) -> Dict[str, Any]:
|
||||||
|
return {
|
||||||
|
"api_key": api_key,
|
||||||
|
}
|
35
llama_stack/providers/remote/inference/anthropic/models.py
Normal file
35
llama_stack/providers/remote/inference/anthropic/models.py
Normal file
|
@ -0,0 +1,35 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from llama_stack.apis.models.models import ModelType
|
||||||
|
from llama_stack.providers.utils.inference.model_registry import (
|
||||||
|
ProviderModelEntry,
|
||||||
|
)
|
||||||
|
|
||||||
|
LLM_MODEL_IDS = [
|
||||||
|
"anthropic/claude-3-5-sonnet-latest",
|
||||||
|
"anthropic/claude-3-7-sonnet-latest",
|
||||||
|
"anthropic/claude-3-5-haiku-latest",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
MODEL_ENTRIES = [ProviderModelEntry(provider_model_id=m) for m in LLM_MODEL_IDS] + [
|
||||||
|
ProviderModelEntry(
|
||||||
|
provider_model_id="anthropic/voyage-3",
|
||||||
|
model_type=ModelType.embedding,
|
||||||
|
metadata={"embedding_dimension": 1024, "context_length": 32000},
|
||||||
|
),
|
||||||
|
ProviderModelEntry(
|
||||||
|
provider_model_id="anthropic/voyage-3-lite",
|
||||||
|
model_type=ModelType.embedding,
|
||||||
|
metadata={"embedding_dimension": 512, "context_length": 32000},
|
||||||
|
),
|
||||||
|
ProviderModelEntry(
|
||||||
|
provider_model_id="anthropic/voyage-code-3",
|
||||||
|
model_type=ModelType.embedding,
|
||||||
|
metadata={"embedding_dimension": 1024, "context_length": 32000},
|
||||||
|
),
|
||||||
|
]
|
|
@ -23,8 +23,8 @@ class FireworksImplConfig(BaseModel):
|
||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def sample_run_config(cls, **kwargs) -> Dict[str, Any]:
|
def sample_run_config(cls, api_key: str = "${env.FIREWORKS_API_KEY}", **kwargs) -> Dict[str, Any]:
|
||||||
return {
|
return {
|
||||||
"url": "https://api.fireworks.ai/inference/v1",
|
"url": "https://api.fireworks.ai/inference/v1",
|
||||||
"api_key": "${env.FIREWORKS_API_KEY}",
|
"api_key": api_key,
|
||||||
}
|
}
|
||||||
|
|
23
llama_stack/providers/remote/inference/gemini/__init__.py
Normal file
23
llama_stack/providers/remote/inference/gemini/__init__.py
Normal file
|
@ -0,0 +1,23 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
from .config import GeminiConfig
|
||||||
|
|
||||||
|
|
||||||
|
class GeminiProviderDataValidator(BaseModel):
|
||||||
|
gemini_api_key: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
|
async def get_adapter_impl(config: GeminiConfig, _deps):
|
||||||
|
from .gemini import GeminiInferenceAdapter
|
||||||
|
|
||||||
|
impl = GeminiInferenceAdapter(config)
|
||||||
|
await impl.initialize()
|
||||||
|
return impl
|
25
llama_stack/providers/remote/inference/gemini/config.py
Normal file
25
llama_stack/providers/remote/inference/gemini/config.py
Normal file
|
@ -0,0 +1,25 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from typing import Any, Dict, Optional
|
||||||
|
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
from llama_stack.schema_utils import json_schema_type
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class GeminiConfig(BaseModel):
|
||||||
|
api_key: Optional[str] = Field(
|
||||||
|
default=None,
|
||||||
|
description="API key for Gemini models",
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def sample_run_config(cls, api_key: str = "${env.GEMINI_API_KEY}", **kwargs) -> Dict[str, Any]:
|
||||||
|
return {
|
||||||
|
"api_key": api_key,
|
||||||
|
}
|
22
llama_stack/providers/remote/inference/gemini/gemini.py
Normal file
22
llama_stack/providers/remote/inference/gemini/gemini.py
Normal file
|
@ -0,0 +1,22 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from llama_stack.providers.utils.inference.litellm_openai_mixin import LiteLLMOpenAIMixin
|
||||||
|
|
||||||
|
from .config import GeminiConfig
|
||||||
|
from .models import MODEL_ENTRIES
|
||||||
|
|
||||||
|
|
||||||
|
class GeminiInferenceAdapter(LiteLLMOpenAIMixin):
|
||||||
|
def __init__(self, config: GeminiConfig) -> None:
|
||||||
|
LiteLLMOpenAIMixin.__init__(self, MODEL_ENTRIES)
|
||||||
|
self.config = config
|
||||||
|
|
||||||
|
async def initialize(self) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def shutdown(self) -> None:
|
||||||
|
pass
|
24
llama_stack/providers/remote/inference/gemini/models.py
Normal file
24
llama_stack/providers/remote/inference/gemini/models.py
Normal file
|
@ -0,0 +1,24 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from llama_stack.apis.models.models import ModelType
|
||||||
|
from llama_stack.providers.utils.inference.model_registry import (
|
||||||
|
ProviderModelEntry,
|
||||||
|
)
|
||||||
|
|
||||||
|
LLM_MODEL_IDS = [
|
||||||
|
"gemini/gemini-1.5-flash",
|
||||||
|
"gemini/gemini-1.5-pro",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
MODEL_ENTRIES = [ProviderModelEntry(provider_model_id=m) for m in LLM_MODEL_IDS] + [
|
||||||
|
ProviderModelEntry(
|
||||||
|
provider_model_id="gemini/text-embedding-004",
|
||||||
|
model_type=ModelType.embedding,
|
||||||
|
metadata={"embedding_dimension": 768, "context_length": 2048},
|
||||||
|
),
|
||||||
|
]
|
23
llama_stack/providers/remote/inference/openai/__init__.py
Normal file
23
llama_stack/providers/remote/inference/openai/__init__.py
Normal file
|
@ -0,0 +1,23 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
from .config import OpenAIConfig
|
||||||
|
|
||||||
|
|
||||||
|
class OpenAIProviderDataValidator(BaseModel):
|
||||||
|
openai_api_key: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
|
async def get_adapter_impl(config: OpenAIConfig, _deps):
|
||||||
|
from .openai import OpenAIInferenceAdapter
|
||||||
|
|
||||||
|
impl = OpenAIInferenceAdapter(config)
|
||||||
|
await impl.initialize()
|
||||||
|
return impl
|
25
llama_stack/providers/remote/inference/openai/config.py
Normal file
25
llama_stack/providers/remote/inference/openai/config.py
Normal file
|
@ -0,0 +1,25 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from typing import Any, Dict, Optional
|
||||||
|
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
from llama_stack.schema_utils import json_schema_type
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class OpenAIConfig(BaseModel):
|
||||||
|
api_key: Optional[str] = Field(
|
||||||
|
default=None,
|
||||||
|
description="API key for OpenAI models",
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def sample_run_config(cls, api_key: str = "${env.OPENAI_API_KEY}", **kwargs) -> Dict[str, Any]:
|
||||||
|
return {
|
||||||
|
"api_key": api_key,
|
||||||
|
}
|
30
llama_stack/providers/remote/inference/openai/models.py
Normal file
30
llama_stack/providers/remote/inference/openai/models.py
Normal file
|
@ -0,0 +1,30 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from llama_stack.apis.models.models import ModelType
|
||||||
|
from llama_stack.providers.utils.inference.model_registry import (
|
||||||
|
ProviderModelEntry,
|
||||||
|
)
|
||||||
|
|
||||||
|
LLM_MODEL_IDS = [
|
||||||
|
"openai/gpt-4o",
|
||||||
|
"openai/gpt-4o-mini",
|
||||||
|
"openai/chatgpt-4o-latest",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
MODEL_ENTRIES = [ProviderModelEntry(provider_model_id=m) for m in LLM_MODEL_IDS] + [
|
||||||
|
ProviderModelEntry(
|
||||||
|
provider_model_id="openai/text-embedding-3-small",
|
||||||
|
model_type=ModelType.embedding,
|
||||||
|
metadata={"embedding_dimension": 1536},
|
||||||
|
),
|
||||||
|
ProviderModelEntry(
|
||||||
|
provider_model_id="openai/text-embedding-3-large",
|
||||||
|
model_type=ModelType.embedding,
|
||||||
|
metadata={"embedding_dimension": 3072},
|
||||||
|
),
|
||||||
|
]
|
22
llama_stack/providers/remote/inference/openai/openai.py
Normal file
22
llama_stack/providers/remote/inference/openai/openai.py
Normal file
|
@ -0,0 +1,22 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from llama_stack.providers.utils.inference.litellm_openai_mixin import LiteLLMOpenAIMixin
|
||||||
|
|
||||||
|
from .config import OpenAIConfig
|
||||||
|
from .models import MODEL_ENTRIES
|
||||||
|
|
||||||
|
|
||||||
|
class OpenAIInferenceAdapter(LiteLLMOpenAIMixin):
|
||||||
|
def __init__(self, config: OpenAIConfig) -> None:
|
||||||
|
LiteLLMOpenAIMixin.__init__(self, MODEL_ENTRIES)
|
||||||
|
self.config = config
|
||||||
|
|
||||||
|
async def initialize(self) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def shutdown(self) -> None:
|
||||||
|
pass
|
|
@ -40,7 +40,7 @@
|
||||||
"tool_calling": {
|
"tool_calling": {
|
||||||
"data": {
|
"data": {
|
||||||
"messages": [
|
"messages": [
|
||||||
{"role": "system", "content": "You are a helpful assistant."},
|
{"role": "system", "content": "Pretend you are a weather assistant."},
|
||||||
{"role": "user", "content": "What's the weather like in San Francisco?"}
|
{"role": "user", "content": "What's the weather like in San Francisco?"}
|
||||||
],
|
],
|
||||||
"tools": [
|
"tools": [
|
||||||
|
@ -65,7 +65,7 @@
|
||||||
"messages": [
|
"messages": [
|
||||||
{
|
{
|
||||||
"role": "system",
|
"role": "system",
|
||||||
"content": "You are a helpful assistant."
|
"content": "Pretend you are a weather assistant."
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"role": "user",
|
"role": "user",
|
||||||
|
|
171
llama_stack/providers/utils/inference/litellm_openai_mixin.py
Normal file
171
llama_stack/providers/utils/inference/litellm_openai_mixin.py
Normal file
|
@ -0,0 +1,171 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from typing import AsyncGenerator, AsyncIterator, List, Optional, Union
|
||||||
|
|
||||||
|
import litellm
|
||||||
|
|
||||||
|
from llama_stack.apis.common.content_types import (
|
||||||
|
InterleavedContent,
|
||||||
|
InterleavedContentItem,
|
||||||
|
)
|
||||||
|
from llama_stack.apis.inference import (
|
||||||
|
ChatCompletionRequest,
|
||||||
|
ChatCompletionResponse,
|
||||||
|
ChatCompletionResponseStreamChunk,
|
||||||
|
EmbeddingsResponse,
|
||||||
|
EmbeddingTaskType,
|
||||||
|
Inference,
|
||||||
|
JsonSchemaResponseFormat,
|
||||||
|
LogProbConfig,
|
||||||
|
Message,
|
||||||
|
ResponseFormat,
|
||||||
|
SamplingParams,
|
||||||
|
TextTruncation,
|
||||||
|
ToolChoice,
|
||||||
|
ToolConfig,
|
||||||
|
ToolDefinition,
|
||||||
|
ToolPromptFormat,
|
||||||
|
)
|
||||||
|
from llama_stack.apis.models.models import Model
|
||||||
|
from llama_stack.providers.utils.inference.model_registry import (
|
||||||
|
ModelRegistryHelper,
|
||||||
|
)
|
||||||
|
from llama_stack.providers.utils.inference.openai_compat import (
|
||||||
|
convert_message_to_openai_dict_new,
|
||||||
|
convert_openai_chat_completion_choice,
|
||||||
|
convert_openai_chat_completion_stream,
|
||||||
|
convert_tooldef_to_openai_tool,
|
||||||
|
get_sampling_options,
|
||||||
|
)
|
||||||
|
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||||
|
interleaved_content_as_str,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class LiteLLMOpenAIMixin(
|
||||||
|
ModelRegistryHelper,
|
||||||
|
Inference,
|
||||||
|
):
|
||||||
|
def __init__(self, model_entries) -> None:
|
||||||
|
self.model_entries = model_entries
|
||||||
|
ModelRegistryHelper.__init__(self, model_entries)
|
||||||
|
|
||||||
|
async def register_model(self, model: Model) -> Model:
|
||||||
|
model_id = self.get_provider_model_id(model.provider_resource_id)
|
||||||
|
if model_id is None:
|
||||||
|
raise ValueError(f"Unsupported model: {model.provider_resource_id}")
|
||||||
|
return model
|
||||||
|
|
||||||
|
async def completion(
|
||||||
|
self,
|
||||||
|
model_id: str,
|
||||||
|
content: InterleavedContent,
|
||||||
|
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
||||||
|
response_format: Optional[ResponseFormat] = None,
|
||||||
|
stream: Optional[bool] = False,
|
||||||
|
logprobs: Optional[LogProbConfig] = None,
|
||||||
|
) -> AsyncGenerator:
|
||||||
|
raise NotImplementedError("LiteLLM does not support completion requests")
|
||||||
|
|
||||||
|
async def chat_completion(
|
||||||
|
self,
|
||||||
|
model_id: str,
|
||||||
|
messages: List[Message],
|
||||||
|
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
||||||
|
tools: Optional[List[ToolDefinition]] = None,
|
||||||
|
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
|
||||||
|
tool_prompt_format: Optional[ToolPromptFormat] = None,
|
||||||
|
response_format: Optional[ResponseFormat] = None,
|
||||||
|
stream: Optional[bool] = False,
|
||||||
|
logprobs: Optional[LogProbConfig] = None,
|
||||||
|
tool_config: Optional[ToolConfig] = None,
|
||||||
|
) -> Union[ChatCompletionResponse, AsyncIterator[ChatCompletionResponseStreamChunk]]:
|
||||||
|
model = await self.model_store.get_model(model_id)
|
||||||
|
request = ChatCompletionRequest(
|
||||||
|
model=model.provider_resource_id,
|
||||||
|
messages=messages,
|
||||||
|
sampling_params=sampling_params,
|
||||||
|
tools=tools or [],
|
||||||
|
response_format=response_format,
|
||||||
|
stream=stream,
|
||||||
|
logprobs=logprobs,
|
||||||
|
tool_config=tool_config,
|
||||||
|
)
|
||||||
|
|
||||||
|
params = await self._get_params(request)
|
||||||
|
|
||||||
|
# unfortunately, we need to use synchronous litellm.completion here because litellm
|
||||||
|
# caches various httpx.client objects in a non-eventloop aware manner
|
||||||
|
response = litellm.completion(**params)
|
||||||
|
if stream:
|
||||||
|
return self._stream_chat_completion(response)
|
||||||
|
else:
|
||||||
|
return convert_openai_chat_completion_choice(response.choices[0])
|
||||||
|
|
||||||
|
async def _stream_chat_completion(
|
||||||
|
self, response: litellm.ModelResponse
|
||||||
|
) -> AsyncIterator[ChatCompletionResponseStreamChunk]:
|
||||||
|
async def _stream_generator():
|
||||||
|
for chunk in response:
|
||||||
|
yield chunk
|
||||||
|
|
||||||
|
async for chunk in convert_openai_chat_completion_stream(
|
||||||
|
_stream_generator(), enable_incremental_tool_calls=True
|
||||||
|
):
|
||||||
|
yield chunk
|
||||||
|
|
||||||
|
async def _get_params(self, request: ChatCompletionRequest) -> dict:
|
||||||
|
input_dict = {}
|
||||||
|
|
||||||
|
input_dict["messages"] = [await convert_message_to_openai_dict_new(m) for m in request.messages]
|
||||||
|
if fmt := request.response_format:
|
||||||
|
if not isinstance(fmt, JsonSchemaResponseFormat):
|
||||||
|
raise ValueError(
|
||||||
|
f"Unsupported response format: {type(fmt)}. Only JsonSchemaResponseFormat is supported."
|
||||||
|
)
|
||||||
|
|
||||||
|
fmt = fmt.json_schema
|
||||||
|
name = fmt["title"]
|
||||||
|
del fmt["title"]
|
||||||
|
fmt["additionalProperties"] = False
|
||||||
|
input_dict["response_format"] = {
|
||||||
|
"type": "json_schema",
|
||||||
|
"json_schema": {
|
||||||
|
"name": name,
|
||||||
|
"schema": fmt,
|
||||||
|
"strict": True,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
if request.tools:
|
||||||
|
input_dict["tools"] = [convert_tooldef_to_openai_tool(tool) for tool in request.tools]
|
||||||
|
if request.tool_config.tool_choice:
|
||||||
|
input_dict["tool_choice"] = request.tool_config.tool_choice.value
|
||||||
|
|
||||||
|
return {
|
||||||
|
"model": request.model,
|
||||||
|
**input_dict,
|
||||||
|
"stream": request.stream,
|
||||||
|
**get_sampling_options(request.sampling_params),
|
||||||
|
}
|
||||||
|
|
||||||
|
async def embeddings(
|
||||||
|
self,
|
||||||
|
model_id: str,
|
||||||
|
contents: List[str] | List[InterleavedContentItem],
|
||||||
|
text_truncation: Optional[TextTruncation] = TextTruncation.none,
|
||||||
|
output_dimension: Optional[int] = None,
|
||||||
|
task_type: Optional[EmbeddingTaskType] = None,
|
||||||
|
) -> EmbeddingsResponse:
|
||||||
|
model = await self.model_store.get_model(model_id)
|
||||||
|
|
||||||
|
response = litellm.embedding(
|
||||||
|
model=model.provider_resource_id,
|
||||||
|
input=[interleaved_content_as_str(content) for content in contents],
|
||||||
|
)
|
||||||
|
|
||||||
|
embeddings = [data["embedding"] for data in response["data"]]
|
||||||
|
return EmbeddingsResponse(embeddings=embeddings)
|
|
@ -57,17 +57,6 @@ def get_distribution_template() -> DistributionTemplate:
|
||||||
config=SentenceTransformersInferenceConfig.sample_run_config(),
|
config=SentenceTransformersInferenceConfig.sample_run_config(),
|
||||||
)
|
)
|
||||||
|
|
||||||
core_model_to_hf_repo = {m.descriptor(): m.huggingface_repo for m in all_registered_models()}
|
|
||||||
default_models = [
|
|
||||||
ModelInput(
|
|
||||||
model_id=core_model_to_hf_repo[m.llama_model] if m.llama_model else m.provider_model_id,
|
|
||||||
provider_model_id=m.provider_model_id,
|
|
||||||
provider_id="fireworks",
|
|
||||||
metadata=m.metadata,
|
|
||||||
model_type=m.model_type,
|
|
||||||
)
|
|
||||||
for m in MODEL_ENTRIES
|
|
||||||
]
|
|
||||||
default_tool_groups = [
|
default_tool_groups = [
|
||||||
ToolGroupInput(
|
ToolGroupInput(
|
||||||
toolgroup_id="builtin::websearch",
|
toolgroup_id="builtin::websearch",
|
||||||
|
@ -82,6 +71,16 @@ def get_distribution_template() -> DistributionTemplate:
|
||||||
provider_id="code-interpreter",
|
provider_id="code-interpreter",
|
||||||
),
|
),
|
||||||
]
|
]
|
||||||
|
core_model_to_hf_repo = {m.descriptor(): m.huggingface_repo for m in all_registered_models()}
|
||||||
|
default_models = [
|
||||||
|
ModelInput(
|
||||||
|
model_id=core_model_to_hf_repo[m.llama_model] if m.llama_model else m.provider_model_id,
|
||||||
|
provider_id="fireworks",
|
||||||
|
model_type=m.model_type,
|
||||||
|
metadata=m.metadata,
|
||||||
|
)
|
||||||
|
for m in MODEL_ENTRIES
|
||||||
|
]
|
||||||
embedding_model = ModelInput(
|
embedding_model = ModelInput(
|
||||||
model_id="all-MiniLM-L6-v2",
|
model_id="all-MiniLM-L6-v2",
|
||||||
provider_id="sentence-transformers",
|
provider_id="sentence-transformers",
|
||||||
|
@ -98,7 +97,7 @@ def get_distribution_template() -> DistributionTemplate:
|
||||||
container_image=None,
|
container_image=None,
|
||||||
template_path=None,
|
template_path=None,
|
||||||
providers=providers,
|
providers=providers,
|
||||||
default_models=default_models,
|
default_models=default_models + [embedding_model],
|
||||||
run_configs={
|
run_configs={
|
||||||
"run.yaml": RunConfigSettings(
|
"run.yaml": RunConfigSettings(
|
||||||
provider_overrides={
|
provider_overrides={
|
||||||
|
|
|
@ -93,59 +93,48 @@ models:
|
||||||
- metadata: {}
|
- metadata: {}
|
||||||
model_id: meta-llama/Llama-3.1-8B-Instruct
|
model_id: meta-llama/Llama-3.1-8B-Instruct
|
||||||
provider_id: fireworks
|
provider_id: fireworks
|
||||||
provider_model_id: accounts/fireworks/models/llama-v3p1-8b-instruct
|
|
||||||
model_type: llm
|
model_type: llm
|
||||||
- metadata: {}
|
- metadata: {}
|
||||||
model_id: meta-llama/Llama-3.1-70B-Instruct
|
model_id: meta-llama/Llama-3.1-70B-Instruct
|
||||||
provider_id: fireworks
|
provider_id: fireworks
|
||||||
provider_model_id: accounts/fireworks/models/llama-v3p1-70b-instruct
|
|
||||||
model_type: llm
|
model_type: llm
|
||||||
- metadata: {}
|
- metadata: {}
|
||||||
model_id: meta-llama/Llama-3.1-405B-Instruct-FP8
|
model_id: meta-llama/Llama-3.1-405B-Instruct-FP8
|
||||||
provider_id: fireworks
|
provider_id: fireworks
|
||||||
provider_model_id: accounts/fireworks/models/llama-v3p1-405b-instruct
|
|
||||||
model_type: llm
|
model_type: llm
|
||||||
- metadata: {}
|
- metadata: {}
|
||||||
model_id: meta-llama/Llama-3.2-1B-Instruct
|
model_id: meta-llama/Llama-3.2-1B-Instruct
|
||||||
provider_id: fireworks
|
provider_id: fireworks
|
||||||
provider_model_id: accounts/fireworks/models/llama-v3p2-1b-instruct
|
|
||||||
model_type: llm
|
model_type: llm
|
||||||
- metadata: {}
|
- metadata: {}
|
||||||
model_id: meta-llama/Llama-3.2-3B-Instruct
|
model_id: meta-llama/Llama-3.2-3B-Instruct
|
||||||
provider_id: fireworks
|
provider_id: fireworks
|
||||||
provider_model_id: accounts/fireworks/models/llama-v3p2-3b-instruct
|
|
||||||
model_type: llm
|
model_type: llm
|
||||||
- metadata: {}
|
- metadata: {}
|
||||||
model_id: meta-llama/Llama-3.2-11B-Vision-Instruct
|
model_id: meta-llama/Llama-3.2-11B-Vision-Instruct
|
||||||
provider_id: fireworks
|
provider_id: fireworks
|
||||||
provider_model_id: accounts/fireworks/models/llama-v3p2-11b-vision-instruct
|
|
||||||
model_type: llm
|
model_type: llm
|
||||||
- metadata: {}
|
- metadata: {}
|
||||||
model_id: meta-llama/Llama-3.2-90B-Vision-Instruct
|
model_id: meta-llama/Llama-3.2-90B-Vision-Instruct
|
||||||
provider_id: fireworks
|
provider_id: fireworks
|
||||||
provider_model_id: accounts/fireworks/models/llama-v3p2-90b-vision-instruct
|
|
||||||
model_type: llm
|
model_type: llm
|
||||||
- metadata: {}
|
- metadata: {}
|
||||||
model_id: meta-llama/Llama-3.3-70B-Instruct
|
model_id: meta-llama/Llama-3.3-70B-Instruct
|
||||||
provider_id: fireworks
|
provider_id: fireworks
|
||||||
provider_model_id: accounts/fireworks/models/llama-v3p3-70b-instruct
|
|
||||||
model_type: llm
|
model_type: llm
|
||||||
- metadata: {}
|
- metadata: {}
|
||||||
model_id: meta-llama/Llama-Guard-3-8B
|
model_id: meta-llama/Llama-Guard-3-8B
|
||||||
provider_id: fireworks
|
provider_id: fireworks
|
||||||
provider_model_id: accounts/fireworks/models/llama-guard-3-8b
|
|
||||||
model_type: llm
|
model_type: llm
|
||||||
- metadata: {}
|
- metadata: {}
|
||||||
model_id: meta-llama/Llama-Guard-3-11B-Vision
|
model_id: meta-llama/Llama-Guard-3-11B-Vision
|
||||||
provider_id: fireworks
|
provider_id: fireworks
|
||||||
provider_model_id: accounts/fireworks/models/llama-guard-3-11b-vision
|
|
||||||
model_type: llm
|
model_type: llm
|
||||||
- metadata:
|
- metadata:
|
||||||
embedding_dimension: 768
|
embedding_dimension: 768
|
||||||
context_length: 8192
|
context_length: 8192
|
||||||
model_id: nomic-ai/nomic-embed-text-v1.5
|
model_id: nomic-ai/nomic-embed-text-v1.5
|
||||||
provider_id: fireworks
|
provider_id: fireworks
|
||||||
provider_model_id: nomic-ai/nomic-embed-text-v1.5
|
|
||||||
model_type: embedding
|
model_type: embedding
|
||||||
- metadata:
|
- metadata:
|
||||||
embedding_dimension: 384
|
embedding_dimension: 384
|
||||||
|
|
7
llama_stack/templates/dev/__init__.py
Normal file
7
llama_stack/templates/dev/__init__.py
Normal file
|
@ -0,0 +1,7 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from .dev import get_distribution_template # noqa: F401
|
36
llama_stack/templates/dev/build.yaml
Normal file
36
llama_stack/templates/dev/build.yaml
Normal file
|
@ -0,0 +1,36 @@
|
||||||
|
version: '2'
|
||||||
|
distribution_spec:
|
||||||
|
description: Distribution for running e2e tests in CI
|
||||||
|
providers:
|
||||||
|
inference:
|
||||||
|
- remote::openai
|
||||||
|
- remote::fireworks
|
||||||
|
- remote::anthropic
|
||||||
|
- remote::gemini
|
||||||
|
- inline::sentence-transformers
|
||||||
|
vector_io:
|
||||||
|
- inline::sqlite-vec
|
||||||
|
- remote::chromadb
|
||||||
|
- remote::pgvector
|
||||||
|
safety:
|
||||||
|
- inline::llama-guard
|
||||||
|
agents:
|
||||||
|
- inline::meta-reference
|
||||||
|
telemetry:
|
||||||
|
- inline::meta-reference
|
||||||
|
eval:
|
||||||
|
- inline::meta-reference
|
||||||
|
datasetio:
|
||||||
|
- remote::huggingface
|
||||||
|
- inline::localfs
|
||||||
|
scoring:
|
||||||
|
- inline::basic
|
||||||
|
- inline::llm-as-judge
|
||||||
|
- inline::braintrust
|
||||||
|
tool_runtime:
|
||||||
|
- remote::brave-search
|
||||||
|
- remote::tavily-search
|
||||||
|
- inline::code-interpreter
|
||||||
|
- inline::rag-runtime
|
||||||
|
- remote::model-context-protocol
|
||||||
|
image_type: conda
|
174
llama_stack/templates/dev/dev.py
Normal file
174
llama_stack/templates/dev/dev.py
Normal file
|
@ -0,0 +1,174 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from typing import List, Tuple
|
||||||
|
|
||||||
|
from llama_stack.apis.models.models import ModelType
|
||||||
|
from llama_stack.distribution.datatypes import (
|
||||||
|
ModelInput,
|
||||||
|
Provider,
|
||||||
|
ShieldInput,
|
||||||
|
ToolGroupInput,
|
||||||
|
)
|
||||||
|
from llama_stack.models.llama.sku_list import all_registered_models
|
||||||
|
from llama_stack.providers.inline.inference.sentence_transformers import (
|
||||||
|
SentenceTransformersInferenceConfig,
|
||||||
|
)
|
||||||
|
from llama_stack.providers.inline.vector_io.sqlite_vec.config import SQLiteVectorIOConfig
|
||||||
|
from llama_stack.providers.remote.inference.anthropic.config import AnthropicConfig
|
||||||
|
from llama_stack.providers.remote.inference.anthropic.models import MODEL_ENTRIES as ANTHROPIC_MODEL_ENTRIES
|
||||||
|
from llama_stack.providers.remote.inference.fireworks.config import FireworksImplConfig
|
||||||
|
from llama_stack.providers.remote.inference.fireworks.models import MODEL_ENTRIES as FIREWORKS_MODEL_ENTRIES
|
||||||
|
from llama_stack.providers.remote.inference.gemini.config import GeminiConfig
|
||||||
|
from llama_stack.providers.remote.inference.gemini.models import MODEL_ENTRIES as GEMINI_MODEL_ENTRIES
|
||||||
|
from llama_stack.providers.remote.inference.openai.config import OpenAIConfig
|
||||||
|
from llama_stack.providers.remote.inference.openai.models import MODEL_ENTRIES as OPENAI_MODEL_ENTRIES
|
||||||
|
from llama_stack.templates.template import DistributionTemplate, RunConfigSettings
|
||||||
|
|
||||||
|
|
||||||
|
def get_inference_providers() -> Tuple[List[Provider], List[ModelInput]]:
|
||||||
|
# in this template, we allow each API key to be optional
|
||||||
|
providers = [
|
||||||
|
(
|
||||||
|
"openai",
|
||||||
|
OPENAI_MODEL_ENTRIES,
|
||||||
|
OpenAIConfig.sample_run_config(api_key="${env.OPENAI_API_KEY:}"),
|
||||||
|
),
|
||||||
|
(
|
||||||
|
"fireworks",
|
||||||
|
FIREWORKS_MODEL_ENTRIES,
|
||||||
|
FireworksImplConfig.sample_run_config(api_key="${env.FIREWORKS_API_KEY:}"),
|
||||||
|
),
|
||||||
|
(
|
||||||
|
"anthropic",
|
||||||
|
ANTHROPIC_MODEL_ENTRIES,
|
||||||
|
AnthropicConfig.sample_run_config(api_key="${env.ANTHROPIC_API_KEY:}"),
|
||||||
|
),
|
||||||
|
(
|
||||||
|
"gemini",
|
||||||
|
GEMINI_MODEL_ENTRIES,
|
||||||
|
GeminiConfig.sample_run_config(api_key="${env.GEMINI_API_KEY:}"),
|
||||||
|
),
|
||||||
|
]
|
||||||
|
inference_providers = []
|
||||||
|
default_models = []
|
||||||
|
core_model_to_hf_repo = {m.descriptor(): m.huggingface_repo for m in all_registered_models()}
|
||||||
|
for provider_id, model_entries, config in providers:
|
||||||
|
inference_providers.append(
|
||||||
|
Provider(
|
||||||
|
provider_id=provider_id,
|
||||||
|
provider_type=f"remote::{provider_id}",
|
||||||
|
config=config,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
default_models.extend(
|
||||||
|
ModelInput(
|
||||||
|
model_id=core_model_to_hf_repo[m.llama_model] if m.llama_model else m.provider_model_id,
|
||||||
|
provider_model_id=m.provider_model_id,
|
||||||
|
provider_id=provider_id,
|
||||||
|
model_type=m.model_type,
|
||||||
|
metadata=m.metadata,
|
||||||
|
)
|
||||||
|
for m in model_entries
|
||||||
|
)
|
||||||
|
return inference_providers, default_models
|
||||||
|
|
||||||
|
|
||||||
|
def get_distribution_template() -> DistributionTemplate:
|
||||||
|
providers = {
|
||||||
|
"inference": [
|
||||||
|
"remote::openai",
|
||||||
|
"remote::fireworks",
|
||||||
|
"remote::anthropic",
|
||||||
|
"remote::gemini",
|
||||||
|
"inline::sentence-transformers",
|
||||||
|
],
|
||||||
|
"vector_io": ["inline::sqlite-vec", "remote::chromadb", "remote::pgvector"],
|
||||||
|
"safety": ["inline::llama-guard"],
|
||||||
|
"agents": ["inline::meta-reference"],
|
||||||
|
"telemetry": ["inline::meta-reference"],
|
||||||
|
"eval": ["inline::meta-reference"],
|
||||||
|
"datasetio": ["remote::huggingface", "inline::localfs"],
|
||||||
|
"scoring": ["inline::basic", "inline::llm-as-judge", "inline::braintrust"],
|
||||||
|
"tool_runtime": [
|
||||||
|
"remote::brave-search",
|
||||||
|
"remote::tavily-search",
|
||||||
|
"inline::code-interpreter",
|
||||||
|
"inline::rag-runtime",
|
||||||
|
"remote::model-context-protocol",
|
||||||
|
],
|
||||||
|
}
|
||||||
|
name = "dev"
|
||||||
|
|
||||||
|
vector_io_provider = Provider(
|
||||||
|
provider_id="sqlite-vec",
|
||||||
|
provider_type="inline::sqlite-vec",
|
||||||
|
config=SQLiteVectorIOConfig.sample_run_config(f"distributions/{name}"),
|
||||||
|
)
|
||||||
|
embedding_provider = Provider(
|
||||||
|
provider_id="sentence-transformers",
|
||||||
|
provider_type="inline::sentence-transformers",
|
||||||
|
config=SentenceTransformersInferenceConfig.sample_run_config(),
|
||||||
|
)
|
||||||
|
|
||||||
|
default_tool_groups = [
|
||||||
|
ToolGroupInput(
|
||||||
|
toolgroup_id="builtin::websearch",
|
||||||
|
provider_id="tavily-search",
|
||||||
|
),
|
||||||
|
ToolGroupInput(
|
||||||
|
toolgroup_id="builtin::rag",
|
||||||
|
provider_id="rag-runtime",
|
||||||
|
),
|
||||||
|
ToolGroupInput(
|
||||||
|
toolgroup_id="builtin::code_interpreter",
|
||||||
|
provider_id="code-interpreter",
|
||||||
|
),
|
||||||
|
]
|
||||||
|
embedding_model = ModelInput(
|
||||||
|
model_id="all-MiniLM-L6-v2",
|
||||||
|
provider_id=embedding_provider.provider_id,
|
||||||
|
model_type=ModelType.embedding,
|
||||||
|
metadata={
|
||||||
|
"embedding_dimension": 384,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
inference_providers, default_models = get_inference_providers()
|
||||||
|
|
||||||
|
return DistributionTemplate(
|
||||||
|
name=name,
|
||||||
|
distro_type="self_hosted",
|
||||||
|
description="Distribution for running e2e tests in CI",
|
||||||
|
container_image=None,
|
||||||
|
template_path=None,
|
||||||
|
providers=providers,
|
||||||
|
default_models=[],
|
||||||
|
run_configs={
|
||||||
|
"run.yaml": RunConfigSettings(
|
||||||
|
provider_overrides={
|
||||||
|
"inference": inference_providers + [embedding_provider],
|
||||||
|
"vector_io": [vector_io_provider],
|
||||||
|
},
|
||||||
|
default_models=default_models + [embedding_model],
|
||||||
|
default_tool_groups=default_tool_groups,
|
||||||
|
default_shields=[ShieldInput(shield_id="meta-llama/Llama-Guard-3-8B")],
|
||||||
|
),
|
||||||
|
},
|
||||||
|
run_config_env_vars={
|
||||||
|
"LLAMA_STACK_PORT": (
|
||||||
|
"5001",
|
||||||
|
"Port for the Llama Stack distribution server",
|
||||||
|
),
|
||||||
|
"FIREWORKS_API_KEY": (
|
||||||
|
"",
|
||||||
|
"Fireworks API Key",
|
||||||
|
),
|
||||||
|
"OPENAI_API_KEY": (
|
||||||
|
"",
|
||||||
|
"OpenAI API Key",
|
||||||
|
),
|
||||||
|
},
|
||||||
|
)
|
261
llama_stack/templates/dev/run.yaml
Normal file
261
llama_stack/templates/dev/run.yaml
Normal file
|
@ -0,0 +1,261 @@
|
||||||
|
version: '2'
|
||||||
|
image_name: dev
|
||||||
|
apis:
|
||||||
|
- agents
|
||||||
|
- datasetio
|
||||||
|
- eval
|
||||||
|
- inference
|
||||||
|
- safety
|
||||||
|
- scoring
|
||||||
|
- telemetry
|
||||||
|
- tool_runtime
|
||||||
|
- vector_io
|
||||||
|
providers:
|
||||||
|
inference:
|
||||||
|
- provider_id: openai
|
||||||
|
provider_type: remote::openai
|
||||||
|
config:
|
||||||
|
api_key: ${env.OPENAI_API_KEY:}
|
||||||
|
- provider_id: fireworks
|
||||||
|
provider_type: remote::fireworks
|
||||||
|
config:
|
||||||
|
url: https://api.fireworks.ai/inference/v1
|
||||||
|
api_key: ${env.FIREWORKS_API_KEY:}
|
||||||
|
- provider_id: anthropic
|
||||||
|
provider_type: remote::anthropic
|
||||||
|
config:
|
||||||
|
api_key: ${env.ANTHROPIC_API_KEY:}
|
||||||
|
- provider_id: gemini
|
||||||
|
provider_type: remote::gemini
|
||||||
|
config:
|
||||||
|
api_key: ${env.GEMINI_API_KEY:}
|
||||||
|
- provider_id: sentence-transformers
|
||||||
|
provider_type: inline::sentence-transformers
|
||||||
|
config: {}
|
||||||
|
vector_io:
|
||||||
|
- provider_id: sqlite-vec
|
||||||
|
provider_type: inline::sqlite-vec
|
||||||
|
config:
|
||||||
|
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/dev}/sqlite_vec.db
|
||||||
|
safety:
|
||||||
|
- provider_id: llama-guard
|
||||||
|
provider_type: inline::llama-guard
|
||||||
|
config: {}
|
||||||
|
agents:
|
||||||
|
- provider_id: meta-reference
|
||||||
|
provider_type: inline::meta-reference
|
||||||
|
config:
|
||||||
|
persistence_store:
|
||||||
|
type: sqlite
|
||||||
|
namespace: null
|
||||||
|
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/dev}/agents_store.db
|
||||||
|
telemetry:
|
||||||
|
- provider_id: meta-reference
|
||||||
|
provider_type: inline::meta-reference
|
||||||
|
config:
|
||||||
|
service_name: ${env.OTEL_SERVICE_NAME:llama-stack}
|
||||||
|
sinks: ${env.TELEMETRY_SINKS:console,sqlite}
|
||||||
|
sqlite_db_path: ${env.SQLITE_DB_PATH:~/.llama/distributions/dev/trace_store.db}
|
||||||
|
eval:
|
||||||
|
- provider_id: meta-reference
|
||||||
|
provider_type: inline::meta-reference
|
||||||
|
config: {}
|
||||||
|
datasetio:
|
||||||
|
- provider_id: huggingface
|
||||||
|
provider_type: remote::huggingface
|
||||||
|
config: {}
|
||||||
|
- provider_id: localfs
|
||||||
|
provider_type: inline::localfs
|
||||||
|
config: {}
|
||||||
|
scoring:
|
||||||
|
- provider_id: basic
|
||||||
|
provider_type: inline::basic
|
||||||
|
config: {}
|
||||||
|
- provider_id: llm-as-judge
|
||||||
|
provider_type: inline::llm-as-judge
|
||||||
|
config: {}
|
||||||
|
- provider_id: braintrust
|
||||||
|
provider_type: inline::braintrust
|
||||||
|
config:
|
||||||
|
openai_api_key: ${env.OPENAI_API_KEY:}
|
||||||
|
tool_runtime:
|
||||||
|
- provider_id: brave-search
|
||||||
|
provider_type: remote::brave-search
|
||||||
|
config:
|
||||||
|
api_key: ${env.BRAVE_SEARCH_API_KEY:}
|
||||||
|
max_results: 3
|
||||||
|
- provider_id: tavily-search
|
||||||
|
provider_type: remote::tavily-search
|
||||||
|
config:
|
||||||
|
api_key: ${env.TAVILY_SEARCH_API_KEY:}
|
||||||
|
max_results: 3
|
||||||
|
- provider_id: code-interpreter
|
||||||
|
provider_type: inline::code-interpreter
|
||||||
|
config: {}
|
||||||
|
- provider_id: rag-runtime
|
||||||
|
provider_type: inline::rag-runtime
|
||||||
|
config: {}
|
||||||
|
- provider_id: model-context-protocol
|
||||||
|
provider_type: remote::model-context-protocol
|
||||||
|
config: {}
|
||||||
|
metadata_store:
|
||||||
|
type: sqlite
|
||||||
|
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/dev}/registry.db
|
||||||
|
models:
|
||||||
|
- metadata: {}
|
||||||
|
model_id: openai/gpt-4o
|
||||||
|
provider_id: openai
|
||||||
|
provider_model_id: openai/gpt-4o
|
||||||
|
model_type: llm
|
||||||
|
- metadata: {}
|
||||||
|
model_id: openai/gpt-4o-mini
|
||||||
|
provider_id: openai
|
||||||
|
provider_model_id: openai/gpt-4o-mini
|
||||||
|
model_type: llm
|
||||||
|
- metadata: {}
|
||||||
|
model_id: openai/chatgpt-4o-latest
|
||||||
|
provider_id: openai
|
||||||
|
provider_model_id: openai/chatgpt-4o-latest
|
||||||
|
model_type: llm
|
||||||
|
- metadata:
|
||||||
|
embedding_dimension: 1536
|
||||||
|
model_id: openai/text-embedding-3-small
|
||||||
|
provider_id: openai
|
||||||
|
provider_model_id: openai/text-embedding-3-small
|
||||||
|
model_type: embedding
|
||||||
|
- metadata:
|
||||||
|
embedding_dimension: 3072
|
||||||
|
model_id: openai/text-embedding-3-large
|
||||||
|
provider_id: openai
|
||||||
|
provider_model_id: openai/text-embedding-3-large
|
||||||
|
model_type: embedding
|
||||||
|
- metadata: {}
|
||||||
|
model_id: meta-llama/Llama-3.1-8B-Instruct
|
||||||
|
provider_id: fireworks
|
||||||
|
provider_model_id: accounts/fireworks/models/llama-v3p1-8b-instruct
|
||||||
|
model_type: llm
|
||||||
|
- metadata: {}
|
||||||
|
model_id: meta-llama/Llama-3.1-70B-Instruct
|
||||||
|
provider_id: fireworks
|
||||||
|
provider_model_id: accounts/fireworks/models/llama-v3p1-70b-instruct
|
||||||
|
model_type: llm
|
||||||
|
- metadata: {}
|
||||||
|
model_id: meta-llama/Llama-3.1-405B-Instruct-FP8
|
||||||
|
provider_id: fireworks
|
||||||
|
provider_model_id: accounts/fireworks/models/llama-v3p1-405b-instruct
|
||||||
|
model_type: llm
|
||||||
|
- metadata: {}
|
||||||
|
model_id: meta-llama/Llama-3.2-1B-Instruct
|
||||||
|
provider_id: fireworks
|
||||||
|
provider_model_id: accounts/fireworks/models/llama-v3p2-1b-instruct
|
||||||
|
model_type: llm
|
||||||
|
- metadata: {}
|
||||||
|
model_id: meta-llama/Llama-3.2-3B-Instruct
|
||||||
|
provider_id: fireworks
|
||||||
|
provider_model_id: accounts/fireworks/models/llama-v3p2-3b-instruct
|
||||||
|
model_type: llm
|
||||||
|
- metadata: {}
|
||||||
|
model_id: meta-llama/Llama-3.2-11B-Vision-Instruct
|
||||||
|
provider_id: fireworks
|
||||||
|
provider_model_id: accounts/fireworks/models/llama-v3p2-11b-vision-instruct
|
||||||
|
model_type: llm
|
||||||
|
- metadata: {}
|
||||||
|
model_id: meta-llama/Llama-3.2-90B-Vision-Instruct
|
||||||
|
provider_id: fireworks
|
||||||
|
provider_model_id: accounts/fireworks/models/llama-v3p2-90b-vision-instruct
|
||||||
|
model_type: llm
|
||||||
|
- metadata: {}
|
||||||
|
model_id: meta-llama/Llama-3.3-70B-Instruct
|
||||||
|
provider_id: fireworks
|
||||||
|
provider_model_id: accounts/fireworks/models/llama-v3p3-70b-instruct
|
||||||
|
model_type: llm
|
||||||
|
- metadata: {}
|
||||||
|
model_id: meta-llama/Llama-Guard-3-8B
|
||||||
|
provider_id: fireworks
|
||||||
|
provider_model_id: accounts/fireworks/models/llama-guard-3-8b
|
||||||
|
model_type: llm
|
||||||
|
- metadata: {}
|
||||||
|
model_id: meta-llama/Llama-Guard-3-11B-Vision
|
||||||
|
provider_id: fireworks
|
||||||
|
provider_model_id: accounts/fireworks/models/llama-guard-3-11b-vision
|
||||||
|
model_type: llm
|
||||||
|
- metadata:
|
||||||
|
embedding_dimension: 768
|
||||||
|
context_length: 8192
|
||||||
|
model_id: nomic-ai/nomic-embed-text-v1.5
|
||||||
|
provider_id: fireworks
|
||||||
|
provider_model_id: nomic-ai/nomic-embed-text-v1.5
|
||||||
|
model_type: embedding
|
||||||
|
- metadata: {}
|
||||||
|
model_id: anthropic/claude-3-5-sonnet-latest
|
||||||
|
provider_id: anthropic
|
||||||
|
provider_model_id: anthropic/claude-3-5-sonnet-latest
|
||||||
|
model_type: llm
|
||||||
|
- metadata: {}
|
||||||
|
model_id: anthropic/claude-3-7-sonnet-latest
|
||||||
|
provider_id: anthropic
|
||||||
|
provider_model_id: anthropic/claude-3-7-sonnet-latest
|
||||||
|
model_type: llm
|
||||||
|
- metadata: {}
|
||||||
|
model_id: anthropic/claude-3-5-haiku-latest
|
||||||
|
provider_id: anthropic
|
||||||
|
provider_model_id: anthropic/claude-3-5-haiku-latest
|
||||||
|
model_type: llm
|
||||||
|
- metadata:
|
||||||
|
embedding_dimension: 1024
|
||||||
|
context_length: 32000
|
||||||
|
model_id: anthropic/voyage-3
|
||||||
|
provider_id: anthropic
|
||||||
|
provider_model_id: anthropic/voyage-3
|
||||||
|
model_type: embedding
|
||||||
|
- metadata:
|
||||||
|
embedding_dimension: 512
|
||||||
|
context_length: 32000
|
||||||
|
model_id: anthropic/voyage-3-lite
|
||||||
|
provider_id: anthropic
|
||||||
|
provider_model_id: anthropic/voyage-3-lite
|
||||||
|
model_type: embedding
|
||||||
|
- metadata:
|
||||||
|
embedding_dimension: 1024
|
||||||
|
context_length: 32000
|
||||||
|
model_id: anthropic/voyage-code-3
|
||||||
|
provider_id: anthropic
|
||||||
|
provider_model_id: anthropic/voyage-code-3
|
||||||
|
model_type: embedding
|
||||||
|
- metadata: {}
|
||||||
|
model_id: gemini/gemini-1.5-flash
|
||||||
|
provider_id: gemini
|
||||||
|
provider_model_id: gemini/gemini-1.5-flash
|
||||||
|
model_type: llm
|
||||||
|
- metadata: {}
|
||||||
|
model_id: gemini/gemini-1.5-pro
|
||||||
|
provider_id: gemini
|
||||||
|
provider_model_id: gemini/gemini-1.5-pro
|
||||||
|
model_type: llm
|
||||||
|
- metadata:
|
||||||
|
embedding_dimension: 768
|
||||||
|
context_length: 2048
|
||||||
|
model_id: gemini/text-embedding-004
|
||||||
|
provider_id: gemini
|
||||||
|
provider_model_id: gemini/text-embedding-004
|
||||||
|
model_type: embedding
|
||||||
|
- metadata:
|
||||||
|
embedding_dimension: 384
|
||||||
|
model_id: all-MiniLM-L6-v2
|
||||||
|
provider_id: sentence-transformers
|
||||||
|
model_type: embedding
|
||||||
|
shields:
|
||||||
|
- shield_id: meta-llama/Llama-Guard-3-8B
|
||||||
|
vector_dbs: []
|
||||||
|
datasets: []
|
||||||
|
scoring_fns: []
|
||||||
|
benchmarks: []
|
||||||
|
tool_groups:
|
||||||
|
- toolgroup_id: builtin::websearch
|
||||||
|
provider_id: tavily-search
|
||||||
|
- toolgroup_id: builtin::rag
|
||||||
|
provider_id: rag-runtime
|
||||||
|
- toolgroup_id: builtin::code_interpreter
|
||||||
|
provider_id: code-interpreter
|
||||||
|
server:
|
||||||
|
port: 8321
|
|
@ -116,12 +116,14 @@ def client_with_models(llama_stack_client, text_model_id, vision_model_id, embed
|
||||||
providers = [p for p in client.providers.list() if p.api == "inference"]
|
providers = [p for p in client.providers.list() if p.api == "inference"]
|
||||||
assert len(providers) > 0, "No inference providers found"
|
assert len(providers) > 0, "No inference providers found"
|
||||||
inference_providers = [p.provider_id for p in providers if p.provider_type != "inline::sentence-transformers"]
|
inference_providers = [p.provider_id for p in providers if p.provider_type != "inline::sentence-transformers"]
|
||||||
if text_model_id:
|
|
||||||
|
model_ids = [m.identifier for m in client.models.list()]
|
||||||
|
if text_model_id and text_model_id not in model_ids:
|
||||||
client.models.register(model_id=text_model_id, provider_id=inference_providers[0])
|
client.models.register(model_id=text_model_id, provider_id=inference_providers[0])
|
||||||
if vision_model_id:
|
if vision_model_id and vision_model_id not in model_ids:
|
||||||
client.models.register(model_id=vision_model_id, provider_id=inference_providers[0])
|
client.models.register(model_id=vision_model_id, provider_id=inference_providers[0])
|
||||||
|
|
||||||
if embedding_model_id and embedding_dimension:
|
if embedding_model_id and embedding_dimension and embedding_model_id not in model_ids:
|
||||||
# try to find a provider that supports embeddings, if sentence-transformers is not available
|
# try to find a provider that supports embeddings, if sentence-transformers is not available
|
||||||
selected_provider = None
|
selected_provider = None
|
||||||
for p in providers:
|
for p in providers:
|
||||||
|
|
|
@ -19,6 +19,16 @@ PROVIDER_TOOL_PROMPT_FORMAT = {
|
||||||
PROVIDER_LOGPROBS_TOP_K = {"remote::together", "remote::fireworks", "remote::vllm"}
|
PROVIDER_LOGPROBS_TOP_K = {"remote::together", "remote::fireworks", "remote::vllm"}
|
||||||
|
|
||||||
|
|
||||||
|
def skip_if_model_doesnt_support_completion(client_with_models, model_id):
|
||||||
|
models = {m.identifier: m for m in client_with_models.models.list()}
|
||||||
|
provider_id = models[model_id].provider_id
|
||||||
|
providers = {p.provider_id: p for p in client_with_models.providers.list()}
|
||||||
|
provider = providers[provider_id]
|
||||||
|
print(f"Provider: {provider.provider_type} for model {model_id}")
|
||||||
|
if provider.provider_type in ("remote::openai", "remote::anthropic", "remote::gemini"):
|
||||||
|
pytest.skip(f"Model {model_id} hosted by {provider.provider_type} doesn't support completion")
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="session")
|
@pytest.fixture(scope="session")
|
||||||
def provider_tool_format(inference_provider_type):
|
def provider_tool_format(inference_provider_type):
|
||||||
return (
|
return (
|
||||||
|
@ -35,6 +45,7 @@ def provider_tool_format(inference_provider_type):
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
def test_text_completion_non_streaming(client_with_models, text_model_id, test_case):
|
def test_text_completion_non_streaming(client_with_models, text_model_id, test_case):
|
||||||
|
skip_if_model_doesnt_support_completion(client_with_models, text_model_id)
|
||||||
tc = TestCase(test_case)
|
tc = TestCase(test_case)
|
||||||
|
|
||||||
response = client_with_models.inference.completion(
|
response = client_with_models.inference.completion(
|
||||||
|
@ -56,6 +67,7 @@ def test_text_completion_non_streaming(client_with_models, text_model_id, test_c
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
def test_text_completion_streaming(client_with_models, text_model_id, test_case):
|
def test_text_completion_streaming(client_with_models, text_model_id, test_case):
|
||||||
|
skip_if_model_doesnt_support_completion(client_with_models, text_model_id)
|
||||||
tc = TestCase(test_case)
|
tc = TestCase(test_case)
|
||||||
|
|
||||||
response = client_with_models.inference.completion(
|
response = client_with_models.inference.completion(
|
||||||
|
@ -79,6 +91,7 @@ def test_text_completion_streaming(client_with_models, text_model_id, test_case)
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
def test_text_completion_log_probs_non_streaming(client_with_models, text_model_id, inference_provider_type, test_case):
|
def test_text_completion_log_probs_non_streaming(client_with_models, text_model_id, inference_provider_type, test_case):
|
||||||
|
skip_if_model_doesnt_support_completion(client_with_models, text_model_id)
|
||||||
if inference_provider_type not in PROVIDER_LOGPROBS_TOP_K:
|
if inference_provider_type not in PROVIDER_LOGPROBS_TOP_K:
|
||||||
pytest.xfail(f"{inference_provider_type} doesn't support log probs yet")
|
pytest.xfail(f"{inference_provider_type} doesn't support log probs yet")
|
||||||
|
|
||||||
|
@ -107,6 +120,7 @@ def test_text_completion_log_probs_non_streaming(client_with_models, text_model_
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
def test_text_completion_log_probs_streaming(client_with_models, text_model_id, inference_provider_type, test_case):
|
def test_text_completion_log_probs_streaming(client_with_models, text_model_id, inference_provider_type, test_case):
|
||||||
|
skip_if_model_doesnt_support_completion(client_with_models, text_model_id)
|
||||||
if inference_provider_type not in PROVIDER_LOGPROBS_TOP_K:
|
if inference_provider_type not in PROVIDER_LOGPROBS_TOP_K:
|
||||||
pytest.xfail(f"{inference_provider_type} doesn't support log probs yet")
|
pytest.xfail(f"{inference_provider_type} doesn't support log probs yet")
|
||||||
|
|
||||||
|
@ -139,6 +153,8 @@ def test_text_completion_log_probs_streaming(client_with_models, text_model_id,
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
def test_text_completion_structured_output(client_with_models, text_model_id, test_case):
|
def test_text_completion_structured_output(client_with_models, text_model_id, test_case):
|
||||||
|
skip_if_model_doesnt_support_completion(client_with_models, text_model_id)
|
||||||
|
|
||||||
class AnswerFormat(BaseModel):
|
class AnswerFormat(BaseModel):
|
||||||
name: str
|
name: str
|
||||||
year_born: str
|
year_born: str
|
||||||
|
@ -237,9 +253,7 @@ def test_text_chat_completion_with_tool_calling_and_non_streaming(
|
||||||
tool_prompt_format=tool_prompt_format,
|
tool_prompt_format=tool_prompt_format,
|
||||||
stream=False,
|
stream=False,
|
||||||
)
|
)
|
||||||
# No content is returned for the system message since we expect the
|
# some models can return content for the response in addition to the tool call
|
||||||
# response to be a tool call
|
|
||||||
assert response.completion_message.content == ""
|
|
||||||
assert response.completion_message.role == "assistant"
|
assert response.completion_message.role == "assistant"
|
||||||
|
|
||||||
assert len(response.completion_message.tool_calls) == 1
|
assert len(response.completion_message.tool_calls) == 1
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue