forked from phoenix-oss/llama-stack-mirror
Support for Llama3.2 models and Swift SDK (#98)
This commit is contained in:
parent
95abbf576b
commit
56aed59eb4
56 changed files with 3745 additions and 630 deletions
|
@ -0,0 +1,167 @@
|
|||
import Foundation
|
||||
|
||||
import LLaMARunner
|
||||
import LlamaStackClient
|
||||
|
||||
class RunnerHolder: ObservableObject {
|
||||
var runner: Runner?
|
||||
}
|
||||
|
||||
public class LocalInference: Inference {
|
||||
private var runnerHolder = RunnerHolder()
|
||||
private let runnerQueue: DispatchQueue
|
||||
|
||||
public init (queue: DispatchQueue) {
|
||||
runnerQueue = queue
|
||||
}
|
||||
|
||||
public func loadModel(modelPath: String, tokenizerPath: String, completion: @escaping (Result<Void, Error>) -> Void) {
|
||||
runnerHolder.runner = runnerHolder.runner ?? Runner(
|
||||
modelPath: modelPath,
|
||||
tokenizerPath: tokenizerPath
|
||||
)
|
||||
|
||||
|
||||
runnerQueue.async {
|
||||
let runner = self.runnerHolder.runner
|
||||
do {
|
||||
try runner!.load()
|
||||
completion(.success(()))
|
||||
} catch let loadError {
|
||||
print("error: " + loadError.localizedDescription)
|
||||
completion(.failure(loadError))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
public func chatCompletion(request: Components.Schemas.ChatCompletionRequest) -> AsyncStream<Components.Schemas.ChatCompletionResponseStreamChunk> {
|
||||
return AsyncStream { continuation in
|
||||
runnerQueue.async {
|
||||
do {
|
||||
var tokens: [String] = []
|
||||
|
||||
let prompt = try encodeDialogPrompt(messages: prepareMessages(request: request))
|
||||
var stopReason: Components.Schemas.StopReason? = nil
|
||||
var buffer = ""
|
||||
var ipython = false
|
||||
var echoDropped = false
|
||||
|
||||
try self.runnerHolder.runner?.generate(prompt, sequenceLength: 4096) { token in
|
||||
buffer += token
|
||||
|
||||
// HACK: Workaround until LlamaRunner exposes echo param
|
||||
if (!echoDropped) {
|
||||
if (buffer.hasPrefix(prompt)) {
|
||||
buffer = String(buffer.dropFirst(prompt.count))
|
||||
echoDropped = true
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
tokens.append(token)
|
||||
|
||||
if !ipython && (buffer.starts(with: "<|python_tag|>") || buffer.starts(with: "[") ) {
|
||||
ipython = true
|
||||
continuation.yield(
|
||||
Components.Schemas.ChatCompletionResponseStreamChunk(
|
||||
event: Components.Schemas.ChatCompletionResponseEvent(
|
||||
delta: .ToolCallDelta(Components.Schemas.ToolCallDelta(
|
||||
content: .case1(""),
|
||||
parse_status: Components.Schemas.ToolCallParseStatus.started
|
||||
)
|
||||
),
|
||||
event_type: .progress
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
if (buffer.starts(with: "<|python_tag|>")) {
|
||||
buffer = String(buffer.dropFirst("<|python_tag|>".count))
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: Non-streaming lobprobs
|
||||
|
||||
var text = ""
|
||||
if token == "<|eot_id|>" {
|
||||
stopReason = Components.Schemas.StopReason.end_of_turn
|
||||
} else if token == "<|eom_id|>" {
|
||||
stopReason = Components.Schemas.StopReason.end_of_message
|
||||
} else {
|
||||
text = token
|
||||
}
|
||||
|
||||
var delta: Components.Schemas.ChatCompletionResponseEvent.deltaPayload
|
||||
if ipython {
|
||||
delta = .ToolCallDelta(Components.Schemas.ToolCallDelta(
|
||||
content: .case1(text),
|
||||
parse_status: .in_progress
|
||||
))
|
||||
} else {
|
||||
delta = .case1(text)
|
||||
}
|
||||
|
||||
if stopReason == nil {
|
||||
continuation.yield(
|
||||
Components.Schemas.ChatCompletionResponseStreamChunk(
|
||||
event: Components.Schemas.ChatCompletionResponseEvent(
|
||||
delta: delta,
|
||||
event_type: .progress
|
||||
)
|
||||
)
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
if stopReason == nil {
|
||||
stopReason = Components.Schemas.StopReason.out_of_tokens
|
||||
}
|
||||
|
||||
let message = decodeAssistantMessage(tokens: tokens.joined(), stopReason: stopReason!)
|
||||
// TODO: non-streaming support
|
||||
|
||||
let didParseToolCalls = message.tool_calls.count > 0
|
||||
if ipython && !didParseToolCalls {
|
||||
continuation.yield(
|
||||
Components.Schemas.ChatCompletionResponseStreamChunk(
|
||||
event: Components.Schemas.ChatCompletionResponseEvent(
|
||||
delta: .ToolCallDelta(Components.Schemas.ToolCallDelta(content: .case1(""), parse_status: .failure)),
|
||||
event_type: .progress
|
||||
)
|
||||
// TODO: stopReason
|
||||
)
|
||||
)
|
||||
}
|
||||
|
||||
for toolCall in message.tool_calls {
|
||||
continuation.yield(
|
||||
Components.Schemas.ChatCompletionResponseStreamChunk(
|
||||
event: Components.Schemas.ChatCompletionResponseEvent(
|
||||
delta: .ToolCallDelta(Components.Schemas.ToolCallDelta(
|
||||
content: .ToolCall(toolCall),
|
||||
parse_status: .success
|
||||
)),
|
||||
event_type: .progress
|
||||
)
|
||||
// TODO: stopReason
|
||||
)
|
||||
)
|
||||
}
|
||||
|
||||
continuation.yield(
|
||||
Components.Schemas.ChatCompletionResponseStreamChunk(
|
||||
event: Components.Schemas.ChatCompletionResponseEvent(
|
||||
delta: .case1(""),
|
||||
event_type: .complete
|
||||
)
|
||||
// TODO: stopReason
|
||||
)
|
||||
)
|
||||
}
|
||||
catch (let error) {
|
||||
print("Inference error: " + error.localizedDescription)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
Loading…
Add table
Add a link
Reference in a new issue