-
Notifications
You must be signed in to change notification settings - Fork 1.8k
Description
Dear Author,
I hope this message finds you well. My apologies for the interruption, but I would like to consult you on a relatively niche question: Could you kindly share how you achieved over 90% accuracy on the CIFAR100 dataset? Would it be possible for you to provide the training script on CIFAR100 you used? I have experimented with several backbones (e.g., ResNet, ConvNeXt, etc.) on CIFAR100 but only achieved around 60% validation accuracy and 80% training accuracy. Notably, I did not use pre-trained weights (on Imagenet1K)or modify the backbone architecture in any way.
Below is the script I used for training:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import os
import time
from datetime import datetime
from torch.cuda.amp import GradScaler, autocast
from convnext_DAW.convnext_tiny_DWv13 import convnext_tiny_DWv13
全局配置
USE_MIXED_PRECISION = True # 是否启用混合精度训练
RESUME_FROM_CHECKPOINT = False # 是否从checkpoint恢复训练
CHECKPOINT_PATH = 'convnext_tiny_DWv13_checkpoint.pth' # checkpoint保存路径
PTH_PATH = 'convnext_tiny_DWv13_best_weight.pth' # checkpoint保存路径
LOG_FILE = 'log_convnext_tiny_DWv13.txt' # 训练日志文件
超参数
EPOCHS = 100
BATCH_SIZE = 100
INIT_LR = 5e-4
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
数据增强
train_transforms = transforms.Compose([
transforms.Resize(224),
transforms.RandomResizedCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
])
val_transforms = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
])
数据集路径(需修改为你的实际路径)
train_dir = "dataset/CIFAR100/train"
val_dir = "dataset/CIFAR100/val"
数据加载
train_dataset = datasets.ImageFolder(train_dir, transform=train_transforms)
val_dataset = datasets.ImageFolder(val_dir, transform=val_transforms)
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True,
num_workers=2, pin_memory=True)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False,
num_workers=1, pin_memory=True)
model = convnext_tiny_DWv13().to(DEVICE)
optimizer = optim.AdamW(model.parameters(), lr=INIT_LR, weight_decay=1e-5)
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=EPOCHS)
criterion = nn.CrossEntropyLoss()
scaler = GradScaler(enabled=USE_MIXED_PRECISION)
def train_one_epoch(epoch):
model.train()
train_loss, train_acc = 0.0, 0.0
start_time = time.time()
for images, labels in train_loader:
images, labels = images.to(DEVICE), labels.to(DEVICE)
torch.autograd.set_detect_anomaly(True)
# 混合精度前向
with autocast(enabled=USE_MIXED_PRECISION):
print("forwarding======")
outputs = model(images)
#if torch.isnan(outputs).any():
#print("forward NaN detected, stopping training.")
#return None
#torch.autograd.set_detect_anomaly(True)
with autocast(enabled=USE_MIXED_PRECISION):
# Loss计算保持全精度
loss = criterion(outputs.float()+1e-9, labels)
if torch.isnan(loss):
print("loss NaN detected, stopping training.")
return None
# 混合精度反向
#torch.autograd.set_detect_anomaly(True)
with autocast(enabled=USE_MIXED_PRECISION):
optimizer.zero_grad()
print("backwarding======")
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
# 统计指标
train_loss += loss.item() * images.size(0)
preds = torch.argmax(outputs, dim=1)
train_acc += (preds == labels).sum().item()
train_loss /= len(train_dataset)
train_acc /= len(train_dataset)
scheduler.step()
return train_loss, train_acc, time.time() - start_time
def validate():
model.eval()
val_loss, val_acc = 0.0, 0.0
with torch.no_grad():
for images, labels in val_loader:
images, labels = images.to(DEVICE), labels.to(DEVICE)
outputs = model(images)
loss = criterion(outputs.float(), labels)
val_loss += loss.item() * images.size(0)
preds = torch.argmax(outputs, dim=1)
val_acc += (preds == labels).sum().item()
val_loss /= len(val_dataset)
val_acc /= len(val_dataset)
return val_loss, val_acc
def save_checkpoint(epoch, best_acc=False):
checkpoint = {
'epoch': epoch + 1,
'model_state': model.state_dict(),
'optimizer_state': optimizer.state_dict(),
'scheduler_state': scheduler.state_dict(),
'best_acc': best_acc
}
torch.save(checkpoint, CHECKPOINT_PATH)
def save_weight():
checkpoint = {
'model_state': model.state_dict(),
'optimizer_state': optimizer.state_dict(),
'scheduler_state': scheduler.state_dict(),
}
torch.save(checkpoint, PTH_PATH)
def log_to_file(epoch, train_loss, train_acc, val_loss, val_acc, epoch_time):
timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
log_str = (f"[{timestamp}] Epoch: {epoch + 1}/{EPOCHS} | "
f"Train Loss: {train_loss:.4f} | Train Acc: {train_acc:.2f}% | "
f"Val Loss: {val_loss:.4f} | Val Acc: {val_acc:.2f}% | "
f"Time: {epoch_time:.1f}s\n")
with open(LOG_FILE, 'a') as f:
f.write(log_str)
print(log_str.strip())
def main():
start_epoch = 0
best_val_acc = 0.0
# 从checkpoint恢复
if RESUME_FROM_CHECKPOINT and os.path.exists(CHECKPOINT_PATH):
checkpoint = torch.load(CHECKPOINT_PATH)
model.load_state_dict(checkpoint['model_state'])
optimizer.load_state_dict(checkpoint['optimizer_state'])
scheduler.load_state_dict(checkpoint['scheduler_state'])
start_epoch = checkpoint['epoch']
best_val_acc = checkpoint.get('best_acc', 0.0)
print(f"Loaded checkpoint from epoch {start_epoch}")
# 清空或创建日志文件
with open(LOG_FILE, 'a') as f:
f.write("Training Log\n============\n")
for epoch in range(start_epoch, EPOCHS):
# 训练+验证
train_result = train_one_epoch(epoch)
if train_result is None:
print("Training stopped due to NaN.")
break
train_loss, train_acc, epoch_time = train_result
val_loss, val_acc = validate()
# 记录日志
log_to_file(epoch, train_loss, train_acc*100, val_loss, val_acc*100, epoch_time)
print(epoch,train_loss,train_acc*100,val_loss,val_acc*100)
# 保存checkpoint
save_checkpoint(epoch) # 注释掉保存checkpoint的代码
if val_acc>best_val_acc:
save_weight()
best_val_acc=val_acc
if name == 'main':
main()
I would greatly appreciate any insights or suggestions you might have. Thank you in advance for your time and assistance.
Best regards,
Mingsheng Chen
mingshengchen2333@163.com