diff --git a/llama_stack/apis/models/client.py b/llama_stack/apis/models/client.py new file mode 100644 index 000000000..929265f9e --- /dev/null +++ b/llama_stack/apis/models/client.py @@ -0,0 +1,72 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import asyncio +import json +from pathlib import Path + +from typing import Any, Dict, List, Optional + +import fire +import httpx + +from llama_stack.distribution.datatypes import RemoteProviderConfig +from termcolor import cprint + +from .models import * # noqa: F403 + + +class ModelsClient(Models): + def __init__(self, base_url: str): + self.base_url = base_url + + async def initialize(self) -> None: + pass + + async def shutdown(self) -> None: + pass + + async def list_models(self) -> ModelsListResponse: + async with httpx.AsyncClient() as client: + response = await client.get( + f"{self.base_url}/models/list", + headers={"Content-Type": "application/json"}, + ) + response.raise_for_status() + return ModelsListResponse(**response.json()) + + async def get_model(self, core_model_id: str) -> ModelsGetResponse: + async with httpx.AsyncClient() as client: + response = await client.post( + f"{self.base_url}/models/get", + json={ + "core_model_id": core_model_id, + }, + headers={"Content-Type": "application/json"}, + ) + response.raise_for_status() + return ModelsGetResponse(**response.json()) + + +async def run_main(host: str, port: int, stream: bool): + client = ModelsClient(f"http://{host}:{port}") + + response = await client.list_models() + cprint(f"list_models response={response}", "green") + + response = await client.get_model("Meta-Llama3.1-8B-Instruct") + cprint(f"get_model response={response}", "blue") + + response = await client.get_model("Llama-Guard-3-8B") + cprint(f"get_model response={response}", "red") + + +def main(host: str, port: int, stream: bool = True): + asyncio.run(run_main(host, port, stream)) + + +if __name__ == "__main__": + fire.Fire(main) diff --git a/llama_stack/distribution/datatypes.py b/llama_stack/distribution/datatypes.py index 7f15e0495..29923e0bd 100644 --- a/llama_stack/distribution/datatypes.py +++ b/llama_stack/distribution/datatypes.py @@ -130,6 +130,10 @@ Fully-qualified name of the module to import. The module is expected to have: provider_data_validator: Optional[str] = Field( default=None, ) + supported_model_ids: List[str] = Field( + default_factory=list, + description="The list of model ids that this adapter supports", + ) @json_schema_type diff --git a/llama_stack/examples/router-table-run.yaml b/llama_stack/examples/router-table-run.yaml index 6bc0c6d14..d3287eb38 100644 --- a/llama_stack/examples/router-table-run.yaml +++ b/llama_stack/examples/router-table-run.yaml @@ -4,25 +4,25 @@ docker_image: null conda_env: local apis_to_serve: - inference -- memory +# - memory - telemetry -- agents -- safety +# - agents +# - safety - models provider_map: telemetry: provider_id: meta-reference config: {} - safety: - provider_id: meta-reference - config: - llama_guard_shield: - model: Llama-Guard-3-8B - excluded_categories: [] - disable_input_check: false - disable_output_check: false - prompt_guard_shield: - model: Prompt-Guard-86M + # safety: + # provider_id: meta-reference + # config: + # llama_guard_shield: + # model: Llama-Guard-3-8B + # excluded_categories: [] + # disable_input_check: false + # disable_output_check: false + # prompt_guard_shield: + # model: Prompt-Guard-86M # inference: # provider_id: meta-reference # config: @@ -31,32 +31,29 @@ provider_map: # torch_seed: null # max_seq_len: 4096 # max_batch_size: 1 - inference: - provider_id: remote::ollama - config: - - agents: - provider_id: meta-reference - config: {} -provider_routing_table: # inference: - # - routing_key: Meta-Llama3.1-8B-Instruct - # provider_id: meta-reference - # config: - # model: Meta-Llama3.1-8B-Instruct - # quantization: null - # torch_seed: null - # max_seq_len: 4096 - # max_batch_size: 1 - # - routing_key: Meta-Llama3.1-8B-Instruct - # provider_id: meta-reference - # config: - # model: Meta-Llama3.1-8B - # quantization: null - # torch_seed: null - # max_seq_len: 4096 - # max_batch_size: 1 - memory: + # provider_id: remote::ollama + # config: + # url: https:ollama-1.com + # agents: + # provider_id: meta-reference + # config: {} +provider_routing_table: + inference: + - routing_key: Meta-Llama3.1-8B-Instruct + provider_id: meta-reference + config: + model: Meta-Llama3.1-8B-Instruct + quantization: null + torch_seed: null + max_seq_len: 4096 + max_batch_size: 1 + - routing_key: Meta-Llama3.1-8B + provider_id: remote::ollama + config: + url: https:://ollama.com + + # memory: # - routing_key: keyvalue # provider_id: remote::pgvector # config: @@ -65,6 +62,6 @@ provider_routing_table: # db: vectordb # user: vectoruser # password: xxxx - - routing_key: vector - provider_id: meta-reference - config: {} + # - routing_key: vector + # provider_id: meta-reference + # config: {} diff --git a/llama_stack/examples/simple-local-run.yaml b/llama_stack/examples/simple-local-run.yaml index d4e3d202e..b628894c1 100644 --- a/llama_stack/examples/simple-local-run.yaml +++ b/llama_stack/examples/simple-local-run.yaml @@ -7,6 +7,7 @@ apis_to_serve: - safety - agents - memory +- models provider_map: inference: provider_id: meta-reference @@ -16,6 +17,10 @@ provider_map: torch_seed: null max_seq_len: 4096 max_batch_size: 1 + # inference: + # provider_id: remote::ollama + # config: + # url: https://xxx safety: provider_id: meta-reference config: diff --git a/llama_stack/providers/adapters/inference/fireworks/fireworks.py b/llama_stack/providers/adapters/inference/fireworks/fireworks.py index 1e6f2e753..6115d7d09 100644 --- a/llama_stack/providers/adapters/inference/fireworks/fireworks.py +++ b/llama_stack/providers/adapters/inference/fireworks/fireworks.py @@ -6,14 +6,14 @@ from typing import AsyncGenerator +from fireworks.client import Fireworks + from llama_models.llama3.api.chat_format import ChatFormat from llama_models.llama3.api.datatypes import Message, StopReason from llama_models.llama3.api.tokenizer import Tokenizer from llama_models.sku_list import resolve_model -from fireworks.client import Fireworks - from llama_stack.apis.inference import * # noqa: F403 from llama_stack.providers.utils.inference.prepare_messages import prepare_messages @@ -42,7 +42,14 @@ class FireworksInferenceAdapter(Inference): async def shutdown(self) -> None: pass - async def completion(self, request: CompletionRequest) -> AsyncGenerator: + async def completion( + self, + model: str, + content: InterleavedTextMedia, + sampling_params: Optional[SamplingParams] = SamplingParams(), + stream: Optional[bool] = False, + logprobs: Optional[LogProbConfig] = None, + ) -> AsyncGenerator: raise NotImplementedError() def _messages_to_fireworks_messages(self, messages: list[Message]) -> list: diff --git a/llama_stack/providers/adapters/inference/ollama/ollama.py b/llama_stack/providers/adapters/inference/ollama/ollama.py index ea726ff75..296fb61a6 100644 --- a/llama_stack/providers/adapters/inference/ollama/ollama.py +++ b/llama_stack/providers/adapters/inference/ollama/ollama.py @@ -30,25 +30,33 @@ OLLAMA_SUPPORTED_SKUS = { class OllamaInferenceAdapter(Inference): def __init__(self, url: str) -> None: self.url = url - tokenizer = Tokenizer.get_instance() - self.formatter = ChatFormat(tokenizer) + # tokenizer = Tokenizer.get_instance() + # self.formatter = ChatFormat(tokenizer) @property def client(self) -> AsyncClient: return AsyncClient(host=self.url) async def initialize(self) -> None: - try: - await self.client.ps() - except httpx.ConnectError as e: - raise RuntimeError( - "Ollama Server is not running, start it using `ollama serve` in a separate terminal" - ) from e + print("Ollama init") + # try: + # await self.client.ps() + # except httpx.ConnectError as e: + # raise RuntimeError( + # "Ollama Server is not running, start it using `ollama serve` in a separate terminal" + # ) from e async def shutdown(self) -> None: pass - async def completion(self, request: CompletionRequest) -> AsyncGenerator: + async def completion( + self, + model: str, + content: InterleavedTextMedia, + sampling_params: Optional[SamplingParams] = SamplingParams(), + stream: Optional[bool] = False, + logprobs: Optional[LogProbConfig] = None, + ) -> AsyncGenerator: raise NotImplementedError() def _messages_to_ollama_messages(self, messages: list[Message]) -> list: diff --git a/llama_stack/providers/adapters/inference/tgi/tgi.py b/llama_stack/providers/adapters/inference/tgi/tgi.py index 6c3b38347..6a385896d 100644 --- a/llama_stack/providers/adapters/inference/tgi/tgi.py +++ b/llama_stack/providers/adapters/inference/tgi/tgi.py @@ -54,7 +54,14 @@ class TGIAdapter(Inference): async def shutdown(self) -> None: pass - async def completion(self, request: CompletionRequest) -> AsyncGenerator: + async def completion( + self, + model: str, + content: InterleavedTextMedia, + sampling_params: Optional[SamplingParams] = SamplingParams(), + stream: Optional[bool] = False, + logprobs: Optional[LogProbConfig] = None, + ) -> AsyncGenerator: raise NotImplementedError() def get_chat_options(self, request: ChatCompletionRequest) -> dict: diff --git a/llama_stack/providers/adapters/inference/together/together.py b/llama_stack/providers/adapters/inference/together/together.py index 565130883..2d747351b 100644 --- a/llama_stack/providers/adapters/inference/together/together.py +++ b/llama_stack/providers/adapters/inference/together/together.py @@ -42,7 +42,14 @@ class TogetherInferenceAdapter(Inference): async def shutdown(self) -> None: pass - async def completion(self, request: CompletionRequest) -> AsyncGenerator: + async def completion( + self, + model: str, + content: InterleavedTextMedia, + sampling_params: Optional[SamplingParams] = SamplingParams(), + stream: Optional[bool] = False, + logprobs: Optional[LogProbConfig] = None, + ) -> AsyncGenerator: raise NotImplementedError() def _messages_to_together_messages(self, messages: list[Message]) -> list: diff --git a/llama_stack/providers/impls/builtin/models/models.py b/llama_stack/providers/impls/builtin/models/models.py index 842302000..74d1299a4 100644 --- a/llama_stack/providers/impls/builtin/models/models.py +++ b/llama_stack/providers/impls/builtin/models/models.py @@ -10,16 +10,14 @@ from typing import AsyncIterator, Union from llama_models.llama3.api.datatypes import StopReason from llama_models.sku_list import resolve_model +from llama_stack.distribution.distribution import Api, api_providers + from llama_stack.apis.models import * # noqa: F403 from llama_models.llama3.api.datatypes import * # noqa: F403 from llama_models.datatypes import CoreModelId, Model from llama_models.sku_list import resolve_model -from llama_stack.distribution.datatypes import ( - Api, - GenericProviderConfig, - StackRunConfig, -) +from llama_stack.distribution.datatypes import * # noqa: F403 from termcolor import cprint @@ -28,27 +26,24 @@ class BuiltinModelsImpl(Models): self, config: StackRunConfig, ) -> None: - print("BuiltinModelsImpl init") - self.run_config = config self.models = {} - - print("BuiltinModelsImpl run_config", config) - # check against inference & safety api apis_with_models = [Api.inference, Api.safety] + all_providers = api_providers() + for api in apis_with_models: + # check against provider_map (simple case single model) if api.value in config.provider_map: + providers_for_api = all_providers[api] provider_spec = config.provider_map[api.value] core_model_id = provider_spec.config - print("provider_spec", provider_spec) - model_spec = ModelServingSpec( - provider_config=provider_spec, - ) # get supported model ids from the provider - supported_model_ids = self.get_supported_model_ids(provider_spec) + supported_model_ids = self.get_supported_model_ids( + api.value, provider_spec, providers_for_api + ) for model_id in supported_model_ids: self.models[model_id] = ModelServingSpec( llama_model=resolve_model(model_id), @@ -58,21 +53,61 @@ class BuiltinModelsImpl(Models): # check against provider_routing_table (router with multiple models) # with routing table, we use the routing_key as the supported models + if api.value in config.provider_routing_table: + routing_table = config.provider_routing_table[api.value] + for rt_entry in routing_table: + model_id = rt_entry.routing_key + self.models[model_id] = ModelServingSpec( + llama_model=resolve_model(model_id), + provider_config=GenericProviderConfig( + provider_id=rt_entry.provider_id, + config=rt_entry.config, + ), + api=api.value, + ) - def resolve_supported_model_ids(self) -> list[CoreModelId]: - # TODO: for remote providers, provide registry to list supported models + print("BuiltinModelsImpl models", self.models) - return ["Meta-Llama3.1-8B-Instruct"] + def get_supported_model_ids( + self, + api: str, + provider_spec: GenericProviderConfig, + providers_for_api: Dict[str, ProviderSpec], + ) -> List[str]: + serving_models_list = [] + if api == Api.inference.value: + provider_id = provider_spec.provider_id + if provider_id == "meta-reference": + serving_models_list.append(provider_spec.config["model"]) + if provider_id in { + remote_provider_id("ollama"), + remote_provider_id("fireworks"), + remote_provider_id("together"), + }: + adapter_supported_models = providers_for_api[ + provider_id + ].adapter.supported_model_ids + serving_models_list.extend(adapter_supported_models) + elif api == Api.safety.value: + if provider_spec.config and "llama_guard_shield" in provider_spec.config: + llama_guard_shield = provider_spec.config["llama_guard_shield"] + serving_models_list.append(llama_guard_shield["model"]) + if provider_spec.config and "prompt_guard_shield" in provider_spec.config: + prompt_guard_shield = provider_spec.config["prompt_guard_shield"] + serving_models_list.append(prompt_guard_shield["model"]) + else: + raise NotImplementedError(f"Unsupported api {api} for builtin models") + + return serving_models_list async def initialize(self) -> None: pass async def list_models(self) -> ModelsListResponse: - pass - # return ModelsListResponse(models_list=list(self.models.values())) + return ModelsListResponse(models_list=list(self.models.values())) async def get_model(self, core_model_id: str) -> ModelsGetResponse: - pass - # if core_model_id in self.models: - # return ModelsGetResponse(core_model_spec=self.models[core_model_id]) - # raise RuntimeError(f"Cannot find {core_model_id} in model registry") + if core_model_id in self.models: + return ModelsGetResponse(core_model_spec=self.models[core_model_id]) + print(f"Cannot find {core_model_id} in model registry") + return ModelsGetResponse() diff --git a/llama_stack/providers/registry/inference.py b/llama_stack/providers/registry/inference.py index 10b3d6ccc..bf739eefa 100644 --- a/llama_stack/providers/registry/inference.py +++ b/llama_stack/providers/registry/inference.py @@ -32,6 +32,10 @@ def available_providers() -> List[ProviderSpec]: adapter_id="ollama", pip_packages=["ollama"], module="llama_stack.providers.adapters.inference.ollama", + supported_model_ids=[ + "Meta-Llama3.1-8B-Instruct", + "Meta-Llama3.1-70B-Instruct", + ], ), ), remote_provider_spec( @@ -52,6 +56,11 @@ def available_providers() -> List[ProviderSpec]: ], module="llama_stack.providers.adapters.inference.fireworks", config_class="llama_stack.providers.adapters.inference.fireworks.FireworksImplConfig", + supported_model_ids=[ + "Meta-Llama3.1-8B-Instruct", + "Meta-Llama3.1-70B-Instruct", + "Meta-Llama3.1-405B-Instruct", + ], ), ), remote_provider_spec( @@ -64,6 +73,11 @@ def available_providers() -> List[ProviderSpec]: module="llama_stack.providers.adapters.inference.together", config_class="llama_stack.providers.adapters.inference.together.TogetherImplConfig", header_extractor_class="llama_stack.providers.adapters.inference.together.TogetherHeaderExtractor", + supported_model_ids=[ + "Meta-Llama3.1-8B-Instruct", + "Meta-Llama3.1-70B-Instruct", + "Meta-Llama3.1-405B-Instruct", + ], ), ), ]