diff --git a/crates/javac-ast/src/lib.rs b/crates/javac-ast/src/lib.rs index c5b6ae0..425b9ef 100644 --- a/crates/javac-ast/src/lib.rs +++ b/crates/javac-ast/src/lib.rs @@ -249,6 +249,7 @@ pub enum JavaSyntaxKind { PatternExpr, LambdaExpr, + LambdaParam, MethodRefExpr, Annotation, diff --git a/crates/javac-bytecode/src/class_gen.rs b/crates/javac-bytecode/src/class_gen.rs index d6e66fc..d1e2bf7 100644 --- a/crates/javac-bytecode/src/class_gen.rs +++ b/crates/javac-bytecode/src/class_gen.rs @@ -1,11 +1,15 @@ -use crate::codegen::CodegenCtx; +use crate::codegen::{CodegenCtx, LambdaInfo}; use crate::error::BytecodeError; +use crate::expr_gen; +use crate::local_var::return_opcode; use javac_call_resolver::ClassCatalog; use javac_classfile::ClassFileWriter; use javac_hir::hir::*; use javac_ty::Ty; use rust_asm::constants::V21; +use rust_asm::insn::Handle; use rust_asm::opcodes; +use std::collections::HashMap; const OBJECT_CLASS: &str = "java/lang/Object"; const INIT_METHOD: &str = ""; @@ -64,7 +68,28 @@ fn gen_type_decl(writer: &mut ClassFileWriter, type_decl: &TypeDecl, catalog: &C if needs_default_constructor(type_decl) { gen_default_constructor(writer, type_decl, &super_name, catalog); } - gen_methods(writer, type_decl, &super_name, catalog); + + let mut counter = 0u32; + for method in &type_decl.methods { + let mut method_lambda_infos: HashMap = HashMap::new(); + scan_and_gen_lambdas( + writer, + type_decl, + &super_name, + catalog, + method, + &mut method_lambda_infos, + &mut counter, + ); + gen_method( + writer, + type_decl, + method, + &super_name, + catalog, + &method_lambda_infos, + ); + } } fn gen_fields(writer: &mut ClassFileWriter, fields: &[FieldDecl]) { @@ -78,14 +103,163 @@ fn gen_fields(writer: &mut ClassFileWriter, fields: &[FieldDecl]) { } } -fn gen_methods( +struct SamInfo { + interface: String, + method_name: String, + method_type: String, + return_ty: Ty, +} + +fn resolve_sam_interface(expr: &Expr, catalog: &ClassCatalog, param_count: usize) -> SamInfo { + if let Expr::Lambda { + target_ty: Some(Ty::Class(name)), + .. + } = expr + { + if let Some(method) = catalog.functional_interface_method(name) { + let (method_type, return_ty) = erased_descriptor_from_method_ref(&method); + return SamInfo { + interface: name.to_string(), + method_name: method.name.clone(), + method_type, + return_ty, + }; + } + } + match param_count { + 0 => SamInfo { + interface: "java/util/function/Supplier".into(), + method_name: "get".into(), + method_type: "()Ljava/lang/Object;".into(), + return_ty: Ty::object(), + }, + 1 => SamInfo { + interface: "java/util/function/Function".into(), + method_name: "apply".into(), + method_type: "(Ljava/lang/Object;)Ljava/lang/Object;".into(), + return_ty: Ty::object(), + }, + _ => SamInfo { + interface: "java/util/function/BiFunction".into(), + method_name: "apply".into(), + method_type: "(Ljava/lang/Object;Ljava/lang/Object;)Ljava/lang/Object;".into(), + return_ty: Ty::object(), + }, + } +} + +fn erased_descriptor_from_method_ref(mr: &javac_call_resolver::MethodRef) -> (String, Ty) { + let param_descs: String = mr + .params + .iter() + .map(|_| "Ljava/lang/Object;") + .collect::>() + .join(""); + let (ret, return_ty) = if matches!(mr.return_ty, Ty::Void) { + ("V", Ty::Void) + } else { + ("Ljava/lang/Object;", Ty::object()) + }; + (format!("({}){}", param_descs, ret), return_ty) +} + +fn scan_and_gen_lambdas( writer: &mut ClassFileWriter, type_decl: &TypeDecl, super_name: &str, catalog: &ClassCatalog, + method: &MethodDecl, + lambda_infos: &mut HashMap, + counter: &mut u32, ) { - for method in &type_decl.methods { - gen_method(writer, type_decl, method, super_name, catalog); + for (expr_id, expr) in method.body.exprs.iter() { + if let Expr::Lambda { + params, + body: lambda_body, + .. + } = expr + { + let synthetic_name = format!("lambda${}${}", method.name, counter); + *counter += 1; + + let sam_info = resolve_sam_interface(expr, catalog, params.len()); + + let param_descs: String = params + .iter() + .map(|_| "Ljava/lang/Object;") + .collect::>() + .join(""); + let impl_descriptor = format!("({}){}", param_descs, sam_info.return_ty.descriptor()); + let sam_descriptor = format!("()L{};", sam_info.interface); + + let impl_method_handle = Handle { + reference_kind: rust_asm::constants::REF_INVOKE_STATIC, + owner: type_decl.name.to_string(), + name: synthetic_name.clone(), + descriptor: impl_descriptor.clone(), + is_interface: false, + }; + + { + let mut mw = writer.visit_method( + javac_classfile::ACC_PRIVATE + | javac_classfile::ACC_STATIC + | javac_classfile::ACC_SYNTHETIC, + &synthetic_name, + &impl_descriptor, + ); + mw.visit_code(); + + let mut ctx = CodegenCtx::new(writer, type_decl.name, catalog); + ctx.set_super_name(ustr::Ustr::from(super_name)); + ctx.set_fields(&type_decl.fields); + ctx.set_methods(&type_decl.methods); + + ctx.return_ty = sam_info.return_ty.clone(); + ctx.next_local = 0; + ctx.locals.clear(); + ctx.local_types.clear(); + for (i, param) in params.iter().enumerate() { + let ty = param.ty.clone().unwrap_or(Ty::object()); + mw.visit_local_variable( + param.name.as_str(), + &ty.erasure().descriptor(), + i as u16, + ); + ctx.locals.insert(param.name, i as u16); + ctx.local_types.insert(param.name, ty); + ctx.next_local = (i as u16) + 1; + } + + match lambda_body { + LambdaBody::Expr(body_expr_id) => { + expr_gen::gen_expr(&mut mw, &mut ctx, &method.body, *body_expr_id); + let body_ty = expr_gen::expr_ty(&ctx, &method.body, *body_expr_id); + mw.visit_insn(return_opcode(&body_ty)); + } + LambdaBody::Block(block) => { + crate::method_gen::gen_method_body(&mut mw, &mut ctx, &method.body, block); + } + } + + mw.visit_maxs(0, 0); + mw.visit_end(writer); + } + + lambda_infos.insert( + expr_id, + LambdaInfo { + synthetic_name, + sam_interface: sam_info.interface.clone(), + sam_method_name: sam_info.method_name.clone(), + sam_method_type: sam_info.method_type.clone(), + sam_descriptor: sam_descriptor.to_string(), + impl_descriptor, + params: params.clone(), + impl_method_handle, + }, + ); + } } } @@ -95,6 +269,7 @@ fn gen_method( method: &MethodDecl, super_name: &str, catalog: &ClassCatalog, + lambda_infos: &HashMap, ) { let descriptor = method.signature.descriptor(); let mut mw = writer.visit_method(method.access_flags, &method.name, &descriptor); @@ -113,6 +288,7 @@ fn gen_method( ctx.set_super_name(ustr::Ustr::from(super_name)); ctx.set_fields(&type_decl.fields); ctx.set_methods(&type_decl.methods); + ctx.lambda_info = lambda_infos.clone(); ctx.begin_method(method); declare_method_locals(&mut mw, type_decl, method); gen_constructor_prelude(&mut mw, &ctx, method); diff --git a/crates/javac-bytecode/src/codegen.rs b/crates/javac-bytecode/src/codegen.rs index 19fce1a..78a31cd 100644 --- a/crates/javac-bytecode/src/codegen.rs +++ b/crates/javac-bytecode/src/codegen.rs @@ -1,7 +1,8 @@ use javac_call_resolver::ClassCatalog; use javac_classfile::{ClassFileWriter, Label}; -use javac_hir::hir::{Block, FieldDecl, MethodDecl}; +use javac_hir::hir::{Block, ExprId, FieldDecl, LambdaParam, MethodDecl}; use javac_ty::{MethodSig, Ty}; +use rust_asm::insn::Handle; use std::collections::HashMap; use ustr::Ustr; @@ -29,6 +30,18 @@ pub struct ControlTarget { pub cleanup_depth: usize, } +#[derive(Clone)] +pub struct LambdaInfo { + pub synthetic_name: String, + pub sam_interface: String, + pub sam_method_name: String, + pub sam_method_type: String, + pub sam_descriptor: String, + pub impl_descriptor: String, + pub params: Vec, + pub impl_method_handle: Handle, +} + pub struct CodegenCtx<'a> { pub writer: &'a mut ClassFileWriter, pub catalog: ClassCatalog, @@ -45,6 +58,7 @@ pub struct CodegenCtx<'a> { pub labeled_break_labels: Vec<(Ustr, ControlTarget)>, pub labeled_continue_labels: Vec<(Ustr, ControlTarget)>, pub cleanup_scopes: Vec, + pub lambda_info: HashMap, } impl<'a> CodegenCtx<'a> { @@ -65,6 +79,7 @@ impl<'a> CodegenCtx<'a> { labeled_break_labels: Vec::new(), labeled_continue_labels: Vec::new(), cleanup_scopes: Vec::new(), + lambda_info: HashMap::new(), } } diff --git a/crates/javac-bytecode/src/expr_gen.rs b/crates/javac-bytecode/src/expr_gen.rs index 52de0ec..618fced 100644 --- a/crates/javac-bytecode/src/expr_gen.rs +++ b/crates/javac-bytecode/src/expr_gen.rs @@ -23,8 +23,13 @@ use crate::codegen::CodegenCtx; use javac_classfile::MethodWriter; use javac_hir::hir::*; use javac_ty::Ty; +use rust_asm::insn::{BootstrapArgument, Handle}; use rust_asm::opcodes; +const LAMBDA_METAFACTORY: &str = "java/lang/invoke/LambdaMetafactory"; +const METAFACTORY_NAME: &str = "metafactory"; +const METAFACTORY_DESC: &str = "(Ljava/lang/invoke/MethodHandles$Lookup;Ljava/lang/String;Ljava/lang/invoke/MethodType;Ljava/lang/invoke/MethodType;Ljava/lang/invoke/MethodHandle;Ljava/lang/invoke/MethodType;)Ljava/lang/invoke/CallSite;"; + pub(crate) use arrays::array_load_opcode; pub(crate) use convert::{cast, coerce, pop_ty, push_default_value}; pub(crate) use types::expr_ty; @@ -116,6 +121,9 @@ pub fn gen_expr(mw: &mut MethodWriter, ctx: &mut CodegenCtx, body: &Body, expr_i gen_expr(mw, ctx, body, *expr); mw.visit_type_insn(opcodes::INSTANCEOF, &ty.internal_name()); } + Expr::Lambda { .. } => { + emit_lambda(mw, ctx, expr_id); + } _ => push_default_value(mw, &expr_ty(ctx, body, expr_id)), } } @@ -180,3 +188,26 @@ fn emit_ternary( coerce(mw, &expr_ty(ctx, body, else_expr), &result_ty); mw.visit_label(end_label); } + +fn emit_lambda(mw: &mut MethodWriter, ctx: &CodegenCtx, expr_id: ExprId) { + let Some(info) = ctx.lambda_info.get(&expr_id) else { + mw.visit_insn(opcodes::ACONST_NULL); + return; + }; + + let bsm = Handle { + reference_kind: rust_asm::constants::REF_INVOKE_STATIC, + owner: LAMBDA_METAFACTORY.to_string(), + name: METAFACTORY_NAME.to_string(), + descriptor: METAFACTORY_DESC.to_string(), + is_interface: false, + }; + + let args = vec![ + BootstrapArgument::MethodType(info.sam_method_type.clone()), + BootstrapArgument::Handle(info.impl_method_handle.clone()), + BootstrapArgument::MethodType(info.impl_descriptor.clone()), + ]; + + mw.visit_invoke_dynamic_insn(&info.sam_method_name, &info.sam_descriptor, bsm, &args); +} diff --git a/crates/javac-bytecode/src/validation.rs b/crates/javac-bytecode/src/validation.rs index 7269308..edd31f1 100644 --- a/crates/javac-bytecode/src/validation.rs +++ b/crates/javac-bytecode/src/validation.rs @@ -333,10 +333,22 @@ impl Validator { } Ok(()) } - Expr::Lambda { body: lambda, .. } => match lambda { - LambdaBody::Expr(expr) => self.validate_expr(body, scope, *expr), - LambdaBody::Block(block) => self.validate_block(body, &mut scope.clone(), block), - }, + Expr::Lambda { + params, + body: lambda, + .. + } => { + let mut lambda_scope = scope.clone(); + for param in params { + lambda_scope + .locals + .insert(param.name, param.ty.clone().unwrap_or(Ty::object())); + } + match lambda { + LambdaBody::Expr(expr) => self.validate_expr(body, &mut lambda_scope, *expr), + LambdaBody::Block(block) => self.validate_block(body, &mut lambda_scope, block), + } + } Expr::MethodRef { target, .. } => self.validate_expr(body, scope, *target), Expr::IntLiteral(_) | Expr::LongLiteral(_) diff --git a/crates/javac-call-resolver/src/catalog.rs b/crates/javac-call-resolver/src/catalog.rs index f46bc1c..32bd98e 100644 --- a/crates/javac-call-resolver/src/catalog.rs +++ b/crates/javac-call-resolver/src/catalog.rs @@ -184,6 +184,31 @@ impl ClassCatalog { } } + pub fn is_interface(&self, internal_name: &str) -> bool { + self.interfaces.contains(internal_name) + } + + pub fn functional_interface_method(&self, internal_name: &str) -> Option { + if !self.interfaces.contains(internal_name) { + return None; + } + + let mut sam: Option = None; + for ((owner, _), methods) in &self.methods { + if owner == internal_name { + for m in methods { + if m.is_interface { + if sam.is_some() { + return None; + } + sam = Some(m.clone()); + } + } + } + } + sam + } + pub fn resolve_static_field(&self, owner: &str, name: &str) -> Option { self.lookup_order(owner).into_iter().find_map(|owner| { self.fields diff --git a/crates/javac-classfile/src/writer.rs b/crates/javac-classfile/src/writer.rs index 8507861..91a8a2e 100644 --- a/crates/javac-classfile/src/writer.rs +++ b/crates/javac-classfile/src/writer.rs @@ -5,6 +5,7 @@ use rust_asm::class_writer::{ use rust_asm::constant_pool::{ConstantPoolBuilder, CpInfo}; pub use rust_asm::insn::Label; use rust_asm::insn::LabelNode; +pub use rust_asm::insn::{BootstrapArgument, Handle}; use std::collections::HashMap; pub struct ClassFileWriter { @@ -188,6 +189,17 @@ impl MethodWriter { .visit_method_insn(opcode, owner, name, descriptor, is_interface); } + pub fn visit_invoke_dynamic_insn( + &mut self, + name: &str, + descriptor: &str, + bootstrap_method: Handle, + bootstrap_args: &[BootstrapArgument], + ) { + self.inner + .visit_invokedynamic_insn(name, descriptor, bootstrap_method, bootstrap_args); + } + pub fn visit_ldc_insn_int(&mut self, value: i32) { self.inner .visit_ldc_insn(rust_asm::insn::LdcInsnNode::int(value)); diff --git a/crates/javac-hir/src/hir.rs b/crates/javac-hir/src/hir.rs index fdbcffc..db24825 100644 --- a/crates/javac-hir/src/hir.rs +++ b/crates/javac-hir/src/hir.rs @@ -282,6 +282,7 @@ pub enum Expr { Lambda { params: Vec, body: LambdaBody, + target_ty: Option, }, MethodRef { diff --git a/crates/javac-hir/src/lowering/expr.rs b/crates/javac-hir/src/lowering/expr.rs index 8d77a99..48424fa 100644 --- a/crates/javac-hir/src/lowering/expr.rs +++ b/crates/javac-hir/src/lowering/expr.rs @@ -67,6 +67,145 @@ impl BodyBuilder { self.pattern_names.contains(&name) && self.local_ty(name).is_none() } + pub(super) fn resolve_lambda_target_types(&mut self, method_return_ty: &Ty) { + let mut targets: Vec<(ExprId, Ty)> = Vec::new(); + for (_, stmt) in self.body.stmts.iter() { + self.collect_lambda_targets(stmt, method_return_ty, &mut targets); + } + for (expr_id, ty) in targets { + if let Expr::Lambda { target_ty: t, .. } = &mut self.body.exprs[expr_id] { + *t = Some(ty); + } + } + } + + fn collect_lambda_targets( + &self, + stmt: &Stmt, + method_return_ty: &Ty, + targets: &mut Vec<(ExprId, Ty)>, + ) { + match stmt { + Stmt::LocalVar(decl) => { + if let Some(init) = decl.initializer { + self.push_lambda_target(init, &decl.ty, targets); + } + } + Stmt::Return(expr) => { + if let Some(expr_id) = expr { + self.push_lambda_target(*expr_id, method_return_ty, targets); + } + } + Stmt::Expr(expr_id) => { + self.collect_expr_lambda_targets(*expr_id, targets); + } + Stmt::Block(block) => { + for &s in &block.stmts { + self.collect_lambda_targets(&self.body.stmts[s], method_return_ty, targets); + } + } + Stmt::If { + then_branch, + else_branch, + .. + } => { + self.collect_lambda_targets( + &self.body.stmts[*then_branch], + method_return_ty, + targets, + ); + if let Some(eb) = else_branch { + self.collect_lambda_targets(&self.body.stmts[*eb], method_return_ty, targets); + } + } + Stmt::For { body, .. } + | Stmt::ForEach { body, .. } + | Stmt::While { body, .. } + | Stmt::Do { body, .. } => { + self.collect_lambda_targets(&self.body.stmts[*body], method_return_ty, targets); + } + Stmt::Labeled { body, .. } => { + self.collect_lambda_targets(&self.body.stmts[*body], method_return_ty, targets); + } + Stmt::Synchronized(_, block) => { + for &s in &block.stmts { + self.collect_lambda_targets(&self.body.stmts[s], method_return_ty, targets); + } + } + Stmt::Try(try_stmt) => { + for &s in &try_stmt.body.stmts { + self.collect_lambda_targets(&self.body.stmts[s], method_return_ty, targets); + } + for catch in &try_stmt.catches { + for &s in &catch.body.stmts { + self.collect_lambda_targets(&self.body.stmts[s], method_return_ty, targets); + } + } + if let Some(finally) = &try_stmt.finally { + for &s in &finally.stmts { + self.collect_lambda_targets(&self.body.stmts[s], method_return_ty, targets); + } + } + } + Stmt::Switch { cases, .. } => { + for case in cases { + let stmts = match case { + SwitchCase::Case { body, .. } | SwitchCase::Default { body, .. } => body, + }; + for &s in stmts { + self.collect_lambda_targets(&self.body.stmts[s], method_return_ty, targets); + } + } + } + Stmt::Assert { + condition: _, + message, + } => { + if let Some(msg) = message { + self.collect_expr_lambda_targets(*msg, targets); + } + } + Stmt::Throw(expr_id) | Stmt::Yield(expr_id) => { + self.collect_expr_lambda_targets(*expr_id, targets); + } + Stmt::Empty | Stmt::Break(_) | Stmt::Continue(_) => {} + } + } + + fn collect_expr_lambda_targets(&self, expr_id: ExprId, targets: &mut Vec<(ExprId, Ty)>) { + match &self.body.exprs[expr_id] { + Expr::Assign { target, value, .. } => { + let target_ty = self.expr_ty(*target); + self.push_lambda_target(*value, &target_ty, targets); + } + Expr::MethodCall { + target: _, + method: _, + args: _, + } => {} + Expr::Parens(inner) => { + self.collect_expr_lambda_targets(*inner, targets); + } + _ => {} + } + } + + fn push_lambda_target(&self, expr_id: ExprId, target_ty: &Ty, targets: &mut Vec<(ExprId, Ty)>) { + let unwrapped = self.unwrap_parens(expr_id); + if matches!(self.body.exprs[unwrapped], Expr::Lambda { .. }) { + targets.push((unwrapped, target_ty.clone())); + } + } + + fn unwrap_parens(&self, mut expr_id: ExprId) -> ExprId { + loop { + match &self.body.exprs[expr_id] { + Expr::Parens(inner) => expr_id = *inner, + _ => break expr_id, + } + } + } + pub(super) fn pattern_binding(&self, expr_id: ExprId) -> Option<(Ustr, Ty, ExprId)> { match &self.body.exprs[expr_id] { Expr::Instanceof { @@ -452,6 +591,31 @@ impl ExprLowerer<'_, '_> { } JavaSyntaxKind::NewKw => self.parse_new_expr(), JavaSyntaxKind::Ident => { + let is_lambda = self + .tokens + .get(self.pos + 1) + .is_some_and(|t| t.kind == JavaSyntaxKind::Arrow); + if is_lambda { + self.pos += 1; + let name = Ustr::from(token.text.as_str()); + self.pos += 1; + self.body.define_local(name, Ty::object()); + let body = if self.at_lambda_block() { + self.skip_block_tokens(); + LambdaBody::Block(Block { stmts: vec![] }) + } else { + LambdaBody::Expr(self.parse_expr()?) + }; + let params = vec![LambdaParam { + name, + ty: Some(Ty::object()), + }]; + return Ok(self.body.alloc_expr(Expr::Lambda { + params, + body, + target_ty: None, + })); + } let name = self.expect_ident()?; let name = Ustr::from(&name); if self.body.pattern_name_is_out_of_scope(name) { @@ -470,6 +634,32 @@ impl ExprLowerer<'_, '_> { } Ok(self.body.alloc_expr(Expr::Ident(name))) } + JavaSyntaxKind::LParen if self.is_lambda_paren() => { + self.pos += 1; + let mut params = Vec::new(); + while !self.eat(JavaSyntaxKind::RParen) { + let name = self.expect_ident()?; + let name = Ustr::from(&name); + self.body.define_local(name, Ty::object()); + params.push(LambdaParam { + name, + ty: Some(Ty::object()), + }); + self.eat(JavaSyntaxKind::Comma); + } + self.expect(JavaSyntaxKind::Arrow)?; + let body = if self.at_lambda_block() { + self.skip_block_tokens(); + LambdaBody::Block(Block { stmts: vec![] }) + } else { + LambdaBody::Expr(self.parse_expr()?) + }; + Ok(self.body.alloc_expr(Expr::Lambda { + params, + body, + target_ty: None, + })) + } JavaSyntaxKind::LParen => { self.pos += 1; let inner = self.parse_expr()?; @@ -643,6 +833,38 @@ impl ExprLowerer<'_, '_> { None } + fn is_lambda_paren(&self) -> bool { + let mut depth = 1i32; + let mut i = self.pos + 1; + loop { + if i >= self.tokens.len() { + return false; + } + match self.tokens[i].kind { + JavaSyntaxKind::LParen => depth += 1, + JavaSyntaxKind::RParen => { + depth -= 1; + if depth == 0 { + i += 1; + while i < self.tokens.len() + && matches!( + self.tokens[i].kind, + JavaSyntaxKind::Whitespace | JavaSyntaxKind::Comment + ) + { + i += 1; + } + return i < self.tokens.len() + && self.tokens[i].kind == JavaSyntaxKind::Arrow; + } + } + JavaSyntaxKind::Arrow => return false, + _ => {} + } + i += 1; + } + } + fn peek(&self) -> Option<&ExprToken> { self.tokens.get(self.pos) } @@ -651,6 +873,23 @@ impl ExprLowerer<'_, '_> { self.peek().map(|token| token.kind) } + fn at_lambda_block(&self) -> bool { + self.peek_kind() == Some(JavaSyntaxKind::LBrace) + } + + fn skip_block_tokens(&mut self) { + let mut depth = 1; + self.pos += 1; + while self.pos < self.tokens.len() && depth > 0 { + match self.tokens[self.pos].kind { + JavaSyntaxKind::LBrace => depth += 1, + JavaSyntaxKind::RBrace => depth -= 1, + _ => {} + } + self.pos += 1; + } + } + fn peek_binary_op(&self) -> Option<(BinaryOp, u8)> { let token = self.peek()?; let op = match token.kind { diff --git a/crates/javac-hir/src/lowering/member.rs b/crates/javac-hir/src/lowering/member.rs index 8d5a94c..48e6908 100644 --- a/crates/javac-hir/src/lowering/member.rs +++ b/crates/javac-hir/src/lowering/member.rs @@ -185,6 +185,8 @@ fn lower_method_decl( let mut body_builder = BodyBuilder::new(resolver.clone()); define_params(&mut body_builder, ¶ms); let root_block = lower_method_body(access_flags, &method, &mut body_builder)?; + let ret_ty = signature.return_type.clone(); + body_builder.resolve_lambda_target_types(&ret_ty); Ok(MethodDecl { id: HirId(method_index + 1), diff --git a/crates/javac-parser/src/parser/expr.rs b/crates/javac-parser/src/parser/expr.rs index 316caec..de69ef9 100644 --- a/crates/javac-parser/src/parser/expr.rs +++ b/crates/javac-parser/src/parser/expr.rs @@ -109,7 +109,14 @@ pub(crate) fn is_cast(p: &Parser) -> bool { } la.skip_type(); la.skip_array_dims(); - la.at(RParen) + if !la.at(RParen) { + return false; + } + la.advance(); + while la.pos < la.tokens.len() && matches!(la.tokens[la.pos].kind, Whitespace | Comment) { + la.pos += 1; + } + !la.at(Arrow) } pub(crate) fn postfix_suffix(p: &mut Parser) { @@ -166,6 +173,9 @@ pub(crate) fn primary_expr(p: &mut Parser) { SwitchKw => { stmt::switch_expr(p); } + LParen if is_lambda_paren(p) => { + lambda_expr_from_paren(p); + } LParen => { let m = p.start(); p.bump(); @@ -173,6 +183,9 @@ pub(crate) fn primary_expr(p: &mut Parser) { p.expect(RParen); m.complete(p, ParenExpr); } + Ident if is_ident_lambda(p) => { + lambda_expr_from_ident(p); + } Ident => { name_expr(p); } @@ -256,3 +269,75 @@ pub(crate) fn expr_list(p: &mut Parser) { expr(p); } } + +fn is_lambda_paren(p: &mut Parser) -> bool { + use JavaSyntaxKind::*; + let mut depth = 1; + let mut i = p.pos + 1; + loop { + if i >= p.tokens.len() { + return false; + } + match p.tokens[i].kind { + LParen => depth += 1, + RParen => { + depth -= 1; + if depth == 0 { + i += 1; + while i < p.tokens.len() && matches!(p.tokens[i].kind, Whitespace | Comment) { + i += 1; + } + return i < p.tokens.len() && p.tokens[i].kind == Arrow; + } + } + Whitespace | Comment => {} + Arrow => return false, + _ => {} + } + i += 1; + } +} + +fn is_ident_lambda(p: &mut Parser) -> bool { + use JavaSyntaxKind::*; + let mut next = p.pos + 1; + while next < p.tokens.len() && matches!(p.tokens[next].kind, Whitespace | Comment) { + next += 1; + } + next < p.tokens.len() && p.tokens[next].kind == Arrow +} + +fn lambda_expr_from_paren(p: &mut Parser) { + use JavaSyntaxKind::*; + let m = p.start(); + p.bump(); + while !p.at(RParen) && p.kind() != Error { + let pm = p.start(); + p.expect(Ident); + pm.complete(p, LambdaParam); + p.eat(Comma); + } + p.expect(RParen); + p.expect(Arrow); + if p.at(LBrace) { + stmt::block(p); + } else { + expr(p); + } + m.complete(p, LambdaExpr); +} + +fn lambda_expr_from_ident(p: &mut Parser) { + use JavaSyntaxKind::*; + let m = p.start(); + let pm = p.start(); + p.expect(Ident); + pm.complete(p, LambdaParam); + p.expect(Arrow); + if p.at(LBrace) { + stmt::block(p); + } else { + expr(p); + } + m.complete(p, LambdaExpr); +} diff --git a/tests/java/LambdaConsumerTest.java b/tests/java/LambdaConsumerTest.java new file mode 100644 index 0000000..20e44bd --- /dev/null +++ b/tests/java/LambdaConsumerTest.java @@ -0,0 +1,8 @@ +import java.util.function.Consumer; + +public class LambdaConsumerTest { + public static void main(String[] args) { + Consumer c = x -> System.out.println(x); + c.accept("hello from consumer"); + } +} diff --git a/tests/java/LambdaTest.java b/tests/java/LambdaTest.java new file mode 100644 index 0000000..157af08 --- /dev/null +++ b/tests/java/LambdaTest.java @@ -0,0 +1,8 @@ +import java.util.function.Supplier; + +public class LambdaTest { + public static void main(String[] args) { + Supplier r = () -> "hello from lambda"; + System.out.println(r.get()); + } +}