diff --git a/llama_stack/apis/inference/client.py b/llama_stack/apis/inference/client.py index f5321c628..15954ef57 100644 --- a/llama_stack/apis/inference/client.py +++ b/llama_stack/apis/inference/client.py @@ -90,7 +90,7 @@ async def run_main(host: str, port: int, stream: bool): cprint(f"User>{message.content}", "green") iterator = client.chat_completion( ChatCompletionRequest( - model="Meta-Llama3.1-8B-Instruct", + model="ollama-1", messages=[message], stream=stream, ) diff --git a/llama_stack/providers/adapters/inference/ollama/ollama.py b/llama_stack/providers/adapters/inference/ollama/ollama.py index 9be4f4935..bc7780c9d 100644 --- a/llama_stack/providers/adapters/inference/ollama/ollama.py +++ b/llama_stack/providers/adapters/inference/ollama/ollama.py @@ -98,7 +98,13 @@ class OllamaInferenceAdapter(Inference): stream: Optional[bool] = False, logprobs: Optional[LogProbConfig] = None, ) -> AsyncGenerator: - cprint("!! calling remote ollama !!", "red") + cprint("!! calling remote ollama {}, url={}!!".format(model, self.url), "red") + yield ChatCompletionResponseStreamChunk( + event=ChatCompletionResponseEvent( + event_type=ChatCompletionResponseEventType.start, + delta="", + ) + ) # request = ChatCompletionRequest( # model=model, # messages=messages, diff --git a/llama_stack/providers/routers/inference/inference.py b/llama_stack/providers/routers/inference/inference.py index f459439d5..3a0ab84f9 100644 --- a/llama_stack/providers/routers/inference/inference.py +++ b/llama_stack/providers/routers/inference/inference.py @@ -4,7 +4,7 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from typing import Any, Dict, List, Tuple +from typing import Any, AsyncGenerator, Dict, List, Tuple from llama_stack.distribution.datatypes import Api from llama_stack.apis.inference import * # noqa: F403 @@ -46,3 +46,14 @@ class InferenceRouterImpl(Inference): logprobs: Optional[LogProbConfig] = None, ) -> AsyncGenerator: print("router chat_completion") + async for chunk in self.providers[model].chat_completion( + model=model, + messages=messages, + sampling_params=sampling_params, + tools=tools, + tool_choice=tool_choice, + tool_prompt_format=tool_prompt_format, + stream=stream, + logprobs=logprobs, + ): + yield chunk