/* trrpitr.c  Version 0.1
 * Written 8/2007 William Herrin <bill@herrin.us>
 *
 * Proof of concept code for an IPv4 Ingress Transit Router used for the
 * Tunnelling Route Reduction Protocol
 * http://bill.herrin.us/network/trrp.html
 *
 * This code is offerd as a proof of concept prototype. It is not intended
 * to be a reference implementation and should not be treated as a reference
 * implementation. In particular, it:
 *   1. Handles only one packet at a time
 *   2. Does no caching at all
 *   3. Implements only a tiny bit of the required error handling.
 *
 * This is intended to be an "end of the line" ITR. You should set a
 * CIDR block intended to be the only destinations it will encapsulate
 * and put it in networktoencapsulate/netmasktoencapsulate. Then use
 * the sample iptables command to prevent Linux's IP stack from interacting
 * with the packet.
 */

/* Author's configuration, 3 linux servers:
 * host 70.184.240.82:
   iptables -I FORWARD -d 192.168.255.0/24 -j DROP
   runs this program
 * host 70.184.240.83
   ip tunnel add etr mode gre local 192.168.99.5 key 1
   ip link set etr up
   ip addr add 10.0.3.2/32 dev etr
   echo 0 > /proc/sys/net/ipv4/conf/etr/rp_filter
   ifconfig lo:0 192.168.255.99 netmask 255.255.255.255
 * host 70.184.240.91:
   ip route add 192.168.255.0/24 via 70.184.240.82
   ping 192.168.255.99
 */

/* iptables -I FORWARD -d 192.168.255.0/24 -j DROP */
const char *networktoencapsulate = "192.168.255.0";
const int netmasktoencapsulate = 24;

const char *itrsourceaddress = "70.184.240.82";
const char *dnsresolver = "127.0.0.1";
/* const char *trrpdomain = "arpa"; */
const char *trrpdomain = "dirtside.com";

#include <stdio.h>
#include <stdlib.h>
#include <unistd.h> /* close */
#include <string.h> /* memcmp */
#include <sys/socket.h>
#include <netpacket/packet.h>
#include <net/ethernet.h>     /* the L2 protocols */
#include <netinet/in.h>
#include <sys/types.h>
#include <arpa/inet.h> /* inet_pton */
#include <errno.h>
#include <fcntl.h> /* O_NONBLOCK */
#include <sys/select.h> /* select */
#include <sys/time.h> /* select */
#include <time.h> /* time() */


typedef unsigned char uint8;
typedef unsigned short uint16;
typedef unsigned int uint32;

struct IPHEADER {
  uint8 versionlength;
  uint8 tos;
  uint16 length;
  uint16 identification;
  uint16 flagsfragment;
  uint8 ttl;
  uint8 protocol;
  uint16 checksum;
  uint32 sourceip;
  uint32 destip;
};

/* struct GREHEADER {
  uint16 flagsver;
  uint16 protocol;
  uint16 checksum;
  uint16 offset;
  uint32 key;
  uint32 sequence;
  uint32 routing;
} */

struct GREHEADER {
  uint8 flags;
  uint8 flagsver;
  uint16 protocol;
  uint32 key;
}; 

struct OUTHEADER {
  struct IPHEADER ip;
  struct GREHEADER gre;
};

/* Some data structures we'll pre-initialize for efficiency */
uint32 *netmasks; /* "/nn" to bitmask map */
struct OUTHEADER headertemplate; /* Precomputed IP+GRE header */
uint32 acceptdest;  /* networktoencapsulate */
struct sockaddr_in dnsresolver_dest; 
uint16 g4;  /* string "g4" in network byte order */

void initializedatastructures (void) {
/* Initialize the global variables above */
  int i;
  uint32 nm[] = { 0x00000000, 0x80000000, 0xc0000000, 0xe0000000,
  0xf0000000, 0xf8000000, 0xfc000000, 0xfe000000, 
  0xff000000, 0xff800000, 0xffc00000, 0xffe00000,
  0xfff00000, 0xfff80000, 0xfffc0000, 0xfffe0000, 
  0xffff0000, 0xffff8000, 0xffffc000, 0xffffe000, 
  0xfffff000, 0xfffff800, 0xfffffc00, 0xfffffe00, 
  0xffffff00, 0xffffff80, 0xffffffc0, 0xffffffe0, 
  0xfffffff0, 0xfffffff8, 0xfffffffc, 0xfffffffe, 
  0xffffffff };
  uint8 *d;

  netmasks = (uint32*) malloc (sizeof(uint32)*33);
  for (i=0; i<=32; i++) netmasks[i] = htonl (nm[i]);
  inet_pton (AF_INET,networktoencapsulate,&acceptdest);

  /* acceptdest = htonl (acceptdest); already in network byte order */
  acceptdest = acceptdest & netmasks[netmasktoencapsulate];

  /* Build IP+GRE header template for encapsulated packets */
  memset (&headertemplate,0,sizeof(headertemplate));
  headertemplate.ip.versionlength = 0x45;
  d=(uint8*) &(headertemplate.ip.flagsfragment);
  *d |= 0x40; /* Don't fragment */
  inet_pton (AF_INET,itrsourceaddress,&(headertemplate.ip.sourceip));
  headertemplate.ip.protocol = 47; /* GRE */
  headertemplate.ip.ttl = 64;
  headertemplate.gre.flags |= 0x20; /* key present and valid */
  headertemplate.gre.key = htonl(1); /* key 1 */
  headertemplate.gre.protocol = htons(0x0800); /* IPv4 = 0x0800 */

  dnsresolver_dest.sin_family=AF_INET;
  dnsresolver_dest.sin_port = htons (53);
  inet_pton (AF_INET,dnsresolver,&(dnsresolver_dest.sin_addr.s_addr));
  memcpy (&g4,"g4",sizeof(char)*2);

  return;
}


int filter (struct IPHEADER *ip) {
/* Return true if we should consider this packet, false if not. */
  if (acceptdest == (ip->destip&netmasks[netmasktoencapsulate])) return 1;
  return 0;
}

char printheaderascii[] =  /* used by printheader */
"................................ !\"#$%&'()*+,-./0123456789:;<=>?"
"@ABCDEFGHIJKLMNOPQRSTUVWXYZ[\\]^_`abcdefghijklmnopqrstuvwxyz{|}~."
"................................ !\"#$%&'()*+,-./0123456789:;<=>?"
"@ABCDEFGHIJKLMNOPQRSTUVWXYZ[\\]^_`abcdefghijklmnopqrstuvwxyz{|}~.";

void printheader (void *ip, int max) {
/* For debugging purposes. Print the data in "ip", up to max bytes. */
  uint8 *p;
  int i;
  char buf[17];

  p = (uint8*) ip;
  for( i=0; i<max; i++) {
    printf("%02x ", (int) p[i]);
    buf[i%16] = printheaderascii[(int) p[i]];
    if (!((i+1)%16)) {	
      buf[16]=0;
      printf(" %s\n",buf);
    }
  };
  if (i%16)  {
    buf[i%16]=0;
    for (; i%16; i++) printf ("   ");
    printf (" %s\n",buf);
  }
  return;
}

int openreceiver (void) {
/* Open a socket which receives all IP packets received by this machine */
  int receiver;
  int bufsize; 
  socklen_t bufsizesize;
  int flags;

  receiver = socket (PF_PACKET, SOCK_DGRAM, htons(ETH_P_IP));
  if (receiver<1) {
    printf ("error: receiver socket returned %d\n",receiver);
    return receiver;
  }
  /* Make as big a receive buffer as we can so that we don't lose packets.
   * Make as small a send buffer as we can so we don't waste memory.
   */
  bufsize = 500000;
  setsockopt (receiver,SOL_SOCKET,SO_RCVBUF,&bufsize,sizeof(bufsize));
  bufsize = 100;
  setsockopt (receiver,SOL_SOCKET,SO_SNDBUF,&bufsize,sizeof(bufsize));
  bufsizesize = sizeof(bufsize);
  getsockopt (receiver,SOL_SOCKET,SO_RCVBUF,&bufsize,&bufsizesize);
  if (bufsize < 260000) {
    printf ("error: receiver socket bufsize %d\n",bufsize);
    close (receiver);
    return -1;
  }
  /* put the socket in non-blocking mode */
  if (-1 == (flags = fcntl(receiver, F_GETFL, 0))) flags = 0;
  if (-1 == fcntl(receiver, F_SETFL, flags | O_NONBLOCK)) {
    printf ("error: receiver socket could not set O_NONBLOCK\n");
    close (receiver);
    return -1;
  }
  return receiver;
}

int opensender (void) {
/* Open a socket with which to send IP packets including ICMP errors and
 * the encoded GRE packets. Note that some fields inside the packet
 * will be modified after its sent to the socket. See "man -S 7 raw" */
  int sender;
  int bufsize; 
  socklen_t bufsizesize;
  int flags;
  int one=1;

  sender = socket (PF_INET, SOCK_RAW, IPPROTO_RAW);
  if (sender<1) {
    printf ("error: sender socket returned %d\n",sender);
    return sender;
  }
  /* Advise that the IP header is included when sent and that certain
   * fields should be filled in per "man -S 7 raw" */
  if (setsockopt (sender, IPPROTO_IP, IP_HDRINCL, &one, sizeof (one)) < 0) {
    printf ("error: Cannot set HDRINCL!\n");
    close (sender);
    return -1;
  }
  /* Make as big a send buffer as we can so we don't lose packets
   * Make as small a receive buffer as we can so we don't waste memory.
   */
  bufsize = 500000;
  setsockopt (sender,SOL_SOCKET,SO_SNDBUF,&bufsize,sizeof(bufsize));
  bufsize = 100;
  setsockopt (sender,SOL_SOCKET,SO_RCVBUF,&bufsize,sizeof(bufsize));
  bufsizesize = sizeof(bufsize);
  getsockopt (sender,SOL_SOCKET,SO_SNDBUF,&bufsize,&bufsizesize);
  if (bufsize < 260000) {
    printf ("error: sender socket bufsize %d\n",bufsize);
    close (sender);
    return -1;
  }
  /* put the socket in non-blocking mode */
  if (-1 == (flags = fcntl(sender, F_GETFL, 0))) flags = 0;
  if (-1 == fcntl(sender, F_SETFL, flags | O_NONBLOCK)) {
    printf ("error: sender socket could not set O_NONBLOCK\n");
    close (sender);
    return -1;
  }
  
  return sender;
}

int opendns (void) {
/* Open a UDP socket to communicate with the DNS resolver */
  int dns;
  int bufsize, flags; 
  struct sockaddr_in sin;

  dns = socket(AF_INET, SOCK_DGRAM, 0);
  if (dns<1) {
    printf ("error: dns socket returned %d\n",dns);
    return dns;
  }
  /* Make as big a send buffer as we can so we don't lose packets
   * Make as big a receive buffer as we can so we don't lost packes
   */
  bufsize = 500000;
  setsockopt (dns,SOL_SOCKET,SO_SNDBUF,&bufsize,sizeof(bufsize));
  setsockopt (dns,SOL_SOCKET,SO_RCVBUF,&bufsize,sizeof(bufsize));
  /* Make the DNS socket non-blocking */
  if (-1 == (flags = fcntl(dns, F_GETFL, 0))) flags = 0;
  if (-1 == fcntl(dns, F_SETFL, flags | O_NONBLOCK)) {
    printf ("error: dns socket could not set O_NONBLOCK\n");
    close (dns);
    return -1;
  }
  /* Bind the DNS socket so that it gets a source port.
   * Don't really care which port.
   */
  sin.sin_family = AF_INET;
  sin.sin_port = INADDR_ANY;
  sin.sin_addr.s_addr = INADDR_ANY;
  if (-1 == bind (dns, (struct sockaddr*) &sin, sizeof(sin))) {
    printf ("error: dns socket could not bound\n");
    close (dns);
    return -1;
  }
  
  return dns;
}

struct DNSQUERY {
  uint16 identification;
  uint8 flagsa;
  uint8 flagsb;
  uint16 numquestions;
  uint16 numanswers;
  uint16 numauthority;
  uint16 numadditional;
  uint8 question[512];
  uint16 len;
};

struct DNSQUERYEND {
  uint16 querytype;
  uint16 queryclass;
};

struct DECODEDDNSQUESTION {
  uint16 type;
  uint16 class;
  char *name;
};

struct DECODEDDNSRESOURCE {
  char *name;
  uint16 type;
  uint16 class;
  uint32 ttl;
  uint16 datalen;
  uint8 data[1];
};

struct DECODEDDNSANSWER {
  uint16 identification;
  uint8 flagsa;
  uint8 flagsb;
  uint16 numquestions;
  uint16 numanswers;
  uint16 numauthority;
  uint16 numadditional;
  struct DECODEDDNSQUESTION **questions; 
  struct DECODEDDNSRESOURCE **answers;
  struct DECODEDDNSRESOURCE **authority;
  struct DECODEDDNSRESOURCE **additional;
};

char *dnsreadname (
/* Part of the DNS packet decoding. Read the DNS name that starts at
 * namestart from the buffer whose first byte is bufstart and last
 * byte is bufend following all "compression" and convert the name
 * into a standard C string. Return as a newly malloced string.
 * nameendptr will be the byte after the string or the byte after
 * the first redirection.
 */
	uint8 *namestart   /* Read the name starting here */
,	uint8 *bufstart    /* The whole DNS packet starts here */
,	uint8 *bufend      /* The whole DNS packet ends here */
,	uint8 **nameendptr /* Output: Index into the packet after consuming
			    * the bytes used up by the name. */
) {
  char name[512];
  int namei;
  uint16 offset;
  uint8 count;

  *nameendptr = NULL;
  if (namestart==NULL) return NULL;
  for (namei=0; (*namestart != 0) && (namei<511); ) {
    if (( (*namestart) & 0xc0) != 0) {
      /* If the two high-bits are set, then the low bits of this byte and
       * all bits of the next are the byte offset within the packet at
       * which the name continues. */
      if (!(*nameendptr)) *nameendptr = namestart+2;
      memcpy (&offset,namestart,2);
      offset = htons (offset);
      offset &= 0x3fff;
      namestart = bufstart + offset;
      if ((namestart<bufstart) || (namestart>bufend)) {
	/* This is really a failure, but do the best we can. */
        name[namei]=0;
        return strdup (name);
      }
      continue;
    }
    /* *namestart indicates the length of the next token in the name.
     * Add the next token to the name. */
    for (count = *namestart, namestart++; (namei<511) && (count>0) &&
	(namestart < bufend) ; count--, namei++, namestart++) 
	name[namei] = *namestart;
    /* Add a dot after the token */
    if (namei<511) {
      name[namei] = '.';
      namei++;
    }
  }
  if (!(*nameendptr)) *nameendptr = namestart+1; /* next byte after the 0 */
  if (namei>=511) namei=510; /* leave room for the 0 terminator */
  if ((namei>0) && (name[namei-1]=='.')) namei--; /* trim the trailing . */
  name[namei] = 0; /* add the 0 terminator */
  return strdup(name);
}

void dnsfreeanswer (struct DECODEDDNSANSWER *d) {
/* Free the data structure which holds the decoded DNS answer */
  int i;

  if (!d) return;
  if (d->questions) {
    for (i=0; i<d->numquestions; i++) {
      if (d->questions[i]) {
        if (d->questions[i]->name) free (d->questions[i]->name);
        free (d->questions[i]); 
      }
    }
    free (d->questions);
  }
  if (d->answers) {
    for (i=0; i<d->numanswers; i++) {
      if (d->answers[i]) {
        if (d->answers[i]->name) free (d->answers[i]->name);
        free (d->answers[i]); 
      }
    }
    free (d->answers);
  }
  if (d->authority) {
    for (i=0; i<d->numauthority; i++) {
      if (d->authority[i]) {
        if (d->authority[i]->name) free (d->authority[i]->name);
        free (d->authority[i]); 
      }
    }
    free (d->authority);
  }
  if (d->additional) {
    for (i=0; i<d->numadditional; i++) {
      if (d->additional[i]) {
        if (d->additional[i]->name) free (d->additional[i]->name);
        free (d->additional[i]); 
      }
    }
    free (d->additional);
  }
  free (d);
  return;
}

struct DECODEDDNSRESOURCE *decodednsresource (
/* Decode a "resource" field in a DNS packet. The Answer, Authority and
 * Additional sections of the packet consist of one or more resource
 * fields. */
	uint8 *buf     /* First byte of the DNS packet */
,	uint8 *end     /* Last byte of the DNS packet */
,	uint8 *current /* Index to the start of the resource record */
,	uint8 **after  /* Output: index to the byte following the
			* resource record. */
) {
  char *name;
  uint8 *p;
  uint16 datalen;
  struct DECODEDDNSRESOURCE *resource;

  /* The resource record starts with a name. In the answer section this
   * is supposed to match the query. In the authority section this is
   * supposed to be the name of the domain's SOA record. In the additional
   * section it can be just about anything. */
  name = dnsreadname (current,buf,end,&current);
  if ((!current)||((end-current)<10)) { /* invalid DNS response */
    *after = NULL;
    if (name) free (name);
    return NULL;
  }
  /* Fetch the length so we know how big a data structure to malloc */
  memcpy (&datalen,current+8,2);
  datalen = htons(datalen);
  if ( (end-current+10)<datalen ) datalen = end-current+10;
  resource = (struct DECODEDDNSRESOURCE *) malloc (
	sizeof (struct DECODEDDNSRESOURCE) + (sizeof(uint8)*datalen));
  /* And copy in the data */
  p = (uint8*) resource;
  memcpy (p+sizeof(char*),current,datalen+10);
  resource->name = name;
  resource->type = htons (resource->type);
  resource->class = htons (resource->class);
  resource->ttl = htonl (resource->ttl);
  resource->datalen = datalen;
  current += 10 + datalen;
printf ("Resource: %d,%d,%d,%d: %s\n",resource->type, resource->class, 
resource->ttl, resource->datalen, resource->name);
  *after = current;
  return resource;
}

struct DECODEDDNSANSWER *decodednsanswer (uint8 *buf, int len) {
/* Decode a complete DNS answer packet into the C data structure
 * DECODEDDNSANSWER which the program can easily handle. */
  struct DECODEDDNSANSWER *answer;
  struct DNSQUERY *q;
  uint8 *end, *current;
  uint16 i;

  if (len<12) return NULL; /* Not a valid DNS packet */
  /* Extract the basic info about what's in the packet */
  q = (struct DNSQUERY*) buf;
  answer = (struct DECODEDDNSANSWER*) malloc (sizeof(struct DECODEDDNSANSWER));
  memset (answer,0,sizeof(struct DECODEDDNSANSWER));
  answer->identification = q->identification;
  answer->flagsa = q->flagsa;
  answer->flagsb = q->flagsb;
  answer->numquestions = htons (q->numquestions);
  answer->numanswers = htons (q->numanswers);
  answer->numauthority = htons (q->numauthority);
  answer->numadditional = htons (q->numadditional);
  end = buf + len -1;
  /* Extract the question records. These are in the normal format for
   * names. */
  answer->questions = (struct DECODEDDNSQUESTION **) malloc (
	sizeof(struct DECODEDDNSQUESTION *) * answer->numquestions);
  memset (answer->questions,0,sizeof(struct DECODEDDNSQUESTION *) *
	answer->numquestions);
  for (i=0, current=buf+12; i<answer->numquestions; i++) {
    answer->questions[i] = (struct DECODEDDNSQUESTION *) malloc (
	sizeof (struct DECODEDDNSQUESTION));
    answer->questions[i]->name = dnsreadname (current,buf,end,&current);
    if (!current) { /* invalid DNS response */
      dnsfreeanswer (answer);
      return NULL;
    }
    memcpy (answer->questions[i],current,4);
    current += 4;
    answer->questions[i]->type = htons (answer->questions[i]->type);
    answer->questions[i]->class = htons (answer->questions[i]->class);
printf ("Question: %d,%d: %s\n",answer->questions[i]->type,
answer->questions[i]->class,answer->questions[i]->name);
  }
  /* Extract the answer resource records */
  answer->answers = (struct DECODEDDNSRESOURCE **) malloc (
	sizeof(struct DECODEDDNSRESOURCE *) * answer->numanswers);
  memset (answer->answers,0,sizeof(struct DECODEDDNSRESOURCE *) *
	answer->numanswers);
  for (i=0; i<answer->numanswers; i++) {
    answer->answers[i] = decodednsresource (buf,end,current,&current);
    if (!current) { /* invalid DNS response */
      dnsfreeanswer (answer);
      return NULL;
    }
  }
  /* Extract the authority resource records */
  answer->authority = (struct DECODEDDNSRESOURCE **) malloc (
	sizeof(struct DECODEDDNSRESOURCE *) * answer->numauthority);
  memset (answer->authority,0,sizeof(struct DECODEDDNSRESOURCE *) *
	answer->numauthority);
  for (i=0; i<answer->numauthority; i++) {
    answer->authority[i] = decodednsresource (buf,end,current,&current);
    if (!current) { /* invalid DNS response */
      dnsfreeanswer (answer);
      return NULL;
    }
  }
  /* Extract the additional resource records */
  answer->additional = (struct DECODEDDNSRESOURCE **) malloc  (
	sizeof(struct DECODEDDNSRESOURCE *) * answer->numadditional);
  memset (answer->additional,0,sizeof(struct DECODEDDNSRESOURCE *) *
	answer->numadditional);
  for (i=0; i<answer->numadditional; i++) {
    answer->additional[i] = decodednsresource (buf,end,current,&current);
    if (!current) { /* invalid DNS response */
      dnsfreeanswer (answer);
      return NULL;
    }
  }
  return answer;
}

int senddnsquery (uint32 ipaddr, int dnssock, int id) {
/* Accept an ip address in network byte order. Create the UDP DNS query
 * record based on the address.
 */

  uint8 *ip;
  struct DNSQUERY dns;
  struct DNSQUERYEND e;
  int l,i;
  void *p, *q;

  ip = (uint8*) &ipaddr;
  memset (&dns,0,sizeof(dns));
  dns.flagsa = 0x01; /* Query: 0x80=0, Recursion: 0x01=1 */
  dns.identification = htons ((short) id);
  dns.numquestions = htons (1);
  sprintf ((char*) dns.question+1,"%d.%d.%d.%d.v4.trrp.%s",
	(int) ip[3], (int) ip[2],(int) ip[1],(int) ip[0],trrpdomain);
  /* In the actual query, the .'s are replaced by a number from 0 to 63
   * which indicates how many bytes the next non-dotted component takes.
   * The first component is also preceeded by such a byte. The end of
   * the string remains null. */
  for (l=0,i=1; dns.question[i]!=0; i++) {
    if (dns.question[i]=='.') {
      dns.question[l] = i - l - 1;
      l=i;
    }
  }
  dns.question[l] = i - l - 1;
  /* Set the other parameters of the packet */
  e.querytype = htons (16); /* TXT */
  e.queryclass = htons (1); /* IP */
  memcpy (dns.question+i+1,&e,sizeof(e));
  p = dns.question+i+1+sizeof(e);
  q = &dns;
  dns.len = (uint16) (p-q);
  /* Send the packet to the DNS resolver */
  if (sendto(dnssock,&dns,dns.len,0,
	(struct sockaddr*) &dnsresolver_dest, sizeof(dnsresolver_dest)) 
	!= dns.len) {
    printf ("Error sending DNS query\n");
    return 0;
  }

  return 1;
}

struct ETRANSWER {
  struct ETRANSWER *next;
  time_t validuntil; /* seconds since epoch */
  uint8 priority;
  union {
    uint16 identifier;
    char id[2];
  };
  char route[1]; /* really variable length */
};

struct ETR {
  struct ETRANSWER *candidates;
  uint32 etr;
  time_t validuntil;
};

char decodehex[256] = /* Decode a hexidecimal string with a table lookup */
"................................................"
"\x00\x01\x02\x03\x04\x05\x06\x07\x08\x09......"
".\x0a\x0b\x0c\x0d\x0e\x0f.........................."
"\x0a\x0b\x0c\x0d\x0e\x0f........................."
"................................................................"
"................................................................";

void etranswerfree (struct ETRANSWER *e) {
/* free the memory allocated for an ETRANSWER structure */
  struct ETRANSWER *n;

  if (!e) return;
  for (; e; e=n) {
    n=e->next;
    free (e);
  }
  return;
}

struct ETRANSWER *dnstoetr (struct DECODEDDNSANSWER *dns, uint32 ip) {
/* Extract any valid priority/identifier/route sets from each TXT answer
 * in the decoded DNS record which matchs the requested ip address. */
  int i, len, routelen;
  char name[100];
  uint8 *cip;
  char *buf, *end, *space;
  struct ETRANSWER *answers=NULL, *rec;
  time_t t;

  /* Determine the correct name for the answers */
  cip = (uint8*) &ip;
  sprintf (name,"%d.%d.%d.%d.v4.trrp.%s",
	(int) cip[3], (int) cip[2],(int) cip[1],(int) cip[0],trrpdomain);
  t = time(NULL); /* Will change TTL to validuntil by adding now. */
  /* Consider each DNS answer */
  for (i=0; i<dns->numanswers; i++) {
    if ((dns->answers[i]->type == 16) &&
	(0==strcmp(name,dns->answers[i]->name))) {
      /* Have a valid TXT record. Get each space-seperated ETR entry. */
      buf = (char*) dns->answers[i]->data;
      len = dns->answers[i]->datalen-1;
      if (len > ((int)*buf)) len = (int) *buf;
      buf++;
      end = buf+len-1;
      for (space=buf; (space<=end); space++, buf=space) {
        for (; (space<=end) && (*space!=' '); space++ );
        if (((space-buf)<7) || (buf[2]!=',') || (buf[5]!=',') ||
	   (decodehex[(int) ((uint8) buf[0])]=='.') || 
	   (decodehex[(int) ((uint8) buf[1])]=='.')) {
          /* bad ETR entry */
          continue;
        }
        /* Got a validly formatted entry. Copy it to an ETRANSWER structure
	 * and add it to the linked list. */
        routelen = space-(buf+6);
        rec = (struct ETRANSWER*) malloc (sizeof(struct ETRANSWER) + 
		(sizeof(char)*routelen));
        rec->next = answers;
        answers = rec;
        rec->priority=(decodehex[(int) ((uint8) buf[0])]<<4)|
		decodehex[(int) ((uint8) buf[1])];
        rec->id[0]=buf[3];
        rec->id[1]=buf[4];
        rec->validuntil = t + ((time_t) dns->answers[i]->ttl);
        memcpy (rec->route,buf+6,routelen);
        rec->route[routelen] = 0;
printf ("Got: %d,%c%c=%d?=%d,%d,%s\n",(int) rec->priority,rec->id[0],rec->id[1],
rec->identifier, g4,(int)dns->answers[i]->ttl,rec->route);
      }
    }
  }
  return answers;
}

struct ETR *choosebestetr (struct ETRANSWER *etra) {
/* Find the lowest-numbered priority answer in etra which has
 * identifier "g4" and "route" is a valid IP address.
 */
  struct ETR *result;
  struct ETRANSWER *p;
  int pri;
  uint32 ip;

  result = (struct ETR*) malloc (sizeof(struct ETR)); 
  result->candidates = etra;
  for (pri=300, p=etra; p; p=p->next) {
    if (pri <= ((int) p->priority)) continue; /* already have a better one */
    if (p->identifier != g4) continue; /* unsupported identifier */
    if (inet_pton (AF_INET,p->route,&ip)<0) continue; /* bad address */
printf ("Selected etr %s\n",p->route);
    result->etr = ip;
    result->validuntil = p->validuntil;
    pri = (int) p->priority;   
  }
  if (pri<256) return result;
  free (result);
  return NULL;
}

uint32 lookupetr (uint32 destip,int dns, uint16 *mtu) {
/* Given the destination IP address, find the best GRE4 ETR. */
  uint32 etr;
  uint8 dnspacket[600];
  fd_set readfds;
  int len;
  struct DECODEDDNSANSWER *dnsanswer;
  struct ETRANSWER *etra;
  struct ETR *etre;

  etr = 0;
  senddnsquery (destip,dns,1);
  /* Prototype shortcut: Camp on the dns socket until we get an answer.
   * If this was a real reference implementation, we'd add this to the
   * stack of pending requests and come back to it when the answer arrived. */
  FD_ZERO (&readfds);
  FD_SET (dns,&readfds);
  select (dns+1,&readfds,NULL,NULL,NULL);
  len = recv (dns,dnspacket,599,0);
  /* printf ("DNS Response (%d):\n",len);
   * printheader (dnspacket,len); */
  dnsanswer = decodednsanswer (dnspacket,len);
  etra = dnstoetr (dnsanswer,destip);
  dnsfreeanswer (dnsanswer);
  etre = choosebestetr (etra);
  if (etre) etr = etre->etr;
  /* Prototype shortcut: don't bother caching the answer. Just re-query the
   * DNS resolver for every single packet. Yikes! */
  etranswerfree(etra);
  if (etre) free (etre);
  *mtu = 1500 - sizeof(struct OUTHEADER);

  printf ("Selected address:\n");
  printheader (&etr,4);
  

  return etr;
}

/* From Stevens, UNP2ev1 */
uint16 in_cksum(uint16 *addr, int len)
/* Compute IP header checksum */
{
    int nleft = len;
    int sum = 0;
    uint16 *w = addr;
    uint16 answer = 0;

    while (nleft > 1) {
        sum += htons(*w);
        w++;
        nleft -= 2;
    }

    if (nleft == 1) {
        *(unsigned char *)(&answer) = *(unsigned char *)w;
        sum += answer;
    }

    sum = (sum >> 16) + (sum & 0xffff);
    sum += (sum >> 16);
    answer = ~sum;
    return (answer);
}
      
void sendunreachable (
/* Compose and send an ICMP destination unreachable message */
	int sender
,	struct IPHEADER *ip
,	int len
,	uint8 code
,	uint16 mtu
) {
  uint8 buf[100 /*20+8+64+8*/], *d, t; 
  struct IPHEADER *u;
  int hlen, result;
  uint16 *sum;
  struct sockaddr_in dest;

  hlen = (int) ((ip->versionlength) & 0x0f) << 2;
  if (len<(hlen+8)) return; /* Need 8 bytes of payload */
  if (ip->protocol == 1) { 
    /* ICMP! May not send ICMP errors in response to ICMP errors! */
    d = (uint8*) ip;
    t = *(d+hlen);
    if (! ((t==0)||((t>=8)&&(t<=10))||((t>=13)&&(t<=18))) ) return;
    /* But may send in response to ICMP queries... */
  } 
  d = (uint8*) &(ip->sourceip);
  if ((*d)>=((uint8) 224)) {
    /* must not send an unreachable message to a multicast or broadcast
     * address. This should mean 224-239 and 255. I'll exclude 240-254
     * too. I won't look up the lan-local broadcast addresses 'cause I'm
     * lazy. I can't know whether a remote address is a broadcast address.
     */
    return;
  }
  memset (buf,0,100); 
  u = (struct IPHEADER*) buf;
  u->versionlength = 0x45;  /* version 4 length 20/4=5 */
  d = (uint8*) &(u->flagsfragment);
  *d |= 0x40; /* Don't fragment */
  u->sourceip = headertemplate.ip.sourceip;
  u->destip = ip->sourceip;
  u->protocol = 1; /* ICMP */
  u->ttl = 64;
  u->length = htons (/*20+8+hlen+8*/ 36+hlen);
  u->identification = ip->identification;
  buf[20] = 3; /* Destination Unreachable */
  buf[21] = code;
  mtu = htons (mtu);
  memcpy (buf+26,&mtu,2); /* 0 unless code=4 */
  memcpy (buf+28,ip,hlen+8);  /* original IP header + 8 bytes */
  sum = (uint16*) (buf+22);
  *sum = htons (in_cksum ((uint16*) (buf+20),hlen+16));

printheader ((struct IPHEADER*) buf, hlen+36);
  /* And send the ICMP packet */
  dest.sin_family = AF_INET;
  dest.sin_port = 0;
  dest.sin_addr.s_addr = u->destip;
  result = sendto (sender,buf,hlen+36,0, 
	(struct sockaddr*) &dest, sizeof(dest));
  return;
}

int main (int argc, char **argv) {

  int receiver, sender, len, result, numfds, dns;
  char packet[2000];
  struct IPHEADER *ip;
  uint8 *d, *s;
  struct OUTHEADER *header;
  struct sockaddr_in dest;
  struct timeval timeout;
  fd_set readfds;
  uint16 sum, mtu;
  int hlen, csum;

  initializedatastructures();
  if ((receiver = openreceiver())<0) return 1;
  if ((sender = opensender())<0) return 1;
  if ((dns = opendns())<0) return 1;
  numfds = receiver;
  if (sender>numfds) numfds = sender;
  if (dns>numfds) numfds = dns;
  numfds++;

  while (1) { 
    timeout.tv_sec = 1; 
    timeout.tv_usec = 0;
    FD_ZERO (&readfds);
    FD_SET (receiver,&readfds);
    select (numfds,&readfds,NULL,NULL,&timeout);

    len = recv (receiver,packet+sizeof(headertemplate),
	1999-sizeof(headertemplate),0);

    if (len<20) continue; /* not an ip packet. The select probably timed out */

    ip = (struct IPHEADER*)  (packet+sizeof(headertemplate));

    if (!filter(ip)) continue; 

    /* Verify the IP checksum */
    hlen = (int) ((ip->versionlength) & 0x0f) << 2;
    sum = in_cksum ((uint16*) ip,hlen);
    if (sum!=0) continue; /* Bad checksum. Drop packet. */

    if (ip->ttl<2) { /* Time exceeded */
      /* Send ICMP time-exceeded message here */
      /* Turns out we don't need to because the IP stack on this Linux
       * machine will send it for us even though we're dropping the packets
       * to the destination address. */
      continue;
    }

    /* Decrement TTL */
    ip->ttl --;
    csum = ((int) ntohs(ip->checksum)) + 0x100;
    ip->checksum = htons ((uint16) (csum + (csum>>16)));

    /* Start building GRE packet */
    memcpy (packet,&headertemplate,sizeof(headertemplate));
    header = (struct OUTHEADER*) packet;
    header->ip.destip = lookupetr (ip->destip,dns,&mtu);
    if (header->ip.destip==0) { /* No ETR; send a host-unreachable */
      sendunreachable (sender,ip,len,7,0);
      continue;
    }
    if (len>mtu) { 
      /* must fragment packet or send an ICMP fragmentation needed */
      d = (uint8*) &(ip->flagsfragment);
      if (*d & 0x40) { /* Don't fragment bit set */
        sendunreachable (sender,ip,len,4,mtu);
        continue;
      }
      /* fragment packet here */
      /* Will skip this part for the prototype code */
    }

    header->ip.length = htons(len+sizeof(headertemplate));
    header->ip.ttl = ip->ttl; /* Copy TTL to GRE packet per TRRP spec */

    /* Send the GRE packet */
    dest.sin_family = AF_INET;
    dest.sin_port = 0;
    dest.sin_addr.s_addr = header->ip.destip;
    result = sendto (sender,packet,len+sizeof(headertemplate),0, 
	(struct sockaddr*) &dest, sizeof(dest));
   
    /* Print some debugging info */ 
    s = (uint8*) &(ip->sourceip);
    d = (uint8*) &(ip->destip);
    printf ("Got: %d.%d.%d.%d > %d.%d.%d.%d (%d), %d, %d\n",
     (int)s[0], (int)s[1], (int)s[2],(int)s[3], 
     (int)d[0], (int)d[1], (int)d[2],(int)d[3],
     len,result,errno);
    printheader ((struct IPHEADER*) packet, len+sizeof(headertemplate));
  }

  close (receiver);
  close (sender);
  close (dns);
  return 0;
}


