made configuration more flexible
This commit is contained in:
parent
6915cb3df0
commit
1f36a89b3b
1
Cargo.lock
generated
1
Cargo.lock
generated
@ -557,6 +557,7 @@ dependencies = [
|
||||
"ndarray",
|
||||
"rand",
|
||||
"rand_distr",
|
||||
"rayon",
|
||||
"serde-pickle",
|
||||
"serde_json",
|
||||
]
|
||||
|
||||
@ -12,4 +12,5 @@ rand_distr = "0.4.3"
|
||||
ndarray = "0.16.1"
|
||||
flate2 = "1.0"
|
||||
image = "0.25.5"
|
||||
chrono = "0.4"
|
||||
chrono = "0.4"
|
||||
rayon = "1.10"
|
||||
@ -1,3 +1,5 @@
|
||||
use rand::seq::SliceRandom;
|
||||
|
||||
use crate::{
|
||||
data::model::{DataInstanceArray, TrainingInstanceArray},
|
||||
net2::Network2,
|
||||
@ -19,30 +21,44 @@ impl InitConfiguration {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Default)]
|
||||
pub struct RunConfiguration {
|
||||
#[derive(Debug, Default, Clone, Copy)]
|
||||
pub struct RunSettings {
|
||||
pub cost: CostFunction,
|
||||
pub epochs: Option<usize>,
|
||||
pub lambda: RegFunction,
|
||||
pub training_data: Vec<TrainingInstanceArray>,
|
||||
pub test_data: Option<Vec<DataInstanceArray>>,
|
||||
pub mini_batch_size: Option<usize>,
|
||||
pub eta: Option<f64>,
|
||||
monitor_cost_train: bool,
|
||||
monitor_accuracy_train: bool,
|
||||
monitor_cost_eval: bool,
|
||||
monitor_accuracy_eval: bool,
|
||||
pub mini_batch_size: usize,
|
||||
pub eta: f64,
|
||||
}
|
||||
|
||||
impl RunConfiguration {
|
||||
#[derive(Debug, Default)]
|
||||
pub struct DataHandler {
|
||||
pub training_data: Vec<TrainingInstanceArray>,
|
||||
pub test_data: Option<Vec<DataInstanceArray>>,
|
||||
}
|
||||
|
||||
impl DataHandler {
|
||||
pub fn shuffle_trainings(&mut self) {
|
||||
let mut rng = rand::thread_rng();
|
||||
self.training_data.shuffle(&mut rng);
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Default)]
|
||||
pub struct RunHandler {
|
||||
settings: RunSettings,
|
||||
epochs: Option<usize>,
|
||||
data: DataHandler,
|
||||
monitor: EvalHandler,
|
||||
}
|
||||
|
||||
impl RunHandler {
|
||||
pub fn new() -> Self {
|
||||
RunConfiguration {
|
||||
RunHandler {
|
||||
..Default::default()
|
||||
}
|
||||
}
|
||||
|
||||
pub fn cost(mut self, cost: CostFunction) -> Self {
|
||||
self.cost = cost;
|
||||
self.settings.cost = cost;
|
||||
self
|
||||
}
|
||||
|
||||
@ -52,66 +68,69 @@ impl RunConfiguration {
|
||||
}
|
||||
|
||||
pub fn lambda(mut self, lambda: RegFunction) -> Self {
|
||||
self.lambda = lambda;
|
||||
self.settings.lambda = lambda;
|
||||
self
|
||||
}
|
||||
|
||||
pub fn training_data(mut self, training_data: Vec<TrainingInstanceArray>) -> Self {
|
||||
self.training_data = training_data;
|
||||
self.data.training_data = training_data;
|
||||
self
|
||||
}
|
||||
|
||||
pub fn test_data(mut self, test_data: Vec<DataInstanceArray>) -> Self {
|
||||
self.test_data = Some(test_data);
|
||||
self.data.test_data = Some(test_data);
|
||||
self
|
||||
}
|
||||
|
||||
pub fn mini_batch_size(mut self, mini_batch_size: usize) -> Self {
|
||||
self.mini_batch_size = Some(mini_batch_size);
|
||||
self.settings.mini_batch_size = mini_batch_size;
|
||||
self
|
||||
}
|
||||
|
||||
pub fn eta(mut self, eta: f64) -> Self {
|
||||
self.eta = Some(eta);
|
||||
self.settings.eta = eta;
|
||||
self
|
||||
}
|
||||
|
||||
pub fn monitor_train_cost(mut self, monitor: bool) -> Self {
|
||||
self.monitor_cost_train = monitor;
|
||||
self.monitor.monitor_cost_train = monitor;
|
||||
self
|
||||
}
|
||||
|
||||
pub fn monitor_train_accuracy(mut self, monitor: bool) -> Self {
|
||||
self.monitor_accuracy_train = monitor;
|
||||
self.monitor.monitor_accuracy_train = monitor;
|
||||
self
|
||||
}
|
||||
|
||||
pub fn monitor_eval_cost(mut self, monitor: bool) -> Self {
|
||||
self.monitor_cost_eval = monitor;
|
||||
self.monitor.monitor_cost_eval = monitor;
|
||||
self
|
||||
}
|
||||
|
||||
pub fn monitor_eval_accuracy(mut self, monitor: bool) -> Self {
|
||||
self.monitor_accuracy_eval = monitor;
|
||||
self.monitor.monitor_accuracy_eval = monitor;
|
||||
self
|
||||
}
|
||||
|
||||
pub fn run(self, net: &mut Network2) -> EvalHandler {
|
||||
// TODO: errorhandling, maybe more than just panic...
|
||||
net.sgd(
|
||||
self.training_data,
|
||||
self.epochs.unwrap(),
|
||||
self.mini_batch_size.unwrap(),
|
||||
self.eta.unwrap(),
|
||||
self.test_data,
|
||||
self.lambda,
|
||||
EvalHandler::new(
|
||||
self.monitor_cost_train,
|
||||
self.monitor_accuracy_train,
|
||||
self.monitor_cost_eval,
|
||||
self.monitor_accuracy_eval,
|
||||
),
|
||||
self.cost,
|
||||
)
|
||||
pub fn run(&mut self, net: &mut Network2) {
|
||||
let now = std::time::Instant::now();
|
||||
let epochs = self.epochs.unwrap_or(5);
|
||||
for j in 0..epochs {
|
||||
let mut last = now.elapsed();
|
||||
println!("Epoch {} started - shuffeling data {:?}", j, now.elapsed());
|
||||
self.data.shuffle_trainings();
|
||||
println!("Epoch {} data shuffeled {:?}", j, now.elapsed() - last);
|
||||
last = now.elapsed();
|
||||
net.sgd_single(&self.data.training_data, self.settings);
|
||||
println!("Epoch {} completed training {:?}", j, now.elapsed() - last);
|
||||
last = now.elapsed();
|
||||
|
||||
net.monitor(self.settings, &mut self.monitor, &self.data);
|
||||
|
||||
println!("Epoch {} completed {:?}", j, now.elapsed() - last);
|
||||
}
|
||||
|
||||
println!("Run completed {:?}", now.elapsed());
|
||||
println!("{:?}", self.monitor);
|
||||
}
|
||||
}
|
||||
|
||||
@ -2,8 +2,7 @@ use ndarray::Array2;
|
||||
|
||||
use crate::math_helper::sigmoid_prime_arr;
|
||||
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
#[derive(Default)]
|
||||
#[derive(Debug, Clone, Copy, Default)]
|
||||
pub enum CostFunction {
|
||||
#[default]
|
||||
QuadraticCost,
|
||||
@ -54,14 +53,16 @@ impl CostFunctionTrait for CrossEntropyCost {
|
||||
output
|
||||
.iter()
|
||||
.zip(target.iter())
|
||||
.map(|(o, t)| {
|
||||
.map(|(t, o)| {
|
||||
let left_ln = t.ln();
|
||||
let left_ln = if left_ln.is_finite() { left_ln } else { 0.0 };
|
||||
let left = o * left_ln;
|
||||
let right_ln = (1.0 - o).ln();
|
||||
let right_ln = (1.0 - t).ln();
|
||||
let right_ln = if right_ln.is_finite() { right_ln } else { 0.0 };
|
||||
let right = (1.0 - t) * right_ln;
|
||||
left + right
|
||||
let right = (1.0 - o) * right_ln;
|
||||
|
||||
//println!("left: {}, right: {} = {}", left, right, left + right);
|
||||
-left - right
|
||||
})
|
||||
.sum::<f64>()
|
||||
}
|
||||
|
||||
@ -1,13 +1,13 @@
|
||||
#[derive(Debug)]
|
||||
#[derive(Debug, Default)]
|
||||
pub struct EvalHandler {
|
||||
evaluation_cost: Vec<f64>,
|
||||
evaluation_accuracy: Vec<usize>,
|
||||
training_cost: Vec<f64>,
|
||||
training_accuracy: Vec<usize>,
|
||||
monitor_cost_train: bool,
|
||||
monitor_accuracy_train: bool,
|
||||
monitor_cost_eval: bool,
|
||||
monitor_accuracy_eval: bool,
|
||||
pub monitor_cost_train: bool,
|
||||
pub monitor_accuracy_train: bool,
|
||||
pub monitor_cost_eval: bool,
|
||||
pub monitor_accuracy_eval: bool,
|
||||
}
|
||||
impl EvalHandler {
|
||||
pub fn new(
|
||||
|
||||
@ -2,4 +2,5 @@ pub mod builder;
|
||||
pub mod cost;
|
||||
pub mod eval;
|
||||
pub mod init;
|
||||
pub mod neuron;
|
||||
pub mod regularization;
|
||||
|
||||
5
src/configuration/neuron.rs
Normal file
5
src/configuration/neuron.rs
Normal file
@ -0,0 +1,5 @@
|
||||
pub enum Neuron {
|
||||
Sigmoid,
|
||||
Tanh,
|
||||
ReLu,
|
||||
}
|
||||
30
src/main.rs
30
src/main.rs
@ -1,4 +1,4 @@
|
||||
use neural_network::configuration::builder::RunConfiguration;
|
||||
use neural_network::configuration::builder::RunHandler;
|
||||
use neural_network::configuration::cost::CostFunction;
|
||||
use neural_network::configuration::init::InitDistribution;
|
||||
use neural_network::configuration::regularization::RegFunction;
|
||||
@ -13,20 +13,32 @@ fn main() -> Result<(), anyhow::Error> {
|
||||
|
||||
let training_data = load_raw_gz("ressources/train.json.gz")?;
|
||||
println!("Data loaded");
|
||||
let mut network = Network2::from_config(vec![784, 30, 10], InitDistribution::NormalizedWeight);
|
||||
let mut network = Network2::from_config(vec![784, 30, 10], InitDistribution::LargeWeight);
|
||||
|
||||
let run_eval = RunConfiguration::new()
|
||||
let run_eval = RunHandler::new()
|
||||
.cost(CostFunction::CrossEntropyCost)
|
||||
.epochs(30)
|
||||
.lambda(RegFunction::L2(0.1))
|
||||
.training_data(training_data.into_iter().map(|x| x.into()).collect())
|
||||
.epochs(10)
|
||||
// .lambda(RegFunction::L2(0.01))
|
||||
.training_data(
|
||||
training_data
|
||||
.into_iter()
|
||||
.take(30000)
|
||||
.map(|x| x.into())
|
||||
.collect(),
|
||||
)
|
||||
.mini_batch_size(10)
|
||||
.eta(3.0)
|
||||
.test_data(test_data.into_iter().map(|x| x.into()).collect::<Vec<_>>())
|
||||
.test_data(
|
||||
test_data
|
||||
.into_iter()
|
||||
.take(300)
|
||||
.map(|x| x.into())
|
||||
.collect::<Vec<_>>(),
|
||||
)
|
||||
.monitor_eval_accuracy(true)
|
||||
.monitor_eval_cost(true)
|
||||
.monitor_train_accuracy(true)
|
||||
.monitor_train_cost(true)
|
||||
.monitor_train_accuracy(false)
|
||||
.monitor_train_cost(false)
|
||||
.run(&mut network);
|
||||
|
||||
println!("{:?}", run_eval);
|
||||
|
||||
105
src/net2.rs
105
src/net2.rs
@ -1,9 +1,11 @@
|
||||
use ndarray::Array2;
|
||||
use rand::seq::SliceRandom;
|
||||
use rayon::iter::{IntoParallelRefIterator, ParallelIterator};
|
||||
|
||||
use crate::{
|
||||
configuration::{
|
||||
cost::{CostFunction, CostFunctionTrait},
|
||||
builder::{DataHandler, RunSettings},
|
||||
cost::{self, CostFunction, CostFunctionTrait},
|
||||
eval::EvalHandler,
|
||||
init::InitDistribution,
|
||||
regularization::RegFunction,
|
||||
@ -112,8 +114,10 @@ impl Network2 {
|
||||
cost_fun: CostFunction,
|
||||
lambda: RegFunction,
|
||||
) -> f64 {
|
||||
cost_fun.f_n(&self.feedforward2(&data.input), &data.target)
|
||||
+ lambda.reg_cost_contribution(&self.weights)
|
||||
let res = cost_fun.f_n(&self.feedforward2(&data.input), &data.target)
|
||||
+ lambda.reg_cost_contribution(&self.weights);
|
||||
println!("Cost {:?}", res);
|
||||
res
|
||||
}
|
||||
|
||||
pub fn cost_data_inst(
|
||||
@ -173,24 +177,13 @@ impl Network2 {
|
||||
data.iter().filter(|d| self.eval_training_inst(d)).count()
|
||||
}
|
||||
|
||||
pub fn sgd_single(
|
||||
&mut self,
|
||||
mut training_data: Vec<TrainingInstanceArray>,
|
||||
mini_batch_size: usize,
|
||||
eta: f64,
|
||||
test_data: Option<Vec<DataInstanceArray>>,
|
||||
lambda: RegFunction,
|
||||
mut eval_handler: EvalHandler,
|
||||
cost_fun: CostFunction,
|
||||
) -> EvalHandler {
|
||||
pub fn sgd_single(&mut self, training_data: &[TrainingInstanceArray], settings: RunSettings) {
|
||||
let now = std::time::Instant::now();
|
||||
let mut last = now.elapsed();
|
||||
let mut cnt = 0;
|
||||
let mut rng = rand::thread_rng();
|
||||
training_data.shuffle(&mut rng);
|
||||
let n = training_data.len();
|
||||
for batch in training_data.chunks(mini_batch_size) {
|
||||
self.update_mini_batch(batch, eta, lambda, n);
|
||||
for batch in training_data.chunks(settings.mini_batch_size) {
|
||||
self.update_mini_batch(batch, n, settings);
|
||||
cnt += batch.len();
|
||||
if cnt % 2000 == 0 {
|
||||
println!("complete {cnt} elapsed {:?}", now.elapsed() - last);
|
||||
@ -198,25 +191,23 @@ impl Network2 {
|
||||
}
|
||||
}
|
||||
println!("complete Elapsed {:?} - testing ... ", now.elapsed());
|
||||
eval_handler.monitor_training_cost(|| self.cost_training(&training_data, cost_fun, lambda));
|
||||
eval_handler.monitor_training_accuracy(|| self.accuracy_training(&training_data));
|
||||
eval_handler.monitor_evaluation_cost(|| {
|
||||
if let Some(ref test) = test_data {
|
||||
self.cost_data(test, cost_fun, lambda)
|
||||
} else {
|
||||
0.0
|
||||
}
|
||||
});
|
||||
eval_handler.monitor_evaluation_accuracy(|| {
|
||||
if let Some(ref test) = test_data {
|
||||
self.accuracy_data(test)
|
||||
} else {
|
||||
0
|
||||
}
|
||||
});
|
||||
|
||||
eval_handler
|
||||
}
|
||||
|
||||
pub fn monitor(&self, settings: RunSettings, monitor: &mut EvalHandler, data: &DataHandler) {
|
||||
monitor.monitor_training_cost(|| {
|
||||
self.cost_training(&data.training_data, settings.cost, settings.lambda)
|
||||
});
|
||||
monitor.monitor_training_accuracy(|| self.accuracy_training(&data.training_data));
|
||||
|
||||
if let Some(ref test_data) = data.test_data {
|
||||
monitor.monitor_evaluation_cost(|| {
|
||||
self.cost_data(test_data, settings.cost, settings.lambda)
|
||||
});
|
||||
monitor.monitor_evaluation_accuracy(|| self.accuracy_data(test_data));
|
||||
}
|
||||
}
|
||||
|
||||
/*
|
||||
pub fn sgd(
|
||||
&mut self,
|
||||
mut training_data: Vec<TrainingInstanceArray>,
|
||||
@ -236,7 +227,7 @@ impl Network2 {
|
||||
training_data.shuffle(&mut rng);
|
||||
let n = training_data.len();
|
||||
for batch in training_data.chunks(mini_batch_size) {
|
||||
self.update_mini_batch(batch, eta, lambda, n);
|
||||
self.update_mini_batch(batch, eta, lambda, n, cost_fun);
|
||||
cnt += batch.len();
|
||||
if cnt % 2000 == 0 {
|
||||
println!(
|
||||
@ -250,7 +241,7 @@ impl Network2 {
|
||||
println!(
|
||||
"Epoch {} complete Elapsed {:?} - testing ... ",
|
||||
j,
|
||||
now.elapsed()
|
||||
now.elapsed() - last
|
||||
);
|
||||
eval_handler
|
||||
.monitor_training_cost(|| self.cost_training(&training_data, cost_fun, lambda));
|
||||
@ -269,18 +260,38 @@ impl Network2 {
|
||||
0
|
||||
}
|
||||
});
|
||||
println!(
|
||||
"Epoch {} with eval complete Elapsed {:?} - testing ... ",
|
||||
j,
|
||||
now.elapsed()
|
||||
);
|
||||
}
|
||||
|
||||
eval_handler
|
||||
}
|
||||
}*/
|
||||
|
||||
pub fn update_mini_batch(
|
||||
&mut self,
|
||||
mini_batch: &[TrainingInstanceArray],
|
||||
eta: f64,
|
||||
lambda: RegFunction,
|
||||
size: usize,
|
||||
settings: RunSettings,
|
||||
) {
|
||||
let cost_fun = settings.cost;
|
||||
let lambda = settings.lambda;
|
||||
let eta = settings.eta;
|
||||
let (nabla_b, nabla_w) = mini_batch
|
||||
.par_iter()
|
||||
.map(|d| self.backprop(d, cost_fun))
|
||||
.reduce(
|
||||
|| (get_zero_clone(&self.biases), get_zero_clone(&self.weights)),
|
||||
|acc, x| {
|
||||
(
|
||||
combine(&acc.0, &x.0, |nb, dnb| nb + dnb),
|
||||
combine(&acc.1, &x.1, |nw, dnw| nw + dnw),
|
||||
)
|
||||
},
|
||||
);
|
||||
/*
|
||||
let mut nabla_b = get_zero_clone(&self.biases);
|
||||
let mut nabla_w = get_zero_clone(&self.weights);
|
||||
for training_instance in mini_batch {
|
||||
@ -289,6 +300,7 @@ impl Network2 {
|
||||
nabla_b = combine(&nabla_b, &delta_nabla_b, |nb, dnb| nb + dnb);
|
||||
nabla_w = combine(&nabla_w, &delta_nabla_w, |nw, dnw| nw + dnw);
|
||||
}
|
||||
*/
|
||||
self.biases = combine(&self.biases, &nabla_b, |b, nb| {
|
||||
b - (eta / mini_batch.len() as f64) * nb
|
||||
});
|
||||
@ -300,6 +312,7 @@ impl Network2 {
|
||||
pub fn backprop(
|
||||
&self,
|
||||
training_instance: &TrainingInstanceArray,
|
||||
cost_fun: CostFunction,
|
||||
) -> (Vec<Array2<f64>>, Vec<Array2<f64>>) {
|
||||
let mut nabla_b = get_zero_clone(&self.biases);
|
||||
let mut nabla_w = get_zero_clone(&self.weights);
|
||||
@ -314,9 +327,11 @@ impl Network2 {
|
||||
zs.push(z);
|
||||
activations.push(activation.clone());
|
||||
}
|
||||
let mut delta = self
|
||||
.cost_derivative(activations.last().unwrap(), &training_instance.target)
|
||||
* sigmoid_prime_arr(zs.last().unwrap());
|
||||
let mut delta = cost_fun.delta(
|
||||
activations.last().unwrap(),
|
||||
&training_instance.target,
|
||||
&zs.last().unwrap(),
|
||||
);
|
||||
nabla_b[nb_len - 1] = delta.clone();
|
||||
nabla_w[nw_len - 1] = delta.dot(&activations[activations.len() - 2].t());
|
||||
for l in 2..self.num_layers {
|
||||
@ -328,8 +343,4 @@ impl Network2 {
|
||||
}
|
||||
(nabla_b, nabla_w)
|
||||
}
|
||||
|
||||
fn cost_derivative(&self, output_activations: &Array2<f64>, y: &Array2<f64>) -> Array2<f64> {
|
||||
output_activations - y
|
||||
}
|
||||
}
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user