这是indexloc提供的服务,不要输入任何密码
Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 36 additions & 0 deletions dataset/get_pretrain_t2i_data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
import json
from tqdm import tqdm

# 输入和输出文件路径
input_file = 'pretrain_data.jsonl'
output_file = 'pretrain_t2i_data.jsonl'

# 打开输入文件和输出文件
with open(input_file, 'r', encoding='utf-8') as infile, open(output_file, 'w', encoding='utf-8') as outfile:
for line in tqdm(infile):
# 读取每行JSON数据
data = json.loads(line.strip())

# 获取对话内容和图片字段
conversations = data.get("conversations", [])
image = data.get("image", "")

# 修改对话内容
if len(conversations) == 2:
# 交换user和assistant的内容
user_content = conversations[0]["content"]
assistant_content = conversations[1]["content"]

conversations[0]["content"] = assistant_content
conversations[1]["content"] = "<image>"

# 创建新的数据格式
new_data = {
"conversations": conversations,
"image": image
}

# 将处理后的数据写入到新的jsonl文件
outfile.write(json.dumps(new_data, ensure_ascii=False) + "\n")

print("处理完成,结果保存在", output_file)
93 changes: 93 additions & 0 deletions eval_t2i.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
import argparse
import os
import random
import numpy as np
import torch
import warnings
import torch.nn.functional as F
from PIL import Image
from transformers import AutoTokenizer, AutoModelForCausalLM
from model.model_t2i import MiniMindT2I
from model.VLMConfig import VLMConfig
from transformers import logging as hf_logging

hf_logging.set_verbosity_error()

warnings.filterwarnings('ignore')


def count_parameters(model):
return sum(p.numel() for p in model.parameters() if p.requires_grad)


def init_model(lm_config, device):
tokenizer = AutoTokenizer.from_pretrained('./model/minimind_tokenizer')
moe_path = '_moe' if args.use_moe else ''
ckp = f'./{args.out_dir}/sft_t2i_{args.dim}{moe_path}.pth'
model = MiniMindT2I(lm_config)
state_dict = torch.load(ckp, map_location=device)
model.load_state_dict({k: v for k, v in state_dict.items() if 'mask' not in k}, strict=False)

print(f'T2I参数量:{sum(p.numel() for p in model.parameters() if p.requires_grad) / 1e6:.3f} 百万')

return model.eval().to(device), tokenizer


def setup_seed(seed):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False


if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Chat with MiniMind")
parser.add_argument('--lora_name', default='None', type=str)
parser.add_argument('--out_dir', default='out', type=str)
parser.add_argument('--temperature', default=0.65, type=float)
parser.add_argument('--top_p', default=0.85, type=float)
parser.add_argument('--device', default='cuda' if torch.cuda.is_available() else 'cpu', type=str)
# MiniMind2-Small (26M):(dim=512, n_layers=8)
# MiniMind2 (104M):(dim=768, n_layers=16)
parser.add_argument('--dim', default=512, type=int)
parser.add_argument('--n_layers', default=8, type=int)
parser.add_argument('--max_seq_len', default=640, type=int)
parser.add_argument('--use_moe', default=False, type=bool)
args = parser.parse_args()

lm_config = VLMConfig(dim=args.dim, n_layers=args.n_layers, max_seq_len=args.max_seq_len, use_moe=args.use_moe)

model, tokenizer = init_model(lm_config, args.device)


def chat_with_vlm(prompt):
messages = [{"role": "user", "content": prompt}]

new_prompt = tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True
)[-args.max_seq_len + 1:]

with torch.no_grad():
x = torch.tensor(tokenizer(new_prompt)['input_ids'], device=args.device).unsqueeze(0)
outputs = model.generate(
x,
temperature=args.temperature,
top_p=args.top_p,
pad_token_id=tokenizer.pad_token_id,
)
image = model.image_tokenizer.decode_code(outputs, (1, 8, 16, 16))
image = F.interpolate(image, size=[256, 256], mode='bicubic').permute(0, 2, 3, 1)[0]
image = torch.clamp(127.5 * image + 128.0, 0, 255).to("cpu", dtype=torch.uint8).numpy()
Image.fromarray(image).save('output.jpg')
print('🤖️: 图片保存至output.jpg')
print('\n')


image_dir = './dataset/eval_images/'
prompt = f"一个年轻人准备踢足球"
chat_with_vlm(prompt)
110 changes: 110 additions & 0 deletions model/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,26 @@
import torch
from .model_vlm import MiniMindVLM
import os
import numpy as np
from torchvision import transforms

os.environ["TOKENIZERS_PARALLELISM"] = "false"

def center_crop_arr(pil_image, image_size):
while min(*pil_image.size) >= 2 * image_size:
pil_image = pil_image.resize(
tuple(x // 2 for x in pil_image.size), resample=Image.BOX
)

scale = image_size / min(*pil_image.size)
pil_image = pil_image.resize(
tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC
)

arr = np.array(pil_image)
crop_y = (arr.shape[0] - image_size) // 2
crop_x = (arr.shape[1] - image_size) // 2
return Image.fromarray(arr[crop_y: crop_y + image_size, crop_x: crop_x + image_size])

class VLMDataset(Dataset):
def __init__(self, jsonl_path, images_path, tokenizer, preprocess=None, max_length=512,
Expand Down Expand Up @@ -84,3 +101,96 @@ def __getitem__(self, index: int):
image_tensors = torch.stack(image_tensors, dim=0)

return X, Y, loss_mask, image_tensors

class T2IDataset(Dataset):
def __init__(self, jsonl_path, images_path, tokenizer, max_length=512, img_pre_process=False,
image_special_token='@' * 256):

super().__init__()
self.samples = self.load_data(jsonl_path)
self.images_path = images_path

self.img_pre_process = img_pre_process # 是否使用提前处理的图片

self.tokenizer = tokenizer
self.max_length = max_length
self.image_size = 256
self.image_token = image_special_token
self.bos_id = tokenizer('<s>assistant\n', add_special_tokens=False).input_ids
self.eos_id = tokenizer('</s>\n', add_special_tokens=False).input_ids

def __len__(self):
return len(self.samples)

def load_data(self, path):
samples = []
with open(path, 'r', encoding='utf-8') as f:
for line_num, line in enumerate(f, 1):
data = json.loads(line.strip())
samples.append(data)
return samples

def _create_chat_prompt(self, conversations):
messages = []
for i, turn in enumerate(conversations):
role = 'user' if i % 2 == 0 else 'assistant'
messages.append({"role": role, "content": turn['content'].replace('<image>', self.image_token)})
return self.tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=False
)

def _generate_loss_mask(self, input_ids):
loss_mask = [0] * len(input_ids)
i = 0
while i < len(input_ids):
if input_ids[i:i + len(self.bos_id)] == self.bos_id:
start = i + len(self.bos_id)
end = start
while end < len(input_ids):
if input_ids[end:end + len(self.eos_id)] == self.eos_id:
break
end += 1
for j in range(start + 1, min(end + len(self.eos_id) + 1, self.max_length)):
loss_mask[j] = 1
i = end + len(self.eos_id) if end < len(input_ids) else len(input_ids)
else:
i += 1
return loss_mask

def __getitem__(self, index: int):
sample = self.samples[index]
image_paths = sample['image']
prompt = self._create_chat_prompt(sample['conversations'])
input_ids = self.tokenizer(prompt).input_ids[:self.max_length]
input_ids += [self.tokenizer.pad_token_id] * (self.max_length - len(input_ids))
loss_mask = self._generate_loss_mask(input_ids)

X = torch.tensor(input_ids[:-1], dtype=torch.long)
Y = torch.tensor(input_ids[1:], dtype=torch.long)
loss_mask = torch.tensor(loss_mask[1:], dtype=torch.long)

if self.img_pre_process:
image_emb_path = image_paths.replace('.jpg', '_emb.npy')
image_emb = np.load(f'{self.images_path}/{image_emb_path}')
image_emb = torch.tensor(image_emb, dtype=torch.float32)
# 加载预处理的图像token
image_token_path = image_paths.replace('.jpg', '_token.npy')
image_token = np.load(f'{self.images_path}/{image_token_path}')
image_token = torch.tensor(image_token, dtype=torch.long)
return X, Y, loss_mask, (image_emb, image_token)
else:
image_tensors = []
for image_name in image_paths.split(','):
image_name = image_name.strip()
image = Image.open(f'{self.images_path}/{image_name}').convert("RGB")
image = center_crop_arr(image, self.image_size)
image = np.array(image) / 255.
image = 2.0 * image - 1.0
image = torch.tensor(image, dtype=torch.float32)
image_tensor = torch.einsum('hwc->chw', image)
image_tensors.append(image_tensor)
image_tensors = torch.stack(image_tensors, dim=0)

return X, Y, loss_mask, image_tensors
Loading