Make all API methods async def again

This commit is contained in:
Ashwin Bharambe 2024-10-18 16:50:57 -07:00
parent 95a96afe34
commit 627edaf407
17 changed files with 120 additions and 145 deletions

View file

@ -47,7 +47,7 @@ class BedrockInferenceAdapter(ModelRegistryHelper, Inference):
async def shutdown(self) -> None:
self.client.close()
def completion(
async def completion(
self,
model: str,
content: InterleavedTextMedia,
@ -283,7 +283,7 @@ class BedrockInferenceAdapter(ModelRegistryHelper, Inference):
)
return tool_config
def chat_completion(
async def chat_completion(
self,
model: str,
messages: List[Message],

View file

@ -48,7 +48,7 @@ class DatabricksInferenceAdapter(ModelRegistryHelper, Inference):
async def shutdown(self) -> None:
pass
def completion(
async def completion(
self,
model: str,
content: InterleavedTextMedia,
@ -58,7 +58,7 @@ class DatabricksInferenceAdapter(ModelRegistryHelper, Inference):
) -> AsyncGenerator:
raise NotImplementedError()
def chat_completion(
async def chat_completion(
self,
model: str,
messages: List[Message],
@ -84,7 +84,7 @@ class DatabricksInferenceAdapter(ModelRegistryHelper, Inference):
if stream:
return self._stream_chat_completion(request, client)
else:
return self._nonstream_chat_completion(request, client)
return await self._nonstream_chat_completion(request, client)
async def _nonstream_chat_completion(
self, request: ChatCompletionRequest, client: OpenAI

View file

@ -51,7 +51,7 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference):
async def shutdown(self) -> None:
pass
def completion(
async def completion(
self,
model: str,
content: InterleavedTextMedia,
@ -61,7 +61,7 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference):
) -> AsyncGenerator:
raise NotImplementedError()
def chat_completion(
async def chat_completion(
self,
model: str,
messages: List[Message],
@ -87,7 +87,7 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference):
if stream:
return self._stream_chat_completion(request, client)
else:
return self._nonstream_chat_completion(request, client)
return await self._nonstream_chat_completion(request, client)
async def _nonstream_chat_completion(
self, request: ChatCompletionRequest, client: Fireworks

View file

@ -84,7 +84,7 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
return ret
def completion(
async def completion(
self,
model: str,
content: InterleavedTextMedia,
@ -94,7 +94,7 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
) -> AsyncGenerator:
raise NotImplementedError()
def chat_completion(
async def chat_completion(
self,
model: str,
messages: List[Message],
@ -118,7 +118,7 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
if stream:
return self._stream_chat_completion(request)
else:
return self._nonstream_chat_completion(request)
return await self._nonstream_chat_completion(request)
def _get_params(self, request: ChatCompletionRequest) -> dict:
return {

View file

@ -66,7 +66,7 @@ class _HfAdapter(Inference, ModelsProtocolPrivate):
async def shutdown(self) -> None:
pass
def completion(
async def completion(
self,
model: str,
content: InterleavedTextMedia,
@ -76,7 +76,7 @@ class _HfAdapter(Inference, ModelsProtocolPrivate):
) -> AsyncGenerator:
raise NotImplementedError()
def chat_completion(
async def chat_completion(
self,
model: str,
messages: List[Message],
@ -101,7 +101,7 @@ class _HfAdapter(Inference, ModelsProtocolPrivate):
if stream:
return self._stream_chat_completion(request)
else:
return self._nonstream_chat_completion(request)
return await self._nonstream_chat_completion(request)
async def _nonstream_chat_completion(
self, request: ChatCompletionRequest

View file

@ -64,7 +64,7 @@ class TogetherInferenceAdapter(
) -> AsyncGenerator:
raise NotImplementedError()
def chat_completion(
async def chat_completion(
self,
model: str,
messages: List[Message],
@ -101,7 +101,7 @@ class TogetherInferenceAdapter(
if stream:
return self._stream_chat_completion(request, client)
else:
return self._nonstream_chat_completion(request, client)
return await self._nonstream_chat_completion(request, client)
async def _nonstream_chat_completion(
self, request: ChatCompletionRequest, client: Together