{-# LANGUAGE ParallelListComp #-} {-# LANGUAGE PartialTypeSignatures #-} {-# LANGUAGE OverloadedLists #-} {-# LANGUAGE TemplateHaskell #-} module Rlp.HindleyMilner ( typeCheckRlpProgR , annotate , TypeError(..) , runHM' , HM , prettyVars , prettyVars' ) where -------------------------------------------------------------------------------- import Control.Lens hiding (Context', Context, (:<), para) import Control.Lens.Unsound import Control.Monad.Errorful import Control.Monad.State import Control.Monad.Accum import Control.Monad import Control.Arrow ((>>>)) import Control.Monad.Writer.Strict import Data.List import Data.Monoid import Data.Text qualified as T import Data.Foldable (fold) import Data.Function import Data.Pretty hiding (annotate) import Data.Hashable import Data.HashMap.Strict (HashMap) import Data.HashMap.Strict qualified as H import Data.HashSet (HashSet) import Data.HashSet qualified as S import Data.Maybe (fromMaybe) import Data.Traversable import GHC.Generics (Generic(..), Generically(..)) import Debug.Trace import Data.Functor import Data.Functor.Foldable hiding (fold) import Data.Fix hiding (cata, para) import Control.Comonad.Cofree import Control.Comonad import Compiler.RLPC import Compiler.RlpcError import Rlp.AltSyntax as Rlp import Core.Syntax qualified as Core import Core.Syntax (ExprF(..), Lit(..)) import Rlp.HindleyMilner.Types -------------------------------------------------------------------------------- fixCofree :: (Functor f, Functor g) => Iso (Fix f) (Fix g) (Cofree f ()) (Cofree g b) fixCofree = iso sa bt where sa = foldFix (() :<) bt (_ :< as) = Fix $ bt <$> as lookupVar :: PsName -> Context -> HM (Type PsName) lookupVar n g = case g ^. contextVars . at n of Just t -> pure t Nothing -> addFatal $ TyErrUntypedVariable n gather :: RlpExpr PsName -> HM (Type PsName, PartialJudgement) gather e = look >>= (H.lookup e >>> maybe memoise pure) where memoise = do r <- gather' e add (H.singleton e r) pure r gather' :: RlpExpr PsName -> HM (Type PsName, PartialJudgement) gather' = \case Finl (LitF (IntL _)) -> pure (IntT, mempty) Finl (VarF n) -> do t <- freshTv let j = mempty & assumptions .~ H.singleton n [t] pure (t,j) Finl (AppF f x) -> do tfx <- freshTv (tf,jf) <- gather f (tx,jx) <- gather x let jtfx = mempty & constraints .~ [Equality tf (tx :-> tfx)] pure (tfx, jf <> jx <> jtfx) Finl (LamF bs e) -> do tbs <- for bs (const freshTv) (te,je) <- gather e let cs = concatMap (uncurry . equals $ je ^. assumptions) $ bs `zip` tbs as = foldr H.delete (je ^. assumptions) bs j = mempty & constraints .~ (je ^. constraints <> cs) & assumptions .~ as t = foldr (:->) te tbs pure (t,j) where equals as b tb = maybe [] (fmap $ Equality tb) (as ^. at b) -- Finl (LamF [b] e) -> do -- tb <- freshTv -- (te,je) <- gather e -- let cs = maybe [] (fmap $ Equality tb) (je ^. assumptions . at b) -- as = je ^. assumptions & at b .~ Nothing -- j = mempty & constraints .~ cs & assumptions .~ as -- t = tb :-> te -- pure (t,j) unify :: [Constraint] -> HM Context unify [] = pure mempty unify (Equality (sx :-> sy) (tx :-> ty) : cs) = unify $ Equality sx tx : Equality sy ty : cs -- elim unify (Equality (ConT s) (ConT t) : cs) | s == t = unify cs unify (Equality (VarT s) (VarT t) : cs) | s == t = unify cs unify (Equality (VarT s) t : cs) | occurs s t = addFatal $ TyErrRecursiveType s t | otherwise = unify cs' <&> contextVars . at s ?~ t where cs' = cs & each . constraintTypes %~ subst s t -- swap unify (Equality s (VarT t) : cs) = unify (Equality (VarT t) s : cs) unify (Equality s t : _) = addFatal $ TyErrCouldNotUnify s t unify' :: [Constraint] -> HM [(PsName, Type PsName)] unify' [] = pure mempty unify' (Equality (sx :-> sy) (tx :-> ty) : cs) = unify' $ Equality sx tx : Equality sy ty : cs -- elim unify' (Equality (ConT s) (ConT t) : cs) | s == t = unify' cs unify' (Equality (VarT s) (VarT t) : cs) | s == t = unify' cs unify' (Equality (VarT s) t : cs) | occurs s t = addFatal $ TyErrRecursiveType s t | otherwise = unify' cs' <&> ((s,t):) where cs' = cs & each . constraintTypes %~ subst s t -- swap unify' (Equality s (VarT t) : cs) = unify' (Equality (VarT t) s : cs) unify' (Equality s t : _) = addFatal $ TyErrCouldNotUnify s t annotate :: RlpExpr PsName -> HM (Cofree (RlpExprF PsName) (Type PsName, PartialJudgement)) annotate = sequenceA . fixtend (gather . wrapFix) -- infer1 :: RlpExpr PsName -> HM (Type PsName) -- infer1 = infer1' mempty -- infer1' :: Context -> RlpExpr PsName -> HM (Type PsName) -- infer1' g1 e = do -- ((t,j) :< _) <- annotate e -- g2 <- unify (j ^. constraints) -- g <- unionContextWithKeyM unifyTypes g1 g2 -- pure $ ifoldrOf (contextVars . itraversed) subst t g -- where -- -- intuitively, we'd return mgu(s,t) but the union is left-biased making `s` -- -- the user-specified type: prioritise her. -- unifyTypes _ s t = unify [Equality s t] $> s assocs :: IndexedTraversal k [(k,v)] [(k,v')] v v' assocs f [] = pure [] assocs f ((k,v):xs) = (\v' xs' -> (k,v') : xs') <$> indexed f k v <*> assocs f xs traceSubst k v t = trace ("subst " <> show' k <> " " <> show' v <> " " <> show' t) $ subst k v t where show' a = showsPrec 11 a mempty infer :: Context -> RlpExpr PsName -> HM (Cofree (RlpExprF PsName) (Type PsName)) infer g1 e = do e' <- annotate e g2 <- unify' $ concatOf (folded . _2 . constraints) e' traceM $ "e': " <> show (view _1 <$> e') traceM $ "g2: " <> show g2 let sub t = ifoldrOf (reversed . assocs) traceSubst t g2 pure $ sub . view _1 <$> e' where -- intuitively, we'd return mgu(s,t) but the union is left-biased making `s` -- the user-specified type: prioritise her. unifyTypes _ s t = unify [Equality s t] $> s e :: Cofree (RlpExprF PsName) (Type PsName) e = AppT (AppT FunT (VarT "$a2")) (AppT (AppT FunT (VarT "$a3")) (VarT "$a4")) :< InL (LamF ["f","x"] (VarT "$a4" :< InL (AppF (VarT "$a5" :< InL (VarF "f")) (VarT "$a6" :< InL (AppF (VarT "$a5" :< InL (VarF "f")) (VarT "$a1" :< InL (VarF "x"))))))) g = Context { _contextVars = H.fromList [("$a1",VarT "$a6") ,("$a3",VarT "$a4") ,("$a2",AppT (AppT FunT (VarT "$a4")) (VarT "$a4")) ,("$a5",AppT (AppT FunT (VarT "$a1")) (VarT "$a6")) ,("$a6",VarT "$a4")]} unionContextWithKeyM :: Monad m => (PsName -> Type PsName -> Type PsName -> m (Type PsName)) -> Context -> Context -> m Context unionContextWithKeyM f a b = Context <$> unionWithKeyM f a' b' where a' = a ^. contextVars b' = b ^. contextVars unionWithKeyM :: forall m k v. (Eq k, Hashable k, Monad m) => (k -> v -> v -> m v) -> HashMap k v -> HashMap k v -> m (HashMap k v) unionWithKeyM f a b = sequenceA $ H.unionWithKey f' ma mb where f' k x y = join $ liftA2 (f k) x y ma = fmap (pure @m) a mb = fmap (pure @m) b -- solve :: RlpExpr PsName -> HM (Cofree (RlpExprF PsName) (Type PsName)) -- solve = solve' mempty -- solve' :: Context -> RlpExpr PsName -- -> HM (Cofree (RlpExprF PsName) (Type PsName)) -- solve' g = sequenceA . fixtend (infer1' g . wrapFix) occurs :: PsName -> Type PsName -> Bool occurs n = cata \case VarTF m | n == m -> True t -> or t subst :: PsName -> Type PsName -> Type PsName -> Type PsName subst n t' = para \case VarTF m | n == m -> t' -- shadowing ForallTF x (pre,post) | x == n -> ForallT x pre t -> embed $ t <&> view _2 prettyHM :: (Out a) => Either [TypeError] (a, [Constraint]) -> Either [TypeError] (String, [String]) prettyHM = over (mapped . _1) rout . over (mapped . _2 . each) rout fixtend :: Functor f => (f (Fix f) -> b) -> Fix f -> Cofree f b fixtend c (Fix f) = c f :< fmap (fixtend c) f -- infer :: RlpExpr PsName -> HM (Cofree (RlpExprF PsName) (Type PsName)) -- infer = infer' mempty -- infer' :: Context -> RlpExpr PsName -- -> HM (Cofree (RlpExprF PsName) (Type PsName)) -- infer' g = sequenceA . fixtend (infer1' g . wrapFix) buildInitialContext :: Program PsName a -> Context buildInitialContext = Context . H.fromList . toListOf (programDecls . each . _TySigD) typeCheckRlpProgR :: (Monad m) => Program PsName (RlpExpr PsName) -> RLPCT m (Program PsName (Cofree (RlpExprF PsName) (Type PsName))) typeCheckRlpProgR p = tc p where g = buildInitialContext p tc = liftHM . traverse (infer g) . etaExpandAll etaExpandAll = programDecls . each %~ etaExpand etaExpand :: Decl b (RlpExpr b) -> Decl b (RlpExpr b) etaExpand (FunD n [] e) = FunD n [] e etaExpand (FunD n as e) | Right as' <- allVarP as = FunD n [] (Finl . LamF as' $ e) where allVarP = traverse (matching _VarP) etaExpand a = a liftHM :: (Monad m) => HM a -> RLPCT m a liftHM = liftEither . runHM' freeVariables :: Type PsName -> HashSet PsName freeVariables = cata \case VarTF x -> S.singleton x ForallTF x m -> m `S.difference` S.singleton x vs -> fold vs boundVariables :: Type PsName -> HashSet PsName boundVariables = cata \case ForallTF x m -> S.singleton x <> m vs -> fold vs -- | rename all free variables for aesthetic purposes prettyVars' :: Type PsName -> Type PsName prettyVars' = join prettyVars freeVariablesLTR :: Type PsName -> [PsName] freeVariablesLTR = nub . cata \case VarTF x -> [x] ForallTF x m -> m \\ [x] vs -> concat vs -- | for some type, compute a substitution which will rename all free variables -- for aesthetic purposes prettyVars :: Type PsName -> Type PsName -> Type PsName prettyVars root = appEndo (foldMap Endo subs) where alphabetNames = [ T.pack [c] | c <- ['a'..'z'] ] names = alphabetNames \\ S.toList (boundVariables root) subs = zipWith (\k v -> subst k (VarT v)) (freeVariablesLTR root) names -- test :: Type PsName -> [(PsName, PsName)] -- test root = subs -- where -- alphabetNames = [ T.pack [c] | c <- ['a'..'z'] ] -- names = alphabetNames \\ S.toList (boundVariables root) -- subs = zip (freeVariablesLTR root) names