拿LoRA代码来微调大模型

发布时间:2024-01-18  


本文引用地址:

1 简介LoRA

上一期介绍了如何复用免费的源代码,来搭配企业的专有数据而训练出形形色色的自用小模型。免费代码既省成本、可靠、省算力、又自有IP,可谓取之不尽、用之不竭的资源,岂不美哉!

重头开始训练自己的小模型,是一条鸟语花香之路。然而,基于别人的预训练(Pre-trained) ,搭配自有数据而进行微调(Fine-tuning),常常更是一条康庄大道。

随着LLM 等日益繁荣发展,基于这些大模型的迁移学习(Transfer learning),将其预训练好的模型加以微调(Fine tune),来适应到下游的各项新任务,已经成为热门的议题。关于微调技术,其中LoRA 是一种资源消耗较小的训练方法,它能在较少训练参数时就得到比较稳定的效果。

由于LoRA 的外挂模型参数非常轻量,对于各个下游任务来说,只需要搭配特定的训练数据,并独立维护自身的LoRA 参数即可。在训练时可以冻结原模型( 如ResNet50 或MT5) 的既有参数,只需要更新较轻量的LoRA 参数即可,因而微调训练的效率很高。

LoRA 的全名是:Low-Rank Adaptation of Large Language Models( 及大语言模型的低阶适应)。使用这种LoRA 微调方法进行训练时,并不需要调整原( 大)模型的参数值( 图1 里的蓝色部分),而只需要训练LoRA 模型的参数( 图1 里的棕色部分)。

image.png

图1 LoRA的架构

(引自:https://heidloff.net/article/efficient-fine-tuning-lora/)

典型的LoRA 微调途径是,使用下游任务的数据来对< 原模型 + LoRA> 进行重新训练,让该协同模型的性能在该下游任务上表现出最佳效果。

2 简介ResNet50

ResNet50是很通用的AI模型,他擅长于图像的特征提取(Feature extraction),然后依据特征来进行分类(Classification)。所以,它能帮您瞬间探索任何一张图像的特征,然后帮您识别出图片里的人或物的种类。目前的ResNet50 可以准确地识别出1000 种人或物,如日常生活中常遇到的狗、猫、食物、汽车和各种家居物品等。

3 下载LoRA源代码

首先访问这个cccntu 网页,从Github 下载minLoRA源码 ( 图2)。

1705572670833303.png

图2 Github上的免费LoRA源码

然后,按下<code> 就自动把minLoRA 源码下载到本机里了。接着,把所下载的源代码压缩檔解开,放置于Wibdows 本机的Python 工作区里,例如 /Python310/目录区里( 图3)。

image.png

图3 放置于本机的Python环境里

这样,就能先在本机里做简单的测试,例如创建模型并拿简单数据( 或假数据) 来测试,有助于提升成功的自信心。

4 展开微调训练

Step 1:准备训练&测试数据

首先,准备了/ox_lora_data/train/ 训练图像集,包含2 个类--- 水母(Jellyfish) 和蘑菇(Mushroom),各有12 张图像,如图4。

image.png

image.png

图4 12张图像实例

此外,也准备了/ox_lora_data/test/ 测试图像集,也是水母和蘑菇,各有8 张图像。

Step 2:准备ResNet50预训练模型

本范例从torchvision.models 里加载resnet50 预训练模型。这ResNet50 属于大模型,其泛化能力很好。然而,然而对于本范例的较少类的预测( 推论) 准确度就常显得不足。现在,就拿本范例的测试图像集,来检测一下。程序码如下:

# Lora_ResNet50_001_test.py

import torch

import torch.nn as nn

from torchvision import transforms

from torchvision.datasets import ImageFolder

from torch.utils.data import Dataset, DataLoader

import torchvision.models as models

path = ‘c:/ox_lora_data/’

#----------------------------------

model = models.resnet50(

w e i g h t s = m o d e l s . R e s N e t 5 0 _We i g h t s .

IMAGENET1K_V1)

#----------------------------------

def process_lx(labels, batch_size):

lx = labels.clone()

for i in range(batch_size):

if(labels[i]==0): lx[i]=107

elif(labels[i]==1): lx[i]=947

return lx

#----------------------------------

T = transforms.Compose([

transforms.Resize((224, 224)),

transforms.ToTensor()

])

#----------------------------------

test_ds = ImageFolder(path + ‘test/’, transform=T)

test_dl = DataLoader(test_ds, batch_size=1)

model.eval()

with torch.no_grad():

j, m = (0, 0)

for idx, (image, la) in enumerate(test_dl):

labels = process_lx(la, 1)

pred = model(image)

k = torch.argmax(pred[0])

if(la[0]==0 and k==107):

j += 1

elif(la[0] == 1 and k==947):

m += 1

print(“n 水母(Jellyfish) 的正确辨识率:”, j / 8)

print(“n 蘑菇(Mushroom) 的正确辨识率:”, m / 8)

#------------------

#END

在本范例里,其图像分为2 个类:水母和蘑菇。所以在此程序里,其< 水母、蘑菇> 的类标签(Label)分别为:[0, 1]。而在ResNet50 预训练模型里,其<水母、蘑菇> 类标签分别为:[107, 947]。于是,使用process_lx() 函数,来把此程序里的类标签,转换为ResNet50 的类别标签值。在此范例里,我们拿测试数据集里的< 水母、蘑菇> 各8 张图像来给ResNet50 进行分类预测。执行时,输出如下:

-2

这显示出:蘑菇的预测准确度为:0.125,并不理想。亦即,可以观察到了,大模型ResNet50 在这范例里的下游任务上,其预测的准确度并不美好。于是,LoRA微调方法就派上用场了。

Step 3:定义LoRA模型,并展开协同训练兹回顾LoRA 的架构( 图1)。在刚才的范例里,我们加载的ResNet50 模型,就是上图里的Pretrained Weights( 即蓝色) 部分。现在,就准备添加LoRA 模型,也就是上图里的A 和B( 即棕色) 部分。

当我们把A&B 部分添加上去了,就能展开协同训练了。在协同训练时,我们会先冻结Pretrained Weights部分的参数,不去更改它;而只更新LoRA 的A&B 参数。一旦协同训练完成了,就会把LoRA 与ResNet50 的参数合并起来( 即上图右方的橘色部分。请来看看程序码:

# Lora_ResNet50_002_train.py

import numpy as np

import torch

import torch.nn as nn

from torchvision import transforms

from torchvision.datasets import ImageFolder

from torch.utils.data import Dataset, DataLoader

import torchvision.models as models

from functools import partial

import min_lora_model as Min_LoRA

import min_lora_utils as Min_LoRA_Util

path = ‘c:/ox_lora_data/’

#----------------------------------

# 把图片转换成Tensor

T = transforms.Compose([

transforms.Resize((224, 224)),

transforms.ToTensor()

])

def process_lx(labels, batch_size):

lx = labels.clone()

for i in range(batch_size):

if(labels[i]==0): lx[i]=107

elif(labels[i]==1): lx[i]=947

return lx

#----------------------------------

model = models.resnet50(

weight s =mode l s .ResNet50_We ight s .

IMAGENET1K_V1)

#-------- 添加LoRA --------

my_lora_config = { nn.Linear: { “weight”: partial(

Min_LoRA.LoRAParametrization.from_linear,

rank=16),

}, }

#---- 把LoRA 参数添加到原模型 ------

Min_LoRA.add_lora(model, lora_config=my_lora_

config)

parameters = [

{ “params”: list(Min_LoRA_Util.get_lora_

params(model))}, ]

# 只更新LoRA 的Weights

optimizer = torch.optim.Adam(parameters, lr=1e-3)

loss_fn = nn.CrossEntropyLoss()

model.train()

bz = 4

train_ds = ImageFolder(path + ‘ train/ ’ ,

transform=T)

train_dl = DataLoader(train_ds, batch_size=bz,

shuffle=True)

length = len(train_ds)

#----------------------------------

print(‘n------ 外挂LoRA 模型, 协同训

练 ------’)

epochs = 25

for ep in range(epochs+1):

total_loss = 0

for idx, (images, la) in enumerate(train_dl):

labels = process_lx(la, bz)

pred = model(images)

loss = loss_fn(pred, labels)

loss.backward()

optimizer.step()

optimizer.zero_grad()

total_loss += loss.item() * bz

if(ep%5 == 0):

print(‘ ep=’, ep, ‘, loss=’, total_loss /

length )

#-------------- testing ---------------

test_ds = ImageFolder(path + ‘test/’, transform=T)

test_dl = DataLoader(test_ds, batch_size=1)

model.eval()

with torch.no_grad():

j, m = (0, 0)

for idx, (image, la) in enumerate(test_dl):

labels = process_lx(la, 1)

pred = model(image)

k = torch.argmax(pred[0])

if(la[0]==0 and k==107): j += 1

elif(la[0] == 1 and k==947): m += 1

print(“n 水母(Jellyfish) 的正确辨识率:”, j / 8)

print(“n 蘑菇(Mushroom) 的正确辨识率:”, m / 8)

#END

在此范例程序里, 把minLoRA 的源代码, 与ResNet50预训练模型结合,展开100 回合的微调协同训练。并输出如下:

1705572969519996.png

从上述的输出结果,于是我们可以观察到,当ResNet50 在未加挂LoRA 时,其< 蘑菇> 测试的预测准确率是:0.125。当我们完成协同训练100 回合之后,其预测准确度提升到:0.75,达到微调的目的了。

5 结束语

本文就ResNet50 为例,说明如何拿LoRA 源代码,来对ResNet50 进行微调。您已经发现到了,微调可以让ResNet50 更加贴心,满足您的需求。这种途径可以适用于各种大模型,例如MT5 大语言模型、以及StableDiffusion绘图大模型等。

(本文来源于《EEPW》2024.1-2)

文章来源于:电子产品世界    原文链接
本站所有转载文章系出于传递更多信息之目的,且明确注明来源,不希望被转载的媒体或个人可与我们联系,我们将立即进行删除处理。

相关文章

我们与500+贴片厂合作,完美满足客户的定制需求。为品牌提供定制化的推广方案、专属产品特色页,多渠道推广,SEM/SEO精准营销以及与公众号的联合推广...详细>>

利用葫芦芯平台的卓越技术服务和新产品推广能力,原厂代理能轻松打入消费物联网(IOT)、信息与通信(ICT)、汽车及新能源汽车、工业自动化及工业物联网、装备及功率电子...详细>>

充分利用其强大的电子元器件采购流量,创新性地为这些物料提供了一个全新的窗口。我们的高效数字营销技术,不仅可以助你轻松识别与连接到需求方,更能够极大地提高“闲置物料”的处理能力,通过葫芦芯平台...详细>>

我们的目标很明确:构建一个全方位的半导体产业生态系统。成为一家全球领先的半导体互联网生态公司。目前,我们已成功打造了智能汽车、智能家居、大健康医疗、机器人和材料等五大生态领域。更为重要的是...详细>>

我们深知加工与定制类服务商的价值和重要性,因此,我们倾力为您提供最顶尖的营销资源。在我们的平台上,您可以直接接触到100万的研发工程师和采购工程师,以及10万的活跃客户群体...详细>>

凭借我们强大的专业流量和尖端的互联网数字营销技术,我们承诺为原厂提供免费的产品资料推广服务。无论是最新的资讯、技术动态还是创新产品,都可以通过我们的平台迅速传达给目标客户...详细>>

我们不止于将线索转化为潜在客户。葫芦芯平台致力于形成业务闭环,从引流、宣传到最终销售,全程跟进,确保每一个potential lead都得到妥善处理,从而大幅提高转化率。不仅如此...详细>>