forked from phoenix-oss/llama-stack-mirror
Auto-generate distro yamls + docs (#468)
# What does this PR do? Automatically generates - build.yaml - run.yaml - run-with-safety.yaml - parts of markdown docs for the distributions. ## Test Plan At this point, this only updates the YAMLs and the docs. Some testing (especially with ollama and vllm) has been performed but needs to be much more tested.
This commit is contained in:
parent
0784284ab5
commit
2a31163178
88 changed files with 3008 additions and 852 deletions
|
@ -4,7 +4,7 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Optional
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from llama_models.schema_utils import json_schema_type
|
||||
from pydantic import BaseModel, Field
|
||||
|
@ -20,3 +20,10 @@ class FireworksImplConfig(BaseModel):
|
|||
default=None,
|
||||
description="The Fireworks.ai API Key",
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def sample_run_config(cls) -> Dict[str, Any]:
|
||||
return {
|
||||
"url": "https://api.fireworks.ai/inference",
|
||||
"api_key": "${env.FIREWORKS_API_KEY}",
|
||||
}
|
||||
|
|
|
@ -35,7 +35,7 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
|
|||
from .config import FireworksImplConfig
|
||||
|
||||
|
||||
model_aliases = [
|
||||
MODEL_ALIASES = [
|
||||
build_model_alias(
|
||||
"fireworks/llama-v3p1-8b-instruct",
|
||||
CoreModelId.llama3_1_8b_instruct.value,
|
||||
|
@ -79,7 +79,7 @@ class FireworksInferenceAdapter(
|
|||
ModelRegistryHelper, Inference, NeedsRequestProviderData
|
||||
):
|
||||
def __init__(self, config: FireworksImplConfig) -> None:
|
||||
ModelRegistryHelper.__init__(self, model_aliases)
|
||||
ModelRegistryHelper.__init__(self, MODEL_ALIASES)
|
||||
self.config = config
|
||||
self.formatter = ChatFormat(Tokenizer.get_instance())
|
||||
|
||||
|
|
|
@ -4,14 +4,10 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from llama_stack.distribution.datatypes import RemoteProviderConfig
|
||||
from .config import OllamaImplConfig
|
||||
|
||||
|
||||
class OllamaImplConfig(RemoteProviderConfig):
|
||||
port: int = 11434
|
||||
|
||||
|
||||
async def get_adapter_impl(config: RemoteProviderConfig, _deps):
|
||||
async def get_adapter_impl(config: OllamaImplConfig, _deps):
|
||||
from .ollama import OllamaInferenceAdapter
|
||||
|
||||
impl = OllamaInferenceAdapter(config.url)
|
||||
|
|
22
llama_stack/providers/remote/inference/ollama/config.py
Normal file
22
llama_stack/providers/remote/inference/ollama/config.py
Normal file
|
@ -0,0 +1,22 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any, Dict
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
DEFAULT_OLLAMA_URL = "http://localhost:11434"
|
||||
|
||||
|
||||
class OllamaImplConfig(BaseModel):
|
||||
url: str = DEFAULT_OLLAMA_URL
|
||||
|
||||
@classmethod
|
||||
def sample_run_config(
|
||||
cls, url: str = "${env.OLLAMA_URL:http://localhost:11434}", **kwargs
|
||||
) -> Dict[str, Any]:
|
||||
return {"url": url}
|
|
@ -82,7 +82,7 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
|
|||
return AsyncClient(host=self.url)
|
||||
|
||||
async def initialize(self) -> None:
|
||||
print("Initializing Ollama, checking connectivity to server...")
|
||||
print(f"checking connectivity to Ollama at `{self.url}`...")
|
||||
try:
|
||||
await self.client.ps()
|
||||
except httpx.ConnectError as e:
|
||||
|
|
|
@ -12,19 +12,20 @@ from pydantic import BaseModel, Field
|
|||
|
||||
@json_schema_type
|
||||
class TGIImplConfig(BaseModel):
|
||||
host: str = "localhost"
|
||||
port: int = 8080
|
||||
protocol: str = "http"
|
||||
|
||||
@property
|
||||
def url(self) -> str:
|
||||
return f"{self.protocol}://{self.host}:{self.port}"
|
||||
|
||||
url: str = Field(
|
||||
description="The URL for the TGI serving endpoint",
|
||||
)
|
||||
api_token: Optional[str] = Field(
|
||||
default=None,
|
||||
description="A bearer token if your TGI endpoint is protected.",
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def sample_run_config(cls, url: str = "${env.TGI_URL}", **kwargs):
|
||||
return {
|
||||
"url": url,
|
||||
}
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class InferenceEndpointImplConfig(BaseModel):
|
||||
|
|
|
@ -4,7 +4,7 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Optional
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from llama_models.schema_utils import json_schema_type
|
||||
from pydantic import BaseModel, Field
|
||||
|
@ -20,3 +20,10 @@ class TogetherImplConfig(BaseModel):
|
|||
default=None,
|
||||
description="The Together AI API Key",
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def sample_run_config(cls) -> Dict[str, Any]:
|
||||
return {
|
||||
"url": "https://api.together.xyz/v1",
|
||||
"api_key": "${env.TOGETHER_API_KEY}",
|
||||
}
|
||||
|
|
|
@ -38,7 +38,7 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
|
|||
from .config import TogetherImplConfig
|
||||
|
||||
|
||||
model_aliases = [
|
||||
MODEL_ALIASES = [
|
||||
build_model_alias(
|
||||
"meta-llama/Meta-Llama-3.1-8B-Instruct-Turbo",
|
||||
CoreModelId.llama3_1_8b_instruct.value,
|
||||
|
@ -78,7 +78,7 @@ class TogetherInferenceAdapter(
|
|||
ModelRegistryHelper, Inference, NeedsRequestProviderData
|
||||
):
|
||||
def __init__(self, config: TogetherImplConfig) -> None:
|
||||
ModelRegistryHelper.__init__(self, model_aliases)
|
||||
ModelRegistryHelper.__init__(self, MODEL_ALIASES)
|
||||
self.config = config
|
||||
self.formatter = ChatFormat(Tokenizer.get_instance())
|
||||
|
||||
|
|
|
@ -24,3 +24,15 @@ class VLLMInferenceAdapterConfig(BaseModel):
|
|||
default="fake",
|
||||
description="The API token",
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def sample_run_config(
|
||||
cls,
|
||||
url: str = "${env.VLLM_URL}",
|
||||
**kwargs,
|
||||
):
|
||||
return {
|
||||
"url": url,
|
||||
"max_tokens": "${env.VLLM_MAX_TOKENS:4096}",
|
||||
"api_token": "${env.VLLM_API_TOKEN:fake}",
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue