这是indexloc提供的服务,不要输入任何密码
Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 31 additions & 0 deletions libraries-ai/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,18 @@
<modelVersion>4.0.0</modelVersion>
<artifactId>libraries-ai</artifactId>
<name>libraries-ai</name>
<build>
<plugins>
<plugin>
<groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-compiler-plugin</artifactId>
<configuration>
<source>21</source>
<target>21</target>
</configuration>
</plugin>
</plugins>
</build>

<parent>
<groupId>com.baeldung</groupId>
Expand Down Expand Up @@ -90,6 +102,22 @@
<artifactId>openai-java</artifactId>
<version>${openai.version}</version>
</dependency>
<dependency>
<groupId>org.tribuo</groupId>
<artifactId>tribuo-all</artifactId>
<version>${tribuo-all.version}</version>
<type>pom</type>
</dependency>
<dependency>
<groupId>org.apache.commons</groupId>
<artifactId>commons-lang3</artifactId>
<version>${common-lang3.version}</version>
</dependency>
<dependency>
<groupId>com.opencsv</groupId>
<artifactId>opencsv</artifactId>
<version>${opencsv.version}</version>
</dependency>
</dependencies>

<properties>
Expand All @@ -99,6 +127,9 @@
<theokanning.gpt>0.18.2</theokanning.gpt>
<h2o-genmodel.version>3.46.0.6</h2o-genmodel.version>
<openai.version>0.22.0</openai.version>
<tribuo-all.version>4.3.2</tribuo-all.version>
<common-lang3.version>3.17.0</common-lang3.version>
<opencsv.version>5.11</opencsv.version>
</properties>

</project>
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
package com.baeldung.tribuo;

import java.io.File;
import java.io.FileInputStream;
import java.io.IOException;
import java.io.ObjectInputStream;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.tribuo.Model;
import org.tribuo.Prediction;
import org.tribuo.impl.ArrayExample;
import org.tribuo.regression.Regressor;

public class WineQualityPredictor {

private static final Logger log = LoggerFactory.getLogger(WineQualityPredictor.class);

public static void main(String[] args) throws IOException, ClassNotFoundException {
File modelFile = new File("src/main/resources/model/winequality-red-regressor.ser");
Model<Regressor> loadedModel = null;

try (ObjectInputStream objectInputStream = new ObjectInputStream(new FileInputStream(modelFile))) {
loadedModel = (Model<Regressor>) objectInputStream.readObject();
}

ArrayExample<Regressor> wineAttribute = new ArrayExample<Regressor>(new Regressor("quality", Double.NaN));
wineAttribute.add("fixed acidity", 7.4f);
wineAttribute.add("volatile acidity", 0.7f);
wineAttribute.add("citric acid", 0.47f);
wineAttribute.add("residual sugar", 1.9f);
wineAttribute.add("chlorides", 0.076f);
wineAttribute.add("free sulfur dioxide", 11.0f);
wineAttribute.add("total sulfur dioxide", 34.0f);
wineAttribute.add("density", 0.9978f);
wineAttribute.add("pH", 3.51f);
wineAttribute.add("sulphates", 0.56f);
wineAttribute.add("alcohol", 9.4f);

Prediction<Regressor> prediction = loadedModel.predict(wineAttribute);
double predictQuality = prediction.getOutput()
.getValues()[0];
log.info("Predicted wine quality: " + predictQuality);

}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
package com.baeldung.tribuo;

import java.io.File;
import java.io.FileOutputStream;
import java.io.ObjectOutputStream;
import java.nio.file.Paths;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.tribuo.DataSource;
import org.tribuo.Dataset;
import org.tribuo.Model;
import org.tribuo.MutableDataset;
import org.tribuo.Trainer;
import org.tribuo.common.tree.AbstractCARTTrainer;
import org.tribuo.common.tree.RandomForestTrainer;
import org.tribuo.data.csv.CSVIterator;
import org.tribuo.data.csv.CSVLoader;
import org.tribuo.evaluation.TrainTestSplitter;
import org.tribuo.regression.RegressionFactory;
import org.tribuo.regression.Regressor;
import org.tribuo.regression.ensemble.AveragingCombiner;
import org.tribuo.regression.evaluation.RegressionEvaluation;
import org.tribuo.regression.evaluation.RegressionEvaluator;
import org.tribuo.regression.rtree.CARTRegressionTrainer;
import org.tribuo.regression.rtree.impurity.MeanSquaredError;

import com.oracle.labs.mlrg.olcut.provenance.ProvenanceUtil;

public class WineQualityRegression {

public static final Logger log = LoggerFactory.getLogger(WineQualityRegression.class);

public static final String DATASET_PATH = "src/main/resources/dataset/winequality-red.csv";
public static final String MODEL_PATH = "src/main/resources/model/winequality-red-regressor.ser";

public Model<Regressor> model;
public Trainer<Regressor> trainer;
public Dataset<Regressor> trainSet;
public Dataset<Regressor> testSet;

public static void main(String[] args) throws Exception {
WineQualityRegression wineQualityRegression = new WineQualityRegression();

wineQualityRegression.createDatasets();
wineQualityRegression.createTrainer();
wineQualityRegression.evaluateModels();
wineQualityRegression.saveModel();
}

public void createTrainer() {
CARTRegressionTrainer subsamplingTree = new CARTRegressionTrainer(Integer.MAX_VALUE, AbstractCARTTrainer.MIN_EXAMPLES, 0.001f, 0.7f,
new MeanSquaredError(), Trainer.DEFAULT_SEED);

trainer = new RandomForestTrainer<>(subsamplingTree, new AveragingCombiner(), 10);
model = trainer.train(trainSet);
}

public void createDatasets() throws Exception {
RegressionFactory regressionFactory = new RegressionFactory();
CSVLoader<Regressor> csvLoader = new CSVLoader<>(';', CSVIterator.QUOTE, regressionFactory);
DataSource<Regressor> dataSource = csvLoader.loadDataSource(Paths.get(DATASET_PATH), "quality");

TrainTestSplitter<Regressor> dataSplitter = new TrainTestSplitter<>(dataSource, 0.7, 1L);

trainSet = new MutableDataset<>(dataSplitter.getTrain());
log.info(String.format("Train set size = %d, num of features = %d", trainSet.size(), trainSet.getFeatureMap()
.size()));

testSet = new MutableDataset<>(dataSplitter.getTest());
log.info(String.format("Test set size = %d, num of features = %d", testSet.size(), testSet.getFeatureMap()
.size()));
}

public void evaluateModels() throws Exception {
log.info("Training model");
evaluate(model, "trainSet", trainSet);

log.info("Testing model");
evaluate(model, "testSet", testSet);

log.info("Dataset Provenance: --------------------");
log.info(ProvenanceUtil.formattedProvenanceString(model.getProvenance()
.getDatasetProvenance()));
log.info("Trainer Provenance: --------------------");
log.info(ProvenanceUtil.formattedProvenanceString(model.getProvenance()
.getTrainerProvenance()));
}

public void evaluate(Model<Regressor> model, String datasetName, Dataset<Regressor> dataset) {
log.info("Results for " + datasetName + "---------------------");
RegressionEvaluator evaluator = new RegressionEvaluator();
RegressionEvaluation evaluation = evaluator.evaluate(model, dataset);

Regressor dimension0 = new Regressor("DIM-0", Double.NaN);

log.info("MAE: " + evaluation.mae(dimension0));
log.info("RMSE: " + evaluation.rmse(dimension0));
log.info("R^2: " + evaluation.r2(dimension0));
}

public void saveModel() throws Exception {
File modelFile = new File(MODEL_PATH);
try (ObjectOutputStream objectOutputStream = new ObjectOutputStream(new FileOutputStream(modelFile))) {
objectOutputStream.writeObject(model);
}
}
}
Loading