Make all methods async def again; add completion() for meta-reference (#270)

PR #201 had made several changes while trying to fix issues with getting the stream=False branches of inference and agents API working. As part of this, it made a change which was slightly gratuitous. Namely, making chat_completion() and brethren "def" instead of "async def".

The rationale was that this allowed the user (within llama-stack) of this to use it as:

```
async for chunk in api.chat_completion(params)
```

However, it causes unnecessary confusion for several folks. Given that clients (e.g., llama-stack-apps) anyway use the SDK methods (which are completely isolated) this choice was not ideal. Let's revert back so the call now looks like:

```
async for chunk in await api.chat_completion(params)
```

Bonus: Added a completion() implementation for the meta-reference provider. Technically should have been another PR :)
This commit is contained in:
Ashwin Bharambe 2024-10-18 20:50:59 -07:00 committed by GitHub
parent 95a96afe34
commit 2089427d60
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
23 changed files with 330 additions and 213 deletions

View file

@ -421,10 +421,8 @@ class Agents(Protocol):
agent_config: AgentConfig,
) -> AgentCreateResponse: ...
# This method is not `async def` because it can result in either an
# `AsyncGenerator` or a `AgentTurnCreateResponse` depending on the value of `stream`.
@webmethod(route="/agents/turn/create")
def create_agent_turn(
async def create_agent_turn(
self,
agent_id: str,
session_id: str,

View file

@ -67,14 +67,14 @@ class AgentsClient(Agents):
response.raise_for_status()
return AgentSessionCreateResponse(**response.json())
def create_agent_turn(
async def create_agent_turn(
self,
request: AgentTurnCreateRequest,
) -> AsyncGenerator:
if request.stream:
return self._stream_agent_turn(request)
else:
return self._nonstream_agent_turn(request)
return await self._nonstream_agent_turn(request)
async def _stream_agent_turn(
self, request: AgentTurnCreateRequest
@ -126,7 +126,7 @@ async def _run_agent(
for content in user_prompts:
cprint(f"User> {content}", color="white", attrs=["bold"])
iterator = api.create_agent_turn(
iterator = await api.create_agent_turn(
AgentTurnCreateRequest(
agent_id=create_response.agent_id,
session_id=session_response.session_id,

View file

@ -42,10 +42,10 @@ class InferenceClient(Inference):
async def shutdown(self) -> None:
pass
def completion(self, request: CompletionRequest) -> AsyncGenerator:
async def completion(self, request: CompletionRequest) -> AsyncGenerator:
raise NotImplementedError()
def chat_completion(
async def chat_completion(
self,
model: str,
messages: List[Message],
@ -139,7 +139,8 @@ async def run_main(
else:
logprobs_config = None
iterator = client.chat_completion(
assert stream, "Non streaming not supported here"
iterator = await client.chat_completion(
model=model,
messages=[message],
stream=stream,

View file

@ -88,7 +88,8 @@ class CompletionRequest(BaseModel):
class CompletionResponse(BaseModel):
"""Completion response."""
completion_message: CompletionMessage
content: str
stop_reason: StopReason
logprobs: Optional[List[TokenLogProbs]] = None
@ -113,7 +114,7 @@ class BatchCompletionRequest(BaseModel):
class BatchCompletionResponse(BaseModel):
"""Batch completion response."""
completion_message_batch: List[CompletionMessage]
batch: List[CompletionResponse]
@json_schema_type
@ -165,7 +166,7 @@ class BatchChatCompletionRequest(BaseModel):
@json_schema_type
class BatchChatCompletionResponse(BaseModel):
completion_message_batch: List[CompletionMessage]
batch: List[ChatCompletionResponse]
@json_schema_type
@ -181,10 +182,8 @@ class ModelStore(Protocol):
class Inference(Protocol):
model_store: ModelStore
# This method is not `async def` because it can result in either an
# `AsyncGenerator` or a `CompletionResponse` depending on the value of `stream`.
@webmethod(route="/inference/completion")
def completion(
async def completion(
self,
model: str,
content: InterleavedTextMedia,
@ -196,7 +195,7 @@ class Inference(Protocol):
# This method is not `async def` because it can result in either an
# `AsyncGenerator` or a `ChatCompletionResponse` depending on the value of `stream`.
@webmethod(route="/inference/chat_completion")
def chat_completion(
async def chat_completion(
self,
model: str,
messages: List[Message],