一步步掌握 PyTorch Image Models (timm) – 从入门到实战完全指南
引言
在深度学习领域,图像模型的训练和应用一直是研究和工程实践的核心任务。无论是计算机视觉的研究者还是工程师,都需要一个强大且易用的工具库来快速加载预训练模型、构建自定义架构以及进行模型微调。PyTorch Image Models(简称 timm)正是这样一个库,它由 Ross Wightman 创建并维护,现已成为 Hugging Face 生态系统中不可或缺的一部分。
timm 库汇集了数百种最新的图像分类模型架构,提供了统一的 API 接口,使得研究人员和开发者能够轻松地在不同模型之间切换比较。这个库不仅包含了经典的 ResNet、VGG 等模型,还收录了大量现代的高性能模型如 EfficientNet、ConvNeXt、Swin Transformer 等。此外,timm 还提供了丰富的图像变换、数据增强以及模型训练工具,能够满足从实验研究到生产部署的各种需求。
本教程将带领读者从零开始,系统地学习 timm 库的安装、核心功能、使用方法以及最佳实践。通过大量的代码示例和实战项目,读者将能够快速掌握这个强大的工具,并在自己的图像项目中高效地应用 timm。无论你是深度学习的新手,还是希望提升开发效率的资深工程师,本教程都将为你提供有价值的参考和指导。
第一部分:环境搭建与基础配置
安装 timm 库
在开始使用 timm 之前,首先需要正确配置开发环境。timm 库的安装非常简单,可以通过 pip 包管理器一键完成。在安装之前,建议首先创建一个独立的虚拟环境,以避免依赖冲突。
# 创建虚拟环境(可选但推荐)
# python -m venv timm_env
# source timm_env/bin/activate # Linux/Mac
# timm_env\Scripts\activate # Windows
# 安装 timm 库
# pip install timm
# 如果需要特定版本,可以指定版本号
# pip install timm==0.9.12
# 验证安装是否成功
import timm
print(f"timm 版本: {timm.__version__}")
timm 库依赖于 PyTorch 和 torchvision,因此在安装 timm 时,这些依赖项会自动安装。如果你的环境中已经安装了 PyTorch,timm 可以很好地与之配合工作。timm 兼容 PyTorch 1.7 及以上版本,但对于最新的模型架构,建议使用 PyTorch 1.12 或更高版本以获得最佳性能和兼容性。
验证安装和环境检查
安装完成后,我们需要验证所有组件是否正常工作。下面的代码可以帮助你检查环境配置是否正确。
# 检查 PyTorch 版本
import torch
print(f"PyTorch 版本: {torch.__version__}")
print(f"CUDA 是否可用: {torch.cuda.is_available()}")
# 检查 timm 是否正确安装
import timm
print(f"timm 版本: {timm.__version__}")
# 列出 timm 库的一些基本信息
print(f"可用模型数量: {len(timm.list_models())}")
print(f"可用预训练模型数量: {len(timm.list_pretrained())}")
如果 CUDA 可用,你还可以检查 GPU 信息,这对于后续的模型训练和推理非常重要。
# 检查 GPU 信息
if torch.cuda.is_available():
print(f"GPU 设备数量: {torch.cuda.device_count()}")
print(f"当前 GPU: {torch.cuda.get_device_name(0)}")
print(f"GPU 计算能力: {torch.cuda.get_device_capability(0)}")
依赖项说明
timm 库的核心依赖包括 PyTorch、torchvision 以及一些常用的数值计算库。理解这些依赖项的作用有助于更好地使用 timm 库的功能。
# 列出 timm 的主要依赖
# PyTorch: 深度学习框架核心
# torchvision: 图像处理和数据增强工具
# PIL/Pillow: Python 图像处理库
# numpy: 数值计算库
# 可以通过以下命令查看 timm 的完整依赖列表
# pip show timm
在实际项目中,有时可能需要特定的依赖版本来确保某些功能正常工作。例如,使用某些最新的模型架构可能需要更新版本的 PyTorch。以下是一个完整的依赖检查函数,可以帮助你快速诊断环境问题。
def check_timm_environment():
"""检查 timm 环境配置"""
results = {}
# 检查 Python 版本
import sys
results['python_version'] = sys.version
# 检查核心依赖
results['torch_version'] = torch.__version__
results['cuda_available'] = torch.cuda.is_available()
# 检查 timm 功能
results['timm_version'] = timm.__version__
results['num_models'] = len(timm.list_models())
results['num_pretrained'] = len(timm.list_pretrained())
# 打印结果
print("=" * 50)
print("timm 环境检查结果")
print("=" * 50)
for key, value in results.items():
print(f"{key}: {value}")
return results
# 运行环境检查
# check_timm_environment()
完成以上步骤后,你的开发环境就已经配置完成,可以开始使用 timm 库进行图像模型的加载、训练和推理了。
第二部分:核心功能详解
模型列表与浏览
timm 库提供了强大的模型浏览功能,让用户能够快速找到所需的网络架构。通过 timm.list_models() 函数,你可以查看所有可用的模型,也可以根据模式匹配来筛选特定类型的模型。
# 查看所有可用模型的数量
all_models = timm.list_models()
print(f"总共有 {len(all_models)} 个模型")
# 使用通配符搜索模型
efficientnet_models = timm.list_models('*efficientnet*')
print(f"EfficientNet 系列模型: {len(efficientnet_models)} 个")
print("前5个 EfficientNet 模型:", efficientnet_models[:5])
# 搜索包含特定关键词的模型
resnet_models = timm.list_models('resnet*')
print(f"\nResNet 系列模型: {len(resnet_models)} 个")
print("部分 ResNet 模型:", resnet_models[:10])
timm 还支持更复杂的模型筛选功能,帮助用户在数百个模型中找到最符合需求的架构。
# 搜索 ViT (Vision Transformer) 模型
vit_models = timm.list_models('*vit*')
print(f"ViT 模型数量: {len(vit_models)}")
print("ViT 模型示例:", vit_models[:5])
# 搜索 ConvNeXt 模型
convnext_models = timm.list_models('*convnext*')
print(f"\nConvNeXt 模型数量: {len(convnext_models)}")
# 搜索 Swin Transformer 模型
swin_models = timm.list_models('*swin*')
print(f"\nSwin Transformer 模型数量: {len(swin_models)}")
# 组合搜索:查找包含特定模式的模型
models = timm.list_models('*resnet*50*')
print(f"\nResNet50 相关模型: {models}")
# 查看特定模型的所有变体
mobilenet_v3_variants = timm.list_models('*mobilenetv3*')
print(f"\nMobileNetV3 变体: {mobilenet_v3_variants}")
预训练模型筛选
timm 库的另一大特色是提供了大量预训练模型。通过 timm.list_pretrained() 函数,用户可以查看所有带预训练权重的模型。
# 查看所有预训练模型
pretrained_models = timm.list_pretrained()
print(f"预训练模型数量: {len(pretrained_models)}")
print("部分预训练模型:", pretrained_models[:10])
# 搜索特定类型的预训练模型
pretrained_efficientnet = timm.list_pretrained('*efficientnet*')
print(f"\n预训练 EfficientNet 模型: {len(pretrained_efficientnet)}")
# 搜索 ImageNet 预训练模型
imagenet_models = timm.list_pretrained('*in21k*') # ImageNet-21k 预训练
print(f"\nImageNet-21k 预训练模型: {len(imagenet_models)}")
# 搜索 ImageNet-1k 预训练模型(标准 ImageNet)
imagenet1k_models = timm.list_pretrained('*in1k*')
print(f"ImageNet-1k 预训练模型: {len(imagenet1k_models)}")
模型信息查询
在使用某个特定模型之前,了解模型的详细信息非常重要。timm 提供了多种方式查询模型的结构和参数信息。
# 获取模型的配置信息
model_name = 'resnet50'
config = timm.get_pretrained_cfg('resnet50')
print(f"ResNet50 配置信息:")
print(f" 输入分辨率: {config.input_size}")
print(f" 预训练数据集: {config.pretrained_cfg.get('url', 'N/A')[:50]}...")
# 创建模型并查看其结构
model = timm.create_model('resnet50', pretrained=True)
print(f"\n模型类型: {type(model)}")
print(f"模型总参数量: {sum(p.numel() for p in model.parameters()):,}")
print(f"可训练参数量: {sum(p.numel() for p in model.parameters() if p.requires_grad):,}")
# 查看模型的结构
print("\n模型结构概览:")
print(model)
特征提取功能
timm 库提供了灵活的特征提取接口,可以方便地获取模型中间层的特征图。这对于迁移学习、特征可视化以及构建自定义模型非常有帮助。
# 创建带特征提取功能的模型
model = timm.create_model(
'resnet50',
pretrained=True,
features_only=True # 启用特征提取模式
)
# 查看模型输出的特征图信息
print(f"模型输出特征通道数: {model.feature_info.channels()}")
print(f"特征图分辨率变化: {model.feature_info.reduction()}")
# 测试特征提取
import torch
dummy_input = torch.randn(1, 3, 224, 224)
features = model(dummy_input)
print(f"\n特征图数量: {len(features)}")
for i, feat in enumerate(features):
print(f" 特征层 {i}: 形状 {feat.shape}")
分类头配置
timm 允许用户灵活配置分类头的参数,包括类别数量、dropout 概率等。这使得用户可以轻松地将预训练模型应用于不同的分类任务。
# 创建具有自定义类别数的模型
num_classes = 10
model = timm.create_model(
'resnet50',
pretrained=True,
num_classes=num_classes
)
# 查看分类层的配置
print(f"分类类别数: {model.num_classes}")
print(f"Dropout 概率: {getattr(model, 'drop_rate', 0.0)}")
# 创建带 Dropout 的模型
model_with_dropout = timm.create_model(
'resnet50',
pretrained=True,
num_classes=100,
drop_rate=0.5
)
# 创建带自定义池化层的模型
model_custom_pool = timm.create_model(
'resnet50',
pretrained=True,
num_classes=1000,
global_pool='avgmax' # 使用平均和最大池化的结合
)
第三部分:实战教程
基础模型加载与使用
本节将通过详细的代码示例,展示如何加载和使用 timm 库中的预训练模型。我们将从最简单的例子开始,逐步深入到更复杂的应用场景。
# 基础示例:加载预训练的 ResNet50 模型
import timm
import torch
from PIL import Image
from torchvision import transforms
# 加载预训练模型
model = timm.create_model('resnet50', pretrained=True)
model.eval() # 设置为评估模式
print("模型加载成功!")
print(f"模型类别数: {model.num_classes}")
# 定义图像预处理流程
data_config = timm.data.resolve_model_data_config(model)
transform = timm.data.create_transform(**data_config, is_training=False)
print(f"\n预处理配置:")
print(f" 输入尺寸: {data_config['input_size']}")
print(f" 归一化均值: {data_config['mean']}")
print(f" 归一化标准差: {data_config['std']}")
# 加载并预处理一张测试图像
# 这里使用一个示例 URL,实际使用时替换为本地图像路径
from urllib.request import urlopen
from io import BytesIO
# 下载示例图像
url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/puppy.jpg"
response = urlopen(url)
image_data = response.read()
image = Image.open(BytesIO(image_data)).convert('RGB')
# 应用预处理
input_tensor = transform(image)
input_batch = input_tensor.unsqueeze(0) # 添加 batch 维度
print(f"\n输入张量形状: {input_batch.shape}")
# 进行推理
with torch.no_grad():
output = model(input_batch)
# 处理输出
probabilities = torch.nn.functional.softmax(output[0], dim=0)
top5_prob, top5_idx = torch.topk(probabilities, 5)
print("\n预测结果 Top 5:")
for i in range(5):
print(f" 类别 {top5_idx[i].item()}: {top5_prob[i].item():.4f}")
完整的图像分类流程
下面的代码展示了一个完整的图像分类流程,包括模型加载、数据预处理、推理以及结果后处理。
import timm
import torch
from PIL import Image
import numpy as np
class ImageClassifier:
"""图像分类器封装类"""
def __init__(self, model_name='resnet50', pretrained=True, device=None):
"""
初始化图像分类器
参数:
model_name: 模型名称
pretrained: 是否使用预训练权重
device: 计算设备 (cpu/cuda)
"""
# 设置设备
if device is None:
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
else:
self.device = torch.device(device)
# 加载模型
self.model = timm.create_model(
model_name,
pretrained=pretrained,
num_classes=1000
)
self.model = self.model.to(self.device)
self.model.eval()
# 获取预处理配置
self.data_config = timm.data.resolve_model_data_config(self.model)
self.transform = timm.data.create_transform(
**self.data_config,
is_training=False
)
# 加载 ImageNet 类别标签
self.labels = self._load_labels()
print(f"模型: {model_name}")
print(f"设备: {self.device}")
print(f"输入尺寸: {self.data_config['input_size']}")
def _load_labels(self):
"""加载 ImageNet 类别标签"""
# ImageNet 1000 类别的简要描述
# 实际应用中可以从本地文件加载完整的标签列表
labels_url = "https://raw.githubusercontent.com/pytorch/hub/master/imagenet_classes.txt"
try:
from urllib.request import urlopen
response = urlopen(labels_url)
labels = [line.strip().decode('utf-8') for line in response.readlines()]
return labels
except:
return [f"类别_{i}" for i in range(1000)]
def preprocess_image(self, image_path):
"""
预处理图像
参数:
image_path: 图像路径或 PIL Image 对象
返回:
预处理后的张量
"""
if isinstance(image_path, str):
image = Image.open(image_path).convert('RGB')
else:
image = image_path
return self.transform(image)
def predict(self, image_path, top_k=5):
"""
预测图像类别
参数:
image_path: 图像路径
top_k: 返回前 k 个最可能的类别
返回:
预测结果列表 [(类别索引, 类别名, 概率), ...]
"""
# 预处理图像
input_tensor = self.preprocess_image(image_path)
input_batch = input_tensor.unsqueeze(0).to(self.device)
# 推理
with torch.no_grad():
output = self.model(input_batch)
# 计算概率
probabilities = torch.nn.functional.softmax(output[0], dim=0)
top_k_probs, top_k_indices = torch.topk(probabilities, top_k)
# 整理结果
results = []
for i in range(top_k):
idx = top_k_indices[i].item()
prob = top_k_probs[i].item()
label = self.labels[idx] if idx < len(self.labels) else f"类别_{idx}"
results.append((idx, label, prob))
return results
def batch_predict(self, image_paths, top_k=5):
"""
批量预测多个图像
参数:
image_paths: 图像路径列表
top_k: 返回前 k 个最可能的类别
返回:
每个图像的预测结果列表
"""
results = []
for path in image_paths:
result = self.predict(path, top_k)
results.append(result)
return results
# 使用示例
def main():
"""主函数"""
# 创建分类器
classifier = ImageClassifier(model_name='resnet50')
# 准备测试图像(这里使用示例代码,实际使用时提供真实图像路径)
# image_path = "path/to/your/image.jpg"
# results = classifier.predict(image_path)
# print("\n预测结果:")
# for idx, label, prob in results:
# print(f" {label}: {prob:.4f}")
# if __name__ == "__main__":
# main()
模型比较与选择
在实际项目中,选择合适的模型架构非常重要。下面的代码展示了如何比较不同模型的性能和特点。
import timm
import torch
import time
from functools import wraps
def measure_inference_time(func):
"""测量函数执行时间的装饰器"""
@wraps(func)
def wrapper(*args, **kwargs):
start = time.time()
result = func(*args, **kwargs)
end = time.time()
return result, end - start
return wrapper
def count_parameters(model):
"""计算模型参数量"""
total = sum(p.numel() for p in model.parameters())
trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
return total, trainable
def get_model_info(model_name):
"""获取模型信息"""
try:
# 创建模型
model = timm.create_model(model_name, pretrained=False)
# 计算参数量
total_params, trainable_params = count_parameters(model)
# 获取配置信息
config = timm.data.resolve_model_data_config(model)
# 计算模型大小(MB)
param_size = total_params * 4 / (1024 ** 2) # 假设使用 float32
return {
'name': model_name,
'total_params': total_params,
'trainable_params': trainable_params,
'param_size_mb': param_size,
'input_size': config.get('input_size', 'N/A'),
}
except Exception as e:
return {'name': model_name, 'error': str(e)}
def compare_models(model_names, num_runs=10):
"""
比较多个模型的性能
参数:
model_names: 模型名称列表
num_runs: 运行次数
返回:
比较结果字典
"""
results = []
dummy_input = torch.randn(1, 3, 224, 224)
print("模型比较中...")
print("=" * 80)
print(f"{'模型名称':<30} {'参数量(M)':<12} {'大小(MB)':<10} {'推理时间(ms)':<12}")
print("=" * 80)
for name in model_names:
try:
# 创建模型
model = timm.create_model(name, pretrained=False)
model.eval()
# 预热
with torch.no_grad():
for _ in range(3):
_ = model(dummy_input)
# 测量推理时间
times = []
with torch.no_grad():
for _ in range(num_runs):
start = time.time()
_ = model(dummy_input)
times.append((time.time() - start) * 1000)
avg_time = sum(times) / len(times)
total_params = sum(p.numel() for p in model.parameters()) / 1e6
param_size = total_params * 4 / 1024 # 转换为 MB
print(f"{name:<30} {total_params:<12.2f} {param_size:<10.2f} {avg_time:<12.2f}")
results.append({
'name': name,
'params_m': total_params,
'size_mb': param_size,
'inference_time_ms': avg_time
})
except Exception as e:
print(f"{name:<30} 错误: {str(e)}")
print("=" * 80)
return results
# 比较示例
# model_list = [
# 'resnet18', 'resnet34', 'resnet50',
# 'efficientnet_b0', 'efficientnet_b1',
# 'mobilenetv3_small_100', 'mobilenetv3_large_100'
# ]
# compare_results = compare_models(model_list)
迁移学习实战
迁移学习是 timm 库最重要的应用场景之一。下面的代码展示了如何将预训练模型应用于自定义数据集。
import timm
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
from PIL import Image
import os
class CustomDataset(Dataset):
"""自定义数据集类"""
def __init__(self, image_dir, transform=None):
"""
初始化数据集
参数:
image_dir: 图像目录路径
transform: 数据变换
"""
self.image_dir = image_dir
self.transform = transform
self.samples = [] # 存储 (图像路径, 标签) 对
# 加载数据(这里需要根据实际数据结构进行调整)
# 示例结构: image_dir/class_name/image.jpg
self.classes = sorted([
d for d in os.listdir(image_dir)
if os.path.isdir(os.path.join(image_dir, d))
])
self.class_to_idx = {cls: idx for idx, cls in enumerate(self.classes)}
# 收集所有图像路径
for class_name in self.classes:
class_dir = os.path.join(image_dir, class_name)
for img_name in os.listdir(class_dir):
if img_name.lower().endswith(('.png', '.jpg', '.jpeg')):
img_path = os.path.join(class_dir, img_name)
self.samples.append((img_path, self.class_to_idx[class_name]))
def __len__(self):
return len(self.samples)
def __getitem__(self, idx):
"""获取数据样本"""
img_path, label = self.samples[idx]
# 加载图像
image = Image.open(img_path).convert('RGB')
# 应用变换
if self.transform:
image = self.transform(image)
return image, label
def get_class_name(self, idx):
"""根据索引获取类别名称"""
return self.classes[idx]
class TransferLearningClassifier:
"""迁移学习分类器"""
def __init__(self, model_name='resnet50', num_classes=10, device=None):
"""
初始化迁移学习分类器
参数:
model_name: 基础模型名称
num_classes: 目标类别数
device: 计算设备
"""
self.device = device or torch.device('cuda' if torch.cuda.is_available() else 'cpu')
self.num_classes = num_classes
# 创建模型(使用预训练权重)
self.model = timm.create_model(
model_name,
pretrained=True,
num_classes=num_classes
)
self.model = self.model.to(self.device)
# 冻结大部分层(可选)
# self._freeze_layers()
# 设置优化器
self.optimizer = optim.AdamW(self.model.parameters(), lr=1e-4)
# 设置损失函数
self.criterion = nn.CrossEntropyLoss()
# 学习率调度器
self.scheduler = optim.lr_scheduler.CosineAnnealingLR(
self.optimizer, T_max=10
)
print(f"迁移学习模型已创建: {model_name}")
print(f"目标类别数: {num_classes}")
print(f"设备: {self.device}")
def _freeze_layers(self, freeze_until='layer4'):
"""
冻结模型的部分层
参数:
freeze_until: 冻结到此层为止
"""
freeze = True
for name, param in self.model.named_parameters():
if freeze_until in name:
freeze = False
if freeze:
param.requires_grad = False
# 统计可训练参数
trainable = sum(p.numel() for p in self.model.parameters() if p.requires_grad)
total = sum(p.numel() for p in self.model.parameters())
print(f"冻结后 - 可训练参数: {trainable:,} / {total:,} ({trainable/total*100:.1f}%)")
def create_data_loaders(self, train_dir, val_dir, batch_size=32):
"""
创建数据加载器
参数:
train_dir: 训练数据目录
val_dir: 验证数据目录
batch_size: 批次大小
返回:
(train_loader, val_loader)
"""
# 训练数据增强
train_transform = transforms.Compose([
transforms.RandomResizedCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2),
transforms.ToTensor(),
transforms.Normalize(
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]
)
])
# 验证数据变换
val_transform = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]
)
])
# 创建数据集
train_dataset = CustomDataset(train_dir, transform=train_transform)
val_dataset = CustomDataset(val_dir, transform=val_transform)
# 创建数据加载器
train_loader = DataLoader(
train_dataset,
batch_size=batch_size,
shuffle=True,
num_workers=4,
pin_memory=True
)
val_loader = DataLoader(
val_dataset,
batch_size=batch_size,
shuffle=False,
num_workers=4,
pin_memory=True
)
print(f"训练样本数: {len(train_dataset)}")
print(f"验证样本数: {len(val_dataset)}")
print(f"训练批次数: {len(train_loader)}")
print(f"验证批次数: {len(val_loader)}")
return train_loader, val_loader
def train_epoch(self, train_loader):
"""训练一个 epoch"""
self.model.train()
total_loss = 0
correct = 0
total = 0
for batch_idx, (images, labels) in enumerate(train_loader):
images = images.to(self.device)
labels = labels.to(self.device)
# 前向传播
self.optimizer.zero_grad()
outputs = self.model(images)
loss = self.criterion(outputs, labels)
# 反向传播
loss.backward()
self.optimizer.step()
# 统计
total_loss += loss.item()
_, predicted = outputs.max(1)
total += labels.size(0)
correct += predicted.eq(labels).sum().item()
if (batch_idx + 1) % 10 == 0:
print(f" 批次 {batch_idx + 1}/{len(train_loader)}, "
f"损失: {loss.item():.4f}, "
f"准确率: {100.*correct/total:.2f}%")
return total_loss / len(train_loader), 100. * correct / total
def validate(self, val_loader):
"""验证模型"""
self.model.eval()
total_loss = 0
correct = 0
total = 0
with torch.no_grad():
for images, labels in val_loader:
images = images.to(self.device)
labels = labels.to(self.device)
outputs = self.model(images)
loss = self.criterion(outputs, labels)
total_loss += loss.item()
_, predicted = outputs.max(1)
total += labels.size(0)
correct += predicted.eq(labels).sum().item()
return total_loss / len(val_loader), 100. * correct / total
def train(self, train_loader, val_loader, num_epochs=10):
"""
训练模型
参数:
train_loader: 训练数据加载器
val_loader: 验证数据加载器
num_epochs: 训练轮数
"""
best_acc = 0
history = {'train_loss': [], 'train_acc': [], 'val_loss': [], 'val_acc': []}
for epoch in range(num_epochs):
print(f"\n{'='*60}")
print(f"Epoch {epoch + 1}/{num_epochs}")
print(f"{'='*60}")
# 训练
train_loss, train_acc = self.train_epoch(train_loader)
history['train_loss'].append(train_loss)
history['train_acc'].append(train_acc)
# 验证
val_loss, val_acc = self.validate(val_loader)
history['val_loss'].append(val_loss)
history['val_acc'].append(val_acc)
# 学习率调整
self.scheduler.step()
# 打印结果
print(f"\nEpoch {epoch + 1} 结果:")
print(f" 训练 - 损失: {train_loss:.4f}, 准确率: {train_acc:.2f}%")
print(f" 验证 - 损失: {val_loss:.4f}, 准确率: {val_acc:.2f}%")
print(f" 学习率: {self.optimizer.param_groups[0]['lr']:.6f}")
# 保存最佳模型
if val_acc > best_acc:
best_acc = val_acc
self.save_model('best_model.pth')
print(f" 已保存最佳模型 (准确率: {val_acc:.2f}%)")
return history
def save_model(self, path):
"""保存模型"""
torch.save({
'model_state_dict': self.model.state_dict(),
'optimizer_state_dict': self.optimizer.state_dict(),
'num_classes': self.num_classes
}, path)
print(f"模型已保存至: {path}")
def load_model(self, path):
"""加载模型"""
checkpoint = torch.load(path, map_location=self.device)
self.model.load_state_dict(checkpoint['model_state_dict'])
print(f"模型已从 {path} 加载")
# 使用示例
# classifier = TransferLearningClassifier(
# model_name='efficientnet_b0',
# num_classes=10
# )
# train_loader, val_loader = classifier.create_data_loaders(
# train_dir='path/to/train',
# val_dir='path/to/val',
# batch_size=32
# )
# history = classifier.train(train_loader, val_loader, num_epochs=10)
模型微调的高级技巧
在迁移学习过程中,合理的微调策略可以显著提升模型性能。以下代码展示了一些高级微调技巧。
“`python
import timm
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts
class AdvancedFineTuner:
“””高级模型微调器”””
def __init__(self, model_name, num_classes, device=None):
"""
初始化微调器
参数:
model_name: 模型名称
num_classes: 目标类别数
device: 计算设备
"""
self.device = device or torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# 创建模型
self.model = timm.create_model(
model_name,
pretrained=True,
num_classes=num_classes
)
self.model = self.model.to(self.device)
print(f"模型: {model_name}")
print(f"设备: {self.device}")
def get_layer_groups(self):
"""
获取模型的分层参数组
返回:
不同层的参数组,用于差异化学习率
"""
# 获取模型的所有参数和它们的名称
param_groups = []
# 分类头参数(最高学习率)
classifier_params = []
# 主干网络参数(较低学习率)
backbone_params = []
for name, param in self.model.named_parameters():
if not param.requires_grad:
continue
if 'classifier' in name or 'fc' in name or 'head' in name:
classifier_params.append(param)
else:
backbone_params.append(param)
param_groups = [
{'params': classifier_params, 'lr': 1e-3},
{'params': backbone_params, 'lr': 1e-4}
]
return param_groups
def progressive_resizing(self, model, target_size=224, start_size=160, epochs=10):
"""
渐进式调整图像尺寸训练策略
参数:
model: 模型
target_size: 目标图像尺寸
start_size: 起始图像尺寸
epochs: 训练轮数
"""
sizes = [start_size + i * ((target_size - start_size) // epochs)
for i in range(epochs + 1)]
print("渐进式尺寸训练策略:")
for i, size in enumerate(sizes):
print(f" Epoch {i}: 图像尺寸 {size}x{size}")
return sizes
def mixup_data(self, x, y, alpha=1.0):
"""
MixUp 数据增强
参数:
x: 输入张量
y: 标签
alpha: MixUp 参数
返回:
混合后的输入、标签
"""
if alpha > 0:
lam = np.random.beta(alpha, alpha)
else:
lam = 1
batch_size = x.size(0)
index = torch.randperm(batch_size).to(self.device)
mixed_x = lam * x + (1 - lam) * x[index, :
评论区