I Tried Learning Rust Through Building a Linear Regression Model


🦀 TL;DR: I learned Rust by building a linear regression model from scratch. No tutorials. Just code, docs, and pain.

I’ve often heard many programmers and tech bros say that the best way to learn a new programming language is through doing a project. This advice initially seemed obvious to me: after all, pain is often said to be the greatest teacher, and debugging a project in an unfamiliar language is certainly painful.

However, I recently realized I’d never actually learned a language this way before. All my experience learning languages like Python, JavaScript, C++, and C# came entirely from YouTube tutorials, Udemy, or online courses like CS50 or Boot.dev. That is why for Rust, I decided to bring this notion to the test and learn it totally through coding a project and reading the official documentation as needed without watching any tutorials or guides.

Since I just finished learning about Linear Regression, I decided that the best project to learn Rust was to code out a Linear Model from scratch.

In the project, I used the Taiwan real estate price estimation dataset, which can be found on Kaggle. The code for this project is hosted on my personal GitHub account.

I started things by creating a Rust program using Cargo, Rust’s package manager. Coming from a background of C++ and OpenGL, Cargo is one of my favorite features Rust offers. Now, I don’t have to deal with linking headers and compiling dependencies anymore, everything is just a cargo add away, freed from CMake hell.

Then, I needed to create functions to read the dataset – a single .csv file with each row containing properties of a piece of land – which required me to learn how to create modules and separate files in Rust using the mod.rs file (a common way to organize modules, though not the only one) as the default access for a module and the pub mod command for declaring modules in the project. During this time I also learned about the println! macro, the basic data types in Rust, control statements, namespace and loops. I also learned about file I/O, iterators and the collect method on iterators. Finally, I came up with a single function to parse the csv file.

A Rust feature that left me confused at first was the match statement and how it handles Enums. My feeble dynamically typed mind initially thought of this as something similar to a switch statement in JavaScript, but in Rust, there is no default case and every branch has to be explicitly stated with a return type that is consistent across all scenarios, which means I couldn’t apply my normal way of writing a switch statement.

I initially thought it was really annoying to write out every branch, but after seeing all the LSP support and type support, match has proven to be the superior version of switch and from that point on, I learned to never doubt our great crustacean overlords again.

The final read_dataset function ended up looking something like this:

pub fn read_dataset() -> (Vec<String>, Vec<Vec<f32>>) {
    let file_path = Path::new("dataset").join("taiwan_real_estate.csv");
    let age_col_index = 2;

    let mut header: Vec<String> = Vec::new();
    let mut body: Vec<Vec<f32>> = Vec::new();

    let file = File::open(file_path);
    let file = match file {
        Ok(f) => f,
        Err(_) => {
            println!("Error in opening file, aborting now");
            return (header, body);
        }
    };

    let reader = io::BufReader::new(file);

    for (index, line) in reader.lines().into_iter().enumerate() {
        let line_extracted = match line {
            Ok(s) => s,
            Err(_) => return (header, body),
        };

        let split_line = line_extracted.split(',');

        let mut temp_body = Vec::new();
        let mut valid_line = true;

        for (chunk_index, chunk) in split_line.enumerate() {
            if index == 0 {
                header.push(String::from(chunk));
            } else {
                if chunk_index == age_col_index {
                    let val_split: Vec<&str> = chunk.split(' ').collect();

                    let start = match val_split[0].parse::<f32>() {
                        Ok(value) => value,
                        Err(e) => {
                            println!("{:?}", e);
                            valid_line = false;
                            break;
                        }
                    };

                    let end = match val_split[val_split.len() - 1].parse::<f32>() {
                        Ok(value) => value,
                        Err(e) => {
                            println!("{:?}", e);
                            valid_line = false;
                            break;
                        }
                    };

                    temp_body.push(end - start);
                } else {
                    let value = chunk.parse::<f32>();

                    // if value is not a valid number, skip that line
                    let value = match value {
                        Ok(value) => value,
                        Err(_) => {
                            println!("Skipping line {index} from float parsing error");
                            break;
                        }
                    };

                    temp_body.push(value);
                }
            }
        }

        if temp_body.len() > 0 && valid_line {
            body.push(temp_body);
        }
    }

    return (header, body);
}
Enter fullscreen mode

Exit fullscreen mode

So far, the learning process was smooth, partly because of the easy to navigate documentation that Rust provided and partly because these are concepts I’ve dealt with a lot in previous languages, especially C and C++.

It’s now time for me to write out the Linear and the Model itself, which means dealing with structs, references and the infamous Borrow Checker. As a segfault enthusiast myself, I’m used to the pain of debugging a null pointer reference and memory corruption, which also means concepts like ownership , borrowing, immutable reference by default, shadowing and lifetimes are things I kept in a dark corner of my mind, well hidden from everything else. Well, in Rust, these concepts are upheld to the highest degree and the borrow checker would scream in your face with a million warnings and errors if you don’t follow them.

Believe it or not, my horror with the borrow checker started when I tried to add a name property to the Linear layer (which is the Dense struct in the code). The Dense struct was defined as

struct Dense {
    pub weights: ndarray::Array<f64, Dim<[usize; 2]>>,
    pub bias: f32,
    pub name: &str,
}
Enter fullscreen mode

Exit fullscreen mode

and the name has a type of &str because I wanted to initialize the Dense struct by passing in a string literal, something like this: Dense::new("Dense1"). For avid Rust programmers, your first thought was probably “Why didn’t you just use String and pass in String::from("Dense1") to allow the struct to take ownership of the string”, and to that my response is that I don’t really know dude I just didn’t. Adamant use of &str led me down a spiral of dealing with lifetimes and passing them through layers of functions as I try to fight the borrow checker to let me add names to my layers.

After a while, I managed to recognize my mistakes and refine the Dense struct to something that’s much easier to work with:

#[derive(Debug)]
pub(crate) struct Dense {
    pub weights: ndarray::Array<f64, Dim<[usize; 2]>>,
    pub bias: f32,
    pub name: String,
}

impl Dense {
    pub fn new(in_features: usize, out_features: usize, name: String) -> Self {
        let upper = 6.0 / (in_features as f64 + out_features as f64).sqrt();
        let lower = -1 as f64 * upper;
        let difference = upper - lower;

        let mut weights = Array::from_elem((in_features, out_features).f(), lower as f64);

        for i in 0..in_features {
            for j in 0..out_features {
                weights[[i, j]] += rand::thread_rng().gen_range(0.0..1.0) as f64 * difference;
            }
        }

        return Self {
            weights: weights,
            bias: rand::thread_rng().gen_range(0.0..1.0),
            name: name,
        };
    }

    pub fn get_weight(&self) -> &ndarray::ArrayBase<ndarray::OwnedRepr<f64>, Ix2> {
        return &self.weights;
    }
}
Enter fullscreen mode

Exit fullscreen mode

Thousands of warning and errors and what felt like personal attacks from the borrow checker later, I was finally able to deal with its errors accurately, and by “accurately”, I mean a handsome 50% success rate.

Marching forward, I then dealt with traits and implementing them on different types of structs. While not classical inheritance in the C++ sense, traits are Rust’s way of defining shared behavior that different types can implement. They are, in a way, similar to virtual functions in C++ combined with inheritance because traits allow you to have or leave the default implementation of the functions within. I decided to create the Compute trait to serve the Dense and Model struct, each of them having their own implementation of the Compute trait.

impl Compute for Dense {
    fn compute_single(
        &self,
        x: ndarray::ArrayBase<ndarray::OwnedRepr<f64>, Ix2>,
    ) -> ndarray::ArrayBase<ndarray::OwnedRepr<f64>, Ix2> {
        assert_eq!(
            self.weights.shape()[0],
            x.shape()[x.shape().len() - 1],
            "Shape mismatched. Input has shape: {:?} but weight has shape: {:?}",
            x.shape(),
            self.weights.shape()
        );

        let result = x.dot(&self.weights);
        return result;
    }
}
Enter fullscreen mode

Exit fullscreen mode

I also learned about how formatted printing works in Rust. While #[derive(Debug)] automatically provides {:?} (for developer-focused debugging output), for more user-friendly printing with the println! macro and the {} format, I implemented the fmt::Display trait. This also introduced me to the ? operator, which is a convenient way to propagate Result errors from functions that can fail.

impl fmt::Display for Dense {
    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
        write!(
            f,
            "Name: {}\nWeights: {:?}\nBias: {}\n",
            self.name, self.weights, self.bias
        )?;
        Ok(())
    }
}
Enter fullscreen mode

Exit fullscreen mode

The remaining major learning points I faced along the way was with dealing with explicit errors, panic!, making use of the Option<> Enum and understanding the syntactic sugar features in Rust, which are all precious learning experiences.

Now to put a definitive answer to the burning question of “Is learning a programming language through coding a project a good idea”, I would answer with a definitive “Maybe“. The reason I say maybe and not yes despite learning a lot of concepts and remember them really well after doing the project, is because the process is long and tedious.

Having to figure out the basics of the language while you’re building something disrupts your workflow immensely, and it gets frustrating after a while. This is especially relevant if you’re coming from another language that you already have decent proficiency in. Imagine being in the zone, coding a complex concept only to be interrupted because you don’t know how to print an array or why you can’t mutate a reference. Also, reading language documentation is not always the easiest thing ever, and as well written as Rust’s documentation is, I frequently felt lost and confused when learning new concepts and how they tie in together.

In conclusion, I think if you’re trying to learn a programming language, remember the concepts, learn how they can be applied and not afraid of taking your time to grind and suffer, then I’d say go ahead and build something with the language. However, if you just want to learn the syntax of the languages, learn it fast and without the frustration of reading the documentation, maybe a combination of YouTube tutorials and a project where you follow someone’s baseline code is the better and faster option.

Disclaimer: Work in progress

At the time of writing this blog, I am still working through the gradient descent and optimization of the model, but since I’m just too excited to share the experiences I’ve learned so far, I decided to post the blog first. I will update the blog to include this section once I’m completely done with it. For the conclusion, I think I have enough information to support my views right now, but it may change after the completion of the project.

P/S: I coded the project using Neovim btw 🤓



Source link