Support for Llama3.2 models and Swift SDK (#98)

This commit is contained in:
Ashwin Bharambe 2024-09-25 10:29:58 -07:00 committed by GitHub
parent 95abbf576b
commit 56aed59eb4
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
56 changed files with 3745 additions and 630 deletions

View file

@ -0,0 +1,16 @@
//
// LocalInference.h
// LocalInference
//
// Created by Dalton Flanagan on 9/23/24.
//
#import <Foundation/Foundation.h>
//! Project version number for LocalInference.
FOUNDATION_EXPORT double LocalInferenceVersionNumber;
//! Project version string for LocalInference.
FOUNDATION_EXPORT const unsigned char LocalInferenceVersionString[];
// In this header, you should import all the public headers of your framework using statements like #import <LocalInference/PublicHeader.h>

View file

@ -0,0 +1,167 @@
import Foundation
import LLaMARunner
import LlamaStackClient
class RunnerHolder: ObservableObject {
var runner: Runner?
}
public class LocalInference: Inference {
private var runnerHolder = RunnerHolder()
private let runnerQueue: DispatchQueue
public init (queue: DispatchQueue) {
runnerQueue = queue
}
public func loadModel(modelPath: String, tokenizerPath: String, completion: @escaping (Result<Void, Error>) -> Void) {
runnerHolder.runner = runnerHolder.runner ?? Runner(
modelPath: modelPath,
tokenizerPath: tokenizerPath
)
runnerQueue.async {
let runner = self.runnerHolder.runner
do {
try runner!.load()
completion(.success(()))
} catch let loadError {
print("error: " + loadError.localizedDescription)
completion(.failure(loadError))
}
}
}
public func chatCompletion(request: Components.Schemas.ChatCompletionRequest) -> AsyncStream<Components.Schemas.ChatCompletionResponseStreamChunk> {
return AsyncStream { continuation in
runnerQueue.async {
do {
var tokens: [String] = []
let prompt = try encodeDialogPrompt(messages: prepareMessages(request: request))
var stopReason: Components.Schemas.StopReason? = nil
var buffer = ""
var ipython = false
var echoDropped = false
try self.runnerHolder.runner?.generate(prompt, sequenceLength: 4096) { token in
buffer += token
// HACK: Workaround until LlamaRunner exposes echo param
if (!echoDropped) {
if (buffer.hasPrefix(prompt)) {
buffer = String(buffer.dropFirst(prompt.count))
echoDropped = true
}
return
}
tokens.append(token)
if !ipython && (buffer.starts(with: "<|python_tag|>") || buffer.starts(with: "[") ) {
ipython = true
continuation.yield(
Components.Schemas.ChatCompletionResponseStreamChunk(
event: Components.Schemas.ChatCompletionResponseEvent(
delta: .ToolCallDelta(Components.Schemas.ToolCallDelta(
content: .case1(""),
parse_status: Components.Schemas.ToolCallParseStatus.started
)
),
event_type: .progress
)
)
)
if (buffer.starts(with: "<|python_tag|>")) {
buffer = String(buffer.dropFirst("<|python_tag|>".count))
}
}
// TODO: Non-streaming lobprobs
var text = ""
if token == "<|eot_id|>" {
stopReason = Components.Schemas.StopReason.end_of_turn
} else if token == "<|eom_id|>" {
stopReason = Components.Schemas.StopReason.end_of_message
} else {
text = token
}
var delta: Components.Schemas.ChatCompletionResponseEvent.deltaPayload
if ipython {
delta = .ToolCallDelta(Components.Schemas.ToolCallDelta(
content: .case1(text),
parse_status: .in_progress
))
} else {
delta = .case1(text)
}
if stopReason == nil {
continuation.yield(
Components.Schemas.ChatCompletionResponseStreamChunk(
event: Components.Schemas.ChatCompletionResponseEvent(
delta: delta,
event_type: .progress
)
)
)
}
}
if stopReason == nil {
stopReason = Components.Schemas.StopReason.out_of_tokens
}
let message = decodeAssistantMessage(tokens: tokens.joined(), stopReason: stopReason!)
// TODO: non-streaming support
let didParseToolCalls = message.tool_calls.count > 0
if ipython && !didParseToolCalls {
continuation.yield(
Components.Schemas.ChatCompletionResponseStreamChunk(
event: Components.Schemas.ChatCompletionResponseEvent(
delta: .ToolCallDelta(Components.Schemas.ToolCallDelta(content: .case1(""), parse_status: .failure)),
event_type: .progress
)
// TODO: stopReason
)
)
}
for toolCall in message.tool_calls {
continuation.yield(
Components.Schemas.ChatCompletionResponseStreamChunk(
event: Components.Schemas.ChatCompletionResponseEvent(
delta: .ToolCallDelta(Components.Schemas.ToolCallDelta(
content: .ToolCall(toolCall),
parse_status: .success
)),
event_type: .progress
)
// TODO: stopReason
)
)
}
continuation.yield(
Components.Schemas.ChatCompletionResponseStreamChunk(
event: Components.Schemas.ChatCompletionResponseEvent(
delta: .case1(""),
event_type: .complete
)
// TODO: stopReason
)
)
}
catch (let error) {
print("Inference error: " + error.localizedDescription)
}
}
}
}
}

View file

@ -0,0 +1,235 @@
import Foundation
import LlamaStackClient
func encodeHeader(role: String) -> String {
return "<|start_header_id|>\(role)<|end_header_id|>\n\n"
}
func encodeDialogPrompt(messages: [Components.Schemas.ChatCompletionRequest.messagesPayloadPayload]) -> String {
var prompt = ""
prompt.append("<|begin_of_text|>")
for message in messages {
let msg = encodeMessage(message: message)
prompt += msg
}
prompt.append(encodeHeader(role: "assistant"))
return prompt
}
func getRole(message: Components.Schemas.ChatCompletionRequest.messagesPayloadPayload) -> String {
switch (message) {
case .UserMessage(let m):
return m.role.rawValue
case .SystemMessage(let m):
return m.role.rawValue
case .ToolResponseMessage(let m):
return m.role.rawValue
case .CompletionMessage(let m):
return m.role.rawValue
}
}
func encodeMessage(message: Components.Schemas.ChatCompletionRequest.messagesPayloadPayload) -> String {
var prompt = encodeHeader(role: getRole(message: message))
switch (message) {
case .CompletionMessage(let m):
if (m.tool_calls.count > 0) {
prompt += "<|python_tag|>"
}
default:
break
}
func _processContent(_ content: Any) -> String {
func _process(_ c: Any) {
if let str = c as? String {
prompt += str
}
}
if let str = content as? String {
_process(str)
} else if let list = content as? [Any] {
for c in list {
_process(c)
}
}
return ""
}
switch (message) {
case .UserMessage(let m):
prompt += _processContent(m.content)
case .SystemMessage(let m):
prompt += _processContent(m.content)
case .ToolResponseMessage(let m):
prompt += _processContent(m.content)
case .CompletionMessage(let m):
prompt += _processContent(m.content)
}
var eom = false
switch (message) {
case .UserMessage(let m):
switch (m.content) {
case .case1(let c):
prompt += _processContent(c)
case .case2(let c):
prompt += _processContent(c)
}
case .CompletionMessage(let m):
// TODO: Support encoding past tool call history
// for t in m.tool_calls {
// _processContent(t.)
//}
eom = m.stop_reason == Components.Schemas.StopReason.end_of_message
case .SystemMessage(_):
break
case .ToolResponseMessage(_):
break
}
if (eom) {
prompt += "<|eom_id|>"
} else {
prompt += "<|eot_id|>"
}
return prompt
}
func prepareMessages(request: Components.Schemas.ChatCompletionRequest) throws -> [Components.Schemas.ChatCompletionRequest.messagesPayloadPayload] {
var existingMessages = request.messages
var existingSystemMessage: Components.Schemas.ChatCompletionRequest.messagesPayloadPayload?
// TODO: Existing system message
var messages: [Components.Schemas.ChatCompletionRequest.messagesPayloadPayload] = []
let defaultGen = SystemDefaultGenerator()
let defaultTemplate = defaultGen.gen()
var sysContent = ""
// TODO: Built-in tools
sysContent += try defaultTemplate.render()
messages.append(.SystemMessage(Components.Schemas.SystemMessage(
content: .case1(sysContent),
role: .system))
)
if request.tools?.isEmpty == false {
// TODO: Separate built-ins and custom tools (right now everything treated as custom)
let toolGen = FunctionTagCustomToolGenerator()
let toolTemplate = try toolGen.gen(customTools: request.tools!)
let tools = try toolTemplate.render()
messages.append(.UserMessage(Components.Schemas.UserMessage(
content: .case1(tools),
role: .user)
))
}
messages.append(contentsOf: existingMessages)
return messages
}
struct FunctionCall {
let name: String
let params: [String: Any]
}
public func maybeExtractCustomToolCalls(input: String) -> [Components.Schemas.ToolCall] {
guard input.hasPrefix("[") && input.hasSuffix("]") else {
return []
}
do {
let trimmed = input.trimmingCharacters(in: CharacterSet(charactersIn: "[]"))
let calls = trimmed.components(separatedBy: "),").map { $0.hasSuffix(")") ? $0 : $0 + ")" }
var result: [Components.Schemas.ToolCall] = []
for call in calls {
guard let nameEndIndex = call.firstIndex(of: "("),
let paramsStartIndex = call.firstIndex(of: "{"),
let paramsEndIndex = call.lastIndex(of: "}") else {
return []
}
let name = String(call[..<nameEndIndex]).trimmingCharacters(in: .whitespacesAndNewlines)
let paramsString = String(call[paramsStartIndex...paramsEndIndex])
guard let data = paramsString.data(using: .utf8),
let params = try? JSONSerialization.jsonObject(with: data, options: []) as? [String: Any] else {
return []
}
var props: [String : Components.Schemas.ToolCall.argumentsPayload.additionalPropertiesPayload] = [:]
for (param_name, param) in params {
switch (param) {
case let value as String:
props[param_name] = .case1(value)
case let value as Int:
props[param_name] = .case2(value)
case let value as Double:
props[param_name] = .case3(value)
case let value as Bool:
props[param_name] = .case4(value)
default:
return []
}
}
result.append(
Components.Schemas.ToolCall(
arguments: .init(additionalProperties: props),
call_id: UUID().uuidString,
tool_name: .case2(name) // custom_tool
)
)
}
return result.isEmpty ? [] : result
} catch {
return []
}
}
func decodeAssistantMessage(tokens: String, stopReason: Components.Schemas.StopReason) -> Components.Schemas.CompletionMessage {
var content = tokens
let roles = ["user", "system", "assistant"]
for role in roles {
let headerStr = encodeHeader(role: role)
if content.hasPrefix(headerStr) {
content = String(content.dropFirst(encodeHeader(role: role).count))
}
}
if content.hasPrefix("<|python_tag|>") {
content = String(content.dropFirst("<|python_tag|>".count))
}
if content.hasSuffix("<|eot_id|>") {
content = String(content.dropLast("<|eot_id|>".count))
} else {
content = String(content.dropLast("<|eom_id|>".count))
}
return Components.Schemas.CompletionMessage(
content: .case1(content),
role: .assistant,
stop_reason: stopReason,
tool_calls: maybeExtractCustomToolCalls(input: content)
)
}

View file

@ -0,0 +1,12 @@
import Foundation
import Stencil
public struct PromptTemplate {
let template: String
let data: [String: Any]
public func render() throws -> String {
let template = Template(templateString: self.template)
return try template.render(self.data)
}
}

View file

@ -0,0 +1,91 @@
import Foundation
import LlamaStackClient
func convertToNativeSwiftType(_ value: Any) -> Any {
switch value {
case let number as NSNumber:
if CFGetTypeID(number) == CFBooleanGetTypeID() {
return number.boolValue
}
if floor(number.doubleValue) == number.doubleValue {
return number.intValue
}
return number.doubleValue
case let string as String:
return string
case let array as [Any]:
return array.map(convertToNativeSwiftType)
case let dict as [String: Any]:
return dict.mapValues(convertToNativeSwiftType)
case is NSNull:
return NSNull()
default:
return value
}
}
public class SystemDefaultGenerator {
public init() {}
public func gen() -> PromptTemplate {
let templateStr = """
Cutting Knowledge Date: December 2023
Today Date: {{ today }}
"""
let dateFormatter = DateFormatter()
dateFormatter.dateFormat = "dd MMMM yyyy"
return PromptTemplate(
template: templateStr,
data: ["today": dateFormatter.string(from: Date())]
)
}
}
public class FunctionTagCustomToolGenerator {
public init() {}
public func gen(customTools: [Components.Schemas.ToolDefinition]) throws -> PromptTemplate {
// TODO: required params
// TODO: {{#unless @last}},{{/unless}}
let templateStr = """
You are an expert in composing functions. You are given a question and a set of possible functions.
Based on the question, you will need to make one or more function/tool calls to achieve the purpose.
If none of the function can be used, point it out. If the given question lacks the parameters required by the function,
also point it out. You should only return the function call in tools call sections.
If you decide to invoke any of the function(s), you MUST put it in the format of [func_name1(params_name1=params_value1, params_name2=params_value2...), func_name2(params)]
You SHOULD NOT include any other text in the response.
Here is a list of functions in JSON format that you can invoke.
[
{% for t in custom_tools %}
{
"name": "{{t.tool_name}}",
"description": "{{t.description}}",
"parameters": {
"type": "dict",
"properties": { {{t.parameters}} }
}
{{/let}}
{% endfor -%}
]
"""
let encoder = JSONEncoder()
return PromptTemplate(
template: templateStr,
data: ["custom_tools": try customTools.map {
let data = try encoder.encode($0)
let obj = try JSONSerialization.jsonObject(with: data)
return convertToNativeSwiftType(obj)
}]
)
}
}