这是indexloc提供的服务,不要输入任何密码
Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ categories = ["text-processing"]
license = "MIT"

[features]
default = [] # No directories included
default = []

[dependencies]
clap = { version = "4.5.39", features = ["derive"] }
Expand Down
19 changes: 9 additions & 10 deletions src/adaboost.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use std::collections::{BTreeMap, HashMap, HashSet};
use std::fs::File;
use std::io::{BufRead, BufReader, Write};
use std::path::Path;
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::Arc;

Expand Down Expand Up @@ -53,7 +54,7 @@ impl AdaBoost {
/// * `filename`: The path to the file containing the features.
/// # Returns: A result indicating success or failure.
/// # Errors: Returns an error if the file cannot be opened or read.
pub fn initialize_features(&mut self, filename: &str) -> std::io::Result<()> {
pub fn initialize_features(&mut self, filename: &Path) -> std::io::Result<()> {
let file = File::open(filename)?;
let reader = BufReader::new(file);
let mut map = BTreeMap::new(); // preserve order
Expand Down Expand Up @@ -100,7 +101,7 @@ impl AdaBoost {
/// * `filename`: The path to the file containing the instances.
/// # Returns: A result indicating success or failure.
/// # Errors: Returns an error if the file cannot be opened or read.
pub fn initialize_instances(&mut self, filename: &str) -> std::io::Result<()> {
pub fn initialize_instances(&mut self, filename: &Path) -> std::io::Result<()> {
let file = File::open(filename)?;
let reader = BufReader::new(file);
let bias = self.get_bias();
Expand Down Expand Up @@ -173,7 +174,7 @@ impl AdaBoost {
// Find the best hypothesis
let mut h_best = 0;
let mut best_error_rate = positive_weight_sum / instance_weight_sum;
for h in 1..num_features {
for (h, _) in errors.iter().enumerate().take(num_features).skip(1) {
let mut e = errors[h] + positive_weight_sum;
e /= instance_weight_sum;
if (0.5 - e).abs() > (0.5 - best_error_rate).abs() {
Expand Down Expand Up @@ -232,7 +233,7 @@ impl AdaBoost {
/// # Errors: Returns an error if the file cannot be created or written to.
/// # Notes: The bias term is calculated as the negative sum of the weights divided by 2.
/// The model is saved in a way that can be easily loaded later.
pub fn save_model(&self, filename: &str) -> std::io::Result<()> {
pub fn save_model(&self, filename: &Path) -> std::io::Result<()> {
let mut file = File::create(filename)?;
let mut bias = -self.model[0];
for (h, &w) in self.features.iter().zip(self.model.iter()).skip(1) {
Expand All @@ -254,7 +255,7 @@ impl AdaBoost {
/// # Errors: Returns an error if the file cannot be opened or read.
/// # Notes: The model is loaded into the `features` and `model` vectors,
/// and the bias is calculated as the negative sum of the weights divided by 2.
pub fn load_model(&mut self, filename: &str) -> std::io::Result<()> {
pub fn load_model(&mut self, filename: &Path) -> std::io::Result<()> {
let file = File::open(filename)?;
let reader = BufReader::new(file);
let mut m: HashMap<String, f64> = HashMap::new();
Expand Down Expand Up @@ -313,12 +314,10 @@ impl AdaBoost {
} else {
pn += 1
}
} else if label > 0 {
np += 1
} else {
if label > 0 {
np += 1
} else {
nn += 1
}
nn += 1
}
}

Expand Down
67 changes: 67 additions & 0 deletions src/extractor.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
use std::collections::HashSet;
use std::error::Error;
use std::fs::File;
use std::io::{self, BufRead, Write};
use std::path::Path;

use crate::segmenter::Segmenter;

pub struct Extractor {
segmenter: Segmenter,
}

impl Default for Extractor {
fn default() -> Self {
Self::new()
}
}

impl Extractor {
pub fn new() -> Self {
Extractor {
segmenter: Segmenter::new(None),
}
}

pub fn extract(
&mut self,
corpus_path: &Path,
features_path: &Path,
) -> Result<(), Box<dyn Error>> {
// Read sentences from stdin
// Each line is treated as a separate sentence
let corpus_file = File::open(corpus_path)?;
let corpus = io::BufReader::new(corpus_file);

// Create a file to write the features
let features_file = File::create(features_path)?;
let mut features = io::BufWriter::new(features_file);

// learner function to write features
// This function will be called for each word in the input sentences
// It takes a set of attributes and a label, and writes them to stdout
let mut learner = |attributes: HashSet<String>, label: i8| {
let mut attrs: Vec<String> = attributes.into_iter().collect();
attrs.sort();
let mut line = vec![label.to_string()];
line.extend(attrs);
writeln!(features, "{}", line.join("\t")).expect("Failed to write features");
};

for line in corpus.lines() {
match line {
Ok(line) => {
let line = line.trim();
if !line.is_empty() {
self.segmenter.add_sentence_with_writer(line, &mut learner);
}
}
Err(err) => {
eprintln!("Error reading input: {}", err);
}
}
}

Ok(())
}
}
2 changes: 2 additions & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
pub mod adaboost;
pub mod extractor;
pub mod segmenter;
pub mod trainer;

const VERERSION: &str = env!("CARGO_PKG_VERSION");

Expand Down
115 changes: 40 additions & 75 deletions src/main.rs
Original file line number Diff line number Diff line change
@@ -1,15 +1,16 @@
use std::collections::HashSet;
use std::error::Error;
use std::fs::File;
use std::io::{self, BufRead, Write};
use std::path::PathBuf;
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::Arc;

use clap::{Args, Parser, Subcommand};

use litsea::adaboost::AdaBoost;
use litsea::extractor::Extractor;
use litsea::get_version;
use litsea::segmenter::Segmenter;
use litsea::trainer::Trainer;

#[derive(Debug, Args)]
#[clap(
Expand All @@ -18,8 +19,8 @@ use litsea::segmenter::Segmenter;
version = get_version(),
)]
struct ExtractArgs {
corpus_file: String,
features_file: String,
corpus_file: PathBuf,
features_file: PathBuf,
}

#[derive(Debug, Args)]
Expand All @@ -38,10 +39,10 @@ struct TrainArgs {
num_threads: usize,

#[arg(short = 'm', long)]
load_model: Option<String>,
load_model_file: Option<PathBuf>,

instances_file: String,
model_file: String,
features_file: PathBuf,
model_file: PathBuf,
}

#[derive(Debug, Args)]
Expand All @@ -50,7 +51,7 @@ struct TrainArgs {
version = get_version(),
)]
struct SegmentArgs {
model_file: String,
model_file: PathBuf,
}

#[derive(Debug, Subcommand)]
Expand All @@ -73,44 +74,11 @@ struct CommandArgs {
}

fn extract(args: ExtractArgs) -> Result<(), Box<dyn Error>> {
// Create a file to write the features
let features_file = File::create(&args.features_file)?;
let mut features = io::BufWriter::new(features_file);

// Initialize the segmenter
// No model is loaded, so it will use the default feature extraction
let mut segmenter = Segmenter::new(None);

// learner function to write features
// This function will be called for each word in the input sentences
// It takes a set of attributes and a label, and writes them to stdout
let mut learner = |attributes: HashSet<String>, label: i8| {
let mut attrs: Vec<String> = attributes.into_iter().collect();
attrs.sort();
let mut line = vec![label.to_string()];
line.extend(attrs);
writeln!(features, "{}", line.join("\t")).expect("Failed to write features");
};

// Read sentences from stdin
// Each line is treated as a separate sentence
let corpus_file = File::open(&args.corpus_file)?;
let corpus = io::BufReader::new(corpus_file);

for line in corpus.lines() {
match line {
Ok(line) => {
let line = line.trim();
if !line.is_empty() {
segmenter.add_sentence_with_writer(line, &mut learner);
}
}
Err(err) => {
eprintln!("Error reading input: {}", err);
}
}
}
let mut extractor = Extractor::new();

extractor.extract(args.corpus_file.as_path(), args.features_file.as_path())?;

println!("Feature extraction completed successfully.");
Ok(())
}

Expand All @@ -127,56 +95,46 @@ fn train(args: TrainArgs) -> Result<(), Box<dyn Error>> {
})
.expect("Error setting Ctrl-C handler");

let mut boost = AdaBoost::new(args.threshold, args.num_iterations, args.num_threads);
let mut trainer = Trainer::new(
args.threshold,
args.num_iterations,
args.num_threads,
args.features_file.as_path(),
);

if let Some(model_path) = args.load_model.as_ref() {
boost.load_model(model_path).unwrap();
if let Some(model_path) = &args.load_model_file {
trainer.load_model(model_path.as_path())?;
}

boost.initialize_features(&args.instances_file).unwrap();
boost.initialize_instances(&args.instances_file).unwrap();

boost.train(running.clone());
boost.save_model(&args.model_file).unwrap();
boost.show_result();
trainer.train(running, args.model_file.as_path())?;

println!("Training completed successfully.");
Ok(())
}

fn segment(args: SegmentArgs) -> Result<(), Box<dyn Error>> {
let model_path = &args.model_file;
let mut leaner = AdaBoost::new(0.01, 100, 1);
leaner.load_model(args.model_file.as_path())?;

let mut model = AdaBoost::new(0.01, 100, 1);
if let Err(e) = model.load_model(model_path) {
eprintln!("Failed to load model: {}", e);
std::process::exit(1);
}

let segmenter = Segmenter::new(Some(model));
let segmenter = Segmenter::new(Some(leaner));
let stdin = io::stdin();
let stdout = io::stdout();
let mut writer = io::BufWriter::new(stdout.lock());

for line in stdin.lock().lines() {
match line {
Ok(line) => {
let line = line.trim();
if line.is_empty() {
continue;
}
let tokens = segmenter.parse(line);
writeln!(writer, "{}", tokens.join(" ")).expect("write failed");
}
Err(err) => {
eprintln!("Error reading input: {}", err);
}
let line = line?;
let line = line.trim();
if line.is_empty() {
continue;
}
let tokens = segmenter.parse(line);
writeln!(writer, "{}", tokens.join(" "))?;
}

Ok(())
}

fn main() -> Result<(), Box<dyn Error>> {
fn run() -> Result<(), Box<dyn std::error::Error>> {
let args = CommandArgs::parse();

match args.command {
Expand All @@ -185,3 +143,10 @@ fn main() -> Result<(), Box<dyn Error>> {
Commands::Segment(args) => segment(args),
}
}

fn main() {
if let Err(e) = run() {
eprintln!("Error: {}", e);
std::process::exit(1);
}
}
Loading