""" Generate VHDL or Verilog code for a signed multiplier. Usage: genmul.py --lang=vhdl|verilog [--nolib] Xbits Ybits npipe Usage: genmul.py --lang=vhdl|verilog --lib --lang=... Specify VHDL or Verilog code --nolib Do not generate library components --lib Generate only library components Xbits Length of input word in bits (Xbits >= 4) Ybits Length of input word in bits (Ybits >= Xbits) npipe Number of register stages (0 or 1 or 2) See also: G. Knagge, "ASIC Design for Signal Processing", http://www.geoffknagge.com/fyp/booth.shtml, 2010. L. Dadda, "Some schemes for parallel multipliers", Associazione Elettrotecnica et Elettronica Italiana, 1965. R. P. Brent, H. T. Kung, "A Regular Layout for Parallel Adders", IEEE Transactions on Computers, 1982. """ # # Copyright 2016 Joris van Rantwijk # # Copying and distribution of this file, with or without modification, # are permitted in any medium without royalty provided the copyright # notice and this notice are preserved. # import sys import argparse class Expr: """Represent a node in the expression tree.""" wire = None done = False class ConstBit(Expr): """Represent constant '0' or '1' bit.""" def __init__(self, v): assert v in (0, 1) self.v = v class InBit(Expr): """Represent an input bit in the expression tree.""" def __init__(self, xy, p): assert xy in ('x', 'y') self.xy = xy self.p = p class Reg(Expr): """Represent flip-flop.""" def __init__(self, v): self.v = v class NotBit(Expr): """Represent inverter.""" def __init__(self, v): self.v = v class BoothNeg(Expr): """Represent calculation of radix-4 Booth sign-inversion flag.""" def __init__(self, pat): assert len(pat) == 3 self.pat = pat class BoothProd(Expr): """Represent calculation of partial product bit with radix-4 Booth.""" def __init__(self, pat, b): assert len(pat) == 3 assert len(b) == 2 self.pat = pat self.b = b class AddBitD(Expr): """Represent selection of data bit from adder.""" def __init__(self, v): self.v = v class AddBitC(Expr): """Represent selection of carry bit from adder.""" def __init__(self, v): self.v = v class HalfAdd(Expr): """Represent half adder.""" def __init__(self, a, b): self.a = a self.b = b class FullAdd(Expr): """Represent full adder.""" def __init__(self, a, b, c): self.a = a self.b = b self.c = c class CarryProp(Expr): """Represent base node of carry propagation tree.""" def __init__(self, a, b): self.a = a self.b = b class CarryMerge(Expr): """Represent internal node of carry propagation tree.""" def __init__(self, p0, p1): self.p0 = p0 self.p1 = p1 class CarryEval(Expr): """Represent logic to calculate carry-out.""" def __init__(self, p, c): self.p = p self.c = c def gen_partial_products(xvec, yvec): """Generate list of partial products using radix-4 Booth algorithm. Return [ (exponent, bit), ... ]. """ partial_products = [ ] # Append zero on LSB side of xvec, sign-extend on MSB side of xvec. xtmp = [ ConstBit(0) ] + xvec + xvec[-1:] # Append zero on LSB side of yvec, sign-extend on MSB side of yvec. ytmp = [ ConstBit(0) ] + yvec + yvec[-1:] # Step through xvec, 2 bits at a time. for i in xrange(0, len(xvec), 2): # Select group of 3 bits from xvec (one bit overlap with last group). pat = xtmp[i:i+3] # Add either 0, +1, +2, -1 or -2 times yvec according to Booth method. # Step through the bits of yvec. for j in xrange(len(yvec)+1): # Use Booth encoder to choose between 0, yvec[j], yvec[j-1] or # inverted bits yvec[j] or yvec[j-1]. t = BoothProd(pat, ytmp[j:j+2]) # Invert the MSB bit, except on first row. if i > 0 and j == len(yvec): t = NotBit(t) # Add result as partial product. partial_products.append( (i+j, t) ) # For first row, sign-extend by two bits. # Apply sign inversion on the new MSB bit. if i == 0: partial_products.append( (i+len(yvec)+1, t) ) partial_products.append( (i+len(yvec)+2, NotBit(t)) ) # For each row except the first row, add constant 1 in the next column. if i > 0: partial_products.append( (i+len(yvec)+1, ConstBit(1)) ) # Use Booth encoder to add 1 in case of negative factor (-1 or -2). t = BoothNeg(pat) partial_products.append( (i, t) ) return partial_products def gen_dadda_tree(partial_products, nbits): """Generate carry save adder based on Dadda tree.""" # Sort partial products by bit position. tvec = [ [ ] for p in xrange(nbits) ] for (p, b) in partial_products: if p < nbits: tvec[p].append(b) # Build Dadda tree. while any([ len(t) > 3 for t in tvec ]): # New layer. nvec = [ [ ] for p in xrange(nbits+1) ] for p in xrange(nbits): t = tvec[p] i = 0 while i + 2 < len(t): # build full adder a = FullAdd(t[i], t[i+1], t[i+2]) nvec[p].append(AddBitD(a)) nvec[p+1].append(AddBitC(a)) i += 3 if i + 1 < len(t) and len(nvec[p]) % 3 == 2: # build half adder a = HalfAdd(t[i], t[i+1]) nvec[p].append(AddBitD(a)) nvec[p+1].append(AddBitC(a)) i += 2 if i < len(t): # pass through nvec[p] += t[i:] tvec = nvec[:nbits] # Last layer. nvec = [ [ ] for p in xrange(nbits+1) ] for p in xrange(nbits): t = tvec[p] if len(t) == 3: # full adder a = FullAdd(t[0], t[1], t[2]) nvec[p].append(AddBitD(a)) nvec[p+1].append(AddBitC(a)) elif len(t) == 2 and len(nvec[p]) > 0: # half adder a = HalfAdd(t[0], t[1]) nvec[p].append(AddBitD(a)) nvec[p+1].append(AddBitC(a)) else: # pass through nvec[p] += t tvec = nvec[:nbits] # Extract remaining two rows of bits. avec = [ (t[0] if len(t) > 0 else ConstBit(0)) for t in tvec ] bvec = [ (t[1] if len(t) > 1 else ConstBit(0)) for t in tvec ] return (avec, bvec) def gen_adder(avec, bvec): """Generate carry-lookahead adder.""" def carry_lookahead(pvec, cin): """Recursively determine carry propagation.""" if len(pvec) == 1: prop = pvec[0] cvec = [ cin ] else: k = (len(pvec) + 1) // 2 (p0, c0) = carry_lookahead(pvec[:k], cin) ctmp = CarryEval(p0, cin) (p1, c1) = carry_lookahead(pvec[k:], ctmp) prop = CarryMerge(p0, p1) cvec = c0 + c1 return (prop, cvec) assert len(avec) == len(bvec) # Determine carry-generate and carry-propagate for each position. pvec = [ CarryProp(a, b) for (a, b) in zip(avec, bvec) ] # Determine carry-in for each position. (prop, cvec) = carry_lookahead(pvec, ConstBit(0)) # Array of full adders. sumvec = [ AddBitD(FullAdd(a, b, c)) for (a, b, c) in zip(avec, bvec, cvec) ] return sumvec def gen_multiplier(xbits, ybits, npipe): """Generate expression tree describing multiplier logic.""" xvec = [ InBit('x', p) for p in xrange(xbits) ] yvec = [ InBit('y', p) for p in xrange(ybits) ] partial_products = gen_partial_products(xvec, yvec) (avec, bvec) = gen_dadda_tree(partial_products, xbits+ybits) if npipe > 0: avec = [ Reg(a) for a in avec ] bvec = [ Reg(b) for b in bvec ] zvec = gen_adder(avec, bvec) if npipe > 1: zvec = [ Reg(z) for z in zvec ] return zvec def gen_netlist(node, wires, insts): """Generate netlist consisting of wires and component instances.""" if node.done: # already processed this node return node.done = True if isinstance(node, ConstBit): # resolve during code generation node.wire = node elif isinstance(node, InBit): # resolve during code generation node.wire = node elif isinstance(node, Reg): # create output wire node.wire = 'wreg%d' % len(wires) wires.append(node.wire) # recurse gen_netlist(node.v, wires, insts) # create instance insts.append(node) elif isinstance(node, NotBit): # create output wire node.wire = 'winv%d' % len(wires) wires.append(node.wire) # recurse gen_netlist(node.v, wires, insts) # create instance insts.append(node) elif isinstance(node, BoothNeg): # create output wire node.wire = 'wboothneg%d' % len(wires) wires.append(node.wire) # recurse for v in node.pat: gen_netlist(v, wires, insts) # create instance insts.append(node) elif isinstance(node, BoothProd): # create output wire node.wire = 'wboothprod%d' % len(wires) wires.append(node.wire) # recurse for v in node.pat: gen_netlist(v, wires, insts) for v in node.b: gen_netlist(v, wires, insts) # create instance insts.append(node) elif isinstance(node, AddBitD): # recurse gen_netlist(node.v, wires, insts) node.wire = node.v.wire + 'd' elif isinstance(node, AddBitC): # recurse gen_netlist(node.v, wires, insts) node.wire = node.v.wire + 'c' elif isinstance(node, HalfAdd): # create output wires node.wire = 'wadd%d' % len(wires) wires.append(node.wire + 'd') wires.append(node.wire + 'c') # recurse gen_netlist(node.a, wires, insts) gen_netlist(node.b, wires, insts) # create instance insts.append(node) elif isinstance(node, FullAdd): # create output wires node.wire = 'wadd%d' % len(wires) wires.append(node.wire + 'd') wires.append(node.wire + 'c') # recurse gen_netlist(node.a, wires, insts) gen_netlist(node.b, wires, insts) gen_netlist(node.c, wires, insts) # create instance insts.append(node) elif isinstance(node, CarryProp): # create output wires node.wire = 'wcarry%d' % len(wires) wires.append(node.wire + 'g') wires.append(node.wire + 'p') # recurse gen_netlist(node.a, wires, insts) gen_netlist(node.b, wires, insts) # create instance insts.append(node) elif isinstance(node, CarryMerge): # create output wires node.wire = 'wcarry%d' % len(wires) wires.append(node.wire + 'g') wires.append(node.wire + 'p') # recurse gen_netlist(node.p0, wires, insts) gen_netlist(node.p1, wires, insts) # create instance insts.append(node) elif isinstance(node, CarryEval): # create output wire node.wire = 'wcarry%d' % len(wires) wires.append(node.wire) # recurse gen_netlist(node.p, wires, insts) gen_netlist(node.c, wires, insts) # create instance insts.append(node) else: assert False def vhdl_inst(node): """Return (name, ports) for a given instance.""" if isinstance(node, Reg): name = 'smul_flipflop' ports = ( 'clk', 'clken', node.v.wire, node.wire ) elif isinstance(node, NotBit): name = 'smul_inverter' ports = ( node.v.wire, node.wire ) elif isinstance(node, BoothNeg): name = 'smul_booth_neg' ports = ( node.pat[0].wire, node.pat[1].wire, node.pat[2].wire, node.wire ) elif isinstance(node, BoothProd): name = 'smul_booth_prod' ports = ( node.pat[0].wire, node.pat[1].wire, node.pat[2].wire, node.b[0].wire, node.b[1].wire, node.wire ) elif isinstance(node, HalfAdd): name = 'smul_half_add' ports = ( node.a.wire, node.b.wire, node.wire + 'd', node.wire + 'c' ) elif isinstance(node, FullAdd): name = 'smul_full_add' ports = ( node.a.wire, node.b.wire, node.c.wire, node.wire + 'd', node.wire + 'c' ) elif isinstance(node, CarryProp): name = 'smul_carry_prop' ports = ( node.a.wire, node.b.wire, node.wire + 'g', node.wire + 'p' ) elif isinstance(node, CarryMerge): name = 'smul_carry_merge' ports = ( node.p0.wire + 'g', node.p0.wire + 'p', node.p1.wire + 'g', node.p1.wire + 'p', node.wire + 'g', node.wire + 'p' ) elif isinstance(node, CarryEval): name = 'smul_carry_eval' ports = ( node.p.wire + 'g', node.p.wire + 'p', node.c.wire, node.wire ) else: assert False return (name, ports) def vhdl_wire(wire): """Resolve wire to VHDL expression string.""" if isinstance(wire, ConstBit): return "'%d'" % wire.v elif isinstance(wire, InBit): return "%sin(%d)" % (wire.xy, wire.p) else: assert isinstance(wire, str) return wire def gen_vhdl_lib(): """Generate VHDL code for library components.""" print """ -- -- Flip-flop. -- library ieee; use ieee.std_logic_1164.all; entity smul_flipflop is port ( clk: in std_ulogic; clken: in std_ulogic; d: in std_ulogic; q: out std_ulogic ); end entity; architecture smul_flipflop_arch of smul_flipflop is begin process (clk) is begin if rising_edge(clk) then if to_x01(clken) = '1' then q <= d; end if; end if; end process; end architecture; -- -- Inverter. -- library ieee; use ieee.std_logic_1164.all; entity smul_inverter is port ( d: in std_ulogic; q: out std_ulogic ); end entity; architecture smul_inverter_arch of smul_inverter is begin q <= not d; end architecture; -- -- Half-adder. -- library ieee; use ieee.std_logic_1164.all; entity smul_half_add is port ( x: in std_ulogic; y: in std_ulogic; d: out std_ulogic; c: out std_ulogic ); end entity; architecture smul_half_add_arch of smul_half_add is begin d <= x xor y; c <= x and y; end architecture; -- -- Full-adder. -- library ieee; use ieee.std_logic_1164.all; entity smul_full_add is port ( x: in std_ulogic; y: in std_ulogic; z: in std_ulogic; d: out std_ulogic; c: out std_ulogic ); end entity; architecture smul_full_add_arch of smul_full_add is begin d <= x xor y xor z; c <= (x and y) or (y and z) or (x and z); end architecture; -- -- Booth negative flag. -- library ieee; use ieee.std_logic_1164.all; entity smul_booth_neg is port ( p0: in std_ulogic; p1: in std_ulogic; p2: in std_ulogic; f: out std_ulogic ); end entity; architecture smul_booth_neg_arch of smul_booth_neg is begin f <= p2 and ((not p1) or (not p0)); end architecture; -- -- Booth partial product generation. -- library ieee; use ieee.std_logic_1164.all; entity smul_booth_prod is port ( p0: in std_ulogic; p1: in std_ulogic; p2: in std_ulogic; b0: in std_ulogic; b1: in std_ulogic; y: out std_ulogic ); end entity; architecture smul_booth_prod_arch of smul_booth_prod is begin process (p0, p1, p2, b0, b1) is variable p: std_ulogic_vector(2 downto 0); begin p := (p2, p1, p0); case p is when "000" => y <= '0'; -- factor 0 when "001" => y <= b1; -- factor 1 when "010" => y <= b1; -- factor 1 when "011" => y <= b0; -- factor 2 when "100" => y <= not b0; -- factor -2 when "101" => y <= not b1; -- factor -1 when "110" => y <= not b1; -- factor -1 when others => y <= '0'; -- factor 0 end case; end process; end architecture; -- -- Determine carry generate and carry propagate. -- library ieee; use ieee.std_logic_1164.all; entity smul_carry_prop is port ( a: in std_ulogic; b: in std_ulogic; g: out std_ulogic; p: out std_ulogic ); end entity; architecture smul_carry_prop of smul_carry_prop is begin g <= a and b; p <= a xor b; end architecture; -- -- Merge two carry propagation trees. -- library ieee; use ieee.std_logic_1164.all; entity smul_carry_merge is port ( g0: in std_ulogic; p0: in std_ulogic; g1: in std_ulogic; p1: in std_ulogic; g: out std_ulogic; p: out std_ulogic ); end entity; architecture smul_carry_merge of smul_carry_merge is begin g <= g1 or (g0 and p1); p <= p0 and p1; end architecture; -- -- Calculate carry-out through a carry propagation tree. -- library ieee; use ieee.std_logic_1164.all; entity smul_carry_eval is port ( g: in std_ulogic; p: in std_ulogic; cin: in std_ulogic; cout: out std_ulogic ); end entity; architecture smul_carry_eval of smul_carry_eval is begin cout <= g or (p and cin); end architecture; """ def gen_vhdl_mul(xbits, ybits, npipe, wires, insts, outputs): """Generate VHDL code and write to stdout.""" # Declaration. print """ --- --- %(xbits)d x %(ybits)d bit signed multiplier --- --- %(npipe)d cycles pipeline delay --- library ieee; use ieee.std_logic_1164.all; entity smul_%(xbits)d_%(ybits)d is port ( clk: in std_ulogic; clken: in std_ulogic; xin: in std_logic_vector(%(xleft)d downto 0); yin: in std_logic_vector(%(yleft)d downto 0); zout: out std_logic_vector(%(zleft)d downto 0) ); end entity; architecture arch of smul_%(xbits)d_%(ybits)d is """ % { 'xbits': xbits, 'ybits': ybits, 'npipe': npipe, 'xleft': xbits-1, 'yleft': ybits-1, 'zleft': xbits+ybits-1 } # Declare signals. for w in wires: print "signal %s: std_ulogic;" % w # Start architecture body. print print "begin" print # Instantiate components. for (i, node) in enumerate(insts): (name, ports) = vhdl_inst(node) print "u%d: entity work.%s port map (" % (i, name), print ", ".join([ vhdl_wire(p) for p in ports ]), print ");" print # Drive output signals. for (i, wire) in enumerate(outputs): print "zout(%d) <= %s;" % (i, vhdl_wire(wire)) # End architecture. print print "end architecture;" def verilog_wire(wire): """Resolve wire to Verilog expression string.""" if isinstance(wire, ConstBit): return "1'b%d" % wire.v elif isinstance(wire, InBit): return "%sin[%d]" % (wire.xy, wire.p) else: assert isinstance(wire, str) return wire def gen_verilog_lib(): """Generate Verilog code for library components.""" print """ // Flip-flop. module smul_flipflop ( input wire clk, input wire clken, input wire d, output reg q ); always @(posedge clk) begin if (clken) q <= d; end endmodule // Inverter. module smul_inverter ( input wire d, output wire q ); assign q = ~d; endmodule // Half-adder. module smul_half_add ( input wire x, input wire y, output wire d, output wire c ); assign d = x ^ y; assign c = x & y; endmodule // Full-adder. module smul_full_add ( input wire x, input wire y, input wire z, output wire d, output wire c ); assign d = x ^ y ^ z; assign c = (x & y) | (y & z) | (x & z); endmodule // Booth negative flag. module smul_booth_neg ( input wire p0, input wire p1, input wire p2, output wire f ); assign f = p2 & ((~p1) | (~p0)); endmodule // Booth partial product generator. module smul_booth_prod ( input wire p0, input wire p1, input wire p2, input wire u0, input wire u1, output reg y ); always @ (*) begin case ({p2, p1, p0}) 3'b000 : y = 1'b0; 3'b001 : y = u1; 3'b010 : y = u1; 3'b011 : y = u0; 3'b100 : y = ~u0; 3'b101 : y = ~u1; 3'b110 : y = ~u1; default : y = 1'b0; endcase end endmodule // Deterimine carry generate and carry propagate. module smul_carry_prop ( input wire a, input wire b, output wire g, output wire p ); assign g = a & b; assign p = a ^ b; endmodule // Merge two carry propagation trees. module smul_carry_merge ( input wire g0, input wire p0, input wire g1, input wire p1, output wire g, output wire p ); assign g = g1 | (g0 & p1); assign p = p0 & p1; endmodule // Calculate carry-out through a carry propagation tree. module smul_carry_eval ( input wire g, input wire p, input wire cin, output wire cout ); assign cout = g | (p & cin); endmodule """ def gen_verilog_mul(xbits, ybits, npipe, wires, insts, outputs): """Generate Verilog code and write to stdout.""" # Preamble. print """ /* * %(xbits)d x %(ybits)d bit signed multiplier * * %(npipe)d cycles pipeline delay */ module smul_%(xbits)d_%(ybits)d ( input wire clk, input wire clken, input wire [%(xleft)d:0] xin, input wire [%(yleft)d:0] yin, output wire [%(zleft)d:0] zout ); """ % { 'xbits': xbits, 'ybits': ybits, 'npipe': npipe, 'xleft': xbits-1, 'yleft': ybits-1, 'zleft': xbits+ybits-1 } # Declare signals. for w in wires: print "wire %s;" % w # Instantiate components. for (i, node) in enumerate(insts): (name, ports) = vhdl_inst(node) print "%s u%d (" % (name, i), print ", ".join([ verilog_wire(p) for p in ports ]), print ");" print # Drive output signals. for (i, wire) in enumerate(outputs): print "assign zout[%d] = %s;" % (i, verilog_wire(wire)) # End module. print print "endmodule" def main(): parser = argparse.ArgumentParser() parser.format_help = lambda: __doc__ parser.format_usage = lambda: __doc__ parser.add_argument('--lang', action='store', type=str) parser.add_argument('--nolib', action='store_true', default=False) parser.add_argument('--lib', action='store_true', default=False) parser.add_argument('Xbits', action='store', type=int, nargs='?') parser.add_argument('Ybits', action='store', type=int, nargs='?') parser.add_argument('npipe', action='store', type=int, nargs='?') args = parser.parse_args() if args.lang is None or args.lang.upper() not in ('VHDL', 'VERILOG'): print >>sys.stderr, __doc__ print >>sys.stderr, "ERROR: Must specify --lang=vhdl or --lang=verilog" sys.exit(1) if args.lib: if (args.nolib or args.Xbits is not None or args.Ybits is not None or args.npipe is not None): print >>sys.stderr, __doc__ print >>sys.stderr, "ERROR: Must specify either --lib or", print >>sys.stderr, "Xbits, Ybits, npipe" sys.exit(1) else: if (args.Xbits is None or args.Ybits is None or args.npipe is None): print >>sys.stderr, __doc__ print >>sys.stderr, "ERROR: Must specify either --lib or", print >>sys.stderr, "Xbits, Ybits, npipe" sys.exit(1) if args.Xbits < 4 or args.Ybits < args.Xbits: print >>sys.stderr, "ERROR: invalid word lengths" sys.exit(1) if args.npipe < 0 or args.npipe > 2: print >>sys.stderr, "ERROR: invalid number of register stages" sys.exit(1) if not args.lib: # Generate expression tree. zvec = gen_multiplier(args.Xbits, args.Ybits, args.npipe) # Generate wires and instances. wires = [ ] insts = [ ] for node in zvec: gen_netlist(node, wires, insts) outputs = [ node.wire for node in zvec ] # Write library components. if not args.nolib: if args.lang.upper() == 'VHDL': gen_vhdl_lib() elif args.lang.upper() == 'VERILOG': gen_verilog_lib() # Write multiplier. if not args.lib: if args.lang.upper() == 'VHDL': gen_vhdl_mul(args.Xbits, args.Ybits, args.npipe, wires, insts, outputs) elif args.lang.upper() == 'VERILOG': gen_verilog_mul(args.Xbits, args.Ybits, args.npipe, wires, insts, outputs) if __name__ == '__main__': main() # end