forked from phoenix-oss/llama-stack-mirror
Make all methods async def
again; add completion() for meta-reference (#270)
PR #201 had made several changes while trying to fix issues with getting the stream=False branches of inference and agents API working. As part of this, it made a change which was slightly gratuitous. Namely, making chat_completion() and brethren "def" instead of "async def". The rationale was that this allowed the user (within llama-stack) of this to use it as: ``` async for chunk in api.chat_completion(params) ``` However, it causes unnecessary confusion for several folks. Given that clients (e.g., llama-stack-apps) anyway use the SDK methods (which are completely isolated) this choice was not ideal. Let's revert back so the call now looks like: ``` async for chunk in await api.chat_completion(params) ``` Bonus: Added a completion() implementation for the meta-reference provider. Technically should have been another PR :)
This commit is contained in:
parent
95a96afe34
commit
2089427d60
23 changed files with 330 additions and 213 deletions
|
@ -21,7 +21,7 @@
|
||||||
"info": {
|
"info": {
|
||||||
"title": "[DRAFT] Llama Stack Specification",
|
"title": "[DRAFT] Llama Stack Specification",
|
||||||
"version": "0.0.1",
|
"version": "0.0.1",
|
||||||
"description": "This is the specification of the llama stack that provides\n a set of endpoints and their corresponding interfaces that are tailored to\n best leverage Llama Models. The specification is still in draft and subject to change.\n Generated at 2024-10-10 15:29:56.831109"
|
"description": "This is the specification of the llama stack that provides\n a set of endpoints and their corresponding interfaces that are tailored to\n best leverage Llama Models. The specification is still in draft and subject to change.\n Generated at 2024-10-18 20:48:17.730988"
|
||||||
},
|
},
|
||||||
"servers": [
|
"servers": [
|
||||||
{
|
{
|
||||||
|
@ -2830,8 +2830,11 @@
|
||||||
"CompletionResponse": {
|
"CompletionResponse": {
|
||||||
"type": "object",
|
"type": "object",
|
||||||
"properties": {
|
"properties": {
|
||||||
"completion_message": {
|
"content": {
|
||||||
"$ref": "#/components/schemas/CompletionMessage"
|
"type": "string"
|
||||||
|
},
|
||||||
|
"stop_reason": {
|
||||||
|
"$ref": "#/components/schemas/StopReason"
|
||||||
},
|
},
|
||||||
"logprobs": {
|
"logprobs": {
|
||||||
"type": "array",
|
"type": "array",
|
||||||
|
@ -2842,7 +2845,8 @@
|
||||||
},
|
},
|
||||||
"additionalProperties": false,
|
"additionalProperties": false,
|
||||||
"required": [
|
"required": [
|
||||||
"completion_message"
|
"content",
|
||||||
|
"stop_reason"
|
||||||
],
|
],
|
||||||
"title": "Completion response."
|
"title": "Completion response."
|
||||||
},
|
},
|
||||||
|
@ -6075,49 +6079,49 @@
|
||||||
],
|
],
|
||||||
"tags": [
|
"tags": [
|
||||||
{
|
{
|
||||||
"name": "Evaluations"
|
"name": "Models"
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "Inspect"
|
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"name": "RewardScoring"
|
"name": "RewardScoring"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"name": "Datasets"
|
"name": "MemoryBanks"
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "Models"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "Telemetry"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "PostTraining"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "SyntheticDataGeneration"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "BatchInference"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "Inference"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "Agents"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "Memory"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "Safety"
|
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"name": "Shields"
|
"name": "Shields"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"name": "MemoryBanks"
|
"name": "SyntheticDataGeneration"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "Inference"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "Inspect"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "BatchInference"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "Memory"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "Datasets"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "Agents"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "PostTraining"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "Telemetry"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "Safety"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "Evaluations"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"name": "BuiltinTool",
|
"name": "BuiltinTool",
|
||||||
|
|
|
@ -501,14 +501,17 @@ components:
|
||||||
CompletionResponse:
|
CompletionResponse:
|
||||||
additionalProperties: false
|
additionalProperties: false
|
||||||
properties:
|
properties:
|
||||||
completion_message:
|
content:
|
||||||
$ref: '#/components/schemas/CompletionMessage'
|
type: string
|
||||||
logprobs:
|
logprobs:
|
||||||
items:
|
items:
|
||||||
$ref: '#/components/schemas/TokenLogProbs'
|
$ref: '#/components/schemas/TokenLogProbs'
|
||||||
type: array
|
type: array
|
||||||
|
stop_reason:
|
||||||
|
$ref: '#/components/schemas/StopReason'
|
||||||
required:
|
required:
|
||||||
- completion_message
|
- content
|
||||||
|
- stop_reason
|
||||||
title: Completion response.
|
title: Completion response.
|
||||||
type: object
|
type: object
|
||||||
CompletionResponseStreamChunk:
|
CompletionResponseStreamChunk:
|
||||||
|
@ -2507,7 +2510,7 @@ info:
|
||||||
description: "This is the specification of the llama stack that provides\n \
|
description: "This is the specification of the llama stack that provides\n \
|
||||||
\ a set of endpoints and their corresponding interfaces that are tailored\
|
\ a set of endpoints and their corresponding interfaces that are tailored\
|
||||||
\ to\n best leverage Llama Models. The specification is still in\
|
\ to\n best leverage Llama Models. The specification is still in\
|
||||||
\ draft and subject to change.\n Generated at 2024-10-10 15:29:56.831109"
|
\ draft and subject to change.\n Generated at 2024-10-18 20:48:17.730988"
|
||||||
title: '[DRAFT] Llama Stack Specification'
|
title: '[DRAFT] Llama Stack Specification'
|
||||||
version: 0.0.1
|
version: 0.0.1
|
||||||
jsonSchemaDialect: https://json-schema.org/draft/2020-12/schema
|
jsonSchemaDialect: https://json-schema.org/draft/2020-12/schema
|
||||||
|
@ -3712,21 +3715,21 @@ security:
|
||||||
servers:
|
servers:
|
||||||
- url: http://any-hosted-llama-stack.com
|
- url: http://any-hosted-llama-stack.com
|
||||||
tags:
|
tags:
|
||||||
- name: Evaluations
|
|
||||||
- name: Inspect
|
|
||||||
- name: RewardScoring
|
|
||||||
- name: Datasets
|
|
||||||
- name: Models
|
- name: Models
|
||||||
- name: Telemetry
|
- name: RewardScoring
|
||||||
- name: PostTraining
|
|
||||||
- name: SyntheticDataGeneration
|
|
||||||
- name: BatchInference
|
|
||||||
- name: Inference
|
|
||||||
- name: Agents
|
|
||||||
- name: Memory
|
|
||||||
- name: Safety
|
|
||||||
- name: Shields
|
|
||||||
- name: MemoryBanks
|
- name: MemoryBanks
|
||||||
|
- name: Shields
|
||||||
|
- name: SyntheticDataGeneration
|
||||||
|
- name: Inference
|
||||||
|
- name: Inspect
|
||||||
|
- name: BatchInference
|
||||||
|
- name: Memory
|
||||||
|
- name: Datasets
|
||||||
|
- name: Agents
|
||||||
|
- name: PostTraining
|
||||||
|
- name: Telemetry
|
||||||
|
- name: Safety
|
||||||
|
- name: Evaluations
|
||||||
- description: <SchemaDefinition schemaRef="#/components/schemas/BuiltinTool" />
|
- description: <SchemaDefinition schemaRef="#/components/schemas/BuiltinTool" />
|
||||||
name: BuiltinTool
|
name: BuiltinTool
|
||||||
- description: <SchemaDefinition schemaRef="#/components/schemas/CompletionMessage"
|
- description: <SchemaDefinition schemaRef="#/components/schemas/CompletionMessage"
|
||||||
|
|
|
@ -421,10 +421,8 @@ class Agents(Protocol):
|
||||||
agent_config: AgentConfig,
|
agent_config: AgentConfig,
|
||||||
) -> AgentCreateResponse: ...
|
) -> AgentCreateResponse: ...
|
||||||
|
|
||||||
# This method is not `async def` because it can result in either an
|
|
||||||
# `AsyncGenerator` or a `AgentTurnCreateResponse` depending on the value of `stream`.
|
|
||||||
@webmethod(route="/agents/turn/create")
|
@webmethod(route="/agents/turn/create")
|
||||||
def create_agent_turn(
|
async def create_agent_turn(
|
||||||
self,
|
self,
|
||||||
agent_id: str,
|
agent_id: str,
|
||||||
session_id: str,
|
session_id: str,
|
||||||
|
|
|
@ -67,14 +67,14 @@ class AgentsClient(Agents):
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
return AgentSessionCreateResponse(**response.json())
|
return AgentSessionCreateResponse(**response.json())
|
||||||
|
|
||||||
def create_agent_turn(
|
async def create_agent_turn(
|
||||||
self,
|
self,
|
||||||
request: AgentTurnCreateRequest,
|
request: AgentTurnCreateRequest,
|
||||||
) -> AsyncGenerator:
|
) -> AsyncGenerator:
|
||||||
if request.stream:
|
if request.stream:
|
||||||
return self._stream_agent_turn(request)
|
return self._stream_agent_turn(request)
|
||||||
else:
|
else:
|
||||||
return self._nonstream_agent_turn(request)
|
return await self._nonstream_agent_turn(request)
|
||||||
|
|
||||||
async def _stream_agent_turn(
|
async def _stream_agent_turn(
|
||||||
self, request: AgentTurnCreateRequest
|
self, request: AgentTurnCreateRequest
|
||||||
|
@ -126,7 +126,7 @@ async def _run_agent(
|
||||||
|
|
||||||
for content in user_prompts:
|
for content in user_prompts:
|
||||||
cprint(f"User> {content}", color="white", attrs=["bold"])
|
cprint(f"User> {content}", color="white", attrs=["bold"])
|
||||||
iterator = api.create_agent_turn(
|
iterator = await api.create_agent_turn(
|
||||||
AgentTurnCreateRequest(
|
AgentTurnCreateRequest(
|
||||||
agent_id=create_response.agent_id,
|
agent_id=create_response.agent_id,
|
||||||
session_id=session_response.session_id,
|
session_id=session_response.session_id,
|
||||||
|
|
|
@ -42,10 +42,10 @@ class InferenceClient(Inference):
|
||||||
async def shutdown(self) -> None:
|
async def shutdown(self) -> None:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def completion(self, request: CompletionRequest) -> AsyncGenerator:
|
async def completion(self, request: CompletionRequest) -> AsyncGenerator:
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
def chat_completion(
|
async def chat_completion(
|
||||||
self,
|
self,
|
||||||
model: str,
|
model: str,
|
||||||
messages: List[Message],
|
messages: List[Message],
|
||||||
|
@ -139,7 +139,8 @@ async def run_main(
|
||||||
else:
|
else:
|
||||||
logprobs_config = None
|
logprobs_config = None
|
||||||
|
|
||||||
iterator = client.chat_completion(
|
assert stream, "Non streaming not supported here"
|
||||||
|
iterator = await client.chat_completion(
|
||||||
model=model,
|
model=model,
|
||||||
messages=[message],
|
messages=[message],
|
||||||
stream=stream,
|
stream=stream,
|
||||||
|
|
|
@ -88,7 +88,8 @@ class CompletionRequest(BaseModel):
|
||||||
class CompletionResponse(BaseModel):
|
class CompletionResponse(BaseModel):
|
||||||
"""Completion response."""
|
"""Completion response."""
|
||||||
|
|
||||||
completion_message: CompletionMessage
|
content: str
|
||||||
|
stop_reason: StopReason
|
||||||
logprobs: Optional[List[TokenLogProbs]] = None
|
logprobs: Optional[List[TokenLogProbs]] = None
|
||||||
|
|
||||||
|
|
||||||
|
@ -113,7 +114,7 @@ class BatchCompletionRequest(BaseModel):
|
||||||
class BatchCompletionResponse(BaseModel):
|
class BatchCompletionResponse(BaseModel):
|
||||||
"""Batch completion response."""
|
"""Batch completion response."""
|
||||||
|
|
||||||
completion_message_batch: List[CompletionMessage]
|
batch: List[CompletionResponse]
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
|
@ -165,7 +166,7 @@ class BatchChatCompletionRequest(BaseModel):
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class BatchChatCompletionResponse(BaseModel):
|
class BatchChatCompletionResponse(BaseModel):
|
||||||
completion_message_batch: List[CompletionMessage]
|
batch: List[ChatCompletionResponse]
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
|
@ -181,10 +182,8 @@ class ModelStore(Protocol):
|
||||||
class Inference(Protocol):
|
class Inference(Protocol):
|
||||||
model_store: ModelStore
|
model_store: ModelStore
|
||||||
|
|
||||||
# This method is not `async def` because it can result in either an
|
|
||||||
# `AsyncGenerator` or a `CompletionResponse` depending on the value of `stream`.
|
|
||||||
@webmethod(route="/inference/completion")
|
@webmethod(route="/inference/completion")
|
||||||
def completion(
|
async def completion(
|
||||||
self,
|
self,
|
||||||
model: str,
|
model: str,
|
||||||
content: InterleavedTextMedia,
|
content: InterleavedTextMedia,
|
||||||
|
@ -196,7 +195,7 @@ class Inference(Protocol):
|
||||||
# This method is not `async def` because it can result in either an
|
# This method is not `async def` because it can result in either an
|
||||||
# `AsyncGenerator` or a `ChatCompletionResponse` depending on the value of `stream`.
|
# `AsyncGenerator` or a `ChatCompletionResponse` depending on the value of `stream`.
|
||||||
@webmethod(route="/inference/chat_completion")
|
@webmethod(route="/inference/chat_completion")
|
||||||
def chat_completion(
|
async def chat_completion(
|
||||||
self,
|
self,
|
||||||
model: str,
|
model: str,
|
||||||
messages: List[Message],
|
messages: List[Message],
|
||||||
|
|
|
@ -70,7 +70,7 @@ class InferenceRouter(Inference):
|
||||||
async def register_model(self, model: ModelDef) -> None:
|
async def register_model(self, model: ModelDef) -> None:
|
||||||
await self.routing_table.register_model(model)
|
await self.routing_table.register_model(model)
|
||||||
|
|
||||||
def chat_completion(
|
async def chat_completion(
|
||||||
self,
|
self,
|
||||||
model: str,
|
model: str,
|
||||||
messages: List[Message],
|
messages: List[Message],
|
||||||
|
@ -93,11 +93,11 @@ class InferenceRouter(Inference):
|
||||||
)
|
)
|
||||||
provider = self.routing_table.get_provider_impl(model)
|
provider = self.routing_table.get_provider_impl(model)
|
||||||
if stream:
|
if stream:
|
||||||
return (chunk async for chunk in provider.chat_completion(**params))
|
return (chunk async for chunk in await provider.chat_completion(**params))
|
||||||
else:
|
else:
|
||||||
return provider.chat_completion(**params)
|
return await provider.chat_completion(**params)
|
||||||
|
|
||||||
def completion(
|
async def completion(
|
||||||
self,
|
self,
|
||||||
model: str,
|
model: str,
|
||||||
content: InterleavedTextMedia,
|
content: InterleavedTextMedia,
|
||||||
|
@ -114,9 +114,9 @@ class InferenceRouter(Inference):
|
||||||
logprobs=logprobs,
|
logprobs=logprobs,
|
||||||
)
|
)
|
||||||
if stream:
|
if stream:
|
||||||
return (chunk async for chunk in provider.completion(**params))
|
return (chunk async for chunk in await provider.completion(**params))
|
||||||
else:
|
else:
|
||||||
return provider.completion(**params)
|
return await provider.completion(**params)
|
||||||
|
|
||||||
async def embeddings(
|
async def embeddings(
|
||||||
self,
|
self,
|
||||||
|
|
|
@ -47,7 +47,7 @@ class BedrockInferenceAdapter(ModelRegistryHelper, Inference):
|
||||||
async def shutdown(self) -> None:
|
async def shutdown(self) -> None:
|
||||||
self.client.close()
|
self.client.close()
|
||||||
|
|
||||||
def completion(
|
async def completion(
|
||||||
self,
|
self,
|
||||||
model: str,
|
model: str,
|
||||||
content: InterleavedTextMedia,
|
content: InterleavedTextMedia,
|
||||||
|
@ -283,7 +283,7 @@ class BedrockInferenceAdapter(ModelRegistryHelper, Inference):
|
||||||
)
|
)
|
||||||
return tool_config
|
return tool_config
|
||||||
|
|
||||||
def chat_completion(
|
async def chat_completion(
|
||||||
self,
|
self,
|
||||||
model: str,
|
model: str,
|
||||||
messages: List[Message],
|
messages: List[Message],
|
||||||
|
|
|
@ -48,7 +48,7 @@ class DatabricksInferenceAdapter(ModelRegistryHelper, Inference):
|
||||||
async def shutdown(self) -> None:
|
async def shutdown(self) -> None:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def completion(
|
async def completion(
|
||||||
self,
|
self,
|
||||||
model: str,
|
model: str,
|
||||||
content: InterleavedTextMedia,
|
content: InterleavedTextMedia,
|
||||||
|
@ -58,7 +58,7 @@ class DatabricksInferenceAdapter(ModelRegistryHelper, Inference):
|
||||||
) -> AsyncGenerator:
|
) -> AsyncGenerator:
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
def chat_completion(
|
async def chat_completion(
|
||||||
self,
|
self,
|
||||||
model: str,
|
model: str,
|
||||||
messages: List[Message],
|
messages: List[Message],
|
||||||
|
@ -84,7 +84,7 @@ class DatabricksInferenceAdapter(ModelRegistryHelper, Inference):
|
||||||
if stream:
|
if stream:
|
||||||
return self._stream_chat_completion(request, client)
|
return self._stream_chat_completion(request, client)
|
||||||
else:
|
else:
|
||||||
return self._nonstream_chat_completion(request, client)
|
return await self._nonstream_chat_completion(request, client)
|
||||||
|
|
||||||
async def _nonstream_chat_completion(
|
async def _nonstream_chat_completion(
|
||||||
self, request: ChatCompletionRequest, client: OpenAI
|
self, request: ChatCompletionRequest, client: OpenAI
|
||||||
|
|
|
@ -51,7 +51,7 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference):
|
||||||
async def shutdown(self) -> None:
|
async def shutdown(self) -> None:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def completion(
|
async def completion(
|
||||||
self,
|
self,
|
||||||
model: str,
|
model: str,
|
||||||
content: InterleavedTextMedia,
|
content: InterleavedTextMedia,
|
||||||
|
@ -61,7 +61,7 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference):
|
||||||
) -> AsyncGenerator:
|
) -> AsyncGenerator:
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
def chat_completion(
|
async def chat_completion(
|
||||||
self,
|
self,
|
||||||
model: str,
|
model: str,
|
||||||
messages: List[Message],
|
messages: List[Message],
|
||||||
|
@ -87,7 +87,7 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference):
|
||||||
if stream:
|
if stream:
|
||||||
return self._stream_chat_completion(request, client)
|
return self._stream_chat_completion(request, client)
|
||||||
else:
|
else:
|
||||||
return self._nonstream_chat_completion(request, client)
|
return await self._nonstream_chat_completion(request, client)
|
||||||
|
|
||||||
async def _nonstream_chat_completion(
|
async def _nonstream_chat_completion(
|
||||||
self, request: ChatCompletionRequest, client: Fireworks
|
self, request: ChatCompletionRequest, client: Fireworks
|
||||||
|
|
|
@ -84,7 +84,7 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
|
||||||
|
|
||||||
return ret
|
return ret
|
||||||
|
|
||||||
def completion(
|
async def completion(
|
||||||
self,
|
self,
|
||||||
model: str,
|
model: str,
|
||||||
content: InterleavedTextMedia,
|
content: InterleavedTextMedia,
|
||||||
|
@ -94,7 +94,7 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
|
||||||
) -> AsyncGenerator:
|
) -> AsyncGenerator:
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
def chat_completion(
|
async def chat_completion(
|
||||||
self,
|
self,
|
||||||
model: str,
|
model: str,
|
||||||
messages: List[Message],
|
messages: List[Message],
|
||||||
|
@ -118,7 +118,7 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
|
||||||
if stream:
|
if stream:
|
||||||
return self._stream_chat_completion(request)
|
return self._stream_chat_completion(request)
|
||||||
else:
|
else:
|
||||||
return self._nonstream_chat_completion(request)
|
return await self._nonstream_chat_completion(request)
|
||||||
|
|
||||||
def _get_params(self, request: ChatCompletionRequest) -> dict:
|
def _get_params(self, request: ChatCompletionRequest) -> dict:
|
||||||
return {
|
return {
|
||||||
|
|
|
@ -66,7 +66,7 @@ class _HfAdapter(Inference, ModelsProtocolPrivate):
|
||||||
async def shutdown(self) -> None:
|
async def shutdown(self) -> None:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def completion(
|
async def completion(
|
||||||
self,
|
self,
|
||||||
model: str,
|
model: str,
|
||||||
content: InterleavedTextMedia,
|
content: InterleavedTextMedia,
|
||||||
|
@ -76,7 +76,7 @@ class _HfAdapter(Inference, ModelsProtocolPrivate):
|
||||||
) -> AsyncGenerator:
|
) -> AsyncGenerator:
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
def chat_completion(
|
async def chat_completion(
|
||||||
self,
|
self,
|
||||||
model: str,
|
model: str,
|
||||||
messages: List[Message],
|
messages: List[Message],
|
||||||
|
@ -101,7 +101,7 @@ class _HfAdapter(Inference, ModelsProtocolPrivate):
|
||||||
if stream:
|
if stream:
|
||||||
return self._stream_chat_completion(request)
|
return self._stream_chat_completion(request)
|
||||||
else:
|
else:
|
||||||
return self._nonstream_chat_completion(request)
|
return await self._nonstream_chat_completion(request)
|
||||||
|
|
||||||
async def _nonstream_chat_completion(
|
async def _nonstream_chat_completion(
|
||||||
self, request: ChatCompletionRequest
|
self, request: ChatCompletionRequest
|
||||||
|
|
|
@ -64,7 +64,7 @@ class TogetherInferenceAdapter(
|
||||||
) -> AsyncGenerator:
|
) -> AsyncGenerator:
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
def chat_completion(
|
async def chat_completion(
|
||||||
self,
|
self,
|
||||||
model: str,
|
model: str,
|
||||||
messages: List[Message],
|
messages: List[Message],
|
||||||
|
@ -101,7 +101,7 @@ class TogetherInferenceAdapter(
|
||||||
if stream:
|
if stream:
|
||||||
return self._stream_chat_completion(request, client)
|
return self._stream_chat_completion(request, client)
|
||||||
else:
|
else:
|
||||||
return self._nonstream_chat_completion(request, client)
|
return await self._nonstream_chat_completion(request, client)
|
||||||
|
|
||||||
async def _nonstream_chat_completion(
|
async def _nonstream_chat_completion(
|
||||||
self, request: ChatCompletionRequest, client: Together
|
self, request: ChatCompletionRequest, client: Together
|
||||||
|
|
|
@ -424,7 +424,7 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
stop_reason = None
|
stop_reason = None
|
||||||
|
|
||||||
with tracing.span("inference"):
|
with tracing.span("inference"):
|
||||||
async for chunk in self.inference_api.chat_completion(
|
async for chunk in await self.inference_api.chat_completion(
|
||||||
self.agent_config.model,
|
self.agent_config.model,
|
||||||
input_messages,
|
input_messages,
|
||||||
tools=self._get_tools(),
|
tools=self._get_tools(),
|
||||||
|
|
|
@ -105,7 +105,7 @@ class MetaReferenceAgentsImpl(Agents):
|
||||||
session_id=session_id,
|
session_id=session_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
def create_agent_turn(
|
async def create_agent_turn(
|
||||||
self,
|
self,
|
||||||
agent_id: str,
|
agent_id: str,
|
||||||
session_id: str,
|
session_id: str,
|
||||||
|
|
|
@ -23,11 +23,6 @@ from fairscale.nn.model_parallel.initialize import (
|
||||||
)
|
)
|
||||||
from llama_models.llama3.api.args import ModelArgs
|
from llama_models.llama3.api.args import ModelArgs
|
||||||
from llama_models.llama3.api.chat_format import ChatFormat, ModelInput
|
from llama_models.llama3.api.chat_format import ChatFormat, ModelInput
|
||||||
from llama_models.llama3.api.datatypes import (
|
|
||||||
InterleavedTextMedia,
|
|
||||||
Message,
|
|
||||||
ToolPromptFormat,
|
|
||||||
)
|
|
||||||
from llama_models.llama3.api.tokenizer import Tokenizer
|
from llama_models.llama3.api.tokenizer import Tokenizer
|
||||||
from llama_models.llama3.reference_impl.model import Transformer
|
from llama_models.llama3.reference_impl.model import Transformer
|
||||||
from llama_models.llama3.reference_impl.multimodal.model import (
|
from llama_models.llama3.reference_impl.multimodal.model import (
|
||||||
|
@ -38,7 +33,11 @@ from llama_models.sku_list import resolve_model
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
from termcolor import cprint
|
from termcolor import cprint
|
||||||
|
|
||||||
|
from llama_stack.apis.inference import * # noqa: F403
|
||||||
from llama_stack.distribution.utils.model_utils import model_local_dir
|
from llama_stack.distribution.utils.model_utils import model_local_dir
|
||||||
|
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||||
|
chat_completion_request_to_messages,
|
||||||
|
)
|
||||||
|
|
||||||
from .config import MetaReferenceInferenceConfig, MetaReferenceQuantizedInferenceConfig
|
from .config import MetaReferenceInferenceConfig, MetaReferenceQuantizedInferenceConfig
|
||||||
|
|
||||||
|
@ -297,15 +296,12 @@ class Llama:
|
||||||
if all(eos_reached):
|
if all(eos_reached):
|
||||||
break
|
break
|
||||||
|
|
||||||
def text_completion(
|
def completion(
|
||||||
self,
|
self,
|
||||||
content: InterleavedTextMedia,
|
request: CompletionRequest,
|
||||||
temperature: float = 0.6,
|
|
||||||
top_p: float = 0.9,
|
|
||||||
max_gen_len: Optional[int] = None,
|
|
||||||
logprobs: bool = False,
|
|
||||||
echo: bool = False,
|
|
||||||
) -> Generator:
|
) -> Generator:
|
||||||
|
sampling_params = request.sampling_params
|
||||||
|
max_gen_len = sampling_params.max_tokens
|
||||||
if (
|
if (
|
||||||
max_gen_len is None
|
max_gen_len is None
|
||||||
or max_gen_len == 0
|
or max_gen_len == 0
|
||||||
|
@ -313,26 +309,25 @@ class Llama:
|
||||||
):
|
):
|
||||||
max_gen_len = self.model.params.max_seq_len - 1
|
max_gen_len = self.model.params.max_seq_len - 1
|
||||||
|
|
||||||
model_input = self.formatter.encode_content(content)
|
model_input = self.formatter.encode_content(request.content)
|
||||||
|
|
||||||
yield from self.generate(
|
yield from self.generate(
|
||||||
model_input=model_input,
|
model_input=model_input,
|
||||||
max_gen_len=max_gen_len,
|
max_gen_len=max_gen_len,
|
||||||
temperature=temperature,
|
temperature=sampling_params.temperature,
|
||||||
top_p=top_p,
|
top_p=sampling_params.top_p,
|
||||||
logprobs=logprobs,
|
logprobs=bool(request.logprobs),
|
||||||
echo=echo,
|
include_stop_token=True,
|
||||||
|
echo=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
def chat_completion(
|
def chat_completion(
|
||||||
self,
|
self,
|
||||||
messages: List[Message],
|
request: ChatCompletionRequest,
|
||||||
temperature: float = 0.6,
|
|
||||||
top_p: float = 0.9,
|
|
||||||
max_gen_len: Optional[int] = None,
|
|
||||||
logprobs: bool = False,
|
|
||||||
tool_prompt_format: ToolPromptFormat = ToolPromptFormat.json,
|
|
||||||
) -> Generator:
|
) -> Generator:
|
||||||
|
messages = chat_completion_request_to_messages(request)
|
||||||
|
|
||||||
|
sampling_params = request.sampling_params
|
||||||
|
max_gen_len = sampling_params.max_tokens
|
||||||
if (
|
if (
|
||||||
max_gen_len is None
|
max_gen_len is None
|
||||||
or max_gen_len == 0
|
or max_gen_len == 0
|
||||||
|
@ -343,12 +338,12 @@ class Llama:
|
||||||
yield from self.generate(
|
yield from self.generate(
|
||||||
model_input=self.formatter.encode_dialog_prompt(
|
model_input=self.formatter.encode_dialog_prompt(
|
||||||
messages,
|
messages,
|
||||||
tool_prompt_format,
|
request.tool_prompt_format,
|
||||||
),
|
),
|
||||||
max_gen_len=max_gen_len,
|
max_gen_len=max_gen_len,
|
||||||
temperature=temperature,
|
temperature=sampling_params.temperature,
|
||||||
top_p=top_p,
|
top_p=sampling_params.top_p,
|
||||||
logprobs=logprobs,
|
logprobs=bool(request.logprobs),
|
||||||
include_stop_token=True,
|
include_stop_token=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -13,9 +13,6 @@ from llama_models.sku_list import resolve_model
|
||||||
from llama_models.llama3.api.datatypes import * # noqa: F403
|
from llama_models.llama3.api.datatypes import * # noqa: F403
|
||||||
from llama_stack.apis.inference import * # noqa: F403
|
from llama_stack.apis.inference import * # noqa: F403
|
||||||
from llama_stack.providers.datatypes import ModelDef, ModelsProtocolPrivate
|
from llama_stack.providers.datatypes import ModelDef, ModelsProtocolPrivate
|
||||||
from llama_stack.providers.utils.inference.prompt_adapter import (
|
|
||||||
chat_completion_request_to_messages,
|
|
||||||
)
|
|
||||||
|
|
||||||
from .config import MetaReferenceInferenceConfig
|
from .config import MetaReferenceInferenceConfig
|
||||||
from .generation import Llama
|
from .generation import Llama
|
||||||
|
@ -58,7 +55,18 @@ class MetaReferenceInferenceImpl(Inference, ModelsProtocolPrivate):
|
||||||
if self.config.create_distributed_process_group:
|
if self.config.create_distributed_process_group:
|
||||||
self.generator.stop()
|
self.generator.stop()
|
||||||
|
|
||||||
def completion(
|
def check_model(self, request) -> None:
|
||||||
|
model = resolve_model(request.model)
|
||||||
|
if model is None:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"Unknown model: {request.model}, Run `llama model list`"
|
||||||
|
)
|
||||||
|
elif model.descriptor() != self.model.descriptor():
|
||||||
|
raise RuntimeError(
|
||||||
|
f"Model mismatch: {request.model} != {self.model.descriptor()}"
|
||||||
|
)
|
||||||
|
|
||||||
|
async def completion(
|
||||||
self,
|
self,
|
||||||
model: str,
|
model: str,
|
||||||
content: InterleavedTextMedia,
|
content: InterleavedTextMedia,
|
||||||
|
@ -66,9 +74,114 @@ class MetaReferenceInferenceImpl(Inference, ModelsProtocolPrivate):
|
||||||
stream: Optional[bool] = False,
|
stream: Optional[bool] = False,
|
||||||
logprobs: Optional[LogProbConfig] = None,
|
logprobs: Optional[LogProbConfig] = None,
|
||||||
) -> Union[CompletionResponse, CompletionResponseStreamChunk]:
|
) -> Union[CompletionResponse, CompletionResponseStreamChunk]:
|
||||||
raise NotImplementedError()
|
if logprobs:
|
||||||
|
assert logprobs.top_k == 1, f"Unexpected top_k={logprobs.top_k}"
|
||||||
|
|
||||||
def chat_completion(
|
request = CompletionRequest(
|
||||||
|
model=model,
|
||||||
|
content=content,
|
||||||
|
sampling_params=sampling_params,
|
||||||
|
stream=stream,
|
||||||
|
logprobs=logprobs,
|
||||||
|
)
|
||||||
|
self.check_model(request)
|
||||||
|
|
||||||
|
if request.stream:
|
||||||
|
return self._stream_completion(request)
|
||||||
|
else:
|
||||||
|
return await self._nonstream_completion(request)
|
||||||
|
|
||||||
|
async def _stream_completion(self, request: CompletionRequest) -> AsyncGenerator:
|
||||||
|
def impl():
|
||||||
|
stop_reason = None
|
||||||
|
|
||||||
|
for token_result in self.generator.completion(request):
|
||||||
|
if token_result.text == "<|eot_id|>":
|
||||||
|
stop_reason = StopReason.end_of_turn
|
||||||
|
text = ""
|
||||||
|
elif token_result.text == "<|eom_id|>":
|
||||||
|
stop_reason = StopReason.end_of_message
|
||||||
|
text = ""
|
||||||
|
else:
|
||||||
|
text = token_result.text
|
||||||
|
|
||||||
|
logprobs = None
|
||||||
|
if stop_reason is None:
|
||||||
|
if request.logprobs:
|
||||||
|
assert len(token_result.logprobs) == 1
|
||||||
|
|
||||||
|
logprobs = [
|
||||||
|
TokenLogProbs(
|
||||||
|
logprobs_by_token={
|
||||||
|
token_result.text: token_result.logprobs[0]
|
||||||
|
}
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
|
yield CompletionResponseStreamChunk(
|
||||||
|
delta=text,
|
||||||
|
stop_reason=stop_reason,
|
||||||
|
logprobs=logprobs if request.logprobs else None,
|
||||||
|
)
|
||||||
|
|
||||||
|
if stop_reason is None:
|
||||||
|
yield CompletionResponseStreamChunk(
|
||||||
|
delta="",
|
||||||
|
stop_reason=StopReason.out_of_tokens,
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.config.create_distributed_process_group:
|
||||||
|
async with SEMAPHORE:
|
||||||
|
for x in impl():
|
||||||
|
yield x
|
||||||
|
else:
|
||||||
|
for x in impl():
|
||||||
|
yield x
|
||||||
|
|
||||||
|
async def _nonstream_completion(
|
||||||
|
self, request: CompletionRequest
|
||||||
|
) -> CompletionResponse:
|
||||||
|
def impl():
|
||||||
|
tokens = []
|
||||||
|
logprobs = []
|
||||||
|
stop_reason = None
|
||||||
|
|
||||||
|
tokenizer = self.generator.formatter.tokenizer
|
||||||
|
for token_result in self.generator.completion(request):
|
||||||
|
tokens.append(token_result.token)
|
||||||
|
|
||||||
|
if token_result.token in tokenizer.stop_tokens:
|
||||||
|
# not quite right semantically
|
||||||
|
stop_reason = StopReason.end_of_turn
|
||||||
|
|
||||||
|
if request.logprobs:
|
||||||
|
assert len(token_result.logprobs) == 1
|
||||||
|
|
||||||
|
logprobs.append(
|
||||||
|
TokenLogProbs(
|
||||||
|
logprobs_by_token={
|
||||||
|
token_result.text: token_result.logprobs[0]
|
||||||
|
}
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
if stop_reason is None:
|
||||||
|
stop_reason = StopReason.out_of_tokens
|
||||||
|
|
||||||
|
content = self.generator.formatter.tokenizer.decode(tokens)
|
||||||
|
return CompletionResponse(
|
||||||
|
content=content,
|
||||||
|
stop_reason=stop_reason,
|
||||||
|
logprobs=logprobs if request.logprobs else None,
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.config.create_distributed_process_group:
|
||||||
|
async with SEMAPHORE:
|
||||||
|
return impl()
|
||||||
|
else:
|
||||||
|
return impl()
|
||||||
|
|
||||||
|
async def chat_completion(
|
||||||
self,
|
self,
|
||||||
model: str,
|
model: str,
|
||||||
messages: List[Message],
|
messages: List[Message],
|
||||||
|
@ -93,16 +206,7 @@ class MetaReferenceInferenceImpl(Inference, ModelsProtocolPrivate):
|
||||||
stream=stream,
|
stream=stream,
|
||||||
logprobs=logprobs,
|
logprobs=logprobs,
|
||||||
)
|
)
|
||||||
|
self.check_model(request)
|
||||||
model = resolve_model(request.model)
|
|
||||||
if model is None:
|
|
||||||
raise RuntimeError(
|
|
||||||
f"Unknown model: {request.model}, Run `llama model list`"
|
|
||||||
)
|
|
||||||
elif model.descriptor() != self.model.descriptor():
|
|
||||||
raise RuntimeError(
|
|
||||||
f"Model mismatch: {request.model} != {self.model.descriptor()}"
|
|
||||||
)
|
|
||||||
|
|
||||||
if self.config.create_distributed_process_group:
|
if self.config.create_distributed_process_group:
|
||||||
if SEMAPHORE.locked():
|
if SEMAPHORE.locked():
|
||||||
|
@ -111,26 +215,17 @@ class MetaReferenceInferenceImpl(Inference, ModelsProtocolPrivate):
|
||||||
if request.stream:
|
if request.stream:
|
||||||
return self._stream_chat_completion(request)
|
return self._stream_chat_completion(request)
|
||||||
else:
|
else:
|
||||||
return self._nonstream_chat_completion(request)
|
return await self._nonstream_chat_completion(request)
|
||||||
|
|
||||||
async def _nonstream_chat_completion(
|
async def _nonstream_chat_completion(
|
||||||
self, request: ChatCompletionRequest
|
self, request: ChatCompletionRequest
|
||||||
) -> ChatCompletionResponse:
|
) -> ChatCompletionResponse:
|
||||||
def impl():
|
def impl():
|
||||||
messages = chat_completion_request_to_messages(request)
|
|
||||||
|
|
||||||
tokens = []
|
tokens = []
|
||||||
logprobs = []
|
logprobs = []
|
||||||
stop_reason = None
|
stop_reason = None
|
||||||
|
|
||||||
for token_result in self.generator.chat_completion(
|
for token_result in self.generator.chat_completion(request):
|
||||||
messages=messages,
|
|
||||||
temperature=request.sampling_params.temperature,
|
|
||||||
top_p=request.sampling_params.top_p,
|
|
||||||
max_gen_len=request.sampling_params.max_tokens,
|
|
||||||
logprobs=request.logprobs,
|
|
||||||
tool_prompt_format=request.tool_prompt_format,
|
|
||||||
):
|
|
||||||
tokens.append(token_result.token)
|
tokens.append(token_result.token)
|
||||||
|
|
||||||
if token_result.text == "<|eot_id|>":
|
if token_result.text == "<|eot_id|>":
|
||||||
|
@ -170,8 +265,6 @@ class MetaReferenceInferenceImpl(Inference, ModelsProtocolPrivate):
|
||||||
self, request: ChatCompletionRequest
|
self, request: ChatCompletionRequest
|
||||||
) -> AsyncGenerator:
|
) -> AsyncGenerator:
|
||||||
def impl():
|
def impl():
|
||||||
messages = chat_completion_request_to_messages(request)
|
|
||||||
|
|
||||||
yield ChatCompletionResponseStreamChunk(
|
yield ChatCompletionResponseStreamChunk(
|
||||||
event=ChatCompletionResponseEvent(
|
event=ChatCompletionResponseEvent(
|
||||||
event_type=ChatCompletionResponseEventType.start,
|
event_type=ChatCompletionResponseEventType.start,
|
||||||
|
@ -184,14 +277,7 @@ class MetaReferenceInferenceImpl(Inference, ModelsProtocolPrivate):
|
||||||
stop_reason = None
|
stop_reason = None
|
||||||
ipython = False
|
ipython = False
|
||||||
|
|
||||||
for token_result in self.generator.chat_completion(
|
for token_result in self.generator.chat_completion(request):
|
||||||
messages=messages,
|
|
||||||
temperature=request.sampling_params.temperature,
|
|
||||||
top_p=request.sampling_params.top_p,
|
|
||||||
max_gen_len=request.sampling_params.max_tokens,
|
|
||||||
logprobs=request.logprobs,
|
|
||||||
tool_prompt_format=request.tool_prompt_format,
|
|
||||||
):
|
|
||||||
tokens.append(token_result.token)
|
tokens.append(token_result.token)
|
||||||
|
|
||||||
if not ipython and token_result.text.startswith("<|python_tag|>"):
|
if not ipython and token_result.text.startswith("<|python_tag|>"):
|
||||||
|
|
|
@ -7,16 +7,17 @@
|
||||||
import os
|
import os
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from typing import Generator, List, Optional
|
from typing import Any, Generator
|
||||||
|
|
||||||
from llama_models.llama3.api.chat_format import ChatFormat
|
from llama_models.llama3.api.chat_format import ChatFormat
|
||||||
from llama_models.llama3.api.datatypes import Message, ToolPromptFormat
|
|
||||||
from llama_models.llama3.api.tokenizer import Tokenizer
|
from llama_models.llama3.api.tokenizer import Tokenizer
|
||||||
from llama_models.sku_list import resolve_model
|
from llama_models.sku_list import resolve_model
|
||||||
|
|
||||||
|
from llama_stack.apis.inference import ChatCompletionRequest, CompletionRequest
|
||||||
|
|
||||||
from .config import MetaReferenceInferenceConfig
|
from .config import MetaReferenceInferenceConfig
|
||||||
from .generation import Llama, model_checkpoint_dir
|
from .generation import Llama, model_checkpoint_dir
|
||||||
from .parallel_utils import InferenceArgs, ModelParallelProcessGroup
|
from .parallel_utils import ModelParallelProcessGroup
|
||||||
|
|
||||||
|
|
||||||
class ModelRunner:
|
class ModelRunner:
|
||||||
|
@ -24,15 +25,13 @@ class ModelRunner:
|
||||||
self.llama = llama
|
self.llama = llama
|
||||||
|
|
||||||
# the `task` object is the same that is sent to `ModelParallelProcessGroup.run_inference()`
|
# the `task` object is the same that is sent to `ModelParallelProcessGroup.run_inference()`
|
||||||
def __call__(self, task: InferenceArgs):
|
def __call__(self, req: Any):
|
||||||
return self.llama.chat_completion(
|
if isinstance(req, ChatCompletionRequest):
|
||||||
task.messages,
|
return self.llama.chat_completion(req)
|
||||||
task.temperature,
|
elif isinstance(req, CompletionRequest):
|
||||||
task.top_p,
|
return self.llama.completion(req)
|
||||||
task.max_gen_len,
|
else:
|
||||||
task.logprobs,
|
raise ValueError(f"Unexpected task type {type(req)}")
|
||||||
task.tool_prompt_format,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def init_model_cb(config: MetaReferenceInferenceConfig):
|
def init_model_cb(config: MetaReferenceInferenceConfig):
|
||||||
|
@ -77,23 +76,18 @@ class LlamaModelParallelGenerator:
|
||||||
def __exit__(self, exc_type, exc_value, exc_traceback):
|
def __exit__(self, exc_type, exc_value, exc_traceback):
|
||||||
self.group.stop()
|
self.group.stop()
|
||||||
|
|
||||||
def chat_completion(
|
def completion(
|
||||||
self,
|
self,
|
||||||
messages: List[Message],
|
request: CompletionRequest,
|
||||||
temperature: float = 0.6,
|
|
||||||
top_p: float = 0.9,
|
|
||||||
max_gen_len: Optional[int] = None,
|
|
||||||
logprobs: bool = False,
|
|
||||||
tool_prompt_format: ToolPromptFormat = ToolPromptFormat.json,
|
|
||||||
) -> Generator:
|
) -> Generator:
|
||||||
req_obj = InferenceArgs(
|
req_obj = deepcopy(request)
|
||||||
messages=deepcopy(messages),
|
gen = self.group.run_inference(req_obj)
|
||||||
temperature=temperature,
|
yield from gen
|
||||||
top_p=top_p,
|
|
||||||
max_gen_len=max_gen_len,
|
def chat_completion(
|
||||||
logprobs=logprobs or False,
|
self,
|
||||||
tool_prompt_format=tool_prompt_format,
|
request: ChatCompletionRequest,
|
||||||
)
|
) -> Generator:
|
||||||
|
req_obj = deepcopy(request)
|
||||||
gen = self.group.run_inference(req_obj)
|
gen = self.group.run_inference(req_obj)
|
||||||
yield from gen
|
yield from gen
|
||||||
|
|
|
@ -4,6 +4,12 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
# Copyright (c) Meta Platforms, IAny, nc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
import json
|
import json
|
||||||
import multiprocessing
|
import multiprocessing
|
||||||
import os
|
import os
|
||||||
|
@ -11,10 +17,9 @@ import tempfile
|
||||||
import time
|
import time
|
||||||
import uuid
|
import uuid
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Callable, Generator, List, Literal, Optional, Union
|
from typing import Callable, Generator, Literal, Optional, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
import zmq
|
import zmq
|
||||||
|
|
||||||
from fairscale.nn.model_parallel.initialize import (
|
from fairscale.nn.model_parallel.initialize import (
|
||||||
|
@ -23,25 +28,16 @@ from fairscale.nn.model_parallel.initialize import (
|
||||||
get_model_parallel_src_rank,
|
get_model_parallel_src_rank,
|
||||||
)
|
)
|
||||||
|
|
||||||
from llama_models.llama3.api.datatypes import Message, ToolPromptFormat
|
|
||||||
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
from torch.distributed.launcher.api import elastic_launch, LaunchConfig
|
from torch.distributed.launcher.api import elastic_launch, LaunchConfig
|
||||||
from typing_extensions import Annotated
|
from typing_extensions import Annotated
|
||||||
|
|
||||||
|
from llama_stack.apis.inference import ChatCompletionRequest, CompletionRequest
|
||||||
|
|
||||||
from .generation import TokenResult
|
from .generation import TokenResult
|
||||||
|
|
||||||
|
|
||||||
class InferenceArgs(BaseModel):
|
|
||||||
messages: List[Message]
|
|
||||||
temperature: float
|
|
||||||
top_p: float
|
|
||||||
max_gen_len: int
|
|
||||||
logprobs: bool
|
|
||||||
tool_prompt_format: ToolPromptFormat
|
|
||||||
|
|
||||||
|
|
||||||
class ProcessingMessageName(str, Enum):
|
class ProcessingMessageName(str, Enum):
|
||||||
ready_request = "ready_request"
|
ready_request = "ready_request"
|
||||||
ready_response = "ready_response"
|
ready_response = "ready_response"
|
||||||
|
@ -80,7 +76,7 @@ class TaskRequest(BaseModel):
|
||||||
type: Literal[ProcessingMessageName.task_request] = (
|
type: Literal[ProcessingMessageName.task_request] = (
|
||||||
ProcessingMessageName.task_request
|
ProcessingMessageName.task_request
|
||||||
)
|
)
|
||||||
task: InferenceArgs
|
task: Union[CompletionRequest, ChatCompletionRequest]
|
||||||
|
|
||||||
|
|
||||||
class TaskResponse(BaseModel):
|
class TaskResponse(BaseModel):
|
||||||
|
@ -349,11 +345,13 @@ class ModelParallelProcessGroup:
|
||||||
self.process.join()
|
self.process.join()
|
||||||
self.started = False
|
self.started = False
|
||||||
|
|
||||||
def run_inference(self, inference_args: InferenceArgs) -> Generator:
|
def run_inference(
|
||||||
|
self, req: Union[CompletionRequest, ChatCompletionRequest]
|
||||||
|
) -> Generator:
|
||||||
assert not self.running, "inference already running"
|
assert not self.running, "inference already running"
|
||||||
|
|
||||||
self.running = True
|
self.running = True
|
||||||
self.request_socket.send(encode_msg(TaskRequest(task=inference_args)))
|
self.request_socket.send(encode_msg(TaskRequest(task=req)))
|
||||||
try:
|
try:
|
||||||
while True:
|
while True:
|
||||||
obj_json = self.request_socket.recv()
|
obj_json = self.request_socket.recv()
|
||||||
|
|
|
@ -184,7 +184,7 @@ class LlamaGuardShield(ShieldBase):
|
||||||
|
|
||||||
# TODO: llama-stack inference protocol has issues with non-streaming inference code
|
# TODO: llama-stack inference protocol has issues with non-streaming inference code
|
||||||
content = ""
|
content = ""
|
||||||
async for chunk in self.inference_api.chat_completion(
|
async for chunk in await self.inference_api.chat_completion(
|
||||||
model=self.model,
|
model=self.model,
|
||||||
messages=[shield_input_message],
|
messages=[shield_input_message],
|
||||||
stream=True,
|
stream=True,
|
||||||
|
|
|
@ -134,7 +134,7 @@ class VLLMInferenceImpl(ModelRegistryHelper, Inference):
|
||||||
if self.engine:
|
if self.engine:
|
||||||
self.engine.shutdown_background_loop()
|
self.engine.shutdown_background_loop()
|
||||||
|
|
||||||
def completion(
|
async def completion(
|
||||||
self,
|
self,
|
||||||
model: str,
|
model: str,
|
||||||
content: InterleavedTextMedia,
|
content: InterleavedTextMedia,
|
||||||
|
@ -152,7 +152,7 @@ class VLLMInferenceImpl(ModelRegistryHelper, Inference):
|
||||||
logprobs=logprobs,
|
logprobs=logprobs,
|
||||||
)
|
)
|
||||||
|
|
||||||
def chat_completion(
|
async def chat_completion(
|
||||||
self,
|
self,
|
||||||
model: str,
|
model: str,
|
||||||
messages: list[Message],
|
messages: list[Message],
|
||||||
|
@ -189,7 +189,7 @@ class VLLMInferenceImpl(ModelRegistryHelper, Inference):
|
||||||
if stream:
|
if stream:
|
||||||
return self._stream_chat_completion(request, results_generator)
|
return self._stream_chat_completion(request, results_generator)
|
||||||
else:
|
else:
|
||||||
return self._nonstream_chat_completion(request, results_generator)
|
return await self._nonstream_chat_completion(request, results_generator)
|
||||||
|
|
||||||
async def _nonstream_chat_completion(
|
async def _nonstream_chat_completion(
|
||||||
self, request: ChatCompletionRequest, results_generator: AsyncGenerator
|
self, request: ChatCompletionRequest, results_generator: AsyncGenerator
|
||||||
|
|
|
@ -116,7 +116,7 @@ async def test_create_agent_turn(agents_settings, sample_messages):
|
||||||
)
|
)
|
||||||
|
|
||||||
turn_response = [
|
turn_response = [
|
||||||
chunk async for chunk in agents_impl.create_agent_turn(**turn_request)
|
chunk async for chunk in await agents_impl.create_agent_turn(**turn_request)
|
||||||
]
|
]
|
||||||
|
|
||||||
assert len(turn_response) > 0
|
assert len(turn_response) > 0
|
||||||
|
@ -204,7 +204,7 @@ async def test_rag_agent_as_attachments(
|
||||||
)
|
)
|
||||||
|
|
||||||
turn_response = [
|
turn_response = [
|
||||||
chunk async for chunk in agents_impl.create_agent_turn(**turn_request)
|
chunk async for chunk in await agents_impl.create_agent_turn(**turn_request)
|
||||||
]
|
]
|
||||||
|
|
||||||
assert len(turn_response) > 0
|
assert len(turn_response) > 0
|
||||||
|
@ -218,7 +218,7 @@ async def test_rag_agent_as_attachments(
|
||||||
)
|
)
|
||||||
|
|
||||||
turn_response = [
|
turn_response = [
|
||||||
chunk async for chunk in agents_impl.create_agent_turn(**turn_request)
|
chunk async for chunk in await agents_impl.create_agent_turn(**turn_request)
|
||||||
]
|
]
|
||||||
|
|
||||||
assert len(turn_response) > 0
|
assert len(turn_response) > 0
|
||||||
|
@ -270,7 +270,7 @@ async def test_create_agent_turn_with_brave_search(
|
||||||
)
|
)
|
||||||
|
|
||||||
turn_response = [
|
turn_response = [
|
||||||
chunk async for chunk in agents_impl.create_agent_turn(**turn_request)
|
chunk async for chunk in await agents_impl.create_agent_turn(**turn_request)
|
||||||
]
|
]
|
||||||
|
|
||||||
assert len(turn_response) > 0
|
assert len(turn_response) > 0
|
||||||
|
|
|
@ -126,6 +126,45 @@ async def test_model_list(inference_settings):
|
||||||
assert model_def.identifier == params["model"]
|
assert model_def.identifier == params["model"]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_completion(inference_settings):
|
||||||
|
inference_impl = inference_settings["impl"]
|
||||||
|
params = inference_settings["common_params"]
|
||||||
|
|
||||||
|
provider = inference_impl.routing_table.get_provider_impl(params["model"])
|
||||||
|
if provider.__provider_id__ != "meta-reference":
|
||||||
|
pytest.skip("Other inference providers don't support completion() yet")
|
||||||
|
|
||||||
|
response = await inference_impl.completion(
|
||||||
|
content="Roses are red,",
|
||||||
|
stream=False,
|
||||||
|
model=params["model"],
|
||||||
|
sampling_params=SamplingParams(
|
||||||
|
max_tokens=50,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
assert isinstance(response, CompletionResponse)
|
||||||
|
assert "violets are blue" in response.content
|
||||||
|
|
||||||
|
chunks = [
|
||||||
|
r
|
||||||
|
async for r in await inference_impl.completion(
|
||||||
|
content="Roses are red,",
|
||||||
|
stream=True,
|
||||||
|
model=params["model"],
|
||||||
|
sampling_params=SamplingParams(
|
||||||
|
max_tokens=50,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
|
assert all(isinstance(chunk, CompletionResponseStreamChunk) for chunk in chunks)
|
||||||
|
assert len(chunks) == 51
|
||||||
|
last = chunks[-1]
|
||||||
|
assert last.stop_reason == StopReason.out_of_tokens
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_chat_completion_non_streaming(inference_settings, sample_messages):
|
async def test_chat_completion_non_streaming(inference_settings, sample_messages):
|
||||||
inference_impl = inference_settings["impl"]
|
inference_impl = inference_settings["impl"]
|
||||||
|
@ -146,7 +185,7 @@ async def test_chat_completion_streaming(inference_settings, sample_messages):
|
||||||
inference_impl = inference_settings["impl"]
|
inference_impl = inference_settings["impl"]
|
||||||
response = [
|
response = [
|
||||||
r
|
r
|
||||||
async for r in inference_impl.chat_completion(
|
async for r in await inference_impl.chat_completion(
|
||||||
messages=sample_messages,
|
messages=sample_messages,
|
||||||
stream=True,
|
stream=True,
|
||||||
**inference_settings["common_params"],
|
**inference_settings["common_params"],
|
||||||
|
@ -217,7 +256,7 @@ async def test_chat_completion_with_tool_calling_streaming(
|
||||||
|
|
||||||
response = [
|
response = [
|
||||||
r
|
r
|
||||||
async for r in inference_impl.chat_completion(
|
async for r in await inference_impl.chat_completion(
|
||||||
messages=messages,
|
messages=messages,
|
||||||
tools=[sample_tool_definition],
|
tools=[sample_tool_definition],
|
||||||
stream=True,
|
stream=True,
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue