Chinaunix首页 | 论坛 | 博客
  • 博客访问: 1092972
  • 博文数量: 104
  • 博客积分: 3715
  • 博客等级: 中校
  • 技术积分: 1868
  • 用 户 组: 普通用户
  • 注册时间: 2006-04-30 08:38
文章分类

全部博文(104)

文章存档

2013年(1)

2012年(9)

2011年(41)

2010年(3)

2009年(3)

2008年(47)

分类: Java

2011-07-02 09:13:10

类型推导(Type Inference)是现代高级语言中一个越来越常见的特性。其实,这个特性在函数式语言
中早有了广泛应用。而HindleyMilner推导器是所有类型推导器的基础。

在阅读类型推导相关文章的时候,看到了另外一个哥们儿根据“Basic Type Checking”那篇文章实现的一
个简单的HindleyMilner推导器,对它稍作了修改,代码附下。(PS:这个哥们儿在注释的最后一句说了,“Do with it what you will”,我的修改版也保留了。)我的修改主要集中在两个方面:
1.  可读性改进:
     a.  在注释中追加语法定义;
     b.  更易理解的名字;
     c.  使用Identifer代替裸字符串;
     d.  使用Constant代替字面量;

2.  在语法级别支持IF表达式:
     在原实现中,条件表达式是作为内置函数实现的,而论文中是在语法级别实现的。我扩展了原作者
     的语法,支持IF表达式。

  1. /*
  2.  *
  3.  * Andrew Forrest
  4.  *
  5.  * Implementation of basic polymorphic type-checking for a simple language.
  6.  * Based heavily on Nikita Borisov’s Perl implementation at
  7.  * ~nikitab/courses/cs263/hm.html
  8.  * which in turn is based on the paper by Luca Cardelli at
  9.  *
  10.  *
  11.  * If you run it with "scala HindleyMilner.scala" it will attempt to report the types
  12.  * for a few example expressions. (It uses UTF-8 for output, so you may need to set your
  13.  * terminal accordingly.)
  14.  *
  15.  * Changes
  16.  * June 30, 2011 by Liang Kun(liangkun(AT)baidu.com)
  17.  * 1. Modify to enhance readability
  18.  * 2. Extend to Support if expression in syntax
  19.  *
  20.  *
  21.  *
  22.  * Do with it what you will. :)
  23.  */

  24. /** Syntax definition. This is a simple lambda calculous syntax.
  25.  * Expression ::= Identifier
  26.  * | Constant
  27.  * | "if" Expression "then" Expression "else" Expression
  28.  * | "lambda(" Identifier ") " Expression
  29.  * | Expression "(" Expression ")"
  30.  * | "let" Identifier "=" Expression "in" Expression
  31.  * | "letrec" Identifier "=" Expression "in" Expression
  32.  * | "(" Expression ")"
  33.  * See the examples below in main function.
  34.  */
  35. sealed abstract class Expression

  36. case class Identifier(name: String) extends Expression {
  37.     override def toString = name
  38. }

  39. case class Constant(value: String) extends Expression {
  40.     override def toString = value
  41. }

  42. case class If(condition: Expression, then: Expression, other: Expression) extends Expression {
  43.     override def toString = "(if " + condition + " then " + then + " else " + other + ")"
  44. }

  45. case class Lambda(argument: Identifier, body: Expression) extends Expression {
  46.     override def toString = "(lambda " + argument + " → " + body + ")"
  47. }

  48. case class Apply(function: Expression, argument: Expression) extends Expression {
  49.     override def toString = "(" + function + " " + argument + ")"
  50. }

  51. case class Let(binding: Identifier, definition: Expression, body: Expression) extends Expression {
  52.     override def toString = "(let " + binding + " = " + definition + " in " + body + ")"
  53. }

  54. case class Letrec(binding: Identifier, definition: Expression, body: Expression) extends Expression {
  55.     override def toString = "(letrec " + binding + " = " + definition + " in " + body + ")"
  56. }


  57. /** Exceptions may happened */
  58. class TypeError(msg: String) extends Exception(msg)
  59. class ParseError(msg: String) extends Exception(msg)


  60. /** Type inference system */
  61. object TypeSystem {
  62.     type Env = Map[Identifier, Type]
  63.     val EmptyEnv: Map[Identifier, Type] = Map.empty

  64.     // type variable and type operator
  65.     sealed abstract class Type
  66.     case class Variable(id: Int) extends Type {
  67.         var instance: Option[Type] = None
  68.         lazy val name = nextUniqueName()

  69.         override def toString = instance match {
  70.             case Some(t) => t.toString
  71.             case None => name
  72.         }
  73.     }

  74.     case class Operator(name: String, args: Seq[Type]) extends Type {
  75.         override def toString = {
  76.             if (args.length == 0)
  77.                 name
  78.             else if (args.length == 2)
  79.                 "[" + args(0) + " " + name + " " + args(1) + "]"
  80.             else
  81.                 args.mkString(name + "[", ", ", "]")
  82.         }
  83.     }

  84.     // builtin types, types can be extended by environment
  85.     def Function(from: Type, to: Type) = Operator("→", Array(from, to))
  86.     val Integer = Operator("Integer", Array[Type]())
  87.     val Boolean = Operator("Boolean", Array[Type]())


  88.     protected var _nextVariableName = 'α';
  89.     protected def nextUniqueName() = {
  90.         val result = _nextVariableName
  91.         _nextVariableName = (_nextVariableName.toInt + 1).toChar
  92.         result.toString
  93.     }
  94.     protected var _nextVariableId = 0
  95.     def newVariable(): Variable = {
  96.         val result = _nextVariableId
  97.         _nextVariableId += 1
  98.         Variable(result)
  99.     }


  100.     // main entry point
  101.     def analyze(expr: Expression, env: Env): Type = analyze(expr, env, Set.empty)
  102.     def analyze(expr: Expression, env: Env, nongeneric: Set[Variable]): Type = expr match {
  103.         case i: Identifier => getIdentifierType(i, env, nongeneric)

  104.         case Constant(value) => getConstantType(value)

  105.         case If(cond, then, other) => {
  106.             val condType = analyze(cond, env, nongeneric)
  107.             val thenType = analyze(then, env, nongeneric)
  108.             val otherType = analyze(other, env, nongeneric)
  109.             unify(condType, Boolean)
  110.             unify(thenType, otherType)
  111.             thenType
  112.         }

  113.         case Apply(func, arg) => {
  114.             val funcType = analyze(func, env, nongeneric)
  115.             val argType = analyze(arg, env, nongeneric)
  116.             val resultType = newVariable()
  117.             unify(Function(argType, resultType), funcType)
  118.             resultType
  119.         }

  120.         case Lambda(arg, body) => {
  121.             val argType = newVariable()
  122.             val resultType = analyze(body,
  123.                                      env + (arg -> argType),
  124.                                      nongeneric + argType)
  125.             Function(argType, resultType)
  126.         }

  127.         case Let(binding, definition, body) => {
  128.             val definitionType = analyze(definition, env, nongeneric)
  129.             val newEnv = env + (binding -> definitionType)
  130.             analyze(body, newEnv, nongeneric)
  131.         }

  132.         case Letrec(binding, definition, body) => {
  133.             val newType = newVariable()
  134.             val newEnv = env + (binding -> newType)
  135.             val definitionType = analyze(definition, newEnv, nongeneric + newType)
  136.             unify(newType, definitionType)
  137.             analyze(body, newEnv, nongeneric)
  138.         }
  139.     }

  140.     protected def getIdentifierType(id: Identifier, env: Env, nongeneric: Set[Variable]): Type = {
  141.         if (env.contains(id))
  142.             fresh(env(id), nongeneric)
  143.         else
  144.             throw new ParseError("Undefined symbol: " + id)
  145.     }

  146.     protected def getConstantType(value: String): Type = {
  147.         if(isIntegerLiteral(value))
  148.             Integer
  149.         else
  150.             throw new ParseError("Undefined symbol: " + value)
  151.     }

  152.     protected def fresh(t: Type, nongeneric: Set[Variable]) = {
  153.         import scala.collection.mutable
  154.         val mappings = new mutable.HashMap[Variable, Variable]
  155.         def freshrec(tp: Type): Type = {
  156.             prune(tp) match {
  157.                 case v: Variable =>
  158.                     if (isgeneric(v, nongeneric))
  159.                         mappings.getOrElseUpdate(v, newVariable())
  160.                     else
  161.                         v

  162.                 case Operator(name, args) =>
  163.                     Operator(name, args.map(freshrec(_)))
  164.             }
  165.         }

  166.         freshrec(t)
  167.     }

  168.     protected def unify(t1: Type, t2: Type) {
  169.         val type1 = prune(t1)
  170.         val type2 = prune(t2)
  171.         (type1, type2) match {
  172.             case (a: Variable, b) => if (a != b) {
  173.                 if (occursintype(a, b))
  174.                     throw new TypeError("Recursive unification")
  175.                 a.instance = Some(b)
  176.             }
  177.             case (a: Operator, b: Variable) => unify(b, a)
  178.             case (a: Operator, b: Operator) => {
  179.                 if (a.name != b.name ||
  180.                     a.args.length != b.args.length) throw new TypeError("Type mismatch: " + a + " ≠ " + b)
  181.                 
  182.                 for(i <- 0 until a.args.length)
  183.                     unify(a.args(i), b.args(i))
  184.             }
  185.         }
  186.     }

  187.     // Returns the currently defining instance of t.
  188.     // As a side effect, collapses the list of type instances.
  189.     protected def prune(t: Type): Type = t match {
  190.         case v: Variable if v.instance.isDefined => {
  191.             val inst = prune(v.instance.get)
  192.             v.instance = Some(inst)
  193.             inst
  194.         }
  195.         case _ => t
  196.     }

  197.     // Note: must be called with v 'pre-pruned'
  198.     protected def isgeneric(v: Variable, nongeneric: Set[Variable]) = !(occursin(v, nongeneric))

  199.     // Note: must be called with v 'pre-pruned'
  200.     protected def occursintype(v: Variable, type2: Type): Boolean = {
  201.         prune(type2) match {
  202.             case `v` => true
  203.             case Operator(name, args) => occursin(v, args)
  204.             case _ => false
  205.         }
  206.     }

  207.     protected def occursin(t: Variable, list: Iterable[Type]) =
  208.         list exists (t2 => occursintype(t, t2))

  209.     protected val checkDigits = "^(\\d+)$".r
  210.     protected def isIntegerLiteral(name: String) = checkDigits.findFirstIn(name).isDefined
  211. }


  212. /** Demo program */
  213. object HindleyMilner {
  214.     def main(args: Array[String]){
  215.         Console.setOut(new java.io.PrintStream(Console.out, true, "utf-8"))

  216.         // extends the system with a new type[pair] and some builtin functions
  217.         val left = TypeSystem.newVariable()
  218.         val right = TypeSystem.newVariable()
  219.         val pairType = TypeSystem.Operator("×", Array(left, right))

  220.         val myenv: TypeSystem.Env = TypeSystem.EmptyEnv ++ Array(
  221.             Identifier("pair") -> TypeSystem.Function(left, TypeSystem.Function(right, pairType)),
  222.             Identifier("true") -> TypeSystem.Boolean,
  223.             Identifier("false")-> TypeSystem.Boolean,
  224.             Identifier("zero") -> TypeSystem.Function(TypeSystem.Integer, TypeSystem.Boolean),
  225.             Identifier("pred") -> TypeSystem.Function(TypeSystem.Integer, TypeSystem.Integer),
  226.             Identifier("times")-> TypeSystem.Function(TypeSystem.Integer,
  227.                     TypeSystem.Function(TypeSystem.Integer, TypeSystem.Integer))
  228.         )

  229.         // example expressions
  230.         val pair = Apply(
  231.             Apply(
  232.                 Identifier("pair"), Apply(Identifier("f"), Constant("4"))
  233.             ),
  234.             Apply(Identifier("f"), Identifier("true"))
  235.         )
  236.         val examples = Array[Expression](
  237.             // factorial
  238.             Letrec(Identifier("factorial"), // letrec factorial =
  239.                 Lambda(Identifier("n"), // lambda n =>
  240.                     If(
  241.                         Apply(Identifier("zero"), Identifier("n")),

  242.                         Constant("1"),

  243.                         Apply(
  244.                             Apply(Identifier("times"), Identifier("n")),
  245.                             Apply(
  246.                                 Identifier("factorial"),
  247.                                 Apply(Identifier("pred"), Identifier("n"))
  248.                             )
  249.                         )
  250.                     )
  251.                 ), // in
  252.                 Apply(Identifier("factorial"), Constant("5"))
  253.             ),

  254.             // Should fail:
  255.             // fn x => (pair(x(3) (x(true))))
  256.             Lambda(Identifier("x"),
  257.                 Apply(
  258.                     Apply(Identifier("pair"),
  259.                         Apply(Identifier("x"), Constant("3"))
  260.                     ),
  261.                     Apply(Identifier("x"), Identifier("true"))
  262.                 )
  263.             ),

  264.             // pair(f(3), f(true))
  265.             Apply(
  266.                 Apply(Identifier("pair"), Apply(Identifier("f"), Constant("4"))),
  267.                 Apply(Identifier("f"), Identifier("true"))
  268.             ),


  269.             // letrec f = (fn x => x) in ((pair (f 4)) (f true))
  270.             Let(Identifier("f"), Lambda(Identifier("x"), Identifier("x")), pair),

  271.             // Should fail:
  272.             // fn f => f f
  273.             Lambda(Identifier("f"), Apply(Identifier("f"), Identifier("f"))),

  274.             // let g = fn f => 5 in g g
  275.             Let(
  276.                 Identifier("g"),
  277.                 Lambda(Identifier("f"), Constant("5")),
  278.                 Apply(Identifier("g"), Identifier("g"))
  279.             ),

  280.             // example that demonstrates generic and non-generic variables:
  281.             // fn g => let f = fn x => g in pair (f 3, f true)
  282.             Lambda(Identifier("g"),
  283.                 Let(Identifier("f"),
  284.                     Lambda(Identifier("x"), Identifier("g")),
  285.                     Apply(
  286.                         Apply(Identifier("pair"),
  287.                               Apply(Identifier("f"), Constant("3"))
  288.                         ),
  289.                         Apply(Identifier("f"), Identifier("true"))
  290.                     )
  291.                 )
  292.             ),

  293.             // Function composition
  294.             // fn f (fn g (fn arg (f g arg)))
  295.             Lambda( Identifier("f"),
  296.                 Lambda( Identifier("g"),
  297.                     Lambda( Identifier("arg"),
  298.                         Apply(Identifier("g"), Apply(Identifier("f"), Identifier("arg")))
  299.                     )
  300.                 )
  301.             )
  302.         )

  303.         for(eg <- examples){
  304.             tryexp(myenv, eg)
  305.         }
  306.     }

  307.     def tryexp(env: TypeSystem.Env, expr: Expression) {
  308.         try {
  309.             val t = TypeSystem.analyze(expr, env)
  310.             print(t)

  311.         }catch{
  312.             case t: ParseError => print(t.getMessage)
  313.             case t: TypeError => print(t.getMessage)
  314.         }
  315.         println(":\t" + expr)
  316.     }
  317. }

  318. HindleyMilner.main(argv)

阅读(2926) | 评论(0) | 转发(0) |
0

上一篇:Hive HQL优化器

下一篇:买了宾得KR的套机

给主人留下些什么吧!~~