diff --git a/eval_videolm.py b/eval_videolm.py new file mode 100644 index 0000000..2a49f4e --- /dev/null +++ b/eval_videolm.py @@ -0,0 +1,121 @@ +import argparse +import os +import random +import numpy as np +import torch +import warnings +from PIL import Image +from transformers import AutoTokenizer, AutoModelForCausalLM +from model.model_videolm import MiniMindVideoLM +from model.VLMConfig import VLMConfig +from transformers import logging as hf_logging +from model.dataset import video2image + +hf_logging.set_verbosity_error() + +warnings.filterwarnings('ignore') + + +def count_parameters(model): + return sum(p.numel() for p in model.parameters() if p.requires_grad) + + +def init_model(lm_config, device): + tokenizer = AutoTokenizer.from_pretrained('./model/minimind_tokenizer') + if args.load == 0: + moe_path = '_moe' if args.use_moe else '' + modes = {0: 'pretrain_videolm', 1: 'sft_videolm'} + ckp = f'./{args.out_dir}/{modes[args.model_mode]}_{args.dim}{moe_path}.pth' + model = MiniMindVideoLM(lm_config) + state_dict = torch.load(ckp, map_location=device) + model.load_state_dict({k: v for k, v in state_dict.items() if 'mask' not in k}, strict=False) + else: + transformers_model_path = 'MiniMind2-V' + tokenizer = AutoTokenizer.from_pretrained(transformers_model_path) + model = AutoModelForCausalLM.from_pretrained(transformers_model_path, trust_remote_code=True) + + print(f'VLM参数量:{sum(p.numel() for p in model.parameters() if p.requires_grad) / 1e6:.3f} 百万') + + vision_model = MiniMindVideoLM.get_vision_model() + return model.eval().to(device), tokenizer, vision_model.eval().to(device) + + +def setup_seed(seed): + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + torch.backends.cudnn.deterministic = True + torch.backends.cudnn.benchmark = False + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Chat with MiniMind") + parser.add_argument('--lora_name', default='None', type=str) + parser.add_argument('--out_dir', default='out', type=str) + parser.add_argument('--temperature', default=0.65, type=float) + parser.add_argument('--top_p', default=0.85, type=float) + parser.add_argument('--device', default='cuda' if torch.cuda.is_available() else 'cpu', type=str) + # MiniMind2-Small (26M):(dim=512, n_layers=8) + # MiniMind2 (104M):(dim=768, n_layers=16) + parser.add_argument('--dim', default=512, type=int) + parser.add_argument('--n_layers', default=8, type=int) + parser.add_argument('--max_seq_len', default=8192, type=int) + parser.add_argument('--use_moe', default=False, type=bool) + # 默认单图推理,设置为2为多图推理 + parser.add_argument('--stream', default=True, type=bool) + parser.add_argument('--load', default=0, type=int, help="0: 原生torch权重,1: transformers加载") + parser.add_argument('--model_mode', default=0, type=int, + help="0: Pretrain模型,1: SFT模型") + args = parser.parse_args() + + lm_config = VLMConfig(dim=args.dim, n_layers=args.n_layers, max_seq_len=args.max_seq_len, use_moe=args.use_moe) + + model, tokenizer, vision_model = init_model(lm_config, args.device) + + + def chat_with_vlm(prompt, pixel_tensors, video_names): + messages = [{"role": "user", "content": prompt}] + + new_prompt = tokenizer.apply_chat_template( + messages, + tokenize=False, + add_generation_prompt=True + )[-args.max_seq_len + 1:] + + print(f'[Video]: {video_names}') + with torch.no_grad(): + x = torch.tensor(tokenizer(new_prompt)['input_ids'], device=args.device).unsqueeze(0) + outputs = model.generate( + x, + eos_token_id=tokenizer.eos_token_id, + max_new_tokens=args.max_seq_len, + temperature=args.temperature, + top_p=args.top_p, + stream=True, + pad_token_id=tokenizer.pad_token_id, + pixel_tensors=pixel_tensors + ) + print('🤖️: ', end='') + try: + if not args.stream: + print(tokenizer.decode(outputs.squeeze()[x.shape[1]:].tolist(), skip_special_tokens=True), end='') + else: + history_idx = 0 + for y in outputs: + answer = tokenizer.decode(y[0].tolist(), skip_special_tokens=True) + if (answer and answer[-1] == '�') or not answer: + continue + print(answer[history_idx:], end='', flush=True) + history_idx = len(answer) + except StopIteration: + print("No answer") + print('\n') + + + video_path = './dataset/eval_videos/video0.mp4' + prompt = f"what is a man driving down?\n{model.params.image_special_token}" + + video_tensors = video2image(video_path).to(args.device).unsqueeze(0).unsqueeze(0) + chat_with_vlm(prompt, video_tensors, 'video0.mp4') diff --git a/model/dataset.py b/model/dataset.py index dae04a7..84ccaf9 100644 --- a/model/dataset.py +++ b/model/dataset.py @@ -4,10 +4,48 @@ import torch from .model_vlm import MiniMindVLM import os +from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize, InterpolationMode +import cv2 +import numpy as np os.environ["TOKENIZERS_PARALLELISM"] = "false" +def video2image(video_path, num_frames=8, size=224): + def preprocess(size, n_px): + return Compose([ + Resize(size, interpolation=InterpolationMode.BICUBIC), + CenterCrop(size), + lambda image: image.convert("RGB"), + ToTensor(), + Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)), + ])(n_px) + + cap = cv2.VideoCapture(video_path) + cap = cv2.VideoCapture(video_path, cv2.CAP_FFMPEG) + frameCount = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) + fps = int(cap.get(cv2.CAP_PROP_FPS)) + + if fps < 1 or frameCount < 1: + images = np.zeros([3, size, size], dtype=np.float32) + print("ERROR: problem reading video file: ", video_path) + else: + frames_idx = np.sort(np.random.choice(frameCount, min(num_frames, frameCount), replace=False)) + + images = np.zeros([len(frames_idx), 3, size, size], dtype=np.float32) + + for i, idx in enumerate(frames_idx): + cap.set(cv2.CAP_PROP_POS_FRAMES, idx) + ret, frame = cap.read() + if not ret: + continue + frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) + images[i,:,:,:] = preprocess(size, Image.fromarray(frame).convert("RGB")) + + cap.release() + video_frames = torch.tensor(images) + return video_frames + class VLMDataset(Dataset): def __init__(self, jsonl_path, images_path, tokenizer, preprocess=None, max_length=512, image_special_token='@' * 196): @@ -84,3 +122,80 @@ def __getitem__(self, index: int): image_tensors = torch.stack(image_tensors, dim=0) return X, Y, loss_mask, image_tensors + + +class VideoLMDataset(Dataset): + def __init__(self, jsonl_path, videos_path, tokenizer, max_length=512, + video_special_token='@' * 196): + + super().__init__() + self.samples = self.load_data(jsonl_path) + self.videos_path = videos_path + + self.tokenizer = tokenizer + self.max_length = max_length + self.video_token = video_special_token + self.bos_id = tokenizer('assistant\n', add_special_tokens=False).input_ids + self.eos_id = tokenizer('\n', add_special_tokens=False).input_ids + + def __len__(self): + return len(self.samples) + + def load_data(self, path): + samples = [] + with open(path, 'r', encoding='utf-8') as f: + for line_num, line in enumerate(f, 1): + data = json.loads(line.strip()) + samples.append(data) + return samples + + def _create_chat_prompt(self, conversations): + messages = [] + for i, turn in enumerate(conversations): + role = 'user' if i % 2 == 0 else 'assistant' + messages.append({"role": role, "content": turn['content'].replace('