Optimizing an STDP Learning Engine for FPGAs

11 minute read

Published:

Written by Aaryan Dhawan, in collaboration with Muhammad Farhan Azmine

What is STDP?

Spike-Timing-Dependent Plasticity (STDP) is a biological learning rule that adjusts the strength of connections between neurons based on the relative timing of their spikes. In simple terms within a pair of neurons, if the presynaptic neuron fires before the post synaptic neuron, the connection between them is strengthened. Conversely if the presynaptic neuron fires later than the post synaptic neuron, the connection is weakened. In machine learning, STDP is used to train some Spiking Neural Networks (SNNs) by mimicking this biological process, allowing the network to learn from temporal patterns in data. Using temporal patterns to learn, STDP allows SNNs to process information similarly to the human brain, making them suitable for tasks like real-time learning, robotics, and in edge computing.

On-Chip STDP Learning

On-chip STDP learning refers to the implementation of the STDP learning algorithm directly within machine learning hardware accelerators, rather than training the network in software and transferring it to hardware. This approach is particularly beneficial for applications with limited data availability or when power consumption is a concern. In 2019, researches at Texas A&M University developed FPGA efficient, STDP based SNN models for both supervised and unsupervised learning 1 with Liquid State machines (LSMs). In this work, they build a STDP learning engine that takes the spike timing difference between the current neuron output and the presynaptic neuron, and uses that to determine between strengthening (Potentiation) or weakening (Depression) the connection between the two neurons. This adjustment value is determined by the following equations:

STDP Potentiation and Depression Equations

\(\Delta w = \begin{cases} \eta_{+} \cdot e^{-\frac{t_{pre}-t_{post}}{\tau_{+}}} & \text{if } t_{pre} < t_{post} \\ -\eta_{-} \cdot e^{-\frac{t_{post}-t_{pre}}{\tau_{-}}} & \text{if } t_{pre} > t_{post} \end{cases}\)

where:

  • \(\Delta w\) is the change in weight
  • \(t_{pre}\) is the time of the pre-synaptic spike
  • \(t_{post}\) is the time of the post-synaptic/current neuron spike
  • \(\tau_{+}\) is the time constant for potentiation
  • \(\tau_{-}\) is the time constant for depression
  • \(\eta_{+}\) is the amplitude of the potentiation curve
  • \(\eta_{-}\) is the amplitude of the depression curve

Over a given sample period, each neuron will produce a spike train, which is compared against the spike generated from the presynaptic neurons to compute a timing difference values, \(\Delta t\). This \(\Delta t\) is used to calculate the potentiation or depression value, strengthening or weakening a given connection. Determining potentiation vs depression can be visualized as follows: STDP Simulation If the presynaptic neuron spikes before the post-synaptic neuron, the timing difference \(\Delta t\) is positive, indicating that the connection should be strengthened (potentiation). Conversely, if the presynaptic neuron spikes after the post-synaptic neuron, the timing difference \(\Delta t\) is negative, indicating that the connection should be weakened (depression).

Original Hardware Design

In 1, the authors develop a “hardware-friendly STDP learning engine”, recreated in the simplified diagram below. STDP Learning Engine Although the referred diagram from 1 shows a simplified architecture, the actual datapath should be as the following, including the registers within the datapath to avoid metastability issues: STDP Learning Engine with Register Discrete spike trains from the presynaptic neurons and post-synaptic/current neuron outputs are fed into Serial-In/Parallel-Out (SIPO) shift registers, the “length” of the registers are equal to the length of the timing window memory between the post synaptic spike and the presynaptic spikes. The shift registers corresponding to the presynaptic neurons are then selected by an M to 1 multiplexer, where M is the number of presynaptic neurons connected to the current post synaptic neuron, which possessing the learning engine. Once one of the presynaptic shift registers is selected and its data is passed through the multiplexer, the contents of the selected presynaptic and the post synaptic shift registers are compared to each other to determine the \(\Delta t\) to adjust that given weight. This is done by comparator, which looks at the Most Significant bit(MSB) of the two spike trains and looks at which of the two is high/Logic-1. If the Pre-synaptic shift register MSB is high, the content of the post-synaptic shift register will be sent to the priority encoder and the produced output \(\Delta t\) will be a negative timing difference. Vice versa, if the Post-synaptic shift register MSB is high, the content of the pre-synaptic shift register will be sent to the priority encoder and the produced output \(\Delta t\) will be a positive timing difference. We can model this design in the following verilog code:

always_ff @(posedge clk) begin
        if (rst) begin
            data_reg <= 16'b0;
            flag <= 1'b0;
            LUT_out_reg <= 20'b0;
        end else begin
            if(Post[15]==1) begin
                data_reg <= Pre;
                flag <= 1'b0; // Post-synaptic neuron is active, set flag to 0
            end
            else if (Pre[15] == 1) begin
                data_reg <= Post;
                flag <= 1'b1; // Pre-synaptic neuron is active, set flag to 1
            end
            else begin
                data_reg <= 16'b0; // No active neuron, reset to zero
                flag <= 1'b0; // Reset flag
            end
            // Update LUT output register
            LUT_out_reg <= LUT_out; // Store LUT output in register
        end
    end

    always_comb begin
        if(flag) begin
            unique casez(data_reg)
                16'b1???????????????: TimingDifference = 6'sd0; // 0
                16'b01??????????????: TimingDifference = 6'sd1; // 1
                16'b001?????????????: TimingDifference = 6'sd2; // 2
                16'b0001????????????: TimingDifference = 6'sd3; // 3
                16'b00001???????????: TimingDifference = 6'sd4; // 4
                16'b000001??????????: TimingDifference = 6'sd5; // 5
                16'b0000001?????????: TimingDifference = 6'sd6; // 6
                16'b00000001????????: TimingDifference = 6'sd7; // 7
                16'b000000001???????: TimingDifference = 6'sd8; // 8
                16'b0000000001??????: TimingDifference = 6'sd9; // 9
                16'b00000000001?????: TimingDifference = 6'sd10; // 10
                16'b000000000001????: TimingDifference = 6'sd11; // 11
                16'b0000000000001???: TimingDifference = 6'sd12; // 12
                16'b00000000000001??: TimingDifference = 6'sd13; // 13
                16'b000000000000001?: TimingDifference = 6'sd14; // 14
                16'b0000000000000001: TimingDifference = 6'sd15; // 15
                default:              TimingDifference = 6'sd0; // Default case
            endcase
        end
        else if (!flag) begin
            unique casez(data_reg)
                    16'b1???????????????: TimingDifference = -6'sd0; // 0
                    16'b01??????????????: TimingDifference = -6'sd1; // 1
                    16'b001?????????????: TimingDifference = -6'sd2; // 2
                    16'b0001????????????: TimingDifference = -6'sd3; // 3
                    16'b00001???????????: TimingDifference = -6'sd4; // 4
                    16'b000001??????????: TimingDifference = -6'sd5; // 5
                    16'b0000001?????????: TimingDifference = -6'sd6; // 6
                    16'b00000001????????: TimingDifference = -6'sd7; // 7
                    16'b000000001???????: TimingDifference = -6'sd8; // 8
                    16'b0000000001??????: TimingDifference = -6'sd9; // 9
                    16'b00000000001?????: TimingDifference = -6'sd10; // 10
                    16'b000000000001????: TimingDifference = -6'sd11; // 11
                    16'b0000000000001???: TimingDifference = -6'sd12; // 12
                    16'b00000000000001??: TimingDifference = -6'sd13; // 13
                    16'b000000000000001?: TimingDifference = -6'sd14; // 14
                    16'b0000000000000001: TimingDifference = -6'sd15; // 15
                    default:              TimingDifference = -6'sd0; // Default case
                endcase
        end

Rather than implementing hardware to compute the potention/depression values with the 2 equations above, the authors use a single Look-Up Table (LUT), containing both positive and negative values to easily compute \(\Delta w\) as there is only a certain range of possible values. \(\Delta w\) is then added to synaptic connection, read from a memory block, using a simple adder. The synaptic connection is then written back to the memory block, completing the learning process for that given connection. The LUT and Adder are implemented as follows:

 unique casez(TimingDifference)
            -'sd15:  LUT_out = -'sd22;
            -'sd14:  LUT_out = -'sd25;
            -'sd13:  LUT_out = -'sd27;
            -'sd12:  LUT_out = -'sd30;
            -'sd11:  LUT_out = -'sd33;
            -'sd10:  LUT_out = -'sd37;
            -'sd9:   LUT_out = -'sd41;
            -'sd8:   LUT_out = -'sd45;
            -'sd7:   LUT_out = -'sd50;
            -'sd6:   LUT_out = -'sd55;
            -'sd5:   LUT_out = -'sd61;
            -'sd4:   LUT_out = -'sd67;
            -'sd3:   LUT_out = -'sd74;
            -'sd2:   LUT_out = -'sd82;
            -'sd1:   LUT_out = -'sd90;
             'sd0:   LUT_out =  'sd0;
             'sd1:   LUT_out =  'sd90;
             'sd2:   LUT_out =  'sd82;
             'sd3:   LUT_out =  'sd74;
             'sd4:   LUT_out =  'sd67;
             'sd5:   LUT_out =  'sd61;
             'sd6:   LUT_out =  'sd55;
             'sd7:   LUT_out =  'sd50;    
             'sd8:   LUT_out =  'sd45;    
             'sd9:   LUT_out =  'sd41;    
             'sd10:  LUT_out =  'sd37;   
             'sd11:  LUT_out =  'sd33;   
             'sd12:  LUT_out =  'sd30;   
             'sd13:  LUT_out =  'sd27;   
             'sd14:  LUT_out =  'sd25;
             'sd15:  LUT_out =  'sd22;
            default: LUT_out =  'sd0;
        endcase
        NewW = CurrentW + LUT_out_reg;

Optimizing the Learning Engine

Looking at this design, we can identify a few areas for optimization, specifically in the timing difference calculation and the potentiation/depression value calculation. Seeing that the depression values are simply the negative of the potentiation values, we can reduce the LUT size by only storing the absolute weight change values and using a signed flag to determine if the value should be added or subtracted from the current weight. Moreover, we can then simplify the priority encoder size since it only needs to generate the absolute value of the timing difference, rather than both positive and negative values, and add simple logic to generate the signed flag.

Optimizations:

  • Reduce the Potentiation/Depression LUT size
  • Replace a priority encoder with a single-bit, flag signal.

This results in the following design Optimized STDP Learning Engine

Below is the simplified Verilog code for the optimized priority encoder and signed flag logic:

logic [15:0] data_reg;
    always_ff @(posedge clk) begin
        if (rst) begin
            data_reg <= 16'b0;
            SignedFlag <= 0; // Reset SignedFlag on reset
        end else begin
            if (PostSynapticNeuron[15] == 1) begin
                data_reg <= PreSynapticNeuron;
                SignedFlag <= 0; // PostSynapticNeuron is active, set SignedFlag
            end else if (PreSynapticNeuron[15] == 1) begin
                data_reg <= PostSynapticNeuron;
                SignedFlag <= 1; // PreSynapticNeuron is active, set SignedFlag
            end 
        end
    end

    always_comb begin
        unique casez(data_reg)
            16'b1???????????????: TimingDifference = 4'd0; // 0
            16'b01??????????????: TimingDifference = 4'd1; // 1
            16'b001?????????????: TimingDifference = 4'd2; // 2
            16'b0001????????????: TimingDifference = 4'd3; // 3
            16'b00001???????????: TimingDifference = 4'd4; // 4
            16'b000001??????????: TimingDifference = 4'd5; // 5
            16'b0000001?????????: TimingDifference = 4'd6; // 6
            16'b00000001????????: TimingDifference = 4'd7; // 7
            16'b000000001???????: TimingDifference = 4'd8; // 8
            16'b0000000001??????: TimingDifference = 4'd9; // 9
            16'b00000000001?????: TimingDifference = 4'd10; // 10
            16'b000000000001????: TimingDifference = 4'd11; // 11
            16'b0000000000001???: TimingDifference = 4'd12; // 12
            16'b00000000000001??: TimingDifference = 4'd13; // 13
            16'b000000000000001?: TimingDifference = 4'd14; // 14
            16'b0000000000000001: TimingDifference = 4'd15; // 15
            default:              TimingDifference = 4'd0; // Default case
        endcase 
    end

The potentiation LUT is now reduced to 16 entries, as the depression values are not needed. The adder now takes in the absolute weight change value and a signed flag, which determines if the value should be added/subtracted from the current weight according to potentiation/depression. The adder and LUT can be implemented as follows:

always_comb begin
      unique case (potDepIndex)
        4'd0: potDepValue = 'sd0;  
        4'd1: potDepValue = 'sd90;    
        4'd2: potDepValue = 'sd82;    
        4'd3: potDepValue = 'sd74;    
        4'd4: potDepValue = 'sd67;    
        4'd5: potDepValue = 'sd61;    
        4'd6: potDepValue = 'sd55;    
        4'd7: potDepValue = 'sd50;    
        4'd8: potDepValue = 'sd45;    
        4'd9: potDepValue = 'sd41;    
        4'd10: potDepValue = 'sd37;   
        4'd11: potDepValue = 'sd33;   
        4'd12: potDepValue = 'sd30;   
        4'd13: potDepValue = 'sd27;   
        4'd14: potDepValue = 'sd25;
        4'd15: potDepValue = 'sd22;   
        default: potDepValue = 'sd0;  
    endcase
end
   always_comb begin
        NewW = CurrentW + (SignedFlag ? -DeltaW : DeltaW); // SignedFlag = 1 means subtract
    end

FPGA Utilization

We designed, synthesized, and implemented the original and optimized STDP learning engines on a Xilinx Zynq-7000 SoC “ZC702” Evaulation board. The following table shows the FPGA resource utilization for both designs:

Resource TypeOriginal DesignOptimized Design
LUTs179141
FFs347341
Fmax187 MHz205 MHz

From the table, we can see that the optimized design has reduced LUT utilization by 23.75% and increase the Fmax by 10.25% due to reduction in the critical path, while manintaing the a similar number of flip-flops (FFs).

Critical Path Analysis

To understand the increase in Fmax, we can analyze the critical path of both designs. The critical path is the longest path delay through the circuit that determines the maximum frequency at which the circuit can operate. In the original design, the critical path reported from Vivado is 5.393 ns as shown in the attached image from Vivado’s Timing Summary. Original Critical Path In the optimized design, the critical path is reduced to 4.801 ns, shown in the Vivado Timing Summary below. Optimized Critical Path