Chinaunix首页 | 论坛 | 博客
  • 博客访问: 1069711
  • 博文数量: 252
  • 博客积分: 4561
  • 博客等级: 上校
  • 技术积分: 2833
  • 用 户 组: 普通用户
  • 注册时间: 2008-03-15 08:23
文章分类

全部博文(252)

文章存档

2015年(2)

2014年(1)

2013年(1)

2012年(16)

2011年(42)

2010年(67)

2009年(87)

2008年(36)

分类: LINUX

2010-07-02 09:14:12

#include <asm/atomic.h>
#include <asm/byteorder.h>
#include <asm/checksum.h>
#include <asm/unaligned.h>
#include <linux/module.h>
#include <linux/init.h>
#include <linux/kernel.h>
#include <linux/net.h>
#include <linux/in.h>
#include <linux/types.h>
#include <linux/kthread.h>
#include <linux/wait.h>
#include <linux/skbuff.h>
#include <linux/string.h>
#include <linux/sysctl.h>
#include <linux/netfilter.h>
#include <linux/netfilter_ipv4.h>
#include <linux/ip.h>
#include <linux/workqueue.h>
#include <linux/jiffies.h>
#include <linux/route.h>
#include <linux/if_arp.h>
#include <linux/inetdevice.h>
#include <linux/slab.h>
#include <linux/stddef.h>
#include <linux/mutex.h>
#include <linux/compiler.h>
#include <linux/icmp.h>
#include <linux/jhash.h>
#include <linux/list.h>
#include <linux/inet.h>
#include <linux/ctype.h>
#include <linux/spinlock_types.h>
#include <net/net_namespace.h>
#include <net/sock.h>
#include <net/route.h>
#include <net/inet_connection_sock.h>
#include <net/request_sock.h>
#include <net/icmp.h>
#include <net/ip.h>
#include <net/tcp.h>

#include "main.h"

#define err(msg) printk(KERN_ERR "%s failed.\n", msg)
#define SA struct sockaddr

static struct conn_struct conn;

static void mangle_data(struct sk_buff *skb)
{
    struct iphdr *iph;
    struct tcphdr *tcph;
    char *data_start;
    int dlen;
    char replace[] = "songtao\n";
    int replace_len = strlen(replace);

    iph = ip_hdr(skb);
    tcph = (struct tcphdr *)(skb_network_header(skb) + ip_hdrlen(skb));

    data_start = ((char *)tcph) + (tcph->doff << 2);
    if (!strncmp(data_start, "tmp", 3)) {
        if (pskb_expand_head(skb, skb_headroom(skb), 100, GFP_ATOMIC | __GFP_ZERO) != 0) {
            err("pskb_expand_head");
            goto out;                    /* songtao = 7 characters */
        }

        iph = ip_hdr(skb);
        tcph = (struct tcphdr *)(skb_network_header(skb) + ip_hdrlen(skb));
        if (iph == NULL || tcph == NULL) {
            err("ip_hdr or tcphdr");
            goto out;
        }

        dlen = ntohs(iph->tot_len) - ((iph->ihl << 2) + (tcph->doff << 2));
        conn.inseq.init_seq = tcph->seq;
        conn.inseq.delta += replace_len - dlen;

        data_start = ((char *)tcph) + (tcph->doff << 2);    /* get data addr */
        memmove(data_start, replace, replace_len);
        skb->len += replace_len - dlen;
        skb_trim(skb, skb->len);                /* set len and tail pointer */
        iph->tot_len = htons(skb->len);
    }
out:    
    return;
}

static struct rtable * get_output_route(int oif, __be32 daddr, __be32 saddr, __u8 tos)
{
    struct rtable *rt;

    struct flowi fl = {             
        .oif = oif,     
        .nl_u = {
            .ip4_u = {
                .daddr = daddr,
                .saddr = saddr,
                .tos = tos, } },
    };


    if (ip_route_output_key(&init_net, &rt, &fl) != 0) {
        err("ip_route_output_key");
        return NULL;
    }

    return rt;
}

static unsigned int sg_vs_in(unsigned int hooknum,
        struct sk_buff *skb,
        const struct net_device *in,
        const struct net_device *out,
        int (*okfn)(struct sk_buff *))
{
    struct rtable *rt;
    struct tcphdr *tcph;
    struct iphdr *iph;

    if ((skb = skb_share_check(skb, GFP_ATOMIC)) == NULL) {
        err("skb_share_check");
        goto out;
    }

    if (!skb_make_writable(skb, skb->len)) {            /* writable and linearize */
        err("skb_make_writable");
        goto out;
    }

    iph = ip_hdr(skb);
    if(iph == NULL){
        err("ip_hdr");
        return NF_DROP;
    }

    tcph = (struct tcphdr *)(skb_network_header(skb) + ip_hdrlen(skb));
    if (tcph == NULL) {
        err("tcph is NULL.");
        return NF_DROP;
    }

    if (iph->daddr == conn.vaddr && tcph->dest == conn.vport) {     /* packet from client to load balancer */
        if (conn.inseq.delta != 0)
            tcph->seq = htonl(ntohl(tcph->seq) + conn.inseq.delta);

        mangle_data(skb);
    

        iph = ip_hdr(skb);
        tcph = (struct tcphdr *)(skb_network_header(skb) + ip_hdrlen(skb));
        if(iph == NULL || tcph == NULL){
            err("ip_hdr");
            goto out;
        }

        conn.cport = tcph->source;    /* fill cport */

        tcph->dest = conn.dport;
        iph->saddr = conn.vaddr;
        iph->daddr = conn.daddr;     /* input */

        tcph->check = 0;     /* tcp_hdrlen(skb) get L4 header length */
        skb->csum = skb_checksum(skb, ip_hdrlen(skb), skb->len - ip_hdrlen(skb), 0);
        tcph->check = csum_tcpudp_magic(iph->saddr, iph->daddr, skb->len - ip_hdrlen(skb), iph->protocol, skb->csum);

        ip_send_check(ip_hdr(skb));

        rt = get_output_route(0, iph->daddr, 0, RT_TOS(iph->tos));

        skb_dst_drop(skb);
        skb_dst_set(skb, &rt->u.dst);
        skb->local_df = 1;

        dst_output(skb);

        return NF_STOLEN;

    } else if (iph->saddr == conn.daddr && tcph->source == conn.dport) {     /* packet from real server to load balancer */

        tcph->source = conn.vport;
        tcph->dest = conn.cport;
        iph->saddr = conn.vaddr;    /* output */
        iph->daddr = conn.caddr;

        if (conn.inseq.delta != 0)
            tcph->ack_seq = htonl(ntohl(tcph->ack_seq) - conn.inseq.delta);

        tcph->check = 0;
        skb->csum = skb_checksum(skb, ip_hdrlen(skb), skb->len - ip_hdrlen(skb), 0);
        tcph->check = csum_tcpudp_magic(iph->saddr, iph->daddr, skb->len - ip_hdrlen(skb), iph->protocol, skb->csum);

        ip_send_check(ip_hdr(skb));


        rt = get_output_route(0, iph->daddr, 0, RT_TOS(iph->tos));

        skb_dst_drop(skb);
        skb_dst_set(skb, &rt->u.dst);
        skb->local_df = 1;

        dst_output(skb);

        return NF_STOLEN;

    }
out:
    return NF_ACCEPT;
}

static struct nf_hook_ops sg_nf_ops = {
    .hook     = sg_vs_in,
    .owner     = THIS_MODULE,
    .pf     = PF_INET,
    .hooknum     = NF_INET_LOCAL_IN,
};

static void conn_struct_init(struct conn_struct *conn)
{
    conn->caddr = in_aton("192.168.10.184");
    conn->vaddr = in_aton("192.168.10.188");
    conn->daddr = in_aton("192.168.10.19");
    conn->cport = 0;
    conn->vport = htons(7777);
    conn->dport = htons(7777);

    conn->inseq.init_seq = conn->inseq.delta = 0;
    conn->outseq.init_seq = conn->outseq.delta = 0;
}

static int main_init(void)
{
    conn_struct_init(&conn);

    if (nf_register_hook(&sg_nf_ops) != 0) {
        err("nf_register_hook");
        goto out;
    }

    return 0;
out:
    return -1;
}

static void main_exit(void)
{
    nf_unregister_hook(&sg_nf_ops);
}

module_init(main_init);
module_exit(main_exit);
MODULE_LICENSE("GPL");


阅读(812) | 评论(0) | 转发(0) |
给主人留下些什么吧!~~