models endpoint testing

This commit is contained in:
Xi Yan 2024-09-22 00:01:35 -07:00
parent c0199029e5
commit 0348f26e00
10 changed files with 235 additions and 79 deletions

View file

@ -6,14 +6,14 @@
from typing import AsyncGenerator
from fireworks.client import Fireworks
from llama_models.llama3.api.chat_format import ChatFormat
from llama_models.llama3.api.datatypes import Message, StopReason
from llama_models.llama3.api.tokenizer import Tokenizer
from llama_models.sku_list import resolve_model
from fireworks.client import Fireworks
from llama_stack.apis.inference import * # noqa: F403
from llama_stack.providers.utils.inference.prepare_messages import prepare_messages
@ -42,7 +42,14 @@ class FireworksInferenceAdapter(Inference):
async def shutdown(self) -> None:
pass
async def completion(self, request: CompletionRequest) -> AsyncGenerator:
async def completion(
self,
model: str,
content: InterleavedTextMedia,
sampling_params: Optional[SamplingParams] = SamplingParams(),
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
) -> AsyncGenerator:
raise NotImplementedError()
def _messages_to_fireworks_messages(self, messages: list[Message]) -> list:

View file

@ -30,25 +30,33 @@ OLLAMA_SUPPORTED_SKUS = {
class OllamaInferenceAdapter(Inference):
def __init__(self, url: str) -> None:
self.url = url
tokenizer = Tokenizer.get_instance()
self.formatter = ChatFormat(tokenizer)
# tokenizer = Tokenizer.get_instance()
# self.formatter = ChatFormat(tokenizer)
@property
def client(self) -> AsyncClient:
return AsyncClient(host=self.url)
async def initialize(self) -> None:
try:
await self.client.ps()
except httpx.ConnectError as e:
raise RuntimeError(
"Ollama Server is not running, start it using `ollama serve` in a separate terminal"
) from e
print("Ollama init")
# try:
# await self.client.ps()
# except httpx.ConnectError as e:
# raise RuntimeError(
# "Ollama Server is not running, start it using `ollama serve` in a separate terminal"
# ) from e
async def shutdown(self) -> None:
pass
async def completion(self, request: CompletionRequest) -> AsyncGenerator:
async def completion(
self,
model: str,
content: InterleavedTextMedia,
sampling_params: Optional[SamplingParams] = SamplingParams(),
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
) -> AsyncGenerator:
raise NotImplementedError()
def _messages_to_ollama_messages(self, messages: list[Message]) -> list:

View file

@ -54,7 +54,14 @@ class TGIAdapter(Inference):
async def shutdown(self) -> None:
pass
async def completion(self, request: CompletionRequest) -> AsyncGenerator:
async def completion(
self,
model: str,
content: InterleavedTextMedia,
sampling_params: Optional[SamplingParams] = SamplingParams(),
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
) -> AsyncGenerator:
raise NotImplementedError()
def get_chat_options(self, request: ChatCompletionRequest) -> dict:

View file

@ -42,7 +42,14 @@ class TogetherInferenceAdapter(Inference):
async def shutdown(self) -> None:
pass
async def completion(self, request: CompletionRequest) -> AsyncGenerator:
async def completion(
self,
model: str,
content: InterleavedTextMedia,
sampling_params: Optional[SamplingParams] = SamplingParams(),
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
) -> AsyncGenerator:
raise NotImplementedError()
def _messages_to_together_messages(self, messages: list[Message]) -> list: