这是indexloc提供的服务,不要输入任何密码
Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 30 additions & 10 deletions demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,12 @@
import numpy as np
import cv2
import os
os.environ['KMP_DUPLICATE_LIB_OK']='True'
from generate_video_tool.pano_video_generation import generate_video
from PIL import Image
from exiftool import ExifToolHelper
from datetime import datetime

torch.manual_seed(0)

def get_K_R(FOV, THETA, PHI, height, width):
Expand Down Expand Up @@ -65,26 +69,41 @@ def resize_and_center_crop(img, size):
config_file = 'configs/pano_generation.yaml'
config = yaml.load(open(config_file, 'rb'), Loader=yaml.SafeLoader)
model = PanoGenerator(config)
model.load_state_dict(torch.load('weights/pano.ckpt', map_location='cpu')[
'state_dict'], strict=True)
model.load_state_dict(torch.load('weights/pano.ckpt', map_location='cpu')['state_dict'], strict=True)
#saved_ckpt = torch.load('weights/pano.ckpt', map_location='cpu')
#model.load_state_dict(saved_ckpt, strict=False)
model=model.cuda()
img=None
else:

config_file = 'configs/pano_generation_outpaint.yaml'
config = yaml.load(open(config_file, 'rb'), Loader=yaml.SafeLoader)
model = PanoOutpaintGenerator(config)
model.load_state_dict(torch.load('weights/pano_outpaint.ckpt', map_location='cpu')[
'state_dict'], strict=True)
model.load_state_dict(torch.load('weights/pano_outpaint.ckpt', map_location='cpu')['state_dict'], strict=True)
#saved_ckpt = torch.load('weights/pano_outpaint.ckpt', map_location='cpu')
#model.load_state_dict(saved_ckpt, strict=False)
model=model.cuda()

img=cv2.imread(args.image_path)
img=cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
img=resize_and_center_crop(img, config['dataset']['resolution'])
img=img/127.5-1

img=img/127.5-1
img=torch.tensor(img).cuda()

#read stable diffusion prompts from PNG
with ExifToolHelper() as et:
#print EXIF metadata
EXIF_dict = et.get_metadata(args.image_path)
print(f'Metadata: {EXIF_dict}')
if 'PNG:Parameters' in EXIF_dict[0]:
PNGparameters = EXIF_dict[0]['PNG:Parameters']
parsed_parameters = PNGparameters.split("\n")
positive_prompt = parsed_parameters[0]
negative_prompt = parsed_parameters[1]
print(f'Positive promts: ' + positive_prompt)
print(f'Negative promts: ' + negative_prompt)
args.text= positive_prompt

resolution=config['dataset']['resolution']
Rs=[]
Ks=[]
Expand All @@ -107,7 +126,7 @@ def resize_and_center_crop(img, size):
for i, line in enumerate(f):
prompt.append(line.strip())
if len(prompt)<8:
raise ValueError('text file should contain 8 lines')
raise ValueError('text file should contain 8 lines for each camera view')
args.text=prompt[0]
else:
prompt=[args.text]*8
Expand All @@ -121,11 +140,12 @@ def resize_and_center_crop(img, size):
'K': K
}
images_pred=model.inference(batch)
res_dir=args.text[:20]
print('save in fold: {}'.format(res_dir))
#res_dir=os.path.join('outputs/',args.text[:20])
res_dir=os.path.join('outputs/',f'results'+datetime.now().strftime('--%Y%m%d-%H%M%S'))
print('saved to the folder: {}'.format(res_dir))
os.makedirs(res_dir, exist_ok=True)
with open(os.path.join(res_dir, 'prompt.txt'), 'w') as f:
f.write(args.text)
f.write(args.text)
image_paths=[]
for i in range(8):
im = Image.fromarray(images_pred[0,i])
Expand Down