Auto-generate distro yamls + docs (#468)

# What does this PR do?

Automatically generates
- build.yaml
- run.yaml
- run-with-safety.yaml
- parts of markdown docs

for the distributions.

## Test Plan

At this point, this only updates the YAMLs and the docs. Some testing
(especially with ollama and vllm) has been performed but needs to be
much more tested.
This commit is contained in:
Ashwin Bharambe 2024-11-18 14:57:06 -08:00 committed by GitHub
parent 0784284ab5
commit 2a31163178
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
88 changed files with 3008 additions and 852 deletions

View file

@ -4,7 +4,7 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Optional
from typing import Any, Dict, Optional
from llama_models.schema_utils import json_schema_type
from pydantic import BaseModel, Field
@ -20,3 +20,10 @@ class FireworksImplConfig(BaseModel):
default=None,
description="The Fireworks.ai API Key",
)
@classmethod
def sample_run_config(cls) -> Dict[str, Any]:
return {
"url": "https://api.fireworks.ai/inference",
"api_key": "${env.FIREWORKS_API_KEY}",
}

View file

@ -35,7 +35,7 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
from .config import FireworksImplConfig
model_aliases = [
MODEL_ALIASES = [
build_model_alias(
"fireworks/llama-v3p1-8b-instruct",
CoreModelId.llama3_1_8b_instruct.value,
@ -79,7 +79,7 @@ class FireworksInferenceAdapter(
ModelRegistryHelper, Inference, NeedsRequestProviderData
):
def __init__(self, config: FireworksImplConfig) -> None:
ModelRegistryHelper.__init__(self, model_aliases)
ModelRegistryHelper.__init__(self, MODEL_ALIASES)
self.config = config
self.formatter = ChatFormat(Tokenizer.get_instance())

View file

@ -4,14 +4,10 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_stack.distribution.datatypes import RemoteProviderConfig
from .config import OllamaImplConfig
class OllamaImplConfig(RemoteProviderConfig):
port: int = 11434
async def get_adapter_impl(config: RemoteProviderConfig, _deps):
async def get_adapter_impl(config: OllamaImplConfig, _deps):
from .ollama import OllamaInferenceAdapter
impl = OllamaInferenceAdapter(config.url)

View file

@ -0,0 +1,22 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any, Dict
from pydantic import BaseModel
DEFAULT_OLLAMA_URL = "http://localhost:11434"
class OllamaImplConfig(BaseModel):
url: str = DEFAULT_OLLAMA_URL
@classmethod
def sample_run_config(
cls, url: str = "${env.OLLAMA_URL:http://localhost:11434}", **kwargs
) -> Dict[str, Any]:
return {"url": url}

View file

@ -82,7 +82,7 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
return AsyncClient(host=self.url)
async def initialize(self) -> None:
print("Initializing Ollama, checking connectivity to server...")
print(f"checking connectivity to Ollama at `{self.url}`...")
try:
await self.client.ps()
except httpx.ConnectError as e:

View file

@ -12,19 +12,20 @@ from pydantic import BaseModel, Field
@json_schema_type
class TGIImplConfig(BaseModel):
host: str = "localhost"
port: int = 8080
protocol: str = "http"
@property
def url(self) -> str:
return f"{self.protocol}://{self.host}:{self.port}"
url: str = Field(
description="The URL for the TGI serving endpoint",
)
api_token: Optional[str] = Field(
default=None,
description="A bearer token if your TGI endpoint is protected.",
)
@classmethod
def sample_run_config(cls, url: str = "${env.TGI_URL}", **kwargs):
return {
"url": url,
}
@json_schema_type
class InferenceEndpointImplConfig(BaseModel):

View file

@ -4,7 +4,7 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Optional
from typing import Any, Dict, Optional
from llama_models.schema_utils import json_schema_type
from pydantic import BaseModel, Field
@ -20,3 +20,10 @@ class TogetherImplConfig(BaseModel):
default=None,
description="The Together AI API Key",
)
@classmethod
def sample_run_config(cls) -> Dict[str, Any]:
return {
"url": "https://api.together.xyz/v1",
"api_key": "${env.TOGETHER_API_KEY}",
}

View file

@ -38,7 +38,7 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
from .config import TogetherImplConfig
model_aliases = [
MODEL_ALIASES = [
build_model_alias(
"meta-llama/Meta-Llama-3.1-8B-Instruct-Turbo",
CoreModelId.llama3_1_8b_instruct.value,
@ -78,7 +78,7 @@ class TogetherInferenceAdapter(
ModelRegistryHelper, Inference, NeedsRequestProviderData
):
def __init__(self, config: TogetherImplConfig) -> None:
ModelRegistryHelper.__init__(self, model_aliases)
ModelRegistryHelper.__init__(self, MODEL_ALIASES)
self.config = config
self.formatter = ChatFormat(Tokenizer.get_instance())

View file

@ -24,3 +24,15 @@ class VLLMInferenceAdapterConfig(BaseModel):
default="fake",
description="The API token",
)
@classmethod
def sample_run_config(
cls,
url: str = "${env.VLLM_URL}",
**kwargs,
):
return {
"url": url,
"max_tokens": "${env.VLLM_MAX_TOKENS:4096}",
"api_token": "${env.VLLM_API_TOKEN:fake}",
}