/*
 *             Automatically Tuned Linear Algebra Software v3.0Beta
 *                    (C) Copyright 1999 R. Clint Whaley                     
 *
 * Redistribution and use in source and binary forms, with or without
 * modification, are permitted provided that the following conditions
 * are met:
 *   1. Redistributions of source code must retain the above copyright
 *      notice, this list of conditions and the following disclaimer.
 *   2. Redistributions in binary form must reproduce the above copyright
 *      notice, this list of conditions, and the following disclaimer in the
 *      documentation and/or other materials provided with the distribution.
 *   3. The name of the University, the ATLAS group, or the names of its 
 *      contributers may not be used to endorse or promote products derived
 *      from this software without specific written permission.
 *
 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 
 * ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED
 * TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
 * PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE UNIVERSITY OR CONTRIBUTORS BE
 * LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
 * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
 * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
 * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
 * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
 * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
 * POSSIBILITY OF SUCH DAMAGE. 
 *
 */

#include <stdio.h>
#include <stdlib.h>
#include <ctype.h>
#include <assert.h>

#define MVIsMV   0
#define MVIsMM   1
#define MVIsAxpy 2

double GetAvg(int n, double tolerance, double *mflop)
{
   int i, j;
   double t0, tavg;
/*
 * Sort results, largest first
 */
   for (i=0; i != n; i++)
   {
      for (j=i+1; j < n; j++)
      {
         if (mflop[i] < mflop[j])
         {
            t0 = mflop[i];
            mflop[i] = mflop[j];
            mflop[j] = t0;
         }
      }
   }
/*
 * Throw out result if it is outside tolerance; rerun if two mflop not within
 * tolerance;  this code assumes n == 3
 */
   if (tolerance*mflop[1] < mflop[0])  /* too big a range in results */
   {
      if (tolerance*mflop[2] < mflop[1]) return(-1.0);
      tavg = mflop[1] + mflop[2] / 2.0;
   }
   else if (tolerance*mflop[2] < mflop[0]) tavg = (mflop[0] + mflop[1]) / 2.0;
   else tavg = (mflop[0] + mflop[1] + mflop[2]) / 3.0;

   return(tavg);
}

int GetL1CacheSize()
{
   FILE *L1f;
   int L1Size;

   L1f = fopen("res/L1CacheSize", "r");
   if (!L1f)
   {
      assert(system("make res/L1CacheSize\n") == 0);
      L1f = fopen("res/L1CacheSize", "r");
      assert(L1f != NULL);
   }
   fscanf(L1f, "%d", &L1Size);
   fclose(L1f);
   fprintf(stderr, "\n      Read in L1 Cache size as = %dKB.\n",L1Size);
   return(L1Size);
}

void emit_mvhead(char pre, double l1mul)
/*
 * the routine assumes sizeof(TYPE) is same on machine running this code
 * (possibly a cross-compiler), and the target machine -- this is bad
 */
{
   char fnam[64];
   int l1;
   FILE *fp;

   l1 = 1024 * GetL1CacheSize();
   if (pre == 's') l1 /= sizeof(float);
   else if (pre == 'd') l1 /= sizeof(double);
   else if (pre == 'c') l1 /= 2*sizeof(float);
   else if (pre == 'z') l1 /= 2*sizeof(double);
   l1 = l1mul * l1;
   sprintf(fnam, "atlas_%cmv.h", pre);
   fp = fopen(fnam, "w");
   fprintf(fp, "#ifndef ATLAS_%cMV_H\n", toupper(pre));
   fprintf(fp, "#define ATLAS_%cMV_H\n\n", toupper(pre));

   fprintf(fp, "#define ATL_L1mvelts %d\n", l1);
   fprintf(fp, "#include \"atlas_%cmvN.h\"\n", pre);
   fprintf(fp, "#include \"atlas_%cmvT.h\"\n", pre);

   fprintf(fp, "\n#endif\n");
   fclose(fp);
}

void emit_transhead(char pre, char TA, int flag, int mu, int nu)
{
   char fnam[128];
   FILE *fp;

   sprintf(fnam, "atlas_%cmv%c.h", pre, TA);
   fp = fopen(fnam, "w");
   assert(fp);
   fprintf(fp, "#ifndef ATLAS_MV%c_H\n", TA);
   fprintf(fp, "#define ATLAS_MV%c_H\n\n", TA);

   fprintf(fp, "#include \"atlas_misc.h\"\n");
   if (flag == MVIsMM) fprintf(fp, "#include \"atlas_lvl3.h\"\n");
   fprintf(fp, "\n");

   if (flag != MVIsMM)
   {
      fprintf(fp, "#define ATL_mv%cMU %d\n", TA, mu);
      fprintf(fp, "#define ATL_mv%cNU %d\n", TA, nu);
   }
   else
   {
      fprintf(fp, "#define ATL_mv%cMU ATL_mmMU\n", TA);
      fprintf(fp, "#define ATL_mv%cNU NB\n", TA);
      fprintf(fp, "#define ATL_mv%cCallsGemm\n", TA);
   }
   fprintf(fp, "#ifndef ATL_L1mvelts\n");
   fprintf(fp, "   #define ATL_L1mvelts ((3*ATL_L1elts)>>2)\n");
   fprintf(fp, "#endif\n");

   if (TA == 'N')
   {
      if (flag == MVIsAxpy) fprintf(fp, "#define ATL_AXPYMV\n");
      else if (nu < 32)
      {
         fprintf(fp, "#ifndef ATL_mvpagesize\n");
         fprintf(fp, "   #define ATL_mvpagesize ATL_DivBySize(4096)\n");
         fprintf(fp, "#endif\n");
         fprintf(fp, "#ifndef ATL_mvntlb\n");
         fprintf(fp, "   #define ATL_mvntlb 56\n");
         fprintf(fp, "#endif\n");
      }
   
      fprintf(fp, "\n#define ATL_GetPartMVN(A_, lda_, mb_, nb_) \\\n{ \\\n");
      if (flag == MVIsAxpy)
      {
         fprintf(fp, 
"   *(mb_) = (ATL_L1mvelts - ATL_mvNNU*(ATL_mvNNU+2)) / (ATL_mvNNU+1); \\\n");
         fprintf(fp, 
"   if (*(mb_) > ATL_mvNMU) *(mb_) = ATL_mvNMU*( *(mb_)/ATL_mvNMU ); \\\n");
         fprintf(fp, "   else *(mb_) = ATL_mvNMU; \\\n");
         fprintf(fp, "   *(nb_) = ATL_mvNNU; \\\n");
      }
      else if (nu >= 32)
      {
         fprintf(fp, "   *(nb_) = ATL_mvNNU; \\\n");
         fprintf(fp, 
"   *(mb_) = (ATL_L1mvelts - ATL_mvNNU - (ATL_mvNNU+1)*ATL_mvNMU) / (ATL_mvNNU+1); \\\n");
         fprintf(fp, 
"   if (*(mb_) > ATL_mvNMU) *(mb_) = (*(mb_) / ATL_mvNMU)*ATL_mvNMU; \\\n");
         fprintf(fp, "   else *(mb_) = ATL_mvNMU; \\\n");
      }
      else if (flag == MVIsMM) /* gemv calls gemm */
      {
         fprintf(fp, 
            "   *(mb_) = (ATL_L1mvelts - (ATL_mmMU+1)*NB) / (NB + 2); \\\n");
         fprintf(fp, 
         "   if (*(mb_) > NB) *(mb_) = ATL_MulByNB(ATL_DivByNB(*(mb_))); \\\n");
         fprintf(fp, 
         "   else if (*(mb_) < ATL_mmMU) *(mb_) = NB; \\\n");
         fprintf(fp, 
         "   else  *(mb_) = (*(mb_) / ATL_mmMU) * ATL_mmMU; \\\n");
         fprintf(fp, "   *(nb_) = NB; \\\n");
      }
      else
      {
         assert(mu && nu);
         fprintf(fp, 
"   *(nb_) = (ATL_L1mvelts - (ATL_mvNMU<<1)) / ((ATL_mvNMU<<1)+1); \\\n");
         fprintf(fp, "   if (ATL_mvpagesize > (lda_)) \\\n   { \\\n");
         fprintf(fp, 
            "      *(mb_) = (ATL_mvpagesize / (lda_)) * ATL_mvntlb; \\\n");
         fprintf(fp, 
            "      if ( *(mb_) < *(nb_) ) *(nb_) = *(mb_); \\\n");
         fprintf(fp, "   } \\\n");
         fprintf(fp, 
            "   else if (ATL_mvntlb < *(nb_)) *(nb_) = ATL_mvntlb; \\\n");
         fprintf(fp, 
"   if (*(nb_) > ATL_mvNNU) *(nb_) = (*(nb_) / ATL_mvNNU) * ATL_mvNNU; \\\n");
         fprintf(fp, "   else *(nb_) = ATL_mvNNU; \\\n");
         fprintf(fp, 
"   *(mb_) = (ATL_L1mvelts - *(nb_) * (ATL_mvNMU+1)) / (*(nb_)+2); \\\n");
         fprintf(fp, 
"   if (*(mb_) > ATL_mvNMU) *(mb_) = (*(mb_) / ATL_mvNMU) * ATL_mvNMU; \\\n");
         fprintf(fp, "   else *(mb_) = ATL_mvNMU; \\\n");
      }
      fprintf(fp, "}\n");
   }
   else
   {
      if (mu != 0)
      {
         fprintf(fp, "#ifndef ATL_mvNNU\n");
         fprintf(fp, "   #include \"atlas_%cmvN.h\"\n", pre);
         fprintf(fp, "#endif\n");
      }
      fprintf(fp, "\n");
      fprintf(fp, "#define ATL_GetPartMVT(A_, lda_, mb_, nb_) \\\n{ \\\n");
      if (mu == 0)
      {
         fprintf(fp, 
         "   *(mb_) = (ATL_L1mvelts - NB - ATL_mmMU*(NB+1)) / (NB+1); \\\n");
         fprintf(fp, "   if (*(mb_) > NB) \\\n   { \\\n");
         fprintf(fp, "      *(mb_) = ATL_MulByNB(ATL_DivByNB(*(mb_))); \\\n");
         fprintf(fp, "      *(nb_) = NB; \\\n   } \\\n");
         fprintf(fp, "   else \\\n   { \\\n");
         fprintf(fp, 
"      if (*(mb_) > ATL_mmMU) *(nb_) = (*(mb_) / ATL_mmMU)*ATL_mmMU; \\\n");
         fprintf(fp, "      else *(nb_) = ATL_mmMU; \\\n");
         fprintf(fp, "      *(mb_) = NB; \\\n   } \\\n");
         fprintf(fp, "}\n");
      }
      else
      {
         fprintf(fp, 
"   *(mb_) = (ATL_L1mvelts - (ATL_mvTMU<<1)) / ((ATL_mvTMU<<1)+1); \\\n");
         fprintf(fp, 
"   if (*(mb_) > ATL_mvTNU) *(mb_) = (*(mb_)/ATL_mvTNU)*ATL_mvTNU; \\\n");
         fprintf(fp, "   else (*mb_) = ATL_mvTNU; \\\n");
         fprintf(fp, "   *(nb_) = ATL_mvTMU; \\\n");
         fprintf(fp, "}\n");
      }

      fprintf(fp, 
         "#if defined(ATL_mvNCallsGemm) && defined(ATL_mvTCallsGemm) \n\n");

      fprintf(fp, 
         "#define ATL_GetPartSYMV(A_, lda_, mb_, nb_) *(mb_) = *(nb_) = NB \n");

      fprintf(fp, "\n#elif defined(ATL_mvNCallsGemm) \n\n");

      fprintf(fp,"#define ATL_GetPartSYMV(A_, lda_, mb_, nb_) \\\n{ \\\n");
      fprintf(fp, "   *(nb_) = Mmax(ATL_mmMU, ATL_mvTMU); \\\n");
      fprintf(fp, "   *(mb_) = (ATL_L1mvelts - NB2) / (NB +2 + *(nb_)); \\\n");
      fprintf(fp, 
         "   if (*(mb_) > *(nb_)) *(mb_) = ( *(mb_) / *(nb_) ) * *(nb_); \\\n");
      fprintf(fp, "   else *(mb_) = *(nb_); \\\n");
      fprintf(fp, "   *(nb_) = NB; \\\n");
      fprintf(fp, "}\n");

      fprintf(fp, "\n#elif defined(ATL_AXPYMV)\n\n");

      fprintf(fp,"#define ATL_GetPartSYMV(A_, lda_, mb_, nb_) \\\n{ \\\n");
      fprintf(fp, "   *(nb_) = ATL_lcm(ATL_mvNNU, ATL_mvTNU); \\\n");
      fprintf(fp, "   *(mb_) = Mmax(ATL_mvNMU, ATL_mvTNU); \\\n");
      fprintf(fp, "   *(mb_) = ATL_L1mvelts - ((*(nb_) + *(mb_))<<1) / (*(nb_) + 2 + *(mb_)); \\\n");
      fprintf(fp, "}\n");

      fprintf(fp, "\n#elif defined(ATL_mvTCallsGemm)\n\n");

      fprintf(fp,"#define ATL_GetPartSYMV(A_, lda_, mb_, nb_) \\\n{ \\\n");
      fprintf(fp, "   *(mb_) = Mmax(ATL_mmMU, ATL_mvTMU); \\\n");
      fprintf(fp, "   *(nb_) = (ATL_L1mvelts - NB2) / (NB + 2 + *(mb_)); \\\n");
      fprintf(fp, 
         "   if (*(nb_) > *(mb_)) *(nb_) = ( *(nb_) / *(mb_) ) * *(mb_); \\\n");
      fprintf(fp, "   else *(nb_) = *(mb_); \\\n");
      fprintf(fp, "   *(mb_) = NB; \\\n");
      fprintf(fp, "}\n");

      fprintf(fp, "\n#else\n\n");

      fprintf(fp,"#define ATL_GetPartSYMV(A_, lda_, mb_, nb_) \\\n{ \\\n");
      fprintf(fp, "   for(*(nb_)=ATL_mvNNU; (*(nb_))*(*(nb_))+((*(nb_))<<2) + ATL_mvTMU*(*(nb_)+2) < ATL_L1mvelts; *(nb_) += ATL_mvNNU); \\\n");
      fprintf(fp, 
            "   *(mb_) = Mmax(ATL_mvTNU, ATL_mvNMU) + *(nb_); \\\n");
      fprintf(fp, 
      "   *(mb_) = ( ATL_L1mvelts - ((*(mb_))<<1) ) / (*(mb_) + 2); \\\n");
      fprintf(fp,
"   if (*(mb_) > ATL_mvTNU) *(mb_) = (*(mb_)/ATL_mvTNU)*ATL_mvTNU; \\\n");
      fprintf(fp, "   else *(mb_) = ATL_mvTNU; \\\n");
      fprintf(fp, "}\n");

      fprintf(fp, "\n#endif\n");
   }

   fprintf(fp, "\n#endif\n");
   fclose(fp);
}

double mvcase(char pre, char *mvnam, char TA, int flag, int mu, int nu, 
              int cas, double l1mul)
{
   char nTA;
   char ln[128], fnam[64];
   double mfs[3], mf;
   FILE *fp;

   if (TA == 'n' || TA == 'N') nTA = 'T';
   else nTA = 'N';

   sprintf(fnam, "res/%cgemv%c_%d", pre, TA, cas);
   fp = fopen(fnam, "r");
   if (fp == NULL)
   {
      emit_mvhead(pre, l1mul);
      emit_transhead(pre, TA, flag, mu, nu);
      emit_transhead(pre, nTA, flag, mu, nu);
      sprintf(ln, "make %cmvcase ta=%c nta=%c mvrout=%s cas=%d\n", 
              pre, TA, nTA, mvnam, cas);
      fprintf(stderr, "%s", ln);
      assert(system(ln) == 0);
      fp = fopen(fnam, "r");
      assert(fp);
   }
   assert(fscanf(fp, " %lf %lf %lf", mfs, mfs+1, mfs+2) == 3);
   fclose(fp);
   mf = GetAvg(3, 1.20, mfs);
   if (mf == -1.0)
   {
      fprintf(stderr, 
"\n\n%s : VARIATION EXCEEDS TOLERENCE, RERUN WITH HIGHER REPS.\n\n", fnam);
      sprintf(ln, "rm -f %s\n", fnam);
      system(ln);
      exit(-1);
   }
   return(mf);
}

double FindL1Mul(char pre, char *mvnam, char TA, int flag, int mu, int nu)
{
   double low = .5, high = 1.0;
   double mflow, mfhigh;
   int ilow, ihigh;

   do
   {
      ilow = (low  * 100.0);
      ihigh = (high * 100.0);
      mflow = mvcase(pre, mvnam, TA, flag, mu, nu, ilow, low);
      mfhigh = mvcase(pre, mvnam, TA, flag, mu, nu, ihigh, high);
      fprintf(stdout, "      %.2f%% %.2fMFLOP  ---  %.2f%% %.2fMFLOP\n",
              low*100.0, mflow, high*100.0, mfhigh);
      if (mflow < mfhigh) low += 0.5*(high-low);
      else high -= 0.5 * (high-low);
   }
   while (ihigh-ilow);
   fprintf(stdout, "\n\nBEST %% of L1 cache: %.2f\n", low*100.0);
   return(low);
}

void GetCases(FILE *fp, int *N, char ***fnams, int **flags, 
              int **mus, int **nus)
{
   char **fnam;
   int i, n;
   int *mu, *nu, *flag;

   fscanf(fp, " %d", &n);
   assert(n < 100 && n > 0);
   fnam = malloc(n * sizeof(char*));
   assert(fnam);
   for (i=0; i < n; i++) assert(fnam[i] = malloc(64*sizeof(char)));
   mu = malloc(n * sizeof(int));
   nu = malloc(n * sizeof(int));
   flag = malloc(n * sizeof(int));
   assert(mu && nu && flag);
   for (i=0; i < n; i++)
   {
      assert(fscanf(fp, " %d %d %d %s", flag+i, mu+i, nu+i, fnam[i]) == 4);
      assert(mu[i] >= 0 && nu[i] >= 0 && fnam[i][0] != '\0');
   }
   
   *N = n;
   *fnams = fnam;
   *flags = flag;
   *mus = mu;
   *nus = nu;
}

int RunTransCases(char pre, char TA, int ncases, char **fnams, 
                  int *flags, int *mus, int *nus)
{
   int i, imax=0;
   double mf, mfmax=0.0;

   for (i=0; i < ncases; i++)
   {
      mf = mvcase(pre, fnams[i], TA, flags[i], mus[i], nus[i], i+1, 0.75);
      fprintf(stdout, "%s : %.2f\n", fnams[i], mf);
      if (mf > mfmax)
      {
         mfmax = mf;
         imax = i+1;
      }
   }
   assert(imax);
   fprintf(stdout, 
           "\nbest %cgemv%c : case %d, mu=%d, nu=%d at %.2f MFLOPS\n\n", 
           pre, TA, imax, mus[imax-1], nus[imax-1], mfmax);
   return(imax-1);
}

void CreateSum(char pre, double l1mul, char *fnamN, int flagN, int muN, int nuN,
               double mfN, char *fnamT, int flagT, int muT, int nuT, double mfT)
{
   char fnam[32];
   FILE *fp;

   sprintf(fnam, "res/%cMVRES", pre);
   fp = fopen(fnam, "w");
   assert(fp);
   fprintf(fp, "%d %d %d %.2f %.2f %s\n", flagN, muN, nuN, l1mul, mfN, fnamN);
   fprintf(fp, "%d %d %d %.2f %.2f %s\n", flagT, muT, nuT, l1mul, mfT, fnamT);
   fclose(fp);
}

void mvinstall(char pre, double l1mul, char *fnamN, int flagN, int muN, int nuN,
               char *fnamT, int flagT, int muT, int nuT)
{
   char ln[128];
   double mfN, mfT;

   mfN = mvcase(pre, fnamN, 'N', flagN, muN, nuN, (int)(l1mul*100), l1mul);
   mfT = mvcase(pre, fnamT, 'T', flagT, muT, nuT, (int)(l1mul*100), l1mul);
   emit_mvhead(pre, l1mul);
   emit_transhead(pre, 'N', flagN, muN, nuN);
   emit_transhead(pre, 'T', flagT, muT, nuT);
   sprintf(ln, "make %cinstall mvNrout=%s mvTrout=%s\n", pre, fnamN, fnamT);
   fprintf(stderr, "%s", ln);
   assert(system(ln) == 0);
   CreateSum(pre, l1mul, fnamN, flagN, muN, nuN, mfN, 
             fnamT, flagT, muT, nuT, mfT);
}

void RunCases(char pre)
{
   char fnam[128];
   char Nfnam[64], Tfnam[64];
   char **fnamN, **fnamT;
   int i, nNTcases, nTcases, Nbest, Tbest;
   int Nflag, Nmu, Nnu, Tflag, Tmu, Tnu;
   int *flagN, *muN, *nuN, *flagT, *muT, *nuT;
   double l1mul;
   FILE *fp;

/*
 * Read in cases to try
 */
   sprintf(fnam, "../CASES/%ccases.dsc", pre);
   fp = fopen(fnam, "r");
   assert(fp);
   GetCases(fp, &nNTcases, &fnamN, &flagN, &muN, &nuN);
   GetCases(fp, &nTcases,  &fnamT, &flagT, &muT, &nuT);
   fclose(fp);
/*
 * Try all cases for each trans case
 */
   Nbest = RunTransCases(pre, 'N', nNTcases, fnamN, flagN, muN, nuN);
   Tbest = RunTransCases(pre, 'T', nTcases, fnamT, flagT, muT, nuT);

   Nflag = flagN[Nbest]; Tflag = flagT[Tbest];
   Nmu = muN[Nbest]; Nnu = nuN[Nbest]; strcpy(Nfnam, fnamN[Nbest]);
   Tmu = muT[Tbest]; Tnu = nuT[Tbest]; strcpy(Tfnam, fnamT[Tbest]);

   free(flagN); free(flagT); free(muN); free(muT); free(nuN); free(nuT);
   for (i=0; i < nNTcases; i++) free(fnamN[i]);
   free(fnamN);
   for (i=0; i < nTcases; i++) free(fnamT[i]);
   free(fnamT);
   l1mul = FindL1Mul(pre, Tfnam, 'T', Tflag, Tmu, Tnu);
   CreateSum(pre, l1mul, Nfnam, Nflag, Nmu, Nnu, -1.0, 
             Tfnam, Tflag, Tmu, Tnu, -1.0);
}

void GoToTown(pre)
{
   char fnamN[128], fnamT[128], ln[128];
   int flagN, muN, nuN, flagT, muT, nuT;
   double l1mul, mfN, mfT;
   FILE *fp;

   sprintf(ln, "res/%cMVRES", pre);
   fp = fopen(ln, "r");
   if (fp == NULL)
   {
      RunCases(pre);
      fp = fopen(ln, "r");
      assert(fp);
   }
   assert(fscanf(fp, " %d %d %d %lf %lf %s", 
                 &flagN, &muN, &nuN, &l1mul, &mfN, fnamN) == 6);
   assert(fscanf(fp, " %d %d %d %lf %lf %s", 
                 &flagT, &muT, &nuT, &l1mul, &mfT, fnamT) == 6);
   fclose(fp);
   mvinstall(pre, l1mul, fnamN, flagN, muN, nuN, fnamT, flagT, muT, nuT);
}

void PrintUsage(char *fnam)
{
   fprintf(stderr, "USAGE: %s [-p <s,d,c,z>]\n", fnam);
   exit(-1);
}

void GetFlags(int nargs, char **args, char *pre, char *TA, int *mu, int *nu)
{
   char ctmp;
   int i;
   *pre = 'd';
   *mu = *nu = 0;
   *TA = ' ';
   for (i=1; i < nargs; i++)
   {
      if (args[i][0] != '-') PrintUsage(&args[0][0]);
      switch(args[i][1])
      {
      case 'p':
         ctmp = args[++i][0];
         ctmp = tolower(ctmp);
         if (ctmp == 's' || ctmp == 'd' || ctmp == 'c' || ctmp == 'z')
            *pre = ctmp;
         else PrintUsage(&args[0][0]);
         break;
      case 'm':
         *mu = atoi(args[++i]);
         assert(*mu > 0);
         break;
      case 'n':
         *nu = atoi(args[++i]);
         assert(*nu > 0);
         break;
      case 'A':
         ctmp = args[++i][0];
         ctmp = toupper(ctmp);
         if (ctmp == 'N' || ctmp == 'T') *TA = ctmp;
         else PrintUsage(args[0]);
         break;
      default:
         fprintf(stderr, "Unknown flag : %s\n", args[i]);
         PrintUsage(&args[0][0]);
      }
   }
   if (*mu == 0 && *nu) *mu = *nu;
   else if (*nu == 0 && *mu) *nu = *mu;
   if ( (*mu || *nu) && *TA == ' ') *TA = 'N';
}


main(int nargs, char **args)
{
   char pre, TA;
   int mu, nu;
   GetFlags(nargs, args, &pre, &TA, &mu, &nu);
   GoToTown(pre);
   exit(0);
}
