forked from phoenix-oss/llama-stack-mirror
Support for Llama3.2 models and Swift SDK (#98)
This commit is contained in:
parent
95abbf576b
commit
56aed59eb4
56 changed files with 3745 additions and 630 deletions
|
@ -0,0 +1,16 @@
|
|||
//
|
||||
// LocalInference.h
|
||||
// LocalInference
|
||||
//
|
||||
// Created by Dalton Flanagan on 9/23/24.
|
||||
//
|
||||
|
||||
#import <Foundation/Foundation.h>
|
||||
|
||||
//! Project version number for LocalInference.
|
||||
FOUNDATION_EXPORT double LocalInferenceVersionNumber;
|
||||
|
||||
//! Project version string for LocalInference.
|
||||
FOUNDATION_EXPORT const unsigned char LocalInferenceVersionString[];
|
||||
|
||||
// In this header, you should import all the public headers of your framework using statements like #import <LocalInference/PublicHeader.h>
|
|
@ -0,0 +1,167 @@
|
|||
import Foundation
|
||||
|
||||
import LLaMARunner
|
||||
import LlamaStackClient
|
||||
|
||||
class RunnerHolder: ObservableObject {
|
||||
var runner: Runner?
|
||||
}
|
||||
|
||||
public class LocalInference: Inference {
|
||||
private var runnerHolder = RunnerHolder()
|
||||
private let runnerQueue: DispatchQueue
|
||||
|
||||
public init (queue: DispatchQueue) {
|
||||
runnerQueue = queue
|
||||
}
|
||||
|
||||
public func loadModel(modelPath: String, tokenizerPath: String, completion: @escaping (Result<Void, Error>) -> Void) {
|
||||
runnerHolder.runner = runnerHolder.runner ?? Runner(
|
||||
modelPath: modelPath,
|
||||
tokenizerPath: tokenizerPath
|
||||
)
|
||||
|
||||
|
||||
runnerQueue.async {
|
||||
let runner = self.runnerHolder.runner
|
||||
do {
|
||||
try runner!.load()
|
||||
completion(.success(()))
|
||||
} catch let loadError {
|
||||
print("error: " + loadError.localizedDescription)
|
||||
completion(.failure(loadError))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
public func chatCompletion(request: Components.Schemas.ChatCompletionRequest) -> AsyncStream<Components.Schemas.ChatCompletionResponseStreamChunk> {
|
||||
return AsyncStream { continuation in
|
||||
runnerQueue.async {
|
||||
do {
|
||||
var tokens: [String] = []
|
||||
|
||||
let prompt = try encodeDialogPrompt(messages: prepareMessages(request: request))
|
||||
var stopReason: Components.Schemas.StopReason? = nil
|
||||
var buffer = ""
|
||||
var ipython = false
|
||||
var echoDropped = false
|
||||
|
||||
try self.runnerHolder.runner?.generate(prompt, sequenceLength: 4096) { token in
|
||||
buffer += token
|
||||
|
||||
// HACK: Workaround until LlamaRunner exposes echo param
|
||||
if (!echoDropped) {
|
||||
if (buffer.hasPrefix(prompt)) {
|
||||
buffer = String(buffer.dropFirst(prompt.count))
|
||||
echoDropped = true
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
tokens.append(token)
|
||||
|
||||
if !ipython && (buffer.starts(with: "<|python_tag|>") || buffer.starts(with: "[") ) {
|
||||
ipython = true
|
||||
continuation.yield(
|
||||
Components.Schemas.ChatCompletionResponseStreamChunk(
|
||||
event: Components.Schemas.ChatCompletionResponseEvent(
|
||||
delta: .ToolCallDelta(Components.Schemas.ToolCallDelta(
|
||||
content: .case1(""),
|
||||
parse_status: Components.Schemas.ToolCallParseStatus.started
|
||||
)
|
||||
),
|
||||
event_type: .progress
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
if (buffer.starts(with: "<|python_tag|>")) {
|
||||
buffer = String(buffer.dropFirst("<|python_tag|>".count))
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: Non-streaming lobprobs
|
||||
|
||||
var text = ""
|
||||
if token == "<|eot_id|>" {
|
||||
stopReason = Components.Schemas.StopReason.end_of_turn
|
||||
} else if token == "<|eom_id|>" {
|
||||
stopReason = Components.Schemas.StopReason.end_of_message
|
||||
} else {
|
||||
text = token
|
||||
}
|
||||
|
||||
var delta: Components.Schemas.ChatCompletionResponseEvent.deltaPayload
|
||||
if ipython {
|
||||
delta = .ToolCallDelta(Components.Schemas.ToolCallDelta(
|
||||
content: .case1(text),
|
||||
parse_status: .in_progress
|
||||
))
|
||||
} else {
|
||||
delta = .case1(text)
|
||||
}
|
||||
|
||||
if stopReason == nil {
|
||||
continuation.yield(
|
||||
Components.Schemas.ChatCompletionResponseStreamChunk(
|
||||
event: Components.Schemas.ChatCompletionResponseEvent(
|
||||
delta: delta,
|
||||
event_type: .progress
|
||||
)
|
||||
)
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
if stopReason == nil {
|
||||
stopReason = Components.Schemas.StopReason.out_of_tokens
|
||||
}
|
||||
|
||||
let message = decodeAssistantMessage(tokens: tokens.joined(), stopReason: stopReason!)
|
||||
// TODO: non-streaming support
|
||||
|
||||
let didParseToolCalls = message.tool_calls.count > 0
|
||||
if ipython && !didParseToolCalls {
|
||||
continuation.yield(
|
||||
Components.Schemas.ChatCompletionResponseStreamChunk(
|
||||
event: Components.Schemas.ChatCompletionResponseEvent(
|
||||
delta: .ToolCallDelta(Components.Schemas.ToolCallDelta(content: .case1(""), parse_status: .failure)),
|
||||
event_type: .progress
|
||||
)
|
||||
// TODO: stopReason
|
||||
)
|
||||
)
|
||||
}
|
||||
|
||||
for toolCall in message.tool_calls {
|
||||
continuation.yield(
|
||||
Components.Schemas.ChatCompletionResponseStreamChunk(
|
||||
event: Components.Schemas.ChatCompletionResponseEvent(
|
||||
delta: .ToolCallDelta(Components.Schemas.ToolCallDelta(
|
||||
content: .ToolCall(toolCall),
|
||||
parse_status: .success
|
||||
)),
|
||||
event_type: .progress
|
||||
)
|
||||
// TODO: stopReason
|
||||
)
|
||||
)
|
||||
}
|
||||
|
||||
continuation.yield(
|
||||
Components.Schemas.ChatCompletionResponseStreamChunk(
|
||||
event: Components.Schemas.ChatCompletionResponseEvent(
|
||||
delta: .case1(""),
|
||||
event_type: .complete
|
||||
)
|
||||
// TODO: stopReason
|
||||
)
|
||||
)
|
||||
}
|
||||
catch (let error) {
|
||||
print("Inference error: " + error.localizedDescription)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,235 @@
|
|||
import Foundation
|
||||
|
||||
import LlamaStackClient
|
||||
|
||||
func encodeHeader(role: String) -> String {
|
||||
return "<|start_header_id|>\(role)<|end_header_id|>\n\n"
|
||||
}
|
||||
|
||||
func encodeDialogPrompt(messages: [Components.Schemas.ChatCompletionRequest.messagesPayloadPayload]) -> String {
|
||||
var prompt = ""
|
||||
|
||||
prompt.append("<|begin_of_text|>")
|
||||
for message in messages {
|
||||
let msg = encodeMessage(message: message)
|
||||
prompt += msg
|
||||
}
|
||||
|
||||
prompt.append(encodeHeader(role: "assistant"))
|
||||
|
||||
return prompt
|
||||
}
|
||||
|
||||
func getRole(message: Components.Schemas.ChatCompletionRequest.messagesPayloadPayload) -> String {
|
||||
switch (message) {
|
||||
case .UserMessage(let m):
|
||||
return m.role.rawValue
|
||||
case .SystemMessage(let m):
|
||||
return m.role.rawValue
|
||||
case .ToolResponseMessage(let m):
|
||||
return m.role.rawValue
|
||||
case .CompletionMessage(let m):
|
||||
return m.role.rawValue
|
||||
}
|
||||
}
|
||||
|
||||
func encodeMessage(message: Components.Schemas.ChatCompletionRequest.messagesPayloadPayload) -> String {
|
||||
var prompt = encodeHeader(role: getRole(message: message))
|
||||
|
||||
switch (message) {
|
||||
case .CompletionMessage(let m):
|
||||
if (m.tool_calls.count > 0) {
|
||||
prompt += "<|python_tag|>"
|
||||
}
|
||||
default:
|
||||
break
|
||||
}
|
||||
|
||||
func _processContent(_ content: Any) -> String {
|
||||
func _process(_ c: Any) {
|
||||
if let str = c as? String {
|
||||
prompt += str
|
||||
}
|
||||
}
|
||||
|
||||
if let str = content as? String {
|
||||
_process(str)
|
||||
} else if let list = content as? [Any] {
|
||||
for c in list {
|
||||
_process(c)
|
||||
}
|
||||
}
|
||||
|
||||
return ""
|
||||
}
|
||||
|
||||
switch (message) {
|
||||
case .UserMessage(let m):
|
||||
prompt += _processContent(m.content)
|
||||
case .SystemMessage(let m):
|
||||
prompt += _processContent(m.content)
|
||||
case .ToolResponseMessage(let m):
|
||||
prompt += _processContent(m.content)
|
||||
case .CompletionMessage(let m):
|
||||
prompt += _processContent(m.content)
|
||||
}
|
||||
|
||||
var eom = false
|
||||
|
||||
switch (message) {
|
||||
case .UserMessage(let m):
|
||||
switch (m.content) {
|
||||
case .case1(let c):
|
||||
prompt += _processContent(c)
|
||||
case .case2(let c):
|
||||
prompt += _processContent(c)
|
||||
}
|
||||
case .CompletionMessage(let m):
|
||||
// TODO: Support encoding past tool call history
|
||||
// for t in m.tool_calls {
|
||||
// _processContent(t.)
|
||||
//}
|
||||
eom = m.stop_reason == Components.Schemas.StopReason.end_of_message
|
||||
case .SystemMessage(_):
|
||||
break
|
||||
case .ToolResponseMessage(_):
|
||||
break
|
||||
}
|
||||
|
||||
if (eom) {
|
||||
prompt += "<|eom_id|>"
|
||||
} else {
|
||||
prompt += "<|eot_id|>"
|
||||
}
|
||||
|
||||
return prompt
|
||||
}
|
||||
|
||||
func prepareMessages(request: Components.Schemas.ChatCompletionRequest) throws -> [Components.Schemas.ChatCompletionRequest.messagesPayloadPayload] {
|
||||
var existingMessages = request.messages
|
||||
var existingSystemMessage: Components.Schemas.ChatCompletionRequest.messagesPayloadPayload?
|
||||
// TODO: Existing system message
|
||||
|
||||
var messages: [Components.Schemas.ChatCompletionRequest.messagesPayloadPayload] = []
|
||||
|
||||
let defaultGen = SystemDefaultGenerator()
|
||||
let defaultTemplate = defaultGen.gen()
|
||||
|
||||
var sysContent = ""
|
||||
|
||||
// TODO: Built-in tools
|
||||
|
||||
sysContent += try defaultTemplate.render()
|
||||
|
||||
messages.append(.SystemMessage(Components.Schemas.SystemMessage(
|
||||
content: .case1(sysContent),
|
||||
role: .system))
|
||||
)
|
||||
|
||||
if request.tools?.isEmpty == false {
|
||||
// TODO: Separate built-ins and custom tools (right now everything treated as custom)
|
||||
let toolGen = FunctionTagCustomToolGenerator()
|
||||
let toolTemplate = try toolGen.gen(customTools: request.tools!)
|
||||
let tools = try toolTemplate.render()
|
||||
messages.append(.UserMessage(Components.Schemas.UserMessage(
|
||||
content: .case1(tools),
|
||||
role: .user)
|
||||
))
|
||||
}
|
||||
|
||||
messages.append(contentsOf: existingMessages)
|
||||
|
||||
return messages
|
||||
}
|
||||
|
||||
struct FunctionCall {
|
||||
let name: String
|
||||
let params: [String: Any]
|
||||
}
|
||||
|
||||
public func maybeExtractCustomToolCalls(input: String) -> [Components.Schemas.ToolCall] {
|
||||
guard input.hasPrefix("[") && input.hasSuffix("]") else {
|
||||
return []
|
||||
}
|
||||
|
||||
do {
|
||||
let trimmed = input.trimmingCharacters(in: CharacterSet(charactersIn: "[]"))
|
||||
let calls = trimmed.components(separatedBy: "),").map { $0.hasSuffix(")") ? $0 : $0 + ")" }
|
||||
|
||||
var result: [Components.Schemas.ToolCall] = []
|
||||
|
||||
for call in calls {
|
||||
guard let nameEndIndex = call.firstIndex(of: "("),
|
||||
let paramsStartIndex = call.firstIndex(of: "{"),
|
||||
let paramsEndIndex = call.lastIndex(of: "}") else {
|
||||
return []
|
||||
}
|
||||
|
||||
let name = String(call[..<nameEndIndex]).trimmingCharacters(in: .whitespacesAndNewlines)
|
||||
let paramsString = String(call[paramsStartIndex...paramsEndIndex])
|
||||
|
||||
guard let data = paramsString.data(using: .utf8),
|
||||
let params = try? JSONSerialization.jsonObject(with: data, options: []) as? [String: Any] else {
|
||||
return []
|
||||
}
|
||||
|
||||
var props: [String : Components.Schemas.ToolCall.argumentsPayload.additionalPropertiesPayload] = [:]
|
||||
for (param_name, param) in params {
|
||||
switch (param) {
|
||||
case let value as String:
|
||||
props[param_name] = .case1(value)
|
||||
case let value as Int:
|
||||
props[param_name] = .case2(value)
|
||||
case let value as Double:
|
||||
props[param_name] = .case3(value)
|
||||
case let value as Bool:
|
||||
props[param_name] = .case4(value)
|
||||
default:
|
||||
return []
|
||||
}
|
||||
}
|
||||
|
||||
result.append(
|
||||
Components.Schemas.ToolCall(
|
||||
arguments: .init(additionalProperties: props),
|
||||
call_id: UUID().uuidString,
|
||||
tool_name: .case2(name) // custom_tool
|
||||
)
|
||||
)
|
||||
}
|
||||
|
||||
return result.isEmpty ? [] : result
|
||||
} catch {
|
||||
return []
|
||||
}
|
||||
}
|
||||
|
||||
func decodeAssistantMessage(tokens: String, stopReason: Components.Schemas.StopReason) -> Components.Schemas.CompletionMessage {
|
||||
var content = tokens
|
||||
|
||||
let roles = ["user", "system", "assistant"]
|
||||
for role in roles {
|
||||
let headerStr = encodeHeader(role: role)
|
||||
if content.hasPrefix(headerStr) {
|
||||
content = String(content.dropFirst(encodeHeader(role: role).count))
|
||||
}
|
||||
}
|
||||
|
||||
if content.hasPrefix("<|python_tag|>") {
|
||||
content = String(content.dropFirst("<|python_tag|>".count))
|
||||
}
|
||||
|
||||
|
||||
if content.hasSuffix("<|eot_id|>") {
|
||||
content = String(content.dropLast("<|eot_id|>".count))
|
||||
} else {
|
||||
content = String(content.dropLast("<|eom_id|>".count))
|
||||
}
|
||||
|
||||
return Components.Schemas.CompletionMessage(
|
||||
content: .case1(content),
|
||||
role: .assistant,
|
||||
stop_reason: stopReason,
|
||||
tool_calls: maybeExtractCustomToolCalls(input: content)
|
||||
)
|
||||
}
|
|
@ -0,0 +1,12 @@
|
|||
import Foundation
|
||||
import Stencil
|
||||
|
||||
public struct PromptTemplate {
|
||||
let template: String
|
||||
let data: [String: Any]
|
||||
|
||||
public func render() throws -> String {
|
||||
let template = Template(templateString: self.template)
|
||||
return try template.render(self.data)
|
||||
}
|
||||
}
|
|
@ -0,0 +1,91 @@
|
|||
import Foundation
|
||||
|
||||
import LlamaStackClient
|
||||
|
||||
func convertToNativeSwiftType(_ value: Any) -> Any {
|
||||
switch value {
|
||||
case let number as NSNumber:
|
||||
if CFGetTypeID(number) == CFBooleanGetTypeID() {
|
||||
return number.boolValue
|
||||
}
|
||||
if floor(number.doubleValue) == number.doubleValue {
|
||||
return number.intValue
|
||||
}
|
||||
return number.doubleValue
|
||||
case let string as String:
|
||||
return string
|
||||
case let array as [Any]:
|
||||
return array.map(convertToNativeSwiftType)
|
||||
case let dict as [String: Any]:
|
||||
return dict.mapValues(convertToNativeSwiftType)
|
||||
case is NSNull:
|
||||
return NSNull()
|
||||
default:
|
||||
return value
|
||||
}
|
||||
}
|
||||
|
||||
public class SystemDefaultGenerator {
|
||||
public init() {}
|
||||
|
||||
public func gen() -> PromptTemplate {
|
||||
let templateStr = """
|
||||
Cutting Knowledge Date: December 2023
|
||||
Today Date: {{ today }}
|
||||
"""
|
||||
|
||||
let dateFormatter = DateFormatter()
|
||||
dateFormatter.dateFormat = "dd MMMM yyyy"
|
||||
|
||||
return PromptTemplate(
|
||||
template: templateStr,
|
||||
data: ["today": dateFormatter.string(from: Date())]
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
public class FunctionTagCustomToolGenerator {
|
||||
public init() {}
|
||||
|
||||
public func gen(customTools: [Components.Schemas.ToolDefinition]) throws -> PromptTemplate {
|
||||
// TODO: required params
|
||||
// TODO: {{#unless @last}},{{/unless}}
|
||||
|
||||
let templateStr = """
|
||||
You are an expert in composing functions. You are given a question and a set of possible functions.
|
||||
Based on the question, you will need to make one or more function/tool calls to achieve the purpose.
|
||||
If none of the function can be used, point it out. If the given question lacks the parameters required by the function,
|
||||
also point it out. You should only return the function call in tools call sections.
|
||||
|
||||
If you decide to invoke any of the function(s), you MUST put it in the format of [func_name1(params_name1=params_value1, params_name2=params_value2...), func_name2(params)]
|
||||
You SHOULD NOT include any other text in the response.
|
||||
|
||||
Here is a list of functions in JSON format that you can invoke.
|
||||
|
||||
[
|
||||
{% for t in custom_tools %}
|
||||
{
|
||||
"name": "{{t.tool_name}}",
|
||||
"description": "{{t.description}}",
|
||||
"parameters": {
|
||||
"type": "dict",
|
||||
"properties": { {{t.parameters}} }
|
||||
}
|
||||
|
||||
{{/let}}
|
||||
{% endfor -%}
|
||||
]
|
||||
"""
|
||||
|
||||
let encoder = JSONEncoder()
|
||||
return PromptTemplate(
|
||||
template: templateStr,
|
||||
data: ["custom_tools": try customTools.map {
|
||||
let data = try encoder.encode($0)
|
||||
let obj = try JSONSerialization.jsonObject(with: data)
|
||||
return convertToNativeSwiftType(obj)
|
||||
}]
|
||||
)
|
||||
}
|
||||
}
|
Loading…
Add table
Add a link
Reference in a new issue