Make all methods async def again; add completion() for meta-reference (#270)

PR #201 had made several changes while trying to fix issues with getting the stream=False branches of inference and agents API working. As part of this, it made a change which was slightly gratuitous. Namely, making chat_completion() and brethren "def" instead of "async def".

The rationale was that this allowed the user (within llama-stack) of this to use it as:

```
async for chunk in api.chat_completion(params)
```

However, it causes unnecessary confusion for several folks. Given that clients (e.g., llama-stack-apps) anyway use the SDK methods (which are completely isolated) this choice was not ideal. Let's revert back so the call now looks like:

```
async for chunk in await api.chat_completion(params)
```

Bonus: Added a completion() implementation for the meta-reference provider. Technically should have been another PR :)
This commit is contained in:
Ashwin Bharambe 2024-10-18 20:50:59 -07:00 committed by GitHub
parent 95a96afe34
commit 2089427d60
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
23 changed files with 330 additions and 213 deletions

View file

@ -70,7 +70,7 @@ class InferenceRouter(Inference):
async def register_model(self, model: ModelDef) -> None:
await self.routing_table.register_model(model)
def chat_completion(
async def chat_completion(
self,
model: str,
messages: List[Message],
@ -93,11 +93,11 @@ class InferenceRouter(Inference):
)
provider = self.routing_table.get_provider_impl(model)
if stream:
return (chunk async for chunk in provider.chat_completion(**params))
return (chunk async for chunk in await provider.chat_completion(**params))
else:
return provider.chat_completion(**params)
return await provider.chat_completion(**params)
def completion(
async def completion(
self,
model: str,
content: InterleavedTextMedia,
@ -114,9 +114,9 @@ class InferenceRouter(Inference):
logprobs=logprobs,
)
if stream:
return (chunk async for chunk in provider.completion(**params))
return (chunk async for chunk in await provider.completion(**params))
else:
return provider.completion(**params)
return await provider.completion(**params)
async def embeddings(
self,