From 1c442c27ca0f89f7b4d7e18dc9bfded8d65ea30a Mon Sep 17 00:00:00 2001 From: Minoru Osuka Date: Wed, 4 Jun 2025 00:08:21 +0900 Subject: [PATCH] Refactoring --- src/adaboost.rs | 136 ++++++++++++++++++++++++++++++------------------ src/main.rs | 32 ++++++++++-- src/trainer.rs | 7 ++- 3 files changed, 117 insertions(+), 58 deletions(-) diff --git a/src/adaboost.rs b/src/adaboost.rs index e0f9f63..2272343 100644 --- a/src/adaboost.rs +++ b/src/adaboost.rs @@ -7,6 +7,26 @@ use std::sync::Arc; type Label = i8; +/// Structure to hold evaluation metrics. +pub struct Metrics { + /// Accuracy in percentage (%) + pub accuracy: f64, + /// Precision in percentage (%) + pub precision: f64, + /// Recall in percentage (%) + pub recall: f64, + /// Number of instances in the dataset + pub num_instances: usize, + /// True Positives count + pub true_positives: usize, + /// False Positives count + pub false_positives: usize, + /// False Negatives count + pub false_negatives: usize, + /// True Negatives count + pub true_negatives: usize, +} + /// AdaBoost implementation for binary classification /// This implementation uses a simple feature extraction method /// and is designed for educational purposes. @@ -72,16 +92,20 @@ impl AdaBoost { let line = line?; let mut parts = line.split_whitespace(); let _label = parts.next(); + for h in parts { map.entry(h.to_string()).or_insert(0.0); buf_size += 1; } + self.num_instances += 1; if self.num_instances % 1000 == 0 { eprint!("\rfinding instances...: {} instances found", self.num_instances); } } + eprintln!("\rfinding instances...: {} instances found", self.num_instances); + map.insert("".to_string(), 0.0); self.features = map.keys().cloned().collect(); @@ -128,7 +152,6 @@ impl AdaBoost { let end = self.instances_buf.len(); self.instances.push((start, end)); self.instance_weights.push((-2.0 * label as f64 * score).exp()); - if self.instance_weights.len() % 1000 == 0 { eprint!( "\rloading instances...: {}/{} instances loaded", @@ -137,7 +160,13 @@ impl AdaBoost { ); } } - eprintln!(); + + eprintln!( + "\rloading instances...: {}/{} instances loaded", + self.instance_weights.len(), + self.num_instances + ); + Ok(()) } @@ -277,55 +306,6 @@ impl AdaBoost { Ok(()) } - /// Gets the bias term of the model. - /// The bias is calculated as the negative sum of the model weights divided by 2. - /// - /// # Returns:The bias term as a `f64`. - pub fn get_bias(&self) -> f64 { - -self.model.iter().sum::() / 2.0 - } - - /// Displays the result of the model's performance on the training data. - /// It calculates accuracy, precision, recall, and confusion matrix. - pub fn show_result(&self) { - let bias = self.get_bias(); - let mut pp = 0; - let mut pn = 0; - let mut np = 0; - let mut nn = 0; - - for i in 0..self.num_instances { - let label = self.labels[i]; - let (start, end) = self.instances[i]; - let mut score = bias; - for &h in &self.instances_buf[start..end] { - score += self.model[h]; - } - - if score >= 0.0 { - if label > 0 { - pp += 1 - } else { - pn += 1 - } - } else if label > 0 { - np += 1 - } else { - nn += 1 - } - } - - let acc = (pp + nn) as f64 / self.num_instances as f64 * 100.0; - let prec = pp as f64 / (pp + pn).max(1) as f64 * 100.0; - let recall = pp as f64 / (pp + np).max(1) as f64 * 100.0; - - eprintln!("Result:"); - eprintln!("Accuracy: {:.2}% ({} / {})", acc, pp + nn, self.num_instances); - eprintln!("Precision: {:.2}% ({} / {})", prec, pp, pp + pn); - eprintln!("Recall: {:.2}% ({} / {})", recall, pp, pp + np); - eprintln!("Confusion Matrix: TP: {}, FP: {}, FN: {}, TN: {}", pp, pn, np, nn); - } - /// Adds a new instance to the model. /// The instance is represented by a set of attributes and a label. /// @@ -372,4 +352,58 @@ impl AdaBoost { -1 } } + + /// Gets the bias term of the model. + /// The bias is calculated as the negative sum of the model weights divided by 2. + /// + /// # Returns:The bias term as a `f64`. + pub fn get_bias(&self) -> f64 { + -self.model.iter().sum::() / 2.0 + } + + /// Calculates and returns the performance metrics of the model on the training data. + pub fn get_metrics(&self) -> Metrics { + let bias = self.get_bias(); + let mut true_positives = 0; // true positives + let mut false_positives = 0; // false positives + let mut false_negatives = 0; // false negatives + let mut true_negatives = 0; // true negatives + + for i in 0..self.num_instances { + let label = self.labels[i]; + let (start, end) = self.instances[i]; + let mut score = bias; + for &h in &self.instances_buf[start..end] { + score += self.model[h]; + } + if score >= 0.0 { + if label > 0 { + true_positives += 1; + } else { + false_positives += 1; + } + } else if label > 0 { + false_negatives += 1; + } else { + true_negatives += 1; + } + } + + let accuracy = (true_positives + true_negatives) as f64 / self.num_instances as f64 * 100.0; + let precision = + true_positives as f64 / (true_positives + false_positives).max(1) as f64 * 100.0; + let recall = + true_positives as f64 / (true_positives + false_negatives).max(1) as f64 * 100.0; + + Metrics { + accuracy, + precision, + recall, + num_instances: self.num_instances, + true_positives, + false_positives, + false_negatives, + true_negatives, + } + } } diff --git a/src/main.rs b/src/main.rs index e31f697..ff85dde 100644 --- a/src/main.rs +++ b/src/main.rs @@ -92,7 +92,7 @@ fn extract(args: ExtractArgs) -> Result<(), Box> { extractor.extract(args.corpus_file.as_path(), args.features_file.as_path())?; - println!("Feature extraction completed successfully."); + eprintln!("Feature extraction completed successfully."); Ok(()) } @@ -129,9 +129,35 @@ fn train(args: TrainArgs) -> Result<(), Box> { trainer.load_model(model_path.as_path())?; } - trainer.train(running, args.model_file.as_path())?; + let metrics = trainer.train(running, args.model_file.as_path())?; + + eprintln!("Result Metrics:"); + eprintln!( + " Accuracy: {:.2}% ( {} / {} )", + metrics.accuracy, + metrics.true_positives + metrics.true_negatives, + metrics.num_instances + ); + eprintln!( + " Precision: {:.2}% ( {} / {} )", + metrics.precision, + metrics.true_positives, + metrics.true_positives + metrics.false_positives + ); + eprintln!( + " Recall: {:.2}% ( {} / {} )", + metrics.recall, + metrics.true_positives, + metrics.true_positives + metrics.false_negatives + ); + eprintln!( + " Confusion Matrix:\n True Positives: {}\n False Positives: {}\n False Negatives: {}\n True Negatives: {}", + metrics.true_positives, + metrics.false_positives, + metrics.false_negatives, + metrics.true_negatives + ); - println!("Training completed successfully."); Ok(()) } diff --git a/src/trainer.rs b/src/trainer.rs index dd76a94..1bce18d 100644 --- a/src/trainer.rs +++ b/src/trainer.rs @@ -2,7 +2,7 @@ use std::path::Path; use std::sync::atomic::AtomicBool; use std::sync::Arc; -use crate::adaboost::AdaBoost; +use crate::adaboost::{AdaBoost, Metrics}; /// Trainer struct for managing the AdaBoost training process. /// It initializes the AdaBoost learner with the specified parameters, @@ -74,13 +74,12 @@ impl Trainer { &mut self, running: Arc, model_path: &Path, - ) -> Result<(), Box> { + ) -> Result> { self.learner.train(running.clone()); // Save the trained model to the specified file self.learner.save_model(model_path)?; - self.learner.show_result(); - Ok(()) + Ok(self.learner.get_metrics()) } }