
// A program to analyse Serpent S box boolean functions to find minimum
// booleasn expressions for evaluating them. This is achieved by using
// a list of terms, each of which consists of a binary combination of
// other terms on the list using 'xor', 'and' or 'or.  The list starts
// with 5 terms: '1', 'a', 'b', 'c', 'd' and is grown recursively until
// hopefully a match is found with e, f, g, h, the function output values.
// Note that 'negate' is impelemented as '1 xor term'.	 The four term 
// boolean functions are represented by a 16 bit variable in which each 
// bit represents a particular miniterm in the canonical representation. 
// Thus if bit 0 is set the term  ~a & ~b & ~c & ~d is present - that is, 
// if bit n (0 < n < 15) in the word representing the function is set then
// the bits (0..3) in the binary representation of n determine whether 
// the variables (a, b, c, d) in the miniterm are normal (bit = 1) or 
// or inverted (bit = 0).	Note that if two values are combined with an
// operator then the miniterm represtation of the result is obtained by
// applying the same operator to the two miniterm representations.

// Copyright: Dr B. R. Gladman (June 1998)

#include <iostream>
#include <fstream>
#include <iomanip>
#include <string>

using namespace std;

#include "f_box.h"

#define ASM

typedef unsigned char	u1byte;
typedef unsigned short	u2byte;
typedef unsigned long	u4byte;
typedef char			s1byte;
typedef short			s2byte;

const unsigned short term_lim = 40;

enum op_type { atomic = 0, xor = 1, and = 2, or = 3};

// the basic variables (a, b, c, d)  in boolean function form
 
const u2byte	a = 0xaaaa;	// miniterm representation of a
const u2byte	b = 0xcccc;	// miniterm representation of b
const u2byte	c = 0xf0f0;	// miniterm representation of c
const u2byte	d = 0xff00;	// miniterm representation of d

// the structure to hold a single term

// the structure to hold a list of terms - bits 0..3 of the flag word
// are set if e..h have been matched - that is if bit 0 is set e has
// been matched and so on.

typedef struct
{	u2byte	nterm;				// the number of valid terms in list
	u2byte	flags;				// flag bits if (e, f, g, h) matched 
	u2byte	tvals[4];			// the target values (e, f, g, h)
	u2byte  t_valu[term_lim];	// miniterm representation of this term
	u2byte  t_trm1[term_lim];	// 1st sub term
	u2byte  t_trm2[term_lim];	// 2nd sub term
	op_type t_oper[term_lim];	// combining operator
	s1byte	t_char[term_lim];	// char representation of term (if any)
} t_str;

// the actual list with initial values set

t_str	term =
{
	5, 
	0, 
	{ 0, 0, 0, 0 }, 
	{ 0xffff, a, b, c, d },
	{ 0, 1, 2, 3, 4 },
	{ 0, 0, 0, 0, 0 },
	{ atomic, atomic, atomic, atomic, atomic },
	{ '1', 'a', 'b', 'c', 'd' }
};

// converts the flags value into the number of matches found so far

u1byte n_found[16] =
{
	0, 1, 1, 2, 1, 2, 2, 3, 1, 2, 2, 3, 2, 3, 3, 4 
};

// controls recursion depth

u2byte	max_terms = 25;
u2byte	cur_best  = 25;
u2byte  sch_dep   =  3;

u2byte		s_box[8][4];	// S boxes (slice representation)
u2byte		i_box[8][4];	// inverse S boxes (ditto)
u1byte		ss_box[8][16];	// S boxes (normal representation)
u1byte		ii_box[8][16];	// inverse S boxes (ditto)
u2byte		r_level;		// recursion level
ofstream	fout;			// out file for progress log

// compile S box arrays from 'f_box.h'

void com_boxes(void)		
{
	u2byte i,j,a,b,c,d,e,f,g,h;

	for(i = 0; i < 8; ++i)
	{
		s_box[i][0] = 0; s_box[i][1] = 0;
		s_box[i][2] = 0; s_box[i][3] = 0;

		for(j = 0; j < 16; ++j)
		{
			a = (j & 1 ? 0xffff : 0); b = (j & 2 ? 0xffff : 0);
			c = (j & 4 ? 0xffff : 0); d = (j & 8 ? 0xffff : 0);

			switch(i)
			{
				case 0: sb0(a,b,c,d,e,f,g,h); break;
				case 1: sb1(a,b,c,d,e,f,g,h); break;
				case 2: sb2(a,b,c,d,e,f,g,h); break;
				case 3: sb3(a,b,c,d,e,f,g,h); break;
				case 4: sb4(a,b,c,d,e,f,g,h); break;
				case 5: sb5(a,b,c,d,e,f,g,h); break;
				case 6: sb6(a,b,c,d,e,f,g,h); break;
				case 7: sb7(a,b,c,d,e,f,g,h); break;
			}
		
			s_box[i][0] |= (e ? 1 << j : 0); s_box[i][1] |= (f ? 1 << j : 0);
			s_box[i][2] |= (g ? 1 << j : 0); s_box[i][3] |= (h ? 1 << j : 0);

			ss_box[i][j] = (e ? 1 : 0) | (f ? 2 : 0) 
									| (g ? 4 : 0) | (h ? 8 : 0); 
		}
	}

	for(i = 0; i < 8; ++i)
	{
		i_box[i][0] = 0; i_box[i][1] = 0;
		i_box[i][2] = 0; i_box[i][3] = 0;

		for(j = 0; j < 16; ++j)
		{
			a = (j & 1 ? 0xffff : 0); b = (j & 2 ? 0xffff : 0);
			c = (j & 4 ? 0xffff : 0); d = (j & 8 ? 0xffff : 0);

			switch(i)
			{
				case 0: ib0(a,b,c,d,e,f,g,h); break;
				case 1: ib1(a,b,c,d,e,f,g,h); break;
				case 2: ib2(a,b,c,d,e,f,g,h); break;
				case 3: ib3(a,b,c,d,e,f,g,h); break;
				case 4: ib4(a,b,c,d,e,f,g,h); break;
				case 5: ib5(a,b,c,d,e,f,g,h); break;
				case 6: ib6(a,b,c,d,e,f,g,h); break;
				case 7: ib7(a,b,c,d,e,f,g,h); break;
			}
		
			i_box[i][0] |= (e ? 1 << j : 0); i_box[i][1] |= (f ? 1 << j : 0);
			i_box[i][2] |= (g ? 1 << j : 0); i_box[i][3] |= (h ? 1 << j : 0);

			ii_box[i][j] = (e ? 1 : 0) | (f ? 2 : 0) 
									| (g ? 4 : 0) | (h ? 8 : 0); 
		}
	}
};

// output a variable name for a term in the list
// 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h' or 'tn'

string &var_out(string &s, u2byte ix, t_str &t)
{
	u2byte	no;
	s1byte	tc[4];

	if(t.t_char[ix] >= 'a')		// the term has an character representation
	{		
		tc[0] = t.t_char[ix]; tc[1] = '\0';
	}
	else						// otherwise output as a temporary variable
	{
		tc[0] = 't'; no = ix - 4;
		
		if(no < 10) 
		{
			tc[1] = '0' + no; tc[2] = '\0';
		}
		else
		{
			tc[1] = '0' + (no / 10); tc[2] = '0' + (no % 10); tc[3] = '\0';
		}
	}

	s = tc; return s;
};

// output the current terms in the list

void terms_out(string &s, t_str &t)
{	
	u2byte	i; 
	string	t1, t2;

	s = "";

	for(i = 5; i < t.nterm; ++i)	// for each term in the list
	{
		s += var_out(t1, i, t) + " = ";

		if(t.t_trm1[i] == 0)		// if '1 xor term' use form ~term
		{
			s += "~" + var_out(t1, t.t_trm2[i], t);
		}
		else
		{	
			s += var_out(t1, t.t_trm1[i], t);		// 1st sub term

			switch(t.t_oper[i])						// operator
			{
				case xor: t1 = " ^ "; break;
				case and: t1 = " & "; break;
				case  or: t1 = " | "; break;
			}
			
			s += t1 + var_out(t2, t.t_trm2[i], t);	// 2nd sub term
		}
			
		s += "; ";
	}
};

// The main recursive routine to add terms to list and look for
// a match with e, f, g and h.

// The following code sequences are somewhat strange in C because
// they are intended to form a basis for assembler code sections

// 1. search the list of terms for a match with a particular value

#ifndef ASM

#define eq_loop(mv) 						\
{	v = v_s; 								\
	while(++v < v_e) 						\
		if(*v == mv)						\
			goto next_term; 				\
};

#else

#define eq_loop(mv) 						\
{	__asm	{								\
	__asm		mov 	edi,dword ptr [v_s]	\
	__asm		add		edi,2				\
	__asm		mov 	ecx,dword ptr [v_e]	\
	__asm		sub 	ecx,edi				\
	__asm		sar		ecx,1				\
	__asm		mov 	ax,word ptr [mv]	\
	__asm		repne	scasw				\
	__asm		je 		next_term			\
			}								\
}

#endif

// 2. search the list of terms for a value such that m_cv ^ val == m_nv

#ifndef ASM

#define xor_loop(x1)	 					\
{	m_op = xor; v = v_s - 1; m_nv ^= m_cv;	\
	while(++v < v_e) 						\
	{		if(*v == m_nv)					\
			{	m_nv ^= m_cv; goto found; } \
	}										\
}

#else

#define xor_loop(x1)	 					\
{	m_op = xor;								\
	__asm	{								\
	__asm		mov 	edi,dword ptr [v_s]	\
	__asm		mov 	ecx,dword ptr [v_e]	\
	__asm		sub		ecx,edi				\
	__asm		sar		ecx,1				\
	__asm		mov 	ax,word ptr [m_nv]	\
	__asm		xor 	ax,word ptr [m_cv]	\
	__asm		repne	scasw				\
	__asm		jne		x1					\
	__asm		sub		edi,2				\
	__asm		mov		dword ptr [v],edi	\
	__asm		jmp 	found				\
	__asm	x1: nop 						\
			}								\
}

#endif

// 3. search the term list for a value such that m_cv & val == m_nv
//    note that this equality requires that m_nv & ~m_cv == 0

#ifndef ASM

#define and_loop(x1,x2) 					\
{	if(!(m_nv & ~m_cv)) 					\
	{	m_op = and; v = v_s - 1; 			\
		while(++v < v_e) 					\
		{	if((*v & m_cv) == m_nv)		 	\
				goto found; 				\
		}									\
	}										\
}

#else

#define and_loop(x1,x2) 					\
{	if(!(m_nv & ~m_cv)) 					\
	{	m_op = and;							\
	__asm	{								\
	__asm		mov 	esi,dword ptr [v_s]	\
	__asm		mov 	ecx,dword ptr [v_e]	\
	__asm		sub		ecx,esi				\
	__asm		sar		ecx,1				\
	__asm		mov 	dx,word ptr [m_nv]	\
	__asm		mov		bx,word ptr [m_cv]	\
	__asm	x1:	lodsw						\
	__asm		and		ax,bx				\
	__asm		cmp		ax,dx				\
	__asm		loopne	x1					\
	__asm		jne		x2					\
	__asm		sub		esi,2				\
	__asm		mov		dword ptr [v],esi	\
	__asm		jmp		found				\
	__asm	x2: nop 						\
			}								\
	}										\
}

#endif

// 4. search the term list for a value such that m_cv | val == m_nv
//    note that this equality requires that ~m_nv & m_cv == 0.

#ifndef ASM

#define or_loop(x1,x2)						\
{	if(!(~m_nv & m_cv)) 					\
	{	m_op = or; v = v_s - 1;				\
		while(++v < v_e) 					\
		{	if((*v | m_cv) == m_nv)		 	\
					goto found; 			\
		}									\
	}										\
}

#else

#define or_loop(x1,x2)						\
{	if(!(~m_nv & m_cv)) 					\
	{	m_op = or;							\
	__asm	{								\
	__asm		mov 	esi,dword ptr [v_s]	\
	__asm		mov 	ecx,dword ptr [v_e]	\
	__asm		sub		ecx,esi				\
	__asm		sar		ecx,1				\
	__asm		mov 	dx,word ptr [m_nv]	\
	__asm		mov		bx,word ptr [m_cv]	\
	__asm	x1:	lodsw						\
	__asm		or		ax,bx				\
	__asm		cmp		ax,dx				\
	__asm		loopne	x1					\
	__asm		jne		x2					\
	__asm		sub		esi,2				\
	__asm		mov		dword ptr [v],esi	\
	__asm		jmp		found				\
	__asm	x2: nop 						\
			}								\
	}										\
}

#endif

bool add_terms(u2byte req_terms, t_str &t)
{	u2byte	i, j, k, l, m_cv, m_nv, lim;
	u2byte	*v_s, *v, *v_e;
	u1byte	m_ch;
	op_type m_op;
	string	ss;

	v_s = t.t_valu; v_e = v_s + t.nterm; 

	for(i = 0; i < t.nterm; ++i)		// for each term in list and each
	{									// operator (xor, and, or) and 
	  for(j = 1; j < (i ? 4 : 2); ++j)	// for all other terms in list
	  {									// prepare to add a new term
		for(k = i + 1; k < t.nterm; ++k) // 't[i] op[j] t[k]'  
		{								 
			m_cv = v_s[i]; m_nv = v_s[k]; 
											
			switch(j)					 // evaluate its representation 	
			{
				case 1: m_cv ^= m_nv; break;
				case 2: m_cv &= m_nv; break;
				case 3: m_cv |= m_nv; break;
			}
			
			eq_loop(m_cv);	// skip this term if it is in the list already

			t.t_valu[t.nterm] = m_cv;	// set the term up
			t.t_oper[t.nterm] = static_cast<op_type>(j);
			t.t_trm1[t.nterm] = i;
			t.t_trm2[t.nterm] = k;
			t.t_char[t.nterm] = 0;		// but don't add it yet

			if(t.nterm < req_terms - 1 && req_terms <= max_terms)
			{
				t.nterm++;				// put current term in list
				
				add_terms(req_terms, t);// and add more
				
				t.nterm--;
			}
			else
			{
				// check if new term will combine with others in the	
				// list to produce a match with a target.  Although
				// this looks costly it reduces the recursion level
				// by one and thus gives a high return

				if(!(t.flags & 0x01))	// if e not yet matched
				{
					m_ch = 'e'; m_nv = t.tvals[0];
					xor_loop(ex1);
					and_loop(ea1, ea2);
					or_loop(eo1, eo2);
				}

				if(!(t.flags & 0x02))	// if f not yet matched
				{
					m_ch = 'f'; m_nv = t.tvals[1];
					xor_loop(fx1);
					and_loop(fa1, fa2);
					or_loop(fo1, fo2);
				}

				if(!(t.flags & 0x04))	// if g not yet matched
				{
					m_ch = 'g'; m_nv = t.tvals[2];
					xor_loop(gx1);
					and_loop(ga1, ga2);
					or_loop(go1, go2);
				}

				if(!(t.flags & 0x08))	// if h not yet matched
				{
					m_ch = 'h'; m_nv = t.tvals[3];
					xor_loop(hx1);
					and_loop(ha1, ha2);
					or_loop(ho1, ho2);
				}
			
			next_term:	
				
				continue;

			found:	// if a match found, add the additional term
					// (one more than normal at this recursion)

				t.nterm++;					// add the current term

				t.t_valu[t.nterm] = m_nv;	// add the term for the match
				t.t_oper[t.nterm] = m_op;	// = 'e', 'f', 'g' or 'h'
				t.t_trm1[t.nterm] = v - v_s;
				t.t_trm2[t.nterm] = t.nterm - 1;
				t.t_char[t.nterm] = m_ch;

				// add the matched term and set flags for the match

				t.nterm++; t.flags |= (1 << (m_ch - 'e'));

				if(t.flags == 0x0f)			// if all 4 outputs (e, f, g, h) are matched
				{
					if(t.nterm < cur_best)	// and this is least no of terms so far
					{
						cout << endl << setw(2) << t.nterm - 5 
								<< " term solution:";
						terms_out(ss, t); 
						
						cout << endl << ss << endl; 
							
						cout.flush(); cur_best = t.nterm;

						if(fout)
						{
							fout << endl << ss << endl; fout.flush();
						}
					}

					max_terms = t.nterm - 1;
				}
				else	// look for solutions based on this partial solution
				{
					cout << '.' << m_ch << setw(1) << (int)n_found[t.flags]; 

					lim = (t.nterm + sch_dep < max_terms ? 
									t.nterm + sch_dep : max_terms);

					for(l = t.nterm; l <= lim; ++l)
					{
						add_terms(l, t);

						if(t.nterm > max_terms)

							break;
					}
				}

				t.flags &= ~(1 << (m_ch - 'e')); t.nterm -= 2;
			}

			if(t.nterm > max_terms)
			
				return false;
		}
	  }
	}

	return false;
};

/*
 If the analyser needs to be steered to look at particular subtrees
 extra terms can be added to the start of the list. The list starts
 with the 5 terms '1', 'a', 'b', 'c', 'd'.	To add extra terms this
 list is extended with added terms and term.nterm (that is normally
 5) is set to the array index first free term in the list (8 in the
 following example). Terms added must be built by combining any two
 lower terms in the list.	For example, adding the term ~a is done 
 by adding a term that points to list[0] ='1' and list[1] = 'a' and
 the operation 'xor'.	The following provide examples of how terms 
 are defined. 

	term.terms[5].val = a | d;	// the miniterm expression for this term
	term.terms[5].tm1 = 1;		// index of 1st sub-term from list (a)
	term.terms[5].tm2 = 4;		// index of 2nd sub-term from list (d)
	term.terms[5].opr = or; 	// how these terms are combined
	term.terms[5].ref = 0;		// always zero on input (used internally)

	term.terms[6].val = a & d;
	term.terms[6].tm1 = 1;
	term.terms[6].tm2 = 4;
	term.terms[6].opr = and;
	term.terms[6].ref = 0;

	term.terms[7].val = a ^ c;
	term.terms[7].tm1 = 1;
	term.terms[7].tm2 = 3;
	term.terms[7].opr = xor;
	term.terms[7].ref = 0;
*/

void set_term(void)
{	int 	nt,n1,n2;
	char	opc;
	op_type opp;
	u2byte	t1,t2;
	

	n1 = n2 = 1000; opc = '\0'; nt = term.nterm;

	cout << endl << "Term " << setw(2) << nt << ':';
	  
	while(n1 < 0 || n1 >= nt)
	{
		cout << endl << "1st Component Term? "; cin >> n1;
	}

	while(n2 < 0 || n2 >= nt)
	{
		cout << endl << "2nd Component Term? "; cin >> n2;
	}

	while(opc != 'a' && opc != 'o' && opc != 'x')
	{
		cout << endl << "Combine with (a)nd, (o)r or (x)or? "; cin >> opc;
	}

	t1 = term.t_valu[n1]; t2 = term.t_valu[n2];

	if(opc == 'a')
	{
		opp = and; t1 &= t2;
	}
	else if(opc == 'o')
	{
		opp = or; t1 |= t2;
	}
	else
	{
		opp = xor; t1 ^= t2;
	}
	
	term.t_valu[nt] = t1;
	term.t_trm1[nt] = n1;
	term.t_trm2[nt] = n2;
	term.t_oper[nt] = opp;
	term.t_char[nt] = '\0';

	(term.nterm)++;
};

void main(void)
{
	u2byte	i;
	char	q;
	string	fn;

	cout << endl << "Search Depth (2-6)? "; cin >> i;

	sch_dep = (i > 1 && i < 7 ? i : 4);

	cout << endl << "S Box Number (0-15)? "; cin >> i;

	fn = "sbox_00.log"; com_boxes();

	fn[5] += i / 10; fn[6] += i % 10;

	if(i < 8)
	{
		term.tvals[0] = s_box[i][0]; term.tvals[1] = s_box[i][1]; 
		
		term.tvals[2] = s_box[i][2]; term.tvals[3] = s_box[i][3];
	}
	else
	{
		i -= 8;

		term.tvals[0] = i_box[i][0]; term.tvals[1] = i_box[i][1]; 
		
		term.tvals[2] = i_box[i][2]; term.tvals[3] = i_box[i][3];
	}

	term.nterm = 5; term.flags = 0; 
	
	q = 'y';

	while(q != 'n' && q != 'N')
	{ 
		cout << endl << "Add Terms? "; cin >> q;

		if(q == 'y' || q == 'Y')
		{
			set_term();
		}
	}

	fout.open(fn.c_str());

	cout << endl << fn << endl;

	max_terms = 22;
	
	for(i = term.nterm; i < 22; ++i)
	{
		cout << endl << "looking for solutions with " << i - 4 << " terms";

		fout << endl << "looking for solutions with " << i - 4 << " terms";

		if(!add_terms(i, term))
		
			cout << ": none found";
	}

	cout << endl; 
	
	fout.close();
};
