models routing work

This commit is contained in:
Xi Yan 2024-09-19 08:48:10 -07:00
parent f3ff3a3001
commit 9bdd4e3dd9
3 changed files with 20 additions and 3 deletions

View file

@ -90,7 +90,7 @@ async def run_main(host: str, port: int, stream: bool):
cprint(f"User>{message.content}", "green")
iterator = client.chat_completion(
ChatCompletionRequest(
model="Meta-Llama3.1-8B-Instruct",
model="ollama-1",
messages=[message],
stream=stream,
)

View file

@ -98,7 +98,13 @@ class OllamaInferenceAdapter(Inference):
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
) -> AsyncGenerator:
cprint("!! calling remote ollama !!", "red")
cprint("!! calling remote ollama {}, url={}!!".format(model, self.url), "red")
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.start,
delta="",
)
)
# request = ChatCompletionRequest(
# model=model,
# messages=messages,

View file

@ -4,7 +4,7 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any, Dict, List, Tuple
from typing import Any, AsyncGenerator, Dict, List, Tuple
from llama_stack.distribution.datatypes import Api
from llama_stack.apis.inference import * # noqa: F403
@ -46,3 +46,14 @@ class InferenceRouterImpl(Inference):
logprobs: Optional[LogProbConfig] = None,
) -> AsyncGenerator:
print("router chat_completion")
async for chunk in self.providers[model].chat_completion(
model=model,
messages=messages,
sampling_params=sampling_params,
tools=tools,
tool_choice=tool_choice,
tool_prompt_format=tool_prompt_format,
stream=stream,
logprobs=logprobs,
):
yield chunk