这是indexloc提供的服务,不要输入任何密码
Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
136 changes: 85 additions & 51 deletions src/adaboost.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,26 @@ use std::sync::Arc;

type Label = i8;

/// Structure to hold evaluation metrics.
pub struct Metrics {
/// Accuracy in percentage (%)
pub accuracy: f64,
/// Precision in percentage (%)
pub precision: f64,
/// Recall in percentage (%)
pub recall: f64,
/// Number of instances in the dataset
pub num_instances: usize,
/// True Positives count
pub true_positives: usize,
/// False Positives count
pub false_positives: usize,
/// False Negatives count
pub false_negatives: usize,
/// True Negatives count
pub true_negatives: usize,
}

/// AdaBoost implementation for binary classification
/// This implementation uses a simple feature extraction method
/// and is designed for educational purposes.
Expand Down Expand Up @@ -72,16 +92,20 @@ impl AdaBoost {
let line = line?;
let mut parts = line.split_whitespace();
let _label = parts.next();

for h in parts {
map.entry(h.to_string()).or_insert(0.0);
buf_size += 1;
}

self.num_instances += 1;
if self.num_instances % 1000 == 0 {
eprint!("\rfinding instances...: {} instances found", self.num_instances);
}
}

eprintln!("\rfinding instances...: {} instances found", self.num_instances);

map.insert("".to_string(), 0.0);

self.features = map.keys().cloned().collect();
Expand Down Expand Up @@ -128,7 +152,6 @@ impl AdaBoost {
let end = self.instances_buf.len();
self.instances.push((start, end));
self.instance_weights.push((-2.0 * label as f64 * score).exp());

if self.instance_weights.len() % 1000 == 0 {
eprint!(
"\rloading instances...: {}/{} instances loaded",
Expand All @@ -137,7 +160,13 @@ impl AdaBoost {
);
}
}
eprintln!();

eprintln!(
"\rloading instances...: {}/{} instances loaded",
self.instance_weights.len(),
self.num_instances
);

Ok(())
}

Expand Down Expand Up @@ -277,55 +306,6 @@ impl AdaBoost {
Ok(())
}

/// Gets the bias term of the model.
/// The bias is calculated as the negative sum of the model weights divided by 2.
///
/// # Returns:The bias term as a `f64`.
pub fn get_bias(&self) -> f64 {
-self.model.iter().sum::<f64>() / 2.0
}

/// Displays the result of the model's performance on the training data.
/// It calculates accuracy, precision, recall, and confusion matrix.
pub fn show_result(&self) {
let bias = self.get_bias();
let mut pp = 0;
let mut pn = 0;
let mut np = 0;
let mut nn = 0;

for i in 0..self.num_instances {
let label = self.labels[i];
let (start, end) = self.instances[i];
let mut score = bias;
for &h in &self.instances_buf[start..end] {
score += self.model[h];
}

if score >= 0.0 {
if label > 0 {
pp += 1
} else {
pn += 1
}
} else if label > 0 {
np += 1
} else {
nn += 1
}
}

let acc = (pp + nn) as f64 / self.num_instances as f64 * 100.0;
let prec = pp as f64 / (pp + pn).max(1) as f64 * 100.0;
let recall = pp as f64 / (pp + np).max(1) as f64 * 100.0;

eprintln!("Result:");
eprintln!("Accuracy: {:.2}% ({} / {})", acc, pp + nn, self.num_instances);
eprintln!("Precision: {:.2}% ({} / {})", prec, pp, pp + pn);
eprintln!("Recall: {:.2}% ({} / {})", recall, pp, pp + np);
eprintln!("Confusion Matrix: TP: {}, FP: {}, FN: {}, TN: {}", pp, pn, np, nn);
}

/// Adds a new instance to the model.
/// The instance is represented by a set of attributes and a label.
///
Expand Down Expand Up @@ -372,4 +352,58 @@ impl AdaBoost {
-1
}
}

/// Gets the bias term of the model.
/// The bias is calculated as the negative sum of the model weights divided by 2.
///
/// # Returns:The bias term as a `f64`.
pub fn get_bias(&self) -> f64 {
-self.model.iter().sum::<f64>() / 2.0
}

/// Calculates and returns the performance metrics of the model on the training data.
pub fn get_metrics(&self) -> Metrics {
let bias = self.get_bias();
let mut true_positives = 0; // true positives
let mut false_positives = 0; // false positives
let mut false_negatives = 0; // false negatives
let mut true_negatives = 0; // true negatives

for i in 0..self.num_instances {
let label = self.labels[i];
let (start, end) = self.instances[i];
let mut score = bias;
for &h in &self.instances_buf[start..end] {
score += self.model[h];
}
if score >= 0.0 {
if label > 0 {
true_positives += 1;
} else {
false_positives += 1;
}
} else if label > 0 {
false_negatives += 1;
} else {
true_negatives += 1;
}
}

let accuracy = (true_positives + true_negatives) as f64 / self.num_instances as f64 * 100.0;
let precision =
true_positives as f64 / (true_positives + false_positives).max(1) as f64 * 100.0;
let recall =
true_positives as f64 / (true_positives + false_negatives).max(1) as f64 * 100.0;

Metrics {
accuracy,
precision,
recall,
num_instances: self.num_instances,
true_positives,
false_positives,
false_negatives,
true_negatives,
}
}
}
32 changes: 29 additions & 3 deletions src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ fn extract(args: ExtractArgs) -> Result<(), Box<dyn Error>> {

extractor.extract(args.corpus_file.as_path(), args.features_file.as_path())?;

println!("Feature extraction completed successfully.");
eprintln!("Feature extraction completed successfully.");
Ok(())
}

Expand Down Expand Up @@ -129,9 +129,35 @@ fn train(args: TrainArgs) -> Result<(), Box<dyn Error>> {
trainer.load_model(model_path.as_path())?;
}

trainer.train(running, args.model_file.as_path())?;
let metrics = trainer.train(running, args.model_file.as_path())?;

eprintln!("Result Metrics:");
eprintln!(
" Accuracy: {:.2}% ( {} / {} )",
metrics.accuracy,
metrics.true_positives + metrics.true_negatives,
metrics.num_instances
);
eprintln!(
" Precision: {:.2}% ( {} / {} )",
metrics.precision,
metrics.true_positives,
metrics.true_positives + metrics.false_positives
);
eprintln!(
" Recall: {:.2}% ( {} / {} )",
metrics.recall,
metrics.true_positives,
metrics.true_positives + metrics.false_negatives
);
eprintln!(
" Confusion Matrix:\n True Positives: {}\n False Positives: {}\n False Negatives: {}\n True Negatives: {}",
metrics.true_positives,
metrics.false_positives,
metrics.false_negatives,
metrics.true_negatives
);

println!("Training completed successfully.");
Ok(())
}

Expand Down
7 changes: 3 additions & 4 deletions src/trainer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use std::path::Path;
use std::sync::atomic::AtomicBool;
use std::sync::Arc;

use crate::adaboost::AdaBoost;
use crate::adaboost::{AdaBoost, Metrics};

/// Trainer struct for managing the AdaBoost training process.
/// It initializes the AdaBoost learner with the specified parameters,
Expand Down Expand Up @@ -74,13 +74,12 @@ impl Trainer {
&mut self,
running: Arc<AtomicBool>,
model_path: &Path,
) -> Result<(), Box<dyn std::error::Error>> {
) -> Result<Metrics, Box<dyn std::error::Error>> {
self.learner.train(running.clone());

// Save the trained model to the specified file
self.learner.save_model(model_path)?;
self.learner.show_result();

Ok(())
Ok(self.learner.get_metrics())
}
}