本仓库基于 AdaAttN 进行了一些改动。
AdaAttN 是对任意神经风格迁移中的注意力机制的重新审视,其论文 "AdaAttN: Revisit Attention Mechanism in Arbitrary Neural Style Transfer, ICCV 2021"。
相比于原仓库,本仓库进行了以下改动:
- 修改 ONNX 不支持的算子:
- 将模型中所有
torch.randperm
的调用改为torch.arange
,以避免 ONNX 不支持的随机算子。 - 该改动可能会对推理结果产生轻微的变化。
- 将模型中所有
- 调整
BaseModel
的继承:- 原仓库
BaseModel
继承ABC
,本仓库修改为同时继承nn.Module
和ABC
以优化兼容性。
- 原仓库
- Result HTML 展示内容调整:
- 修改了表格格式及标题,使其更易阅读和比较。
- 直接导出的onnx模型
adaattn.onnx
- 用onnxslim优化后的onnx模型
adaattn_slim.onnx
- 导出原始大小151MB。
- 经过onnxslim图优化后,AdaAttN ONNX 模型文件大小约 102MB。
- 输入 1:
content
,形状[b, 3, h, w]
- 输入 2:
style
,形状[b, 3, h, w]
- 输出:
output
,形状[b, 3, h, w]
-
Python 3.10
-
依赖库安装:
pip install torch==2.5.1+cu124 torchvision==0.20+cu124 -f https://download.pytorch.org/whl/torch_stable.html pip install torchvision==0.20.1 torchaudio==2.5.1 --index-url https://download.pytorch.org/whl/cu124 pip install dominate pip install numpy==1.26.4 pip install pillow
python test.py ^
--content_path datasets/contents ^
--style_path datasets/styles ^
--name AdaAttN ^
--model adaattn ^
--dataset_mode unaligned ^
--load_size 1024 ^
--crop_size 1024 ^
--image_encoder_path checkpoints/vgg_normalised.pth ^
--gpu_ids 0 ^
--skip_connection_3 ^
--shallow_layer
WebDemo在线演示地址: https://whyb.github.io/AdaAttN-WebGPU/ 以下为本仓库风格迁移的部分效果示例: 左1:content 左2:Style 右1:Result