/* $Id: rex.h,v 1.9 2001/05/16 09:40:09 ericp Exp $ */

class rexfd;

class fdkeeper
{
 public:
  virtual void insert_fd (int fdn, rexfd *rfd) = 0;
  virtual void remove_fd (int fdn) = 0;
};




class rexfd
{
 protected:
  fdkeeper &fdk;
  ptr <aclnt> proxy;
  u_int32_t channo;
  int fd;
  
 public:
  //todo : does this interface need shutdown?
  virtual void shutdown (int how, int err = 0) {};

  virtual void data (svccb *sbp) {
#if 0      
      rex_payload *argp = sbp->template getarg<rex_payload> ();
      str data (argp->data.base (), argp->data.size ());
      warn ("received data on dummy fd: %s\n", data.cstr ());
#endif      
      sbp->replyref (true);
  }

  virtual void newfd (svccb *sbp) {sbp->replyref (true);}
  virtual void exited () {};
  
  rexfd (u_int32_t channo, int fd, ptr <aclnt> pr, fdkeeper &fdk)
    : fdk (fdk), proxy (pr), channo (channo), fd (fd) {
    if (fd < 0)
      fatal ("attempt to create negative fd: %d\n", fd);
    fdk.insert_fd (fd, this);
  }

  virtual ~rexfd () {fdk.remove_fd (fd);}
};



static bool garbage_bool;

class unixfd : public rexfd {
 protected:
  int localfd;

 private:
  ptr<aios> paios;

  bool weof;
  bool reof;

  bool noclose;
  bool shutrdonexit;


  
 public:
  //todo: does this need to be public?
  virtual void shutdown (int how, int err = 0) {
    if (!noclose)
      ::shutdown (localfd, how);
    
    if (how == SHUT_WR) weof = true; 
    else if (how == SHUT_RD) reof = true;
    else weof = reof = true;
    
    
    if (weof && reof) {
      if (!noclose) 
	::close (localfd);
      
      rex_int_arg arg;
      arg.channel = channo;
      arg.val = fd;
      proxy->call (REX_CLOSE, &arg, &garbage_bool, aclnt_cb_null);
      delete this;
    }
  }

 private:   
  void rcb (const str data, int err) {
    //I think this gets called for write errors as well
    rex_payload arg;
    arg.channel = channo;
    arg.fd = fd;
    
    if (!data) {
      if (err) {
	shutdown (SHUT_RDWR);
	return;
      }
      else {
	arg.data.set ((char *)NULL, 0);
	proxy->call (REX_DATA, &arg, &garbage_bool, aclnt_cb_null);
	shutdown (SHUT_RD);
	return;
      }
    }
    else {
      arg.data.set (const_cast<char *> (data.cstr ()), 
		    data.len (), freemode::NOFREE);
      
      proxy->call (REX_DATA, &arg, &garbage_bool, aclnt_cb_null);
      
      // todo:  flow control
      paios->readany (wrap (this, &unixfd::rcb));
    }
  }

 public:

  virtual void
    newfd (svccb *sbp)
    {
      rexcb_newfd_arg *argp = sbp->template getarg<rexcb_newfd_arg> ();
      
      int s[2];

      if(socketpair(AF_UNIX, SOCK_STREAM, 0, s)) {
	warn << "error creating socketpair for agent forwarding";
	sbp->replyref(false);
	return;
      }

      make_async (s[1]);
      make_async (s[0]);

      if (writefd (localfd, NULL, 0, s[1]) < 0) {
	warn ("failed to pass newfd along existing fd\n");
	::close (s[0]);
	::close (s[1]);
	sbp->replyref(false);
	return;
      }

      vNew unixfd (channo, argp->newfd, proxy, fdk, s[0]);

      ::close (s[1]);

      sbp->replyref (true);
    }
  
  virtual void
    data (svccb *sbp)
    {
      assert (paios);
      
      rex_payload *argp = sbp->template getarg<rex_payload> ();
      
      if (argp->data.size () > 0) {
	if (weof) {
	  sbp->replyref (false);
	  return;
	}
	else {
	  str data (argp->data.base (), argp->data.size ());
	  paios << data;
	  sbp->replyref (true);
	}
      }
      else {
	sbp->replyref (true);
	
	//we don't shutdown immediately to give data a chance to
	//asynchronously flush
	paios->setwcb (wrap (this, &unixfd::shutdown, SHUT_WR));
      }
    }

  void
    exited ()
    {
      if (shutrdonexit)
	shutdown (SHUT_RD);
    }

  
  /* unixfd specific arguments:
       localfd:
                 local file descriptor
       noclose:
                 will not use close or shutdown calls on the local file descriptor, useful for terminal
                 descriptors, which must hang around so that raw mode can be disabled, etc.
       shutrdonexit:
                 when the remote module exits, shutdown the read direction of the local file descriptor.
		 this isn't always done since not all file descriptors managed on the REX channel are
		 currently connected to the remote module.  for example file descriptors passed to
		 the local module are potentially not connected to the remote module.               */
  unixfd (u_int32_t channo, int fd, ptr <aclnt> pr, fdkeeper &fdk,
	  int localfd, bool noclose = false,  bool shutrdonexit = false) :
    rexfd::rexfd (channo, fd, pr, fdk),
    localfd (localfd),  weof (false), reof (false), noclose (noclose), shutrdonexit (shutrdonexit)
    {
      paios = aios::alloc (localfd);
      paios->readany (wrap (this, &unixfd::rcb));
    }

  virtual ~unixfd ()
    {
      if (paios)
	paios->flush ();
    }

};






class rexchannel: public fdkeeper {

  vec <rexfd *>  vfds;
  
 protected:
  
  ptr <aclnt> proxy;
  u_int32_t channo;
  
 public:

  int initnfds;

  vec <str> command;

  ////////////member functions

  //these two implement the fdkeeper interface
  virtual void
    insert_fd (int fdn, rexfd *rfd)
    {
      assert (fdn >= 0);

      size_t oldsize = vfds.size ();
      size_t neededsize = fdn + 1;

      if (neededsize > oldsize) {
	vfds.setsize (neededsize);
	for (int ix = oldsize; implicit_cast <size_t> (ix) < neededsize; ix++)
	  vfds[ix] = NULL;
      }

      if (vfds[fdn]) {
	warn ("creating fd on busy fd %d at rexfd::rexfd, overwriting\n", fdn);
	assert (false);
      }

      vfds [fdn] = rfd;
    }

  virtual void
    remove_fd (int fdn)
    {
      vfds [fdn] = NULL;
    }
      
  rexchannel (int initialfdcount, vec <str> command)
    : initnfds (initialfdcount), command(command) {}

  virtual void
    madechannel () {};

  void
    channelinit (u_int32_t chnumber, ref <aclnt> proxyaclnt)
    {
      proxy = proxyaclnt;
      channo = chnumber;
      madechannel();
    }
  
  virtual void data(svccb *sbp)
    {
      assert (sbp->prog () == REXCB_PROG && sbp->proc () == REXCB_DATA);
      rex_payload *dp = sbp->template getarg<rex_payload> ();
      assert (dp->channel == channo);
      if (dp->fd < 0 ||
	  implicit_cast<size_t> (dp->fd) >= vfds.size () ||
	  !vfds[dp->fd]) {
	warn ("payload fd %d out of range\ndata:%s\n", dp->fd, dp->data.base ());
	sbp->replyref (false);
	return;
      }

      vfds[dp->fd]->data(sbp);
    }

  
  virtual void newfd (svccb *sbp)
    {
      assert (sbp->prog () == REXCB_PROG && sbp->proc () == REXCB_NEWFD);
      rexcb_newfd_arg *arg = sbp->template getarg<rexcb_newfd_arg> ();

      int fd = arg->fd;

      if (fd < 0 || implicit_cast<size_t> (fd) >= vfds.size () || !vfds[fd]) {
	warn ("newfd received on invalid fd %d at rexchannel::newfd\n", fd);
	sbp->replyref (false);
	return;
      }
      
      vfds[fd]->newfd (sbp);
    }

  virtual void
    exited ()
    {
      for (size_t ix = 0; ix < vfds.size();  ix++) {
	if (!vfds[ix]) continue;
	vfds[ix]->exited();
      } 
    }

  virtual ~rexchannel () {}
  
};


class rexsession {
 private:
 
  ptr<aclnt> proxy;
  ptr<axprt_crypt> proxyxprt;
  ptr<asrv> rexserv;

  // XXX: I was going to use a regular (rexchannel *) as the values
  // but I kept getting compiler errors, array notation was returning
  // a **rexchannel , an extra level of indirection.  dereferencing it
  // would segfault the program.  then I looked at the [] operator in 
  // qhash but it uses these incomprehensible "typename R::type" things
  //        is this a bug?  also ptr <rexchannel> wouldn't work, had to
  //                        finally settle with ref <rexchannel>

  qhash<u_int32_t, ref <rexchannel> > channels;

  void
    rexcb_dispatch (svccb *sbp)
    {
      if (!sbp) {
	//todo: add callback argument for this case
	warn << "rexcb_dispatch: error\n";
	return;
      }
      
      switch (sbp->proc ()) {
	
      case REXCB_NULL:
	sbp->reply (NULL);
	break;
	
      case REXCB_EXIT:
	{
	  rex_int_arg *argp = sbp->template getarg<rex_int_arg> ();
	  rexchannel *chan = channels[argp->channel];
	  
	  if(chan)
	    chan->exited();
	  break;
	}
	
      case REXCB_DATA:
	{
	  rex_payload *argp = sbp->template getarg<rex_payload> ();
	  rexchannel *chan = channels[argp->channel];
	  
	  if(chan)
	    chan->data(sbp);
	  else	         	    
	    sbp->replyref (false);
	  break;
	}
	
      case REXCB_NEWFD:
	{
	  rex_int_arg *argp = sbp->template getarg<rex_int_arg> ();
	  rexchannel *chan = channels[argp->channel];
	  if(chan)
	    chan->newfd(sbp);
	  else
	    sbp->replyref(false);
	  break;
	}
    
      default:
	sbp->reject (PROC_UNAVAIL);
	break;
      }
    }
  
  
  void
    madechannel (rex_mkchannel_res *resp, ref <rexchannel> newchan, cbv madechannelcb, clnt_stat err)
    {
      if (err) {
	fatal << "FAILED (" << err << ")\n";
      }
      else if (resp->err != SFS_OK) {
	// XXX
	fatal << "FAILED (mkchannel " << int (resp->err) << ")\n";
      }
      warnx << "made channel\n";

      channels.insert(resp->resok->channel, newchan);
      
      newchan->channelinit(resp->resok->channel, proxy);

      madechannelcb ();

      delete resp;
    }
    
  void
    seq2sessinfo (u_int64_t seqno, sfs_hash *sidp, sfs_sessinfo *sip,
		  rex_sesskeydat *kcsdat, rex_sesskeydat *kscdat)
    {
      kcsdat->seqno = seqno;
      kscdat->seqno = seqno;
      
      sfs_sessinfo si;
      si.type = SFS_SESSINFO;
      si.kcs.setsize (sha1::hashsize);
      sha1_hashxdr (si.kcs.base (), *kcsdat, true);
      si.ksc.setsize (sha1::hashsize);
      sha1_hashxdr (si.ksc.base (), *kscdat, true);

      if (sidp)
	sha1_hashxdr (sidp->base (), si, true);
      if (sip)
	*sip = si;

      bzero (si.kcs.base (), si.kcs.size ());
      bzero (si.ksc.base (), si.ksc.size ());
    }

  void
    attached (rexd_attach_res *resp, ptr<axprt_crypt> sessxprt, sfs_sessinfo *sessinfo, cbv sessioncreatedcb, clnt_stat err)
    {
      if (err) {
	fatal << "FAILED (" << err << ")\n";
      }
      else if (*resp != SFS_OK) {
	// XXX
	fatal << "FAILED (attach err " << int (*resp) << ")\n";
      }
      delete resp;
      warnx << "attached\n";
      
      proxyxprt = axprt_crypt::alloc (sessxprt->reclaim ());
      proxyxprt->encrypt (sessinfo->kcs.base (), sessinfo->kcs.size (),
			 sessinfo->ksc.base (), sessinfo->ksc.size ());

      bzero (sessinfo->kcs.base (), sessinfo->kcs.size ());
      bzero (sessinfo->ksc.base (), sessinfo->ksc.size ());
      delete sessinfo;
    
      proxy = aclnt::alloc (proxyxprt, rex_prog_1);
      rexserv = asrv::alloc (proxyxprt, rexcb_prog_1, wrap (this, &rexsession::rexcb_dispatch));
      
      sessioncreatedcb ();
    }

  

  void connected (rex_sesskeydat *kcsdat, rex_sesskeydat *kscdat,
		  sfs_seqno *rexseqno, cbv sessioncreatedcb, ptr<sfscon> sc, str err)
    {
      if (!sc) {
	fatal << schost << ": FAILED (" << err << ")\n";
      }
      
      ptr <axprt_crypt> sessxprt = sc->x;
      ptr <aclnt> sessclnt = aclnt::alloc (sessxprt, rexd_prog_1);
      
      rexd_attach_arg arg;

      arg.seqno = *rexseqno;
      sfs_sessinfo *sessinfo = New sfs_sessinfo;
      
      seq2sessinfo (0, &arg.sessid, NULL, kcsdat, kscdat);
      seq2sessinfo (arg.seqno, &arg.newsessid, sessinfo, kcsdat, kscdat);

      //ECP comment: why doesn't agent just give us sessid,newsessid,sessinfo ??
      
      rexd_attach_res *resp = New rexd_attach_res;
      sessclnt->call (REXD_ATTACH, &arg, resp, wrap (this, &rexsession::attached, resp, sessxprt, sessinfo, sessioncreatedcb));

      delete kcsdat;
      delete kscdat;
      delete rexseqno;
    }

  
 public:
  str schost;

  rexsession (str schostname, ptr<axprt_crypt> proxyxprt): proxyxprt (proxyxprt), schost (schostname)
    {
      proxy = aclnt::alloc (proxyxprt, rex_prog_1);
      rexserv = asrv::alloc (proxyxprt, rexcb_prog_1, wrap (this, &rexsession::rexcb_dispatch));
    }

  ~rexsession ()
    {
      channels.clear ();
    }
      
  /*todo: add proxy eof callback */
  rexsession (callback<void, void>::ref sessioncreatedcb, str schostname): schost(schostname)
    {

      ref <agentconn> aconn = New refcounted<agentconn> ();
      ptr<sfsagent_rex_res> ares = aconn->rex (schost);
      if (!ares || !ares->status)
	fatal << "could not connect to agent\n";


  
      rex_sesskeydat *kscdat = New rex_sesskeydat;
      rex_sesskeydat *kcsdat = New rex_sesskeydat;
      sfs_seqno *rexseqno = New sfs_seqno;
      	    
      kcsdat->type = SFS_KCS;
      kcsdat->cshare = ares->resok->kcs.kcs_share;
      kcsdat->sshare = ares->resok->kcs.ksc_share;
      kscdat->type = SFS_KSC;
      kscdat->cshare = ares->resok->ksc.kcs_share;
      kscdat->sshare = ares->resok->ksc.ksc_share;
      *rexseqno = ares->resok->seqno;
      sfs_connect_path (schostname, SFS_REX,
			wrap (this, &rexsession::connected, kcsdat, kscdat, rexseqno, sessioncreatedcb),
			false);
    }

  void
    makechannel (ptr<rexchannel> newchan)
    {
      makechannel (newchan, cbv_null);
    }

 
  void
    makechannel (ptr<rexchannel> newchan, cbv madechannelcb)
    {
      rex_mkchannel_arg arg;
      
      arg.av.setsize (newchan->command.size ());
      for (size_t i = 0; i < newchan->command.size (); i++)
	arg.av[i] = newchan->command[i];
      arg.nfds = newchan->initnfds;
      
      rex_mkchannel_res *resp = New rex_mkchannel_res;
      proxy->call (REX_MKCHANNEL, &arg, resp, wrap (this, &rexsession::madechannel, resp, newchan, madechannelcb));
    }
};


//requires that "lookup" is already absolute
static str
mk_readlinkres_abs (char *readlinkres, str lookup) {
  assert (lookup[0] == '/');
  
  if (readlinkres[0] == '/')
    return readlinkres;


  int ix = lookup.len() - 1;
  //skip over trailing slashes
  while (lookup[ix] == '/')
    ix--;
  
  //skip over basename
  while (lookup[ix] != '/')
    ix--;
  
  return strbuf () << str (lookup.cstr (), ix + 1) << readlinkres;
}

//returns destination:hostid portion of self-certifying pathname
static str
path2sch (str path) {
  if (strcmp (sfsroot, path.cstr ()))
    return "";

  const char *nosfs = path.cstr () + 5;
  
  char *firstslash = strchr (nosfs, '/');
  if (firstslash)
    return str (nosfs, firstslash - nosfs);
  else
    return nosfs;
}
  
//for looking up self-certifying hostnames in certprog interface
//   returns empty string on failure
str
certproglookup (str host) {
  if (!host.len ())
    return host;
  if (strchr (host.cstr(), ':'))
    return host;

  str lookup = strbuf () << "/sfs/" << host;
  char readlinkbuf [PATH_MAX + 1];
    
  struct stat sb;

  while (!lstat (lookup.cstr (), &sb)) {
    //if we didn't lookup a symlink, then it's the self-certifying pathname
    if ( (sb.st_mode & S_IFMT) != S_IFLNK)
      return path2sch (lookup);

    int len = readlink (lookup, readlinkbuf, PATH_MAX);
    readlinkbuf[len] = 0;

    //warn << "readlink of " << lookup << " returned " << readlinkbuf << "\n";

    lookup = mk_readlinkres_abs (readlinkbuf, lookup);
  }

  //if we got permission denied, lookup probably contains self-certifying path
  return (errno == EACCES) ? path2sch (lookup) : str ("");
}
