hdl-multiplier/genmul.py

1031 lines
25 KiB
Python

"""
Generate VHDL or Verilog code for a signed multiplier.
Usage: genmul.py --lang=vhdl|verilog [--nolib] Xbits Ybits npipe
Usage: genmul.py --lang=vhdl|verilog --lib
--lang=... Specify VHDL or Verilog code
--nolib Do not generate library components
--lib Generate only library components
Xbits Length of input word in bits (Xbits >= 4)
Ybits Length of input word in bits (Ybits >= Xbits)
npipe Number of register stages (0 or 1 or 2)
See also:
G. Knagge, "ASIC Design for Signal Processing",
http://www.geoffknagge.com/fyp/booth.shtml, 2010.
X. Xiong, M. Lin, "Low Power 8-bit Baugh-Wooley Multiplier Based on Wallace
Tree Architecture", Lecture Notes in Electrical Engineering, 2012.
L. Dadda, "Some schemes for parallel multipliers",
Associazione Elettrotecnica et Elettronica Italiana, 1965.
R. P. Brent, H. T. Kung, "A Regular Layout for Parallel Adders",
IEEE Transactions on Computers, 1982.
"""
import sys
import argparse
class Expr:
"""Represent a node in the expression tree."""
wire = None
done = False
class ConstBit(Expr):
"""Represent constant '0' or '1' bit."""
def __init__(self, v):
assert v in (0, 1)
self.v = v
class InBit(Expr):
"""Represent an input bit in the expression tree."""
def __init__(self, xy, p):
assert xy in ('x', 'y')
self.xy = xy
self.p = p
class Reg(Expr):
"""Represent flip-flop."""
def __init__(self, v):
self.v = v
class NotBit(Expr):
"""Represent inverter."""
def __init__(self, v):
self.v = v
class BoothNeg(Expr):
"""Represent calculation of radix-4 Booth sign-inversion flag."""
def __init__(self, pat):
assert len(pat) == 3
self.pat = pat
class BoothProd(Expr):
"""Represent calculation of partial product bit with radix-4 Booth."""
def __init__(self, pat, b):
assert len(pat) == 3
assert len(b) == 2
self.pat = pat
self.b = b
class AddBitD(Expr):
"""Represent selection of data bit from adder."""
def __init__(self, v):
self.v = v
class AddBitC(Expr):
"""Represent selection of carry bit from adder."""
def __init__(self, v):
self.v = v
class HalfAdd(Expr):
"""Represent half adder."""
def __init__(self, a, b):
self.a = a
self.b = b
class FullAdd(Expr):
"""Represent full adder."""
def __init__(self, a, b, c):
self.a = a
self.b = b
self.c = c
class CarryProp(Expr):
"""Represent base node of carry propagation tree."""
def __init__(self, a, b):
self.a = a
self.b = b
class CarryMerge(Expr):
"""Represent internal node of carry propagation tree."""
def __init__(self, p0, p1):
self.p0 = p0
self.p1 = p1
class CarryEval(Expr):
"""Represent logic to calculate carry-out."""
def __init__(self, p, c):
self.p = p
self.c = c
def gen_partial_products(xvec, yvec):
"""Generate list of partial products using radix-4 Booth algorithm.
Return [ (exponent, bit), ... ].
"""
partial_products = [ ]
# Append zero on LSB side of xvec, sign-extend on MSB side of xvec.
xtmp = [ ConstBit(0) ] + xvec + xvec[-1:]
# Append zero on LSB side of yvec, sign-extend on MSB side of yvec.
ytmp = [ ConstBit(0) ] + yvec + yvec[-1:]
# Step through xvec, 2 bits at a time.
for i in xrange(0, len(xvec), 2):
# Select group of 3 bits from xvec (one bit overlap with last group).
pat = xtmp[i:i+3]
# Add either 0, +1, +2, -1 or -2 times yvec according to Booth method.
# Step through the bits of yvec.
for j in xrange(len(yvec)+1):
# Use Booth encoder to choose between 0, yvec[j], yvec[j-1] or
# inverted bits yvec[j] or yvec[j-1].
t = BoothProd(pat, ytmp[j:j+2])
# Invert the MSB bit, except on first row.
if i > 0 and j == len(yvec):
t = NotBit(t)
# Add result as partial product.
partial_products.append( (i+j, t) )
# For first row, sign-extend by two bits.
# Apply sign inversion on the new MSB bit.
if i == 0:
partial_products.append( (i+len(yvec)+1, t) )
partial_products.append( (i+len(yvec)+2, NotBit(t)) )
# For each row except the first row, add constant 1 in the next column.
if i > 0:
partial_products.append( (i+len(yvec)+1, ConstBit(1)) )
# Use Booth encoder to add 1 in case of negative factor (-1 or -2).
t = BoothNeg(pat)
partial_products.append( (i, t) )
return partial_products
def gen_dadda_tree(partial_products, nbits):
"""Generate carry save adder based on Dadda tree."""
# Sort partial products by bit position.
tvec = [ [ ] for p in xrange(nbits) ]
for (p, b) in partial_products:
if p < nbits:
tvec[p].append(b)
# Build Dadda tree.
while any([ len(t) > 3 for t in tvec ]):
# New layer.
nvec = [ [ ] for p in xrange(nbits+1) ]
for p in xrange(nbits):
t = tvec[p]
i = 0
while i + 2 < len(t):
# build full adder
a = FullAdd(t[i], t[i+1], t[i+2])
nvec[p].append(AddBitD(a))
nvec[p+1].append(AddBitC(a))
i += 3
if i + 1 < len(t) and len(nvec[p]) % 3 == 2:
# build half adder
a = HalfAdd(t[i], t[i+1])
nvec[p].append(AddBitD(a))
nvec[p+1].append(AddBitC(a))
i += 2
if i < len(t):
# pass through
nvec[p] += t[i:]
tvec = nvec[:nbits]
# Last layer.
nvec = [ [ ] for p in xrange(nbits+1) ]
for p in xrange(nbits):
t = tvec[p]
if len(t) == 3:
# full adder
a = FullAdd(t[0], t[1], t[2])
nvec[p].append(AddBitD(a))
nvec[p+1].append(AddBitC(a))
elif len(t) == 2 and len(nvec[p]) > 0:
# half adder
a = HalfAdd(t[0], t[1])
nvec[p].append(AddBitD(a))
nvec[p+1].append(AddBitC(a))
else:
# pass through
nvec[p] += t
tvec = nvec[:nbits]
# Extract remaining two rows of bits.
avec = [ (t[0] if len(t) > 0 else ConstBit(0)) for t in tvec ]
bvec = [ (t[1] if len(t) > 1 else ConstBit(0)) for t in tvec ]
return (avec, bvec)
def gen_adder(avec, bvec):
"""Generate carry-lookahead adder."""
def carry_lookahead(pvec, cin):
"""Recursively determine carry propagation."""
if len(pvec) == 1:
prop = pvec[0]
cvec = [ cin ]
else:
k = (len(pvec) + 1) // 2
(p0, c0) = carry_lookahead(pvec[:k], cin)
ctmp = CarryEval(p0, cin)
(p1, c1) = carry_lookahead(pvec[k:], ctmp)
prop = CarryMerge(p0, p1)
cvec = c0 + c1
return (prop, cvec)
assert len(avec) == len(bvec)
# Determine carry-generate and carry-propagate for each position.
pvec = [ CarryProp(a, b) for (a, b) in zip(avec, bvec) ]
# Determine carry-in for each position.
(prop, cvec) = carry_lookahead(pvec, ConstBit(0))
# Array of full adders.
sumvec = [ AddBitD(FullAdd(a, b, c))
for (a, b, c) in zip(avec, bvec, cvec) ]
return sumvec
def gen_multiplier(xbits, ybits, npipe):
"""Generate expression tree describing multiplier logic."""
xvec = [ InBit('x', p) for p in xrange(xbits) ]
yvec = [ InBit('y', p) for p in xrange(ybits) ]
partial_products = gen_partial_products(xvec, yvec)
(avec, bvec) = gen_dadda_tree(partial_products, xbits+ybits)
if npipe > 0:
avec = [ Reg(a) for a in avec ]
bvec = [ Reg(b) for b in bvec ]
zvec = gen_adder(avec, bvec)
if npipe > 1:
zvec = [ Reg(z) for z in zvec ]
return zvec
def gen_netlist(node, wires, insts):
"""Generate netlist consisting of wires and component instances."""
if node.done:
# already processed this node
return
node.done = True
if isinstance(node, ConstBit):
# resolve during code generation
node.wire = node
elif isinstance(node, InBit):
# resolve during code generation
node.wire = node
elif isinstance(node, Reg):
# create output wire
node.wire = 'wreg%d' % len(wires)
wires.append(node.wire)
# recurse
gen_netlist(node.v, wires, insts)
# create instance
insts.append(node)
elif isinstance(node, NotBit):
# create output wire
node.wire = 'winv%d' % len(wires)
wires.append(node.wire)
# recurse
gen_netlist(node.v, wires, insts)
# create instance
insts.append(node)
elif isinstance(node, BoothNeg):
# create output wire
node.wire = 'wboothneg%d' % len(wires)
wires.append(node.wire)
# recurse
for v in node.pat:
gen_netlist(v, wires, insts)
# create instance
insts.append(node)
elif isinstance(node, BoothProd):
# create output wire
node.wire = 'wboothprod%d' % len(wires)
wires.append(node.wire)
# recurse
for v in node.pat:
gen_netlist(v, wires, insts)
for v in node.b:
gen_netlist(v, wires, insts)
# create instance
insts.append(node)
elif isinstance(node, AddBitD):
# recurse
gen_netlist(node.v, wires, insts)
node.wire = node.v.wire + 'd'
elif isinstance(node, AddBitC):
# recurse
gen_netlist(node.v, wires, insts)
node.wire = node.v.wire + 'c'
elif isinstance(node, HalfAdd):
# create output wires
node.wire = 'wadd%d' % len(wires)
wires.append(node.wire + 'd')
wires.append(node.wire + 'c')
# recurse
gen_netlist(node.a, wires, insts)
gen_netlist(node.b, wires, insts)
# create instance
insts.append(node)
elif isinstance(node, FullAdd):
# create output wires
node.wire = 'wadd%d' % len(wires)
wires.append(node.wire + 'd')
wires.append(node.wire + 'c')
# recurse
gen_netlist(node.a, wires, insts)
gen_netlist(node.b, wires, insts)
gen_netlist(node.c, wires, insts)
# create instance
insts.append(node)
elif isinstance(node, CarryProp):
# create output wires
node.wire = 'wcarry%d' % len(wires)
wires.append(node.wire + 'g')
wires.append(node.wire + 'p')
# recurse
gen_netlist(node.a, wires, insts)
gen_netlist(node.b, wires, insts)
# create instance
insts.append(node)
elif isinstance(node, CarryMerge):
# create output wires
node.wire = 'wcarry%d' % len(wires)
wires.append(node.wire + 'g')
wires.append(node.wire + 'p')
# recurse
gen_netlist(node.p0, wires, insts)
gen_netlist(node.p1, wires, insts)
# create instance
insts.append(node)
elif isinstance(node, CarryEval):
# create output wire
node.wire = 'wcarry%d' % len(wires)
wires.append(node.wire)
# recurse
gen_netlist(node.p, wires, insts)
gen_netlist(node.c, wires, insts)
# create instance
insts.append(node)
else:
assert False
def vhdl_inst(node):
"""Return (name, ports) for a given instance."""
if isinstance(node, Reg):
name = 'smul_flipflop'
ports = ( 'clk', 'clken', node.v.wire, node.wire )
elif isinstance(node, NotBit):
name = 'smul_inverter'
ports = ( node.v.wire, node.wire )
elif isinstance(node, BoothNeg):
name = 'smul_booth_neg'
ports = ( node.pat[0].wire, node.pat[1].wire, node.pat[2].wire,
node.wire )
elif isinstance(node, BoothProd):
name = 'smul_booth_prod'
ports = ( node.pat[0].wire, node.pat[1].wire, node.pat[2].wire,
node.b[0].wire, node.b[1].wire,
node.wire )
elif isinstance(node, HalfAdd):
name = 'smul_half_add'
ports = ( node.a.wire, node.b.wire,
node.wire + 'd', node.wire + 'c' )
elif isinstance(node, FullAdd):
name = 'smul_full_add'
ports = ( node.a.wire, node.b.wire, node.c.wire,
node.wire + 'd', node.wire + 'c' )
elif isinstance(node, CarryProp):
name = 'smul_carry_prop'
ports = ( node.a.wire, node.b.wire,
node.wire + 'g', node.wire + 'p' )
elif isinstance(node, CarryMerge):
name = 'smul_carry_merge'
ports = ( node.p0.wire + 'g', node.p0.wire + 'p',
node.p1.wire + 'g', node.p1.wire + 'p',
node.wire + 'g', node.wire + 'p' )
elif isinstance(node, CarryEval):
name = 'smul_carry_eval'
ports = ( node.p.wire + 'g', node.p.wire + 'p', node.c.wire,
node.wire )
else:
assert False
return (name, ports)
def vhdl_wire(wire):
"""Resolve wire to VHDL expression string."""
if isinstance(wire, ConstBit):
return "'%d'" % wire.v
elif isinstance(wire, InBit):
return "%sin(%d)" % (wire.xy, wire.p)
else:
assert isinstance(wire, str)
return wire
def gen_vhdl_lib():
"""Generate VHDL code for library components."""
print """
--
-- Flip-flop.
--
library ieee;
use ieee.std_logic_1164.all;
entity smul_flipflop is
port (
clk: in std_ulogic;
clken: in std_ulogic;
d: in std_ulogic;
q: out std_ulogic );
end entity;
architecture smul_flipflop_arch of smul_flipflop is
begin
process (clk) is
begin
if rising_edge(clk) then
if to_x01(clken) = '1' then
q <= d;
end if;
end if;
end process;
end architecture;
--
-- Inverter.
--
library ieee;
use ieee.std_logic_1164.all;
entity smul_inverter is
port (
d: in std_ulogic;
q: out std_ulogic );
end entity;
architecture smul_inverter_arch of smul_inverter is
begin
q <= not d;
end architecture;
--
-- Half-adder.
--
library ieee;
use ieee.std_logic_1164.all;
entity smul_half_add is
port (
x: in std_ulogic;
y: in std_ulogic;
d: out std_ulogic;
c: out std_ulogic );
end entity;
architecture smul_half_add_arch of smul_half_add is
begin
d <= x xor y;
c <= x and y;
end architecture;
--
-- Full-adder.
--
library ieee;
use ieee.std_logic_1164.all;
entity smul_full_add is
port (
x: in std_ulogic;
y: in std_ulogic;
z: in std_ulogic;
d: out std_ulogic;
c: out std_ulogic );
end entity;
architecture smul_full_add_arch of smul_full_add is
begin
d <= x xor y xor z;
c <= (x and y) or (y and z) or (x and z);
end architecture;
--
-- Booth negative flag.
--
library ieee;
use ieee.std_logic_1164.all;
entity smul_booth_neg is
port (
p0: in std_ulogic;
p1: in std_ulogic;
p2: in std_ulogic;
f: out std_ulogic );
end entity;
architecture smul_booth_neg_arch of smul_booth_neg is
begin
f <= p2 and ((not p1) or (not p0));
end architecture;
--
-- Booth partial product generation.
--
library ieee;
use ieee.std_logic_1164.all;
entity smul_booth_prod is
port (
p0: in std_ulogic;
p1: in std_ulogic;
p2: in std_ulogic;
b0: in std_ulogic;
b1: in std_ulogic;
y: out std_ulogic );
end entity;
architecture smul_booth_prod_arch of smul_booth_prod is
begin
process (p0, p1, p2, b0, b1) is
variable p: std_ulogic_vector(2 downto 0);
begin
p := (p2, p1, p0);
case p is
when "000" => y <= '0'; -- factor 0
when "001" => y <= b1; -- factor 1
when "010" => y <= b1; -- factor 1
when "011" => y <= b0; -- factor 2
when "100" => y <= not b0; -- factor -2
when "101" => y <= not b1; -- factor -1
when "110" => y <= not b1; -- factor -1
when others => y <= '0'; -- factor 0
end case;
end process;
end architecture;
--
-- Determine carry generate and carry propagate.
--
library ieee;
use ieee.std_logic_1164.all;
entity smul_carry_prop is
port (
a: in std_ulogic;
b: in std_ulogic;
g: out std_ulogic;
p: out std_ulogic );
end entity;
architecture smul_carry_prop of smul_carry_prop is
begin
g <= a and b;
p <= a xor b;
end architecture;
--
-- Merge two carry propagation trees.
--
library ieee;
use ieee.std_logic_1164.all;
entity smul_carry_merge is
port (
g0: in std_ulogic;
p0: in std_ulogic;
g1: in std_ulogic;
p1: in std_ulogic;
g: out std_ulogic;
p: out std_ulogic );
end entity;
architecture smul_carry_merge of smul_carry_merge is
begin
g <= g1 or (g0 and p1);
p <= p0 and p1;
end architecture;
--
-- Calculate carry-out through a carry propagation tree.
--
library ieee;
use ieee.std_logic_1164.all;
entity smul_carry_eval is
port (
g: in std_ulogic;
p: in std_ulogic;
cin: in std_ulogic;
cout: out std_ulogic );
end entity;
architecture smul_carry_eval of smul_carry_eval is
begin
cout <= g or (p and cin);
end architecture;
"""
def gen_vhdl_mul(xbits, ybits, npipe, wires, insts, outputs):
"""Generate VHDL code and write to stdout."""
# Declaration.
print """
---
--- %(xbits)d x %(ybits)d bit signed multiplier
---
--- %(npipe)d cycles pipeline delay
---
library ieee;
use ieee.std_logic_1164.all;
entity smul_%(xbits)d_%(ybits)d is
port (
clk: in std_ulogic;
clken: in std_ulogic;
xin: in std_logic_vector(%(xleft)d downto 0);
yin: in std_logic_vector(%(yleft)d downto 0);
zout: out std_logic_vector(%(zleft)d downto 0) );
end entity;
architecture arch of smul_%(xbits)d_%(ybits)d is
""" % { 'xbits': xbits, 'ybits': ybits,
'npipe': npipe,
'xleft': xbits-1, 'yleft': ybits-1,
'zleft': xbits+ybits-1 }
# Declare signals.
for w in wires:
print "signal %s: std_ulogic;" % w
# Start architecture body.
print
print "begin"
print
# Instantiate components.
for (i, node) in enumerate(insts):
(name, ports) = vhdl_inst(node)
print "u%d: entity work.%s port map (" % (i, name),
print ", ".join([ vhdl_wire(p) for p in ports ]),
print ");"
print
# Drive output signals.
for (i, wire) in enumerate(outputs):
print "zout(%d) <= %s;" % (i, vhdl_wire(wire))
# End architecture.
print
print "end architecture;"
def verilog_wire(wire):
"""Resolve wire to Verilog expression string."""
if isinstance(wire, ConstBit):
return "1'b%d" % wire.v
elif isinstance(wire, InBit):
return "%sin[%d]" % (wire.xy, wire.p)
else:
assert isinstance(wire, str)
return wire
def gen_verilog_lib():
"""Generate Verilog code for library components."""
print """
// Flip-flop.
module smul_flipflop (
input wire clk,
input wire clken,
input wire d,
output reg q );
always @(posedge clk)
begin
if (clken)
q <= d;
end
endmodule
// Inverter.
module smul_inverter (
input wire d,
output wire q );
assign q = ~d;
endmodule
// Half-adder.
module smul_half_add (
input wire x,
input wire y,
output wire d,
output wire c );
assign d = x ^ y;
assign c = x & y;
endmodule
// Full-adder.
module smul_full_add (
input wire x,
input wire y,
input wire z,
output wire d,
output wire c );
assign d = x ^ y ^ z;
assign c = (x & y) | (y & z) | (x & z);
endmodule
// Booth negative flag.
module smul_booth_neg (
input wire p0,
input wire p1,
input wire p2,
output wire f );
assign f = p2 & ((~p1) | (~p0));
endmodule
// Booth partial product generator.
module smul_booth_prod (
input wire p0,
input wire p1,
input wire p2,
input wire u0,
input wire u1,
output reg y );
always @ (*)
begin
case ({p2, p1, p0})
3'b000 : y = 1'b0;
3'b001 : y = u1;
3'b010 : y = u1;
3'b011 : y = u0;
3'b100 : y = ~u0;
3'b101 : y = ~u1;
3'b110 : y = ~u1;
default : y = 1'b0;
endcase
end
endmodule
// Deterimine carry generate and carry propagate.
module smul_carry_prop (
input wire a,
input wire b,
output wire g,
output wire p );
assign g = a & b;
assign p = a ^ b;
endmodule
// Merge two carry propagation trees.
module smul_carry_merge (
input wire g0,
input wire p0,
input wire g1,
input wire p1,
output wire g,
output wire p );
assign g = g1 | (g0 & p1);
assign p = p0 & p1;
endmodule
// Calculate carry-out through a carry propagation tree.
module smul_carry_eval (
input wire g,
input wire p,
input wire cin,
output wire cout );
assign cout = g | (p & cin);
endmodule
"""
def gen_verilog_mul(xbits, ybits, npipe, wires, insts, outputs):
"""Generate Verilog code and write to stdout."""
# Preamble.
print """
/*
* %(xbits)d x %(ybits)d bit signed multiplier
*
* %(npipe)d cycles pipeline delay
*/
module smul_%(xbits)d_%(ybits)d (
input wire clk,
input wire clken,
input wire [%(xleft)d:0] xin,
input wire [%(yleft)d:0] yin,
output wire [%(zleft)d:0] zout );
""" % { 'xbits': xbits, 'ybits': ybits,
'npipe': npipe,
'xleft': xbits-1, 'yleft': ybits-1,
'zleft': xbits+ybits-1 }
# Declare signals.
for w in wires:
print "wire %s;" % w
# Instantiate components.
for (i, node) in enumerate(insts):
(name, ports) = vhdl_inst(node)
print "%s u%d (" % (name, i),
print ", ".join([ verilog_wire(p) for p in ports ]),
print ");"
print
# Drive output signals.
for (i, wire) in enumerate(outputs):
print "assign zout[%d] = %s;" % (i, verilog_wire(wire))
# End module.
print
print "endmodule"
def main():
parser = argparse.ArgumentParser()
parser.format_help = lambda: __doc__
parser.format_usage = lambda: __doc__
parser.add_argument('--lang', action='store', type=str)
parser.add_argument('--nolib', action='store_true', default=False)
parser.add_argument('--lib', action='store_true', default=False)
parser.add_argument('Xbits', action='store', type=int, nargs='?')
parser.add_argument('Ybits', action='store', type=int, nargs='?')
parser.add_argument('npipe', action='store', type=int, nargs='?')
args = parser.parse_args()
if args.lang is None or args.lang.upper() not in ('VHDL', 'VERILOG'):
print >>sys.stderr, __doc__
print >>sys.stderr, "ERROR: Must specify --lang=vhdl or --lang=verilog"
sys.exit(1)
if args.lib:
if (args.nolib or
args.Xbits is not None or
args.Ybits is not None or
args.npipe is not None):
print >>sys.stderr, __doc__
print >>sys.stderr, "ERROR: Must specify either --lib or",
print >>sys.stderr, "Xbits, Ybits, npipe"
sys.exit(1)
else:
if (args.Xbits is None or args.Ybits is None or args.npipe is None):
print >>sys.stderr, __doc__
print >>sys.stderr, "ERROR: Must specify either --lib or",
print >>sys.stderr, "Xbits, Ybits, npipe"
sys.exit(1)
if args.Xbits < 4 or args.Ybits < args.Xbits:
print >>sys.stderr, "ERROR: invalid word lengths"
sys.exit(1)
if args.npipe < 0 or args.npipe > 2:
print >>sys.stderr, "ERROR: invalid number of register stages"
sys.exit(1)
if not args.lib:
# Generate expression tree.
zvec = gen_multiplier(args.Xbits, args.Ybits, args.npipe)
# Generate wires and instances.
wires = [ ]
insts = [ ]
for node in zvec:
gen_netlist(node, wires, insts)
outputs = [ node.wire for node in zvec ]
# Write library components.
if not args.nolib:
if args.lang.upper() == 'VHDL':
gen_vhdl_lib()
elif args.lang.upper() == 'VERILOG':
gen_verilog_lib()
# Write multiplier.
if not args.lib:
if args.lang.upper() == 'VHDL':
gen_vhdl_mul(args.Xbits, args.Ybits, args.npipe,
wires, insts, outputs)
elif args.lang.upper() == 'VERILOG':
gen_verilog_mul(args.Xbits, args.Ybits, args.npipe,
wires, insts, outputs)
if __name__ == '__main__':
main()
# end