原文: https://www.raywenderlich.com/7960296-core-ml-and-vision-tutorial-on-device-training-on-ios
本教程向你介绍 Core ML 和 Vision 这两个最新的 iOS 框架,以及如何在设备上“微调”模型。
工程文件下载:https://koenig-media.raywenderlich.com/uploads/2020/05/Vibes-1.zip
译者传了百度网盘分享,如果上方的链接失效了可以点击 这里 ,提取码是 vec5
原文运行环境:Swift 5, iOS 13, Xcode 11
Update note: Christine Abernathy updated this tutorial for Xcode 11, Swift 5 and iOS 13. Audrey Tam wrote the original.
译者已经将代码提升兼容至 Swift 5.4, iOS 13, Xcode 12.5
苹果在iOS 11中发布了 Core ML 和 Vision 。 Core ML 为开发者提供了将机器学习模型引入App的方法。从而可以实现在iOS设备上直接使用 AI 相关的功能,比如物体检测。
iOS 13 在 Core ML 3 中增加了 on-device training (基于设备的模型训练)功能,为框架增加了基于设备进行模型个性化微调能力。
在本教程中,你将学习如何使用 Core ML 和 Vision 框架在设备上“微调”一个模型。为了学习这些内容,你可以从文章开始的链接下载一个示例的工程 Vibes ,这是一个根据所选图像随机生成箴言的应用程序。在完成教程后,这个项目还能在训练模型后使用“快捷方式”添加你指定的表情符号并且放到绘制“快捷方式”的位置。
开端
开始前,从文章开头下载示例的工程。下载文件是一个 zip 压缩包,将这个压缩包解压后,你会看到 starter 和 final 这两个文件夹。双击在 starter 中的 Vibes.xcodeproj 打开项目。
编译并且运行(如果你要在真机上运行,需要改BundleID,译者注 )项目,你会看到这个界面。
点击左下角的相机按钮,选择一个相册中预置(模拟器会自带一些图片,最好使用模拟器完成这个教程,译者注 )的照片,可以看到下发自动生成了箴言。
接下来,点击右下角表情符号的图标,选择一个 emoji 表情贴纸添加到图片上。可以使用手指,将贴纸移动到任何位置。
基于这个初始的项目,我们有2点可以改进:
箴言似乎是随机生成的,为何不生成和图片最相关的箴言呢?
表情符号贴纸添加起来太麻烦了,有没有更方便的添加贴纸的方式呢?
你在本教程中的目标是使用机器学习 来解决这两个问题。
什么是机器学习 ?
如果你是机器学习的新手,以下是对于这个领域术语的一些解释。
人工智能 (AI)定义为:以编程方式添加到机器上以模仿人类的行动和思想的能力。
机器学习 (ML)是人工智能的子集,训练机器执行某些特定任务。例如,你可以使用 ML 来训练机器识别图像中的猫,或将文字从一种语言翻译成另一种语言。
深度学习 是一种机器训练方法。这种技术模仿人脑,由组织在网络中的神经元组成。深度学习从提供的数据中训练出一个人工神经网络。
假设你想让机器识别图像中的一只猫。你可以给机器提供大量人工标注为 是猫 和 不是猫 的图像。然后你可以从中建立一个模型,使用这个模型可以对另外的图片做出准确的猜测或预测。
使用模型进行训练
苹果将模型 (Model) 定义为“将机器学习算法应用于一组训练数据的结果”。把模型 看作是一个函数,它接受一个输入,对给定的输入进行特定的操作,使其达到最佳效果,比如学习,然后进行预测和分类,并产生合适的输出。
用标记的数据进行训练被称为监督学习 (supervised learning)。你需要大量的优质数据来建立一个优质模型。什么是 优质 ?优质数据要尽可能全面,因为最终建立的模型全部依赖于喂给机器的数据。
比如,如果你想让你的模型识别所有的猫,但只给它提供一个特定的品种,它可能会不认识在这些品种之外的猫。用残缺的数据进行训练会导致不想要的结果。
训练过程是计算密集型的,通常在服务器上完成。凭借其并行计算能力,使用 GPU 通常会加快训练的速度。
一旦训练完成,你可以将你的模型部署到生产中,在真实世界的数据上运行预测或推理。
预测推理 并不像训练那样需要计算。然而在过去,移动 App 必须远程调用服务器接口才能进行模型推理。现在,移动芯片性能的进步为设备上的推理打开了大门。其好处包括减少延迟,减少对网络的依赖和改善隐私。但是,由于推理运算提高了硬件负载,应用程序大小会增加,推理时电池消耗也会有明显的提升。
本教程展示了如何使用 Core ML 进行基于设备的推理(on-device inference )和基于设备的训练(on-device training )。
苹果提供的框架和机器学习工具
Core ML 与特定领域的框架一起工作,如用于图像分析的 Vision 。 Vision 框架提供了再图像或视频上执行计算机视觉算法的高阶API封装。 Vision 可以使用苹果提供的内置模型或者自定义的 Core ML 模型对图像进行分类(classify)。
Core ML 是建立在低级别基元(lower-level primitives): BNNS加速 和 Metal高效着色器 之上的。
其他可以和 Core ML 共同工作的特定领域还包含:用于处理文本的自然语言识别 和用于声音分析的音频识别 。
将Core ML模型集成到你的APP中
为了在 App 中集成 Core ML ,你需要一个 Core ML 格式的模型。Apple 提供了一些 预训练(就是使用少量数据训练的,效果比较一般的,译者注)的模型 ,可以用于图像分类等需求。如果这些模型不能满足你的要求,你可以自行去AI社区查找一番或者训练你自己的模型。
对于项目 Vibes 的一些改进,需要一个能够将图像分类的模型。可使用的模型具备不同程度的准确性,而这些模型也有不同的大小。(对于这个教程项目来说,)你可以使用提供的 SqueezeNet ,这是一个经过训练,可以识别常见物品的小型模型。
拖拽在 starter – Models 中的 SqueezeNet.mlmodel 文件到你 Xcode 中已经打开的 Vibes 项目中,把它放到 Models 这个组下。
选中 SqueezeNet.mlmodel ,可以再右侧看到这个模型的详细信息。
在预测(Prediction)部分列出了预期的输入和输出
输入侧,期望得到一个 227 * 227 尺寸的图像
输出侧有两种类型, classLabelProbs 返回一个字典,其中包含图像分类的概率, classLabel 返回概率最高的分类。
点击上方的 Model Class 区域:
Xcode 为模型自动生成一个文件,其中包括输入、输出和主类。主类包括用于预测的各种方法。
Vision 框架的标准工作流程是:
首先,创建 Core ML 模型(model)。
然后,创建一到多个请求(requests)。
最后,创建和运行多个请求句柄(request handler)。
现在,你已经有了一个现成的 Core ML 模型 SqueezeNet.mlmodel ,直接进行下一步,创建请求。
创建请求
在 CreateQuoteViewController.swift 文件头部 UIKit 后面增加需要的框架导入声明。
import CoreML
import Vision
Vision 有助于处理图像,如将其转换为所需的格式。
为 CreateQuoteViewController 增加如下的 property:
// 定义了一个懒加载的图像分析请求
private lazy var classificationRequest: VNCoreMLRequest = {
do {
// 创建模型实例
let model = try VNCoreMLModel(for: SqueezeNet(configuration: MLModelConfiguration()).model)
// 基于该模型实例化一个图像分析请求对象。完成处理程序后,接收分类结果并打印出来。
let request = VNCoreMLRequest(model: model) { request, _ in
if let classifications =
request.results as? [VNClassificationObservation] {
print("Classification results: \(classifications)")
}
}
// 使用Vision来裁剪输入的图像以符合模型的期望
request.imageCropAndScaleOption = .centerCrop
return request
} catch {
// 通过杀死应用程序来处理模型加载错误。模型是应用程序捆绑的一部分,应该不会走到这。实际应用中,埋个点吧。
fatalError("Failed to load Vision ML model: \(error)")
}
}()
集成请求
将以下代码添加到 CreateQuoteViewController 私有扩展的底部。(写有 Private methods 注释的那个extension ,译者注)
func classifyImage(_ image: UIImage) {
// 获取图像的方向,并且将图像转换成CIImage格式提供给后续步骤
guard let orientation = CGImagePropertyOrientation(rawValue: UInt32(image.imageOrientation.rawValue)) else {
return
}
guard let ciImage = CIImage(image: image) else {
fatalError("Unable to create \(CIImage.self) from \(image).")
}
// 在一个后台队列中启动一个异步分类请求,当句柄在外部被创建并且安排上时,这个Vision请求将被执行
DispatchQueue.global(qos: .userInitiated).async {
let handler = VNImageRequestHandler(ciImage: ciImage, orientation: orientation)
do {
try handler.perform([self.classificationRequest])
} catch {
print("Failed to perform classification.\n\(error.localizedDescription)")
}
}
}
最后,在 imagePickerController(_:didFinishPickingMediaWithInfo:) 方法最后添加这段代码。
classifyImage(image) // 当用户选择一张图片时,触发分类请求
重新编译、运行App,点击相机图标选择一张照片,这是看上去和之前没什么区别。
但是在 Xcode 控制台应该打印了分类的相关元信息。
在上面这个截图中,图像分类器给了cliff, drop, drop-off (悬崖,跌落,落差)约27.88%的置信度。修改 classificationRequest 中的打印语句以直观的数据输出这个结果。
// print("Classification results: \(classifications)")
let topClassifications = classifications.prefix(2).map {
(confidence: $0.confidence, identifier: $0.identifier)
}
print("Top classifications: \(topClassifications)")
重新编译App并且运行,选择一张照片,控制台这时候直接输出了最高置信度的分类结果。
现在你可以使用分类器提供的预测细节来显示与图片相关的箴言了!
添加具有相关性的箴言
在 imagePickerController(_:didFinishPickingMediaWithInfo:) 方法中,注释掉以下代码。
if let quote = getQuote() {
quoteTextView.text = quote.text
}
不需要使用这个方法来随机展示箴言了。接下来,你将使用 VNClassificationObservation 的结果来给App添加挑选最接近的箴言的逻辑。还是在 CreateQuoteViewController 的私有 extension 中增加这个函数。
func processClassifications(for request: VNRequest, error: Error?) {
DispatchQueue.main.async {
// 处理来自图像分类请求的结果
if let classifications =
request.results as? [VNClassificationObservation] {
// 提取前两个预测结果
let topClassifications = classifications.prefix(2).map {
(confidence: $0.confidence, identifier: $0.identifier)
}
print("Top classifications: \(topClassifications)")
let topIdentifiers =
topClassifications.map {$0.identifier.lowercased() }
// 将预测结果传入getQuote(for:)获得一个相关的箴言
if let quote = self.getQuote(for: topIdentifiers) {
self.quoteTextView.text = quote.text
}
}
}
}
因为箴言的视图会进行UI更新,所以这个方法必须在主线程执行。
最后,在 classificationRequest 方法中将 request 的初始化方法修改如下。
let request = VNCoreMLRequest(model: model) { [weak self] request, error in
guard let self = self else {
return
}
// 当预测结束时,调用方法更新箴言
self.processClassifications(for: request, error: error)
}
重新编译运行,选择一张含有柠檬或者柠檬树的照片(模拟器自带,译者注)。这时候显示的箴言应该和柠檬相关了,而不是随机展示的箴言。
观察在Xcode中分类器的日志,
你可以多尝试几次验证结果的一致性。
Great Stuff! 你已经学会了如何使用 Core ML 进行设备上的模型推理。:]
在设备上对模型进行个性化处理
通过最新的 Core ML 3 ,可以在运行期间在设备上对一个可更新的模型进行微调。这意味着你可以为每个用户提供个性化的体验。
基于设备的个性化处理是 FaceID 背后的逻辑。Apple 将一个通用的模型下发到设备上,识别一般的人脸,在 FaceID 设置过程中,每个用户可以对这个模型进行微调以识别他们自己的脸。
将这一更新的模型返回给 Apple,再部署给其他用户,是没有意义的。所以这个功能也凸显了基于设备个性化处理隐私方面的优势。
可更新的模型 是一个标记为“可更新”的 Core ML 模型,你也可以将你自己训练的模型定义为可更新的。
K最近邻分类(k-NN)算法
接下来,你会使用一个可更新的绘图分类器来改进 Vibes 这个项目。这个分类器是基于k-NN(k-Nearest Neighbors)的。这时黑人问号脸就出来了, k-NN 是什么?
k-NN 算法,简单的说,就是“同性相吸”。
它通过比较特征向量 (feature vectors ) 来达到想要的结果,一个特征向量包含描述一个物体特征的关键信息,比如使用特征向量R、G、B 来表示RGB颜色。
比较特征向量之间的距离是查看两个物体是否相似的简单方法,k-NN 对临接的K个邻居进行分类。
下面的例子展示了一个已经按照正方形和圆形形状进行分类分布的场景,如果你想识别神秘的红色属于哪一组。
根据距离画个圈,可以看到 k = 3 的时候可以预测到这个新图形是一个正方形。
k-NN 模型简单而又迅速,不需要很多数据就可以训练。但是随着样例数据越多,它的性能也会变得越慢。
k-NN 是 Core ML 支持训练的模型类型之一。在接下来的示例中, Vibes 这个项目将使用可更新的绘图分类器完成:
一个作为特征提取器的神经网络,它知道如何识别图形。你需要为 k-NN 模型提取特征。
一个用于基于设备绘图个性化处理的 k-NN 模型
接下来,在 Vibes 这个项目中,用户可以添加一个一次性画3个表情符号的快捷方式。你将会以表情符号作为标签,以绘图作为实例来训练模型。
设置绘图训练流程
首先,准备好响应用户在屏幕上的点击事件,用来训练你的模型:
增加一个界面用于选择表情符号。
增加点击保存的动作。
从 stickerLabel 中删除 UIPanGestureRecognizer 手势。
打开 AddStickerViewController.swift 文件,在 collectionView(_:didSelectItemAt:) 中注释掉原本 performSegue(withIdentifier:sender:) 这一行代码,并且替换为:
performSegue(withIdentifier: "AddShortcutSegue", sender: self)
当用户点击一个表情符号时,会前往一个新的页面。
接着打开 AddShortcutViewController.swift 文件在 savePressed(_:) 方法中添加以下代码:
print("Training data ready for label: \(selectedEmoji ?? "")")
// 当用户点击保存按钮时,回到首页
performSegue(
withIdentifier: "AddShortcutUnwindSegue",
sender: self)
最后,打开文件 CreateQuoteViewController.swift 并且注释掉以下代码:
stickerLabel.isUserInteractionEnabled = true
let panGestureRecognizer = UIPanGestureRecognizer(
target: self,
action: #selector(handlePanGesture(_:)))
stickerLabel.addGestureRecognizer(panGestureRecognizer)
通过注释这段代码,禁止用户移动表情符号。这个功能只有用户在无法控制贴纸位置的时候才有用。
重新编译运行 App,选择一张图片,点击贴纸图标选择一个表情符号。你会看到你选择的表情符号和三个用于绘图的画布。
现在在画布上绘制3个相似的图形,全部完成后 Save 按钮才会被设置为可以点击的状态。
译者注 :如果你发现在模拟器上无法绘图,是因为 Vibes 这个项目在这里用了 PKCanvasView 作为画布,可以在 DrawingView.swift 文件中 setupPencilKitCanvas 方法在设置 canvasView 属性的时候,添加以下代码来允许模拟器模拟Apple pencil绘图。
if #available(iOS 14.0, *) {
canvasView.drawingPolicy = .anyInput
} else {
canvasView.allowsFingerDrawing = true
}
然后点击 Save 按钮,你会看到在 Xcode 的控制台中打印了以下日志:
接下来你就可以将工作聚焦到如何使用保存的快捷方式上。
添加快捷方式绘制视图
现在是时候通过以下步骤来实现在图像上进行直接的绘制了:
首先声明一个 DrawingView 。
接下来在主视图中添加绘制视图。
然后,从 viewDidLoad() 中调用 addCanvasForDrawing 。
最后,在完成图像选择后清除画布。
打开 CreateQuoteViewController.swift 文件,在 @IBOutlet 定义区域后面添加以下声明:
var drawingView: DrawingView! // 用户绘制快捷方式的画布视图
接下来添加以下代码以实现 addCanvasForDrawing() 方法:
// 创建绘图视图示例
drawingView = DrawingView(frame: stickerView.bounds)
// 添加到主视图
view.addSubview(drawingView)
// 添加约束防止和贴纸视图重叠
drawingView.translatesAutoresizingMaskIntoConstraints = false
NSLayoutConstraint.activate([
drawingView.topAnchor.constraint(equalTo: stickerView.topAnchor),
drawingView.leftAnchor.constraint(equalTo: stickerView.leftAnchor),
drawingView.rightAnchor.constraint(equalTo: stickerView.rightAnchor),
drawingView.bottomAnchor.constraint(equalTo: stickerView.bottomAnchor)
])
然后在 viewDidLoad() 末尾添加以下内容:
addCanvasForDrawing()
drawingView.isHidden = true
// 添加绘图视图,初始隐藏
现在在 imagePickerController(_:didFinishPickingMediaWithInfo:) 方法中,在 addStickerButton 的 isEnabled 属性在被设置为 true 之后添加以下代码。
// 清除画布,隐藏绘图视图以便表情符号的贴纸可以正确展示
drawingView.clearCanvas()
drawingView.isHidden = false
编译并且运行 App,选择一张照片,使用鼠标或者手指,验证下载图片上可以绘制图形。
一个小目标搞定!我们继续。
进行模型预测
从 starter 的 Models 文件夹中拖拽 UpdatableDrawingClassifier.mlmodel 到 Xcode 项目视图中 Models 组里。
在项目导航窗格里选中 UpdatableDrawingClassifier.mlmodel 文件,在 Updates 这个标签里列出了模型在训练期间期望的两个输入,一个代表绘制的图形,另一个代表表情符号的标签。
Predictions 部分展示了输入和输出,格式和训练期间使用的格式一致,输出项表示表情符号的标签。
在Xcode的项目导航窗格中选择 Models 文件夹,然后:
点击 File ▸ New ▸ File…
在对话框中选择 iOS ▸ Source ▸ Swift File,点击 Next 。
将创建的文件命名为 UpdatableModel.swift 点击 Create 。
将文件头部的 import Foundation 替换为:
import CoreML // 引入机器学习框架
接下来在文件末尾添加以下代码。
extension UpdatableDrawingClassifier {
// 确保图像与模型所期望的一致
var imageConstraint: MLImageConstraint {
return model.modelDescription
.inputDescriptionsByName["drawing"]!
.imageConstraint!
}
// 用绘图的CVPixelBuffer调用模型的预测方法。
// 返回预测的表情符号标签,如果没有匹配的则返回nil
func predictLabelFor(_ value: MLFeatureValue) -> String? {
guard
let pixelBuffer = value.imageBufferValue,
let prediction = try? prediction(drawing: pixelBuffer).label
else {
return nil
}
if prediction == "unknown" {
print("No prediction found")
return nil
}
return prediction
}
}
更新模型
在 UpdatableModel.swift 文件的 import 区域后面添加以下代码:
// 映射可更新的模型
struct UpdatableModel {
private static var updatedDrawingClassifier: UpdatableDrawingClassifier?
private static let appDirectory = FileManager.default.urls(
for: .applicationSupportDirectory,
in: .userDomainMask).first!
// 指向原始编译模型
private static let defaultModelURL =
UpdatableDrawingClassifier.urlOfModelInThisBundle
// 保存模型的位置
private static var updatedModelURL =
appDirectory.appendingPathComponent("personalized.mlmodelc")
private static var tempUpdatedModelURL =
appDirectory.appendingPathComponent("personalized_tmp.mlmodelc")
private init() { }
static var imageConstraint: MLImageConstraint {
guard let model = try? updatedDrawingClassifier ?? UpdatableDrawingClassifier(configuration: MLModelConfiguration()) else {
fatalError("init UpdatableDrawingClassifier error")
}
return model.imageConstraint
}
}
TIPS: Core ML使用一个扩展名为.mlmodelc的编译模型文件,它实际上是一个文件夹。
将模型加载到内存
接下来,在上面这个 struct 定义后面添加以下代码:
private extension UpdatableModel {
// 加载模型
static func loadModel() {
let fileManager = FileManager.default
if !fileManager.fileExists(atPath: updatedModelURL.path) {
do {
let updatedModelParentURL =
updatedModelURL.deletingLastPathComponent()
try fileManager.createDirectory(
at: updatedModelParentURL,
withIntermediateDirectories: true,
attributes: nil)
let toTemp = updatedModelParentURL
.appendingPathComponent(defaultModelURL.lastPathComponent)
try fileManager.copyItem(
at: defaultModelURL,
to: toTemp)
try fileManager.moveItem(
at: toTemp,
to: updatedModelURL)
} catch {
print("Error: \(error)")
return
}
}
guard let model = try? UpdatableDrawingClassifier(
contentsOf: updatedModelURL) else {
return
}
// 模型加载到内存
updatedDrawingClassifier = model
}
}
以上代码将已经更新、编译完毕的模型加载到内存。接下来在 struct 定义后面添加这个公开的扩展。
extension UpdatableModel {
static func predictLabelFor(_ value: MLFeatureValue) -> String? {
loadModel()
return updatedDrawingClassifier?.predictLabelFor(value)
}
}
predict 方法将模型加载到内存中并且在你添加的扩展方法中调用预测方法。
接下来打开 Drawing.swift 文件在 import PencilKit 后面添加以下导入代码。
开始准备预测输入的信息。
预测准备
Core ML 希望开发人员将预测的输入数据包装在一个 MLFeatureValue 对象中,这个对象包含数据类型和数据值。
在 Drawing.swift 文件已经定义的 struct Drawing 尾部添加以下代码:
var featureValue: MLFeatureValue {
let imageConstraint = UpdatableModel.imageConstraint
let preparedImage = whiteTintedImage
let imageFeatureValue =
try? MLFeatureValue(cgImage: preparedImage, constraint: imageConstraint)
return imageFeatureValue!
}
这段代码定义了一个计算属性,用于设置绘图的特征对象,这个特征对象是基于一个全白的图像和图像的相关约束。
现在你已经准备好了输入数据,可以专注于触发预测动作了。
首先,打开 CreateQuoteViewController.swift ,在文件末尾添加 DrawingViewDelegate 扩展。
extension CreateQuoteViewController: DrawingViewDelegate {
func drawingDidChange(_ drawingView: DrawingView) {
// 绘图的边界,防止越界
let drawingRect = drawingView.boundingSquare()
// 绘图实例
let drawing = Drawing(
drawing: drawingView.canvasView.drawing,
rect: drawingRect)
// 为绘图预测输入创建特征值
let imageFeatureValue = drawing.featureValue
// 进行预测,以获得与该绘制图形相对应的表情符号
let drawingLabel =
UpdatableModel.predictLabelFor(imageFeatureValue)
// 更新主队列中的视图,清除画布并将预测的表情符号添加到主视图中
DispatchQueue.main.async {
drawingView.clearCanvas()
guard let emoji = drawingLabel else {
return
}
self.addStickerToCanvas(emoji, at: drawingRect)
}
}
}
回忆一下,上面的步骤中你已经添加了一个 DrawingView 来绘制表情符号的快捷方式,在以上代码中遵循了 DrawingViewDelegate ,在每次绘图发生变化时,这部分代码就会得到响应。
接着在 imagePickerController(_:didFinishPickingMediaWithInfo:) 删除以下重置画布的操作,因为在上面进行预测的方法中已经清理了画布。
drawingView.clearCanvas()
测试预测
接下来在 addCanvasForDrawing() 方法中添加 drawingView 的代理。
drawingView.delegate = self
这使得视图控制器成为绘图视图代理。
编译并运行该应用程序,选择一张照片。在画布上绘图,验证绘图完成后,画布是否被重置,并在控制台中记录了以下内容。
那是意料之中的事。你还没有添加表情贴纸的快捷方式呢~
现在来到添加表情贴纸快捷方式的流程。在你回到所选照片的视图后,绘制同样的快捷方式。
哎呀,贴纸还是没有被添加! 你可以查看控制台日志来查看问题。
经过一番折腾之后,你可能会注意到训练的模型对所要添加的贴纸完全无法理解。是时候解决这个问题了!
更新模型
更新模型需要创建一个 MLUpdateTask 。更新任务的初始化方法需要编译后的模型文件、续联数据和一个完成后的回调句柄。一般来说,把更新后的模型保存到磁盘并且重新加载,新的预测就会使用最新 的模型数据。
首先你需要根据快捷方式的绘制图形准备训练数据。
回顾一下,之前通过传入一个 MLFeatureProvider 特征对象输入来进行模型预测流程,同样的,你也可以传入一个 MLFeatureProvider 特征对象输入来训练一个模型。另外你也可以传入一个包含多个特 征的 MLBatchProvider 来进行批量预测或者批量训练。
要进行以上的操作,首先,打开 DrawingDataStore.swift 文件替换 import Foundation 为:
译者注:导入CoreML 的时候,内部已经导入过Foundation框架,所以这里直接替换掉就可以了
然后在扩展的末尾添加以下方法:
func prepareTrainingData() throws -> MLBatchProvider {
// 初始化一个空的 MLFeatureProvider 数组
var featureProviders: [MLFeatureProvider] = []
// 定义模型训练输入的名称
let inputName = "drawing"
let outputName = "label"
// 循环浏览数据存储中的图形数据
for drawing in drawings {
if let drawing = drawing {
// 将绘图训练输入包在一个特征值中
let inputValue = drawing.featureValue
// 将emoji训练输入包在一个特征值中
let outputValue = MLFeatureValue(string: emoji)
// 为训练输入创建一个MLFeatureValue集合。这是一个训练输入名称和特征值的字典。
let dataPointFeatures: [String: MLFeatureValue] =
[inputName: inputValue,
outputName: outputValue]
// 为 MLFeatureValue 集合创建一个 MLFeatureProvider,并将其追加到数组中
if let provider =
try? MLDictionaryFeatureProvider(
dictionary: dataPointFeatures) {
featureProviders.append(provider)
}
}
}
// 最后,从MLFeatureProvider数组中创建一个批处理对象(MLArrayBatchProvider)
return MLArrayBatchProvider(array: featureProviders)
}
现在,打开 UpdatableModel.swift 文件,在 UpdatableDrawingClassifier 扩展末尾添加以下代码:
static func updateModel(
at url: URL,
with trainingData: MLBatchProvider,
completionHandler: @escaping (MLUpdateContext) -> Void
) {
do {
let updateTask = try MLUpdateTask(
forModelAt: url,
trainingData: trainingData,
configuration: nil,
completionHandler: completionHandler)
updateTask.resume()
} catch {
print("Couldn't create an MLUpdateTask.")
}
}
以上代码使用编译后的模型URL创建 MLUpdateTask ,还传入了一个带有训练数据批处理 MLBatchProvider 。对这个任务调用 resume() 开始训练,当训练结束时, completionHandler 被调用。
保存模型
接下来,在 UpdatableModel 私有扩展(private extension)末尾添加以下代码:
static func saveUpdatedModel(_ updateContext: MLUpdateContext) {
// 首先,从内存中获取更新的模型。这与原始模型不一样。
let updatedModel = updateContext.model
let fileManager = FileManager.default
do {
// 然后,创建一个中间文件夹来保存更新的模型。
try fileManager.createDirectory(
at: tempUpdatedModelURL,
withIntermediateDirectories: true,
attributes: nil)
// 把更新的模型写到一个临时文件夹中
try updatedModel.write(to: tempUpdatedModelURL)
// 替换模型文件夹的内容
// 直接覆盖现有的mlmodelc文件夹会出现错误。
// 解决方案是保存到一个中间文件夹,然后把内容复制过来。
_ = try fileManager.replaceItemAt(
updatedModelURL,
withItemAt: tempUpdatedModelURL)
print("Updated model saved to:\n\t\(updatedModelURL)")
} catch let error {
print("Could not save updated model to the file system: \(error)")
return
}
}
这个辅助类完成了保存更新模型的任务,它接收了一个 MLUpdateContext ,其中包含了训练相关的有用信息。
执行更新后的模型
在 UpdatableModel 的公开扩展(public extension)末尾添加以下代码:
static func updateWith(
trainingData: MLBatchProvider,
completionHandler: @escaping () -> Void
) {
loadModel()
UpdatableDrawingClassifier.updateModel(
at: updatedModelURL,
with: trainingData) { context in
saveUpdatedModel(context)
DispatchQueue.main.async { completionHandler() }
}
}
以上代码将模型加载到内存,然后调用私有扩展中定义的更新方法,完成处理流程后保存更新后的模型,然后再运行这个流程中的成功回调句柄。
然后,打开 AddShortcutViewController.swift 文件,替换 savePressed(_:) 方法的实现:
do {
let trainingData = try drawingDataStore.prepareTrainingData()
DispatchQueue.global(qos: .userInitiated).async {
UpdatableModel.updateWith(trainingData: trainingData) {
DispatchQueue.main.async {
self.performSegue(
withIdentifier: "AddShortcutUnwindSegue",
sender: self)
}
}
}
} catch {
print("Error updating model", error)
}
在这里,你已经吧一切都混合在一起。设置好训练数据后,启动了一个后台线程来更新模型,更新完毕后调用 AddShortcutUnwindSegue 跳转会主视图。
编译并且运行App,通过以下步骤来创建一个快捷方式。
当你点击 Save 按钮时,观察以下 Xcode 的控制台输出。
在选定的照片上绘制同样的快捷方式团,并验证是否显示了正确的表情符号。
恭喜你,获得了机器学习忍者称号!
接下来呢?
你可以看下下载项目中已经完成的项目(在 final 文件夹里,但是Xode 12.5可能跑不起来。。译者注 )。
查看iOS中的 机器学习视频课程 ,了解更多关于如何使用 Create ML 和 Turi Create 来训练自己的模型。Beginning Machine Learning with Keras & Core ML 指导你如何训练一个神经网络并将其转换为Core ML。
Create ML应用可以让你建立、训练和部署机器学习模型,不需要机器学习的专业知识。你还可以查看WWDC 2019的官方视频,What’s New in Machine Learning 和 Training Object Detection Models in Create ML 。