diff --git a/dataset/get_pretrain_t2i_data.py b/dataset/get_pretrain_t2i_data.py new file mode 100644 index 0000000..ffb6e45 --- /dev/null +++ b/dataset/get_pretrain_t2i_data.py @@ -0,0 +1,36 @@ +import json +from tqdm import tqdm + +# 输入和输出文件路径 +input_file = 'pretrain_data.jsonl' +output_file = 'pretrain_t2i_data.jsonl' + +# 打开输入文件和输出文件 +with open(input_file, 'r', encoding='utf-8') as infile, open(output_file, 'w', encoding='utf-8') as outfile: + for line in tqdm(infile): + # 读取每行JSON数据 + data = json.loads(line.strip()) + + # 获取对话内容和图片字段 + conversations = data.get("conversations", []) + image = data.get("image", "") + + # 修改对话内容 + if len(conversations) == 2: + # 交换user和assistant的内容 + user_content = conversations[0]["content"] + assistant_content = conversations[1]["content"] + + conversations[0]["content"] = assistant_content + conversations[1]["content"] = "" + + # 创建新的数据格式 + new_data = { + "conversations": conversations, + "image": image + } + + # 将处理后的数据写入到新的jsonl文件 + outfile.write(json.dumps(new_data, ensure_ascii=False) + "\n") + +print("处理完成,结果保存在", output_file) diff --git a/eval_t2i.py b/eval_t2i.py new file mode 100644 index 0000000..6db11a7 --- /dev/null +++ b/eval_t2i.py @@ -0,0 +1,93 @@ +import argparse +import os +import random +import numpy as np +import torch +import warnings +import torch.nn.functional as F +from PIL import Image +from transformers import AutoTokenizer, AutoModelForCausalLM +from model.model_t2i import MiniMindT2I +from model.VLMConfig import VLMConfig +from transformers import logging as hf_logging + +hf_logging.set_verbosity_error() + +warnings.filterwarnings('ignore') + + +def count_parameters(model): + return sum(p.numel() for p in model.parameters() if p.requires_grad) + + +def init_model(lm_config, device): + tokenizer = AutoTokenizer.from_pretrained('./model/minimind_tokenizer') + moe_path = '_moe' if args.use_moe else '' + ckp = f'./{args.out_dir}/sft_t2i_{args.dim}{moe_path}.pth' + model = MiniMindT2I(lm_config) + state_dict = torch.load(ckp, map_location=device) + model.load_state_dict({k: v for k, v in state_dict.items() if 'mask' not in k}, strict=False) + + print(f'T2I参数量:{sum(p.numel() for p in model.parameters() if p.requires_grad) / 1e6:.3f} 百万') + + return model.eval().to(device), tokenizer + + +def setup_seed(seed): + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + torch.backends.cudnn.deterministic = True + torch.backends.cudnn.benchmark = False + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Chat with MiniMind") + parser.add_argument('--lora_name', default='None', type=str) + parser.add_argument('--out_dir', default='out', type=str) + parser.add_argument('--temperature', default=0.65, type=float) + parser.add_argument('--top_p', default=0.85, type=float) + parser.add_argument('--device', default='cuda' if torch.cuda.is_available() else 'cpu', type=str) + # MiniMind2-Small (26M):(dim=512, n_layers=8) + # MiniMind2 (104M):(dim=768, n_layers=16) + parser.add_argument('--dim', default=512, type=int) + parser.add_argument('--n_layers', default=8, type=int) + parser.add_argument('--max_seq_len', default=640, type=int) + parser.add_argument('--use_moe', default=False, type=bool) + args = parser.parse_args() + + lm_config = VLMConfig(dim=args.dim, n_layers=args.n_layers, max_seq_len=args.max_seq_len, use_moe=args.use_moe) + + model, tokenizer = init_model(lm_config, args.device) + + + def chat_with_vlm(prompt): + messages = [{"role": "user", "content": prompt}] + + new_prompt = tokenizer.apply_chat_template( + messages, + tokenize=False, + add_generation_prompt=True + )[-args.max_seq_len + 1:] + + with torch.no_grad(): + x = torch.tensor(tokenizer(new_prompt)['input_ids'], device=args.device).unsqueeze(0) + outputs = model.generate( + x, + temperature=args.temperature, + top_p=args.top_p, + pad_token_id=tokenizer.pad_token_id, + ) + image = model.image_tokenizer.decode_code(outputs, (1, 8, 16, 16)) + image = F.interpolate(image, size=[256, 256], mode='bicubic').permute(0, 2, 3, 1)[0] + image = torch.clamp(127.5 * image + 128.0, 0, 255).to("cpu", dtype=torch.uint8).numpy() + Image.fromarray(image).save('output.jpg') + print('🤖️: 图片保存至output.jpg') + print('\n') + + + image_dir = './dataset/eval_images/' + prompt = f"一个年轻人准备踢足球" + chat_with_vlm(prompt) \ No newline at end of file diff --git a/model/dataset.py b/model/dataset.py index dae04a7..e07eab9 100644 --- a/model/dataset.py +++ b/model/dataset.py @@ -4,9 +4,26 @@ import torch from .model_vlm import MiniMindVLM import os +import numpy as np +from torchvision import transforms os.environ["TOKENIZERS_PARALLELISM"] = "false" +def center_crop_arr(pil_image, image_size): + while min(*pil_image.size) >= 2 * image_size: + pil_image = pil_image.resize( + tuple(x // 2 for x in pil_image.size), resample=Image.BOX + ) + + scale = image_size / min(*pil_image.size) + pil_image = pil_image.resize( + tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC + ) + + arr = np.array(pil_image) + crop_y = (arr.shape[0] - image_size) // 2 + crop_x = (arr.shape[1] - image_size) // 2 + return Image.fromarray(arr[crop_y: crop_y + image_size, crop_x: crop_x + image_size]) class VLMDataset(Dataset): def __init__(self, jsonl_path, images_path, tokenizer, preprocess=None, max_length=512, @@ -84,3 +101,96 @@ def __getitem__(self, index: int): image_tensors = torch.stack(image_tensors, dim=0) return X, Y, loss_mask, image_tensors + +class T2IDataset(Dataset): + def __init__(self, jsonl_path, images_path, tokenizer, max_length=512, img_pre_process=False, + image_special_token='@' * 256): + + super().__init__() + self.samples = self.load_data(jsonl_path) + self.images_path = images_path + + self.img_pre_process = img_pre_process # 是否使用提前处理的图片 + + self.tokenizer = tokenizer + self.max_length = max_length + self.image_size = 256 + self.image_token = image_special_token + self.bos_id = tokenizer('assistant\n', add_special_tokens=False).input_ids + self.eos_id = tokenizer('\n', add_special_tokens=False).input_ids + + def __len__(self): + return len(self.samples) + + def load_data(self, path): + samples = [] + with open(path, 'r', encoding='utf-8') as f: + for line_num, line in enumerate(f, 1): + data = json.loads(line.strip()) + samples.append(data) + return samples + + def _create_chat_prompt(self, conversations): + messages = [] + for i, turn in enumerate(conversations): + role = 'user' if i % 2 == 0 else 'assistant' + messages.append({"role": role, "content": turn['content'].replace('', self.image_token)}) + return self.tokenizer.apply_chat_template( + messages, + tokenize=False, + add_generation_prompt=False + ) + + def _generate_loss_mask(self, input_ids): + loss_mask = [0] * len(input_ids) + i = 0 + while i < len(input_ids): + if input_ids[i:i + len(self.bos_id)] == self.bos_id: + start = i + len(self.bos_id) + end = start + while end < len(input_ids): + if input_ids[end:end + len(self.eos_id)] == self.eos_id: + break + end += 1 + for j in range(start + 1, min(end + len(self.eos_id) + 1, self.max_length)): + loss_mask[j] = 1 + i = end + len(self.eos_id) if end < len(input_ids) else len(input_ids) + else: + i += 1 + return loss_mask + + def __getitem__(self, index: int): + sample = self.samples[index] + image_paths = sample['image'] + prompt = self._create_chat_prompt(sample['conversations']) + input_ids = self.tokenizer(prompt).input_ids[:self.max_length] + input_ids += [self.tokenizer.pad_token_id] * (self.max_length - len(input_ids)) + loss_mask = self._generate_loss_mask(input_ids) + + X = torch.tensor(input_ids[:-1], dtype=torch.long) + Y = torch.tensor(input_ids[1:], dtype=torch.long) + loss_mask = torch.tensor(loss_mask[1:], dtype=torch.long) + + if self.img_pre_process: + image_emb_path = image_paths.replace('.jpg', '_emb.npy') + image_emb = np.load(f'{self.images_path}/{image_emb_path}') + image_emb = torch.tensor(image_emb, dtype=torch.float32) + # 加载预处理的图像token + image_token_path = image_paths.replace('.jpg', '_token.npy') + image_token = np.load(f'{self.images_path}/{image_token_path}') + image_token = torch.tensor(image_token, dtype=torch.long) + return X, Y, loss_mask, (image_emb, image_token) + else: + image_tensors = [] + for image_name in image_paths.split(','): + image_name = image_name.strip() + image = Image.open(f'{self.images_path}/{image_name}').convert("RGB") + image = center_crop_arr(image, self.image_size) + image = np.array(image) / 255. + image = 2.0 * image - 1.0 + image = torch.tensor(image, dtype=torch.float32) + image_tensor = torch.einsum('hwc->chw', image) + image_tensors.append(image_tensor) + image_tensors = torch.stack(image_tensors, dim=0) + + return X, Y, loss_mask, image_tensors diff --git a/model/model_t2i.py b/model/model_t2i.py new file mode 100644 index 0000000..2af43ef --- /dev/null +++ b/model/model_t2i.py @@ -0,0 +1,199 @@ +from .VLMConfig import VLMConfig +from .model import * +from typing import Optional, Tuple, List +from torch import nn +import warnings +from model.model_vlm import MiniMindVLM, VisionProj +from model.model_vq import VQ_models +import torch +from einops import rearrange + +warnings.filterwarnings('ignore') + + +# class VisionProj(nn.Module): +# def __init__(self, ve_dim=768, lm_dim=512): +# super().__init__() +# self.ve_dim = ve_dim +# self.lm_dim = lm_dim +# self.vision_proj = nn.Sequential( +# nn.Linear(self.ve_dim, self.lm_dim), +# nn.ReLU(), +# nn.Linear(self.lm_dim, self.lm_dim) +# ) + +# def forward(self, image_encoders): +# vision_proj = self.vision_proj(image_encoders) +# return vision_proj + + +# 继承自语言模型 +class MiniMindT2I(MiniMindVLM): + config_class = VLMConfig + + def __init__(self, params: VLMConfig = None): + super().__init__(params) + if not params: params = VLMConfig() + self.params = params + self.image_tokenizer = self.__class__.get_image_tokenizer() + # self.output = nn.Linear(params.dim, 16384) + self.vision_proj = VisionProj(ve_dim=8, lm_dim=params.dim) + + @staticmethod + def get_image_tokenizer(model_path="./model/minimind_img_tokenizer/minimind_img_tokenizer.pt"): + model = VQ_models['VQ-16']( + codebook_size=6400, + codebook_embed_dim=8) + + # 加载模型 + checkpoint = torch.load(model_path, map_location="cpu") + if "ema" in checkpoint: # ema + model_weight = checkpoint["ema"] + elif "model" in checkpoint: # ddp + model_weight = checkpoint["model"] + elif "state_dict" in checkpoint: + model_weight = checkpoint["state_dict"] + else: + raise Exception("please check model weight") + model.load_state_dict(model_weight) + del checkpoint + + # 冻结 image_tokenizer 的所有参数 + for param in model.parameters(): + param.requires_grad = False + return model.eval() + + @staticmethod + def get_image_embeddings(image_tensors, image_tokenizer): + with torch.no_grad(): + # latent为离散化向量, indices为离散化向量的tokenid + latent, _, [_, _, indices] = image_tokenizer.encode(image_tensors) + # 展平latent [B, C, H, W] -> [B, H*W, C] + img_embedding = rearrange(latent, 'b c h w -> b (h w) c') + indices = indices.reshape(-1, 256) + return img_embedding, indices + + # 替换"@......"为图片的token + def count_token_replace(self, tokens, vision_token=None, seqlen=512): + def find_indices(tokens, image_ids): + image_ids_tensor = torch.tensor(image_ids).to(tokens.device) + len_image_ids = len(image_ids) + if len_image_ids > tokens.size(1): + return None + tokens_view = tokens.unfold(1, len_image_ids, 1) + matches = (tokens_view == image_ids_tensor).all(dim=2) + return { + batch_idx: [(idx.item(), idx.item() + len_image_ids - 1) for idx in + matches[batch_idx].nonzero(as_tuple=True)[0]] + for batch_idx in range(tokens.size(0)) if matches[batch_idx].any() + } or None + + image_indices = find_indices(tokens, self.params.image_ids) + + if vision_token is not None and image_indices: + new_tokens = [] + for i in range(tokens.size(0)): + if i in image_indices: + token_i = tokens[i] + img_idx = 0 + for start_idx, end_idx in image_indices[i]: + if img_idx < vision_token.size(1): + token_i = torch.cat( + (token_i[:start_idx], vision_token[i][img_idx], token_i[end_idx + 1:]), dim=0 + )[:seqlen] + img_idx += 1 + new_tokens.append(token_i) + else: + new_tokens.append(tokens[i]) + + return torch.stack(new_tokens, dim=0) + return tokens + + def forward(self, + input_ids: Optional[torch.Tensor] = None, + target_ids: Optional[torch.Tensor] = None, + past_key_values: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = None, + use_cache: bool = False, + **args): + start_pos = args.get('start_pos', 0) + pixel_tensors = args.get('pixel_tensors', None) + is_image = args.get('is_image', False) + if is_image: + input_ids = self.image_tokenizer.quantize.get_codebook_entry(input_ids) + h = self.vision_proj(input_ids) + else: + h = self.tok_embeddings(input_ids) + + if pixel_tensors is not None and start_pos == 0: + if isinstance(pixel_tensors, torch.Tensor): + pixel_tensors = pixel_tensors.to(h.device) + if len(pixel_tensors.shape) == 6: + pixel_tensors = pixel_tensors.squeeze(2) + bs, num, c, im_h, im_w = pixel_tensors.shape + stack_dim = 1 if bs > 1 else 0 + # 获取图片的embedding + vision_tensors = torch.stack([ + MiniMindT2I.get_image_embeddings(pixel_tensors[:, i, :, :, :], self.image_tokenizer)[0] + for i in range(num) + ], dim=stack_dim) + h = self.count_vision_proj(tokens=input_ids, h=h, vision_tensors=vision_tensors, seqlen=input_ids.shape[1]) + # 获取图片的token + vision_tokens = torch.stack([ + MiniMindT2I.get_image_embeddings(pixel_tensors[:, i, :, :, :], self.image_tokenizer)[1] + for i in range(num) + ], dim=stack_dim) + # 替换target中的'@......'为图片的token + target_ids = self.count_token_replace(tokens=target_ids, vision_token=vision_tokens, seqlen=target_ids.shape[1]) + else: + vision_tensors, vision_tokens = pixel_tensors + vision_tensors = vision_tensors.to(h.device) + vision_tokens = vision_tokens.to(h.device) + h = self.count_vision_proj(tokens=input_ids, h=h, vision_tensors=vision_tensors, seqlen=input_ids.shape[1]) + target_ids = self.count_token_replace(tokens=target_ids, vision_token=vision_tokens, seqlen=target_ids.shape[1]) + + pos_cis = self.pos_cis[start_pos:start_pos + input_ids.shape[1]] + past_kvs = [] + for l, layer in enumerate(self.layers): + h, past_kv = layer( + h, pos_cis, + past_key_value=past_key_values[l] if past_key_values else None, + use_cache=use_cache + ) + past_kvs.append(past_kv) + + logits = self.output(self.norm(h)) + aux_loss = sum(l.feed_forward.aux_loss for l in self.layers if isinstance(l.feed_forward, MOEFeedForward)) + + self.OUT.__setitem__('logits', logits) + self.OUT.__setitem__('aux_loss', aux_loss) + self.OUT.__setitem__('past_key_values', past_kvs) + self.OUT.__setitem__('target_ids', target_ids) # 添加target_ids + return self.OUT + + @torch.inference_mode() + def generate(self, input_ids, img_token_num=256, temperature=0.75, top_p=0.90, rp=1., use_cache=True, pad_token_id=0, **args): + idx = 0 + input_ids = input_ids[input_ids != pad_token_id].unsqueeze(0) + start, first_seq, past_kvs = input_ids.shape[1], True, None + while idx < img_token_num: + if first_seq or not use_cache: + out, first_seq = self(input_ids, past_key_values=past_kvs, use_cache=use_cache, **args), False + else: + out = self(input_ids[:, -1:], past_key_values=past_kvs, use_cache=use_cache, + start_pos=input_ids.shape[1] - 1, is_image=True, **args) + logits, past_kvs = out.logits[:, -1, :], out.past_key_values + logits[:, list(set(input_ids.tolist()[0]))] /= rp + logits /= (temperature + 1e-9) + if top_p is not None and top_p < 1.0: + sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1) + sorted_probs = F.softmax(sorted_logits, dim=-1) + cumulative_probs = torch.cumsum(sorted_probs, dim=-1) + sorted_indices_to_remove = cumulative_probs > top_p + sorted_indices_to_remove[:, 1:] = sorted_indices_to_remove[:, :-1].clone() + sorted_indices_to_remove[:, 0] = False + indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove) + logits[indices_to_remove] = -float('Inf') + input_ids_next = torch.multinomial(F.softmax(logits, dim=-1), num_samples=1) + input_ids = torch.cat((input_ids, input_ids_next), dim=1) + idx += 1 + return input_ids[:, -img_token_num:] \ No newline at end of file diff --git a/model/model_vq.py b/model/model_vq.py new file mode 100644 index 0000000..3a0ca4e --- /dev/null +++ b/model/model_vq.py @@ -0,0 +1,473 @@ +# Modified from: +# taming-transformers: https://github.com/CompVis/taming-transformers +# maskgit: https://github.com/google-research/maskgit +from dataclasses import dataclass, field +from typing import List + +import torch +import torch.nn as nn +import torch.nn.functional as F + +import numpy as np +import random +import math +from PIL import Image + +def center_crop_arr(pil_image, image_size): + """ + Center cropping implementation from ADM. + https://github.com/openai/guided-diffusion/blob/8fb3ad9197f16bbc40620447b2742e13458d2831/guided_diffusion/image_datasets.py#L126 + """ + while min(*pil_image.size) >= 2 * image_size: + pil_image = pil_image.resize( + tuple(x // 2 for x in pil_image.size), resample=Image.BOX + ) + + scale = image_size / min(*pil_image.size) + pil_image = pil_image.resize( + tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC + ) + + arr = np.array(pil_image) + crop_y = (arr.shape[0] - image_size) // 2 + crop_x = (arr.shape[1] - image_size) // 2 + return Image.fromarray(arr[crop_y: crop_y + image_size, crop_x: crop_x + image_size]) + + +def random_crop_arr(pil_image, image_size, min_crop_frac=0.8, max_crop_frac=1.0): + min_smaller_dim_size = math.ceil(image_size / max_crop_frac) + max_smaller_dim_size = math.ceil(image_size / min_crop_frac) + smaller_dim_size = random.randrange(min_smaller_dim_size, max_smaller_dim_size + 1) + + # We are not on a new enough PIL to support the `reducing_gap` + # argument, which uses BOX downsampling at powers of two first. + # Thus, we do it by hand to improve downsample quality. + while min(*pil_image.size) >= 2 * smaller_dim_size: + pil_image = pil_image.resize( + tuple(x // 2 for x in pil_image.size), resample=Image.BOX + ) + + scale = smaller_dim_size / min(*pil_image.size) + pil_image = pil_image.resize( + tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC + ) + + arr = np.array(pil_image) + crop_y = random.randrange(arr.shape[0] - image_size + 1) + crop_x = random.randrange(arr.shape[1] - image_size + 1) + return Image.fromarray(arr[crop_y : crop_y + image_size, crop_x : crop_x + image_size]) + + +@dataclass +class ModelArgs: + codebook_size: int = 16384 + codebook_embed_dim: int = 8 + codebook_l2_norm: bool = True + codebook_show_usage: bool = True + commit_loss_beta: float = 0.25 + entropy_loss_ratio: float = 0.0 + + encoder_ch_mult: List[int] = field(default_factory=lambda: [1, 1, 2, 2, 4]) + decoder_ch_mult: List[int] = field(default_factory=lambda: [1, 1, 2, 2, 4]) + z_channels: int = 256 + dropout_p: float = 0.0 + + + +class VQModel(nn.Module): + def __init__(self, config: ModelArgs): + super().__init__() + self.config = config + self.encoder = Encoder(ch_mult=config.encoder_ch_mult, z_channels=config.z_channels, dropout=config.dropout_p) + self.decoder = Decoder(ch_mult=config.decoder_ch_mult, z_channels=config.z_channels, dropout=config.dropout_p) + + self.quantize = VectorQuantizer(config.codebook_size, config.codebook_embed_dim, + config.commit_loss_beta, config.entropy_loss_ratio, + config.codebook_l2_norm, config.codebook_show_usage) + self.quant_conv = nn.Conv2d(config.z_channels, config.codebook_embed_dim, 1) + self.post_quant_conv = nn.Conv2d(config.codebook_embed_dim, config.z_channels, 1) + + def encode(self, x): + h = self.encoder(x) + h = self.quant_conv(h) + quant, emb_loss, info = self.quantize(h) + return quant, emb_loss, info + + def decode(self, quant): + quant = self.post_quant_conv(quant) + dec = self.decoder(quant) + return dec + + def decode_code(self, code_b, shape=None, channel_first=True): + quant_b = self.quantize.get_codebook_entry(code_b, shape, channel_first) + dec = self.decode(quant_b) + return dec + + def forward(self, input): + quant, diff, _ = self.encode(input) + dec = self.decode(quant) + return dec, diff + + + +class Encoder(nn.Module): + def __init__(self, in_channels=3, ch=128, ch_mult=(1,1,2,2,4), num_res_blocks=2, + norm_type='group', dropout=0.0, resamp_with_conv=True, z_channels=256): + super().__init__() + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + self.conv_in = nn.Conv2d(in_channels, ch, kernel_size=3, stride=1, padding=1) + + # downsampling + in_ch_mult = (1,) + tuple(ch_mult) + self.conv_blocks = nn.ModuleList() + for i_level in range(self.num_resolutions): + conv_block = nn.Module() + # res & attn + res_block = nn.ModuleList() + attn_block = nn.ModuleList() + block_in = ch*in_ch_mult[i_level] + block_out = ch*ch_mult[i_level] + for _ in range(self.num_res_blocks): + res_block.append(ResnetBlock(block_in, block_out, dropout=dropout, norm_type=norm_type)) + block_in = block_out + if i_level == self.num_resolutions - 1: + attn_block.append(AttnBlock(block_in, norm_type)) + conv_block.res = res_block + conv_block.attn = attn_block + # downsample + if i_level != self.num_resolutions-1: + conv_block.downsample = Downsample(block_in, resamp_with_conv) + self.conv_blocks.append(conv_block) + + # middle + self.mid = nn.ModuleList() + self.mid.append(ResnetBlock(block_in, block_in, dropout=dropout, norm_type=norm_type)) + self.mid.append(AttnBlock(block_in, norm_type=norm_type)) + self.mid.append(ResnetBlock(block_in, block_in, dropout=dropout, norm_type=norm_type)) + + # end + self.norm_out = Normalize(block_in, norm_type) + self.conv_out = nn.Conv2d(block_in, z_channels, kernel_size=3, stride=1, padding=1) + + + def forward(self, x): + h = self.conv_in(x) + # downsampling + for i_level, block in enumerate(self.conv_blocks): + for i_block in range(self.num_res_blocks): + h = block.res[i_block](h) + if len(block.attn) > 0: + h = block.attn[i_block](h) + if i_level != self.num_resolutions - 1: + h = block.downsample(h) + + # middle + for mid_block in self.mid: + h = mid_block(h) + + # end + h = self.norm_out(h) + h = nonlinearity(h) + h = self.conv_out(h) + return h + + + +class Decoder(nn.Module): + def __init__(self, z_channels=256, ch=128, ch_mult=(1,1,2,2,4), num_res_blocks=2, norm_type="group", + dropout=0.0, resamp_with_conv=True, out_channels=3): + super().__init__() + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + + block_in = ch*ch_mult[self.num_resolutions-1] + # z to block_in + self.conv_in = nn.Conv2d(z_channels, block_in, kernel_size=3, stride=1, padding=1) + + # middle + self.mid = nn.ModuleList() + self.mid.append(ResnetBlock(block_in, block_in, dropout=dropout, norm_type=norm_type)) + self.mid.append(AttnBlock(block_in, norm_type=norm_type)) + self.mid.append(ResnetBlock(block_in, block_in, dropout=dropout, norm_type=norm_type)) + + # upsampling + self.conv_blocks = nn.ModuleList() + for i_level in reversed(range(self.num_resolutions)): + conv_block = nn.Module() + # res & attn + res_block = nn.ModuleList() + attn_block = nn.ModuleList() + block_out = ch*ch_mult[i_level] + for _ in range(self.num_res_blocks + 1): + res_block.append(ResnetBlock(block_in, block_out, dropout=dropout, norm_type=norm_type)) + block_in = block_out + if i_level == self.num_resolutions - 1: + attn_block.append(AttnBlock(block_in, norm_type)) + conv_block.res = res_block + conv_block.attn = attn_block + # downsample + if i_level != 0: + conv_block.upsample = Upsample(block_in, resamp_with_conv) + self.conv_blocks.append(conv_block) + + # end + self.norm_out = Normalize(block_in, norm_type) + self.conv_out = nn.Conv2d(block_in, out_channels, kernel_size=3, stride=1, padding=1) + + @property + def last_layer(self): + return self.conv_out.weight + + def forward(self, z): + # z to block_in + h = self.conv_in(z) + + # middle + for mid_block in self.mid: + h = mid_block(h) + + # upsampling + for i_level, block in enumerate(self.conv_blocks): + for i_block in range(self.num_res_blocks + 1): + h = block.res[i_block](h) + if len(block.attn) > 0: + h = block.attn[i_block](h) + if i_level != self.num_resolutions - 1: + h = block.upsample(h) + + # end + h = self.norm_out(h) + h = nonlinearity(h) + h = self.conv_out(h) + return h + + +class VectorQuantizer(nn.Module): + def __init__(self, n_e, e_dim, beta, entropy_loss_ratio, l2_norm, show_usage): + super().__init__() + self.n_e = n_e + self.e_dim = e_dim + self.beta = beta + self.entropy_loss_ratio = entropy_loss_ratio + self.l2_norm = l2_norm + self.show_usage = show_usage + + self.embedding = nn.Embedding(self.n_e, self.e_dim) + self.embedding.weight.data.uniform_(-1.0 / self.n_e, 1.0 / self.n_e) + if self.l2_norm: + self.embedding.weight.data = F.normalize(self.embedding.weight.data, p=2, dim=-1) + if self.show_usage: + self.register_buffer("codebook_used", nn.Parameter(torch.zeros(65536))) + + + def forward(self, z): + # reshape z -> (batch, height, width, channel) and flatten + z = torch.einsum('b c h w -> b h w c', z).contiguous() + z_flattened = z.view(-1, self.e_dim) + # distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z + + if self.l2_norm: + z = F.normalize(z, p=2, dim=-1) + z_flattened = F.normalize(z_flattened, p=2, dim=-1) + embedding = F.normalize(self.embedding.weight, p=2, dim=-1) + else: + embedding = self.embedding.weight + + d = torch.sum(z_flattened ** 2, dim=1, keepdim=True) + \ + torch.sum(embedding**2, dim=1) - 2 * \ + torch.einsum('bd,dn->bn', z_flattened, torch.einsum('n d -> d n', embedding)) + + min_encoding_indices = torch.argmin(d, dim=1) + z_q = embedding[min_encoding_indices].view(z.shape) + perplexity = None + min_encodings = None + vq_loss = None + commit_loss = None + entropy_loss = None + codebook_usage = 0 + + if self.show_usage and self.training: + cur_len = min_encoding_indices.shape[0] + self.codebook_used[:-cur_len] = self.codebook_used[cur_len:].clone() + self.codebook_used[-cur_len:] = min_encoding_indices + codebook_usage = len(torch.unique(self.codebook_used)) / self.n_e + + # compute loss for embedding + if self.training: + vq_loss = torch.mean((z_q - z.detach()) ** 2) + commit_loss = self.beta * torch.mean((z_q.detach() - z) ** 2) + entropy_loss = self.entropy_loss_ratio * compute_entropy_loss(-d) + + # preserve gradients + z_q = z + (z_q - z).detach() + + # reshape back to match original input shape + z_q = torch.einsum('b h w c -> b c h w', z_q) + + return z_q, (vq_loss, commit_loss, entropy_loss, codebook_usage), (perplexity, min_encodings, min_encoding_indices) + + def get_codebook_entry(self, indices, shape=None, channel_first=True): + # shape = (batch, channel, height, width) if channel_first else (batch, height, width, channel) + if self.l2_norm: + embedding = F.normalize(self.embedding.weight, p=2, dim=-1) + else: + embedding = self.embedding.weight + z_q = embedding[indices] # (b*h*w, c) + + if shape is not None: + if channel_first: + z_q = z_q.reshape(shape[0], shape[2], shape[3], shape[1]) + # reshape back to match original input shape + z_q = z_q.permute(0, 3, 1, 2).contiguous() + else: + z_q = z_q.view(shape) + return z_q + + +class ResnetBlock(nn.Module): + def __init__(self, in_channels, out_channels=None, conv_shortcut=False, dropout=0.0, norm_type='group'): + super().__init__() + self.in_channels = in_channels + out_channels = in_channels if out_channels is None else out_channels + self.out_channels = out_channels + self.use_conv_shortcut = conv_shortcut + + self.norm1 = Normalize(in_channels, norm_type) + self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1) + self.norm2 = Normalize(out_channels, norm_type) + self.dropout = nn.Dropout(dropout) + self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1) + + if self.in_channels != self.out_channels: + if self.use_conv_shortcut: + self.conv_shortcut = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1) + else: + self.nin_shortcut = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0) + + def forward(self, x): + h = x + h = self.norm1(h) + h = nonlinearity(h) + h = self.conv1(h) + h = self.norm2(h) + h = nonlinearity(h) + h = self.dropout(h) + h = self.conv2(h) + + if self.in_channels != self.out_channels: + if self.use_conv_shortcut: + x = self.conv_shortcut(x) + else: + x = self.nin_shortcut(x) + return x+h + + +class AttnBlock(nn.Module): + def __init__(self, in_channels, norm_type='group'): + super().__init__() + self.norm = Normalize(in_channels, norm_type) + self.q = nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) + self.k = nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) + self.v = nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) + self.proj_out = nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) + + + def forward(self, x): + h_ = x + h_ = self.norm(h_) + q = self.q(h_) + k = self.k(h_) + v = self.v(h_) + + # compute attention + b,c,h,w = q.shape + q = q.reshape(b,c,h*w) + q = q.permute(0,2,1) # b,hw,c + k = k.reshape(b,c,h*w) # b,c,hw + w_ = torch.bmm(q,k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j] + w_ = w_ * (int(c)**(-0.5)) + w_ = F.softmax(w_, dim=2) + + # attend to values + v = v.reshape(b,c,h*w) + w_ = w_.permute(0,2,1) # b,hw,hw (first hw of k, second of q) + h_ = torch.bmm(v,w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j] + h_ = h_.reshape(b,c,h,w) + + h_ = self.proj_out(h_) + + return x+h_ + + +def nonlinearity(x): + # swish + return x*torch.sigmoid(x) + + +def Normalize(in_channels, norm_type='group'): + assert norm_type in ['group', 'batch'] + if norm_type == 'group': + return nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) + elif norm_type == 'batch': + return nn.SyncBatchNorm(in_channels) + + +class Upsample(nn.Module): + def __init__(self, in_channels, with_conv): + super().__init__() + self.with_conv = with_conv + if self.with_conv: + self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1) + + def forward(self, x): + x = F.interpolate(x, scale_factor=2.0, mode="nearest") + if self.with_conv: + x = self.conv(x) + return x + + +class Downsample(nn.Module): + def __init__(self, in_channels, with_conv): + super().__init__() + self.with_conv = with_conv + if self.with_conv: + # no asymmetric padding in torch conv, must do it ourselves + self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=0) + + def forward(self, x): + if self.with_conv: + pad = (0,1,0,1) + x = F.pad(x, pad, mode="constant", value=0) + x = self.conv(x) + else: + x = F.avg_pool2d(x, kernel_size=2, stride=2) + return x + + +def compute_entropy_loss(affinity, loss_type="softmax", temperature=0.01): + flat_affinity = affinity.reshape(-1, affinity.shape[-1]) + flat_affinity /= temperature + probs = F.softmax(flat_affinity, dim=-1) + log_probs = F.log_softmax(flat_affinity + 1e-5, dim=-1) + if loss_type == "softmax": + target_probs = probs + else: + raise ValueError("Entropy loss {} not supported".format(loss_type)) + avg_probs = torch.mean(target_probs, dim=0) + avg_entropy = - torch.sum(avg_probs * torch.log(avg_probs + 1e-5)) + sample_entropy = - torch.mean(torch.sum(target_probs * log_probs, dim=-1)) + loss = sample_entropy - avg_entropy + return loss + + +################################################################################# +# VQ Model Configs # +################################################################################# +def VQ_8(**kwargs): + return VQModel(ModelArgs(encoder_ch_mult=[1, 2, 2, 4], decoder_ch_mult=[1, 2, 2, 4], **kwargs)) + +def VQ_16(**kwargs): + return VQModel(ModelArgs(encoder_ch_mult=[1, 1, 2, 2, 4], decoder_ch_mult=[1, 1, 2, 2, 4], **kwargs)) + +VQ_models = {'VQ-16': VQ_16, 'VQ-8': VQ_8} \ No newline at end of file diff --git a/model/vq_demo.py b/model/vq_demo.py new file mode 100644 index 0000000..4caf695 --- /dev/null +++ b/model/vq_demo.py @@ -0,0 +1,90 @@ +import torch +import torch.nn.functional as F + +import os +import argparse +import numpy as np +from PIL import Image + +from model_vq import VQ_models, center_crop_arr + + +def main(args): + # Setup PyTorch: + torch.manual_seed(args.seed) + torch.set_grad_enabled(False) + device = "cuda" if torch.cuda.is_available() else "cpu" + + # create and load model + model = VQ_models[args.vq_model]( + codebook_size=args.codebook_size, + codebook_embed_dim=args.codebook_embed_dim) + model.to(device) + model.eval() + checkpoint = torch.load(args.vq_ckpt, map_location="cpu") + if "ema" in checkpoint: # ema + print("load ema") + model_weight = checkpoint["ema"] + elif "model" in checkpoint: # ddp + print("load model") + model_weight = checkpoint["model"] + elif "state_dict" in checkpoint: + print("load state_dict") + model_weight = checkpoint["state_dict"] + else: + raise Exception("please check model weight") + model.load_state_dict(model_weight) + del checkpoint + + # output dir + os.makedirs(args.output_dir, exist_ok=True) + out_path = args.image_path.replace('.jpg', '_{}.jpg'.format(args.suffix)) + out_path = out_path.replace('.jpeg', '_{}.jpeg'.format(args.suffix)) + out_path = out_path.replace('.png', '_{}.png'.format(args.suffix)) + out_filename = out_path.split('/')[-1] + out_path = os.path.join(args.output_dir, out_filename) + + # load image + pil_image = Image.open(args.image_path).convert("RGB") + img = center_crop_arr(pil_image, args.image_size) + # # preprocess + # size_org = img.size + # img = img.resize((input_size, input_size)) + img = np.array(img) / 255. + x = 2.0 * img - 1.0 # x value is between [-1, 1] + x = torch.tensor(x) + x = x.unsqueeze(dim=0) + x = torch.einsum('nhwc->nchw', x) + x_input = x.float().to("cuda") + + # inference + with torch.no_grad(): + latent, _, [_, _, indices] = model.encode(x_input) + print("latent shape: ", latent.shape) + print("indices shape: ", indices.shape) + # print(latent) + # print(indices) + output = model.decode_code(indices, latent.shape) # output value is between [-1, 1] + + # postprocess + output = F.interpolate(output, size=[args.image_size, args.image_size], mode='bicubic').permute(0, 2, 3, 1)[0] + sample = torch.clamp(127.5 * output + 128.0, 0, 255).to("cpu", dtype=torch.uint8).numpy() + + # save + Image.fromarray(sample).save(out_path) + print("Reconstructed image is saved to {}".format(out_path)) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--image-path", type=str, help="你想测试的图片路径") + parser.add_argument("--output-dir", type=str, help="输出文件夹路径") + parser.add_argument("--suffix", type=str, default="tokenizer_image", help="输出文件名后缀") + parser.add_argument("--vq-model", type=str, choices=list(VQ_models.keys()), default="VQ-16", help="VQGAN的类型") + parser.add_argument("--vq-ckpt", type=str, help="VQGAN的图片路径") + parser.add_argument("--codebook-size", type=int, default=16384, help="VQGAN的codebook大小,可理解为vocab size") + parser.add_argument("--codebook-embed-dim", type=int, default=8, help="VQGAN离散化向量的维度") + parser.add_argument("--image-size", type=int, choices=[256, 384, 448, 512, 1024], default=256) + parser.add_argument("--seed", type=int, default=0) + args = parser.parse_args() + main(args) \ No newline at end of file diff --git a/train_pretrain_t2i.py b/train_pretrain_t2i.py new file mode 100644 index 0000000..3e25827 --- /dev/null +++ b/train_pretrain_t2i.py @@ -0,0 +1,216 @@ +import os +import platform +import argparse +import time +import math +import warnings +import json + +import pandas as pd +import torch +import torch.nn.functional as F +import torch.distributed as dist +from contextlib import nullcontext + +from torch import optim, nn +from torch.nn.parallel import DistributedDataParallel +from torch.utils.data import DataLoader, DistributedSampler +from transformers import AutoTokenizer, AutoModel +from model.model_t2i import MiniMindT2I +from model.VLMConfig import VLMConfig +from model.dataset import T2IDataset + +warnings.filterwarnings('ignore') + + +def Logger(content): + if not ddp or dist.get_rank() == 0: + print(content) + + +def get_lr(current_step, total_steps, lr): + return lr / 10 + 0.5 * lr * (1 + math.cos(math.pi * current_step / total_steps)) + + +def train_epoch(epoch, wandb): + loss_fct = nn.CrossEntropyLoss(reduction='none') + start_time = time.time() + for step, (X, Y, loss_mask, pixel_tensors) in enumerate(train_loader): + X = X.to(args.device) + Y = Y.to(args.device) + loss_mask = loss_mask.to(args.device) + pixel_tensors = pixel_tensors + lr = get_lr(epoch * iter_per_epoch + step, args.epochs * iter_per_epoch, args.learning_rate) + for param_group in optimizer.param_groups: + param_group['lr'] = lr + + with ctx: + res = model(X, Y, pixel_tensors=pixel_tensors) + Y = res.target_ids + loss = loss_fct( + res.logits.view(-1, res.logits.size(-1)), + res.target_ids.view(-1) + ).view(Y.size()) + + loss = (loss * loss_mask).sum() / loss_mask.sum() + loss += res.aux_loss + loss = loss / args.accumulation_steps + + scaler.scale(loss).backward() + + if (step + 1) % args.accumulation_steps == 0: + scaler.unscale_(optimizer) + torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip) + + scaler.step(optimizer) + scaler.update() + + optimizer.zero_grad(set_to_none=True) + + if step % args.log_interval == 0: + spend_time = time.time() - start_time + Logger( + 'Epoch:[{}/{}]({}/{}) loss:{:.3f} lr:{:.7f} epoch_Time:{}min:'.format( + epoch + 1, + args.epochs, + step, + iter_per_epoch, + loss.item(), + optimizer.param_groups[-1]['lr'], + spend_time / (step + 1) * iter_per_epoch // 60 - spend_time // 60)) + + if (wandb is not None) and (not ddp or dist.get_rank() == 0): + wandb.log({"loss": loss, + "lr": optimizer.param_groups[-1]['lr'], + "epoch_Time": spend_time / (step + 1) * iter_per_epoch // 60 - spend_time // 60}) + + if (step + 1) % args.save_interval == 0 and (not ddp or dist.get_rank() == 0): + model.eval() + moe_path = '_moe' if model_config.use_moe else '' + ckp = f'{args.save_dir}/pretrain_t2i_{model_config.dim}{moe_path}.pth' + if isinstance(model, torch.nn.parallel.DistributedDataParallel): + state_dict = model.module.state_dict() + else: + state_dict = model.state_dict() + clean_state_dict = { + key: value for key, value in state_dict.items() if not key.startswith('vision_encoder.') + } + torch.save(clean_state_dict, ckp) + model.train() + + +def init_model(model_config: VLMConfig): + tokenizer = AutoTokenizer.from_pretrained('./model/minimind_tokenizer') + moe_path = '_moe' if model_config.use_moe else '' + # 加载纯语言模型权重 + ckp = f'./out/lm_{model_config.dim}{moe_path}.pth' + model = MiniMindT2I(model_config) + # state_dict = torch.load(ckp, map_location=args.device) + # model.load_state_dict(state_dict, strict=False) + + # 冻结除 vision_proj 外的所有参数 + for name, param in model.named_parameters(): + if 'vision_proj' not in name: + param.requires_grad = False + # 可训练 + if hasattr(model, "layers"): + last_two_layers = model.layers[-1:] + for layer in last_two_layers: + for param in layer.parameters(): + param.requires_grad = True + + Logger(f'T2I可训练参数量:{sum(p.numel() for p in model.parameters() if p.requires_grad) / 1e6:.3f} 百万') + + return model.to(args.device), tokenizer + + +def init_distributed_mode(): + if not ddp: return + global ddp_local_rank, DEVICE + + dist.init_process_group(backend="nccl") + ddp_rank = int(os.environ["RANK"]) + ddp_local_rank = int(os.environ["LOCAL_RANK"]) + ddp_world_size = int(os.environ["WORLD_SIZE"]) + DEVICE = f"cuda:{ddp_local_rank}" + torch.cuda.set_device(DEVICE) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="MiniMind-V Pretrain") + parser.add_argument("--out_dir", type=str, default="out") + parser.add_argument("--epochs", type=int, default=4) + parser.add_argument("--batch_size", type=int, default=16) + parser.add_argument("--learning_rate", type=float, default=4e-4) + parser.add_argument("--device", type=str, default="cuda:0" if torch.cuda.is_available() else "cpu") + parser.add_argument("--dtype", type=str, default="bfloat16") + parser.add_argument("--use_wandb", default=False, action="store_true") + parser.add_argument("--wandb_project", type=str, default="MiniMind-V") + parser.add_argument("--num_workers", type=int, default=8) + parser.add_argument("--data_path", type=str, default="./dataset/pretrain_t2i_data.jsonl") + parser.add_argument("--images_path", type=str, default="./dataset/pretrain_t2i_code") + parser.add_argument("--ddp", action="store_true") + parser.add_argument("--accumulation_steps", type=int, default=1) + parser.add_argument("--grad_clip", type=float, default=1.0) + parser.add_argument("--warmup_iters", type=int, default=0) + parser.add_argument("--log_interval", type=int, default=100) + parser.add_argument("--save_interval", type=int, default=100) + parser.add_argument('--local_rank', type=int, default=-1) + parser.add_argument('--dim', default=512, type=int) + parser.add_argument('--n_layers', default=8, type=int) + parser.add_argument('--max_seq_len', default=640, type=int) + parser.add_argument('--use_moe', default=False, type=bool) + parser.add_argument('--img_process', type=bool, default=True) + args = parser.parse_args() + + model_config = VLMConfig(dim=args.dim, n_layers=args.n_layers, max_seq_len=args.max_seq_len, + image_special_token = '@' * 256, image_ids = [34] * 256) + max_seq_len = model_config.max_seq_len + args.save_dir = os.path.join(args.out_dir) + os.makedirs(args.save_dir, exist_ok=True) + os.makedirs(args.out_dir, exist_ok=True) + tokens_per_iter = args.batch_size * max_seq_len + torch.manual_seed(1337) + device_type = "cuda" if "cuda" in args.device else "cpu" + + args.wandb_run_name = f"MiniMind-V Pretrain-Epoch-{args.epochs}-BatchSize-{args.batch_size}-LearningRate-{args.learning_rate}" + + ctx = nullcontext() if device_type == "cpu" else torch.cuda.amp.autocast() + ddp = int(os.environ.get("RANK", -1)) != -1 # is this a ddp run? + ddp_local_rank, DEVICE = 0, "cuda:0" + if ddp: + init_distributed_mode() + args.device = torch.device(DEVICE) + + if args.use_wandb and (not ddp or ddp_local_rank == 0): + import wandb + + wandb.init(project=args.wandb_project, name=args.wandb_run_name) + else: + wandb = None + + model, tokenizer = init_model(model_config) + + train_ds = T2IDataset(args.data_path, args.images_path, tokenizer, max_seq_len, args.img_process, + image_special_token=model_config.image_special_token) + train_sampler = DistributedSampler(train_ds) if ddp else None + train_loader = DataLoader( + train_ds, + batch_size=args.batch_size, + pin_memory=True, + drop_last=False, + shuffle=False, + num_workers=args.num_workers, + sampler=train_sampler + ) + + scaler = torch.cuda.amp.GradScaler(enabled=(args.dtype in ['float16', 'bfloat16'])) + optimizer = optim.AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=args.learning_rate) + + if ddp: + model._ddp_params_and_buffers_to_ignore = {"pos_cis"} + model = DistributedDataParallel(model, device_ids=[ddp_local_rank]) + + iter_per_epoch = len(train_loader) + for epoch in range(args.epochs): + train_epoch(epoch, wandb) diff --git a/train_sft_t2i.py b/train_sft_t2i.py new file mode 100644 index 0000000..25fd5df --- /dev/null +++ b/train_sft_t2i.py @@ -0,0 +1,205 @@ +import os +import platform +import argparse +import time +import math +import warnings +import json + +import pandas as pd +import torch +import torch.nn.functional as F +import torch.distributed as dist +from contextlib import nullcontext + +from torch import optim, nn +from torch.nn.parallel import DistributedDataParallel +from torch.utils.data import DataLoader, DistributedSampler +from transformers import AutoTokenizer, AutoModel +from model.model_t2i import MiniMindT2I +from model.VLMConfig import VLMConfig +from model.dataset import T2IDataset + +warnings.filterwarnings('ignore') + + +def Logger(content): + if not ddp or dist.get_rank() == 0: + print(content) + + +def get_lr(current_step, total_steps, lr): + return lr / 10 + 0.5 * lr * (1 + math.cos(math.pi * current_step / total_steps)) + + +def train_epoch(epoch, wandb): + loss_fct = nn.CrossEntropyLoss(reduction='none') + start_time = time.time() + for step, (X, Y, loss_mask, pixel_tensors) in enumerate(train_loader): + X = X.to(args.device) + Y = Y.to(args.device) + loss_mask = loss_mask.to(args.device) + pixel_tensors = pixel_tensors + lr = get_lr(epoch * iter_per_epoch + step, args.epochs * iter_per_epoch, args.learning_rate) + for param_group in optimizer.param_groups: + param_group['lr'] = lr + + with ctx: + res = model(X, Y, pixel_tensors=pixel_tensors) + Y = res.target_ids + loss = loss_fct( + res.logits.view(-1, res.logits.size(-1)), + res.target_ids.view(-1) + ).view(Y.size()) + + loss = (loss * loss_mask).sum() / loss_mask.sum() + loss += res.aux_loss + loss = loss / args.accumulation_steps + + scaler.scale(loss).backward() + + if (step + 1) % args.accumulation_steps == 0: + scaler.unscale_(optimizer) + torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip) + + scaler.step(optimizer) + scaler.update() + + optimizer.zero_grad(set_to_none=True) + + if step % args.log_interval == 0: + spend_time = time.time() - start_time + Logger( + 'Epoch:[{}/{}]({}/{}) loss:{:.3f} lr:{:.7f} epoch_Time:{}min:'.format( + epoch + 1, + args.epochs, + step, + iter_per_epoch, + loss.item(), + optimizer.param_groups[-1]['lr'], + spend_time / (step + 1) * iter_per_epoch // 60 - spend_time // 60)) + + if (wandb is not None) and (not ddp or dist.get_rank() == 0): + wandb.log({"loss": loss, + "lr": optimizer.param_groups[-1]['lr'], + "epoch_Time": spend_time / (step + 1) * iter_per_epoch // 60 - spend_time // 60}) + + if (step + 1) % args.save_interval == 0 and (not ddp or dist.get_rank() == 0): + model.eval() + moe_path = '_moe' if model_config.use_moe else '' + ckp = f'{args.save_dir}/sft_t2i_{model_config.dim}{moe_path}.pth' + if isinstance(model, torch.nn.parallel.DistributedDataParallel): + state_dict = model.module.state_dict() + else: + state_dict = model.state_dict() + clean_state_dict = { + key: value for key, value in state_dict.items() if not key.startswith('vision_encoder.') + } + torch.save(clean_state_dict, ckp) + model.train() + + +def init_model(model_config: VLMConfig): + tokenizer = AutoTokenizer.from_pretrained('./model/minimind_tokenizer') + moe_path = '_moe' if model_config.use_moe else '' + # 加载纯语言模型权重 + ckp = f'./out/pretrain_t2i_{model_config.dim}{moe_path}.pth' + model = MiniMindT2I(model_config) + state_dict = torch.load(ckp, map_location=args.device) + model.load_state_dict(state_dict, strict=False) + + Logger(f'T2I可训练参数量:{sum(p.numel() for p in model.parameters() if p.requires_grad) / 1e6:.3f} 百万') + + return model.to(args.device), tokenizer + + +def init_distributed_mode(): + if not ddp: return + global ddp_local_rank, DEVICE + + dist.init_process_group(backend="nccl") + ddp_rank = int(os.environ["RANK"]) + ddp_local_rank = int(os.environ["LOCAL_RANK"]) + ddp_world_size = int(os.environ["WORLD_SIZE"]) + DEVICE = f"cuda:{ddp_local_rank}" + torch.cuda.set_device(DEVICE) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="MiniMind-V Pretrain") + parser.add_argument("--out_dir", type=str, default="out") + parser.add_argument("--epochs", type=int, default=6) + parser.add_argument("--batch_size", type=int, default=16) + parser.add_argument("--learning_rate", type=float, default=4e-4) + parser.add_argument("--device", type=str, default="cuda:0" if torch.cuda.is_available() else "cpu") + parser.add_argument("--dtype", type=str, default="bfloat16") + parser.add_argument("--use_wandb", default=False, action="store_true") + parser.add_argument("--wandb_project", type=str, default="MiniMind-V") + parser.add_argument("--num_workers", type=int, default=8) + parser.add_argument("--data_path", type=str, default="./dataset/pretrain_t2i_data.jsonl") + parser.add_argument("--images_path", type=str, default="./dataset/pretrain_t2i_code") + parser.add_argument("--ddp", action="store_true") + parser.add_argument("--accumulation_steps", type=int, default=1) + parser.add_argument("--grad_clip", type=float, default=1.0) + parser.add_argument("--warmup_iters", type=int, default=0) + parser.add_argument("--log_interval", type=int, default=100) + parser.add_argument("--save_interval", type=int, default=100) + parser.add_argument('--local_rank', type=int, default=-1) + parser.add_argument('--dim', default=512, type=int) + parser.add_argument('--n_layers', default=8, type=int) + parser.add_argument('--max_seq_len', default=640, type=int) + parser.add_argument('--use_moe', default=False, type=bool) + parser.add_argument('--img_process', type=bool, default=True) + args = parser.parse_args() + + model_config = VLMConfig(dim=args.dim, n_layers=args.n_layers, max_seq_len=args.max_seq_len, + image_special_token = '@' * 256, image_ids = [34] * 256) + max_seq_len = model_config.max_seq_len + args.save_dir = os.path.join(args.out_dir) + os.makedirs(args.save_dir, exist_ok=True) + os.makedirs(args.out_dir, exist_ok=True) + tokens_per_iter = args.batch_size * max_seq_len + torch.manual_seed(1337) + device_type = "cuda" if "cuda" in args.device else "cpu" + + args.wandb_run_name = f"MiniMind-V Pretrain-Epoch-{args.epochs}-BatchSize-{args.batch_size}-LearningRate-{args.learning_rate}" + + ctx = nullcontext() if device_type == "cpu" else torch.cuda.amp.autocast() + ddp = int(os.environ.get("RANK", -1)) != -1 # is this a ddp run? + ddp_local_rank, DEVICE = 0, "cuda:0" + if ddp: + init_distributed_mode() + args.device = torch.device(DEVICE) + + if args.use_wandb and (not ddp or ddp_local_rank == 0): + import wandb + + wandb.init(project=args.wandb_project, name=args.wandb_run_name) + else: + wandb = None + + model, tokenizer = init_model(model_config) + + train_ds = T2IDataset(args.data_path, args.images_path, tokenizer, max_seq_len, args.img_process, + image_special_token=model_config.image_special_token) + train_sampler = DistributedSampler(train_ds) if ddp else None + train_loader = DataLoader( + train_ds, + batch_size=args.batch_size, + pin_memory=True, + drop_last=False, + shuffle=False, + num_workers=args.num_workers, + sampler=train_sampler + ) + + scaler = torch.cuda.amp.GradScaler(enabled=(args.dtype in ['float16', 'bfloat16'])) + optimizer = optim.AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=args.learning_rate) + + if ddp: + model._ddp_params_and_buffers_to_ignore = {"pos_cis"} + model = DistributedDataParallel(model, device_ids=[ddp_local_rank]) + + iter_per_epoch = len(train_loader) + for epoch in range(args.epochs): + train_epoch(epoch, wandb)