include Backend open Printf open Expr open Func let indent s = " "^s let comments l = List.map (fun x -> "# "^x) l let documentation f = comments (documentation f) let const = function | Int x -> string_of_int x | Float x -> (string_of_float x) ^ "." | Math "pi" -> "numpy.pi" | Math x -> x let func = function | "exp" -> "numpy.exp" | "sqrt" -> "numpy.sqrt" | "log" -> "numpy.log" | "abs" -> "numpy.abs" | "inv" -> "numpy.inv" | "det" -> "numpy.det" | s -> s let never_paren = function | Cst _ | Var _ | Pow _ | App _ | Mul _ | Opp _ | Length _ | Get (Var _, _) -> true | _ -> false let never_paren_for_pow = function | Cst _ | Var _ | Pow _ | App _ | Get (Var _, _) -> true | _ -> false let paren x = sprintf "(%s)" x let pow u v = sprintf "%s**%s" u v let length x = sprintf "%s.shape[0]" x let sum x = sprintf "%s.sum()" x let transpose x = sprintf "%s.T" x (* let need *) let rec expr = function (* Variable *) | Var v -> var (Variable.name v) (* Constant *) | Cst x -> const x (* Multiplication *) | Mul (u, v) when never_paren u && never_paren v -> (expr u) ^ " * " ^ (expr v) | Mul (u, (Mul _ as v)) when never_paren u -> (expr u) ^ " * " ^ (expr v) | Mul ((Mul _ as u), v) when never_paren v -> (expr u) ^ " * " ^ (expr v) | Mul ((Mul _ as u), v) -> (expr u) ^ " * " ^ (paren (expr v)) | Mul (u, (Mul _ as v)) -> (paren (expr u)) ^ " * " ^ (expr v) | Mul (u, (Div _ as v)) when never_paren u -> (expr u) ^ " * " ^ (expr v) | Mul ((Div _ as u), v) when never_paren v -> (expr u) ^ " * " ^ (expr v) | Mul ((Div _ as u), v) -> (expr u) ^ " * " ^ (paren (expr v)) | Mul (u, (Div _ as v)) -> (paren (expr u)) ^ " * " ^ (expr v) | Mul (u, v) when never_paren u -> (expr u) ^ " * " ^ (paren (expr v)) | Mul (u, v) when never_paren v -> (paren (expr u)) ^ " * " ^ (expr v) | Mul (u, v) -> (paren (expr u)) ^ " * " ^ (paren (expr v)) (* Division *) | Div (u, v) when never_paren u && never_paren v -> (expr u) ^ " / " ^ (expr v) | Div (u, v) when never_paren u -> (expr u) ^ " / " ^ (paren (expr v)) | Div (u, v) when never_paren v -> (paren (expr u)) ^ " / " ^ (expr v) | Div (u, v) -> (paren (expr u)) ^ " / " ^ (paren (expr v)) (* Addition *) | Add (u, v) when never_paren u && never_paren v -> (expr u) ^ " + " ^ (expr v) | Add (u, (Add _ as v)) when never_paren u -> (expr u) ^ " + " ^ (expr v) | Add ((Add _ as u), v) when never_paren v -> (expr u) ^ " + " ^ (expr v) | Add ((Add _ as u), v) -> (expr u) ^ " + " ^ (expr v) | Add (u, (Div _ as v)) when never_paren u -> (expr u) ^ " + " ^ (expr v) | Add ((Div _ as u), v) when never_paren v -> (expr u) ^ " + " ^ (expr v) | Add ((Div _ as u), v) -> (expr u) ^ " + " ^ (expr v) | Add (u, v) when never_paren u -> (expr u) ^ " + " ^ (paren (expr v)) | Add (u, (Add _ as v)) -> (paren (expr u)) ^ " + " ^ (expr v) | Add (u, v) when never_paren v -> (paren (expr u)) ^ " + " ^ (expr v) | Add (u, v) -> (paren (expr u)) ^ " + " ^ (paren (expr v)) (* Minus *) | Min (u, v) when never_paren u && never_paren v -> (expr u) ^ " - " ^ (expr v) | Min (u, (Min _ as v)) when never_paren u -> (expr u) ^ " - " ^ (expr v) | Min (u, (Min _ as v)) -> (paren (expr u)) ^ " - " ^ (expr v) | Min (u, (Add _ as v)) when never_paren u -> (expr u) ^ " - " ^ (expr v) | Min (u, (Add _ as v)) -> (paren (expr u)) ^ " - " ^ (expr v) | Min (u, v) when never_paren u -> (expr u) ^ " - " ^ (paren (expr v)) | Min ((Min _ as u), v) when never_paren v -> (expr u) ^ " - " ^ (expr v) | Min ((Min _ as u), v) -> (expr u) ^ " - " ^ (paren (expr v)) | Min ((Add _ as u), v) when never_paren v -> (expr u) ^ " - " ^ (expr v) | Min ((Add _ as u), v) -> (expr u) ^ " - " ^ (paren (expr v)) | Min (u, v) when never_paren v -> (paren (expr u)) ^ " - " ^ (expr v) | Min (u, v) -> (paren (expr u)) ^ " - " ^ (paren (expr v)) (* Power *) | Pow (u, v) when never_paren_for_pow u && never_paren_for_pow v -> (expr u) ^ "**" ^ (expr v) | Pow (u, v) when never_paren_for_pow v -> (paren (expr u)) ^ "**" ^ (expr v) | Pow (u, v) when never_paren_for_pow u -> (expr u) ^ "**" ^ (paren (expr v)) | Pow (u, v) -> (paren (expr u)) ^ "**" ^ (paren (expr v)) (* Opposite *) | Opp u when never_paren u -> "- " ^ (expr u) | Opp (Mul _ as u) -> "- " ^ (expr u) | Opp (Div _ as u) -> "- " ^ (expr u) | Opp u -> "- " ^ (paren (expr u)) (* Application *) | App ("transpose", [u]) when never_paren u -> transpose (expr u) | App ("transpose", [u]) -> transpose (paren (expr u)) | App (f, lu) -> (func f) ^ "(" ^ (String.concat ", " (List.map expr lu)) ^ ")" (* Vector length *) | Length u -> length (expr u) (* Vector sum *) | Sum (Var _ as u) -> sum (expr u) | Sum u -> sum (paren (expr u)) (* Vector index *) | Get (Var _ as u, i) -> (expr u) ^ "[" ^ (string_of_int i) ^ "]" (* Tuple *) | Tuple lu -> "(" ^ (String.concat ", " (List.map expr lu) ) ^ ")" (* Vector *) | Vector lu -> "numpy.array([" ^ (String.concat ", " (List.map expr lu) ) ^ "])" let gettype (t, _, _) = t let docstring f = let doc = Backend.documentation f in match doc with | [] -> [] | [x] -> ["\"" ^ x ^ "\""] | x::q -> ("\"\"\" " ^ x)::q @ ["\"\"\""] let def ?(doc=true) f = String.concat "\n" ( [sprintf "def %s(%s):" (var f.name) (String.concat ", " (List.map var f.args))]@ (if doc then List.map indent (docstring f) else [])@ [sprintf " return %s" (expr f.expr)]) let try_def f = try def f with _ -> "Backend failure"