diff --git a/llama_stack/providers/adapters/inference/ollama/ollama.py b/llama_stack/providers/adapters/inference/ollama/ollama.py index ea726ff75..9be4f4935 100644 --- a/llama_stack/providers/adapters/inference/ollama/ollama.py +++ b/llama_stack/providers/adapters/inference/ollama/ollama.py @@ -17,6 +17,7 @@ from ollama import AsyncClient from llama_stack.apis.inference import * # noqa: F403 from llama_stack.providers.utils.inference.prepare_messages import prepare_messages +from termcolor import cprint # TODO: Eventually this will move to the llama cli model list command # mapping of Model SKUs to ollama models @@ -38,12 +39,13 @@ class OllamaInferenceAdapter(Inference): return AsyncClient(host=self.url) async def initialize(self) -> None: - try: - await self.client.ps() - except httpx.ConnectError as e: - raise RuntimeError( - "Ollama Server is not running, start it using `ollama serve` in a separate terminal" - ) from e + pass + # try: + # await self.client.ps() + # except httpx.ConnectError as e: + # raise RuntimeError( + # "Ollama Server is not running, start it using `ollama serve` in a separate terminal" + # ) from e async def shutdown(self) -> None: pass @@ -96,166 +98,167 @@ class OllamaInferenceAdapter(Inference): stream: Optional[bool] = False, logprobs: Optional[LogProbConfig] = None, ) -> AsyncGenerator: - request = ChatCompletionRequest( - model=model, - messages=messages, - sampling_params=sampling_params, - tools=tools or [], - tool_choice=tool_choice, - tool_prompt_format=tool_prompt_format, - stream=stream, - logprobs=logprobs, - ) + cprint("!! calling remote ollama !!", "red") + # request = ChatCompletionRequest( + # model=model, + # messages=messages, + # sampling_params=sampling_params, + # tools=tools or [], + # tool_choice=tool_choice, + # tool_prompt_format=tool_prompt_format, + # stream=stream, + # logprobs=logprobs, + # ) - messages = prepare_messages(request) - # accumulate sampling params and other options to pass to ollama - options = self.get_ollama_chat_options(request) - ollama_model = self.resolve_ollama_model(request.model) + # messages = prepare_messages(request) + # # accumulate sampling params and other options to pass to ollama + # options = self.get_ollama_chat_options(request) + # ollama_model = self.resolve_ollama_model(request.model) - res = await self.client.ps() - need_model_pull = True - for r in res["models"]: - if ollama_model == r["model"]: - need_model_pull = False - break + # res = await self.client.ps() + # need_model_pull = True + # for r in res["models"]: + # if ollama_model == r["model"]: + # need_model_pull = False + # break - if need_model_pull: - print(f"Pulling model: {ollama_model}") - status = await self.client.pull(ollama_model) - assert ( - status["status"] == "success" - ), f"Failed to pull model {self.model} in ollama" + # if need_model_pull: + # print(f"Pulling model: {ollama_model}") + # status = await self.client.pull(ollama_model) + # assert ( + # status["status"] == "success" + # ), f"Failed to pull model {self.model} in ollama" - if not request.stream: - r = await self.client.chat( - model=ollama_model, - messages=self._messages_to_ollama_messages(messages), - stream=False, - options=options, - ) - stop_reason = None - if r["done"]: - if r["done_reason"] == "stop": - stop_reason = StopReason.end_of_turn - elif r["done_reason"] == "length": - stop_reason = StopReason.out_of_tokens + # if not request.stream: + # r = await self.client.chat( + # model=ollama_model, + # messages=self._messages_to_ollama_messages(messages), + # stream=False, + # options=options, + # ) + # stop_reason = None + # if r["done"]: + # if r["done_reason"] == "stop": + # stop_reason = StopReason.end_of_turn + # elif r["done_reason"] == "length": + # stop_reason = StopReason.out_of_tokens - completion_message = self.formatter.decode_assistant_message_from_content( - r["message"]["content"], stop_reason - ) - yield ChatCompletionResponse( - completion_message=completion_message, - logprobs=None, - ) - else: - yield ChatCompletionResponseStreamChunk( - event=ChatCompletionResponseEvent( - event_type=ChatCompletionResponseEventType.start, - delta="", - ) - ) - stream = await self.client.chat( - model=ollama_model, - messages=self._messages_to_ollama_messages(messages), - stream=True, - options=options, - ) + # completion_message = self.formatter.decode_assistant_message_from_content( + # r["message"]["content"], stop_reason + # ) + # yield ChatCompletionResponse( + # completion_message=completion_message, + # logprobs=None, + # ) + # else: + # yield ChatCompletionResponseStreamChunk( + # event=ChatCompletionResponseEvent( + # event_type=ChatCompletionResponseEventType.start, + # delta="", + # ) + # ) + # stream = await self.client.chat( + # model=ollama_model, + # messages=self._messages_to_ollama_messages(messages), + # stream=True, + # options=options, + # ) - buffer = "" - ipython = False - stop_reason = None + # buffer = "" + # ipython = False + # stop_reason = None - async for chunk in stream: - if chunk["done"]: - if stop_reason is None and chunk["done_reason"] == "stop": - stop_reason = StopReason.end_of_turn - elif stop_reason is None and chunk["done_reason"] == "length": - stop_reason = StopReason.out_of_tokens - break + # async for chunk in stream: + # if chunk["done"]: + # if stop_reason is None and chunk["done_reason"] == "stop": + # stop_reason = StopReason.end_of_turn + # elif stop_reason is None and chunk["done_reason"] == "length": + # stop_reason = StopReason.out_of_tokens + # break - text = chunk["message"]["content"] + # text = chunk["message"]["content"] - # check if its a tool call ( aka starts with <|python_tag|> ) - if not ipython and text.startswith("<|python_tag|>"): - ipython = True - yield ChatCompletionResponseStreamChunk( - event=ChatCompletionResponseEvent( - event_type=ChatCompletionResponseEventType.progress, - delta=ToolCallDelta( - content="", - parse_status=ToolCallParseStatus.started, - ), - ) - ) - buffer += text - continue + # # check if its a tool call ( aka starts with <|python_tag|> ) + # if not ipython and text.startswith("<|python_tag|>"): + # ipython = True + # yield ChatCompletionResponseStreamChunk( + # event=ChatCompletionResponseEvent( + # event_type=ChatCompletionResponseEventType.progress, + # delta=ToolCallDelta( + # content="", + # parse_status=ToolCallParseStatus.started, + # ), + # ) + # ) + # buffer += text + # continue - if ipython: - if text == "<|eot_id|>": - stop_reason = StopReason.end_of_turn - text = "" - continue - elif text == "<|eom_id|>": - stop_reason = StopReason.end_of_message - text = "" - continue + # if ipython: + # if text == "<|eot_id|>": + # stop_reason = StopReason.end_of_turn + # text = "" + # continue + # elif text == "<|eom_id|>": + # stop_reason = StopReason.end_of_message + # text = "" + # continue - buffer += text - delta = ToolCallDelta( - content=text, - parse_status=ToolCallParseStatus.in_progress, - ) + # buffer += text + # delta = ToolCallDelta( + # content=text, + # parse_status=ToolCallParseStatus.in_progress, + # ) - yield ChatCompletionResponseStreamChunk( - event=ChatCompletionResponseEvent( - event_type=ChatCompletionResponseEventType.progress, - delta=delta, - stop_reason=stop_reason, - ) - ) - else: - buffer += text - yield ChatCompletionResponseStreamChunk( - event=ChatCompletionResponseEvent( - event_type=ChatCompletionResponseEventType.progress, - delta=text, - stop_reason=stop_reason, - ) - ) + # yield ChatCompletionResponseStreamChunk( + # event=ChatCompletionResponseEvent( + # event_type=ChatCompletionResponseEventType.progress, + # delta=delta, + # stop_reason=stop_reason, + # ) + # ) + # else: + # buffer += text + # yield ChatCompletionResponseStreamChunk( + # event=ChatCompletionResponseEvent( + # event_type=ChatCompletionResponseEventType.progress, + # delta=text, + # stop_reason=stop_reason, + # ) + # ) - # parse tool calls and report errors - message = self.formatter.decode_assistant_message_from_content( - buffer, stop_reason - ) - parsed_tool_calls = len(message.tool_calls) > 0 - if ipython and not parsed_tool_calls: - yield ChatCompletionResponseStreamChunk( - event=ChatCompletionResponseEvent( - event_type=ChatCompletionResponseEventType.progress, - delta=ToolCallDelta( - content="", - parse_status=ToolCallParseStatus.failure, - ), - stop_reason=stop_reason, - ) - ) + # # parse tool calls and report errors + # message = self.formatter.decode_assistant_message_from_content( + # buffer, stop_reason + # ) + # parsed_tool_calls = len(message.tool_calls) > 0 + # if ipython and not parsed_tool_calls: + # yield ChatCompletionResponseStreamChunk( + # event=ChatCompletionResponseEvent( + # event_type=ChatCompletionResponseEventType.progress, + # delta=ToolCallDelta( + # content="", + # parse_status=ToolCallParseStatus.failure, + # ), + # stop_reason=stop_reason, + # ) + # ) - for tool_call in message.tool_calls: - yield ChatCompletionResponseStreamChunk( - event=ChatCompletionResponseEvent( - event_type=ChatCompletionResponseEventType.progress, - delta=ToolCallDelta( - content=tool_call, - parse_status=ToolCallParseStatus.success, - ), - stop_reason=stop_reason, - ) - ) + # for tool_call in message.tool_calls: + # yield ChatCompletionResponseStreamChunk( + # event=ChatCompletionResponseEvent( + # event_type=ChatCompletionResponseEventType.progress, + # delta=ToolCallDelta( + # content=tool_call, + # parse_status=ToolCallParseStatus.success, + # ), + # stop_reason=stop_reason, + # ) + # ) - yield ChatCompletionResponseStreamChunk( - event=ChatCompletionResponseEvent( - event_type=ChatCompletionResponseEventType.complete, - delta="", - stop_reason=stop_reason, - ) - ) + # yield ChatCompletionResponseStreamChunk( + # event=ChatCompletionResponseEvent( + # event_type=ChatCompletionResponseEventType.complete, + # delta="", + # stop_reason=stop_reason, + # ) + # ) diff --git a/llama_stack/providers/impls/meta_reference/models/models.py b/llama_stack/providers/impls/meta_reference/models/models.py index c3a2048c0..ee5b5c339 100644 --- a/llama_stack/providers/impls/meta_reference/models/models.py +++ b/llama_stack/providers/impls/meta_reference/models/models.py @@ -16,6 +16,9 @@ from llama_models.datatypes import CoreModelId, Model from llama_models.sku_list import resolve_model from llama_stack.apis.inference import Inference from llama_stack.apis.safety import Safety +from llama_stack.providers.adapters.inference.ollama.ollama import ( + OllamaInferenceAdapter, +) from llama_stack.providers.impls.meta_reference.inference.inference import ( MetaReferenceInferenceImpl, @@ -23,6 +26,7 @@ from llama_stack.providers.impls.meta_reference.inference.inference import ( from llama_stack.providers.impls.meta_reference.safety.safety import ( MetaReferenceSafetyImpl, ) +from llama_stack.providers.routers.inference.inference import InferenceRouterImpl from .config import MetaReferenceImplConfig @@ -39,7 +43,7 @@ class MetaReferenceModelsImpl(Models): self.safety_api = safety_api self.models_list = [] - model = get_model_id_from_api(self.inference_api) + # model = get_model_id_from_api(self.inference_api) # TODO, make the inference route provider and use router provider to do the lookup dynamically if isinstance( @@ -56,6 +60,25 @@ class MetaReferenceModelsImpl(Models): ) ) + if isinstance( + self.inference_api, + OllamaInferenceAdapter, + ): + self.models_list.append( + ModelSpec( + providers_spec={ + "inference": [{"provider_type": "remote::ollama"}], + }, + ) + ) + + if isinstance( + self.inference_api, + InferenceRouterImpl, + ): + print("Found router") + print(self.inference_api.providers) + if isinstance( self.safety_api, MetaReferenceSafetyImpl, diff --git a/llama_stack/providers/routers/inference/__init__.py b/llama_stack/providers/routers/inference/__init__.py new file mode 100644 index 000000000..c6619ffc9 --- /dev/null +++ b/llama_stack/providers/routers/inference/__init__.py @@ -0,0 +1,17 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from typing import Any, List, Tuple + +from llama_stack.distribution.datatypes import Api + + +async def get_router_impl(inner_impls: List[Tuple[str, Any]], deps: List[Api]): + from .inference import InferenceRouterImpl + + impl = InferenceRouterImpl(inner_impls, deps) + await impl.initialize() + return impl diff --git a/llama_stack/providers/routers/inference/inference.py b/llama_stack/providers/routers/inference/inference.py new file mode 100644 index 000000000..f459439d5 --- /dev/null +++ b/llama_stack/providers/routers/inference/inference.py @@ -0,0 +1,48 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from typing import Any, Dict, List, Tuple + +from llama_stack.distribution.datatypes import Api +from llama_stack.apis.inference import * # noqa: F403 +from llama_stack.providers.registry.inference import available_providers + + +class InferenceRouterImpl(Inference): + """Routes to an provider based on the memory bank type""" + + def __init__( + self, + inner_impls: List[Tuple[str, Any]], + deps: List[Api], + ) -> None: + self.inner_impls = inner_impls + self.deps = deps + + self.providers = {} + for routing_key, provider_impl in inner_impls: + self.providers[routing_key] = provider_impl + + async def initialize(self) -> None: + pass + + async def shutdown(self) -> None: + for p in self.providers.values(): + await p.shutdown() + + async def chat_completion( + self, + model: str, + messages: List[Message], + sampling_params: Optional[SamplingParams] = SamplingParams(), + # zero-shot tool definitions as input to the model + tools: Optional[List[ToolDefinition]] = list, + tool_choice: Optional[ToolChoice] = ToolChoice.auto, + tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json, + stream: Optional[bool] = False, + logprobs: Optional[LogProbConfig] = None, + ) -> AsyncGenerator: + print("router chat_completion")