diff --git a/llama_stack/providers/inline/ios/inference/LocalInferenceImpl/LocalInference.swift b/llama_stack/providers/inline/ios/inference/LocalInferenceImpl/LocalInference.swift index a5394ecff..2c03029e0 100644 --- a/llama_stack/providers/inline/ios/inference/LocalInferenceImpl/LocalInference.swift +++ b/llama_stack/providers/inline/ios/inference/LocalInferenceImpl/LocalInference.swift @@ -40,7 +40,7 @@ public class LocalInference: Inference { public func chatCompletion(request: Components.Schemas.ChatCompletionRequest) -> AsyncStream { return AsyncStream { continuation in - runnerQueue.async { + let workItem = DispatchWorkItem { do { var tokens: [String] = [] @@ -69,9 +69,10 @@ public class LocalInference: Inference { continuation.yield( Components.Schemas.ChatCompletionResponseStreamChunk( event: Components.Schemas.ChatCompletionResponseEvent( - delta: .ToolCallDelta(Components.Schemas.ToolCallDelta( - content: .case1(""), - parse_status: Components.Schemas.ToolCallParseStatus.started + delta: .tool_call(Components.Schemas.ToolCallDelta( + parse_status: Components.Schemas.ToolCallParseStatus.started, + tool_call: .case1(""), + _type: Components.Schemas.ToolCallDelta._typePayload.tool_call ) ), event_type: .progress @@ -95,14 +96,18 @@ public class LocalInference: Inference { text = token } - var delta: Components.Schemas.ChatCompletionResponseEvent.deltaPayload + var delta: Components.Schemas.ContentDelta if ipython { - delta = .ToolCallDelta(Components.Schemas.ToolCallDelta( - content: .case1(text), - parse_status: .in_progress + delta = .tool_call(Components.Schemas.ToolCallDelta( + parse_status: .in_progress, + tool_call: .case1(text), + _type: .tool_call )) } else { - delta = .case1(text) + delta = .text(Components.Schemas.TextDelta( + text: text, + _type: Components.Schemas.TextDelta._typePayload.text) + ) } if stopReason == nil { @@ -129,7 +134,12 @@ public class LocalInference: Inference { continuation.yield( Components.Schemas.ChatCompletionResponseStreamChunk( event: Components.Schemas.ChatCompletionResponseEvent( - delta: .ToolCallDelta(Components.Schemas.ToolCallDelta(content: .case1(""), parse_status: .failure)), + delta: .tool_call(Components.Schemas.ToolCallDelta( + parse_status: Components.Schemas.ToolCallParseStatus.failed, + tool_call: .case1(""), + _type: Components.Schemas.ToolCallDelta._typePayload.tool_call + ) + ), event_type: .progress ) // TODO: stopReason @@ -141,10 +151,12 @@ public class LocalInference: Inference { continuation.yield( Components.Schemas.ChatCompletionResponseStreamChunk( event: Components.Schemas.ChatCompletionResponseEvent( - delta: .ToolCallDelta(Components.Schemas.ToolCallDelta( - content: .ToolCall(toolCall), - parse_status: .success - )), + delta: .tool_call(Components.Schemas.ToolCallDelta( + parse_status: Components.Schemas.ToolCallParseStatus.succeeded, + tool_call: Components.Schemas.ToolCallDelta.tool_callPayload.ToolCall(toolCall), + _type: Components.Schemas.ToolCallDelta._typePayload.tool_call + ) + ), event_type: .progress ) // TODO: stopReason @@ -155,7 +167,10 @@ public class LocalInference: Inference { continuation.yield( Components.Schemas.ChatCompletionResponseStreamChunk( event: Components.Schemas.ChatCompletionResponseEvent( - delta: .case1(""), + delta: .text(Components.Schemas.TextDelta( + text: "", + _type: Components.Schemas.TextDelta._typePayload.text) + ), event_type: .complete ) // TODO: stopReason @@ -166,6 +181,7 @@ public class LocalInference: Inference { print("Inference error: " + error.localizedDescription) } } + runnerQueue.async(execute: workItem) } } } diff --git a/llama_stack/providers/inline/ios/inference/LocalInferenceImpl/Parsing.swift b/llama_stack/providers/inline/ios/inference/LocalInferenceImpl/Parsing.swift index 84da42d1b..0ccd646ed 100644 --- a/llama_stack/providers/inline/ios/inference/LocalInferenceImpl/Parsing.swift +++ b/llama_stack/providers/inline/ios/inference/LocalInferenceImpl/Parsing.swift @@ -6,7 +6,7 @@ func encodeHeader(role: String) -> String { return "<|start_header_id|>\(role)<|end_header_id|>\n\n" } -func encodeDialogPrompt(messages: [Components.Schemas.ChatCompletionRequest.messagesPayloadPayload]) -> String { +func encodeDialogPrompt(messages: [Components.Schemas.Message]) -> String { var prompt = "" prompt.append("<|begin_of_text|>") @@ -20,24 +20,24 @@ func encodeDialogPrompt(messages: [Components.Schemas.ChatCompletionRequest.mess return prompt } -func getRole(message: Components.Schemas.ChatCompletionRequest.messagesPayloadPayload) -> String { +func getRole(message: Components.Schemas.Message) -> String { switch (message) { - case .UserMessage(let m): + case .user(let m): return m.role.rawValue - case .SystemMessage(let m): + case .system(let m): return m.role.rawValue - case .ToolResponseMessage(let m): + case .tool(let m): return m.role.rawValue - case .CompletionMessage(let m): + case .assistant(let m): return m.role.rawValue } } -func encodeMessage(message: Components.Schemas.ChatCompletionRequest.messagesPayloadPayload) -> String { +func encodeMessage(message: Components.Schemas.Message) -> String { var prompt = encodeHeader(role: getRole(message: message)) switch (message) { - case .CompletionMessage(let m): + case .assistant(let m): if (m.tool_calls.count > 0) { prompt += "<|python_tag|>" } @@ -64,37 +64,37 @@ func encodeMessage(message: Components.Schemas.ChatCompletionRequest.messagesPay } switch (message) { - case .UserMessage(let m): + case .user(let m): prompt += _processContent(m.content) - case .SystemMessage(let m): + case .system(let m): prompt += _processContent(m.content) - case .ToolResponseMessage(let m): + case .tool(let m): prompt += _processContent(m.content) - case .CompletionMessage(let m): + case .assistant(let m): prompt += _processContent(m.content) } var eom = false switch (message) { - case .UserMessage(let m): + case .user(let m): switch (m.content) { case .case1(let c): prompt += _processContent(c) - case .ImageMedia(let c): + case .InterleavedContentItem(let c): prompt += _processContent(c) case .case3(let c): prompt += _processContent(c) } - case .CompletionMessage(let m): + case .assistant(let m): // TODO: Support encoding past tool call history // for t in m.tool_calls { // _processContent(t.) //} eom = m.stop_reason == Components.Schemas.StopReason.end_of_message - case .SystemMessage(_): + case .system(_): break - case .ToolResponseMessage(_): + case .tool(_): break } @@ -107,12 +107,12 @@ func encodeMessage(message: Components.Schemas.ChatCompletionRequest.messagesPay return prompt } -func prepareMessages(request: Components.Schemas.ChatCompletionRequest) throws -> [Components.Schemas.ChatCompletionRequest.messagesPayloadPayload] { +func prepareMessages(request: Components.Schemas.ChatCompletionRequest) throws -> [Components.Schemas.Message] { var existingMessages = request.messages - var existingSystemMessage: Components.Schemas.ChatCompletionRequest.messagesPayloadPayload? + var existingSystemMessage: Components.Schemas.Message? // TODO: Existing system message - var messages: [Components.Schemas.ChatCompletionRequest.messagesPayloadPayload] = [] + var messages: [Components.Schemas.Message] = [] let defaultGen = SystemDefaultGenerator() let defaultTemplate = defaultGen.gen() @@ -123,7 +123,7 @@ func prepareMessages(request: Components.Schemas.ChatCompletionRequest) throws - sysContent += try defaultTemplate.render() - messages.append(.SystemMessage(Components.Schemas.SystemMessage( + messages.append(.system(Components.Schemas.SystemMessage( content: .case1(sysContent), role: .system)) ) @@ -133,7 +133,7 @@ func prepareMessages(request: Components.Schemas.ChatCompletionRequest) throws - let toolGen = FunctionTagCustomToolGenerator() let toolTemplate = try toolGen.gen(customTools: request.tools!) let tools = try toolTemplate.render() - messages.append(.UserMessage(Components.Schemas.UserMessage( + messages.append(.user(Components.Schemas.UserMessage( content: .case1(tools), role: .user) ))