/*	  File:	  restruct.c
 *	  Author:	Sebastian Skalberg <skalberg@diku.dk>
 *	  Content:  C-Mix restructuring: functions.
 *
 *	  Copyright  1999. The TOPPS group at DIKU, U of Copenhagen.
 *	  Redistribution and modification are allowed under certain
 *	  terms; see the file COPYING.cmix for details.
 */
  
#include <assert.h>
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <ctype.h>
#define cmixSPECLIB_SOURCE
#include "speclib.h"
#include "code.h"

static unsigned makeInterval( struct cmixStmtLabel * );
static void addReference( Interval *, Interval * );

static iList
*makeIList( Interval *interval, iList *next )
{
  iList *rv;
  
  rv = (iList *)malloc(sizeof(iList));
  rv->interval = interval;
  rv->next = next;
  
  return rv;
}

static void
addReference( Interval *from, Interval *to )
{
  iList *newRef;
  
  assert( from != NULL );
  assert( to != NULL );
  
  /* Add forward reference from->to */
  newRef = makeIList( to, from->forw_refs );
  from->forw_refs = newRef;
  
  /* Add backward reference to->from */
  newRef = makeIList( from, to->back_refs );
  to->back_refs = newRef;
}

static void
collapseCompounds( Interval *interval )
{
  struct cmixStmtIf *xif = &interval->member.basic.control->cond;
  struct cmixStmtIf *yif;
  
  while ( 1 ) {
    if ( xif->then_target->next->common.tag == If &&
         xif->then_target->refcount == 1 ) {
      yif = &xif->then_target->next->cond;
      
      if ( yif->else_target == xif->else_target &&
           yif->then_target != xif->then_target ) {
        xif->cond = &cmixMkExp( "(?) && (?)", xif->cond, yif->cond )->inner;
        xif->then_target = yif->then_target;
        continue;
      }
      
      if ( yif->then_target == xif->else_target &&
           yif->else_target != xif->then_target ) {
        xif->cond =
                   &cmixMkExp( "! (?) || (?)", xif->cond, yif->cond )->inner;
        xif->then_target = yif->then_target;
        xif->else_target = yif->else_target;
        continue;
      }
      
    }
    
    if ( xif->else_target->next->common.tag == If &&
         xif->else_target->refcount == 1 ) {
      yif = &xif->else_target->next->cond;
      
      if ( yif->then_target == xif->then_target &&
           yif->else_target != xif->else_target ) {
        xif->cond = &cmixMkExp( "(?) || (?)", xif->cond, yif->cond )->inner;
        xif->else_target = yif->else_target;
        continue;
      }
      
      if ( yif->else_target == xif->then_target &&
           yif->then_target != xif->else_target ) {
        xif->cond = &cmixMkExp( "! (?) && (?)", xif->cond, yif->cond )->inner;
        xif->then_target = yif->then_target;
        xif->else_target = yif->else_target;
        continue;
      }

    }

    break;
  }
}

static unsigned
makeInterval( struct cmixStmtLabel *label )
{
  unsigned noOfIntervals = 0;
  union cmixStmt *stmt, **pstmt;

  if ( label->interval != NULL )
    return noOfIntervals;

  pstmt = NULL;
  stmt = (union cmixStmt *)label;

  while ( stmt->common.next != NULL &&
          stmt->common.next->common.tag != Label ) {
    pstmt = &stmt->common.next;
    stmt = stmt->common.next;
  }

  assert( pstmt != NULL );
  assert( *pstmt == stmt );

  label->interval = (Interval *)malloc(sizeof(Interval));
  label->interval->parent = NULL;
  label->interval->number = -1;
  label->interval->type = basic;
  label->interval->member.basic.label = label;
  label->interval->member.basic.control = stmt;
  label->interval->member.basic.immDom = NULL;
  label->interval->member.basic.ifFollow = NULL;
  label->interval->member.basic.loop = NULL;
  label->interval->member.basic.printed = 0;
  label->interval->back_refs = NULL;
  label->interval->forw_refs = NULL;
  label->interval->loopCheckDone = 0;
  label->interval->path = 0;

  noOfIntervals++;

  switch ( stmt->common.tag ) {
  case If:
    if ( stmt->cond.then_target == stmt->cond.else_target ) {
      /* degenerate branch */
      union cmixStmt *newGoto =
                               (union cmixStmt *)malloc(sizeof(struct cmixStmtGoto));

      newGoto->jump.tag = Goto;
      newGoto->jump.next = stmt->common.next;
      newGoto->jump.target = stmt->cond.then_target;
      *pstmt = newGoto;

      /* hack: inner expression not free'd */
      free( stmt );

      /* fall through to Goto case below, may need further */
      /* handling, e.g. infinte loop handling              */
      stmt = newGoto;
      label->interval->member.basic.control = stmt;
    } else if ( label->next != stmt ) {
      /* conditional has "plain" prefix */
      union cmixStmt *newGoto =
                               (union cmixStmt *)malloc(sizeof(struct cmixStmtGoto));
      union cmixStmt *newLabel = cmixMakeLabel();

      newGoto->jump.tag = Goto;
      newGoto->jump.next = newLabel;
      newGoto->jump.target = &newLabel->label;

      newLabel->label.next = stmt;
      newLabel->label.refcount = 1;
      newLabel->label.defined = 1;

      *pstmt = newGoto;
      stmt = newGoto;
      label->interval->member.basic.control = stmt;

      noOfIntervals += makeInterval( &newLabel->label );

      addReference( label->interval, newLabel->label.interval );
      break;
    } else {
      collapseCompounds( label->interval );

      noOfIntervals += makeInterval( stmt->cond.then_target );
      noOfIntervals += makeInterval( stmt->cond.else_target );

      /* normal branch */
      addReference( label->interval, stmt->cond.then_target->interval );
      addReference( label->interval, stmt->cond.else_target->interval );
      break;
    }
    /* the 'degenerate branch' case above may fall through */
  case Goto:
    noOfIntervals += makeInterval( stmt->jump.target );
    addReference( label->interval, stmt->jump.target->interval );
    break;
  case Return:
    break;
  default:
    fprintf(stderr,"Tag = %d\n",stmt->common.tag);
    assert( 0 == 1 );
    break;
  }

  return noOfIntervals;
}

static unsigned
enumIntervals( Interval *interval, unsigned lastFree, Interval *order[] )
{
  iList *ref;

  if ( interval->number != -1 )
    return lastFree;

  interval->number = 0;

  for ( ref = interval->forw_refs; ref != NULL; ref = ref->next )
    lastFree = enumIntervals( ref->interval, lastFree, order );

  interval->number = lastFree;
  if ( interval->type == basic )
    interval->member.basic.label->number = lastFree;

  order[lastFree--] = interval;

  return lastFree;
}

static unsigned
iterateIntervals( Interval *interval )
{
  iList *ref, *ref2;
  Interval *parent;
  int changed;
  unsigned noOfIntervals = 0;

  if ( interval->parent != NULL )
    return noOfIntervals;

  interval->parent = (Interval *)malloc(sizeof(Interval));
  parent = interval->parent;

  parent->parent = NULL;
  parent->number = interval->number;
  parent->type = ilist;
  parent->member.ilist = makeIList( interval, NULL );
  parent->back_refs = NULL;
  parent->forw_refs = NULL;
  parent->loopCheckDone = 0;

  noOfIntervals++;

  do {
    changed = 0;
    for ( ref2 = parent->member.ilist; ref2 != NULL; ref2 = ref2->next )
      for ( ref = ref2->interval->forw_refs; ref != NULL; ref = ref->next ) {
        assert( ref->interval != NULL );
        if ( ref->interval->parent == NULL ) {
          iList *bref;
          int all_same = 1;

          for ( bref = ref->interval->back_refs;
                bref != NULL;
                bref = bref->next ) {
            if ( bref->interval->parent == NULL ||
                 bref->interval->parent != parent ) {
              all_same = 0;
              break;
            }
          }

          if ( all_same ) {
            ref->interval->parent = parent;
            parent->member.ilist =
                                  makeIList( ref->interval, parent->member.ilist );
            changed = 1;
          }
        }
      }
  } while ( changed );

  for ( ref2 = parent->member.ilist; ref2 != NULL; ref2 = ref2->next ) {
    for ( ref = ref2->interval->forw_refs; ref != NULL; ref = ref->next ) {
      iList *pref;

      if ( ref->interval->parent == parent )
        continue;

      noOfIntervals += iterateIntervals( ref->interval );
      for ( pref = parent->forw_refs; pref != NULL; pref = pref->next )
        if ( pref->interval == ref->interval->parent )
          break;
      if ( pref == NULL )
        addReference( parent, ref->interval->parent );
    }
  }
  return noOfIntervals;
}

static void
swapBranches( Interval *interval )
{
  struct cmixStmtIf *cond;
  struct cmixStmtLabel *swapLabel;
  iList *swapIList;

  assert( interval->type == basic );
  assert( interval->member.basic.control->common.tag == If );

  cond = &interval->member.basic.control->cond;
  cond->cond = &cmixMkExp( "! (?)", cond->cond )->inner;

  swapLabel = cond->then_target;
  cond->then_target = cond->else_target;
  cond->else_target = swapLabel;

  swapIList = interval->forw_refs;
  interval->forw_refs = interval->forw_refs->next;
  interval->forw_refs->next = swapIList;
  swapIList->next = NULL;
}
	
static void
markLoop( Loop *loop, Interval *interval )
{
	iList *ref;

	if ( interval->member.basic.loop == loop )
		return;

	if ( interval->member.basic.loop != NULL )
		interval = interval->member.basic.loop->header;
	else
		interval->member.basic.loop = loop;

	for ( ref = interval->back_refs; ref != NULL; ref = ref->next )
		if ( ref->interval->number < interval->number )
			markLoop( loop, ref->interval );
}

static void
loopDetect( Interval *interval, Interval *order[] )
{
  iList *ref;
  Interval *header;

  assert( interval != NULL );
  assert( interval->parent == NULL );

  if ( interval->loopCheckDone == 1 )
    return;

  interval->loopCheckDone = 1;

  header = interval;
  while ( header->type == ilist ) {
    ref = header->member.ilist;
    while ( ref != NULL ) {
      header = ref->interval;
      ref = ref->next;
    }
  }

  if ( header->member.basic.loop == NULL ) {
    Interval *latch;
    Interval *lparent;

    latch = NULL;
    for ( ref = header->back_refs; ref != NULL; ref = ref->next ) {
      lparent = ref->interval;
      while ( lparent->parent != NULL )
        lparent = lparent->parent;
      if ( lparent == interval &&
           ref->interval->member.basic.loop == NULL &&
           ( latch == NULL || latch->number < ref->interval->number ) )
        latch = ref->interval;
    }

    if ( latch != NULL ) {
      Loop *loop;

      loop = (Loop *)malloc(sizeof(Loop));
      loop->header = header;
      loop->latch = latch;
      loop->follow = NULL;

      header->member.basic.loop = loop;
      markLoop( loop, latch );

      if ( latch->member.basic.control->common.tag == If ) {
        loop->type = dowile;
        if ( latch->forw_refs->interval == header )
          swapBranches( latch );
        loop->follow = latch->forw_refs->interval;
      } else if ( header->member.basic.label->next->common.tag == If ) {
        /* potential while-loop - header consists of a single
           if statement */
        loop->type = wile;
        if ( header->forw_refs->next->interval->member.basic.loop != loop )
          swapBranches( header );
        loop->follow = header->forw_refs->interval;
	if ( header->forw_refs->interval->member.basic.loop == loop ) {
	  loop->type = endless;
	  loop->follow = NULL;
	}
      } else
        loop->type = endless;
    }
  }

  for ( ref = interval->forw_refs; ref != NULL; ref = ref->next )
    if ( ref->interval->number > interval->number )
      loopDetect( ref->interval, order );

}

static Interval *
commonDom( Interval *currImmDom, Interval *predImmDom, Interval *order[] )
{
  if ( currImmDom == NULL )
    return predImmDom;

  if ( predImmDom == NULL )
    return currImmDom;

  while ( currImmDom != NULL &&
          predImmDom != NULL &&
          currImmDom != predImmDom )
  {
    if ( currImmDom->number < predImmDom->number )
      predImmDom = predImmDom->member.basic.immDom;
    else
      currImmDom = currImmDom->member.basic.immDom;
  }

  return currImmDom;
}

static void
findDominators( unsigned numIs, Interval *order[] )
{
  unsigned idx;
  iList *ref;

  for ( idx = 1; idx <= numIs; idx++ ) {
    for ( ref = order[idx]->back_refs; ref != NULL; ref = ref->next ) {
      if ( ref->interval->number < idx ){
        order[idx]->member.basic.immDom
          = commonDom( order[idx]->member.basic.immDom,
                       ref->interval, order );
      }
    }
  }
}

static unsigned pathNum = 0;

static int
pathExists( Interval *from, Interval *to )
{
  iList *ref;

  if ( from == to )
    return 1;
  if ( from->number > to->number )
    return 0;
  from->path = pathNum;
  for ( ref = from->forw_refs; ref != NULL; ref = ref->next )
    if ( ref->interval->number > from->number &&
	 ref->interval->path != pathNum &&
	 pathExists( ref->interval, to ) )
      return 1;
  return 0;
}

static void
condDetect( unsigned count, Interval *order[] )
{
  iList *unref = NULL, *tmpRef;
  int idx;

  for ( idx = count; idx >= 1; idx-- ) {
    if ( order[idx]->member.basic.control->common.tag == If &&
         ( order[idx]->member.basic.loop == NULL ||
           ( order[idx]->member.basic.loop->header != order[idx] &&
             order[idx]->member.basic.loop->latch != order[idx] ) ) ) {
      int idx2;

      tmpRef = (iList *)malloc(sizeof(iList));
      tmpRef->interval = order[idx];
      tmpRef->next = unref;
      unref = tmpRef;

      for ( idx2 = count; idx2 > idx; idx2-- ) {
        if ( order[idx2]->member.basic.immDom == order[idx] &&
             order[idx2]->back_refs->next != NULL ) {
          while ( unref != NULL && unref->interval->number < idx2 ) {
	    pathNum++;
	    if ( pathExists( unref->interval, order[idx2] ) ) {
              unref->interval->member.basic.ifFollow = order[idx2];
              if ( unref->interval->forw_refs->next->interval == order[idx2] )
                swapBranches( unref->interval );
              tmpRef = unref;
              unref = unref->next;
              free( tmpRef );
	    }
	    else
              unref = unref->next;
	  }
          break;
        }
      }
    }
  }
}

static void
resetLabels( unsigned count, Interval *order[] )
{
  int idx;

  for ( idx = 1; idx <= count; idx++ )
    order[idx]->number = -1;
}

struct cmixStmtLabel *
compress( struct cmixStmtLabel *target )
{
  union cmixStmt *afterLabel;

  if ( target->refcount > 1 )
    return target;

  afterLabel = target->next;

  if ( afterLabel->common.tag == Goto ) {
    afterLabel->jump.target = compress( afterLabel->jump.target );
    return afterLabel->jump.target;
  }

  if ( afterLabel->common.tag == If && afterLabel->cond.then_target ==
       afterLabel->cond.else_target ) {
    afterLabel->cond.then_target = compress( afterLabel->cond.then_target );
    afterLabel->cond.else_target = afterLabel->cond.then_target;
    return afterLabel->cond.then_target;
  }

  return target;
}

static void
initRestruct( union cmixStmt *stmt )
{
  while ( stmt != NULL ) {
    switch ( stmt->common.tag ) {
    case Plain:
    case Return:
    case Label:
      break;
    case If:
      stmt->cond.then_target = compress( stmt->cond.then_target );
      stmt->cond.else_target = compress( stmt->cond.else_target );
      break;
    case Goto:
      stmt->jump.target = compress( stmt->jump.target );
      break;
    default:
      fprintf(stderr,"Tag = %d\n", stmt->common.tag);
      assert( 0 == 1 );
    }
    stmt = stmt->common.next;
  }
}

static void
printIntervals( unsigned numIs, Interval *order[] )
{
  unsigned idx;
  iList *ref;

  for ( idx = 1; idx <= numIs; idx++ ) {
    printf( "%u: ", idx );
    for ( ref = order[idx]->forw_refs; ref != NULL; ref = ref->next ) {
      printf( "%u ", ref->interval->number );
      }
    if ( order[idx]->member.basic.immDom != NULL )
      printf( "(%u)", order[idx]->member.basic.immDom->number );
    printf( "\n" );
  }
}

Restruct *
cmixRestructStmts( union cmixStmt *stmt )
{
  Restruct *result;
  unsigned count, lastCount;
  union cmixStmt *newLabel;
  union cmixStmt *newGoto;
  union cmixStmt *firstLabel;
  Interval *currentTop;

  if ( stmt->common.tag != Label ) {
    firstLabel = cmixMakeLabel();
    firstLabel->label.next = stmt;
  } else
    firstLabel = stmt;

  newGoto = (union cmixStmt *)malloc(sizeof(struct cmixStmtGoto));
  newGoto->jump.tag = Goto;
  newGoto->jump.next = firstLabel;
  newGoto->jump.target = &firstLabel->label;
  newLabel = cmixMakeLabel();
  newLabel->label.next = newGoto;

  initRestruct( newLabel );

  result = (Restruct *)malloc(sizeof(Restruct));

  count = makeInterval( &newLabel->label );
  currentTop = newLabel->label.interval;

  result->order = (Interval **)malloc((count+1)*sizeof(Interval *));
  result->count = count;

  enumIntervals( currentTop, count, result->order );

  findDominators( count, result->order );

  /** debugging
  printIntervals( count, result->order );
  **/

  lastCount = 0;
  while ( count != lastCount ) {
    loopDetect( currentTop, result->order );
    lastCount = count;
    count = iterateIntervals( currentTop );
    currentTop = currentTop->parent;
  }

  condDetect( result->count, result->order );

  /*
  printf( "Detected conditionals\n" );
  */

  resetLabels( result->count, result->order );

  return result;
}
