From ab1673159a07231b39a6db13454994d8f05d2837 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Miguel=20Gonz=C3=A1lez-Fierro?= Date: Thu, 8 Feb 2018 13:01:04 +0000 Subject: [PATCH] Movielens fix #7 --- data_utils/movie_lense_data_converter.py | 17 ++++++++++++++--- 1 file changed, 14 insertions(+), 3 deletions(-) diff --git a/data_utils/movie_lense_data_converter.py b/data_utils/movie_lense_data_converter.py index 8b310f5..eb8b26d 100644 --- a/data_utils/movie_lense_data_converter.py +++ b/data_utils/movie_lense_data_converter.py @@ -31,7 +31,7 @@ def main(args): min_ts = 100000000000 max_ts = 0 total_rating_count = 0 - with open(inpt, 'r') as inpt_f: + with open(inpt, 'r') as inpt_f: #ratings.csv headers: userId,movieId,rating,timestamp for line in inpt_f: if 'userId' in line: continue @@ -65,19 +65,29 @@ def main(args): training_data = dict() validation_data = dict() test_data = dict() + train_set_items = set() + for userId in data.keys(): if len(data[userId]) < 2: - print("WARNING") + #print("WARNING, userId {} has less than 2 ratings, skipping user...".format(userId)) continue time_sorted_ratings = sorted(data[userId], key=lambda x: x[2]) # sort by timestamp last_train_ind = floor(percent * len(time_sorted_ratings)) training_data[userId] = time_sorted_ratings[:last_train_ind] + for rating_item in time_sorted_ratings[:last_train_ind]: + train_set_items.add(rating_item[0]) # keep track of items from training set p = random.random() if p <= 0.5: validation_data[userId] = time_sorted_ratings[last_train_ind:] else: test_data[userId] = time_sorted_ratings[last_train_ind:] + # remove items not not seen in training set + for userId, userRatings in test_data.items(): + test_data[userId] = [rating for rating in userRatings if rating[0] in train_set_items] + for userId, userRatings in validation_data.items(): + validation_data[userId] = [rating for rating in userRatings if rating[0] in train_set_items] + print("Training Data") print_stats(training_data) save_data_to_file(training_data, out_prefix+".train") @@ -91,4 +101,5 @@ def main(args): if __name__ == "__main__": - main(sys.argv) \ No newline at end of file + main(sys.argv) +