added visualizer to write grayscale data as image
This commit is contained in:
parent
0a508cfeb8
commit
066a892bdd
1
.gitignore
vendored
1
.gitignore
vendored
@ -1 +1,2 @@
|
||||
target
|
||||
output
|
||||
922
Cargo.lock
generated
922
Cargo.lock
generated
File diff suppressed because it is too large
Load Diff
@ -10,4 +10,6 @@ serde_json = "1.0"
|
||||
rand = "0.8.5"
|
||||
rand_distr = "0.4.3"
|
||||
ndarray = "0.16.1"
|
||||
flate2 = "1.0"
|
||||
flate2 = "1.0"
|
||||
image = "0.25.5"
|
||||
chrono = "0.4"
|
||||
44
src/consts.rs
Normal file
44
src/consts.rs
Normal file
@ -0,0 +1,44 @@
|
||||
use std::{path::PathBuf, sync::OnceLock};
|
||||
|
||||
pub static OUTPUT_FOLDER: OnceLock<PathBuf> = OnceLock::new();
|
||||
pub static DATA_FOLDER: OnceLock<PathBuf> = OnceLock::new();
|
||||
|
||||
pub fn get_output_folder() -> &'static PathBuf {
|
||||
OUTPUT_FOLDER.get_or_init(|| {
|
||||
PathBuf::from("/")
|
||||
.join("root")
|
||||
.join("rust")
|
||||
.join("neural-network")
|
||||
.join("output")
|
||||
})
|
||||
}
|
||||
|
||||
pub fn get_data_folder() -> &'static PathBuf {
|
||||
OUTPUT_FOLDER.get_or_init(|| {
|
||||
PathBuf::from("/")
|
||||
.join("root")
|
||||
.join("rust")
|
||||
.join("neural-network")
|
||||
.join("ressources")
|
||||
})
|
||||
}
|
||||
|
||||
fn get_input_file_path(file: &InputFiles) -> PathBuf {
|
||||
let folder = get_data_folder();
|
||||
match file {
|
||||
InputFiles::Train => folder.join("train.json.gz"),
|
||||
InputFiles::Test => folder.join("test.json.gz"),
|
||||
InputFiles::Validate => folder.join("valid.json.gz"),
|
||||
}
|
||||
}
|
||||
pub enum InputFiles {
|
||||
Train,
|
||||
Test,
|
||||
Validate,
|
||||
}
|
||||
|
||||
impl InputFiles {
|
||||
pub fn get_file_path(&self) -> PathBuf {
|
||||
get_input_file_path(self)
|
||||
}
|
||||
}
|
||||
2
src/data/mod.rs
Normal file
2
src/data/mod.rs
Normal file
@ -0,0 +1,2 @@
|
||||
pub mod model;
|
||||
pub mod visualize;
|
||||
@ -1,4 +1,6 @@
|
||||
use ndarray::Array2;
|
||||
use std::borrow::Cow;
|
||||
|
||||
use ndarray::{Array2, ArrayView2};
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct RawInstance {
|
||||
@ -73,3 +75,36 @@ impl From<RawInstance> for TrainingInstanceArray {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn reshape_data<'a>(data: Array2<f64>) -> Result<Array2<f64>, anyhow::Error> {
|
||||
let (rows, cols) = data.dim();
|
||||
if cols != 1 {
|
||||
return Err(anyhow::anyhow!("Expected an Array2 of shape (n^2, 1)"));
|
||||
}
|
||||
|
||||
// Compute n (assuming n^2 = rows)
|
||||
let n = (rows as f64).sqrt() as usize;
|
||||
if n * n != rows {
|
||||
return Err(anyhow::anyhow!("Number of rows is not a perfect square"));
|
||||
}
|
||||
|
||||
// Reshape the array into (n, n)
|
||||
let reshaped = data.into_shape_clone((n, n))?;
|
||||
Ok(reshaped)
|
||||
}
|
||||
|
||||
pub struct DisplayInstance {
|
||||
pub data: Array2<f64>,
|
||||
pub val: u8,
|
||||
}
|
||||
|
||||
impl From<RawInstance> for DisplayInstance {
|
||||
fn from(data: RawInstance) -> Self {
|
||||
let rows = data.raw.len();
|
||||
let n = (rows as f64).sqrt() as usize;
|
||||
DisplayInstance {
|
||||
data: Array2::from_shape_vec((n, n), data.raw).unwrap(),
|
||||
val: data.val,
|
||||
}
|
||||
}
|
||||
}
|
||||
50
src/data/visualize.rs
Normal file
50
src/data/visualize.rs
Normal file
@ -0,0 +1,50 @@
|
||||
use std::path::PathBuf;
|
||||
|
||||
use image::{GrayImage, Luma};
|
||||
use ndarray::Array2;
|
||||
|
||||
use crate::{consts::get_output_folder, data::model::DisplayInstance};
|
||||
|
||||
fn generate_image(data: &Array2<f64>) -> Result<GrayImage, anyhow::Error> {
|
||||
let (h, w) = data.dim();
|
||||
let mut img = GrayImage::new(w as u32, h as u32);
|
||||
for ((y, x), &value) in data.indexed_iter() {
|
||||
let intensity = (255.0 - value.clamp(0.0, 1.0) * 255.0) as u8;
|
||||
img.put_pixel(x as u32, y as u32, Luma([intensity]));
|
||||
}
|
||||
Ok(img)
|
||||
}
|
||||
|
||||
pub fn save_instance<T: Into<DisplayInstance>>(
|
||||
data: T,
|
||||
note: Option<&str>,
|
||||
) -> Result<(), anyhow::Error> {
|
||||
let display_instance: DisplayInstance = data.into();
|
||||
let img = generate_image(&display_instance.data)?;
|
||||
let output = generate_output_path(display_instance.val.into(), note);
|
||||
img.save(output)?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn generate_output_path(val: usize, note: Option<&str>) -> PathBuf {
|
||||
let output_dir = get_output_folder();
|
||||
let name = generate_name(val, note);
|
||||
output_dir.clone().join(name)
|
||||
}
|
||||
|
||||
fn generate_name(val: usize, note: Option<&str>) -> String {
|
||||
let mut name = get_formatted_datetime();
|
||||
name.push('_');
|
||||
name.push_str(&val.to_string());
|
||||
if let Some(note) = note {
|
||||
name.push('_');
|
||||
name.push_str(note);
|
||||
}
|
||||
name.push_str(".png");
|
||||
name
|
||||
}
|
||||
|
||||
fn get_formatted_datetime() -> String {
|
||||
let now = chrono::Utc::now();
|
||||
now.format("%Y-%m-%d_%H-%M-%S").to_string()
|
||||
}
|
||||
@ -1,4 +1,5 @@
|
||||
pub mod consts;
|
||||
pub mod data;
|
||||
pub mod math_helper;
|
||||
pub mod model;
|
||||
pub mod network;
|
||||
pub mod reader;
|
||||
|
||||
96
src/main.rs
96
src/main.rs
@ -1,99 +1,33 @@
|
||||
use ndarray::Array;
|
||||
use neural_network::model::DataInstanceArray;
|
||||
use neural_network::data::visualize::save_instance;
|
||||
use neural_network::network::Network;
|
||||
use neural_network::reader::{load_raw, load_raw_gz};
|
||||
use serde_json::{from_reader, Value};
|
||||
use std::error::Error;
|
||||
use std::fs::File;
|
||||
use std::io::BufReader;
|
||||
fn main_o() -> Result<(), Box<dyn Error>> {
|
||||
load_raw("ressources/test.json");
|
||||
// Open the file
|
||||
let file = File::open("ressources/test.json")?;
|
||||
let reader = BufReader::new(file);
|
||||
// Deserialize into a generic Value
|
||||
//let value: (Vec<f64>, Vec<u8>) = serde_json::from_reader(reader)?;
|
||||
|
||||
let value: Vec<Vec<Value>> = serde_json::from_reader(reader)?;
|
||||
|
||||
let value = value
|
||||
.into_iter()
|
||||
.skip(0)
|
||||
.next()
|
||||
.unwrap()
|
||||
.into_iter()
|
||||
.skip(2)
|
||||
.next()
|
||||
.unwrap();
|
||||
// Inspect the deserialized value
|
||||
|
||||
if let Value::Array(arr) = value {
|
||||
println!("Array: {:?}", arr.len());
|
||||
let a = arr.into_iter().next().unwrap();
|
||||
match a {
|
||||
Value::Object(obj) => {
|
||||
println!("keys: {:?}", obj.keys().into_iter().collect::<Vec<_>>());
|
||||
}
|
||||
Value::Array(d) => {
|
||||
println!("Array: {:?}", d.into_iter().next().unwrap());
|
||||
}
|
||||
_ => println!("Not an object"),
|
||||
}
|
||||
} else {
|
||||
println!("{:?}", value);
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn amain() {
|
||||
let array = Array::from_vec(vec![1.0, 2.0, 3.0, 4.0]); // 1D array
|
||||
let result = &array + 5.0; // Add scalar to each element
|
||||
println!("Result: {:?}", result); // Example: Create an array `z`
|
||||
let z = Array::from_vec(vec![0.0, 1.0, -1.0, 2.0, -2.0]);
|
||||
|
||||
// Perform the computation: 1 / (1 + exp(-z))
|
||||
|
||||
// Print the result
|
||||
println!("Result: {:?}", result);
|
||||
}
|
||||
use neural_network::reader::load_raw_gz;
|
||||
|
||||
fn main() -> Result<(), anyhow::Error> {
|
||||
let mut nnet = Network::new(vec![784, 30, 10]);
|
||||
/*let test_data = load_raw("ressources/test.json")?
|
||||
.into_iter()
|
||||
.next()
|
||||
.unwrap();
|
||||
let test_data_2: DataInstanceArray = test_data.clone().into();
|
||||
nnet.print_values();
|
||||
println!(
|
||||
"FEED FORWARD: {:?}",
|
||||
nnet.feedforward(test_data_2.raw.clone())
|
||||
);
|
||||
nnet.update_mini_batch(&vec![test_data.clone().into()], 0.01);
|
||||
println!(
|
||||
"FEED FORWARD: {:?} (expected {:?})",
|
||||
nnet.feedforward(test_data_2.raw.clone()),
|
||||
test_data_2.val
|
||||
);
|
||||
nnet.print_values();
|
||||
return Ok(());
|
||||
*/
|
||||
let test_data = load_raw_gz("ressources/test.json.gz")?;
|
||||
let first = test_data.into_iter().next().unwrap();
|
||||
println!("{:?}", save_instance(first, Some("first")));
|
||||
return Ok(());
|
||||
|
||||
let training_data = load_raw_gz("ressources/train.json.gz")?;
|
||||
println!("Data loaded");
|
||||
let mut network = Network::new(vec![784, 30, 10]);
|
||||
network.sgd(
|
||||
training_data
|
||||
.into_iter()
|
||||
.take(3000)
|
||||
//.take(3000)
|
||||
.map(|x| x.into())
|
||||
.collect(),
|
||||
30,
|
||||
10,
|
||||
0.1,
|
||||
3.0,
|
||||
//None,
|
||||
Some(test_data.into_iter().take(300).map(|x| x.into()).collect()),
|
||||
Some(
|
||||
test_data
|
||||
.into_iter()
|
||||
//.take(300)
|
||||
.map(|x| x.into())
|
||||
.collect(),
|
||||
),
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@ -1,5 +1,3 @@
|
||||
use std::vec;
|
||||
|
||||
use ndarray::Array2;
|
||||
|
||||
pub fn sigmoid(z: f64) -> f64 {
|
||||
|
||||
@ -1,10 +1,10 @@
|
||||
use crate::{
|
||||
math_helper::{dot_product, get_zero_clone, sigmoid_arr, sigmoid_prime_arr, sigmoid_vec},
|
||||
model::{DataInstanceArray, RawInstance, TrainingInstance, TrainingInstanceArray},
|
||||
data::model::{DataInstanceArray, TrainingInstanceArray},
|
||||
math_helper::{get_zero_clone, sigmoid_arr, sigmoid_prime_arr},
|
||||
};
|
||||
use ndarray::{arr2, Array2};
|
||||
use ndarray::Array2;
|
||||
use rand::{seq::SliceRandom, Rng};
|
||||
use rand_distr::{Distribution, Normal, NormalError};
|
||||
use rand_distr::{Distribution, Normal};
|
||||
pub struct Network {
|
||||
num_layers: usize,
|
||||
sizes: Vec<usize>,
|
||||
|
||||
@ -3,7 +3,7 @@ use serde_json::Value;
|
||||
use std::fs::File;
|
||||
use std::io::BufReader;
|
||||
|
||||
use crate::model::RawInstance;
|
||||
use crate::data::model::RawInstance;
|
||||
|
||||
pub fn load_raw_gz(path: &str) -> Result<Vec<RawInstance>, anyhow::Error> {
|
||||
let file = File::open(path)?;
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user