made configuration more flexible

This commit is contained in:
hendrik 2025-01-05 10:25:35 +01:00
parent 6915cb3df0
commit 1f36a89b3b
9 changed files with 159 additions and 108 deletions

1
Cargo.lock generated
View File

@ -557,6 +557,7 @@ dependencies = [
"ndarray",
"rand",
"rand_distr",
"rayon",
"serde-pickle",
"serde_json",
]

View File

@ -12,4 +12,5 @@ rand_distr = "0.4.3"
ndarray = "0.16.1"
flate2 = "1.0"
image = "0.25.5"
chrono = "0.4"
chrono = "0.4"
rayon = "1.10"

View File

@ -1,3 +1,5 @@
use rand::seq::SliceRandom;
use crate::{
data::model::{DataInstanceArray, TrainingInstanceArray},
net2::Network2,
@ -19,30 +21,44 @@ impl InitConfiguration {
}
}
#[derive(Debug, Default)]
pub struct RunConfiguration {
#[derive(Debug, Default, Clone, Copy)]
pub struct RunSettings {
pub cost: CostFunction,
pub epochs: Option<usize>,
pub lambda: RegFunction,
pub training_data: Vec<TrainingInstanceArray>,
pub test_data: Option<Vec<DataInstanceArray>>,
pub mini_batch_size: Option<usize>,
pub eta: Option<f64>,
monitor_cost_train: bool,
monitor_accuracy_train: bool,
monitor_cost_eval: bool,
monitor_accuracy_eval: bool,
pub mini_batch_size: usize,
pub eta: f64,
}
impl RunConfiguration {
#[derive(Debug, Default)]
pub struct DataHandler {
pub training_data: Vec<TrainingInstanceArray>,
pub test_data: Option<Vec<DataInstanceArray>>,
}
impl DataHandler {
pub fn shuffle_trainings(&mut self) {
let mut rng = rand::thread_rng();
self.training_data.shuffle(&mut rng);
}
}
#[derive(Debug, Default)]
pub struct RunHandler {
settings: RunSettings,
epochs: Option<usize>,
data: DataHandler,
monitor: EvalHandler,
}
impl RunHandler {
pub fn new() -> Self {
RunConfiguration {
RunHandler {
..Default::default()
}
}
pub fn cost(mut self, cost: CostFunction) -> Self {
self.cost = cost;
self.settings.cost = cost;
self
}
@ -52,66 +68,69 @@ impl RunConfiguration {
}
pub fn lambda(mut self, lambda: RegFunction) -> Self {
self.lambda = lambda;
self.settings.lambda = lambda;
self
}
pub fn training_data(mut self, training_data: Vec<TrainingInstanceArray>) -> Self {
self.training_data = training_data;
self.data.training_data = training_data;
self
}
pub fn test_data(mut self, test_data: Vec<DataInstanceArray>) -> Self {
self.test_data = Some(test_data);
self.data.test_data = Some(test_data);
self
}
pub fn mini_batch_size(mut self, mini_batch_size: usize) -> Self {
self.mini_batch_size = Some(mini_batch_size);
self.settings.mini_batch_size = mini_batch_size;
self
}
pub fn eta(mut self, eta: f64) -> Self {
self.eta = Some(eta);
self.settings.eta = eta;
self
}
pub fn monitor_train_cost(mut self, monitor: bool) -> Self {
self.monitor_cost_train = monitor;
self.monitor.monitor_cost_train = monitor;
self
}
pub fn monitor_train_accuracy(mut self, monitor: bool) -> Self {
self.monitor_accuracy_train = monitor;
self.monitor.monitor_accuracy_train = monitor;
self
}
pub fn monitor_eval_cost(mut self, monitor: bool) -> Self {
self.monitor_cost_eval = monitor;
self.monitor.monitor_cost_eval = monitor;
self
}
pub fn monitor_eval_accuracy(mut self, monitor: bool) -> Self {
self.monitor_accuracy_eval = monitor;
self.monitor.monitor_accuracy_eval = monitor;
self
}
pub fn run(self, net: &mut Network2) -> EvalHandler {
// TODO: errorhandling, maybe more than just panic...
net.sgd(
self.training_data,
self.epochs.unwrap(),
self.mini_batch_size.unwrap(),
self.eta.unwrap(),
self.test_data,
self.lambda,
EvalHandler::new(
self.monitor_cost_train,
self.monitor_accuracy_train,
self.monitor_cost_eval,
self.monitor_accuracy_eval,
),
self.cost,
)
pub fn run(&mut self, net: &mut Network2) {
let now = std::time::Instant::now();
let epochs = self.epochs.unwrap_or(5);
for j in 0..epochs {
let mut last = now.elapsed();
println!("Epoch {} started - shuffeling data {:?}", j, now.elapsed());
self.data.shuffle_trainings();
println!("Epoch {} data shuffeled {:?}", j, now.elapsed() - last);
last = now.elapsed();
net.sgd_single(&self.data.training_data, self.settings);
println!("Epoch {} completed training {:?}", j, now.elapsed() - last);
last = now.elapsed();
net.monitor(self.settings, &mut self.monitor, &self.data);
println!("Epoch {} completed {:?}", j, now.elapsed() - last);
}
println!("Run completed {:?}", now.elapsed());
println!("{:?}", self.monitor);
}
}

View File

@ -2,8 +2,7 @@ use ndarray::Array2;
use crate::math_helper::sigmoid_prime_arr;
#[derive(Debug, Clone, Copy)]
#[derive(Default)]
#[derive(Debug, Clone, Copy, Default)]
pub enum CostFunction {
#[default]
QuadraticCost,
@ -54,14 +53,16 @@ impl CostFunctionTrait for CrossEntropyCost {
output
.iter()
.zip(target.iter())
.map(|(o, t)| {
.map(|(t, o)| {
let left_ln = t.ln();
let left_ln = if left_ln.is_finite() { left_ln } else { 0.0 };
let left = o * left_ln;
let right_ln = (1.0 - o).ln();
let right_ln = (1.0 - t).ln();
let right_ln = if right_ln.is_finite() { right_ln } else { 0.0 };
let right = (1.0 - t) * right_ln;
left + right
let right = (1.0 - o) * right_ln;
//println!("left: {}, right: {} = {}", left, right, left + right);
-left - right
})
.sum::<f64>()
}

View File

@ -1,13 +1,13 @@
#[derive(Debug)]
#[derive(Debug, Default)]
pub struct EvalHandler {
evaluation_cost: Vec<f64>,
evaluation_accuracy: Vec<usize>,
training_cost: Vec<f64>,
training_accuracy: Vec<usize>,
monitor_cost_train: bool,
monitor_accuracy_train: bool,
monitor_cost_eval: bool,
monitor_accuracy_eval: bool,
pub monitor_cost_train: bool,
pub monitor_accuracy_train: bool,
pub monitor_cost_eval: bool,
pub monitor_accuracy_eval: bool,
}
impl EvalHandler {
pub fn new(

View File

@ -2,4 +2,5 @@ pub mod builder;
pub mod cost;
pub mod eval;
pub mod init;
pub mod neuron;
pub mod regularization;

View File

@ -0,0 +1,5 @@
pub enum Neuron {
Sigmoid,
Tanh,
ReLu,
}

View File

@ -1,4 +1,4 @@
use neural_network::configuration::builder::RunConfiguration;
use neural_network::configuration::builder::RunHandler;
use neural_network::configuration::cost::CostFunction;
use neural_network::configuration::init::InitDistribution;
use neural_network::configuration::regularization::RegFunction;
@ -13,20 +13,32 @@ fn main() -> Result<(), anyhow::Error> {
let training_data = load_raw_gz("ressources/train.json.gz")?;
println!("Data loaded");
let mut network = Network2::from_config(vec![784, 30, 10], InitDistribution::NormalizedWeight);
let mut network = Network2::from_config(vec![784, 30, 10], InitDistribution::LargeWeight);
let run_eval = RunConfiguration::new()
let run_eval = RunHandler::new()
.cost(CostFunction::CrossEntropyCost)
.epochs(30)
.lambda(RegFunction::L2(0.1))
.training_data(training_data.into_iter().map(|x| x.into()).collect())
.epochs(10)
// .lambda(RegFunction::L2(0.01))
.training_data(
training_data
.into_iter()
.take(30000)
.map(|x| x.into())
.collect(),
)
.mini_batch_size(10)
.eta(3.0)
.test_data(test_data.into_iter().map(|x| x.into()).collect::<Vec<_>>())
.test_data(
test_data
.into_iter()
.take(300)
.map(|x| x.into())
.collect::<Vec<_>>(),
)
.monitor_eval_accuracy(true)
.monitor_eval_cost(true)
.monitor_train_accuracy(true)
.monitor_train_cost(true)
.monitor_train_accuracy(false)
.monitor_train_cost(false)
.run(&mut network);
println!("{:?}", run_eval);

View File

@ -1,9 +1,11 @@
use ndarray::Array2;
use rand::seq::SliceRandom;
use rayon::iter::{IntoParallelRefIterator, ParallelIterator};
use crate::{
configuration::{
cost::{CostFunction, CostFunctionTrait},
builder::{DataHandler, RunSettings},
cost::{self, CostFunction, CostFunctionTrait},
eval::EvalHandler,
init::InitDistribution,
regularization::RegFunction,
@ -112,8 +114,10 @@ impl Network2 {
cost_fun: CostFunction,
lambda: RegFunction,
) -> f64 {
cost_fun.f_n(&self.feedforward2(&data.input), &data.target)
+ lambda.reg_cost_contribution(&self.weights)
let res = cost_fun.f_n(&self.feedforward2(&data.input), &data.target)
+ lambda.reg_cost_contribution(&self.weights);
println!("Cost {:?}", res);
res
}
pub fn cost_data_inst(
@ -173,24 +177,13 @@ impl Network2 {
data.iter().filter(|d| self.eval_training_inst(d)).count()
}
pub fn sgd_single(
&mut self,
mut training_data: Vec<TrainingInstanceArray>,
mini_batch_size: usize,
eta: f64,
test_data: Option<Vec<DataInstanceArray>>,
lambda: RegFunction,
mut eval_handler: EvalHandler,
cost_fun: CostFunction,
) -> EvalHandler {
pub fn sgd_single(&mut self, training_data: &[TrainingInstanceArray], settings: RunSettings) {
let now = std::time::Instant::now();
let mut last = now.elapsed();
let mut cnt = 0;
let mut rng = rand::thread_rng();
training_data.shuffle(&mut rng);
let n = training_data.len();
for batch in training_data.chunks(mini_batch_size) {
self.update_mini_batch(batch, eta, lambda, n);
for batch in training_data.chunks(settings.mini_batch_size) {
self.update_mini_batch(batch, n, settings);
cnt += batch.len();
if cnt % 2000 == 0 {
println!("complete {cnt} elapsed {:?}", now.elapsed() - last);
@ -198,25 +191,23 @@ impl Network2 {
}
}
println!("complete Elapsed {:?} - testing ... ", now.elapsed());
eval_handler.monitor_training_cost(|| self.cost_training(&training_data, cost_fun, lambda));
eval_handler.monitor_training_accuracy(|| self.accuracy_training(&training_data));
eval_handler.monitor_evaluation_cost(|| {
if let Some(ref test) = test_data {
self.cost_data(test, cost_fun, lambda)
} else {
0.0
}
});
eval_handler.monitor_evaluation_accuracy(|| {
if let Some(ref test) = test_data {
self.accuracy_data(test)
} else {
0
}
});
eval_handler
}
pub fn monitor(&self, settings: RunSettings, monitor: &mut EvalHandler, data: &DataHandler) {
monitor.monitor_training_cost(|| {
self.cost_training(&data.training_data, settings.cost, settings.lambda)
});
monitor.monitor_training_accuracy(|| self.accuracy_training(&data.training_data));
if let Some(ref test_data) = data.test_data {
monitor.monitor_evaluation_cost(|| {
self.cost_data(test_data, settings.cost, settings.lambda)
});
monitor.monitor_evaluation_accuracy(|| self.accuracy_data(test_data));
}
}
/*
pub fn sgd(
&mut self,
mut training_data: Vec<TrainingInstanceArray>,
@ -236,7 +227,7 @@ impl Network2 {
training_data.shuffle(&mut rng);
let n = training_data.len();
for batch in training_data.chunks(mini_batch_size) {
self.update_mini_batch(batch, eta, lambda, n);
self.update_mini_batch(batch, eta, lambda, n, cost_fun);
cnt += batch.len();
if cnt % 2000 == 0 {
println!(
@ -250,7 +241,7 @@ impl Network2 {
println!(
"Epoch {} complete Elapsed {:?} - testing ... ",
j,
now.elapsed()
now.elapsed() - last
);
eval_handler
.monitor_training_cost(|| self.cost_training(&training_data, cost_fun, lambda));
@ -269,18 +260,38 @@ impl Network2 {
0
}
});
println!(
"Epoch {} with eval complete Elapsed {:?} - testing ... ",
j,
now.elapsed()
);
}
eval_handler
}
}*/
pub fn update_mini_batch(
&mut self,
mini_batch: &[TrainingInstanceArray],
eta: f64,
lambda: RegFunction,
size: usize,
settings: RunSettings,
) {
let cost_fun = settings.cost;
let lambda = settings.lambda;
let eta = settings.eta;
let (nabla_b, nabla_w) = mini_batch
.par_iter()
.map(|d| self.backprop(d, cost_fun))
.reduce(
|| (get_zero_clone(&self.biases), get_zero_clone(&self.weights)),
|acc, x| {
(
combine(&acc.0, &x.0, |nb, dnb| nb + dnb),
combine(&acc.1, &x.1, |nw, dnw| nw + dnw),
)
},
);
/*
let mut nabla_b = get_zero_clone(&self.biases);
let mut nabla_w = get_zero_clone(&self.weights);
for training_instance in mini_batch {
@ -289,6 +300,7 @@ impl Network2 {
nabla_b = combine(&nabla_b, &delta_nabla_b, |nb, dnb| nb + dnb);
nabla_w = combine(&nabla_w, &delta_nabla_w, |nw, dnw| nw + dnw);
}
*/
self.biases = combine(&self.biases, &nabla_b, |b, nb| {
b - (eta / mini_batch.len() as f64) * nb
});
@ -300,6 +312,7 @@ impl Network2 {
pub fn backprop(
&self,
training_instance: &TrainingInstanceArray,
cost_fun: CostFunction,
) -> (Vec<Array2<f64>>, Vec<Array2<f64>>) {
let mut nabla_b = get_zero_clone(&self.biases);
let mut nabla_w = get_zero_clone(&self.weights);
@ -314,9 +327,11 @@ impl Network2 {
zs.push(z);
activations.push(activation.clone());
}
let mut delta = self
.cost_derivative(activations.last().unwrap(), &training_instance.target)
* sigmoid_prime_arr(zs.last().unwrap());
let mut delta = cost_fun.delta(
activations.last().unwrap(),
&training_instance.target,
&zs.last().unwrap(),
);
nabla_b[nb_len - 1] = delta.clone();
nabla_w[nw_len - 1] = delta.dot(&activations[activations.len() - 2].t());
for l in 2..self.num_layers {
@ -328,8 +343,4 @@ impl Network2 {
}
(nabla_b, nabla_w)
}
fn cost_derivative(&self, output_activations: &Array2<f64>, y: &Array2<f64>) -> Array2<f64> {
output_activations - y
}
}