diff --git a/Cargo.lock b/Cargo.lock index a878421..16406bb 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -166,6 +166,34 @@ version = "1.15.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "48c757948c5ede0e46177b7add2e67155f70e33c07fea8284df6576da70b3719" +[[package]] +name = "errno" +version = "0.3.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cea14ef9355e3beab063703aa9dab15afd25f0667c341310c1e5274bb1d0da18" +dependencies = [ + "libc", + "windows-sys", +] + +[[package]] +name = "fastrand" +version = "2.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "37909eebbb50d72f9059c3b6d82c0463f2ff062c9e95845c43a6c9c0355411be" + +[[package]] +name = "getrandom" +version = "0.3.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "26145e563e54f2cadc477553f1ec5ee650b00862f0a58bcd12cbdc5f0ea2d2f4" +dependencies = [ + "cfg-if", + "libc", + "r-efi", + "wasi", +] + [[package]] name = "heck" version = "0.5.0" @@ -190,9 +218,15 @@ version = "0.2.172" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d750af042f7ef4f724306de029d18836c26c1765a54a6a3f094cbd23a7267ffa" +[[package]] +name = "linux-raw-sys" +version = "0.9.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cd945864f07fe9f5371a27ad7b52a172b4b499999f1d97574c9fa68373937e12" + [[package]] name = "litsea" -version = "0.1.0" +version = "0.2.0" dependencies = [ "clap", "ctrlc", @@ -200,6 +234,7 @@ dependencies = [ "regex", "serde", "serde_json", + "tempfile", ] [[package]] @@ -220,6 +255,12 @@ dependencies = [ "libc", ] +[[package]] +name = "once_cell" +version = "1.21.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "42f5e15c9953c5e4ccceeb2e7382a716482c34515315f7b03532b8b4e8393d2d" + [[package]] name = "once_cell_polyfill" version = "1.70.1" @@ -244,6 +285,12 @@ dependencies = [ "proc-macro2", ] +[[package]] +name = "r-efi" +version = "5.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "74765f6d916ee2faa39bc8e68e4f3ed8949b48cccdac59983d287a7cb71ce9c5" + [[package]] name = "rayon" version = "1.10.0" @@ -293,6 +340,19 @@ version = "0.8.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2b15c43186be67a4fd63bee50d0303afffcef381492ebe2c5d87f324e1b8815c" +[[package]] +name = "rustix" +version = "1.0.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c71e83d6afe7ff64890ec6b71d6a69bb8a610ab78ce364b3352876bb4c801266" +dependencies = [ + "bitflags", + "errno", + "libc", + "linux-raw-sys", + "windows-sys", +] + [[package]] name = "ryu" version = "1.0.20" @@ -348,6 +408,19 @@ dependencies = [ "unicode-ident", ] +[[package]] +name = "tempfile" +version = "3.20.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e8a64e3985349f2441a1a9ef0b853f869006c3855f2cda6862a94d26ebb9d6a1" +dependencies = [ + "fastrand", + "getrandom", + "once_cell", + "rustix", + "windows-sys", +] + [[package]] name = "unicode-ident" version = "1.0.18" @@ -360,6 +433,15 @@ version = "0.2.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "06abde3611657adf66d383f00b093d7faecc7fa57071cce2578660c9f1010821" +[[package]] +name = "wasi" +version = "0.14.2+wasi-0.2.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9683f9a5a998d873c0d21fcbe3c083009670149a8fab228644b8bd36b2c48cb3" +dependencies = [ + "wit-bindgen-rt", +] + [[package]] name = "windows-sys" version = "0.59.0" @@ -432,3 +514,12 @@ name = "windows_x86_64_msvc" version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "589f6da84c646204747d1270a2a5661ea66ed1cced2631d546fdfb155959f9ec" + +[[package]] +name = "wit-bindgen-rt" +version = "0.39.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6f42320e61fe2cfd34354ecb597f86f413484a798ba44a8ca1165c58d42da6c1" +dependencies = [ + "bitflags", +] diff --git a/Cargo.toml b/Cargo.toml index 0652eae..f4e8db2 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "litsea" -version = "0.1.0" +version = "0.2.0" edition = "2021" description = "Litsea is an extreamely compact word segmentation and model training tool implemented in Rust." documentation = "https://docs.rs/litsea" @@ -12,7 +12,7 @@ categories = ["text-processing"] license = "MIT" [features] -default = [] # No directories included +default = [] [dependencies] clap = { version = "4.5.39", features = ["derive"] } @@ -21,3 +21,6 @@ rayon = "1.10.0" regex = "1.10.5" serde = { version = "1.0.219", features = ["derive"] } serde_json = "1.0.140" + +[dev-dependencies] +tempfile = "3.20.0" diff --git a/README.md b/README.md index 09bd431..6d0ad4d 100644 --- a/README.md +++ b/README.md @@ -64,6 +64,12 @@ Extract the information and features from the corpus: ./target/release/litsea extract ./resources/corpus.txt ./resources/features.txt ``` +The output from the `extract` command is similar to: + +```text +Feature extraction completed successfully. +``` + Train the features output by the above command using AdaBoost. Training stops if the new weak classifier’s accuracy falls below 0.001 or after 10,000 iterations. ```sh @@ -74,13 +80,17 @@ The output from the `train` command is similar to: ```text finding instances...: 61 instances found - +loading instances...: 61/61 instances loaded Iteration 9999 - margin: 0.16068839956263622 -Result: -Accuracy: 100.00% (61 / 61) -Precision: 100.00% (24 / 24) -Recall: 100.00% (24 / 24) -Confusion Matrix: TP: 24, FP: 0, FN: 0, TN: 37 +Result Metrics: + Accuracy: 100.00% ( 61 / 61 ) + Precision: 100.00% ( 24 / 24 ) + Recall: 100.00% ( 24 / 24 ) + Confusion Matrix: + True Positives: 24 + False Positives: 0 + False Negatives: 0 + True Negatives: 37 ``` ## How to segment sentences into words diff --git a/rustfmt.toml b/rustfmt.toml new file mode 100644 index 0000000..f1add26 --- /dev/null +++ b/rustfmt.toml @@ -0,0 +1,80 @@ +max_width = 100 +hard_tabs = false +tab_spaces = 4 +newline_style = "Auto" +# indent_style = "Block" +use_small_heuristics = "Default" +fn_call_width = 80 +attr_fn_like_width = 70 +struct_lit_width = 18 +struct_variant_width = 35 +array_width = 80 +chain_width = 80 +single_line_if_else_max_width = 80 +single_line_let_else_max_width = 80 +# wrap_comments = false +# format_code_in_doc_comments = false +# doc_comment_code_block_width = 100 +# comment_width = 80 +# normalize_comments = false +# normalize_doc_attributes = false +# format_strings = false +# format_macro_matchers = false +# format_macro_bodies = true +# skip_macro_invocations = [] +# hex_literal_case = "Preserve" +# empty_item_single_line = true +# struct_lit_single_line = true +# fn_single_line = false +# where_single_line = false +# imports_indent = "Block" +# imports_layout = "Mixed" +# imports_granularity = "Preserve" +# group_imports = "Preserve" +reorder_imports = true +reorder_modules = true +# reorder_impl_items = false +# type_punctuation_density = "Wide" +# space_before_colon = false +# space_after_colon = true +# spaces_around_ranges = false +# binop_separator = "Front" +remove_nested_parens = true +# combine_control_expr = true +short_array_element_width_threshold = 10 +# overflow_delimited_expr = false +# struct_field_align_threshold = 0 +# enum_discrim_align_threshold = 0 +# match_arm_blocks = true +match_arm_leading_pipes = "Never" +# force_multiline_blocks = false +fn_params_layout = "Tall" +# brace_style = "SameLineWhere" +# control_brace_style = "AlwaysSameLine" +# trailing_semicolon = true +# trailing_comma = "Vertical" +match_block_trailing_comma = false +# blank_lines_upper_bound = 1 +# blank_lines_lower_bound = 0 +edition = "2015" +style_edition = "2015" +# version = "One" +# inline_attribute_width = 0 +# format_generated_files = true +# generated_marker_line_search_limit = 5 +merge_derives = true +use_try_shorthand = false +use_field_init_shorthand = false +force_explicit_abi = true +# condense_wildcard_suffixes = false +# color = "Auto" +# required_version = "1.8.0" +# unstable_features = false +disable_all_formatting = false +# skip_children = false +# show_parse_errors = true +# error_on_line_overflow = false +# error_on_unformatted = false +# ignore = [] +# emit_mode = "Files" +# make_backup = false diff --git a/src/adaboost.rs b/src/adaboost.rs index d2125f1..7dd9cb9 100644 --- a/src/adaboost.rs +++ b/src/adaboost.rs @@ -1,16 +1,36 @@ use std::collections::{BTreeMap, HashMap, HashSet}; use std::fs::File; use std::io::{BufRead, BufReader, Write}; +use std::path::Path; use std::sync::atomic::{AtomicBool, Ordering}; use std::sync::Arc; type Label = i8; +/// Structure to hold evaluation metrics. +pub struct Metrics { + /// Accuracy in percentage (%) + pub accuracy: f64, + /// Precision in percentage (%) + pub precision: f64, + /// Recall in percentage (%) + pub recall: f64, + /// Number of instances in the dataset + pub num_instances: usize, + /// True Positives count + pub true_positives: usize, + /// False Positives count + pub false_positives: usize, + /// False Negatives count + pub false_negatives: usize, + /// True Negatives count + pub true_negatives: usize, +} + /// AdaBoost implementation for binary classification /// This implementation uses a simple feature extraction method /// and is designed for educational purposes. /// It is not optimized for performance or large datasets. -/// #[derive(Debug)] pub struct AdaBoost { pub threshold: f64, @@ -26,11 +46,15 @@ pub struct AdaBoost { } impl AdaBoost { - /// Creates a new [`AdaBoost`]. + /// Creates a new instance of [`AdaBoost`]. + /// This method initializes the AdaBoost parameters such as threshold, + /// number of iterations, and number of threads. + /// /// # Arguments /// * `threshold`: The threshold for stopping the training. /// * `num_iterations`: The maximum number of iterations for training. /// * `num_threads`: The number of threads to use for training (not used in this implementation). + /// /// # Returns: A new instance of [`AdaBoost`]. pub fn new(threshold: f64, num_iterations: usize, num_threads: usize) -> Self { AdaBoost { @@ -49,11 +73,14 @@ impl AdaBoost { /// Initializes the features from a file. /// The file should contain lines with a label followed by space-separated features. + /// /// # Arguments /// * `filename`: The path to the file containing the features. + /// /// # Returns: A result indicating success or failure. + /// /// # Errors: Returns an error if the file cannot be opened or read. - pub fn initialize_features(&mut self, filename: &str) -> std::io::Result<()> { + pub fn initialize_features(&mut self, filename: &Path) -> std::io::Result<()> { let file = File::open(filename)?; let reader = BufReader::new(file); let mut map = BTreeMap::new(); // preserve order @@ -65,22 +92,20 @@ impl AdaBoost { let line = line?; let mut parts = line.split_whitespace(); let _label = parts.next(); + for h in parts { map.entry(h.to_string()).or_insert(0.0); buf_size += 1; } + self.num_instances += 1; if self.num_instances % 1000 == 0 { - eprint!( - "\rfinding instances...: {} instances found", - self.num_instances - ); + eprint!("\rfinding instances...: {} instances found", self.num_instances); } } - eprintln!( - "\rfinding instances...: {} instances found", - self.num_instances - ); + + eprintln!("\rfinding instances...: {} instances found", self.num_instances); + map.insert("".to_string(), 0.0); self.features = map.keys().cloned().collect(); @@ -96,11 +121,14 @@ impl AdaBoost { /// Initializes the instances from a file. /// The file should contain lines with a label followed by space-separated features. + /// /// # Arguments /// * `filename`: The path to the file containing the instances. + /// /// # Returns: A result indicating success or failure. + /// /// # Errors: Returns an error if the file cannot be opened or read. - pub fn initialize_instances(&mut self, filename: &str) -> std::io::Result<()> { + pub fn initialize_instances(&mut self, filename: &Path) -> std::io::Result<()> { let file = File::open(filename)?; let reader = BufReader::new(file); let bias = self.get_bias(); @@ -123,9 +151,7 @@ impl AdaBoost { let end = self.instances_buf.len(); self.instances.push((start, end)); - self.instance_weights - .push((-2.0 * label as f64 * score).exp()); - + self.instance_weights.push((-2.0 * label as f64 * score).exp()); if self.instance_weights.len() % 1000 == 0 { eprint!( "\rloading instances...: {}/{} instances loaded", @@ -134,15 +160,21 @@ impl AdaBoost { ); } } - eprintln!(); + + eprintln!( + "\rloading instances...: {}/{} instances loaded", + self.instance_weights.len(), + self.num_instances + ); + Ok(()) } /// Trains the AdaBoost model. /// This method iteratively updates the model based on the training data. + /// /// # Arguments /// * `running`: An `Arc` to control the running state of the training process. - /// # Notes: The training process will stop if `running` is set to false. pub fn train(&mut self, running: Arc) { let num_features = self.features.len(); @@ -173,7 +205,7 @@ impl AdaBoost { // Find the best hypothesis let mut h_best = 0; let mut best_error_rate = positive_weight_sum / instance_weight_sum; - for h in 1..num_features { + for (h, _) in errors.iter().enumerate().take(num_features).skip(1) { let mut e = errors[h] + positive_weight_sum; e /= instance_weight_sum; if (0.5 - e).abs() > (0.5 - best_error_rate).abs() { @@ -182,11 +214,7 @@ impl AdaBoost { } } - eprint!( - "\rIteration {} - margin: {}", - t, - (0.5 - best_error_rate).abs() - ); + eprint!("\rIteration {} - margin: {}", t, (0.5 - best_error_rate).abs()); if (0.5 - best_error_rate).abs() < self.threshold { break; } @@ -202,11 +230,7 @@ impl AdaBoost { let label = self.labels[i]; let (start, end) = self.instances[i]; let hs = &self.instances_buf[start..end]; - let prediction = if hs.binary_search(&h_best).is_ok() { - 1 - } else { - -1 - }; + let prediction = if hs.binary_search(&h_best).is_ok() { 1 } else { -1 }; if label * prediction < 0 { self.instance_weights[i] *= alpha_exp; } else { @@ -226,13 +250,14 @@ impl AdaBoost { /// Saves the trained model to a file. /// The model is saved in a format where each line contains a feature and its weight, /// with the last line containing the bias term. + /// /// # Arguments /// * `filename`: The path to the file where the model will be saved. + /// /// # Returns: A result indicating success or failure. + /// /// # Errors: Returns an error if the file cannot be created or written to. - /// # Notes: The bias term is calculated as the negative sum of the weights divided by 2. - /// The model is saved in a way that can be easily loaded later. - pub fn save_model(&self, filename: &str) -> std::io::Result<()> { + pub fn save_model(&self, filename: &Path) -> std::io::Result<()> { let mut file = File::create(filename)?; let mut bias = -self.model[0]; for (h, &w) in self.features.iter().zip(self.model.iter()).skip(1) { @@ -248,13 +273,14 @@ impl AdaBoost { /// Loads a model from a file. /// The file should contain lines with a feature and its weight, /// with the last line containing the bias term. + /// /// # Arguments /// * `filename`: The path to the file containing the model. + /// /// # Returns: A result indicating success or failure. + /// /// # Errors: Returns an error if the file cannot be opened or read. - /// # Notes: The model is loaded into the `features` and `model` vectors, - /// and the bias is calculated as the negative sum of the weights divided by 2. - pub fn load_model(&mut self, filename: &str) -> std::io::Result<()> { + pub fn load_model(&mut self, filename: &Path) -> std::io::Result<()> { let file = File::open(filename)?; let reader = BufReader::new(file); let mut m: HashMap = HashMap::new(); @@ -280,74 +306,12 @@ impl AdaBoost { Ok(()) } - /// Gets the bias term of the model. - /// The bias is calculated as the negative sum of the model weights divided by 2. - /// # Returns: The bias term as a `f64`. - /// # Notes: This is used to adjust the decision boundary of the model. - pub fn get_bias(&self) -> f64 { - -self.model.iter().sum::() / 2.0 - } - - /// Displays the result of the model's performance on the training data. - /// It calculates accuracy, precision, recall, and confusion matrix. - /// # Notes: This method iterates through the instances, calculates the score for each, - /// and counts true positives, false positives, true negatives, and false negatives. - pub fn show_result(&self) { - let bias = self.get_bias(); - let mut pp = 0; - let mut pn = 0; - let mut np = 0; - let mut nn = 0; - - for i in 0..self.num_instances { - let label = self.labels[i]; - let (start, end) = self.instances[i]; - let mut score = bias; - for &h in &self.instances_buf[start..end] { - score += self.model[h]; - } - - if score >= 0.0 { - if label > 0 { - pp += 1 - } else { - pn += 1 - } - } else { - if label > 0 { - np += 1 - } else { - nn += 1 - } - } - } - - let acc = (pp + nn) as f64 / self.num_instances as f64 * 100.0; - let prec = pp as f64 / (pp + pn).max(1) as f64 * 100.0; - let recall = pp as f64 / (pp + np).max(1) as f64 * 100.0; - - eprintln!("Result:"); - eprintln!( - "Accuracy: {:.2}% ({} / {})", - acc, - pp + nn, - self.num_instances - ); - eprintln!("Precision: {:.2}% ({} / {})", prec, pp, pp + pn); - eprintln!("Recall: {:.2}% ({} / {})", recall, pp, pp + np); - eprintln!( - "Confusion Matrix: TP: {}, FP: {}, FN: {}, TN: {}", - pp, pn, np, nn - ); - } - /// Adds a new instance to the model. /// The instance is represented by a set of attributes and a label. + /// /// # Arguments /// * `attributes`: A `HashSet` containing the attributes of the instance. /// * `label`: The label of the instance, represented as an `i8`. - /// # Notes: The attributes are sorted and added to the `features` vector if they do not already exist. - /// The instance is stored in `instances_buf`, and its start and end indices are recorded in `instances`. pub fn add_instance(&mut self, attributes: HashSet, label: i8) { let start = self.instances_buf.len(); let mut attrs: Vec = attributes.into_iter().collect(); @@ -370,10 +334,11 @@ impl AdaBoost { } /// Predicts the label for a given set of attributes. + /// /// # Arguments /// * `attributes`: A `HashSet` containing the attributes to predict. + /// /// # Returns: The predicted label as an `i8`, where 1 indicates a positive prediction and -1 indicates a negative prediction. - /// # Notes: The prediction is made by calculating the score based on the model weights for the given attributes. pub fn predict(&self, attributes: HashSet) -> i8 { let mut score = 0.0; for attr in attributes { @@ -387,4 +352,223 @@ impl AdaBoost { -1 } } + + /// Gets the bias term of the model. + /// The bias is calculated as the negative sum of the model weights divided by 2. + /// + /// # Returns:The bias term as a `f64`. + pub fn get_bias(&self) -> f64 { + -self.model.iter().sum::() / 2.0 + } + + /// Calculates and returns the performance metrics of the model on the training data. + pub fn get_metrics(&self) -> Metrics { + let bias = self.get_bias(); + let mut true_positives = 0; // true positives + let mut false_positives = 0; // false positives + let mut false_negatives = 0; // false negatives + let mut true_negatives = 0; // true negatives + + for i in 0..self.num_instances { + let label = self.labels[i]; + let (start, end) = self.instances[i]; + let mut score = bias; + for &h in &self.instances_buf[start..end] { + score += self.model[h]; + } + if score >= 0.0 { + if label > 0 { + true_positives += 1; + } else { + false_positives += 1; + } + } else if label > 0 { + false_negatives += 1; + } else { + true_negatives += 1; + } + } + + let accuracy = (true_positives + true_negatives) as f64 / self.num_instances as f64 * 100.0; + let precision = + true_positives as f64 / (true_positives + false_positives).max(1) as f64 * 100.0; + let recall = + true_positives as f64 / (true_positives + false_negatives).max(1) as f64 * 100.0; + + Metrics { + accuracy, + precision, + recall, + num_instances: self.num_instances, + true_positives, + false_positives, + false_negatives, + true_negatives, + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + use std::collections::HashSet; + use std::io::Write; + use std::sync::atomic::AtomicBool; + use std::sync::Arc; + + use tempfile::NamedTempFile; + + #[test] + fn test_initialize_features() -> std::io::Result<()> { + // Create a dummy features file + let mut features_file = NamedTempFile::new()?; + writeln!(features_file, "1 feat1 feat2")?; + writeln!(features_file, "0 feat3")?; + features_file.as_file().sync_all()?; + + let mut learner = AdaBoost::new(0.01, 10, 1); + learner.initialize_features(features_file.path())?; + + // Features is an ordered set that should contain ""(empty string), "feat1", "feat2", "feat3" + assert!(learner.features.contains(&"".to_string())); + assert!(learner.features.contains(&"feat1".to_string())); + assert!(learner.features.contains(&"feat2".to_string())); + assert!(learner.features.contains(&"feat3".to_string())); + Ok(()) + } + + #[test] + fn test_initialize_instances() -> std::io::Result<()> { + // First, initialize features in the feature file. + let mut features_file = NamedTempFile::new()?; + writeln!(features_file, "1 feat1 feat2")?; + features_file.as_file().sync_all()?; + + let mut learner = AdaBoost::new(0.01, 10, 1); + learner.initialize_features(features_file.path())?; + + // Create a dummy instance file + let mut instance_file = NamedTempFile::new()?; + // Example: "1 feat1" line. The learner will consider feat1 as a candidate if found by binary_search. + writeln!(instance_file, "1 feat1")?; + instance_file.as_file().sync_all()?; + + learner.initialize_instances(instance_file.path())?; + + // The number of instances should be 1, and the instance_weights, labels, and instances should be updated accordingly. + assert_eq!(learner.num_instances, 1); + assert_eq!(learner.labels.len(), 1); + assert_eq!(learner.instance_weights.len(), 1); + assert_eq!(learner.instances.len(), 1); + + Ok(()) + } + + #[test] + fn test_train() -> std::io::Result<()> { + // Initialize features using a features file. + let mut features_file = NamedTempFile::new()?; + writeln!(features_file, "1 feat1 feat2")?; + features_file.as_file().sync_all()?; + + let mut learner = AdaBoost::new(0.01, 3, 1); + learner.initialize_features(features_file.path())?; + + // Create a dummy instance file with one instance. + let mut instance_file = NamedTempFile::new()?; + writeln!(instance_file, "1 feat1")?; + instance_file.as_file().sync_all()?; + learner.initialize_instances(instance_file.path())?; + + // Set running to false to immediately exit the learning loop. + let running = Arc::new(AtomicBool::new(false)); + learner.train(running.clone()); + + // If normalization of model or instance_weights is performed after learning, it should be OK. + let weight_sum: f64 = learner.instance_weights.iter().sum(); + + // weight_sum should be normalized to 1.0. + assert!((weight_sum - 1.0).abs() < 1e-6); + + Ok(()) + } + + #[test] + fn test_save_and_load_model() -> std::io::Result<()> { + // Prepare a dummy learner. + let mut learner = AdaBoost::new(0.01, 10, 1); + + // Set the features and weights in advance. + learner.features = vec!["feat1".to_string(), "feat2".to_string()]; + learner.model = vec![0.5, -0.3]; + + // Save the model to a temporary file. + let temp_model = NamedTempFile::new()?; + learner.save_model(temp_model.path())?; + + // Load the model with a new learner. + let mut learner2 = AdaBoost::new(0.01, 10, 1); + learner2.load_model(temp_model.path())?; + + // Check that the number of features and models match. + assert_eq!(learner2.features.len(), learner.features.len()); + assert_eq!(learner2.model.len(), learner.model.len()); + + Ok(()) + } + + #[test] + fn test_add_instance_and_predict() { + let mut learner = AdaBoost::new(0.01, 10, 1); + + // Here, features and model are empty in the initial state. They are newly registered by add_instance. + let mut attrs = HashSet::new(); + attrs.insert("A".to_string()); + learner.add_instance(attrs.clone(), 1); + + // When the same attribute is passed to predict, score returns 1 based on the initial model value (0.0) (because score>=0). + let prediction = learner.predict(attrs); + assert_eq!(prediction, 1); + } + + #[test] + fn test_get_bias() { + let mut learner = AdaBoost::new(0.01, 10, 1); + + // Set model weights as an example. + learner.model = vec![0.2, 0.3, -0.1]; + + // bias = -sum(model)/2 = -(0.2+0.3-0.1)/2 = -0.4/2 = -0.2 + assert!((learner.get_bias() + 0.2).abs() < 1e-6); + } + + #[test] + fn test_get_metrics() { + let mut learner = AdaBoost::new(0.01, 10, 1); + + // Set features and model for prediction + learner.features = vec!["A".to_string(), "B".to_string()]; + learner.model = vec![0.5, -1.0]; + + // Instance 1: Attribute “A” → score = 0.25 + 0.5 = 0.75 (positive example) + let mut attrs1 = HashSet::new(); + attrs1.insert("A".to_string()); + learner.add_instance(attrs1, 1); + + // Instance 2: Attribute “B” → score = 0.25 + (-1.0) = -0.75 (negative example) + let mut attrs2 = HashSet::new(); + attrs2.insert("B".to_string()); + learner.add_instance(attrs2, -1); + + let metrics = learner.get_metrics(); + assert_eq!(metrics.true_positives, 1); + assert_eq!(metrics.true_negatives, 1); + assert_eq!(metrics.false_positives, 0); + assert_eq!(metrics.false_negatives, 0); + assert_eq!(metrics.num_instances, 2); + + // Since this is a simple case, the accuracy is 100%. + assert!((metrics.accuracy - 100.0).abs() < 1e-6); + } } diff --git a/src/extractor.rs b/src/extractor.rs new file mode 100644 index 0000000..fbb7b92 --- /dev/null +++ b/src/extractor.rs @@ -0,0 +1,124 @@ +use std::collections::HashSet; +use std::error::Error; +use std::fs::File; +use std::io::{self, BufRead, Write}; +use std::path::Path; + +use crate::segmenter::Segmenter; + +/// Extractor struct for processing text data and extracting features. +/// It reads sentences from a corpus file, segments them into words, +/// and writes the extracted features to a specified output file. +pub struct Extractor { + segmenter: Segmenter, +} + +impl Default for Extractor { + /// Creates a new instance of [`Extractor`] with default settings. + /// + /// # Returns + /// Returns a new instance of `Extractor`. + fn default() -> Self { + Self::new() + } +} + +impl Extractor { + /// Creates a new instance of [`Extractor`]. + /// + /// # Returns + /// Returns a new instance of `Extractor` with a new `Segmenter`. + pub fn new() -> Self { + Extractor { + segmenter: Segmenter::new(None), + } + } + + /// Extracts features from a corpus file and writes them to a specified output file. + /// + /// # Arguments + /// * `corpus_path` - The path to the input corpus file containing sentences. + /// * `features_path` - The path to the output file where extracted features will be written. + /// + /// # Returns + /// Returns a Result indicating success or failure. + pub fn extract( + &mut self, + corpus_path: &Path, + features_path: &Path, + ) -> Result<(), Box> { + // Read sentences from stdin + // Each line is treated as a separate sentence + let corpus_file = File::open(corpus_path)?; + let corpus = io::BufReader::new(corpus_file); + + // Create a file to write the features + let features_file = File::create(features_path)?; + let mut features = io::BufWriter::new(features_file); + + // learner function to write features + // This function will be called for each word in the input sentences + // It takes a set of attributes and a label, and writes them to stdout + let mut learner = |attributes: HashSet, label: i8| { + let mut attrs: Vec = attributes.into_iter().collect(); + attrs.sort(); + let mut line = vec![label.to_string()]; + line.extend(attrs); + writeln!(features, "{}", line.join("\t")).expect("Failed to write features"); + }; + + for line in corpus.lines() { + match line { + Ok(line) => { + let line = line.trim(); + if !line.is_empty() { + self.segmenter.add_sentence_with_writer(line, &mut learner); + } + } + Err(err) => { + eprintln!("Error reading input: {}", err); + } + } + } + + Ok(()) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + use std::fs::File; + use std::io::{Read, Write}; + + use tempfile::NamedTempFile; + + #[test] + fn test_extract() -> Result<(), Box> { + // Create a temporary file to simulate the corpus input + let mut corpus_file = NamedTempFile::new()?; + writeln!(corpus_file, "これ は テスト です 。")?; + writeln!(corpus_file, "別 の 文 も あり ます 。")?; + corpus_file.as_file().sync_all()?; + + // Create a temporary file for the features output + let features_file = NamedTempFile::new()?; + + // Create an instance of Extractor and extract features + let mut extractor = Extractor::new(); + extractor.extract(corpus_file.path(), features_file.path())?; + + // Read the output from the features file + let mut output = String::new(); + File::open(features_file.path())?.read_to_string(&mut output)?; + + // Check if the output is not empty + assert!(!output.is_empty(), "Extracted features should not be empty"); + + // Check if the output contains tab-separated values + assert!(output.contains("\t"), "Output should contain tab-separated values"); + + Ok(()) + } +} diff --git a/src/lib.rs b/src/lib.rs index feca1db..3d62349 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,5 +1,7 @@ pub mod adaboost; +pub mod extractor; pub mod segmenter; +pub mod trainer; const VERERSION: &str = env!("CARGO_PKG_VERSION"); diff --git a/src/main.rs b/src/main.rs index 6bccaae..06d9692 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,16 +1,18 @@ -use std::collections::HashSet; use std::error::Error; -use std::fs::File; use std::io::{self, BufRead, Write}; +use std::path::PathBuf; use std::sync::atomic::{AtomicBool, Ordering}; use std::sync::Arc; use clap::{Args, Parser, Subcommand}; use litsea::adaboost::AdaBoost; +use litsea::extractor::Extractor; use litsea::get_version; use litsea::segmenter::Segmenter; +use litsea::trainer::Trainer; +/// Arguments for the extract command. #[derive(Debug, Args)] #[clap( author, @@ -18,10 +20,11 @@ use litsea::segmenter::Segmenter; version = get_version(), )] struct ExtractArgs { - corpus_file: String, - features_file: String, + corpus_file: PathBuf, + features_file: PathBuf, } +/// Arguments for the train command. #[derive(Debug, Args)] #[clap(author, about = "Train a segmenter", @@ -38,21 +41,23 @@ struct TrainArgs { num_threads: usize, #[arg(short = 'm', long)] - load_model: Option, + load_model_file: Option, - instances_file: String, - model_file: String, + features_file: PathBuf, + model_file: PathBuf, } +/// Arguments for the segment command. #[derive(Debug, Args)] #[clap(author, about = "Segment a sentence", version = get_version(), )] struct SegmentArgs { - model_file: String, + model_file: PathBuf, } +/// Subcommands for lietsea CLI. #[derive(Debug, Subcommand)] enum Commands { Extract(ExtractArgs), @@ -60,6 +65,7 @@ enum Commands { Segment(SegmentArgs), } +/// Arguments for the litsea command. #[derive(Debug, Parser)] #[clap( name = "litsea", @@ -72,48 +78,33 @@ struct CommandArgs { command: Commands, } +/// Extract features from a corpus file and write them to a specified output file. +/// This function reads sentences from the corpus file, segments them into words, +/// and writes the extracted features to the output file. +/// +/// # Arguments +/// * `args` - The arguments for the extract command [`ExtractArgs`]. +/// +/// # Returns +/// Returns a Result indicating success or failure. fn extract(args: ExtractArgs) -> Result<(), Box> { - // Create a file to write the features - let features_file = File::create(&args.features_file)?; - let mut features = io::BufWriter::new(features_file); - - // Initialize the segmenter - // No model is loaded, so it will use the default feature extraction - let mut segmenter = Segmenter::new(None); - - // learner function to write features - // This function will be called for each word in the input sentences - // It takes a set of attributes and a label, and writes them to stdout - let mut learner = |attributes: HashSet, label: i8| { - let mut attrs: Vec = attributes.into_iter().collect(); - attrs.sort(); - let mut line = vec![label.to_string()]; - line.extend(attrs); - writeln!(features, "{}", line.join("\t")).expect("Failed to write features"); - }; - - // Read sentences from stdin - // Each line is treated as a separate sentence - let corpus_file = File::open(&args.corpus_file)?; - let corpus = io::BufReader::new(corpus_file); - - for line in corpus.lines() { - match line { - Ok(line) => { - let line = line.trim(); - if !line.is_empty() { - segmenter.add_sentence_with_writer(line, &mut learner); - } - } - Err(err) => { - eprintln!("Error reading input: {}", err); - } - } - } + let mut extractor = Extractor::new(); + + extractor.extract(args.corpus_file.as_path(), args.features_file.as_path())?; + eprintln!("Feature extraction completed successfully."); Ok(()) } +/// Train a segmenter using the provided arguments. +/// This function initializes a Trainer with the specified parameters, +/// loads a model if specified, and trains the model using the features file. +/// +/// # Arguments +/// * `args` - The arguments for the train command [`TrainArgs`]. +/// +/// # Returns +/// Returns a Result indicating success or failure. fn train(args: TrainArgs) -> Result<(), Box> { let running = Arc::new(AtomicBool::new(true)); let r = running.clone(); @@ -127,56 +118,82 @@ fn train(args: TrainArgs) -> Result<(), Box> { }) .expect("Error setting Ctrl-C handler"); - let mut boost = AdaBoost::new(args.threshold, args.num_iterations, args.num_threads); + let mut trainer = Trainer::new( + args.threshold, + args.num_iterations, + args.num_threads, + args.features_file.as_path(), + ); - if let Some(model_path) = args.load_model.as_ref() { - boost.load_model(model_path).unwrap(); + if let Some(model_path) = &args.load_model_file { + trainer.load_model(model_path.as_path())?; } - boost.initialize_features(&args.instances_file).unwrap(); - boost.initialize_instances(&args.instances_file).unwrap(); - - boost.train(running.clone()); - boost.save_model(&args.model_file).unwrap(); - boost.show_result(); + let metrics = trainer.train(running, args.model_file.as_path())?; + + eprintln!("Result Metrics:"); + eprintln!( + " Accuracy: {:.2}% ( {} / {} )", + metrics.accuracy, + metrics.true_positives + metrics.true_negatives, + metrics.num_instances + ); + eprintln!( + " Precision: {:.2}% ( {} / {} )", + metrics.precision, + metrics.true_positives, + metrics.true_positives + metrics.false_positives + ); + eprintln!( + " Recall: {:.2}% ( {} / {} )", + metrics.recall, + metrics.true_positives, + metrics.true_positives + metrics.false_negatives + ); + eprintln!( + " Confusion Matrix:\n True Positives: {}\n False Positives: {}\n False Negatives: {}\n True Negatives: {}", + metrics.true_positives, + metrics.false_positives, + metrics.false_negatives, + metrics.true_negatives + ); Ok(()) } +/// Segment a sentence using the trained model. +/// This function loads the AdaBoost model from the specified file, +/// reads sentences from standard input, segments them into words, +/// and writes the segmented sentences to standard output. +/// +/// # Arguments +/// * `args` - The arguments for the segment command [`SegmentArgs`]. +/// +/// # Returns +/// Returns a Result indicating success or failure. fn segment(args: SegmentArgs) -> Result<(), Box> { - let model_path = &args.model_file; + let mut leaner = AdaBoost::new(0.01, 100, 1); + leaner.load_model(args.model_file.as_path())?; - let mut model = AdaBoost::new(0.01, 100, 1); - if let Err(e) = model.load_model(model_path) { - eprintln!("Failed to load model: {}", e); - std::process::exit(1); - } - - let segmenter = Segmenter::new(Some(model)); + let segmenter = Segmenter::new(Some(leaner)); let stdin = io::stdin(); let stdout = io::stdout(); let mut writer = io::BufWriter::new(stdout.lock()); for line in stdin.lock().lines() { - match line { - Ok(line) => { - let line = line.trim(); - if line.is_empty() { - continue; - } - let tokens = segmenter.parse(line); - writeln!(writer, "{}", tokens.join(" ")).expect("write failed"); - } - Err(err) => { - eprintln!("Error reading input: {}", err); - } + let line = line?; + let line = line.trim(); + if line.is_empty() { + continue; } + let tokens = segmenter.segment(line); + writeln!(writer, "{}", tokens.join(" "))?; } Ok(()) } -fn main() -> Result<(), Box> { +fn run() -> Result<(), Box> { let args = CommandArgs::parse(); match args.command { @@ -185,3 +202,10 @@ fn main() -> Result<(), Box> { Commands::Segment(args) => segment(args), } } + +fn main() { + if let Err(e) = run() { + eprintln!("Error: {}", e); + std::process::exit(1); + } +} diff --git a/src/segmenter.rs b/src/segmenter.rs index 18f2d91..7b69ecb 100644 --- a/src/segmenter.rs +++ b/src/segmenter.rs @@ -3,54 +3,77 @@ use regex::Regex; use std::collections::HashSet; /// Segmenter struct for text segmentation using AdaBoost +/// It uses predefined patterns to classify characters and segments sentences into words. pub struct Segmenter { patterns: Vec<(Regex, &'static str)>, pub learner: AdaBoost, } impl Segmenter { - /// Creates a new Segmenter with the given AdaBoost learner or a default one + /// creates a new instance of [`Segmenter`]. + /// /// # Arguments /// * `learner` - An optional AdaBoost instance. If None, a default AdaBoost instance is created. + /// /// # Returns /// A new Segmenter instance with the specified or default AdaBoost learner. pub fn new(learner: Option) -> Self { let patterns = vec![ - ( - Regex::new(r"[一二三四五六七八九十百千万億兆]").unwrap(), - "M", - ), - (Regex::new(r"[一-龠々〆ヵヶ]").unwrap(), "H"), + // Numbers + (Regex::new(r"[0-90-9]").unwrap(), "N"), + // Japanese Kanji numbers + (Regex::new(r"[一二三四五六七八九十百千万億兆]").unwrap(), "M"), + // Hiragana (Japanese) (Regex::new(r"[ぁ-ん]").unwrap(), "I"), - (Regex::new(r"[ァ-ヴーア-ン゙ー]").unwrap(), "K"), + // Katakana (Japanese) + (Regex::new(r"[ァ-ヴーア-ン゙゚]").unwrap(), "K"), + // Hangul (Korean) + (Regex::new(r"[가-힣]").unwrap(), "G"), + // Thai script + (Regex::new(r"[ก-๛]").unwrap(), "T"), + // Kanji (Japanese) + (Regex::new(r"[一-龠々〆ヵヶ]").unwrap(), "H"), + // Kanji (CJK Unified Ideographs) + (Regex::new(r"[㐀-䶵一-鿿]").unwrap(), "Z"), + // Extended Latin (Vietnamese, etc.) + (Regex::new(r"[À-ÿĀ-ſƀ-ƿǍ-ɏ]").unwrap(), "E"), + // ASCII + Full-width Latin (Regex::new(r"[a-zA-Za-zA-Z]").unwrap(), "A"), - (Regex::new(r"[0-90-9]").unwrap(), "N"), ]; + Segmenter { patterns, learner: learner.unwrap_or_else(|| AdaBoost::new(0.01, 100, 1)), } } - /// gets the type of a character based on predefined patterns + /// Gets the type of a character based on predefined patterns. + /// /// # Arguments /// * `ch` - A string slice representing a single character. + /// /// # Returns - /// A static string representing the type of the character, such as "M", "H", "I", "K", "A", "N", or "O" (for others). - pub fn get_type(&self, ch: &str) -> &'static str { - for (pattern, s_type) in &self.patterns { + /// A string slice representing the type of the character, such as "N" for number, + /// "I" for Hiragana, "K" for Katakana, etc. If the character does not match any pattern, + /// it returns "O" for Other. + pub fn get_type(&self, ch: &str) -> &str { + for (pattern, label) in &self.patterns { if pattern.is_match(ch) { - return s_type; + return label; } } - "O" + "O" // Other } - /// Adds a sentence to the segmenter with a custom writer function + /// Adds a sentence to the segmenter with a custom writer function. + /// /// # Arguments /// * `sentence` - A string slice representing the sentence to be added. - /// * `writer` - A closure that takes a HashSet of attributes and a label (i8) as arguments. - /// This closure is called for each word in the sentence, allowing custom handling of the attributes and label. + /// * `writer` - A closure that takes a `HashSet` of attributes and a label (`i8`) as arguments. + /// + /// This closure is called for each instance created from the sentence. + /// This method processes the sentence, extracts features, and calls the writer function for each instance. + /// It constructs attributes based on the characters and their types, and uses the AdaBoost learner to add instances. pub fn add_sentence_with_writer(&mut self, sentence: &str, mut writer: F) where F: FnMut(HashSet, i8), @@ -91,9 +114,11 @@ impl Segmenter { } } - /// Adds a sentence to the segmenter for training + /// Adds a sentence to the segmenter for training. + /// /// # Arguments /// * `sentence` - A string slice representing the sentence to be added. + /// /// This method processes the sentence, extracts features, and adds them to the AdaBoost learner. /// It constructs attributes based on the characters and their types, and uses the AdaBoost learner to add instances. /// If the sentence is empty or too short, it does nothing. @@ -130,17 +155,19 @@ impl Segmenter { for i in 4..(chars.len() - 3) { let label = if tags[i] == "B" { 1 } else { -1 }; let attrs = self.get_attributes(i, &tags, &chars, &types); - // ★ ここで毎回 self.learner を呼ぶことで借用がぶつからない! + // Call the learner for each instance; doing so individually avoids borrowing conflicts. self.learner.add_instance(attrs, label); } } - /// Parses a sentence and segments it into words + /// Segments a sentence and segments it into words. + /// /// # Arguments /// * `sentence` - A string slice representing the sentence to be parsed. + /// /// # Returns /// A vector of strings, where each string is a segmented word from the sentence. - pub fn parse(&self, sentence: &str) -> Vec { + pub fn segment(&self, sentence: &str) -> Vec { if sentence.is_empty() { return Vec::new(); } @@ -174,12 +201,14 @@ impl Segmenter { result } - /// Gets the attributes for a specific index in the character and type arrays + /// Gets the attributes for a specific index in the character and type arrays. + /// /// # Arguments /// * `i` - The index for which to get the attributes. /// * `tags` - A slice of strings representing the tags for each character. /// * `chars` - A slice of strings representing the characters in the sentence. /// * `types` - A slice of strings representing the types of each character. + /// /// # Returns /// A HashSet of strings representing the attributes for the specified index. fn get_attributes( @@ -254,3 +283,114 @@ impl Segmenter { .collect() } } + +#[cfg(test)] +mod tests { + use super::*; + + use std::path::PathBuf; + + #[test] + fn test_add_sentence_with_writer() { + let mut segmenter = Segmenter::new(None); + let sentence = "テスト です"; + let mut collected = Vec::new(); + + segmenter.add_sentence_with_writer(sentence, |attrs, label| { + collected.push((attrs, label)); + }); + + // There should be as many instances as there are characters (excluding padding) + assert!(!collected.is_empty()); + + // Check that labels are either 1 or -1 + for (_, label) in &collected { + assert!(*label == 1 || *label == -1); + } + + // Check that attributes contain expected keys + let (attrs, _) = &collected[0]; + assert!(attrs.iter().any(|a| a.starts_with("UW"))); + assert!(attrs.iter().any(|a| a.starts_with("UC"))); + } + + #[test] + fn test_add_sentence_empty() { + let mut segmenter = Segmenter::new(None); + segmenter.add_sentence(""); + // Should not panic or add anything + } + + #[test] + fn test_segmenter() { + let sentence = "これはテストです。"; + + let model_file = + PathBuf::from(env!("CARGO_MANIFEST_DIR")).join("./resources").join("RWCP.model"); + let mut learner = AdaBoost::new(0.01, 100, 1); + learner.load_model(model_file.as_path()).unwrap(); + + let mut segmenter = Segmenter::new(Some(learner)); + + segmenter.add_sentence(sentence); + let result = segmenter.segment(sentence); + + assert!(!result.is_empty()); + assert_eq!(result.len(), 5); // Adjust based on expected segmentation + assert_eq!(result[0], "これ"); + assert_eq!(result[1], "は"); + assert_eq!(result[2], "テスト"); + assert_eq!(result[3], "です"); + assert_eq!(result[4], "。"); + } + + #[test] + fn test_segment_empty_sentence() { + let segmenter = Segmenter::new(None); + let result = segmenter.segment(""); + assert!(result.is_empty()); + } + + #[test] + fn test_get_type() { + let segmenter = Segmenter::new(None); + + assert_eq!(segmenter.get_type("あ"), "I"); // Hiragana + assert_eq!(segmenter.get_type("漢"), "H"); // Kanji + assert_eq!(segmenter.get_type("A"), "A"); // Latin + assert_eq!(segmenter.get_type("1"), "N"); // Digit + assert_eq!(segmenter.get_type("@"), "O"); // Not matching any pattern + } + + #[test] + fn test_get_attributes_content() { + let segmenter = Segmenter::new(None); + + let tags = vec!["U".to_string(); 7]; + + let chars = vec![ + "B3".to_string(), // index 0 + "B2".to_string(), // index 1 + "B1".to_string(), // index 2 + "あ".to_string(), // index 3 + "い".to_string(), // index 4 + "う".to_string(), // index 5 + "E1".to_string(), // index 6 + ]; + + let types = vec![ + "O".to_string(), // index 0 + "O".to_string(), // index 1 + "O".to_string(), // index 2 + "O".to_string(), // index 3 + "I".to_string(), // index 4 + "I".to_string(), // index 5 + "O".to_string(), // index 6 + ]; + + let attrs = segmenter.get_attributes(4, &tags, &chars, &types); + assert!(attrs.contains("UW4:い")); + assert!(attrs.contains("UC4:I")); + assert!(attrs.contains("UP3:U")); + } +} diff --git a/src/trainer.rs b/src/trainer.rs new file mode 100644 index 0000000..280e38b --- /dev/null +++ b/src/trainer.rs @@ -0,0 +1,164 @@ +use std::path::Path; +use std::sync::atomic::AtomicBool; +use std::sync::Arc; + +use crate::adaboost::{AdaBoost, Metrics}; + +/// Trainer struct for managing the AdaBoost training process. +/// It initializes the AdaBoost learner with the specified parameters, +/// loads the model from a file, and provides methods to train the model +/// and save the trained model. +pub struct Trainer { + learner: AdaBoost, +} + +impl Trainer { + /// Creates a new instance of [`Trainer`]. + /// + /// # Arguments + /// * `threshold` - The threshold for the AdaBoost algorithm. + /// * `num_iterations` - The number of iterations for the training. + /// * `num_threads` - The number of threads to use for training. + /// * `features_path` - The path to the features file. + /// + /// # Returns + /// Returns a new instance of `Trainer`. + /// + /// # Errors + /// Returns an error if the features or instances cannot be initialized. + pub fn new( + threshold: f64, + num_iterations: usize, + num_threads: usize, + features_path: &Path, + ) -> Self { + let mut learner = AdaBoost::new(threshold, num_iterations, num_threads); + + learner + .initialize_features(features_path) + .expect("Failed to initialize features"); + learner + .initialize_instances(features_path) + .expect("Failed to initialize instances"); + + Trainer { learner } + } + + /// Load Model from a file + /// + /// # Arguments + /// * `model_path` - The path to the model file to load. + /// + /// # Returns + /// Returns a Result indicating success or failure. + /// + /// # Errors + /// Returns an error if the model cannot be loaded. + pub fn load_model(&mut self, model_path: &Path) -> Result<(), Box> { + // Load the model from the specified file + Ok(self.learner.load_model(model_path)?) + } + + /// Train the AdaBoost model. + /// + /// # Arguments + /// * `running` - An Arc to control the running state of the training process. + /// * `model_path` - The path to save the trained model. + /// + /// # Returns + /// Returns a Result indicating success or failure. + /// + /// # Errors + /// Returns an error if the training fails or if the model cannot be saved. + pub fn train( + &mut self, + running: Arc, + model_path: &Path, + ) -> Result> { + self.learner.train(running.clone()); + + // Save the trained model to the specified file + self.learner.save_model(model_path)?; + + Ok(self.learner.get_metrics()) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + use std::io::Write; + use std::sync::atomic::AtomicBool; + use std::sync::Arc; + + use tempfile::NamedTempFile; + + use crate::adaboost::Metrics; + + // Helper: create a dummy features file. + // This file should contain at least one line for initialize_features and initialize_instances. + fn create_dummy_features_file() -> NamedTempFile { + let mut file = NamedTempFile::new().expect("Failed to create temp file for features"); + + // For example, it could contain "1 feature1" to represent one feature. + writeln!(file, "1 feature1").expect("Failed to write to features file"); + file + } + + // Helper: create a dummy model file. + // This file should contain the model weights and bias. + fn create_dummy_model_file() -> NamedTempFile { + let mut file = NamedTempFile::new().expect("Failed to create temp file for model"); + + // For example, it could contain a single feature weight and a bias term. + // The feature line is "BW1:こん -0.1262" and the last line is the bias term "100.0". + writeln!(file, "BW1:こん\t-0.1262").expect("Failed to write feature"); + writeln!(file, "100.0").expect("Failed to write bias"); + file + } + + #[test] + fn test_load_model() -> Result<(), Box> { + // Prepare a dummy features file + let features_file = create_dummy_features_file(); + + // Create a Trainer instance + let mut trainer = Trainer::new(0.01, 10, 1, features_file.path()); + + // Prepare a dummy model file + let model_file = create_dummy_model_file(); + + // Load the model file into the Trainer + // This should not return an error if the model file is correctly formatted. + // If the model file is not correctly formatted, it will return an error. + trainer.load_model(model_file.path())?; + + Ok(()) + } + + #[test] + fn test_train() -> Result<(), Box> { + // Prepare a dummy features file + let features_file = create_dummy_features_file(); + + // Create a Trainer instance with the dummy features file + let mut trainer = Trainer::new(0.01, 5, 1, features_file.path()); + + // Prepare a temporary file for the model output + let model_out = NamedTempFile::new()?; + + // Set AtomicBool to false and immediately exit the learning loop + let running = Arc::new(AtomicBool::new(false)); + + // Execute the train method. + let metrics: Metrics = trainer.train(running, model_out.path())?; + + // Check if the metrics are valie. + // Since metrics are dummy data, we will consider anything 0 or above to be OK here. + assert!(metrics.accuracy >= 0.0); + assert!(metrics.precision >= 0.0); + assert!(metrics.recall >= 0.0); + Ok(()) + } +}