diff --git a/.envrc b/.envrc new file mode 100644 index 0000000..446626f --- /dev/null +++ b/.envrc @@ -0,0 +1 @@ +source activate minimind \ No newline at end of file diff --git a/.gitignore b/.gitignore index 5f10ded..f0b6441 100644 --- a/.gitignore +++ b/.gitignore @@ -1,4 +1,5 @@ model/clip_model/clip-vit-base-patch32/ +model/siglip_model/siglip-vit-base-patch16/ out/*.pth full.json trans_json.py diff --git a/1-pretrain_vlm.py b/1-pretrain_vlm.py index 8cec330..e42840e 100644 --- a/1-pretrain_vlm.py +++ b/1-pretrain_vlm.py @@ -128,7 +128,7 @@ def init_model(lm_config): print(f'模型可学习参数: {count_parameters(model) / 1e6} 百万 = {count_parameters(model) / 1e9} B (Billion)') - (vision_model, preprocess) = get_vision_model() + (vision_model, preprocess) = get_vision_model(args.visual_encoder) vision_model = vision_model.to(args.device) return model, tokenizer, (vision_model, preprocess) @@ -166,10 +166,15 @@ def init_distributed_mode(): parser.add_argument("--log_interval", type=int, default=10, help="Logging interval") parser.add_argument("--save_interval", type=int, default=100, help="Model saving interval") parser.add_argument('--local_rank', type=int, default=-1, help='local rank for distributed training') + parser.add_argument('--visual_encoder', type=str, default="clip", help='type of visual endcoder') args = parser.parse_args() - lm_config = LMConfig() + if args.visual_encoder == "clip": + lm_config = LMConfig() + else: + lm_config = LMConfig(image_special_token='<'*98+'>'*98, image_ids=[30]*98+[32]*98) + max_seq_len = lm_config.max_seq_len args.save_dir = os.path.join(args.out_dir) os.makedirs(args.save_dir, exist_ok=True) diff --git a/2-sft_vlm.py b/2-sft_vlm.py index 768a6e1..0a2b1a6 100644 --- a/2-sft_vlm.py +++ b/2-sft_vlm.py @@ -148,7 +148,7 @@ def init_model(lm_config): print(f'模型可学习参数: {count_parameters(model) / 1e6} 百万 = {count_parameters(model) / 1e9} B (Billion)') - (vision_model, preprocess) = get_vision_model() + (vision_model, preprocess) = get_vision_model(args.visual_encoder) vision_model = vision_model.to(args.device) return model, tokenizer, (vision_model, preprocess) @@ -190,10 +190,15 @@ def init_distributed_mode(): parser.add_argument('--local_rank', type=int, default=-1, help='local rank for distributed training') parser.add_argument('--multi', type=bool, default=False, help='multi-images training') parser.add_argument('--save_last', type=bool, default=True, help='save last step model') + parser.add_argument('--visual_encoder', type=str, default="clip", help='type of visual endcoder') args = parser.parse_args() - lm_config = LMConfig() + if args.visual_encoder == "clip": + lm_config = LMConfig() + else: + lm_config = LMConfig(image_special_token='<'*98+'>'*98, image_ids=[30]*98+[32]*98) + max_seq_len = lm_config.max_seq_len args.save_dir = os.path.join(args.out_dir) os.makedirs(args.save_dir, exist_ok=True) diff --git a/3-eval_chat.py b/3-eval_chat.py index 84fa53f..aa7c91a 100644 --- a/3-eval_chat.py +++ b/3-eval_chat.py @@ -41,7 +41,7 @@ def init_model(lm_config, device, multi): model = model.to(device) print(f'模型参数: {count_parameters(model) / 1e6} 百万 = {count_parameters(model) / 1e9} B (Billion)') - vision_model, preprocess = get_vision_model() + vision_model, preprocess = get_vision_model(encoder_type="clip") vision_model = vision_model.to(device) return model, tokenizer, vision_model, preprocess @@ -66,7 +66,12 @@ def setup_seed(seed): device = 'cuda:0' if torch.cuda.is_available() else 'cpu' dtype = 'bfloat16' max_seq_len = 1024 - lm_config = LMConfig() + encoder_type="clip" + # lm_config = LMConfig() + if encoder_type == "clip": + lm_config = LMConfig() + else: + lm_config = LMConfig(image_special_token='<'*98+'>'*98, image_ids=[30]*98+[32]*98) lm_config.max_seq_len = max_seq_len model, tokenizer, vision_model, preprocess = init_model(lm_config, device, multi) model.eval() diff --git a/model/__pycache__/LMConfig.cpython-310.pyc b/model/__pycache__/LMConfig.cpython-310.pyc index 3b4c852..22d8930 100644 Binary files a/model/__pycache__/LMConfig.cpython-310.pyc and b/model/__pycache__/LMConfig.cpython-310.pyc differ diff --git a/model/__pycache__/dataset.cpython-310.pyc b/model/__pycache__/dataset.cpython-310.pyc index ef0806e..93d0132 100644 Binary files a/model/__pycache__/dataset.cpython-310.pyc and b/model/__pycache__/dataset.cpython-310.pyc differ diff --git a/model/__pycache__/model.cpython-310.pyc b/model/__pycache__/model.cpython-310.pyc index 2df8cd0..5043f05 100644 Binary files a/model/__pycache__/model.cpython-310.pyc and b/model/__pycache__/model.cpython-310.pyc differ diff --git a/model/__pycache__/vision_utils.cpython-310.pyc b/model/__pycache__/vision_utils.cpython-310.pyc index 97b6ac5..f995df7 100644 Binary files a/model/__pycache__/vision_utils.cpython-310.pyc and b/model/__pycache__/vision_utils.cpython-310.pyc differ diff --git a/model/dataset.py b/model/dataset.py index f5b7511..bd6ba9f 100644 --- a/model/dataset.py +++ b/model/dataset.py @@ -54,7 +54,7 @@ def __getitem__(self, index: int): sample = self.data[index] image_name = sample['image'] conversation = sample['conversations'] - # minimind-v的image的特殊占位符,对应每张图切分成10个token,和get_img_process中的数量对应 + # minimind-v的image的特殊占位符,对应每张图切分成M个token,和get_img_process中的数量对应 messages = [] # 遍历 conversation 列表 for i in range(0, len(conversation), 2): diff --git a/model/model.py b/model/model.py index 3db3ce0..bfa4c65 100644 --- a/model/model.py +++ b/model/model.py @@ -326,23 +326,23 @@ class Transformer(PreTrainedModel): config_class = LMConfig last_loss: Optional[torch.Tensor] - def __init__(self, params: LMConfig = None): + def __init__(self, params: LMConfig = None, vocab_size = 6400): super().__init__(params) if not params: params = LMConfig() self.params = params - self.vocab_size = params.vocab_size + self.vocab_size = vocab_size self.n_layers = params.n_layers # image的特殊占位符,对应每张图切分成M个token,和get_img_process中的数量对应 self.image_ids = params.image_ids - self.tok_embeddings = nn.Embedding(params.vocab_size, params.dim) + self.tok_embeddings = nn.Embedding(self.vocab_size, params.dim) self.dropout = nn.Dropout(params.dropout) self.layers = torch.nn.ModuleList() for layer_id in range(self.n_layers): self.layers.append(TransformerBlock(layer_id, params)) self.norm = RMSNorm(params.dim, eps=params.norm_eps) - self.output = nn.Linear(params.dim, params.vocab_size, bias=False) + self.output = nn.Linear(params.dim, self.vocab_size, bias=False) self.tok_embeddings.weight = self.output.weight pos_cis = precompute_pos_cis(self.params.dim // self.params.n_heads, self.params.max_seq_len) self.register_buffer("pos_cis", pos_cis, persistent=False) @@ -372,17 +372,27 @@ def count_vision_proj(self, tokens, h, image_encoders=None, seqlen=200): # 查找token中片段的索引,为了替换做准备 def find_indices(tokens, image_ids): image_ids_tensor = torch.tensor(image_ids).to(tokens.device) - indices = [] - + len_image_ids = len(image_ids) + + # .generate时,在初始化后直接跳过 + if len_image_ids > tokens.size(1): + # print(f"len_image_ids ({len_image_ids}) is greater than sequence length ({tokens.size(1)}), skipping.") + return None + + # 使用view来创建一个视图,便于处理滑动窗口 + tokens_view = tokens.unfold(1, len_image_ids, 1) # 在第二维度创建滑动窗口 + # 检查每个滑动窗口是否与image_ids_tensor相等 + matches = (tokens_view == image_ids_tensor).all(dim=2) # 对窗口中的每一行进行比较 + + # 提取匹配的索引 + indices = {} for batch_idx in range(tokens.size(0)): - for i in range(tokens.size(1) - len(image_ids) + 1): - if torch.equal(tokens[batch_idx, i:i + len(image_ids)], image_ids_tensor): - indices.append([batch_idx, i, i + len(image_ids) - 1]) # 返回batch_idx和开始结束索引 - + match_indices = matches[batch_idx].nonzero(as_tuple=True)[0] # 获取非零(匹配)索引 + if match_indices.numel() > 0: # 如果有匹配 + indices[batch_idx] = [(idx.item(), idx.item() + len_image_ids - 1) for idx in match_indices] return indices if indices else None - image_indices = find_indices(tokens, - self.image_ids) # [0, 4, 53], [0, 54, 103], [0, 104, 153], [0, 154, 203] or [1, 4, 53], [1, 54, 103] + image_indices = find_indices(tokens, self.image_ids) # 字典形式存储索引 # 如果此时有图像编码 if image_encoders is not None: @@ -394,8 +404,8 @@ def find_indices(tokens, image_ids): for i in range(h.size(0)): # i即为current_batch_idx索引 img_idx = 0 - for batch_idx, start_idx, end_idx in image_indices: - if batch_idx == i: + if i in image_indices: # 直接从字典中获取 + for start_idx, end_idx in image_indices[i]: # 插入vision_proj特征 before = h[i][:start_idx, :] after = h[i][end_idx + 1:, :] diff --git a/model/siglip_model/README.md b/model/siglip_model/README.md new file mode 100644 index 0000000..e9e0702 --- /dev/null +++ b/model/siglip_model/README.md @@ -0,0 +1,5 @@ +* 需要把siglip-base-patch16-224模型下载到此目录下 + +```bash +git clone https://hf-mirror.com/google/siglip-base-patch16-224 +``` \ No newline at end of file diff --git a/model/vision_utils.py b/model/vision_utils.py index 56d70fe..61e14ea 100644 --- a/model/vision_utils.py +++ b/model/vision_utils.py @@ -1,5 +1,5 @@ import warnings -from transformers import CLIPProcessor, CLIPModel +from transformers import CLIPProcessor, CLIPModel, SiglipProcessor, SiglipModel from PIL import Image import requests import torch @@ -8,19 +8,27 @@ warnings.filterwarnings('ignore') -def get_vision_model(): +def get_vision_model(encoder_type): # 加载预训练的CLIP模型和处理器 - model_path = "./model/clip_model/clip-vit-base-patch32" - model = CLIPModel.from_pretrained(model_path) - processor = CLIPProcessor.from_pretrained(model_path) + if encoder_type == "clip": + model_path = "./model/clip_model/clip-vit-base-patch32" + model = CLIPModel.from_pretrained(model_path) + processor = CLIPProcessor.from_pretrained(model_path) + else: + model_path = "./model/siglip_model/siglip-vit-base-patch16" + model = SiglipModel.from_pretrained(model_path) + processor = SiglipProcessor.from_pretrained(model_path) return (model, processor) def get_img_process(image, processor): - # 将图像调整为144*144大小 + # 将图像调整为224*224大小 image = image.resize((224, 224)) + if image.mode in ['RGBA', 'LA']: # 处理有透明通道的图像 + image = image.convert('RGB') # 使用CLIPProcessor处理每个patch - inputs = processor(images=image, return_tensors="pt", clean_up_tokenization_spaces=False) + # inputs = processor(images=image, return_tensors="pt", clean_up_tokenization_spaces=False) + inputs = processor(images=image, return_tensors="pt") return inputs @@ -32,7 +40,7 @@ def hook_fn(module, input, output): embeddings.append(output.last_hidden_state) # 从 BatchEncoding 中提取图像张量 - if isinstance(batch_encoding, transformers.tokenization_utils_base.BatchEncoding): + if isinstance(batch_encoding, transformers.tokenization_utils_base.BatchEncoding) or isinstance(batch_encoding, transformers.feature_extraction_utils.BatchFeature): image_tensor = batch_encoding['pixel_values'] else: image_tensor = batch_encoding # torch.Size([32, 4, 3, 224, 224]) @@ -58,5 +66,5 @@ def hook_fn(module, input, output): hook.remove() # 拼接所有特征向量成为一个张量 - all_embeddings = torch.stack(embeddings, dim=0).squeeze() # torch.Size([32, 4, 50, 768]) + all_embeddings = torch.stack(embeddings, dim=0).squeeze() # torch.Size([32, 4, 50, 768]) or torch.Size([32, 2, 196, 768]) return all_embeddings