/*-
 * Copyright Sergey Kosyakov ks@itp.ac.ru 1999
 * Copyright (C) @BABOLO  2002 http://www.babolo.ru/
 * All rights reserved.
 *
 * Redistribution and use in source and binary forms, with or without
 * modification, are permitted provided that the following conditions
 * are met:
 * 1. Redistributions of source code must retain the above copyright
 *    notice, this list of conditions and the following disclaimer.
 * 2. Redistributions in binary form must reproduce the above copyright
 *    notice, this list of conditions and the following disclaimer in the
 *    documentation and/or other materials provided with the distribution.
 *
 * THIS SOFTWARE IS PROVIDED BY THE AUTHOR AND CONTRIBUTORS ``AS IS'' AND
 * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
 * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
 * ARE DISCLAIMED.  IN NO EVENT SHALL THE AUTHOR OR CONTRIBUTORS BE LIABLE
 * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
 * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS
 * OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION)
 * HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
 * LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY
 * OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF
 * SUCH DAMAGE.
 */

#ifndef lint
static const char copyright[] = "\
@(#)Copyright Sergey Kosyakov ks@itp.ac.ru 1999\n\
@(#)Copyright (C) @BABOLO  2002 http://www.babolo.ru/\n\
@(#)All rights reserved.\n";
static const char rcsid[] = "$Id: secur.c,v 1.2 2010/10/18 16:51:36 babolo Exp $";
#endif /* not lint */

#include <openssl/rsa.h>
#include <sys/time.h>
#include <stdlib.h>
#include <string.h>
#include <stdio.h>
#include <errno.h>
#include "tund.h"

extern unsigned char packetBuf[PACKET_MAX_SIZE];
extern int socket_control;
extern int control_port;
extern int debug;
extern Seq *stun;

static unsigned char cipherBuf[PACKET_MAX_SIZE];
static void send_new_password(int tun_ind);
static void generate_new_password();
static unsigned char rsa_str[1024];
static unsigned char sKey[16];
static RSA *rsa_priv=NULL;
static Seq *hostsKey;
static RSA *rsa_pub;

typedef struct {
    char *host;
    RSA *rsa;
} HostKey;

/* Initialization procedure */

void
init_secur() {
    unsigned char key[8192];
    char hostname[128];
    unsigned char *p;
    char str[8192];
    unsigned int n;
    HostKey *hk;
    RSA *rsa;
    FILE *f;
    int ret;
    int len;

    /* Read own RSA private key */

    f = fopen(get_real_path(RSAKEY_FILE), "r");
    if  (!f) Error("Can not open %s", get_real_path(RSAKEY_FILE));
    bzero(key, 8192);
    len = 0;
    do {
        ret = fscanf(f, "%2x", &n);
        if  (ret != 1) break;
        key[len] = n;
        len++;
    } while(len < 8190);
    if  (debug) Log("Read %d bytes as RSA private key", len);
    fclose(f);
    p = key;
    rsa = d2i_RSAPrivateKey(&rsa_priv, &p, (long)len);
    if  (!rsa || !rsa_priv) Error("Error while d2i_RSAPrivateKey");
    /* Now scan other hosts RSA public key database */
    f = fopen(get_real_path(RSAPUB_DB), "r");
    if  (!f) Error("Can not open %s", get_real_path(RSAPUB_DB));
    hostsKey = seq_new(); 
    do {
        bzero(key, 8192);
        bzero(str, 8192);
        bzero(hostname, 128);
        ret = fscanf(f, "%s %s", hostname, str);
        if  (ret <= 0) break;
#ifdef DEBUG
        if  (debug) Log("Host [%s] key [%s]", hostname, str);
#endif
        for (len = 0; len * 2 < strlen(str); len++) {
            ret = sscanf(str + len * 2, "%2x", &n);
            if  (ret != 1) Error("Corrupted %s", get_real_path(RSAPUB_DB));
            key[len] = n;
        }
        if  (debug) Log("Read %d bytes as RSA public key", len);
        p = key;
        rsa_pub = NULL;
        rsa = d2i_RSAPublicKey(&rsa_pub, &p, (long)len);
        if  (!rsa || !rsa_pub) Error("Error while d2i_RSAPublicKey");
        hk = (HostKey *)malloc(sizeof(HostKey));
        hk->host = strdup(hostname);
        hk->rsa = rsa_pub;
        seq_append(hostsKey, (void *)hk);
    } while(1);
}

/* Initiates symmetric password change procedure*/

static void
change_password(int tun_ind) {
    Tunnel *tun;

    tun = (Tunnel *)stun->buf[tun_ind];
    generate_new_password();
    bcopy(sKey, tun->local_key, 16);
    set_cipher_key(tun_ind, 0);
    tun->state = TUNS_SPAW;
    send_new_password(tun_ind);
}

int
data_encrypt(unsigned char *buf, int len, int tun_ind) {
    Tunnel *tun = (Tunnel *)stun->buf[tun_ind];
    int en_len;

    if  (!(tun->flags & TUNF_ENC && tun->flags & TUNF_HAS_RPWD && tun->state == TUNS_PRC)) return(-1);
#ifdef DEBUG
    if  (debug) Log("Request for encryption of %d bytes t=%d", len, tun_ind);
#endif /*DEBUG*/
    if  ((en_len = cipher_encrypt(buf, cipherBuf, len, tun_ind)) < 0) return(-1);
    bcopy(cipherBuf, buf, en_len);
    return(en_len);
}

void
data_decrypt(unsigned char *buf, int len, int tun_ind) {
#ifdef DEBUG
    if  (debug) Log("Request for decryption of %d bytes t=%d", len, tun_ind);
#endif /*DEBUG*/
    cipher_decrypt(buf, cipherBuf, len, tun_ind);
    bcopy(cipherBuf, buf, len);
}

int
rsa_encrypt(unsigned char *buf, int len) {
    int ret;

    ret = RSA_private_encrypt(len, buf, rsa_str, rsa_priv, RSA_PKCS1_PADDING);
    if  (ret <= 0) Error("rsa_encrypt(0x%x,%d) returns %d", buf, len, ret);
    bcopy(rsa_str, buf, ret);
#ifdef DEBUG
    if  (debug) Log("RSA encrypted %d bytes into %d bytes", len, ret);
#endif /* DEBUG */
    return(ret);
}

int
rsa_decrypt(Tunnel *tun, unsigned char *buf, int len) {
    int ret;

#ifdef DEBUG
    if  (debug) Log("Request for RSA decrypt of %d bytes", len);
#endif /* DEBUG */
    ret = RSA_public_decrypt(len, buf, rsa_str, tun->r_key, RSA_PKCS1_PADDING);
    if  (ret <= 0) {
        if  (debug) Log("rsa_decrypt(0x%x,%d) returns %d", buf, len, ret);
        return(ret);
    }
#ifdef DEBUG
    if  (debug) Log("RSA decrypted %d bytes into %d bytes", len, ret);
#endif /* DEBUG */
    bcopy(rsa_str, buf, ret);
    return(ret);
}

int
control_encrypt(CP *cp) {
    int len;

    cp->c.rn = get_rand();
    len = rsa_encrypt(((unsigned char *)cp) + sizeof(PacketHead), cp->c.body_len+sizeof(ControlHead));
    return(len + sizeof(PacketHead));
}

int
control_decrypt(CP *cp, Tunnel *tun, int size) {
    int ret;

    ret = rsa_decrypt(tun, ((unsigned char *)cp) + sizeof(PacketHead), size-sizeof(PacketHead));
    return(ret);
}

static void
send_control(int tun_ind) {
    struct sockaddr_in saddr;
    Tunnel *tun;
    CP *control;
    int len;

    tun = (Tunnel *)stun->buf[tun_ind];
    control = (CP *)packetBuf;
    len = control_encrypt(control);
    bcopy(&(tun->to), &saddr, sizeof(struct sockaddr_in));
    saddr.sin_port = htons(control_port);
    ks_sendto( tun_ind
             , socket_control
             , packetBuf
             , len
             , 0
             , (struct sockaddr*)(&saddr)
             , sizeof(struct sockaddr_in)
             );
}

static void
reset_remote_tunnel(int tun_ind) {
    Tunnel *tun;
    CP *control;

    tun = (Tunnel *)stun->buf[tun_ind];
    if  (tun->state != TUNS_RESET_AW) return;
    control = (CP *)packetBuf;
    control->h.ver = PPF_PROTO_VER;
    control->h.flags = PPF_CNT;
    control->h.tunAddr = tun->tunAddr;
    control->c.mt = CMT_RESET;
    control->c.body_len = 0;
    send_control(tun_ind);
    tun->state = TUNS_RESET_AW;
    add_expired_operation(tun_ind, 10, reset_remote_tunnel);
#ifdef DEBUG
    if  (debug) Log("Reset tunnel [%s] requested", tun->label);
#endif /* DEBUG */
}

void
start_secur_tunnel(int tun_ind) {
    HostKey *hk;
    Tunnel *tun;
    int i;

    tun = (Tunnel *)stun->buf[tun_ind];
    /* First find remote end RSA public key */
    for (i = 0; i < hostsKey->length; i++) {
        hk = (HostKey *)hostsKey->buf[i];
        if  (strcmp(hk->host, tun->remoteName) == 0) {
            tun->r_key = hk->rsa;
#ifdef DEBUG
            if  (debug)
                Log("Tun [%s] remote h=%s RSA PKEY at 0x%x", tun->label, tun->remoteName, tun->r_key);
#endif /*DEBUG*/
            break;
    }   }
    if  (!tun->r_key) Error("Can not find RSA PKEY for tun [%s]", tun->label);
    /*First reset tunnel on peer*/
    tun->state = TUNS_RESET_AW;
    reset_remote_tunnel(tun_ind);
}

static void
send_new_password(int tun_ind) {
    Tunnel *tun;
    CP *control;

    tun = (Tunnel *)stun->buf[tun_ind];
    if  (tun->state != TUNS_SPAW) return;
    control = (CP *)packetBuf;
    control->h.ver = PPF_PROTO_VER;
    control->h.flags = PPF_CNT;
    control->h.tunAddr = tun->tunAddr;
    control->c.mt = CMT_SPWD;
    control->c.body_len = 16;
    bcopy(tun->local_key, control->body, 16);
    send_control(tun_ind);
    tun->state = TUNS_SPAW;
    add_expired_operation(tun_ind, 10, send_new_password);
#ifdef DEBUG
    if  (debug) Log("New password for tunnel [%s] sent", tun->label);
#endif /* DEBUG */
}

static void
accept_new_password(int tun_ind) {
    Tunnel *tun;
    CP* control;

    tun = (Tunnel *)stun->buf[tun_ind];
    control = (CP *) packetBuf;
    bcopy(control->body, tun->remote_key, 16);
    set_cipher_key(tun_ind, 1);
    tun->flags |= TUNF_HAS_RPWD;
    control->c.mt = CMT_SPWD_ACK;
    control->c.body_len = 0;
    send_control(tun_ind); 
}

static void
accept_reset(int tun_ind) {
    Tunnel *tun;
    CP* control;

    remove_expired_operation(tun_ind);
    tun = (Tunnel *)stun->buf[tun_ind];
    control = (CP *)packetBuf;
    tun->state = TUNS_INI;
    tun->flags &= (~TUNF_HAS_RPWD);
    control->c.mt = CMT_RESET_ACK;
    control->c.body_len = 0;
    send_control(tun_ind); 
    change_password(tun_ind);
}

static void
accept_reset_ack(int tun_ind) {
    Tunnel *tun;

    tun = (Tunnel *)stun->buf[tun_ind];
    remove_expired_operation(tun_ind);
    tun->state = TUNS_INI;
    if  (debug) Log("Tunnel [%s] now in initial state f=0x%x", tun->label, tun->flags);
    /* Now generate and send symmetric password to remote */
    change_password(tun_ind);
}

static void
accept_npwd_ack(int tun_ind) {
    Tunnel *tun;

    tun = (Tunnel *)stun->buf[tun_ind];
    tun->state = TUNS_PRC;
    remove_expired_operation(tun_ind);
    if  (debug) Log("Tunnel [%s] now in operation state, f=0x%x", tun->label, tun->flags);
}

void
proceed_cp(int tun_ind, int len) {
    int ret;
    Tunnel *tun;
    CP *cp;

    tun = (Tunnel *)stun->buf[tun_ind];
    cp = (CP *)packetBuf;
#ifdef DEBUG
    if  (debug) Log("proceed_cp(%d,%d)", tun_ind, len);
#endif /* DEBUG */
    ret = control_decrypt(cp, tun, len);
    if  (ret < 0) {
        if  (debug) Log("Dropped packet (illegal RSA encryption) from tun [%s] l=%d", tun->label, len);
        return;
    }
#ifdef DEBUG
    if  (debug) Log("CP type=%d", cp->c.mt);
#endif /* DEBUG */
    switch(cp->c.mt) {
    case CMT_SPWD:
        accept_new_password(tun_ind);
        break;
    case CMT_SPWD_ACK:
        accept_npwd_ack(tun_ind);
        break;
    case CMT_RESET:
        accept_reset(tun_ind);
        break;
    case CMT_RESET_ACK:
        accept_reset_ack(tun_ind);
        break;
    default:
        if  (debug) Log("Dropped illegal CP (unknown type): t=[%s] mt=%d", tun->label, cp->c.mt);
}   }

static void
generate_new_password() {
    static struct timeval tp;
    static struct {
        int rn1;
        int t1;
        int rn2;
        int t2;
    } buffer;
    unsigned char *hash;

    if  (gettimeofday(&tp, NULL)) Error("Error while gettimeofday: %s", strerror(errno));
    buffer.rn1 = get_rand();
    buffer.rn2 = get_rand();
    buffer.t1 = tp.tv_sec;
    buffer.t2 = tp.tv_usec;
    hash = get_md5_hash((unsigned char*)(&buffer), sizeof(buffer));
    bcopy(hash, sKey, 16);
}
