models endpoint testing

This commit is contained in:
Xi Yan 2024-09-22 00:01:35 -07:00
parent c0199029e5
commit 0348f26e00
10 changed files with 235 additions and 79 deletions

View file

@ -30,25 +30,33 @@ OLLAMA_SUPPORTED_SKUS = {
class OllamaInferenceAdapter(Inference):
def __init__(self, url: str) -> None:
self.url = url
tokenizer = Tokenizer.get_instance()
self.formatter = ChatFormat(tokenizer)
# tokenizer = Tokenizer.get_instance()
# self.formatter = ChatFormat(tokenizer)
@property
def client(self) -> AsyncClient:
return AsyncClient(host=self.url)
async def initialize(self) -> None:
try:
await self.client.ps()
except httpx.ConnectError as e:
raise RuntimeError(
"Ollama Server is not running, start it using `ollama serve` in a separate terminal"
) from e
print("Ollama init")
# try:
# await self.client.ps()
# except httpx.ConnectError as e:
# raise RuntimeError(
# "Ollama Server is not running, start it using `ollama serve` in a separate terminal"
# ) from e
async def shutdown(self) -> None:
pass
async def completion(self, request: CompletionRequest) -> AsyncGenerator:
async def completion(
self,
model: str,
content: InterleavedTextMedia,
sampling_params: Optional[SamplingParams] = SamplingParams(),
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
) -> AsyncGenerator:
raise NotImplementedError()
def _messages_to_ollama_messages(self, messages: list[Message]) -> list: