/* $Id: sfsconnect.C,v 1.11 2000/04/01 00:39:27 dm Exp $ */

/*
 *
 * Copyright (C) 1999 David Mazieres (dm@uun.org)
 *
 * This program is free software; you can redistribute it and/or
 * modify it under the terms of the GNU General Public License as
 * published by the Free Software Foundation; either version 2, or (at
 * your option) any later version.
 *
 * This program is distributed in the hope that it will be useful, but
 * WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
 * General Public License for more details.
 *
 * You should have received a copy of the GNU General Public License
 * along with this program; if not, write to the Free Software
 * Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307
 * USA
 *
 */

#include "sfskey.h"
#include "srp.h"

struct constate {
  static ptr<rabin_priv> ckey;

  const sfs_connect_cb cb;
  sfs_connectarg carg;
  sfs_connectres cres;

  ref<sfscon> sc;
  ptr<aclnt> c;
  
  bool encrypt;
  bool check_hostid;
  
  constate (const sfs_connect_cb &c)
    : cb (c), sc (New refcounted<sfscon>),
      encrypt (true), check_hostid (true) { sc->hostid_valid = false; }

  void fail (str msg) { (*cb) (NULL, msg); delete this; }
  void succeed () { c = NULL; (*cb) (sc, NULL); delete this; }

  void start ();
  void getfd (int fd);
  void getconres (enum clnt_stat err);
  void cryptcb (const sfs_hash *sessidp);
};

ptr<rabin_priv> constate::ckey;

void
constate::start ()
{
  if (encrypt)
    rndstart ();

  if (carg.name == "-" && carg.service == SFS_AUTHSERV) {
    int fd = suidgetfd ("authserv");
    errno = ECONNREFUSED;
    getfd (fd);
  }
  else if (!strchr (carg.name, '.') && check_hostid)
    fail (carg.name << ": must be fully qualified domain name\n");
  else
    tcpconnect (carg.name, sfs_port, wrap (this, &constate::getfd));
}

void
constate::getfd (int fd)
{
  if (fd < 0) {
    fail (carg.name << ": " << strerror (errno));
    return;
  }
  sc->x = axprt_crypt::alloc (fd);
  c = aclnt::alloc (sc->x, sfs_program_1);
  c->call (SFSPROC_CONNECT, &carg, &cres, wrap (this, &constate::getconres));
}

void
constate::getconres (enum clnt_stat err)
{
  if (err) {
    fail (carg.name << ": " << err);
    return;
  }
  if (cres.status) {
    fail (carg.name << ": " << cres.status);
    return;
  }

  sc->servinfo = cres.reply->servinfo;
  if (!sfs_mkhostid (&sc->hostid, sc->servinfo.host)) {
    fail (carg.name << ": Server returned garbage hostinfo");
    return;
  }

  if (sc->servinfo.host.hostname == carg.name && sc->hostid == carg.hostid)
    sc->hostid_valid = true;
  else if (check_hostid) {
    fail (carg.name << ": Server key does not match hostid");
    return;
  }
  sc->path = sfs_hostinfo2path (sc->servinfo.host);

  if (encrypt) {
    if (!ckey) {
      rndsync ();
      if (!ckey)
	ckey = New refcounted<rabin_priv> (rabin_keygen (sfs_minpubkeysize));
    }
    sfs_client_crypt (c, ckey, carg, *cres.reply,
		      wrap (this, &constate::cryptcb));
  }
  else
    succeed ();
}

void
constate::cryptcb (const sfs_hash *sidp)
{
  if (!sidp) {
    fail (carg.name << ": Session key negotiation failed");
    return;
  }
  sc->sessid = *sidp;
  sfs_get_authid (&sc->authid, carg.service, sc->servinfo.host.hostname,
		  &sc->hostid, &sc->sessid);
  succeed ();
}

void
sfs_connect (const sfs_connectarg &carg, sfs_connect_cb cb,
	     bool encrypt, bool check_hostid)
{
  constate *sc = New constate (cb);
  sc->carg = carg;
  sc->encrypt = encrypt;
  sc->check_hostid = check_hostid;
  sc->start ();
}

void
sfs_connect_local_cb (constate *cs, ptr<sfscon> sc, str err)
{
  if (sc) {
    cs->carg.name = sc->servinfo.host.hostname;
    cs->carg.hostid = sc->hostid;
    cs->check_hostid = true;
    cs->start ();
  }
  else
    cs->fail (err);
}

void
sfs_connect_path (str path, sfs_service service, sfs_connect_cb cb,
		  bool encrypt, bool check_hostid)
{
  constate *sc = New constate (cb);
  sc->carg.release = sfs_release;
  sc->carg.service = service;
  sc->encrypt = encrypt;
  sc->check_hostid = check_hostid;
  if (path == "-") {
    sfs_connectarg carg = sc->carg;
    carg.name = "-";
    sfs_connect (carg, wrap (sfs_connect_local_cb, sc), false, false);
  }
  else if (!sfs_parsepath (path, &sc->carg.name, &sc->carg.hostid))
    sc->fail (path << ": not a file SFS file system");
  else
    sc->start ();
}

struct srpcon {
  const sfs_connect_cb cb;
  str user;
  str host;
  str pwd;
  ptr<sfscon> sc;
  ptr<aclnt> c;
  sfsauth_srpres sres;
  srp_client *srpp;
  str *pwdp;
  str *userp;
  int ntries;

  srpcon (const sfs_connect_cb &c)
    : cb (c), srpp (NULL), pwdp (NULL), userp (NULL), ntries (0) {}
  void fail (str msg) { (*cb) (NULL, msg); delete this; }
  void succeed () { c = NULL; (*cb) (sc, NULL); delete this; }

  void start (str &u);
  void getcon (ptr<sfscon> sc, str err);
  void initsrp ();
  void srpcb (clnt_stat err);
};

void
srpcon::start (str &u)
{
  static rxx usrhost ("^([^@]+)?@(.*)$");
  if (!usrhost.search (u)) {
    *userp = u;
    fail ("not of form [user]@host");
    return;
  }

  user = usrhost[1];
  host = usrhost[2];
  if (!user && !(user = myusername ())) {
    fail ("Could not get local username");
    return;
  }

  rndstart ();

  sfs_connectarg carg;
  carg.release = sfs_release;
  carg.service = SFS_AUTHSERV;
  carg.name = host;

  sfs_connect (carg, wrap (this, &srpcon::getcon), true, false);
}

void
srpcon::getcon (ptr<sfscon> s, str err)
{
  sc = s;
  if (!s) {
    fail (err);
    return;
  }

  c = aclnt::alloc (sc->x, sfsauth_program_1);
  initsrp ();
}

void
srpcon::initsrp ()
{
  sfssrp_init_arg arg;
  arg.username = user;
  if (!srpp->init (&arg.msg, sc->authid, user)) {
    fail ("SRP client initialization failed");
    return;
  }
  c->call (SFSAUTHPROC_SRP_INIT, &arg, &sres, wrap (this, &srpcon::srpcb));
}

void
srpcon::srpcb (clnt_stat err)
{
  if (err) {
    fail (host << ": " << err);
    return;
  }
  if (sres.status != SFSAUTH_OK) {
    if (!pwd || ntries++ >= 3) {
      fail ("Server aborted SRP protocol");
      return;
    }
    pwd = NULL;
    warnx ("Server rejected passphrase.\n");
    initsrp ();
    return;
  }

 reswitch:
  switch (srpp->next (sres.msg, sres.msg.addr ())) {
  case SRP_SETPWD:
    pwd = getpwd (strbuf () << "Passphrase for " << srpp->getname () << ": ");
    srpp->setpwd (pwd);
    if (!pwd.len ()) {
      fail ("Aborted.");
      return;
    }
    goto reswitch;
  case SRP_NEXT:
    c->call (SFSAUTHPROC_SRP_MORE, sres.msg.addr (), &sres, 
	     wrap (this, &srpcon::srpcb));
    break;
  case SRP_DONE:
    if (user)
      *userp = user << "@" << srpp->host;
    if (pwdp)
      *pwdp = pwd;
    sc->hostid_valid = (srpp->host == sc->servinfo.host.hostname);
    succeed ();
    break;
  default:
    fail (host << ": server returned invalid SRP message");
    break;
  }
}

void
sfs_connect_srp (str &user, srp_client *srpp, sfs_connect_cb cb,
		 str *userp, str *pwdp)
{
  assert (srpp);
  srpcon *sc = New srpcon (cb);
  sc->srpp = srpp;
  sc->pwdp = pwdp;
  sc->userp = userp;
  sc->start (user);
}
